|
import os |
|
import torch |
|
import torch.nn as nn |
|
import torch.optim as optim |
|
from torch.utils.tensorboard import SummaryWriter |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
import time |
|
from tqdm import tqdm |
|
from optimizers import build_optimizer |
|
|
|
def train_aligner(config, accelerator, train_dataloader, val_dataloader, device, log_dir, epochs=100): |
|
|
|
aligner = AlignerModel().to(device) |
|
|
|
|
|
forward_sum_loss = ForwardSumLoss() |
|
|
|
|
|
scheduler_params = { |
|
"max_lr": float(config['optimizer_params'].get('lr', 5e-4)), |
|
"pct_start": float(config['optimizer_params'].get('pct_start', 0.0)), |
|
"epochs": epochs, |
|
"steps_per_epoch": len(train_dataloader), |
|
} |
|
|
|
optimizer, scheduler = build_optimizer( |
|
{"params": aligner.parameters(), "optimizer_params":{}, "scheduler_params": scheduler_params}) |
|
|
|
|
|
writer = SummaryWriter(log_dir=log_dir) |
|
|
|
|
|
os.makedirs(os.path.join(log_dir, 'checkpoints'), exist_ok=True) |
|
|
|
|
|
best_val_loss = float('inf') |
|
|
|
|
|
fwd_sum_loss_weight = config.get('fwd_sum_loss_weight', 1.0) |
|
|
|
|
|
for epoch in range(1, epochs + 1): |
|
aligner.train() |
|
train_losses = [] |
|
train_fwd_losses = [] |
|
start_time = time.time() |
|
|
|
|
|
pbar = tqdm(train_dataloader, desc=f"Epoch {epoch}/{epochs} [Train]") |
|
for i, batch in enumerate(pbar): |
|
batch = [b.to(device) for b in batch] |
|
|
|
text_input, text_input_length, mel_input, mel_input_length, attn_prior = batch |
|
|
|
|
|
attn_soft, attn_logprob = aligner(spec=mel_input, |
|
spec_len=mel_input_length, |
|
text=text_input, |
|
text_len=text_input_length, |
|
attn_prior=attn_prior) |
|
|
|
|
|
loss = forward_sum_loss(attn_logprob=attn_logprob, |
|
in_lens=text_input_length, |
|
out_lens=mel_input_length) |
|
|
|
|
|
optimizer.zero_grad() |
|
loss.backward() |
|
|
|
|
|
grad_norm = nn.utils.clip_grad_norm_(aligner.parameters(), config.get('grad_clip', 5.0)) |
|
|
|
optimizer.step() |
|
if scheduler is not None: |
|
scheduler.step() |
|
|
|
|
|
global_step = (epoch - 1) * len(train_dataloader) + i |
|
writer.add_scalar('train/total_loss', loss.item(), global_step) |
|
writer.add_scalar('train/grad_norm', grad_norm, global_step) |
|
|
|
|
|
train_losses.append(loss.item()) |
|
train_fwd_losses.append(loss.item()) |
|
|
|
|
|
pbar.set_description(f"Epoch {epoch}/{epochs} [Train] Loss: {loss.item():.4f}") |
|
|
|
|
|
avg_train_loss = sum(train_losses) / len(train_losses) |
|
|
|
|
|
aligner.eval() |
|
val_losses = [] |
|
|
|
with torch.no_grad(): |
|
for batch in tqdm(val_dataloader, desc=f"Epoch {epoch}/{epochs} [Val]"): |
|
batch = [b.to(device) for b in batch] |
|
|
|
text_input, text_input_length, mel_input, mel_input_length, attn_prior = batch |
|
|
|
|
|
attn_soft, attn_logprob = aligner(spec=mel_input, |
|
spec_len=mel_input_length, |
|
text=text_input, |
|
text_len=text_input_length, |
|
attn_prior=attn_prior) |
|
|
|
|
|
val_loss = forward_sum_loss(attn_logprob=attn_logprob, |
|
in_lens=text_input_length, |
|
out_lens=mel_input_length) |
|
|
|
val_losses.append(val_loss.item()) |
|
|
|
|
|
avg_val_loss = sum(val_losses) / len(val_losses) |
|
|
|
|
|
writer.add_scalar('epoch/train_loss', avg_train_loss, epoch) |
|
writer.add_scalar('epoch/val_loss', avg_val_loss, epoch) |
|
|
|
|
|
if avg_val_loss < best_val_loss: |
|
best_val_loss = avg_val_loss |
|
torch.save({ |
|
'epoch': epoch, |
|
'model_state_dict': aligner.state_dict(), |
|
'optimizer_state_dict': optimizer.state_dict(), |
|
'train_loss': avg_train_loss, |
|
'val_loss': avg_val_loss, |
|
}, os.path.join(log_dir, 'checkpoints', 'best_model.pt')) |
|
|
|
|
|
if epoch % config.get('save_every', 10) == 0: |
|
torch.save({ |
|
'epoch': epoch, |
|
'model_state_dict': aligner.state_dict(), |
|
'optimizer_state_dict': optimizer.state_dict(), |
|
'train_loss': avg_train_loss, |
|
'val_loss': avg_val_loss, |
|
}, os.path.join(log_dir, 'checkpoints', f'checkpoint_epoch_{epoch}.pt')) |
|
|
|
|
|
epoch_time = time.time() - start_time |
|
print(f"Epoch {epoch}/{epochs} completed in {epoch_time:.2f}s | " |
|
f"Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}") |
|
|
|
|
|
if epoch % config.get('plot_every', 10) == 0: |
|
plot_attention_matrices(aligner, val_dataloader, device, |
|
os.path.join(log_dir, 'attention_plots', f'epoch_{epoch}'), |
|
num_samples=4) |
|
|
|
writer.close() |
|
print(f"Training completed. Best validation loss: {best_val_loss:.4f}") |
|
return aligner |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
def length_to_mask(lengths): |
|
mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths) |
|
mask = torch.gt(mask+1, lengths.unsqueeze(1)) |
|
return mask |
|
|
|
|
|
train_aligner( |
|
config=config, |
|
train_dataloader=train_dataloader, |
|
val_dataloader=val_dataloader, |
|
device=device, |
|
log_dir=log_dir, |
|
epochs=epoch |
|
) |