OTA_TextAligner / train.py
Respair's picture
Create train.py
ea101a8 verified
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt
import numpy as np
import time
from tqdm import tqdm
from optimizers import build_optimizer
def train_aligner(config, accelerator, train_dataloader, val_dataloader, device, log_dir, epochs=100):
# Create model
aligner = AlignerModel().to(device)
# Define loss function
forward_sum_loss = ForwardSumLoss()
# Setup optimizer
scheduler_params = {
"max_lr": float(config['optimizer_params'].get('lr', 5e-4)),
"pct_start": float(config['optimizer_params'].get('pct_start', 0.0)),
"epochs": epochs,
"steps_per_epoch": len(train_dataloader),
}
optimizer, scheduler = build_optimizer(
{"params": aligner.parameters(), "optimizer_params":{}, "scheduler_params": scheduler_params})
# Setup TensorBoard writer
writer = SummaryWriter(log_dir=log_dir)
# Create directories for model checkpoints
os.makedirs(os.path.join(log_dir, 'checkpoints'), exist_ok=True)
# Track best validation loss
best_val_loss = float('inf')
# Loss weights
fwd_sum_loss_weight = config.get('fwd_sum_loss_weight', 1.0)
# Training loop
for epoch in range(1, epochs + 1):
aligner.train()
train_losses = []
train_fwd_losses = []
start_time = time.time()
# Training phase
pbar = tqdm(train_dataloader, desc=f"Epoch {epoch}/{epochs} [Train]")
for i, batch in enumerate(pbar):
batch = [b.to(device) for b in batch]
text_input, text_input_length, mel_input, mel_input_length, attn_prior = batch
# Forward pass
attn_soft, attn_logprob = aligner(spec=mel_input,
spec_len=mel_input_length,
text=text_input,
text_len=text_input_length,
attn_prior=attn_prior)
# Calculate loss
loss = forward_sum_loss(attn_logprob=attn_logprob,
in_lens=text_input_length,
out_lens=mel_input_length)
# Backward pass and optimization
optimizer.zero_grad()
loss.backward()
# Optional gradient clipping
grad_norm = nn.utils.clip_grad_norm_(aligner.parameters(), config.get('grad_clip', 5.0))
optimizer.step()
if scheduler is not None:
scheduler.step()
# Log to TensorBoard
global_step = (epoch - 1) * len(train_dataloader) + i
writer.add_scalar('train/total_loss', loss.item(), global_step)
writer.add_scalar('train/grad_norm', grad_norm, global_step)
# Update progress bar
train_losses.append(loss.item())
train_fwd_losses.append(loss.item())
# Update the progress bar description
pbar.set_description(f"Epoch {epoch}/{epochs} [Train] Loss: {loss.item():.4f}")
# Calculate average training loss for this epoch
avg_train_loss = sum(train_losses) / len(train_losses)
# Validation phase
aligner.eval()
val_losses = []
with torch.no_grad():
for batch in tqdm(val_dataloader, desc=f"Epoch {epoch}/{epochs} [Val]"):
batch = [b.to(device) for b in batch]
text_input, text_input_length, mel_input, mel_input_length, attn_prior = batch
# Forward pass
attn_soft, attn_logprob = aligner(spec=mel_input,
spec_len=mel_input_length,
text=text_input,
text_len=text_input_length,
attn_prior=attn_prior)
# Calculate loss
val_loss = forward_sum_loss(attn_logprob=attn_logprob,
in_lens=text_input_length,
out_lens=mel_input_length)
val_losses.append(val_loss.item())
# Calculate average validation loss
avg_val_loss = sum(val_losses) / len(val_losses)
# Log to TensorBoard
writer.add_scalar('epoch/train_loss', avg_train_loss, epoch)
writer.add_scalar('epoch/val_loss', avg_val_loss, epoch)
# Save model if it's the best so far
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
torch.save({
'epoch': epoch,
'model_state_dict': aligner.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'train_loss': avg_train_loss,
'val_loss': avg_val_loss,
}, os.path.join(log_dir, 'checkpoints', 'best_model.pt'))
# Save checkpoint every N epochs
if epoch % config.get('save_every', 10) == 0:
torch.save({
'epoch': epoch,
'model_state_dict': aligner.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'train_loss': avg_train_loss,
'val_loss': avg_val_loss,
}, os.path.join(log_dir, 'checkpoints', f'checkpoint_epoch_{epoch}.pt'))
# Print summary for this epoch
epoch_time = time.time() - start_time
print(f"Epoch {epoch}/{epochs} completed in {epoch_time:.2f}s | "
f"Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")
# Plot and save attention matrices for visualization
if epoch % config.get('plot_every', 10) == 0:
plot_attention_matrices(aligner, val_dataloader, device,
os.path.join(log_dir, 'attention_plots', f'epoch_{epoch}'),
num_samples=4)
writer.close()
print(f"Training completed. Best validation loss: {best_val_loss:.4f}")
return aligner
# Main execution
if __name__ == "__main__":
def length_to_mask(lengths):
mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
mask = torch.gt(mask+1, lengths.unsqueeze(1))
return mask
# Assuming these variables are defined in your main script
train_aligner(
config=config,
train_dataloader=train_dataloader,
val_dataloader=val_dataloader,
device=device,
log_dir=log_dir,
epochs=epoch
)