voice-clone-app / src /train.py
hengjie yang
Initial commit: Voice Clone App with Gradio interface
9580089
import argparse
import torch
from pathlib import Path
from torch.utils.data import DataLoader
from src.configs.config import Config
from src.models.encoder import SpeakerEncoder
from src.data.dataset import create_meta_learning_dataloader
from src.trainers.meta_trainer import MetaTrainer
def parse_args():
parser = argparse.ArgumentParser(description='训练说话人编码器')
parser.add_argument('--data_dir', type=str, required=True, help='数据集根目录')
parser.add_argument('--checkpoint', type=str, help='恢复训练的检查点路径')
parser.add_argument('--no_wandb', action='store_true', help='禁用Weights & Biases日志')
return parser.parse_args()
def main():
args = parse_args()
# 加载配置
config = Config()
config.data.root_dir = args.data_dir
# 创建数据加载器
train_loader = create_meta_learning_dataloader(
root_dir=config.data.root_dir,
n_way=config.meta_learning.n_way,
k_shot=config.meta_learning.k_shot,
k_query=config.meta_learning.k_query,
n_tasks=config.meta_learning.n_tasks,
batch_size=config.meta_learning.batch_size,
num_workers=config.meta_learning.num_workers
)
# 创建验证集数据加载器
val_loader = create_meta_learning_dataloader(
root_dir=config.data.root_dir,
n_way=config.meta_learning.n_way,
k_shot=config.meta_learning.k_shot,
k_query=config.meta_learning.k_query,
n_tasks=config.meta_learning.n_tasks // 10, # 验证集任务数较少
batch_size=config.meta_learning.batch_size,
num_workers=config.meta_learning.num_workers
)
# 创建模型
model = SpeakerEncoder(
input_dim=config.audio.n_mels,
hidden_dim=256,
embedding_dim=512
)
# 创建训练器
trainer = MetaTrainer(
model=model,
config=config,
use_wandb=not args.no_wandb
)
# 如果指定了检查点,则加载
start_epoch = 0
if args.checkpoint:
print(f"Loading checkpoint from {args.checkpoint}")
start_epoch, _, _ = trainer.load_checkpoint(args.checkpoint)
start_epoch += 1
# 开始训练
print("Starting training...")
best_val_acc = 0
for epoch in range(start_epoch, config.training.num_epochs):
print(f"\nEpoch {epoch + 1}/{config.training.num_epochs}")
# 训练一个epoch
train_loss, train_acc = trainer.train_epoch(train_loader)
print(f"Training - Loss: {train_loss:.4f}, Accuracy: {train_acc:.4f}")
# 验证
val_loss, val_acc = trainer.validate(val_loader)
print(f"Validation - Loss: {val_loss:.4f}, Accuracy: {val_acc:.4f}")
# 保存最佳模型
if val_acc > best_val_acc:
best_val_acc = val_acc
trainer.save_checkpoint(
epoch=epoch,
loss=val_loss,
acc=val_acc
)
print(f"Saved new best model with validation accuracy: {val_acc:.4f}")
if __name__ == '__main__':
main()