virtual-try-on / improved_viton.py
sengourav012's picture
Upload improved_viton.py
a224cb9 verified
import os
import numpy as np
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
from torch.optim.lr_scheduler import StepLR
import random
import cv2
# Ensure reproducibility
def set_seed(seed=42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
os.environ['PYTHONHASHSEED'] = str(seed)
# Improved dataset handling with proper data augmentation
class VITONDataset(Dataset):
def __init__(self, root_dir, mode='train', transform=None, augment=False):
"""
Enhanced dataset class with better error handling and data augmentation
Args:
root_dir: Root directory of the dataset
mode: 'train' or 'test'
transform: Transforms to apply to images
augment: Whether to apply data augmentation
"""
self.root_dir = root_dir
self.mode = mode
self.transform = transform
self.augment = augment
# Check if directories exist
img_dir = os.path.join(root_dir, f'{mode}_img')
cloth_dir = os.path.join(root_dir, f'{mode}_color')
label_dir = os.path.join(root_dir, f'{mode}_label')
if not os.path.exists(img_dir) or not os.path.exists(cloth_dir) or not os.path.exists(label_dir):
raise FileNotFoundError(f"One or more dataset directories not found in {root_dir}")
# Get all image names
self.image_names = []
for f in sorted(os.listdir(img_dir)):
if f.endswith('.jpg'):
# Make sure corresponding files exist
base_name = f.replace('_0.jpg', '')
cloth_path = os.path.join(cloth_dir, f"{base_name}_1.jpg")
label_path = os.path.join(label_dir, f"{base_name}_0.png")
if os.path.exists(cloth_path) and os.path.exists(label_path):
self.image_names.append(base_name)
print(f"Found {len(self.image_names)} valid samples in {mode} set")
def __len__(self):
return len(self.image_names)
def _apply_augmentation(self, img, cloth, label):
"""Apply data augmentation"""
# Random horizontal flip
if random.random() > 0.5:
img = img.transpose(Image.FLIP_LEFT_RIGHT)
cloth = cloth.transpose(Image.FLIP_LEFT_RIGHT)
label = label.transpose(Image.FLIP_LEFT_RIGHT)
# Random brightness and contrast adjustment for person image
if random.random() > 0.7:
img = transforms.functional.adjust_brightness(img, random.uniform(0.8, 1.2))
img = transforms.functional.adjust_contrast(img, random.uniform(0.8, 1.2))
# Random color jitter for clothing
if random.random() > 0.7:
cloth = transforms.functional.adjust_brightness(cloth, random.uniform(0.8, 1.2))
cloth = transforms.functional.adjust_saturation(cloth, random.uniform(0.8, 1.2))
return img, cloth, label
def __getitem__(self, idx):
base_name = self.image_names[idx]
# Build file paths
img_path = os.path.join(self.root_dir, f'{self.mode}_img', f"{base_name}_0.jpg")
cloth_path = os.path.join(self.root_dir, f'{self.mode}_color', f"{base_name}_1.jpg")
label_path = os.path.join(self.root_dir, f'{self.mode}_label', f"{base_name}_0.png")
try:
# Load images
img = Image.open(img_path).convert('RGB').resize((192, 256))
cloth = Image.open(cloth_path).convert('RGB').resize((192, 256))
label = Image.open(label_path).convert('L').resize((192, 256), resample=Image.NEAREST)
# Apply augmentation if enabled
if self.augment and self.mode == 'train':
img, cloth, label = self._apply_augmentation(img, cloth, label)
# Convert label to numpy for processing
img_np = np.array(img)
label_np = np.array(label)
# Create agnostic person image (remove upclothes → label 4)
agnostic_np = img_np.copy()
agnostic_np[label_np == 4] = [128, 128, 128] # Grey out clothing region
# Create cloth mask (binary mask of clothing)
cloth_mask = (label_np == 4).astype(np.uint8) * 255
cloth_mask_img = Image.fromarray(cloth_mask)
# Apply transforms
to_tensor = self.transform if self.transform else transforms.ToTensor()
person_tensor = to_tensor(img)
agnostic_tensor = to_tensor(Image.fromarray(agnostic_np))
cloth_tensor = to_tensor(cloth)
# Fix: Handle cloth mask properly
if self.transform:
# Convert to RGB for consistent channel handling
cloth_mask_rgb = Image.fromarray(cloth_mask).convert('RGB')
cloth_mask_tensor = to_tensor(cloth_mask_rgb)
else:
# Simple ToTensor() normalization for grayscale image
cloth_mask_tensor = transforms.ToTensor()(cloth_mask_img)
# If needed, expand to 3 channels
if cloth_tensor.shape[0] == 3:
cloth_mask_tensor = cloth_mask_tensor.expand(3, -1, -1)
# One-hot encode the segmentation mask
label_tensor = torch.from_numpy(label_np).long()
sample = {
'person': person_tensor,
'agnostic': agnostic_tensor,
'cloth': cloth_tensor,
'cloth_mask': cloth_mask_tensor,
'label': label_tensor,
'name': base_name
}
return sample
except Exception as e:
print(f"Error loading sample {base_name}: {e}")
# Return a valid sample as fallback - get a different index
return self.__getitem__((idx + 1) % len(self.image_names))
# class VITONDataset(Dataset):
# def __init__(self, root_dir, mode='train', transform=None, augment=False):
# """
# Enhanced dataset class with better error handling and data augmentation
# Args:
# root_dir: Root directory of the dataset
# mode: 'train' or 'test'
# transform: Transforms to apply to images
# augment: Whether to apply data augmentation
# """
# self.root_dir = root_dir
# self.mode = mode
# self.transform = transform
# self.augment = augment
# # Check if directories exist
# img_dir = os.path.join(root_dir, f'{mode}_img')
# cloth_dir = os.path.join(root_dir, f'{mode}_color')
# label_dir = os.path.join(root_dir, f'{mode}_label')
# if not os.path.exists(img_dir) or not os.path.exists(cloth_dir) or not os.path.exists(label_dir):
# raise FileNotFoundError(f"One or more dataset directories not found in {root_dir}")
# # Get all image names
# self.image_names = []
# for f in sorted(os.listdir(img_dir)):
# if f.endswith('.jpg'):
# # Make sure corresponding files exist
# base_name = f.replace('_0.jpg', '')
# cloth_path = os.path.join(cloth_dir, f"{base_name}_1.jpg")
# label_path = os.path.join(label_dir, f"{base_name}_0.png")
# if os.path.exists(cloth_path) and os.path.exists(label_path):
# self.image_names.append(base_name)
# print(f"Found {len(self.image_names)} valid samples in {mode} set")
# def __len__(self):
# return len(self.image_names)
# def _apply_augmentation(self, img, cloth, label):
# """Apply data augmentation"""
# # Random horizontal flip
# if random.random() > 0.5:
# img = img.transpose(Image.FLIP_LEFT_RIGHT)
# cloth = cloth.transpose(Image.FLIP_LEFT_RIGHT)
# label = label.transpose(Image.FLIP_LEFT_RIGHT)
# # Random brightness and contrast adjustment for person image
# if random.random() > 0.7:
# img = transforms.functional.adjust_brightness(img, random.uniform(0.8, 1.2))
# img = transforms.functional.adjust_contrast(img, random.uniform(0.8, 1.2))
# # Random color jitter for clothing
# if random.random() > 0.7:
# cloth = transforms.functional.adjust_brightness(cloth, random.uniform(0.8, 1.2))
# cloth = transforms.functional.adjust_saturation(cloth, random.uniform(0.8, 1.2))
# return img, cloth, label
# def __getitem__(self, idx):
# base_name = self.image_names[idx]
# # Build file paths
# img_path = os.path.join(self.root_dir, f'{self.mode}_img', f"{base_name}_0.jpg")
# cloth_path = os.path.join(self.root_dir, f'{self.mode}_color', f"{base_name}_1.jpg")
# label_path = os.path.join(self.root_dir, f'{self.mode}_label', f"{base_name}_0.png")
# try:
# # Load images
# img = Image.open(img_path).convert('RGB').resize((192, 256))
# cloth = Image.open(cloth_path).convert('RGB').resize((192, 256))
# label = Image.open(label_path).convert('L').resize((192, 256), resample=Image.NEAREST)
# # Apply augmentation if enabled
# if self.augment and self.mode == 'train':
# img, cloth, label = self._apply_augmentation(img, cloth, label)
# # Convert label to numpy for processing
# img_np = np.array(img)
# label_np = np.array(label)
# # Create agnostic person image (remove upclothes → label 4)
# agnostic_np = img_np.copy()
# agnostic_np[label_np == 4] = [128, 128, 128] # Grey out clothing region
# # Create cloth mask (binary mask of clothing)
# cloth_mask = (label_np == 4).astype(np.uint8) * 255
# cloth_mask_img = Image.fromarray(cloth_mask)
# # Apply transforms
# to_tensor = self.transform if self.transform else transforms.ToTensor()
# person_tensor = to_tensor(img)
# agnostic_tensor = to_tensor(Image.fromarray(agnostic_np))
# cloth_tensor = to_tensor(cloth)
# # Fix: Ensure the cloth mask is properly processed to match expected dimensions
# # First convert to Pillow Image with mode 'L' (grayscale)
# cloth_mask_pil = Image.fromarray(cloth_mask, mode='L')
# # Then apply the transform (which should normalize to [-1, 1] range)
# if self.transform:
# # For custom transform that expects RGB input, convert grayscale to RGB
# cloth_mask_rgb = cloth_mask_pil.convert('RGB')
# cloth_mask_tensor = self.transform(cloth_mask_rgb)
# else:
# # If using basic ToTensor, keep as grayscale but repeat to 3 channels if needed
# cloth_mask_tensor = transforms.ToTensor()(cloth_mask_pil)
# # If model expects 3 channels, repeat the single channel
# if cloth_tensor.shape[0] == 3: # If cloth is RGB (3 channels)
# cloth_mask_tensor = cloth_mask_tensor.repeat(3, 1, 1)
# # One-hot encode the segmentation mask
# label_tensor = torch.from_numpy(label_np).long()
# sample = {
# 'person': person_tensor,
# 'agnostic': agnostic_tensor,
# 'cloth': cloth_tensor,
# 'cloth_mask': cloth_mask_tensor,
# 'label': label_tensor,
# 'name': base_name
# }
# return sample
# except Exception as e:
# print(f"Error loading sample {base_name}: {e}")
# # Return a valid sample as fallback - get a different index
# return self.__getitem__((idx + 1) % len(self.image_names))
# Improved U-Net with residual connections and attention
class AttentionBlock(nn.Module):
def __init__(self, F_g, F_l, F_int):
super(AttentionBlock, self).__init__()
self.W_g = nn.Sequential(
nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(F_int)
)
self.W_x = nn.Sequential(
nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(F_int)
)
self.psi = nn.Sequential(
nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(1),
nn.Sigmoid()
)
# Fixed: Change inplace ReLU to non-inplace
self.relu = nn.ReLU(inplace=False)
def forward(self, g, x):
g1 = self.W_g(g)
x1 = self.W_x(x)
psi = self.relu(g1 + x1)
psi = self.psi(psi)
return x * psi
class ResidualBlock(nn.Module):
def __init__(self, in_channels):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(in_channels)
self.conv2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(in_channels)
# Fixed: Change inplace ReLU to non-inplace
self.relu = nn.ReLU(inplace=False)
def forward(self, x):
residual = x
out = self.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += residual
out = self.relu(out)
return out
class PatchDiscriminator(nn.Module):
def __init__(self, in_channels=6):
super(PatchDiscriminator, self).__init__()
def discriminator_block(in_filters, out_filters, normalization=True):
layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
if normalization:
layers.append(nn.BatchNorm2d(out_filters))
layers.append(nn.LeakyReLU(0.2, inplace=False)) # Fixed: inplace=False
return layers
self.model = nn.Sequential(
*discriminator_block(in_channels, 64, normalization=False),
*discriminator_block(64, 128),
*discriminator_block(128, 256),
*discriminator_block(256, 512),
nn.ZeroPad2d((1, 0, 1, 0)),
nn.Conv2d(512, 1, 4, padding=1, bias=False)
)
def forward(self, img_A, img_B):
# Concatenate image and condition
img_input = torch.cat((img_A, img_B), 1)
return self.model(img_input)
class ImprovedUNetGenerator(nn.Module):
def __init__(self, in_channels=6, out_channels=3):
super(ImprovedUNetGenerator, self).__init__()
# Encoder
self.enc1 = nn.Sequential(
nn.Conv2d(in_channels, 64, 4, 2, 1),
nn.LeakyReLU(0.2, inplace=False) # Fixed: inplace=False
)
self.enc2 = nn.Sequential(
nn.Conv2d(64, 128, 4, 2, 1),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, inplace=False) # Fixed: inplace=False
)
self.enc3 = nn.Sequential(
nn.Conv2d(128, 256, 4, 2, 1),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2, inplace=False) # Fixed: inplace=False
)
self.enc4 = nn.Sequential(
nn.Conv2d(256, 512, 4, 2, 1),
nn.BatchNorm2d(512),
nn.LeakyReLU(0.2, inplace=False) # Fixed: inplace=False
)
self.enc5 = nn.Sequential(
nn.Conv2d(512, 512, 4, 2, 1),
nn.BatchNorm2d(512),
nn.LeakyReLU(0.2, inplace=False) # Fixed: inplace=False
)
# Bottleneck
self.bottleneck = ResidualBlock(512)
# Decoder
self.dec5 = nn.Sequential(
nn.ConvTranspose2d(512, 512, 4, 2, 1),
nn.BatchNorm2d(512),
nn.ReLU(inplace=False), # Fixed: inplace=False
nn.Dropout(0.5)
)
self.dec4 = nn.Sequential(
nn.ConvTranspose2d(1024, 256, 4, 2, 1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=False), # Fixed: inplace=False
nn.Dropout(0.5)
)
self.dec3 = nn.Sequential(
nn.ConvTranspose2d(512, 128, 4, 2, 1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=False) # Fixed: inplace=False
)
self.dec2 = nn.Sequential(
nn.ConvTranspose2d(256, 64, 4, 2, 1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=False) # Fixed: inplace=False
)
self.dec1 = nn.Sequential(
nn.ConvTranspose2d(128, out_channels, 4, 2, 1),
nn.Tanh()
)
# Attention gates
self.att4 = AttentionBlock(F_g=512, F_l=512, F_int=256)
self.att3 = AttentionBlock(F_g=256, F_l=256, F_int=128)
self.att2 = AttentionBlock(F_g=128, F_l=128, F_int=64)
self.att1 = AttentionBlock(F_g=64, F_l=64, F_int=32)
def forward(self, x):
# Encoder
e1 = self.enc1(x)
e2 = self.enc2(e1)
e3 = self.enc3(e2)
e4 = self.enc4(e3)
e5 = self.enc5(e4)
# Bottleneck
b = self.bottleneck(e5)
# Decoder with attention and skip connections
d5 = self.dec5(b)
d5 = torch.cat([self.att4(d5, e4), d5], dim=1)
d4 = self.dec4(d5)
d4 = torch.cat([self.att3(d4, e3), d4], dim=1)
d3 = self.dec3(d4)
d3 = torch.cat([self.att2(d3, e2), d3], dim=1)
d2 = self.dec2(d3)
d2 = torch.cat([self.att1(d2, e1), d2], dim=1)
d1 = self.dec1(d2)
return d1
# Discriminator network for adversarial training
class ImprovedUNetGenerator(nn.Module):
def __init__(self, in_channels=6, out_channels=3):
super(ImprovedUNetGenerator, self).__init__()
# Encoder
self.enc1 = nn.Sequential(
nn.Conv2d(in_channels, 64, 4, 2, 1),
nn.LeakyReLU(0.2, inplace=False) # Fixed: inplace=False
)
self.enc2 = nn.Sequential(
nn.Conv2d(64, 128, 4, 2, 1),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, inplace=False) # Fixed: inplace=False
)
self.enc3 = nn.Sequential(
nn.Conv2d(128, 256, 4, 2, 1),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2, inplace=False) # Fixed: inplace=False
)
self.enc4 = nn.Sequential(
nn.Conv2d(256, 512, 4, 2, 1),
nn.BatchNorm2d(512),
nn.LeakyReLU(0.2, inplace=False) # Fixed: inplace=False
)
self.enc5 = nn.Sequential(
nn.Conv2d(512, 512, 4, 2, 1),
nn.BatchNorm2d(512),
nn.LeakyReLU(0.2, inplace=False) # Fixed: inplace=False
)
# Bottleneck
self.bottleneck = ResidualBlock(512)
# Decoder
self.dec5 = nn.Sequential(
nn.ConvTranspose2d(512, 512, 4, 2, 1),
nn.BatchNorm2d(512),
nn.ReLU(inplace=False), # Fixed: inplace=False
nn.Dropout(0.5)
)
self.dec4 = nn.Sequential(
nn.ConvTranspose2d(1024, 256, 4, 2, 1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=False), # Fixed: inplace=False
nn.Dropout(0.5)
)
self.dec3 = nn.Sequential(
nn.ConvTranspose2d(512, 128, 4, 2, 1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=False) # Fixed: inplace=False
)
self.dec2 = nn.Sequential(
nn.ConvTranspose2d(256, 64, 4, 2, 1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=False) # Fixed: inplace=False
)
self.dec1 = nn.Sequential(
nn.ConvTranspose2d(128, out_channels, 4, 2, 1),
nn.Tanh()
)
# Attention gates
self.att4 = AttentionBlock(F_g=512, F_l=512, F_int=256)
self.att3 = AttentionBlock(F_g=256, F_l=256, F_int=128)
self.att2 = AttentionBlock(F_g=128, F_l=128, F_int=64)
self.att1 = AttentionBlock(F_g=64, F_l=64, F_int=32)
def forward(self, x):
# Encoder
e1 = self.enc1(x)
e2 = self.enc2(e1)
e3 = self.enc3(e2)
e4 = self.enc4(e3)
e5 = self.enc5(e4)
# Bottleneck
b = self.bottleneck(e5)
# Decoder with attention and skip connections
d5 = self.dec5(b)
d5 = torch.cat([self.att4(d5, e4), d5], dim=1)
d4 = self.dec4(d5)
d4 = torch.cat([self.att3(d4, e3), d4], dim=1)
d3 = self.dec3(d4)
d3 = torch.cat([self.att2(d3, e2), d3], dim=1)
d2 = self.dec2(d3)
d2 = torch.cat([self.att1(d2, e1), d2], dim=1)
d1 = self.dec1(d2)
return d1
# Custom loss functions
class VGGPerceptualLoss(nn.Module):
def __init__(self):
super(VGGPerceptualLoss, self).__init__()
# Import vgg here to avoid dependency at module level
import torchvision.models as models
# Load pretrained VGG but make sure to use non-inplace operations
vgg = models.vgg19(pretrained=True).features.eval()
# Replace inplace ReLU with non-inplace version
for idx, module in enumerate(vgg):
if isinstance(module, nn.ReLU):
vgg[idx] = nn.ReLU(inplace=False)
self.model = nn.Sequential()
# Using feature layers
feature_layers = [0, 2, 5, 10, 15, 20]
self.layer_weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0]
for i in range(len(feature_layers)):
self.model.add_module(f'layer_{i}', vgg[feature_layers[i]])
for param in self.model.parameters():
param.requires_grad = False
self.criterion = nn.L1Loss()
self.register_buffer("mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
self.register_buffer("std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
def forward(self, x, y):
x = (x - self.mean) / self.std
y = (y - self.mean) / self.std
loss = 0.0
x_features = x
y_features = y
for i, layer in enumerate(self.model):
x_features = layer(x_features)
y_features = layer(y_features)
if i in [0, 1, 2, 3, 4]: # Only compute loss at specified layers
loss += self.layer_weights[i] * self.criterion(x_features, y_features)
return loss
# Training setup
def train_model(model_G, model_D=None, train_loader=None, val_loader=None,
num_epochs=50, device=None, use_gan=True):
"""
Improved training function with GAN training, learning rate scheduler, and validation
"""
torch.autograd.set_detect_anomaly(True)
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Optimizers
optimizer_G = torch.optim.Adam(model_G.parameters(), lr=2e-4, betas=(0.5, 0.999))
scheduler_G = StepLR(optimizer_G, step_size=10, gamma=0.5)
# Losses
criterion_L1 = nn.L1Loss()
criterion_perceptual = VGGPerceptualLoss().to(device)
# GAN setup
if use_gan and model_D is not None:
optimizer_D = torch.optim.Adam(model_D.parameters(), lr=2e-4, betas=(0.5, 0.999))
scheduler_D = StepLR(optimizer_D, step_size=10, gamma=0.5)
criterion_GAN = nn.MSELoss()
# Lists to store losses for plotting
train_losses_G = []
train_losses_D = [] if use_gan else None
val_losses = []
# Training loop
for epoch in range(num_epochs):
model_G.train()
if use_gan and model_D is not None:
model_D.train()
epoch_loss_G = 0.0
epoch_loss_D = 0.0 if use_gan else None
start_time = time.time()
for i, sample in enumerate(train_loader):
agnostic = sample['agnostic'].to(device)
cloth = sample['cloth'].to(device)
target = sample['person'].to(device)
cloth_mask = sample['cloth_mask'].to(device)
# Combine inputs
input_tensor = torch.cat([agnostic, cloth], dim=1)
# -----------------
# Generator training
# -----------------
optimizer_G.zero_grad()
# Generate fake image
fake_image = model_G(input_tensor)
# Calculate L1 loss
loss_L1 = criterion_L1(fake_image, target)
# Calculate perceptual loss
loss_perceptual = criterion_perceptual(fake_image, target)
# Calculate total generator loss
loss_G = loss_L1 + 0.1 * loss_perceptual
# Add GAN loss if using adversarial training
if use_gan and model_D is not None:
# Adversarial loss (trick for stability: use 1s instead of 0.9)
pred_fake = model_D(fake_image, cloth)
target_real = torch.ones_like(pred_fake).to(device)
loss_GAN = criterion_GAN(pred_fake, target_real)
# Total generator loss with GAN component
loss_G += 0.1 * loss_GAN
# Backward pass and optimize generator
loss_G.backward()
optimizer_G.step()
epoch_loss_G += loss_G.item()
# -----------------
# Discriminator training (if using GAN)
# -----------------
if use_gan and model_D is not None:
optimizer_D.zero_grad()
# Real loss
pred_real = model_D(target, cloth)
target_real = torch.ones_like(pred_real).to(device)
loss_real = criterion_GAN(pred_real, target_real)
# Fake loss
pred_fake = model_D(fake_image.detach(), cloth)
target_fake = torch.zeros_like(pred_fake).to(device)
loss_fake = criterion_GAN(pred_fake, target_fake)
# Total discriminator loss
loss_D = (loss_real + loss_fake) / 2
# Backward pass and optimize discriminator
loss_D.backward()
optimizer_D.step()
epoch_loss_D += loss_D.item()
# Print progress
if (i+1) % 50 == 0:
time_elapsed = time.time() - start_time
print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], "
f"G Loss: {loss_G.item():.4f}, "
f"{'D Loss: ' + f'{loss_D.item():.4f}, ' if use_gan else ''}"
f"Time: {time_elapsed:.2f}s")
# Update learning rates
scheduler_G.step()
if use_gan and model_D is not None:
scheduler_D.step()
# Calculate average losses for this epoch
avg_loss_G = epoch_loss_G / len(train_loader)
train_losses_G.append(avg_loss_G)
if use_gan:
avg_loss_D = epoch_loss_D / len(train_loader)
train_losses_D.append(avg_loss_D)
# Validation
if val_loader is not None:
val_loss = validate_model(model_G, val_loader, device)
val_losses.append(val_loss)
print(f"Epoch {epoch+1}, Train Loss G: {avg_loss_G:.4f}, "
f"{'Train Loss D: ' + f'{avg_loss_D:.4f}, ' if use_gan else ''}"
f"Val Loss: {val_loss:.4f}, "
f"Time: {time.time()-start_time:.2f}s")
else:
print(f"Epoch {epoch+1}, Train Loss G: {avg_loss_G:.4f}, "
f"{'Train Loss D: ' + f'{avg_loss_D:.4f}, ' if use_gan else ''}"
f"Time: {time.time()-start_time:.2f}s")
# Save model checkpoint periodically
if (epoch+1) % 5 == 0:
save_checkpoint(model_G, model_D, optimizer_G, optimizer_D if use_gan else None,
epoch, f"checkpoint_epoch_{epoch+1}.pth")
# Visualize some results
if (epoch+1) % 5 == 0:
visualize_results(model_G, val_loader, device, epoch+1)
# Plot training losses
plot_losses(train_losses_G, train_losses_D, val_losses)
return model_G, model_D
def validate_model(model, val_loader, device):
"""Validate the model on validation set"""
model.eval()
val_loss = 0.0
criterion = nn.L1Loss()
with torch.no_grad():
for sample in val_loader:
agnostic = sample['agnostic'].to(device)
cloth = sample['cloth'].to(device)
target = sample['person'].to(device)
input_tensor = torch.cat([agnostic, cloth], dim=1)
output = model(input_tensor)
loss = criterion(output, target)
val_loss += loss.item()
return val_loss / len(val_loader)
def visualize_results(model, dataloader, device, epoch):
"""Visualize generated try-on results"""
model.eval()
# Get a batch of samples
for i, sample in enumerate(dataloader):
if i >= 1: # Only visualize first batch
break
with torch.no_grad():
agnostic = sample['agnostic'].to(device)
cloth = sample['cloth'].to(device)
target = sample['person'].to(device)
input_tensor = torch.cat([agnostic, cloth], dim=1)
output = model(input_tensor)
# Convert tensors for visualization
imgs = []
for j in range(min(4, output.size(0))): # Show max 4 examples
person_img = (target[j].cpu().permute(1, 2, 0).numpy() + 1) / 2
agnostic_img = (agnostic[j].cpu().permute(1, 2, 0).numpy() + 1) / 2
cloth_img = (cloth[j].cpu().permute(1, 2, 0).numpy() + 1) / 2
output_img = (output[j].cpu().permute(1, 2, 0).numpy() + 1) / 2
# Combine images for visualization
row1 = np.hstack([agnostic_img, cloth_img])
row2 = np.hstack([output_img, person_img])
combined = np.vstack([row1, row2])
imgs.append(combined)
# Create figure
fig, axs = plt.subplots(1, len(imgs), figsize=(15, 5))
if len(imgs) == 1:
axs = [axs]
for j, img in enumerate(imgs):
axs[j].imshow(img)
axs[j].set_title(f"Sample {j+1}")
axs[j].axis('off')
fig.suptitle(f"Epoch {epoch} Results", fontsize=16)
plt.tight_layout()
# Save figure
os.makedirs('results', exist_ok=True)
plt.savefig(f'results/epoch_{epoch}_samples.png')
plt.close()
def plot_losses(train_losses_G, train_losses_D=None, val_losses=None):
"""Plot training and validation losses"""
plt.figure(figsize=(10, 5))
plt.plot(train_losses_G, label='Generator Loss')
if train_losses_D:
plt.plot(train_losses_D, label='Discriminator Loss')
if val_losses:
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training and Validation Losses')
plt.legend()
plt.grid(True)
os.makedirs('results', exist_ok=True)
plt.savefig('results/loss_plot.png')
plt.close()
def save_checkpoint(model_G, model_D=None, optimizer_G=None, optimizer_D=None, epoch=None, filename="checkpoint.pth"):
"""Save model checkpoint"""
os.makedirs('checkpoints', exist_ok=True)
checkpoint = {
'epoch': epoch,
'model_G_state_dict': model_G.state_dict(),
'optimizer_G_state_dict': optimizer_G.state_dict() if optimizer_G else None,
}
if model_D is not None:
checkpoint['model_D_state_dict'] = model_D.state_dict()
if optimizer_D is not None:
checkpoint['optimizer_D_state_dict'] = optimizer_D.state_dict()
torch.save(checkpoint, f'checkpoints/{filename}')
def load_checkpoint(model_G, model_D=None, optimizer_G=None, optimizer_D=None, filename="checkpoint.pth"):
"""Load model checkpoint"""
checkpoint = torch.load(f'checkpoints/{filename}')
model_G.load_state_dict(checkpoint['model_G_state_dict'])
if optimizer_G and 'optimizer_G_state_dict' in checkpoint:
optimizer_G.load_state_dict(checkpoint['optimizer_G_state_dict'])
if model_D is not None and 'model_D_state_dict' in checkpoint:
model_D.load_state_dict(checkpoint['model_D_state_dict'])
if optimizer_D is not None and 'optimizer_D_state_dict' in checkpoint:
optimizer_D.load_state_dict(checkpoint['optimizer_D_state_dict'])
return checkpoint.get('epoch', 0)
# Test function
def test_model(model, test_loader, device, result_dir='test_results'):
"""Generate and save test results"""
model.eval()
os.makedirs(result_dir, exist_ok=True)
with torch.no_grad():
for i, sample in enumerate(test_loader):
agnostic = sample['agnostic'].to(device)
cloth = sample['cloth'].to(device)
target = sample['person'].to(device)
name = sample['name'][0] # Get sample name
# Generate try-on result
input_tensor = torch.cat([agnostic, cloth], dim=1)
output = model(input_tensor)
# Save images
output_img = (output[0].cpu().permute(1, 2, 0).numpy() + 1) / 2
target_img = (target[0].cpu().permute(1, 2, 0).numpy() + 1) / 2
agnostic_img = (agnostic[0].cpu().permute(1, 2, 0).numpy() + 1) / 2
cloth_img = (cloth[0].cpu().permute(1, 2, 0).numpy() + 1) / 2
# Save individual images
plt.imsave(f'{result_dir}/{name}_output.png', output_img)
plt.imsave(f'{result_dir}/{name}_target.png', target_img)
# Save comparison grid
fig, ax = plt.subplots(2, 2, figsize=(12, 12))
ax[0, 0].imshow(agnostic_img)
ax[0, 0].set_title('Person (w/o clothes)')
ax[0, 0].axis('off')
ax[0, 1].imshow(cloth_img)
ax[0, 1].set_title('Clothing Item')
ax[0, 1].axis('off')
ax[1, 0].imshow(output_img)
ax[1, 0].set_title('Generated Result')
ax[1, 0].axis('off')
ax[1, 1].imshow(target_img)
ax[1, 1].set_title('Ground Truth')
ax[1, 1].axis('off')
plt.tight_layout()
plt.savefig(f'{result_dir}/{name}_comparison.png')
plt.close()
if (i+1) % 10 == 0:
print(f"Processed {i+1}/{len(test_loader)} test samples")