Spaces:
Runtime error
Runtime error
File size: 1,819 Bytes
c19ca42 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
import navi
from nodes.base_input import BaseInput
from ...impl.onnx.model import OnnxModel, OnnxModels, OnnxRemBg, is_rembg_model
from .generic_inputs import DropDownInput
class OnnxModelInput(BaseInput):
"""Input for onnx model"""
def __init__(
self, label: str = "Model", input_type: navi.ExpressionJson = "OnnxModel"
):
super().__init__(input_type, label)
self.associated_type = OnnxModel
class OnnxGenericModelInput(OnnxModelInput):
"""ONNX model input for things that aren't background removal"""
def __init__(
self, label: str = "Model", input_type: navi.ExpressionJson = "OnnxModel"
):
super().__init__(label, navi.intersect(input_type, "OnnxGenericModel"))
def enforce(self, value):
assert isinstance(value, OnnxModels)
assert not is_rembg_model(value.bytes), "Expected a non-rembg model"
return value
class OnnxRemBgModelInput(OnnxModelInput):
"""ONNX model input for background removal"""
def __init__(
self, label: str = "Model", input_type: navi.ExpressionJson = "OnnxModel"
):
super().__init__(label, navi.intersect(input_type, "OnnxRemBgModel"))
self.associated_type = OnnxRemBg
def enforce(self, value):
assert isinstance(value, OnnxModels)
assert is_rembg_model(value.bytes), "Expected a rembg model"
return value
def OnnxFpDropdown() -> DropDownInput:
return DropDownInput(
input_type="FpMode",
label="Data Type",
options=[
{
"option": "fp32",
"value": 0,
"type": "FpMode::fp32",
},
{
"option": "fp16",
"value": 1,
"type": "FpMode::fp16",
},
],
)
|