import argparse import torch from pathlib import Path from torch.utils.data import DataLoader from src.configs.config import Config from src.models.encoder import SpeakerEncoder from src.data.dataset import create_meta_learning_dataloader from src.trainers.meta_trainer import MetaTrainer def parse_args(): parser = argparse.ArgumentParser(description='训练说话人编码器') parser.add_argument('--data_dir', type=str, required=True, help='数据集根目录') parser.add_argument('--checkpoint', type=str, help='恢复训练的检查点路径') parser.add_argument('--no_wandb', action='store_true', help='禁用Weights & Biases日志') return parser.parse_args() def main(): args = parse_args() # 加载配置 config = Config() config.data.root_dir = args.data_dir # 创建数据加载器 train_loader = create_meta_learning_dataloader( root_dir=config.data.root_dir, n_way=config.meta_learning.n_way, k_shot=config.meta_learning.k_shot, k_query=config.meta_learning.k_query, n_tasks=config.meta_learning.n_tasks, batch_size=config.meta_learning.batch_size, num_workers=config.meta_learning.num_workers ) # 创建验证集数据加载器 val_loader = create_meta_learning_dataloader( root_dir=config.data.root_dir, n_way=config.meta_learning.n_way, k_shot=config.meta_learning.k_shot, k_query=config.meta_learning.k_query, n_tasks=config.meta_learning.n_tasks // 10, # 验证集任务数较少 batch_size=config.meta_learning.batch_size, num_workers=config.meta_learning.num_workers ) # 创建模型 model = SpeakerEncoder( input_dim=config.audio.n_mels, hidden_dim=256, embedding_dim=512 ) # 创建训练器 trainer = MetaTrainer( model=model, config=config, use_wandb=not args.no_wandb ) # 如果指定了检查点,则加载 start_epoch = 0 if args.checkpoint: print(f"Loading checkpoint from {args.checkpoint}") start_epoch, _, _ = trainer.load_checkpoint(args.checkpoint) start_epoch += 1 # 开始训练 print("Starting training...") best_val_acc = 0 for epoch in range(start_epoch, config.training.num_epochs): print(f"\nEpoch {epoch + 1}/{config.training.num_epochs}") # 训练一个epoch train_loss, train_acc = trainer.train_epoch(train_loader) print(f"Training - Loss: {train_loss:.4f}, Accuracy: {train_acc:.4f}") # 验证 val_loss, val_acc = trainer.validate(val_loader) print(f"Validation - Loss: {val_loss:.4f}, Accuracy: {val_acc:.4f}") # 保存最佳模型 if val_acc > best_val_acc: best_val_acc = val_acc trainer.save_checkpoint( epoch=epoch, loss=val_loss, acc=val_acc ) print(f"Saved new best model with validation accuracy: {val_acc:.4f}") if __name__ == '__main__': main()