import os # from tqdm import tqdm from PIL import Image import numpy as np import sys import torch import torch.nn.functional as F import torchvision.transforms as transforms from .u2net_cloth_seg.data.base_dataset import Normalize_image from .u2net_cloth_seg.utils.saving_utils import load_checkpoint_mgpu from .u2net_cloth_seg.networks import U2NET class U2NETParser: def __init__(self, checkpoint_path): self.cp_path = checkpoint_path self.img_transforms = transforms.Compose([ transforms.ToTensor(), Normalize_image(0.5, 0.5) ]) self.model = self.load_model() def load_model(self): model = U2NET(in_ch=3, out_ch=4) model = load_checkpoint_mgpu(model, self.cp_path) model = model.to("cuda") model = model.eval() return model def get_image_mask(self, img): # print("Evaluating total class number {} with {}".format(self.num_classes, self.label)) img_size = img.size img = img.resize((768, 768), Image.BICUBIC) image_tensor = self.img_transforms(img) image_tensor = torch.unsqueeze(image_tensor, 0) with torch.no_grad(): output_tensor = self.model(image_tensor.to("cuda")) output_tensor = F.log_softmax(output_tensor[0], dim=1) output_tensor = torch.max(output_tensor, dim=1, keepdim=True)[1] output_tensor = torch.squeeze(output_tensor, dim=0) output_tensor = torch.squeeze(output_tensor, dim=0) output_arr = output_tensor.cpu().numpy() output_arr[output_arr != 1] = 0 output_arr[output_arr == 1] = 255 output_img = Image.fromarray(output_arr.astype('uint8'), mode='L') output_img = output_img.resize(img_size, Image.BICUBIC) return output_img