sengourav012 commited on
Commit
a224cb9
·
verified ·
1 Parent(s): 3be3204

Upload improved_viton.py

Browse files
Files changed (1) hide show
  1. improved_viton.py +946 -0
improved_viton.py ADDED
@@ -0,0 +1,946 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import time
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from torch.utils.data import Dataset, DataLoader
8
+ from PIL import Image
9
+ import torchvision.transforms as transforms
10
+ import matplotlib.pyplot as plt
11
+ from torchvision.utils import make_grid
12
+ from torch.optim.lr_scheduler import StepLR
13
+ import random
14
+ import cv2
15
+
16
+ # Ensure reproducibility
17
+ def set_seed(seed=42):
18
+ random.seed(seed)
19
+ np.random.seed(seed)
20
+ torch.manual_seed(seed)
21
+ torch.cuda.manual_seed_all(seed)
22
+ torch.backends.cudnn.deterministic = True
23
+ torch.backends.cudnn.benchmark = False
24
+ os.environ['PYTHONHASHSEED'] = str(seed)
25
+
26
+
27
+ # Improved dataset handling with proper data augmentation
28
+ class VITONDataset(Dataset):
29
+ def __init__(self, root_dir, mode='train', transform=None, augment=False):
30
+ """
31
+ Enhanced dataset class with better error handling and data augmentation
32
+
33
+ Args:
34
+ root_dir: Root directory of the dataset
35
+ mode: 'train' or 'test'
36
+ transform: Transforms to apply to images
37
+ augment: Whether to apply data augmentation
38
+ """
39
+ self.root_dir = root_dir
40
+ self.mode = mode
41
+ self.transform = transform
42
+ self.augment = augment
43
+
44
+ # Check if directories exist
45
+ img_dir = os.path.join(root_dir, f'{mode}_img')
46
+ cloth_dir = os.path.join(root_dir, f'{mode}_color')
47
+ label_dir = os.path.join(root_dir, f'{mode}_label')
48
+
49
+ if not os.path.exists(img_dir) or not os.path.exists(cloth_dir) or not os.path.exists(label_dir):
50
+ raise FileNotFoundError(f"One or more dataset directories not found in {root_dir}")
51
+
52
+ # Get all image names
53
+ self.image_names = []
54
+ for f in sorted(os.listdir(img_dir)):
55
+ if f.endswith('.jpg'):
56
+ # Make sure corresponding files exist
57
+ base_name = f.replace('_0.jpg', '')
58
+ cloth_path = os.path.join(cloth_dir, f"{base_name}_1.jpg")
59
+ label_path = os.path.join(label_dir, f"{base_name}_0.png")
60
+
61
+ if os.path.exists(cloth_path) and os.path.exists(label_path):
62
+ self.image_names.append(base_name)
63
+
64
+ print(f"Found {len(self.image_names)} valid samples in {mode} set")
65
+
66
+ def __len__(self):
67
+ return len(self.image_names)
68
+
69
+ def _apply_augmentation(self, img, cloth, label):
70
+ """Apply data augmentation"""
71
+ # Random horizontal flip
72
+ if random.random() > 0.5:
73
+ img = img.transpose(Image.FLIP_LEFT_RIGHT)
74
+ cloth = cloth.transpose(Image.FLIP_LEFT_RIGHT)
75
+ label = label.transpose(Image.FLIP_LEFT_RIGHT)
76
+
77
+ # Random brightness and contrast adjustment for person image
78
+ if random.random() > 0.7:
79
+ img = transforms.functional.adjust_brightness(img, random.uniform(0.8, 1.2))
80
+ img = transforms.functional.adjust_contrast(img, random.uniform(0.8, 1.2))
81
+
82
+ # Random color jitter for clothing
83
+ if random.random() > 0.7:
84
+ cloth = transforms.functional.adjust_brightness(cloth, random.uniform(0.8, 1.2))
85
+ cloth = transforms.functional.adjust_saturation(cloth, random.uniform(0.8, 1.2))
86
+
87
+ return img, cloth, label
88
+
89
+ def __getitem__(self, idx):
90
+ base_name = self.image_names[idx]
91
+
92
+ # Build file paths
93
+ img_path = os.path.join(self.root_dir, f'{self.mode}_img', f"{base_name}_0.jpg")
94
+ cloth_path = os.path.join(self.root_dir, f'{self.mode}_color', f"{base_name}_1.jpg")
95
+ label_path = os.path.join(self.root_dir, f'{self.mode}_label', f"{base_name}_0.png")
96
+
97
+ try:
98
+ # Load images
99
+ img = Image.open(img_path).convert('RGB').resize((192, 256))
100
+ cloth = Image.open(cloth_path).convert('RGB').resize((192, 256))
101
+ label = Image.open(label_path).convert('L').resize((192, 256), resample=Image.NEAREST)
102
+
103
+ # Apply augmentation if enabled
104
+ if self.augment and self.mode == 'train':
105
+ img, cloth, label = self._apply_augmentation(img, cloth, label)
106
+
107
+ # Convert label to numpy for processing
108
+ img_np = np.array(img)
109
+ label_np = np.array(label)
110
+
111
+ # Create agnostic person image (remove upclothes → label 4)
112
+ agnostic_np = img_np.copy()
113
+ agnostic_np[label_np == 4] = [128, 128, 128] # Grey out clothing region
114
+
115
+ # Create cloth mask (binary mask of clothing)
116
+ cloth_mask = (label_np == 4).astype(np.uint8) * 255
117
+ cloth_mask_img = Image.fromarray(cloth_mask)
118
+
119
+ # Apply transforms
120
+ to_tensor = self.transform if self.transform else transforms.ToTensor()
121
+
122
+ person_tensor = to_tensor(img)
123
+ agnostic_tensor = to_tensor(Image.fromarray(agnostic_np))
124
+ cloth_tensor = to_tensor(cloth)
125
+
126
+ # Fix: Handle cloth mask properly
127
+ if self.transform:
128
+ # Convert to RGB for consistent channel handling
129
+ cloth_mask_rgb = Image.fromarray(cloth_mask).convert('RGB')
130
+ cloth_mask_tensor = to_tensor(cloth_mask_rgb)
131
+ else:
132
+ # Simple ToTensor() normalization for grayscale image
133
+ cloth_mask_tensor = transforms.ToTensor()(cloth_mask_img)
134
+
135
+ # If needed, expand to 3 channels
136
+ if cloth_tensor.shape[0] == 3:
137
+ cloth_mask_tensor = cloth_mask_tensor.expand(3, -1, -1)
138
+
139
+ # One-hot encode the segmentation mask
140
+ label_tensor = torch.from_numpy(label_np).long()
141
+
142
+ sample = {
143
+ 'person': person_tensor,
144
+ 'agnostic': agnostic_tensor,
145
+ 'cloth': cloth_tensor,
146
+ 'cloth_mask': cloth_mask_tensor,
147
+ 'label': label_tensor,
148
+ 'name': base_name
149
+ }
150
+
151
+ return sample
152
+
153
+ except Exception as e:
154
+ print(f"Error loading sample {base_name}: {e}")
155
+ # Return a valid sample as fallback - get a different index
156
+ return self.__getitem__((idx + 1) % len(self.image_names))
157
+ # class VITONDataset(Dataset):
158
+ # def __init__(self, root_dir, mode='train', transform=None, augment=False):
159
+ # """
160
+ # Enhanced dataset class with better error handling and data augmentation
161
+
162
+ # Args:
163
+ # root_dir: Root directory of the dataset
164
+ # mode: 'train' or 'test'
165
+ # transform: Transforms to apply to images
166
+ # augment: Whether to apply data augmentation
167
+ # """
168
+ # self.root_dir = root_dir
169
+ # self.mode = mode
170
+ # self.transform = transform
171
+ # self.augment = augment
172
+
173
+ # # Check if directories exist
174
+ # img_dir = os.path.join(root_dir, f'{mode}_img')
175
+ # cloth_dir = os.path.join(root_dir, f'{mode}_color')
176
+ # label_dir = os.path.join(root_dir, f'{mode}_label')
177
+
178
+ # if not os.path.exists(img_dir) or not os.path.exists(cloth_dir) or not os.path.exists(label_dir):
179
+ # raise FileNotFoundError(f"One or more dataset directories not found in {root_dir}")
180
+
181
+ # # Get all image names
182
+ # self.image_names = []
183
+ # for f in sorted(os.listdir(img_dir)):
184
+ # if f.endswith('.jpg'):
185
+ # # Make sure corresponding files exist
186
+ # base_name = f.replace('_0.jpg', '')
187
+ # cloth_path = os.path.join(cloth_dir, f"{base_name}_1.jpg")
188
+ # label_path = os.path.join(label_dir, f"{base_name}_0.png")
189
+
190
+ # if os.path.exists(cloth_path) and os.path.exists(label_path):
191
+ # self.image_names.append(base_name)
192
+
193
+ # print(f"Found {len(self.image_names)} valid samples in {mode} set")
194
+
195
+ # def __len__(self):
196
+ # return len(self.image_names)
197
+
198
+ # def _apply_augmentation(self, img, cloth, label):
199
+ # """Apply data augmentation"""
200
+ # # Random horizontal flip
201
+ # if random.random() > 0.5:
202
+ # img = img.transpose(Image.FLIP_LEFT_RIGHT)
203
+ # cloth = cloth.transpose(Image.FLIP_LEFT_RIGHT)
204
+ # label = label.transpose(Image.FLIP_LEFT_RIGHT)
205
+
206
+ # # Random brightness and contrast adjustment for person image
207
+ # if random.random() > 0.7:
208
+ # img = transforms.functional.adjust_brightness(img, random.uniform(0.8, 1.2))
209
+ # img = transforms.functional.adjust_contrast(img, random.uniform(0.8, 1.2))
210
+
211
+ # # Random color jitter for clothing
212
+ # if random.random() > 0.7:
213
+ # cloth = transforms.functional.adjust_brightness(cloth, random.uniform(0.8, 1.2))
214
+ # cloth = transforms.functional.adjust_saturation(cloth, random.uniform(0.8, 1.2))
215
+
216
+ # return img, cloth, label
217
+
218
+ # def __getitem__(self, idx):
219
+ # base_name = self.image_names[idx]
220
+
221
+ # # Build file paths
222
+ # img_path = os.path.join(self.root_dir, f'{self.mode}_img', f"{base_name}_0.jpg")
223
+ # cloth_path = os.path.join(self.root_dir, f'{self.mode}_color', f"{base_name}_1.jpg")
224
+ # label_path = os.path.join(self.root_dir, f'{self.mode}_label', f"{base_name}_0.png")
225
+
226
+ # try:
227
+ # # Load images
228
+ # img = Image.open(img_path).convert('RGB').resize((192, 256))
229
+ # cloth = Image.open(cloth_path).convert('RGB').resize((192, 256))
230
+ # label = Image.open(label_path).convert('L').resize((192, 256), resample=Image.NEAREST)
231
+
232
+ # # Apply augmentation if enabled
233
+ # if self.augment and self.mode == 'train':
234
+ # img, cloth, label = self._apply_augmentation(img, cloth, label)
235
+
236
+ # # Convert label to numpy for processing
237
+ # img_np = np.array(img)
238
+ # label_np = np.array(label)
239
+
240
+ # # Create agnostic person image (remove upclothes → label 4)
241
+ # agnostic_np = img_np.copy()
242
+ # agnostic_np[label_np == 4] = [128, 128, 128] # Grey out clothing region
243
+
244
+ # # Create cloth mask (binary mask of clothing)
245
+ # cloth_mask = (label_np == 4).astype(np.uint8) * 255
246
+ # cloth_mask_img = Image.fromarray(cloth_mask)
247
+
248
+ # # Apply transforms
249
+ # to_tensor = self.transform if self.transform else transforms.ToTensor()
250
+
251
+ # person_tensor = to_tensor(img)
252
+ # agnostic_tensor = to_tensor(Image.fromarray(agnostic_np))
253
+ # cloth_tensor = to_tensor(cloth)
254
+
255
+ # # Fix: Ensure the cloth mask is properly processed to match expected dimensions
256
+ # # First convert to Pillow Image with mode 'L' (grayscale)
257
+ # cloth_mask_pil = Image.fromarray(cloth_mask, mode='L')
258
+
259
+ # # Then apply the transform (which should normalize to [-1, 1] range)
260
+ # if self.transform:
261
+ # # For custom transform that expects RGB input, convert grayscale to RGB
262
+ # cloth_mask_rgb = cloth_mask_pil.convert('RGB')
263
+ # cloth_mask_tensor = self.transform(cloth_mask_rgb)
264
+ # else:
265
+ # # If using basic ToTensor, keep as grayscale but repeat to 3 channels if needed
266
+ # cloth_mask_tensor = transforms.ToTensor()(cloth_mask_pil)
267
+
268
+ # # If model expects 3 channels, repeat the single channel
269
+ # if cloth_tensor.shape[0] == 3: # If cloth is RGB (3 channels)
270
+ # cloth_mask_tensor = cloth_mask_tensor.repeat(3, 1, 1)
271
+
272
+ # # One-hot encode the segmentation mask
273
+ # label_tensor = torch.from_numpy(label_np).long()
274
+
275
+ # sample = {
276
+ # 'person': person_tensor,
277
+ # 'agnostic': agnostic_tensor,
278
+ # 'cloth': cloth_tensor,
279
+ # 'cloth_mask': cloth_mask_tensor,
280
+ # 'label': label_tensor,
281
+ # 'name': base_name
282
+ # }
283
+
284
+ # return sample
285
+
286
+ # except Exception as e:
287
+ # print(f"Error loading sample {base_name}: {e}")
288
+ # # Return a valid sample as fallback - get a different index
289
+ # return self.__getitem__((idx + 1) % len(self.image_names))
290
+
291
+
292
+ # Improved U-Net with residual connections and attention
293
+ class AttentionBlock(nn.Module):
294
+ def __init__(self, F_g, F_l, F_int):
295
+ super(AttentionBlock, self).__init__()
296
+ self.W_g = nn.Sequential(
297
+ nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0),
298
+ nn.BatchNorm2d(F_int)
299
+ )
300
+
301
+ self.W_x = nn.Sequential(
302
+ nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0),
303
+ nn.BatchNorm2d(F_int)
304
+ )
305
+
306
+ self.psi = nn.Sequential(
307
+ nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0),
308
+ nn.BatchNorm2d(1),
309
+ nn.Sigmoid()
310
+ )
311
+
312
+ # Fixed: Change inplace ReLU to non-inplace
313
+ self.relu = nn.ReLU(inplace=False)
314
+
315
+ def forward(self, g, x):
316
+ g1 = self.W_g(g)
317
+ x1 = self.W_x(x)
318
+ psi = self.relu(g1 + x1)
319
+ psi = self.psi(psi)
320
+
321
+ return x * psi
322
+
323
+
324
+ class ResidualBlock(nn.Module):
325
+ def __init__(self, in_channels):
326
+ super(ResidualBlock, self).__init__()
327
+ self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
328
+ self.bn1 = nn.BatchNorm2d(in_channels)
329
+ self.conv2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
330
+ self.bn2 = nn.BatchNorm2d(in_channels)
331
+ # Fixed: Change inplace ReLU to non-inplace
332
+ self.relu = nn.ReLU(inplace=False)
333
+
334
+ def forward(self, x):
335
+ residual = x
336
+ out = self.relu(self.bn1(self.conv1(x)))
337
+ out = self.bn2(self.conv2(out))
338
+ out += residual
339
+ out = self.relu(out)
340
+ return out
341
+
342
+ class PatchDiscriminator(nn.Module):
343
+ def __init__(self, in_channels=6):
344
+ super(PatchDiscriminator, self).__init__()
345
+
346
+ def discriminator_block(in_filters, out_filters, normalization=True):
347
+ layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
348
+ if normalization:
349
+ layers.append(nn.BatchNorm2d(out_filters))
350
+ layers.append(nn.LeakyReLU(0.2, inplace=False)) # Fixed: inplace=False
351
+ return layers
352
+
353
+ self.model = nn.Sequential(
354
+ *discriminator_block(in_channels, 64, normalization=False),
355
+ *discriminator_block(64, 128),
356
+ *discriminator_block(128, 256),
357
+ *discriminator_block(256, 512),
358
+ nn.ZeroPad2d((1, 0, 1, 0)),
359
+ nn.Conv2d(512, 1, 4, padding=1, bias=False)
360
+ )
361
+
362
+ def forward(self, img_A, img_B):
363
+ # Concatenate image and condition
364
+ img_input = torch.cat((img_A, img_B), 1)
365
+ return self.model(img_input)
366
+
367
+ class ImprovedUNetGenerator(nn.Module):
368
+ def __init__(self, in_channels=6, out_channels=3):
369
+ super(ImprovedUNetGenerator, self).__init__()
370
+
371
+ # Encoder
372
+ self.enc1 = nn.Sequential(
373
+ nn.Conv2d(in_channels, 64, 4, 2, 1),
374
+ nn.LeakyReLU(0.2, inplace=False) # Fixed: inplace=False
375
+ )
376
+ self.enc2 = nn.Sequential(
377
+ nn.Conv2d(64, 128, 4, 2, 1),
378
+ nn.BatchNorm2d(128),
379
+ nn.LeakyReLU(0.2, inplace=False) # Fixed: inplace=False
380
+ )
381
+ self.enc3 = nn.Sequential(
382
+ nn.Conv2d(128, 256, 4, 2, 1),
383
+ nn.BatchNorm2d(256),
384
+ nn.LeakyReLU(0.2, inplace=False) # Fixed: inplace=False
385
+ )
386
+ self.enc4 = nn.Sequential(
387
+ nn.Conv2d(256, 512, 4, 2, 1),
388
+ nn.BatchNorm2d(512),
389
+ nn.LeakyReLU(0.2, inplace=False) # Fixed: inplace=False
390
+ )
391
+ self.enc5 = nn.Sequential(
392
+ nn.Conv2d(512, 512, 4, 2, 1),
393
+ nn.BatchNorm2d(512),
394
+ nn.LeakyReLU(0.2, inplace=False) # Fixed: inplace=False
395
+ )
396
+
397
+ # Bottleneck
398
+ self.bottleneck = ResidualBlock(512)
399
+
400
+ # Decoder
401
+ self.dec5 = nn.Sequential(
402
+ nn.ConvTranspose2d(512, 512, 4, 2, 1),
403
+ nn.BatchNorm2d(512),
404
+ nn.ReLU(inplace=False), # Fixed: inplace=False
405
+ nn.Dropout(0.5)
406
+ )
407
+ self.dec4 = nn.Sequential(
408
+ nn.ConvTranspose2d(1024, 256, 4, 2, 1),
409
+ nn.BatchNorm2d(256),
410
+ nn.ReLU(inplace=False), # Fixed: inplace=False
411
+ nn.Dropout(0.5)
412
+ )
413
+ self.dec3 = nn.Sequential(
414
+ nn.ConvTranspose2d(512, 128, 4, 2, 1),
415
+ nn.BatchNorm2d(128),
416
+ nn.ReLU(inplace=False) # Fixed: inplace=False
417
+ )
418
+ self.dec2 = nn.Sequential(
419
+ nn.ConvTranspose2d(256, 64, 4, 2, 1),
420
+ nn.BatchNorm2d(64),
421
+ nn.ReLU(inplace=False) # Fixed: inplace=False
422
+ )
423
+ self.dec1 = nn.Sequential(
424
+ nn.ConvTranspose2d(128, out_channels, 4, 2, 1),
425
+ nn.Tanh()
426
+ )
427
+
428
+ # Attention gates
429
+ self.att4 = AttentionBlock(F_g=512, F_l=512, F_int=256)
430
+ self.att3 = AttentionBlock(F_g=256, F_l=256, F_int=128)
431
+ self.att2 = AttentionBlock(F_g=128, F_l=128, F_int=64)
432
+ self.att1 = AttentionBlock(F_g=64, F_l=64, F_int=32)
433
+
434
+ def forward(self, x):
435
+ # Encoder
436
+ e1 = self.enc1(x)
437
+ e2 = self.enc2(e1)
438
+ e3 = self.enc3(e2)
439
+ e4 = self.enc4(e3)
440
+ e5 = self.enc5(e4)
441
+
442
+ # Bottleneck
443
+ b = self.bottleneck(e5)
444
+
445
+ # Decoder with attention and skip connections
446
+ d5 = self.dec5(b)
447
+ d5 = torch.cat([self.att4(d5, e4), d5], dim=1)
448
+
449
+ d4 = self.dec4(d5)
450
+ d4 = torch.cat([self.att3(d4, e3), d4], dim=1)
451
+
452
+ d3 = self.dec3(d4)
453
+ d3 = torch.cat([self.att2(d3, e2), d3], dim=1)
454
+
455
+ d2 = self.dec2(d3)
456
+ d2 = torch.cat([self.att1(d2, e1), d2], dim=1)
457
+
458
+ d1 = self.dec1(d2)
459
+
460
+ return d1
461
+
462
+
463
+ # Discriminator network for adversarial training
464
+ class ImprovedUNetGenerator(nn.Module):
465
+ def __init__(self, in_channels=6, out_channels=3):
466
+ super(ImprovedUNetGenerator, self).__init__()
467
+
468
+ # Encoder
469
+ self.enc1 = nn.Sequential(
470
+ nn.Conv2d(in_channels, 64, 4, 2, 1),
471
+ nn.LeakyReLU(0.2, inplace=False) # Fixed: inplace=False
472
+ )
473
+ self.enc2 = nn.Sequential(
474
+ nn.Conv2d(64, 128, 4, 2, 1),
475
+ nn.BatchNorm2d(128),
476
+ nn.LeakyReLU(0.2, inplace=False) # Fixed: inplace=False
477
+ )
478
+ self.enc3 = nn.Sequential(
479
+ nn.Conv2d(128, 256, 4, 2, 1),
480
+ nn.BatchNorm2d(256),
481
+ nn.LeakyReLU(0.2, inplace=False) # Fixed: inplace=False
482
+ )
483
+ self.enc4 = nn.Sequential(
484
+ nn.Conv2d(256, 512, 4, 2, 1),
485
+ nn.BatchNorm2d(512),
486
+ nn.LeakyReLU(0.2, inplace=False) # Fixed: inplace=False
487
+ )
488
+ self.enc5 = nn.Sequential(
489
+ nn.Conv2d(512, 512, 4, 2, 1),
490
+ nn.BatchNorm2d(512),
491
+ nn.LeakyReLU(0.2, inplace=False) # Fixed: inplace=False
492
+ )
493
+
494
+ # Bottleneck
495
+ self.bottleneck = ResidualBlock(512)
496
+
497
+ # Decoder
498
+ self.dec5 = nn.Sequential(
499
+ nn.ConvTranspose2d(512, 512, 4, 2, 1),
500
+ nn.BatchNorm2d(512),
501
+ nn.ReLU(inplace=False), # Fixed: inplace=False
502
+ nn.Dropout(0.5)
503
+ )
504
+ self.dec4 = nn.Sequential(
505
+ nn.ConvTranspose2d(1024, 256, 4, 2, 1),
506
+ nn.BatchNorm2d(256),
507
+ nn.ReLU(inplace=False), # Fixed: inplace=False
508
+ nn.Dropout(0.5)
509
+ )
510
+ self.dec3 = nn.Sequential(
511
+ nn.ConvTranspose2d(512, 128, 4, 2, 1),
512
+ nn.BatchNorm2d(128),
513
+ nn.ReLU(inplace=False) # Fixed: inplace=False
514
+ )
515
+ self.dec2 = nn.Sequential(
516
+ nn.ConvTranspose2d(256, 64, 4, 2, 1),
517
+ nn.BatchNorm2d(64),
518
+ nn.ReLU(inplace=False) # Fixed: inplace=False
519
+ )
520
+ self.dec1 = nn.Sequential(
521
+ nn.ConvTranspose2d(128, out_channels, 4, 2, 1),
522
+ nn.Tanh()
523
+ )
524
+
525
+ # Attention gates
526
+ self.att4 = AttentionBlock(F_g=512, F_l=512, F_int=256)
527
+ self.att3 = AttentionBlock(F_g=256, F_l=256, F_int=128)
528
+ self.att2 = AttentionBlock(F_g=128, F_l=128, F_int=64)
529
+ self.att1 = AttentionBlock(F_g=64, F_l=64, F_int=32)
530
+
531
+ def forward(self, x):
532
+ # Encoder
533
+ e1 = self.enc1(x)
534
+ e2 = self.enc2(e1)
535
+ e3 = self.enc3(e2)
536
+ e4 = self.enc4(e3)
537
+ e5 = self.enc5(e4)
538
+
539
+ # Bottleneck
540
+ b = self.bottleneck(e5)
541
+
542
+ # Decoder with attention and skip connections
543
+ d5 = self.dec5(b)
544
+ d5 = torch.cat([self.att4(d5, e4), d5], dim=1)
545
+
546
+ d4 = self.dec4(d5)
547
+ d4 = torch.cat([self.att3(d4, e3), d4], dim=1)
548
+
549
+ d3 = self.dec3(d4)
550
+ d3 = torch.cat([self.att2(d3, e2), d3], dim=1)
551
+
552
+ d2 = self.dec2(d3)
553
+ d2 = torch.cat([self.att1(d2, e1), d2], dim=1)
554
+
555
+ d1 = self.dec1(d2)
556
+
557
+ return d1
558
+
559
+
560
+ # Custom loss functions
561
+ class VGGPerceptualLoss(nn.Module):
562
+ def __init__(self):
563
+ super(VGGPerceptualLoss, self).__init__()
564
+ # Import vgg here to avoid dependency at module level
565
+ import torchvision.models as models
566
+
567
+ # Load pretrained VGG but make sure to use non-inplace operations
568
+ vgg = models.vgg19(pretrained=True).features.eval()
569
+
570
+ # Replace inplace ReLU with non-inplace version
571
+ for idx, module in enumerate(vgg):
572
+ if isinstance(module, nn.ReLU):
573
+ vgg[idx] = nn.ReLU(inplace=False)
574
+
575
+ self.model = nn.Sequential()
576
+
577
+ # Using feature layers
578
+ feature_layers = [0, 2, 5, 10, 15, 20]
579
+ self.layer_weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0]
580
+
581
+ for i in range(len(feature_layers)):
582
+ self.model.add_module(f'layer_{i}', vgg[feature_layers[i]])
583
+
584
+ for param in self.model.parameters():
585
+ param.requires_grad = False
586
+
587
+ self.criterion = nn.L1Loss()
588
+ self.register_buffer("mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
589
+ self.register_buffer("std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
590
+
591
+ def forward(self, x, y):
592
+ x = (x - self.mean) / self.std
593
+ y = (y - self.mean) / self.std
594
+
595
+ loss = 0.0
596
+ x_features = x
597
+ y_features = y
598
+
599
+ for i, layer in enumerate(self.model):
600
+ x_features = layer(x_features)
601
+ y_features = layer(y_features)
602
+
603
+ if i in [0, 1, 2, 3, 4]: # Only compute loss at specified layers
604
+ loss += self.layer_weights[i] * self.criterion(x_features, y_features)
605
+
606
+ return loss
607
+
608
+
609
+ # Training setup
610
+ def train_model(model_G, model_D=None, train_loader=None, val_loader=None,
611
+ num_epochs=50, device=None, use_gan=True):
612
+ """
613
+ Improved training function with GAN training, learning rate scheduler, and validation
614
+ """
615
+ torch.autograd.set_detect_anomaly(True)
616
+ if device is None:
617
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
618
+
619
+ # Optimizers
620
+ optimizer_G = torch.optim.Adam(model_G.parameters(), lr=2e-4, betas=(0.5, 0.999))
621
+ scheduler_G = StepLR(optimizer_G, step_size=10, gamma=0.5)
622
+
623
+ # Losses
624
+ criterion_L1 = nn.L1Loss()
625
+ criterion_perceptual = VGGPerceptualLoss().to(device)
626
+
627
+ # GAN setup
628
+ if use_gan and model_D is not None:
629
+ optimizer_D = torch.optim.Adam(model_D.parameters(), lr=2e-4, betas=(0.5, 0.999))
630
+ scheduler_D = StepLR(optimizer_D, step_size=10, gamma=0.5)
631
+ criterion_GAN = nn.MSELoss()
632
+
633
+ # Lists to store losses for plotting
634
+ train_losses_G = []
635
+ train_losses_D = [] if use_gan else None
636
+ val_losses = []
637
+
638
+ # Training loop
639
+ for epoch in range(num_epochs):
640
+ model_G.train()
641
+ if use_gan and model_D is not None:
642
+ model_D.train()
643
+
644
+ epoch_loss_G = 0.0
645
+ epoch_loss_D = 0.0 if use_gan else None
646
+ start_time = time.time()
647
+
648
+ for i, sample in enumerate(train_loader):
649
+ agnostic = sample['agnostic'].to(device)
650
+ cloth = sample['cloth'].to(device)
651
+ target = sample['person'].to(device)
652
+ cloth_mask = sample['cloth_mask'].to(device)
653
+
654
+ # Combine inputs
655
+ input_tensor = torch.cat([agnostic, cloth], dim=1)
656
+
657
+ # -----------------
658
+ # Generator training
659
+ # -----------------
660
+ optimizer_G.zero_grad()
661
+
662
+ # Generate fake image
663
+ fake_image = model_G(input_tensor)
664
+
665
+ # Calculate L1 loss
666
+ loss_L1 = criterion_L1(fake_image, target)
667
+
668
+ # Calculate perceptual loss
669
+ loss_perceptual = criterion_perceptual(fake_image, target)
670
+
671
+ # Calculate total generator loss
672
+ loss_G = loss_L1 + 0.1 * loss_perceptual
673
+
674
+ # Add GAN loss if using adversarial training
675
+ if use_gan and model_D is not None:
676
+ # Adversarial loss (trick for stability: use 1s instead of 0.9)
677
+ pred_fake = model_D(fake_image, cloth)
678
+ target_real = torch.ones_like(pred_fake).to(device)
679
+ loss_GAN = criterion_GAN(pred_fake, target_real)
680
+
681
+ # Total generator loss with GAN component
682
+ loss_G += 0.1 * loss_GAN
683
+
684
+ # Backward pass and optimize generator
685
+ loss_G.backward()
686
+ optimizer_G.step()
687
+
688
+ epoch_loss_G += loss_G.item()
689
+
690
+ # -----------------
691
+ # Discriminator training (if using GAN)
692
+ # -----------------
693
+ if use_gan and model_D is not None:
694
+ optimizer_D.zero_grad()
695
+
696
+ # Real loss
697
+ pred_real = model_D(target, cloth)
698
+ target_real = torch.ones_like(pred_real).to(device)
699
+ loss_real = criterion_GAN(pred_real, target_real)
700
+
701
+ # Fake loss
702
+ pred_fake = model_D(fake_image.detach(), cloth)
703
+ target_fake = torch.zeros_like(pred_fake).to(device)
704
+ loss_fake = criterion_GAN(pred_fake, target_fake)
705
+
706
+ # Total discriminator loss
707
+ loss_D = (loss_real + loss_fake) / 2
708
+
709
+ # Backward pass and optimize discriminator
710
+ loss_D.backward()
711
+ optimizer_D.step()
712
+
713
+ epoch_loss_D += loss_D.item()
714
+
715
+ # Print progress
716
+ if (i+1) % 50 == 0:
717
+ time_elapsed = time.time() - start_time
718
+ print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], "
719
+ f"G Loss: {loss_G.item():.4f}, "
720
+ f"{'D Loss: ' + f'{loss_D.item():.4f}, ' if use_gan else ''}"
721
+ f"Time: {time_elapsed:.2f}s")
722
+
723
+ # Update learning rates
724
+ scheduler_G.step()
725
+ if use_gan and model_D is not None:
726
+ scheduler_D.step()
727
+
728
+ # Calculate average losses for this epoch
729
+ avg_loss_G = epoch_loss_G / len(train_loader)
730
+ train_losses_G.append(avg_loss_G)
731
+
732
+ if use_gan:
733
+ avg_loss_D = epoch_loss_D / len(train_loader)
734
+ train_losses_D.append(avg_loss_D)
735
+
736
+ # Validation
737
+ if val_loader is not None:
738
+ val_loss = validate_model(model_G, val_loader, device)
739
+ val_losses.append(val_loss)
740
+
741
+ print(f"Epoch {epoch+1}, Train Loss G: {avg_loss_G:.4f}, "
742
+ f"{'Train Loss D: ' + f'{avg_loss_D:.4f}, ' if use_gan else ''}"
743
+ f"Val Loss: {val_loss:.4f}, "
744
+ f"Time: {time.time()-start_time:.2f}s")
745
+ else:
746
+ print(f"Epoch {epoch+1}, Train Loss G: {avg_loss_G:.4f}, "
747
+ f"{'Train Loss D: ' + f'{avg_loss_D:.4f}, ' if use_gan else ''}"
748
+ f"Time: {time.time()-start_time:.2f}s")
749
+
750
+ # Save model checkpoint periodically
751
+ if (epoch+1) % 5 == 0:
752
+ save_checkpoint(model_G, model_D, optimizer_G, optimizer_D if use_gan else None,
753
+ epoch, f"checkpoint_epoch_{epoch+1}.pth")
754
+
755
+ # Visualize some results
756
+ if (epoch+1) % 5 == 0:
757
+ visualize_results(model_G, val_loader, device, epoch+1)
758
+
759
+ # Plot training losses
760
+ plot_losses(train_losses_G, train_losses_D, val_losses)
761
+
762
+ return model_G, model_D
763
+
764
+
765
+ def validate_model(model, val_loader, device):
766
+ """Validate the model on validation set"""
767
+ model.eval()
768
+ val_loss = 0.0
769
+ criterion = nn.L1Loss()
770
+
771
+ with torch.no_grad():
772
+ for sample in val_loader:
773
+ agnostic = sample['agnostic'].to(device)
774
+ cloth = sample['cloth'].to(device)
775
+ target = sample['person'].to(device)
776
+
777
+ input_tensor = torch.cat([agnostic, cloth], dim=1)
778
+ output = model(input_tensor)
779
+
780
+ loss = criterion(output, target)
781
+ val_loss += loss.item()
782
+
783
+ return val_loss / len(val_loader)
784
+
785
+
786
+ def visualize_results(model, dataloader, device, epoch):
787
+ """Visualize generated try-on results"""
788
+ model.eval()
789
+
790
+ # Get a batch of samples
791
+ for i, sample in enumerate(dataloader):
792
+ if i >= 1: # Only visualize first batch
793
+ break
794
+
795
+ with torch.no_grad():
796
+ agnostic = sample['agnostic'].to(device)
797
+ cloth = sample['cloth'].to(device)
798
+ target = sample['person'].to(device)
799
+
800
+ input_tensor = torch.cat([agnostic, cloth], dim=1)
801
+ output = model(input_tensor)
802
+
803
+ # Convert tensors for visualization
804
+ imgs = []
805
+ for j in range(min(4, output.size(0))): # Show max 4 examples
806
+ person_img = (target[j].cpu().permute(1, 2, 0).numpy() + 1) / 2
807
+ agnostic_img = (agnostic[j].cpu().permute(1, 2, 0).numpy() + 1) / 2
808
+ cloth_img = (cloth[j].cpu().permute(1, 2, 0).numpy() + 1) / 2
809
+ output_img = (output[j].cpu().permute(1, 2, 0).numpy() + 1) / 2
810
+
811
+ # Combine images for visualization
812
+ row1 = np.hstack([agnostic_img, cloth_img])
813
+ row2 = np.hstack([output_img, person_img])
814
+ combined = np.vstack([row1, row2])
815
+
816
+ imgs.append(combined)
817
+
818
+ # Create figure
819
+ fig, axs = plt.subplots(1, len(imgs), figsize=(15, 5))
820
+ if len(imgs) == 1:
821
+ axs = [axs]
822
+
823
+ for j, img in enumerate(imgs):
824
+ axs[j].imshow(img)
825
+ axs[j].set_title(f"Sample {j+1}")
826
+ axs[j].axis('off')
827
+
828
+ fig.suptitle(f"Epoch {epoch} Results", fontsize=16)
829
+ plt.tight_layout()
830
+
831
+ # Save figure
832
+ os.makedirs('results', exist_ok=True)
833
+ plt.savefig(f'results/epoch_{epoch}_samples.png')
834
+ plt.close()
835
+
836
+
837
+ def plot_losses(train_losses_G, train_losses_D=None, val_losses=None):
838
+ """Plot training and validation losses"""
839
+ plt.figure(figsize=(10, 5))
840
+ plt.plot(train_losses_G, label='Generator Loss')
841
+
842
+ if train_losses_D:
843
+ plt.plot(train_losses_D, label='Discriminator Loss')
844
+
845
+ if val_losses:
846
+ plt.plot(val_losses, label='Validation Loss')
847
+
848
+ plt.xlabel('Epochs')
849
+ plt.ylabel('Loss')
850
+ plt.title('Training and Validation Losses')
851
+ plt.legend()
852
+ plt.grid(True)
853
+
854
+ os.makedirs('results', exist_ok=True)
855
+ plt.savefig('results/loss_plot.png')
856
+ plt.close()
857
+
858
+
859
+ def save_checkpoint(model_G, model_D=None, optimizer_G=None, optimizer_D=None, epoch=None, filename="checkpoint.pth"):
860
+ """Save model checkpoint"""
861
+ os.makedirs('checkpoints', exist_ok=True)
862
+
863
+ checkpoint = {
864
+ 'epoch': epoch,
865
+ 'model_G_state_dict': model_G.state_dict(),
866
+ 'optimizer_G_state_dict': optimizer_G.state_dict() if optimizer_G else None,
867
+ }
868
+
869
+ if model_D is not None:
870
+ checkpoint['model_D_state_dict'] = model_D.state_dict()
871
+
872
+ if optimizer_D is not None:
873
+ checkpoint['optimizer_D_state_dict'] = optimizer_D.state_dict()
874
+
875
+ torch.save(checkpoint, f'checkpoints/{filename}')
876
+
877
+
878
+ def load_checkpoint(model_G, model_D=None, optimizer_G=None, optimizer_D=None, filename="checkpoint.pth"):
879
+ """Load model checkpoint"""
880
+ checkpoint = torch.load(f'checkpoints/{filename}')
881
+
882
+ model_G.load_state_dict(checkpoint['model_G_state_dict'])
883
+
884
+ if optimizer_G and 'optimizer_G_state_dict' in checkpoint:
885
+ optimizer_G.load_state_dict(checkpoint['optimizer_G_state_dict'])
886
+
887
+ if model_D is not None and 'model_D_state_dict' in checkpoint:
888
+ model_D.load_state_dict(checkpoint['model_D_state_dict'])
889
+
890
+ if optimizer_D is not None and 'optimizer_D_state_dict' in checkpoint:
891
+ optimizer_D.load_state_dict(checkpoint['optimizer_D_state_dict'])
892
+
893
+ return checkpoint.get('epoch', 0)
894
+
895
+
896
+ # Test function
897
+ def test_model(model, test_loader, device, result_dir='test_results'):
898
+ """Generate and save test results"""
899
+ model.eval()
900
+ os.makedirs(result_dir, exist_ok=True)
901
+
902
+ with torch.no_grad():
903
+ for i, sample in enumerate(test_loader):
904
+ agnostic = sample['agnostic'].to(device)
905
+ cloth = sample['cloth'].to(device)
906
+ target = sample['person'].to(device)
907
+ name = sample['name'][0] # Get sample name
908
+
909
+ # Generate try-on result
910
+ input_tensor = torch.cat([agnostic, cloth], dim=1)
911
+ output = model(input_tensor)
912
+
913
+ # Save images
914
+ output_img = (output[0].cpu().permute(1, 2, 0).numpy() + 1) / 2
915
+ target_img = (target[0].cpu().permute(1, 2, 0).numpy() + 1) / 2
916
+ agnostic_img = (agnostic[0].cpu().permute(1, 2, 0).numpy() + 1) / 2
917
+ cloth_img = (cloth[0].cpu().permute(1, 2, 0).numpy() + 1) / 2
918
+
919
+ # Save individual images
920
+ plt.imsave(f'{result_dir}/{name}_output.png', output_img)
921
+ plt.imsave(f'{result_dir}/{name}_target.png', target_img)
922
+
923
+ # Save comparison grid
924
+ fig, ax = plt.subplots(2, 2, figsize=(12, 12))
925
+ ax[0, 0].imshow(agnostic_img)
926
+ ax[0, 0].set_title('Person (w/o clothes)')
927
+ ax[0, 0].axis('off')
928
+
929
+ ax[0, 1].imshow(cloth_img)
930
+ ax[0, 1].set_title('Clothing Item')
931
+ ax[0, 1].axis('off')
932
+
933
+ ax[1, 0].imshow(output_img)
934
+ ax[1, 0].set_title('Generated Result')
935
+ ax[1, 0].axis('off')
936
+
937
+ ax[1, 1].imshow(target_img)
938
+ ax[1, 1].set_title('Ground Truth')
939
+ ax[1, 1].axis('off')
940
+
941
+ plt.tight_layout()
942
+ plt.savefig(f'{result_dir}/{name}_comparison.png')
943
+ plt.close()
944
+
945
+ if (i+1) % 10 == 0:
946
+ print(f"Processed {i+1}/{len(test_loader)} test samples")