Spaces:
Runtime error
Runtime error
# pylint: disable=relative-beyond-top-level | |
from __future__ import annotations | |
from typing import List, Optional, Union | |
import numpy as np | |
import navi | |
from nodes.base_input import BaseInput, ErrorValue | |
from ...impl.color.color import Color | |
from ...utils.format import format_color_with_channels, format_image_with_channels | |
from ...utils.utils import get_h_w_c | |
class AudioInput(BaseInput): | |
"""Input a 1D Audio NumPy array""" | |
def __init__(self, label: str = "Audio"): | |
super().__init__("Audio", label) | |
class ImageInput(BaseInput): | |
"""Input a 2D Image NumPy array""" | |
def __init__( | |
self, | |
label: str = "Image", | |
image_type: navi.ExpressionJson = "Image | Color", | |
channels: Union[int, List[int], None] = None, | |
allow_colors: bool = False, | |
): | |
base_type = [navi.Image(channels=channels)] | |
if allow_colors: | |
base_type.append(navi.Color(channels=channels)) | |
image_type = navi.intersect(image_type, base_type) | |
super().__init__(image_type, label) | |
self.channels: Optional[List[int]] = ( | |
[channels] if isinstance(channels, int) else channels | |
) | |
self.allow_colors: bool = allow_colors | |
self.associated_type = np.ndarray | |
if self.allow_colors: | |
self.associated_type = Union[np.ndarray, Color] | |
def enforce(self, value): | |
if isinstance(value, Color): | |
if not self.allow_colors: | |
raise ValueError( | |
f"The input {self.label} does not accept colors, but was connected with one." | |
) | |
if self.channels is not None and value.channels not in self.channels: | |
expected = format_color_with_channels(self.channels, plural=True) | |
actual = format_color_with_channels([value.channels]) | |
raise ValueError( | |
f"The input {self.label} only supports {expected} but was given {actual}." | |
) | |
return value | |
assert isinstance(value, np.ndarray) | |
_, _, c = get_h_w_c(value) | |
if self.channels is not None and c not in self.channels: | |
expected = format_image_with_channels(self.channels, plural=True) | |
actual = format_image_with_channels([c]) | |
raise ValueError( | |
f"The input {self.label} only supports {expected} but was given {actual}." | |
) | |
assert value.dtype == np.float32, "Expected the input image to be normalized." | |
if c == 1 and value.ndim == 3: | |
value = value[:, :, 0] | |
return value | |
def get_error_value(self, value) -> ErrorValue: | |
def get_channels(channel: int) -> str: | |
if channel == 1: | |
return "Grayscale" | |
if channel == 3: | |
return "RGB" | |
if channel == 4: | |
return "RGBA" | |
return f"{channel}-channel" | |
if isinstance(value, Color): | |
return { | |
"type": "formatted", | |
"formatString": f"{get_channels(value.channels)} Color", | |
} | |
elif isinstance(value, np.ndarray): | |
h, w, c = get_h_w_c(value) | |
return { | |
"type": "formatted", | |
"formatString": f"{get_channels(c)} Image {w}x{h}", | |
} | |
else: | |
return super().get_error_value(value) | |
class VideoInput(BaseInput): | |
"""Input a 3D Video NumPy array""" | |
def __init__(self, label: str = "Video"): | |
super().__init__("Video", label) | |