Spaces:
Runtime error
Runtime error
import navi | |
from nodes.base_input import BaseInput | |
from ...impl.onnx.model import OnnxModel, OnnxModels, OnnxRemBg, is_rembg_model | |
from .generic_inputs import DropDownInput | |
class OnnxModelInput(BaseInput): | |
"""Input for onnx model""" | |
def __init__( | |
self, label: str = "Model", input_type: navi.ExpressionJson = "OnnxModel" | |
): | |
super().__init__(input_type, label) | |
self.associated_type = OnnxModel | |
class OnnxGenericModelInput(OnnxModelInput): | |
"""ONNX model input for things that aren't background removal""" | |
def __init__( | |
self, label: str = "Model", input_type: navi.ExpressionJson = "OnnxModel" | |
): | |
super().__init__(label, navi.intersect(input_type, "OnnxGenericModel")) | |
def enforce(self, value): | |
assert isinstance(value, OnnxModels) | |
assert not is_rembg_model(value.bytes), "Expected a non-rembg model" | |
return value | |
class OnnxRemBgModelInput(OnnxModelInput): | |
"""ONNX model input for background removal""" | |
def __init__( | |
self, label: str = "Model", input_type: navi.ExpressionJson = "OnnxModel" | |
): | |
super().__init__(label, navi.intersect(input_type, "OnnxRemBgModel")) | |
self.associated_type = OnnxRemBg | |
def enforce(self, value): | |
assert isinstance(value, OnnxModels) | |
assert is_rembg_model(value.bytes), "Expected a rembg model" | |
return value | |
def OnnxFpDropdown() -> DropDownInput: | |
return DropDownInput( | |
input_type="FpMode", | |
label="Data Type", | |
options=[ | |
{ | |
"option": "fp32", | |
"value": 0, | |
"type": "FpMode::fp32", | |
}, | |
{ | |
"option": "fp16", | |
"value": 1, | |
"type": "FpMode::fp16", | |
}, | |
], | |
) | |