File size: 1,251 Bytes
c19ca42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from typing import Dict, List, Tuple

import numpy as np
import onnxruntime as ort
from PIL import Image
from PIL.Image import Image as PILImage


class BaseSession:
    def __init__(
        self,
        inner_session: ort.InferenceSession,
        mean: Tuple[float, float, float],
        std: Tuple[float, float, float],
        size: Tuple[int, int],
    ):
        self.inner_session = inner_session
        self.mean = mean
        self.std = std
        self.size = size

    def normalize(self, img: PILImage) -> Dict[str, np.ndarray]:
        im = img.convert("RGB").resize(self.size, Image.LANCZOS)
        im_ary = np.array(im)
        im_ary = im_ary / np.max(im_ary)

        tmpImg = np.zeros((im_ary.shape[0], im_ary.shape[1], 3))
        tmpImg[:, :, 0] = (im_ary[:, :, 0] - self.mean[0]) / self.std[0]
        tmpImg[:, :, 1] = (im_ary[:, :, 1] - self.mean[1]) / self.std[1]
        tmpImg[:, :, 2] = (im_ary[:, :, 2] - self.mean[2]) / self.std[2]

        tmpImg = tmpImg.transpose((2, 0, 1))

        model_input_name = self.inner_session.get_inputs()[0].name

        return {model_input_name: np.expand_dims(tmpImg, 0).astype(np.float32)}

    def predict(self, _: PILImage) -> List[PILImage]:
        raise NotImplementedError