sengourav012 commited on
Commit
3be3204
·
verified ·
1 Parent(s): feeffbe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -4
app.py CHANGED
@@ -6,6 +6,7 @@ import numpy as np
6
  from torchvision import transforms
7
  import cv2
8
  from transformers import AutoImageProcessor, SegformerForSemanticSegmentation
 
9
 
10
  # ----------------- Device -----------------
11
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -84,21 +85,53 @@ def generate_agnostic(image: Image.Image, segmentation):
84
  agnostic_np[segmentation_resized == label] = [128, 128, 128]
85
  return Image.fromarray(agnostic_np)
86
 
87
- def load_model(model_type):
88
- model = UNetGenerator().to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  if model_type == "UNet":
 
90
  checkpoint = torch.load("viton_unet_full_checkpoint.pth", map_location=device)
 
91
  elif model_type == "GAN":
 
92
  checkpoint = torch.load("viton_gan_full_checkpoint.pth", map_location=device)
 
93
  elif model_type == "Diffusion":
 
94
  checkpoint = torch.load("viton_diffusion_full_checkpoint.pth", map_location=device)
 
95
  else:
96
- raise ValueError("Invalid model type")
97
 
98
- model.load_state_dict(checkpoint['model_state_dict'])
 
 
 
 
99
  model.eval()
100
  return model
101
 
 
102
  def generate_tryon_output(agnostic_img, cloth_img, model):
103
  agnostic_tensor = img_transform(agnostic_img).unsqueeze(0).to(device)
104
  cloth_tensor = img_transform(cloth_img).unsqueeze(0).to(device)
 
6
  from torchvision import transforms
7
  import cv2
8
  from transformers import AutoImageProcessor, SegformerForSemanticSegmentation
9
+ from improved_viton import ImprovedUNetGenerator, load_checkpoint
10
 
11
  # ----------------- Device -----------------
12
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
85
  agnostic_np[segmentation_resized == label] = [128, 128, 128]
86
  return Image.fromarray(agnostic_np)
87
 
88
+ # def load_model(model_type):
89
+ # model = UNetGenerator().to(device)
90
+ # if model_type == "UNet":
91
+ # checkpoint = torch.load("viton_unet_full_checkpoint.pth", map_location=device)
92
+ # elif model_type == "GAN":
93
+ # checkpoint = torch.load("viton_gan_full_checkpoint.pth", map_location=device)
94
+ # elif model_type == "Diffusion":
95
+ # checkpoint = torch.load("viton_diffusion_full_checkpoint.pth", map_location=device)
96
+ # else:
97
+ # raise ValueError("Invalid model type")
98
+
99
+ # model.load_state_dict(checkpoint['model_state_dict'])
100
+ # model.eval()
101
+ # return model
102
+
103
+ import torch
104
+ from improved_viton import ImprovedUNetGenerator
105
+
106
+ def load_model(model_type, device):
107
+ # Initialize generator architecture
108
+ # model = ImprovedUNetGenerator(in_channels=6, out_channels=3).to(device)
109
+
110
+ # Load appropriate checkpoint
111
  if model_type == "UNet":
112
+ model = UNetGenerator().to(device)
113
  checkpoint = torch.load("viton_unet_full_checkpoint.pth", map_location=device)
114
+ state_dict = checkpoint.get('model_G_state_dict') or checkpoint.get('model_state_dict')
115
  elif model_type == "GAN":
116
+ model = ImprovedUNetGenerator(in_channels=6, out_channels=3).to(device)
117
  checkpoint = torch.load("viton_gan_full_checkpoint.pth", map_location=device)
118
+ state_dict = checkpoint.get('model_G_state_dict')
119
  elif model_type == "Diffusion":
120
+ model = ImprovedUNetGenerator(in_channels=6, out_channels=3).to(device)
121
  checkpoint = torch.load("viton_diffusion_full_checkpoint.pth", map_location=device)
122
+ state_dict = checkpoint.get('model_G_state_dict') or checkpoint.get('model_state_dict')
123
  else:
124
+ raise ValueError(f"Invalid model type: {model_type}")
125
 
126
+ # Load the state dict into the model
127
+ if state_dict is None:
128
+ raise KeyError(f"No model weights found in the checkpoint for {model_type}")
129
+
130
+ model.load_state_dict(state_dict)
131
  model.eval()
132
  return model
133
 
134
+
135
  def generate_tryon_output(agnostic_img, cloth_img, model):
136
  agnostic_tensor = img_transform(agnostic_img).unsqueeze(0).to(device)
137
  cloth_tensor = img_transform(cloth_img).unsqueeze(0).to(device)