ClothQuill / parser /u2net_parser.py
Bismay
Initial commit
475e066
import os
# from tqdm import tqdm
from PIL import Image
import numpy as np
import sys
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
from .u2net_cloth_seg.data.base_dataset import Normalize_image
from .u2net_cloth_seg.utils.saving_utils import load_checkpoint_mgpu
from .u2net_cloth_seg.networks import U2NET
class U2NETParser:
def __init__(self, checkpoint_path):
self.cp_path = checkpoint_path
self.img_transforms = transforms.Compose([
transforms.ToTensor(),
Normalize_image(0.5, 0.5)
])
self.model = self.load_model()
def load_model(self):
model = U2NET(in_ch=3, out_ch=4)
model = load_checkpoint_mgpu(model, self.cp_path)
model = model.to("cuda")
model = model.eval()
return model
def get_image_mask(self, img):
# print("Evaluating total class number {} with {}".format(self.num_classes, self.label))
img_size = img.size
img = img.resize((768, 768), Image.BICUBIC)
image_tensor = self.img_transforms(img)
image_tensor = torch.unsqueeze(image_tensor, 0)
with torch.no_grad():
output_tensor = self.model(image_tensor.to("cuda"))
output_tensor = F.log_softmax(output_tensor[0], dim=1)
output_tensor = torch.max(output_tensor, dim=1, keepdim=True)[1]
output_tensor = torch.squeeze(output_tensor, dim=0)
output_tensor = torch.squeeze(output_tensor, dim=0)
output_arr = output_tensor.cpu().numpy()
output_arr[output_arr != 1] = 0
output_arr[output_arr == 1] = 255
output_img = Image.fromarray(output_arr.astype('uint8'), mode='L')
output_img = output_img.resize(img_size, Image.BICUBIC)
return output_img