Spaces:
Sleeping
Sleeping
import argparse | |
import torch | |
from pathlib import Path | |
from torch.utils.data import DataLoader | |
from src.configs.config import Config | |
from src.models.encoder import SpeakerEncoder | |
from src.data.dataset import create_meta_learning_dataloader | |
from src.trainers.meta_trainer import MetaTrainer | |
def parse_args(): | |
parser = argparse.ArgumentParser(description='训练说话人编码器') | |
parser.add_argument('--data_dir', type=str, required=True, help='数据集根目录') | |
parser.add_argument('--checkpoint', type=str, help='恢复训练的检查点路径') | |
parser.add_argument('--no_wandb', action='store_true', help='禁用Weights & Biases日志') | |
return parser.parse_args() | |
def main(): | |
args = parse_args() | |
# 加载配置 | |
config = Config() | |
config.data.root_dir = args.data_dir | |
# 创建数据加载器 | |
train_loader = create_meta_learning_dataloader( | |
root_dir=config.data.root_dir, | |
n_way=config.meta_learning.n_way, | |
k_shot=config.meta_learning.k_shot, | |
k_query=config.meta_learning.k_query, | |
n_tasks=config.meta_learning.n_tasks, | |
batch_size=config.meta_learning.batch_size, | |
num_workers=config.meta_learning.num_workers | |
) | |
# 创建验证集数据加载器 | |
val_loader = create_meta_learning_dataloader( | |
root_dir=config.data.root_dir, | |
n_way=config.meta_learning.n_way, | |
k_shot=config.meta_learning.k_shot, | |
k_query=config.meta_learning.k_query, | |
n_tasks=config.meta_learning.n_tasks // 10, # 验证集任务数较少 | |
batch_size=config.meta_learning.batch_size, | |
num_workers=config.meta_learning.num_workers | |
) | |
# 创建模型 | |
model = SpeakerEncoder( | |
input_dim=config.audio.n_mels, | |
hidden_dim=256, | |
embedding_dim=512 | |
) | |
# 创建训练器 | |
trainer = MetaTrainer( | |
model=model, | |
config=config, | |
use_wandb=not args.no_wandb | |
) | |
# 如果指定了检查点,则加载 | |
start_epoch = 0 | |
if args.checkpoint: | |
print(f"Loading checkpoint from {args.checkpoint}") | |
start_epoch, _, _ = trainer.load_checkpoint(args.checkpoint) | |
start_epoch += 1 | |
# 开始训练 | |
print("Starting training...") | |
best_val_acc = 0 | |
for epoch in range(start_epoch, config.training.num_epochs): | |
print(f"\nEpoch {epoch + 1}/{config.training.num_epochs}") | |
# 训练一个epoch | |
train_loss, train_acc = trainer.train_epoch(train_loader) | |
print(f"Training - Loss: {train_loss:.4f}, Accuracy: {train_acc:.4f}") | |
# 验证 | |
val_loss, val_acc = trainer.validate(val_loader) | |
print(f"Validation - Loss: {val_loss:.4f}, Accuracy: {val_acc:.4f}") | |
# 保存最佳模型 | |
if val_acc > best_val_acc: | |
best_val_acc = val_acc | |
trainer.save_checkpoint( | |
epoch=epoch, | |
loss=val_loss, | |
acc=val_acc | |
) | |
print(f"Saved new best model with validation accuracy: {val_acc:.4f}") | |
if __name__ == '__main__': | |
main() |