File size: 1,819 Bytes
c19ca42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import navi
from nodes.base_input import BaseInput

from ...impl.onnx.model import OnnxModel, OnnxModels, OnnxRemBg, is_rembg_model
from .generic_inputs import DropDownInput


class OnnxModelInput(BaseInput):
    """Input for onnx model"""

    def __init__(
        self, label: str = "Model", input_type: navi.ExpressionJson = "OnnxModel"
    ):
        super().__init__(input_type, label)
        self.associated_type = OnnxModel


class OnnxGenericModelInput(OnnxModelInput):
    """ONNX model input for things that aren't background removal"""

    def __init__(
        self, label: str = "Model", input_type: navi.ExpressionJson = "OnnxModel"
    ):
        super().__init__(label, navi.intersect(input_type, "OnnxGenericModel"))

    def enforce(self, value):
        assert isinstance(value, OnnxModels)
        assert not is_rembg_model(value.bytes), "Expected a non-rembg model"
        return value


class OnnxRemBgModelInput(OnnxModelInput):
    """ONNX model input for background removal"""

    def __init__(
        self, label: str = "Model", input_type: navi.ExpressionJson = "OnnxModel"
    ):
        super().__init__(label, navi.intersect(input_type, "OnnxRemBgModel"))
        self.associated_type = OnnxRemBg

    def enforce(self, value):
        assert isinstance(value, OnnxModels)
        assert is_rembg_model(value.bytes), "Expected a rembg model"
        return value


def OnnxFpDropdown() -> DropDownInput:
    return DropDownInput(
        input_type="FpMode",
        label="Data Type",
        options=[
            {
                "option": "fp32",
                "value": 0,
                "type": "FpMode::fp32",
            },
            {
                "option": "fp16",
                "value": 1,
                "type": "FpMode::fp16",
            },
        ],
    )