File size: 3,155 Bytes
9580089
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import argparse
import torch
from pathlib import Path
from torch.utils.data import DataLoader

from src.configs.config import Config
from src.models.encoder import SpeakerEncoder
from src.data.dataset import create_meta_learning_dataloader
from src.trainers.meta_trainer import MetaTrainer

def parse_args():
    parser = argparse.ArgumentParser(description='训练说话人编码器')
    parser.add_argument('--data_dir', type=str, required=True, help='数据集根目录')
    parser.add_argument('--checkpoint', type=str, help='恢复训练的检查点路径')
    parser.add_argument('--no_wandb', action='store_true', help='禁用Weights & Biases日志')
    return parser.parse_args()

def main():
    args = parse_args()
    
    # 加载配置
    config = Config()
    config.data.root_dir = args.data_dir
    
    # 创建数据加载器
    train_loader = create_meta_learning_dataloader(
        root_dir=config.data.root_dir,
        n_way=config.meta_learning.n_way,
        k_shot=config.meta_learning.k_shot,
        k_query=config.meta_learning.k_query,
        n_tasks=config.meta_learning.n_tasks,
        batch_size=config.meta_learning.batch_size,
        num_workers=config.meta_learning.num_workers
    )
    
    # 创建验证集数据加载器
    val_loader = create_meta_learning_dataloader(
        root_dir=config.data.root_dir,
        n_way=config.meta_learning.n_way,
        k_shot=config.meta_learning.k_shot,
        k_query=config.meta_learning.k_query,
        n_tasks=config.meta_learning.n_tasks // 10,  # 验证集任务数较少
        batch_size=config.meta_learning.batch_size,
        num_workers=config.meta_learning.num_workers
    )
    
    # 创建模型
    model = SpeakerEncoder(
        input_dim=config.audio.n_mels,
        hidden_dim=256,
        embedding_dim=512
    )
    
    # 创建训练器
    trainer = MetaTrainer(
        model=model,
        config=config,
        use_wandb=not args.no_wandb
    )
    
    # 如果指定了检查点,则加载
    start_epoch = 0
    if args.checkpoint:
        print(f"Loading checkpoint from {args.checkpoint}")
        start_epoch, _, _ = trainer.load_checkpoint(args.checkpoint)
        start_epoch += 1
    
    # 开始训练
    print("Starting training...")
    best_val_acc = 0
    
    for epoch in range(start_epoch, config.training.num_epochs):
        print(f"\nEpoch {epoch + 1}/{config.training.num_epochs}")
        
        # 训练一个epoch
        train_loss, train_acc = trainer.train_epoch(train_loader)
        print(f"Training - Loss: {train_loss:.4f}, Accuracy: {train_acc:.4f}")
        
        # 验证
        val_loss, val_acc = trainer.validate(val_loader)
        print(f"Validation - Loss: {val_loss:.4f}, Accuracy: {val_acc:.4f}")
        
        # 保存最佳模型
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            trainer.save_checkpoint(
                epoch=epoch,
                loss=val_loss,
                acc=val_acc
            )
            print(f"Saved new best model with validation accuracy: {val_acc:.4f}")
            
if __name__ == '__main__':
    main()