Spaces:
Running
Running
import os | |
import numpy as np | |
import time | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.utils.data import Dataset, DataLoader | |
from PIL import Image | |
import torchvision.transforms as transforms | |
import matplotlib.pyplot as plt | |
from torchvision.utils import make_grid | |
from torch.optim.lr_scheduler import StepLR | |
import random | |
import cv2 | |
# Ensure reproducibility | |
def set_seed(seed=42): | |
random.seed(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed_all(seed) | |
torch.backends.cudnn.deterministic = True | |
torch.backends.cudnn.benchmark = False | |
os.environ['PYTHONHASHSEED'] = str(seed) | |
# Improved dataset handling with proper data augmentation | |
class VITONDataset(Dataset): | |
def __init__(self, root_dir, mode='train', transform=None, augment=False): | |
""" | |
Enhanced dataset class with better error handling and data augmentation | |
Args: | |
root_dir: Root directory of the dataset | |
mode: 'train' or 'test' | |
transform: Transforms to apply to images | |
augment: Whether to apply data augmentation | |
""" | |
self.root_dir = root_dir | |
self.mode = mode | |
self.transform = transform | |
self.augment = augment | |
# Check if directories exist | |
img_dir = os.path.join(root_dir, f'{mode}_img') | |
cloth_dir = os.path.join(root_dir, f'{mode}_color') | |
label_dir = os.path.join(root_dir, f'{mode}_label') | |
if not os.path.exists(img_dir) or not os.path.exists(cloth_dir) or not os.path.exists(label_dir): | |
raise FileNotFoundError(f"One or more dataset directories not found in {root_dir}") | |
# Get all image names | |
self.image_names = [] | |
for f in sorted(os.listdir(img_dir)): | |
if f.endswith('.jpg'): | |
# Make sure corresponding files exist | |
base_name = f.replace('_0.jpg', '') | |
cloth_path = os.path.join(cloth_dir, f"{base_name}_1.jpg") | |
label_path = os.path.join(label_dir, f"{base_name}_0.png") | |
if os.path.exists(cloth_path) and os.path.exists(label_path): | |
self.image_names.append(base_name) | |
print(f"Found {len(self.image_names)} valid samples in {mode} set") | |
def __len__(self): | |
return len(self.image_names) | |
def _apply_augmentation(self, img, cloth, label): | |
"""Apply data augmentation""" | |
# Random horizontal flip | |
if random.random() > 0.5: | |
img = img.transpose(Image.FLIP_LEFT_RIGHT) | |
cloth = cloth.transpose(Image.FLIP_LEFT_RIGHT) | |
label = label.transpose(Image.FLIP_LEFT_RIGHT) | |
# Random brightness and contrast adjustment for person image | |
if random.random() > 0.7: | |
img = transforms.functional.adjust_brightness(img, random.uniform(0.8, 1.2)) | |
img = transforms.functional.adjust_contrast(img, random.uniform(0.8, 1.2)) | |
# Random color jitter for clothing | |
if random.random() > 0.7: | |
cloth = transforms.functional.adjust_brightness(cloth, random.uniform(0.8, 1.2)) | |
cloth = transforms.functional.adjust_saturation(cloth, random.uniform(0.8, 1.2)) | |
return img, cloth, label | |
def __getitem__(self, idx): | |
base_name = self.image_names[idx] | |
# Build file paths | |
img_path = os.path.join(self.root_dir, f'{self.mode}_img', f"{base_name}_0.jpg") | |
cloth_path = os.path.join(self.root_dir, f'{self.mode}_color', f"{base_name}_1.jpg") | |
label_path = os.path.join(self.root_dir, f'{self.mode}_label', f"{base_name}_0.png") | |
try: | |
# Load images | |
img = Image.open(img_path).convert('RGB').resize((192, 256)) | |
cloth = Image.open(cloth_path).convert('RGB').resize((192, 256)) | |
label = Image.open(label_path).convert('L').resize((192, 256), resample=Image.NEAREST) | |
# Apply augmentation if enabled | |
if self.augment and self.mode == 'train': | |
img, cloth, label = self._apply_augmentation(img, cloth, label) | |
# Convert label to numpy for processing | |
img_np = np.array(img) | |
label_np = np.array(label) | |
# Create agnostic person image (remove upclothes → label 4) | |
agnostic_np = img_np.copy() | |
agnostic_np[label_np == 4] = [128, 128, 128] # Grey out clothing region | |
# Create cloth mask (binary mask of clothing) | |
cloth_mask = (label_np == 4).astype(np.uint8) * 255 | |
cloth_mask_img = Image.fromarray(cloth_mask) | |
# Apply transforms | |
to_tensor = self.transform if self.transform else transforms.ToTensor() | |
person_tensor = to_tensor(img) | |
agnostic_tensor = to_tensor(Image.fromarray(agnostic_np)) | |
cloth_tensor = to_tensor(cloth) | |
# Fix: Handle cloth mask properly | |
if self.transform: | |
# Convert to RGB for consistent channel handling | |
cloth_mask_rgb = Image.fromarray(cloth_mask).convert('RGB') | |
cloth_mask_tensor = to_tensor(cloth_mask_rgb) | |
else: | |
# Simple ToTensor() normalization for grayscale image | |
cloth_mask_tensor = transforms.ToTensor()(cloth_mask_img) | |
# If needed, expand to 3 channels | |
if cloth_tensor.shape[0] == 3: | |
cloth_mask_tensor = cloth_mask_tensor.expand(3, -1, -1) | |
# One-hot encode the segmentation mask | |
label_tensor = torch.from_numpy(label_np).long() | |
sample = { | |
'person': person_tensor, | |
'agnostic': agnostic_tensor, | |
'cloth': cloth_tensor, | |
'cloth_mask': cloth_mask_tensor, | |
'label': label_tensor, | |
'name': base_name | |
} | |
return sample | |
except Exception as e: | |
print(f"Error loading sample {base_name}: {e}") | |
# Return a valid sample as fallback - get a different index | |
return self.__getitem__((idx + 1) % len(self.image_names)) | |
# class VITONDataset(Dataset): | |
# def __init__(self, root_dir, mode='train', transform=None, augment=False): | |
# """ | |
# Enhanced dataset class with better error handling and data augmentation | |
# Args: | |
# root_dir: Root directory of the dataset | |
# mode: 'train' or 'test' | |
# transform: Transforms to apply to images | |
# augment: Whether to apply data augmentation | |
# """ | |
# self.root_dir = root_dir | |
# self.mode = mode | |
# self.transform = transform | |
# self.augment = augment | |
# # Check if directories exist | |
# img_dir = os.path.join(root_dir, f'{mode}_img') | |
# cloth_dir = os.path.join(root_dir, f'{mode}_color') | |
# label_dir = os.path.join(root_dir, f'{mode}_label') | |
# if not os.path.exists(img_dir) or not os.path.exists(cloth_dir) or not os.path.exists(label_dir): | |
# raise FileNotFoundError(f"One or more dataset directories not found in {root_dir}") | |
# # Get all image names | |
# self.image_names = [] | |
# for f in sorted(os.listdir(img_dir)): | |
# if f.endswith('.jpg'): | |
# # Make sure corresponding files exist | |
# base_name = f.replace('_0.jpg', '') | |
# cloth_path = os.path.join(cloth_dir, f"{base_name}_1.jpg") | |
# label_path = os.path.join(label_dir, f"{base_name}_0.png") | |
# if os.path.exists(cloth_path) and os.path.exists(label_path): | |
# self.image_names.append(base_name) | |
# print(f"Found {len(self.image_names)} valid samples in {mode} set") | |
# def __len__(self): | |
# return len(self.image_names) | |
# def _apply_augmentation(self, img, cloth, label): | |
# """Apply data augmentation""" | |
# # Random horizontal flip | |
# if random.random() > 0.5: | |
# img = img.transpose(Image.FLIP_LEFT_RIGHT) | |
# cloth = cloth.transpose(Image.FLIP_LEFT_RIGHT) | |
# label = label.transpose(Image.FLIP_LEFT_RIGHT) | |
# # Random brightness and contrast adjustment for person image | |
# if random.random() > 0.7: | |
# img = transforms.functional.adjust_brightness(img, random.uniform(0.8, 1.2)) | |
# img = transforms.functional.adjust_contrast(img, random.uniform(0.8, 1.2)) | |
# # Random color jitter for clothing | |
# if random.random() > 0.7: | |
# cloth = transforms.functional.adjust_brightness(cloth, random.uniform(0.8, 1.2)) | |
# cloth = transforms.functional.adjust_saturation(cloth, random.uniform(0.8, 1.2)) | |
# return img, cloth, label | |
# def __getitem__(self, idx): | |
# base_name = self.image_names[idx] | |
# # Build file paths | |
# img_path = os.path.join(self.root_dir, f'{self.mode}_img', f"{base_name}_0.jpg") | |
# cloth_path = os.path.join(self.root_dir, f'{self.mode}_color', f"{base_name}_1.jpg") | |
# label_path = os.path.join(self.root_dir, f'{self.mode}_label', f"{base_name}_0.png") | |
# try: | |
# # Load images | |
# img = Image.open(img_path).convert('RGB').resize((192, 256)) | |
# cloth = Image.open(cloth_path).convert('RGB').resize((192, 256)) | |
# label = Image.open(label_path).convert('L').resize((192, 256), resample=Image.NEAREST) | |
# # Apply augmentation if enabled | |
# if self.augment and self.mode == 'train': | |
# img, cloth, label = self._apply_augmentation(img, cloth, label) | |
# # Convert label to numpy for processing | |
# img_np = np.array(img) | |
# label_np = np.array(label) | |
# # Create agnostic person image (remove upclothes → label 4) | |
# agnostic_np = img_np.copy() | |
# agnostic_np[label_np == 4] = [128, 128, 128] # Grey out clothing region | |
# # Create cloth mask (binary mask of clothing) | |
# cloth_mask = (label_np == 4).astype(np.uint8) * 255 | |
# cloth_mask_img = Image.fromarray(cloth_mask) | |
# # Apply transforms | |
# to_tensor = self.transform if self.transform else transforms.ToTensor() | |
# person_tensor = to_tensor(img) | |
# agnostic_tensor = to_tensor(Image.fromarray(agnostic_np)) | |
# cloth_tensor = to_tensor(cloth) | |
# # Fix: Ensure the cloth mask is properly processed to match expected dimensions | |
# # First convert to Pillow Image with mode 'L' (grayscale) | |
# cloth_mask_pil = Image.fromarray(cloth_mask, mode='L') | |
# # Then apply the transform (which should normalize to [-1, 1] range) | |
# if self.transform: | |
# # For custom transform that expects RGB input, convert grayscale to RGB | |
# cloth_mask_rgb = cloth_mask_pil.convert('RGB') | |
# cloth_mask_tensor = self.transform(cloth_mask_rgb) | |
# else: | |
# # If using basic ToTensor, keep as grayscale but repeat to 3 channels if needed | |
# cloth_mask_tensor = transforms.ToTensor()(cloth_mask_pil) | |
# # If model expects 3 channels, repeat the single channel | |
# if cloth_tensor.shape[0] == 3: # If cloth is RGB (3 channels) | |
# cloth_mask_tensor = cloth_mask_tensor.repeat(3, 1, 1) | |
# # One-hot encode the segmentation mask | |
# label_tensor = torch.from_numpy(label_np).long() | |
# sample = { | |
# 'person': person_tensor, | |
# 'agnostic': agnostic_tensor, | |
# 'cloth': cloth_tensor, | |
# 'cloth_mask': cloth_mask_tensor, | |
# 'label': label_tensor, | |
# 'name': base_name | |
# } | |
# return sample | |
# except Exception as e: | |
# print(f"Error loading sample {base_name}: {e}") | |
# # Return a valid sample as fallback - get a different index | |
# return self.__getitem__((idx + 1) % len(self.image_names)) | |
# Improved U-Net with residual connections and attention | |
class AttentionBlock(nn.Module): | |
def __init__(self, F_g, F_l, F_int): | |
super(AttentionBlock, self).__init__() | |
self.W_g = nn.Sequential( | |
nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0), | |
nn.BatchNorm2d(F_int) | |
) | |
self.W_x = nn.Sequential( | |
nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0), | |
nn.BatchNorm2d(F_int) | |
) | |
self.psi = nn.Sequential( | |
nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0), | |
nn.BatchNorm2d(1), | |
nn.Sigmoid() | |
) | |
# Fixed: Change inplace ReLU to non-inplace | |
self.relu = nn.ReLU(inplace=False) | |
def forward(self, g, x): | |
g1 = self.W_g(g) | |
x1 = self.W_x(x) | |
psi = self.relu(g1 + x1) | |
psi = self.psi(psi) | |
return x * psi | |
class ResidualBlock(nn.Module): | |
def __init__(self, in_channels): | |
super(ResidualBlock, self).__init__() | |
self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1) | |
self.bn1 = nn.BatchNorm2d(in_channels) | |
self.conv2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1) | |
self.bn2 = nn.BatchNorm2d(in_channels) | |
# Fixed: Change inplace ReLU to non-inplace | |
self.relu = nn.ReLU(inplace=False) | |
def forward(self, x): | |
residual = x | |
out = self.relu(self.bn1(self.conv1(x))) | |
out = self.bn2(self.conv2(out)) | |
out += residual | |
out = self.relu(out) | |
return out | |
class PatchDiscriminator(nn.Module): | |
def __init__(self, in_channels=6): | |
super(PatchDiscriminator, self).__init__() | |
def discriminator_block(in_filters, out_filters, normalization=True): | |
layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)] | |
if normalization: | |
layers.append(nn.BatchNorm2d(out_filters)) | |
layers.append(nn.LeakyReLU(0.2, inplace=False)) # Fixed: inplace=False | |
return layers | |
self.model = nn.Sequential( | |
*discriminator_block(in_channels, 64, normalization=False), | |
*discriminator_block(64, 128), | |
*discriminator_block(128, 256), | |
*discriminator_block(256, 512), | |
nn.ZeroPad2d((1, 0, 1, 0)), | |
nn.Conv2d(512, 1, 4, padding=1, bias=False) | |
) | |
def forward(self, img_A, img_B): | |
# Concatenate image and condition | |
img_input = torch.cat((img_A, img_B), 1) | |
return self.model(img_input) | |
class ImprovedUNetGenerator(nn.Module): | |
def __init__(self, in_channels=6, out_channels=3): | |
super(ImprovedUNetGenerator, self).__init__() | |
# Encoder | |
self.enc1 = nn.Sequential( | |
nn.Conv2d(in_channels, 64, 4, 2, 1), | |
nn.LeakyReLU(0.2, inplace=False) # Fixed: inplace=False | |
) | |
self.enc2 = nn.Sequential( | |
nn.Conv2d(64, 128, 4, 2, 1), | |
nn.BatchNorm2d(128), | |
nn.LeakyReLU(0.2, inplace=False) # Fixed: inplace=False | |
) | |
self.enc3 = nn.Sequential( | |
nn.Conv2d(128, 256, 4, 2, 1), | |
nn.BatchNorm2d(256), | |
nn.LeakyReLU(0.2, inplace=False) # Fixed: inplace=False | |
) | |
self.enc4 = nn.Sequential( | |
nn.Conv2d(256, 512, 4, 2, 1), | |
nn.BatchNorm2d(512), | |
nn.LeakyReLU(0.2, inplace=False) # Fixed: inplace=False | |
) | |
self.enc5 = nn.Sequential( | |
nn.Conv2d(512, 512, 4, 2, 1), | |
nn.BatchNorm2d(512), | |
nn.LeakyReLU(0.2, inplace=False) # Fixed: inplace=False | |
) | |
# Bottleneck | |
self.bottleneck = ResidualBlock(512) | |
# Decoder | |
self.dec5 = nn.Sequential( | |
nn.ConvTranspose2d(512, 512, 4, 2, 1), | |
nn.BatchNorm2d(512), | |
nn.ReLU(inplace=False), # Fixed: inplace=False | |
nn.Dropout(0.5) | |
) | |
self.dec4 = nn.Sequential( | |
nn.ConvTranspose2d(1024, 256, 4, 2, 1), | |
nn.BatchNorm2d(256), | |
nn.ReLU(inplace=False), # Fixed: inplace=False | |
nn.Dropout(0.5) | |
) | |
self.dec3 = nn.Sequential( | |
nn.ConvTranspose2d(512, 128, 4, 2, 1), | |
nn.BatchNorm2d(128), | |
nn.ReLU(inplace=False) # Fixed: inplace=False | |
) | |
self.dec2 = nn.Sequential( | |
nn.ConvTranspose2d(256, 64, 4, 2, 1), | |
nn.BatchNorm2d(64), | |
nn.ReLU(inplace=False) # Fixed: inplace=False | |
) | |
self.dec1 = nn.Sequential( | |
nn.ConvTranspose2d(128, out_channels, 4, 2, 1), | |
nn.Tanh() | |
) | |
# Attention gates | |
self.att4 = AttentionBlock(F_g=512, F_l=512, F_int=256) | |
self.att3 = AttentionBlock(F_g=256, F_l=256, F_int=128) | |
self.att2 = AttentionBlock(F_g=128, F_l=128, F_int=64) | |
self.att1 = AttentionBlock(F_g=64, F_l=64, F_int=32) | |
def forward(self, x): | |
# Encoder | |
e1 = self.enc1(x) | |
e2 = self.enc2(e1) | |
e3 = self.enc3(e2) | |
e4 = self.enc4(e3) | |
e5 = self.enc5(e4) | |
# Bottleneck | |
b = self.bottleneck(e5) | |
# Decoder with attention and skip connections | |
d5 = self.dec5(b) | |
d5 = torch.cat([self.att4(d5, e4), d5], dim=1) | |
d4 = self.dec4(d5) | |
d4 = torch.cat([self.att3(d4, e3), d4], dim=1) | |
d3 = self.dec3(d4) | |
d3 = torch.cat([self.att2(d3, e2), d3], dim=1) | |
d2 = self.dec2(d3) | |
d2 = torch.cat([self.att1(d2, e1), d2], dim=1) | |
d1 = self.dec1(d2) | |
return d1 | |
# Discriminator network for adversarial training | |
class ImprovedUNetGenerator(nn.Module): | |
def __init__(self, in_channels=6, out_channels=3): | |
super(ImprovedUNetGenerator, self).__init__() | |
# Encoder | |
self.enc1 = nn.Sequential( | |
nn.Conv2d(in_channels, 64, 4, 2, 1), | |
nn.LeakyReLU(0.2, inplace=False) # Fixed: inplace=False | |
) | |
self.enc2 = nn.Sequential( | |
nn.Conv2d(64, 128, 4, 2, 1), | |
nn.BatchNorm2d(128), | |
nn.LeakyReLU(0.2, inplace=False) # Fixed: inplace=False | |
) | |
self.enc3 = nn.Sequential( | |
nn.Conv2d(128, 256, 4, 2, 1), | |
nn.BatchNorm2d(256), | |
nn.LeakyReLU(0.2, inplace=False) # Fixed: inplace=False | |
) | |
self.enc4 = nn.Sequential( | |
nn.Conv2d(256, 512, 4, 2, 1), | |
nn.BatchNorm2d(512), | |
nn.LeakyReLU(0.2, inplace=False) # Fixed: inplace=False | |
) | |
self.enc5 = nn.Sequential( | |
nn.Conv2d(512, 512, 4, 2, 1), | |
nn.BatchNorm2d(512), | |
nn.LeakyReLU(0.2, inplace=False) # Fixed: inplace=False | |
) | |
# Bottleneck | |
self.bottleneck = ResidualBlock(512) | |
# Decoder | |
self.dec5 = nn.Sequential( | |
nn.ConvTranspose2d(512, 512, 4, 2, 1), | |
nn.BatchNorm2d(512), | |
nn.ReLU(inplace=False), # Fixed: inplace=False | |
nn.Dropout(0.5) | |
) | |
self.dec4 = nn.Sequential( | |
nn.ConvTranspose2d(1024, 256, 4, 2, 1), | |
nn.BatchNorm2d(256), | |
nn.ReLU(inplace=False), # Fixed: inplace=False | |
nn.Dropout(0.5) | |
) | |
self.dec3 = nn.Sequential( | |
nn.ConvTranspose2d(512, 128, 4, 2, 1), | |
nn.BatchNorm2d(128), | |
nn.ReLU(inplace=False) # Fixed: inplace=False | |
) | |
self.dec2 = nn.Sequential( | |
nn.ConvTranspose2d(256, 64, 4, 2, 1), | |
nn.BatchNorm2d(64), | |
nn.ReLU(inplace=False) # Fixed: inplace=False | |
) | |
self.dec1 = nn.Sequential( | |
nn.ConvTranspose2d(128, out_channels, 4, 2, 1), | |
nn.Tanh() | |
) | |
# Attention gates | |
self.att4 = AttentionBlock(F_g=512, F_l=512, F_int=256) | |
self.att3 = AttentionBlock(F_g=256, F_l=256, F_int=128) | |
self.att2 = AttentionBlock(F_g=128, F_l=128, F_int=64) | |
self.att1 = AttentionBlock(F_g=64, F_l=64, F_int=32) | |
def forward(self, x): | |
# Encoder | |
e1 = self.enc1(x) | |
e2 = self.enc2(e1) | |
e3 = self.enc3(e2) | |
e4 = self.enc4(e3) | |
e5 = self.enc5(e4) | |
# Bottleneck | |
b = self.bottleneck(e5) | |
# Decoder with attention and skip connections | |
d5 = self.dec5(b) | |
d5 = torch.cat([self.att4(d5, e4), d5], dim=1) | |
d4 = self.dec4(d5) | |
d4 = torch.cat([self.att3(d4, e3), d4], dim=1) | |
d3 = self.dec3(d4) | |
d3 = torch.cat([self.att2(d3, e2), d3], dim=1) | |
d2 = self.dec2(d3) | |
d2 = torch.cat([self.att1(d2, e1), d2], dim=1) | |
d1 = self.dec1(d2) | |
return d1 | |
# Custom loss functions | |
class VGGPerceptualLoss(nn.Module): | |
def __init__(self): | |
super(VGGPerceptualLoss, self).__init__() | |
# Import vgg here to avoid dependency at module level | |
import torchvision.models as models | |
# Load pretrained VGG but make sure to use non-inplace operations | |
vgg = models.vgg19(pretrained=True).features.eval() | |
# Replace inplace ReLU with non-inplace version | |
for idx, module in enumerate(vgg): | |
if isinstance(module, nn.ReLU): | |
vgg[idx] = nn.ReLU(inplace=False) | |
self.model = nn.Sequential() | |
# Using feature layers | |
feature_layers = [0, 2, 5, 10, 15, 20] | |
self.layer_weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0] | |
for i in range(len(feature_layers)): | |
self.model.add_module(f'layer_{i}', vgg[feature_layers[i]]) | |
for param in self.model.parameters(): | |
param.requires_grad = False | |
self.criterion = nn.L1Loss() | |
self.register_buffer("mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) | |
self.register_buffer("std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) | |
def forward(self, x, y): | |
x = (x - self.mean) / self.std | |
y = (y - self.mean) / self.std | |
loss = 0.0 | |
x_features = x | |
y_features = y | |
for i, layer in enumerate(self.model): | |
x_features = layer(x_features) | |
y_features = layer(y_features) | |
if i in [0, 1, 2, 3, 4]: # Only compute loss at specified layers | |
loss += self.layer_weights[i] * self.criterion(x_features, y_features) | |
return loss | |
# Training setup | |
def train_model(model_G, model_D=None, train_loader=None, val_loader=None, | |
num_epochs=50, device=None, use_gan=True): | |
""" | |
Improved training function with GAN training, learning rate scheduler, and validation | |
""" | |
torch.autograd.set_detect_anomaly(True) | |
if device is None: | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# Optimizers | |
optimizer_G = torch.optim.Adam(model_G.parameters(), lr=2e-4, betas=(0.5, 0.999)) | |
scheduler_G = StepLR(optimizer_G, step_size=10, gamma=0.5) | |
# Losses | |
criterion_L1 = nn.L1Loss() | |
criterion_perceptual = VGGPerceptualLoss().to(device) | |
# GAN setup | |
if use_gan and model_D is not None: | |
optimizer_D = torch.optim.Adam(model_D.parameters(), lr=2e-4, betas=(0.5, 0.999)) | |
scheduler_D = StepLR(optimizer_D, step_size=10, gamma=0.5) | |
criterion_GAN = nn.MSELoss() | |
# Lists to store losses for plotting | |
train_losses_G = [] | |
train_losses_D = [] if use_gan else None | |
val_losses = [] | |
# Training loop | |
for epoch in range(num_epochs): | |
model_G.train() | |
if use_gan and model_D is not None: | |
model_D.train() | |
epoch_loss_G = 0.0 | |
epoch_loss_D = 0.0 if use_gan else None | |
start_time = time.time() | |
for i, sample in enumerate(train_loader): | |
agnostic = sample['agnostic'].to(device) | |
cloth = sample['cloth'].to(device) | |
target = sample['person'].to(device) | |
cloth_mask = sample['cloth_mask'].to(device) | |
# Combine inputs | |
input_tensor = torch.cat([agnostic, cloth], dim=1) | |
# ----------------- | |
# Generator training | |
# ----------------- | |
optimizer_G.zero_grad() | |
# Generate fake image | |
fake_image = model_G(input_tensor) | |
# Calculate L1 loss | |
loss_L1 = criterion_L1(fake_image, target) | |
# Calculate perceptual loss | |
loss_perceptual = criterion_perceptual(fake_image, target) | |
# Calculate total generator loss | |
loss_G = loss_L1 + 0.1 * loss_perceptual | |
# Add GAN loss if using adversarial training | |
if use_gan and model_D is not None: | |
# Adversarial loss (trick for stability: use 1s instead of 0.9) | |
pred_fake = model_D(fake_image, cloth) | |
target_real = torch.ones_like(pred_fake).to(device) | |
loss_GAN = criterion_GAN(pred_fake, target_real) | |
# Total generator loss with GAN component | |
loss_G += 0.1 * loss_GAN | |
# Backward pass and optimize generator | |
loss_G.backward() | |
optimizer_G.step() | |
epoch_loss_G += loss_G.item() | |
# ----------------- | |
# Discriminator training (if using GAN) | |
# ----------------- | |
if use_gan and model_D is not None: | |
optimizer_D.zero_grad() | |
# Real loss | |
pred_real = model_D(target, cloth) | |
target_real = torch.ones_like(pred_real).to(device) | |
loss_real = criterion_GAN(pred_real, target_real) | |
# Fake loss | |
pred_fake = model_D(fake_image.detach(), cloth) | |
target_fake = torch.zeros_like(pred_fake).to(device) | |
loss_fake = criterion_GAN(pred_fake, target_fake) | |
# Total discriminator loss | |
loss_D = (loss_real + loss_fake) / 2 | |
# Backward pass and optimize discriminator | |
loss_D.backward() | |
optimizer_D.step() | |
epoch_loss_D += loss_D.item() | |
# Print progress | |
if (i+1) % 50 == 0: | |
time_elapsed = time.time() - start_time | |
print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], " | |
f"G Loss: {loss_G.item():.4f}, " | |
f"{'D Loss: ' + f'{loss_D.item():.4f}, ' if use_gan else ''}" | |
f"Time: {time_elapsed:.2f}s") | |
# Update learning rates | |
scheduler_G.step() | |
if use_gan and model_D is not None: | |
scheduler_D.step() | |
# Calculate average losses for this epoch | |
avg_loss_G = epoch_loss_G / len(train_loader) | |
train_losses_G.append(avg_loss_G) | |
if use_gan: | |
avg_loss_D = epoch_loss_D / len(train_loader) | |
train_losses_D.append(avg_loss_D) | |
# Validation | |
if val_loader is not None: | |
val_loss = validate_model(model_G, val_loader, device) | |
val_losses.append(val_loss) | |
print(f"Epoch {epoch+1}, Train Loss G: {avg_loss_G:.4f}, " | |
f"{'Train Loss D: ' + f'{avg_loss_D:.4f}, ' if use_gan else ''}" | |
f"Val Loss: {val_loss:.4f}, " | |
f"Time: {time.time()-start_time:.2f}s") | |
else: | |
print(f"Epoch {epoch+1}, Train Loss G: {avg_loss_G:.4f}, " | |
f"{'Train Loss D: ' + f'{avg_loss_D:.4f}, ' if use_gan else ''}" | |
f"Time: {time.time()-start_time:.2f}s") | |
# Save model checkpoint periodically | |
if (epoch+1) % 5 == 0: | |
save_checkpoint(model_G, model_D, optimizer_G, optimizer_D if use_gan else None, | |
epoch, f"checkpoint_epoch_{epoch+1}.pth") | |
# Visualize some results | |
if (epoch+1) % 5 == 0: | |
visualize_results(model_G, val_loader, device, epoch+1) | |
# Plot training losses | |
plot_losses(train_losses_G, train_losses_D, val_losses) | |
return model_G, model_D | |
def validate_model(model, val_loader, device): | |
"""Validate the model on validation set""" | |
model.eval() | |
val_loss = 0.0 | |
criterion = nn.L1Loss() | |
with torch.no_grad(): | |
for sample in val_loader: | |
agnostic = sample['agnostic'].to(device) | |
cloth = sample['cloth'].to(device) | |
target = sample['person'].to(device) | |
input_tensor = torch.cat([agnostic, cloth], dim=1) | |
output = model(input_tensor) | |
loss = criterion(output, target) | |
val_loss += loss.item() | |
return val_loss / len(val_loader) | |
def visualize_results(model, dataloader, device, epoch): | |
"""Visualize generated try-on results""" | |
model.eval() | |
# Get a batch of samples | |
for i, sample in enumerate(dataloader): | |
if i >= 1: # Only visualize first batch | |
break | |
with torch.no_grad(): | |
agnostic = sample['agnostic'].to(device) | |
cloth = sample['cloth'].to(device) | |
target = sample['person'].to(device) | |
input_tensor = torch.cat([agnostic, cloth], dim=1) | |
output = model(input_tensor) | |
# Convert tensors for visualization | |
imgs = [] | |
for j in range(min(4, output.size(0))): # Show max 4 examples | |
person_img = (target[j].cpu().permute(1, 2, 0).numpy() + 1) / 2 | |
agnostic_img = (agnostic[j].cpu().permute(1, 2, 0).numpy() + 1) / 2 | |
cloth_img = (cloth[j].cpu().permute(1, 2, 0).numpy() + 1) / 2 | |
output_img = (output[j].cpu().permute(1, 2, 0).numpy() + 1) / 2 | |
# Combine images for visualization | |
row1 = np.hstack([agnostic_img, cloth_img]) | |
row2 = np.hstack([output_img, person_img]) | |
combined = np.vstack([row1, row2]) | |
imgs.append(combined) | |
# Create figure | |
fig, axs = plt.subplots(1, len(imgs), figsize=(15, 5)) | |
if len(imgs) == 1: | |
axs = [axs] | |
for j, img in enumerate(imgs): | |
axs[j].imshow(img) | |
axs[j].set_title(f"Sample {j+1}") | |
axs[j].axis('off') | |
fig.suptitle(f"Epoch {epoch} Results", fontsize=16) | |
plt.tight_layout() | |
# Save figure | |
os.makedirs('results', exist_ok=True) | |
plt.savefig(f'results/epoch_{epoch}_samples.png') | |
plt.close() | |
def plot_losses(train_losses_G, train_losses_D=None, val_losses=None): | |
"""Plot training and validation losses""" | |
plt.figure(figsize=(10, 5)) | |
plt.plot(train_losses_G, label='Generator Loss') | |
if train_losses_D: | |
plt.plot(train_losses_D, label='Discriminator Loss') | |
if val_losses: | |
plt.plot(val_losses, label='Validation Loss') | |
plt.xlabel('Epochs') | |
plt.ylabel('Loss') | |
plt.title('Training and Validation Losses') | |
plt.legend() | |
plt.grid(True) | |
os.makedirs('results', exist_ok=True) | |
plt.savefig('results/loss_plot.png') | |
plt.close() | |
def save_checkpoint(model_G, model_D=None, optimizer_G=None, optimizer_D=None, epoch=None, filename="checkpoint.pth"): | |
"""Save model checkpoint""" | |
os.makedirs('checkpoints', exist_ok=True) | |
checkpoint = { | |
'epoch': epoch, | |
'model_G_state_dict': model_G.state_dict(), | |
'optimizer_G_state_dict': optimizer_G.state_dict() if optimizer_G else None, | |
} | |
if model_D is not None: | |
checkpoint['model_D_state_dict'] = model_D.state_dict() | |
if optimizer_D is not None: | |
checkpoint['optimizer_D_state_dict'] = optimizer_D.state_dict() | |
torch.save(checkpoint, f'checkpoints/{filename}') | |
def load_checkpoint(model_G, model_D=None, optimizer_G=None, optimizer_D=None, filename="checkpoint.pth"): | |
"""Load model checkpoint""" | |
checkpoint = torch.load(f'checkpoints/{filename}') | |
model_G.load_state_dict(checkpoint['model_G_state_dict']) | |
if optimizer_G and 'optimizer_G_state_dict' in checkpoint: | |
optimizer_G.load_state_dict(checkpoint['optimizer_G_state_dict']) | |
if model_D is not None and 'model_D_state_dict' in checkpoint: | |
model_D.load_state_dict(checkpoint['model_D_state_dict']) | |
if optimizer_D is not None and 'optimizer_D_state_dict' in checkpoint: | |
optimizer_D.load_state_dict(checkpoint['optimizer_D_state_dict']) | |
return checkpoint.get('epoch', 0) | |
# Test function | |
def test_model(model, test_loader, device, result_dir='test_results'): | |
"""Generate and save test results""" | |
model.eval() | |
os.makedirs(result_dir, exist_ok=True) | |
with torch.no_grad(): | |
for i, sample in enumerate(test_loader): | |
agnostic = sample['agnostic'].to(device) | |
cloth = sample['cloth'].to(device) | |
target = sample['person'].to(device) | |
name = sample['name'][0] # Get sample name | |
# Generate try-on result | |
input_tensor = torch.cat([agnostic, cloth], dim=1) | |
output = model(input_tensor) | |
# Save images | |
output_img = (output[0].cpu().permute(1, 2, 0).numpy() + 1) / 2 | |
target_img = (target[0].cpu().permute(1, 2, 0).numpy() + 1) / 2 | |
agnostic_img = (agnostic[0].cpu().permute(1, 2, 0).numpy() + 1) / 2 | |
cloth_img = (cloth[0].cpu().permute(1, 2, 0).numpy() + 1) / 2 | |
# Save individual images | |
plt.imsave(f'{result_dir}/{name}_output.png', output_img) | |
plt.imsave(f'{result_dir}/{name}_target.png', target_img) | |
# Save comparison grid | |
fig, ax = plt.subplots(2, 2, figsize=(12, 12)) | |
ax[0, 0].imshow(agnostic_img) | |
ax[0, 0].set_title('Person (w/o clothes)') | |
ax[0, 0].axis('off') | |
ax[0, 1].imshow(cloth_img) | |
ax[0, 1].set_title('Clothing Item') | |
ax[0, 1].axis('off') | |
ax[1, 0].imshow(output_img) | |
ax[1, 0].set_title('Generated Result') | |
ax[1, 0].axis('off') | |
ax[1, 1].imshow(target_img) | |
ax[1, 1].set_title('Ground Truth') | |
ax[1, 1].axis('off') | |
plt.tight_layout() | |
plt.savefig(f'{result_dir}/{name}_comparison.png') | |
plt.close() | |
if (i+1) % 10 == 0: | |
print(f"Processed {i+1}/{len(test_loader)} test samples") |