Spaces:
Running
Running
import os | |
# from tqdm import tqdm | |
from PIL import Image | |
import numpy as np | |
import sys | |
import torch | |
import torch.nn.functional as F | |
import torchvision.transforms as transforms | |
from .u2net_cloth_seg.data.base_dataset import Normalize_image | |
from .u2net_cloth_seg.utils.saving_utils import load_checkpoint_mgpu | |
from .u2net_cloth_seg.networks import U2NET | |
class U2NETParser: | |
def __init__(self, checkpoint_path): | |
self.cp_path = checkpoint_path | |
self.img_transforms = transforms.Compose([ | |
transforms.ToTensor(), | |
Normalize_image(0.5, 0.5) | |
]) | |
self.model = self.load_model() | |
def load_model(self): | |
model = U2NET(in_ch=3, out_ch=4) | |
model = load_checkpoint_mgpu(model, self.cp_path) | |
model = model.to("cuda") | |
model = model.eval() | |
return model | |
def get_image_mask(self, img): | |
# print("Evaluating total class number {} with {}".format(self.num_classes, self.label)) | |
img_size = img.size | |
img = img.resize((768, 768), Image.BICUBIC) | |
image_tensor = self.img_transforms(img) | |
image_tensor = torch.unsqueeze(image_tensor, 0) | |
with torch.no_grad(): | |
output_tensor = self.model(image_tensor.to("cuda")) | |
output_tensor = F.log_softmax(output_tensor[0], dim=1) | |
output_tensor = torch.max(output_tensor, dim=1, keepdim=True)[1] | |
output_tensor = torch.squeeze(output_tensor, dim=0) | |
output_tensor = torch.squeeze(output_tensor, dim=0) | |
output_arr = output_tensor.cpu().numpy() | |
output_arr[output_arr != 1] = 0 | |
output_arr[output_arr == 1] = 255 | |
output_img = Image.fromarray(output_arr.astype('uint8'), mode='L') | |
output_img = output_img.resize(img_size, Image.BICUBIC) | |
return output_img |