Spaces:
Runtime error
Runtime error
File size: 3,020 Bytes
c19ca42 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
try:
import torch
from ...impl.pytorch.types import (
PyTorchFaceModel,
PyTorchInpaintModel,
PyTorchModel,
PyTorchSRModel,
is_pytorch_face_model,
is_pytorch_inpaint_model,
is_pytorch_model,
is_pytorch_sr_model,
)
except:
torch = None
import navi
from nodes.base_input import BaseInput
class ModelInput(BaseInput):
"""Input a loaded model"""
def __init__(
self,
label: str = "Model",
input_type: navi.ExpressionJson = "PyTorchModel",
):
super().__init__(input_type, label)
if torch is not None:
self.associated_type = PyTorchModel
def enforce(self, value):
if torch is not None:
assert isinstance(value, torch.nn.Module), "Expected a PyTorch model."
assert is_pytorch_model(value), "Expected a supported PyTorch model."
return value
class SrModelInput(ModelInput):
def __init__(
self,
label: str = "Model",
input_type: navi.ExpressionJson = "PyTorchModel",
):
super().__init__(
label,
navi.intersect(input_type, "PyTorchSRModel"),
)
if torch is not None:
self.associated_type = PyTorchSRModel
def enforce(self, value):
if torch is not None:
assert isinstance(value, torch.nn.Module), "Expected a PyTorch model."
assert is_pytorch_sr_model(
value
), "Expected a regular Super-Resolution model."
return value
class FaceModelInput(ModelInput):
def __init__(
self, label: str = "Model", input_type: navi.ExpressionJson = "PyTorchModel"
):
super().__init__(
label,
navi.intersect(input_type, "PyTorchFaceModel"),
)
if torch is not None:
self.associated_type = PyTorchFaceModel
def enforce(self, value):
if torch is not None:
assert isinstance(value, torch.nn.Module), "Expected a PyTorch model."
assert is_pytorch_face_model(
value
), "Expected a Face-specific Super-Resolution model."
return value
class InpaintModelInput(ModelInput):
def __init__(
self, label: str = "Model", input_type: navi.ExpressionJson = "PyTorchModel"
):
super().__init__(
label,
navi.intersect(input_type, "PyTorchInpaintModel"),
)
if torch is not None:
self.associated_type = PyTorchInpaintModel
def enforce(self, value):
if torch is not None:
assert isinstance(value, torch.nn.Module), "Expected a PyTorch model."
assert is_pytorch_inpaint_model(
value
), "Expected an inpainting-specific model."
return value
class TorchScriptInput(BaseInput):
"""Input a JIT traced model"""
def __init__(self, label: str = "Traced Model"):
super().__init__("PyTorchScript", label)
|