Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -64,7 +64,8 @@ class UNetGenerator(nn.Module):
|
|
64 |
# ----------------- Image Transforms -----------------
|
65 |
img_transform = transforms.Compose([
|
66 |
transforms.Resize((256, 192)),
|
67 |
-
transforms.ToTensor()
|
|
|
68 |
])
|
69 |
|
70 |
# ----------------- Helper Functions -----------------
|
@@ -108,24 +109,34 @@ def load_model(model_type):
|
|
108 |
model.eval()
|
109 |
return model
|
110 |
|
111 |
-
def generate_tryon_output(agnostic_img, cloth_img, model):
|
112 |
agnostic_tensor = img_transform(agnostic_img).unsqueeze(0).to(device)
|
113 |
cloth_tensor = img_transform(cloth_img).unsqueeze(0).to(device)
|
114 |
input_tensor = torch.cat([agnostic_tensor, cloth_tensor], dim=1)
|
115 |
|
116 |
with torch.no_grad():
|
117 |
output = model(input_tensor)
|
118 |
-
|
|
|
119 |
output_img = (output_img + 1) / 2
|
120 |
-
output_img = (output_img
|
121 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
|
123 |
# ----------------- Inference Pipeline -----------------
|
124 |
def virtual_tryon(person_image, cloth_image, model_type):
|
125 |
segmentation = get_segmentation(person_image)
|
126 |
agnostic = generate_agnostic(person_image, segmentation)
|
127 |
model = load_model(model_type)
|
128 |
-
result = generate_tryon_output(agnostic, cloth_image, model)
|
129 |
return agnostic, result
|
130 |
|
131 |
# ----------------- Gradio Interface -----------------
|
|
|
64 |
# ----------------- Image Transforms -----------------
|
65 |
img_transform = transforms.Compose([
|
66 |
transforms.Resize((256, 192)),
|
67 |
+
transforms.ToTensor(),
|
68 |
+
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
69 |
])
|
70 |
|
71 |
# ----------------- Helper Functions -----------------
|
|
|
109 |
model.eval()
|
110 |
return model
|
111 |
|
112 |
+
def generate_tryon_output(person_img, agnostic_img, cloth_img, segmentation, model):
|
113 |
agnostic_tensor = img_transform(agnostic_img).unsqueeze(0).to(device)
|
114 |
cloth_tensor = img_transform(cloth_img).unsqueeze(0).to(device)
|
115 |
input_tensor = torch.cat([agnostic_tensor, cloth_tensor], dim=1)
|
116 |
|
117 |
with torch.no_grad():
|
118 |
output = model(input_tensor)
|
119 |
+
|
120 |
+
output_img = output[0].cpu().permute(1, 2, 0).numpy()
|
121 |
output_img = (output_img + 1) / 2
|
122 |
+
output_img = np.clip(output_img, 0, 1)
|
123 |
+
|
124 |
+
person_np = np.array(person_img.resize((192, 256))).astype(np.float32) / 255.0
|
125 |
+
segmentation_resized = cv2.resize(segmentation.astype(np.uint8), (192, 256), interpolation=cv2.INTER_NEAREST)
|
126 |
+
blend_mask = (segmentation_resized == 0).astype(np.float32)
|
127 |
+
blend_mask = np.expand_dims(blend_mask, axis=2)
|
128 |
+
|
129 |
+
final_output = blend_mask * person_np + (1 - blend_mask) * output_img
|
130 |
+
final_output = (final_output * 255).astype(np.uint8)
|
131 |
+
|
132 |
+
return Image.fromarray(final_output)
|
133 |
|
134 |
# ----------------- Inference Pipeline -----------------
|
135 |
def virtual_tryon(person_image, cloth_image, model_type):
|
136 |
segmentation = get_segmentation(person_image)
|
137 |
agnostic = generate_agnostic(person_image, segmentation)
|
138 |
model = load_model(model_type)
|
139 |
+
result = generate_tryon_output(person_image, agnostic, cloth_image, segmentation, model)
|
140 |
return agnostic, result
|
141 |
|
142 |
# ----------------- Gradio Interface -----------------
|