Spaces:
Sleeping
Sleeping
File size: 3,155 Bytes
9580089 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
import argparse
import torch
from pathlib import Path
from torch.utils.data import DataLoader
from src.configs.config import Config
from src.models.encoder import SpeakerEncoder
from src.data.dataset import create_meta_learning_dataloader
from src.trainers.meta_trainer import MetaTrainer
def parse_args():
parser = argparse.ArgumentParser(description='训练说话人编码器')
parser.add_argument('--data_dir', type=str, required=True, help='数据集根目录')
parser.add_argument('--checkpoint', type=str, help='恢复训练的检查点路径')
parser.add_argument('--no_wandb', action='store_true', help='禁用Weights & Biases日志')
return parser.parse_args()
def main():
args = parse_args()
# 加载配置
config = Config()
config.data.root_dir = args.data_dir
# 创建数据加载器
train_loader = create_meta_learning_dataloader(
root_dir=config.data.root_dir,
n_way=config.meta_learning.n_way,
k_shot=config.meta_learning.k_shot,
k_query=config.meta_learning.k_query,
n_tasks=config.meta_learning.n_tasks,
batch_size=config.meta_learning.batch_size,
num_workers=config.meta_learning.num_workers
)
# 创建验证集数据加载器
val_loader = create_meta_learning_dataloader(
root_dir=config.data.root_dir,
n_way=config.meta_learning.n_way,
k_shot=config.meta_learning.k_shot,
k_query=config.meta_learning.k_query,
n_tasks=config.meta_learning.n_tasks // 10, # 验证集任务数较少
batch_size=config.meta_learning.batch_size,
num_workers=config.meta_learning.num_workers
)
# 创建模型
model = SpeakerEncoder(
input_dim=config.audio.n_mels,
hidden_dim=256,
embedding_dim=512
)
# 创建训练器
trainer = MetaTrainer(
model=model,
config=config,
use_wandb=not args.no_wandb
)
# 如果指定了检查点,则加载
start_epoch = 0
if args.checkpoint:
print(f"Loading checkpoint from {args.checkpoint}")
start_epoch, _, _ = trainer.load_checkpoint(args.checkpoint)
start_epoch += 1
# 开始训练
print("Starting training...")
best_val_acc = 0
for epoch in range(start_epoch, config.training.num_epochs):
print(f"\nEpoch {epoch + 1}/{config.training.num_epochs}")
# 训练一个epoch
train_loss, train_acc = trainer.train_epoch(train_loader)
print(f"Training - Loss: {train_loss:.4f}, Accuracy: {train_acc:.4f}")
# 验证
val_loss, val_acc = trainer.validate(val_loader)
print(f"Validation - Loss: {val_loss:.4f}, Accuracy: {val_acc:.4f}")
# 保存最佳模型
if val_acc > best_val_acc:
best_val_acc = val_acc
trainer.save_checkpoint(
epoch=epoch,
loss=val_loss,
acc=val_acc
)
print(f"Saved new best model with validation accuracy: {val_acc:.4f}")
if __name__ == '__main__':
main() |