Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -6,6 +6,7 @@ import numpy as np
|
|
6 |
from torchvision import transforms
|
7 |
import cv2
|
8 |
from transformers import AutoImageProcessor, SegformerForSemanticSegmentation
|
|
|
9 |
|
10 |
# ----------------- Device -----------------
|
11 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
@@ -84,21 +85,53 @@ def generate_agnostic(image: Image.Image, segmentation):
|
|
84 |
agnostic_np[segmentation_resized == label] = [128, 128, 128]
|
85 |
return Image.fromarray(agnostic_np)
|
86 |
|
87 |
-
def load_model(model_type):
|
88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
if model_type == "UNet":
|
|
|
90 |
checkpoint = torch.load("viton_unet_full_checkpoint.pth", map_location=device)
|
|
|
91 |
elif model_type == "GAN":
|
|
|
92 |
checkpoint = torch.load("viton_gan_full_checkpoint.pth", map_location=device)
|
|
|
93 |
elif model_type == "Diffusion":
|
|
|
94 |
checkpoint = torch.load("viton_diffusion_full_checkpoint.pth", map_location=device)
|
|
|
95 |
else:
|
96 |
-
raise ValueError("Invalid model type")
|
97 |
|
98 |
-
model
|
|
|
|
|
|
|
|
|
99 |
model.eval()
|
100 |
return model
|
101 |
|
|
|
102 |
def generate_tryon_output(agnostic_img, cloth_img, model):
|
103 |
agnostic_tensor = img_transform(agnostic_img).unsqueeze(0).to(device)
|
104 |
cloth_tensor = img_transform(cloth_img).unsqueeze(0).to(device)
|
|
|
6 |
from torchvision import transforms
|
7 |
import cv2
|
8 |
from transformers import AutoImageProcessor, SegformerForSemanticSegmentation
|
9 |
+
from improved_viton import ImprovedUNetGenerator, load_checkpoint
|
10 |
|
11 |
# ----------------- Device -----------------
|
12 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
85 |
agnostic_np[segmentation_resized == label] = [128, 128, 128]
|
86 |
return Image.fromarray(agnostic_np)
|
87 |
|
88 |
+
# def load_model(model_type):
|
89 |
+
# model = UNetGenerator().to(device)
|
90 |
+
# if model_type == "UNet":
|
91 |
+
# checkpoint = torch.load("viton_unet_full_checkpoint.pth", map_location=device)
|
92 |
+
# elif model_type == "GAN":
|
93 |
+
# checkpoint = torch.load("viton_gan_full_checkpoint.pth", map_location=device)
|
94 |
+
# elif model_type == "Diffusion":
|
95 |
+
# checkpoint = torch.load("viton_diffusion_full_checkpoint.pth", map_location=device)
|
96 |
+
# else:
|
97 |
+
# raise ValueError("Invalid model type")
|
98 |
+
|
99 |
+
# model.load_state_dict(checkpoint['model_state_dict'])
|
100 |
+
# model.eval()
|
101 |
+
# return model
|
102 |
+
|
103 |
+
import torch
|
104 |
+
from improved_viton import ImprovedUNetGenerator
|
105 |
+
|
106 |
+
def load_model(model_type, device):
|
107 |
+
# Initialize generator architecture
|
108 |
+
# model = ImprovedUNetGenerator(in_channels=6, out_channels=3).to(device)
|
109 |
+
|
110 |
+
# Load appropriate checkpoint
|
111 |
if model_type == "UNet":
|
112 |
+
model = UNetGenerator().to(device)
|
113 |
checkpoint = torch.load("viton_unet_full_checkpoint.pth", map_location=device)
|
114 |
+
state_dict = checkpoint.get('model_G_state_dict') or checkpoint.get('model_state_dict')
|
115 |
elif model_type == "GAN":
|
116 |
+
model = ImprovedUNetGenerator(in_channels=6, out_channels=3).to(device)
|
117 |
checkpoint = torch.load("viton_gan_full_checkpoint.pth", map_location=device)
|
118 |
+
state_dict = checkpoint.get('model_G_state_dict')
|
119 |
elif model_type == "Diffusion":
|
120 |
+
model = ImprovedUNetGenerator(in_channels=6, out_channels=3).to(device)
|
121 |
checkpoint = torch.load("viton_diffusion_full_checkpoint.pth", map_location=device)
|
122 |
+
state_dict = checkpoint.get('model_G_state_dict') or checkpoint.get('model_state_dict')
|
123 |
else:
|
124 |
+
raise ValueError(f"Invalid model type: {model_type}")
|
125 |
|
126 |
+
# Load the state dict into the model
|
127 |
+
if state_dict is None:
|
128 |
+
raise KeyError(f"No model weights found in the checkpoint for {model_type}")
|
129 |
+
|
130 |
+
model.load_state_dict(state_dict)
|
131 |
model.eval()
|
132 |
return model
|
133 |
|
134 |
+
|
135 |
def generate_tryon_output(agnostic_img, cloth_img, model):
|
136 |
agnostic_tensor = img_transform(agnostic_img).unsqueeze(0).to(device)
|
137 |
cloth_tensor = img_transform(cloth_img).unsqueeze(0).to(device)
|