File size: 3,567 Bytes
c19ca42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
# pylint: disable=relative-beyond-top-level
from __future__ import annotations

from typing import List, Optional, Union

import numpy as np

import navi
from nodes.base_input import BaseInput, ErrorValue

from ...impl.color.color import Color
from ...utils.format import format_color_with_channels, format_image_with_channels
from ...utils.utils import get_h_w_c


class AudioInput(BaseInput):
    """Input a 1D Audio NumPy array"""

    def __init__(self, label: str = "Audio"):
        super().__init__("Audio", label)


class ImageInput(BaseInput):
    """Input a 2D Image NumPy array"""

    def __init__(
        self,
        label: str = "Image",
        image_type: navi.ExpressionJson = "Image | Color",
        channels: Union[int, List[int], None] = None,
        allow_colors: bool = False,
    ):
        base_type = [navi.Image(channels=channels)]
        if allow_colors:
            base_type.append(navi.Color(channels=channels))
        image_type = navi.intersect(image_type, base_type)
        super().__init__(image_type, label)

        self.channels: Optional[List[int]] = (
            [channels] if isinstance(channels, int) else channels
        )
        self.allow_colors: bool = allow_colors

        self.associated_type = np.ndarray

        if self.allow_colors:
            self.associated_type = Union[np.ndarray, Color]

    def enforce(self, value):
        if isinstance(value, Color):
            if not self.allow_colors:
                raise ValueError(
                    f"The input {self.label} does not accept colors, but was connected with one."
                )

            if self.channels is not None and value.channels not in self.channels:
                expected = format_color_with_channels(self.channels, plural=True)
                actual = format_color_with_channels([value.channels])
                raise ValueError(
                    f"The input {self.label} only supports {expected} but was given {actual}."
                )

            return value

        assert isinstance(value, np.ndarray)
        _, _, c = get_h_w_c(value)

        if self.channels is not None and c not in self.channels:
            expected = format_image_with_channels(self.channels, plural=True)
            actual = format_image_with_channels([c])
            raise ValueError(
                f"The input {self.label} only supports {expected} but was given {actual}."
            )

        assert value.dtype == np.float32, "Expected the input image to be normalized."

        if c == 1 and value.ndim == 3:
            value = value[:, :, 0]

        return value

    def get_error_value(self, value) -> ErrorValue:
        def get_channels(channel: int) -> str:
            if channel == 1:
                return "Grayscale"
            if channel == 3:
                return "RGB"
            if channel == 4:
                return "RGBA"
            return f"{channel}-channel"

        if isinstance(value, Color):
            return {
                "type": "formatted",
                "formatString": f"{get_channels(value.channels)} Color",
            }
        elif isinstance(value, np.ndarray):
            h, w, c = get_h_w_c(value)
            return {
                "type": "formatted",
                "formatString": f"{get_channels(c)} Image {w}x{h}",
            }
        else:
            return super().get_error_value(value)


class VideoInput(BaseInput):
    """Input a 3D Video NumPy array"""

    def __init__(self, label: str = "Video"):
        super().__init__("Video", label)