File size: 1,806 Bytes
475e066
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import os
# from tqdm import tqdm
from PIL import Image
import numpy as np
import sys

import torch
import torch.nn.functional as F
import torchvision.transforms as transforms

from .u2net_cloth_seg.data.base_dataset import Normalize_image
from .u2net_cloth_seg.utils.saving_utils import load_checkpoint_mgpu

from .u2net_cloth_seg.networks import U2NET

class U2NETParser:
    def __init__(self, checkpoint_path):
        self.cp_path = checkpoint_path
        self.img_transforms = transforms.Compose([
            transforms.ToTensor(),
            Normalize_image(0.5, 0.5)
        ])
        self.model = self.load_model()


    def load_model(self):
        model = U2NET(in_ch=3, out_ch=4)
        model = load_checkpoint_mgpu(model, self.cp_path)
        model = model.to("cuda")
        model = model.eval()
        return model

    def get_image_mask(self, img): 
        # print("Evaluating total class number {} with {}".format(self.num_classes, self.label))
        img_size = img.size
        img = img.resize((768, 768), Image.BICUBIC)
        image_tensor = self.img_transforms(img)
        image_tensor = torch.unsqueeze(image_tensor, 0)

        with torch.no_grad():
            output_tensor = self.model(image_tensor.to("cuda"))

        output_tensor = F.log_softmax(output_tensor[0], dim=1)
        output_tensor = torch.max(output_tensor, dim=1, keepdim=True)[1]
        output_tensor = torch.squeeze(output_tensor, dim=0)
        output_tensor = torch.squeeze(output_tensor, dim=0)
        output_arr = output_tensor.cpu().numpy()

        output_arr[output_arr != 1] = 0
        output_arr[output_arr == 1] = 255

        output_img = Image.fromarray(output_arr.astype('uint8'), mode='L')
        output_img = output_img.resize(img_size, Image.BICUBIC)

        return output_img