Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
from torch.utils.data import DataLoader | |
from typing import Dict, Tuple | |
import os | |
from tqdm import tqdm | |
import wandb | |
from ..models.encoder import SpeakerEncoder | |
from ..configs.config import Config, TrainingConfig | |
class MetaTrainer: | |
"""元学习训练器:实现少样本语音克隆的训练过程""" | |
def __init__( | |
self, | |
model: SpeakerEncoder, | |
config: Config, | |
use_wandb: bool = True | |
): | |
self.model = model | |
self.config = config | |
self.use_wandb = use_wandb | |
self.device = torch.device(config.training.device) | |
self.model = self.model.to(self.device) | |
self.optimizer = optim.Adam( | |
self.model.parameters(), | |
lr=config.training.learning_rate | |
) | |
self.criterion = nn.CrossEntropyLoss() | |
if use_wandb: | |
wandb.init(project="voice-cloning", config=config) | |
def compute_loss( | |
self, | |
support_data: Dict[str, torch.Tensor], | |
query_data: Dict[str, torch.Tensor] | |
) -> Tuple[torch.Tensor, float]: | |
""" | |
计算元学习损失 | |
Args: | |
support_data: | |
- mel_spec: [n_way*k_shot, n_mels, time] | |
- speaker_ids: [n_way*k_shot] | |
query_data: | |
- mel_spec: [n_way*k_query, n_mels, time] | |
- speaker_ids: [n_way*k_query] | |
Returns: | |
loss: 标量损失值 | |
acc: 准确率 | |
""" | |
# 获取支持集和查询集的嵌入向量 | |
support_mel = support_data['mel_spec'].to(self.device) # [n_way*k_shot, n_mels, time] | |
query_mel = query_data['mel_spec'].to(self.device) # [n_way*k_query, n_mels, time] | |
# 获取嵌入向量 | |
support_embeds = self.model(support_mel) # [n_way*k_shot, embedding_dim] | |
query_embeds = self.model(query_mel) # [n_way*k_query, embedding_dim] | |
# 计算支持集的质心 | |
centroids = [] # 将存储每个说话人的质心 | |
for speaker_idx in range(self.config.meta_learning.n_way): | |
speaker_mask = (support_data['speaker_ids'] == speaker_idx).to(self.device) | |
speaker_embeds = support_embeds[speaker_mask] # [k_shot, embedding_dim] | |
centroid = speaker_embeds.mean(dim=0) # [embedding_dim] | |
centroids.append(centroid) | |
centroids = torch.stack(centroids) # [n_way, embedding_dim] | |
# 计算查询集样本与各个质心的相似度 | |
similarities = torch.matmul(query_embeds, centroids.T) # [n_way*k_query, n_way] | |
# 计算分类损失 | |
target = query_data['speaker_ids'].to(self.device) # [n_way*k_query] | |
loss = self.criterion(similarities, target) | |
# 计算准确率 | |
pred = similarities.argmax(dim=1) # [n_way*k_query] | |
acc = (pred == target).float().mean().item() | |
return loss, acc | |
def train_epoch(self, dataloader: DataLoader) -> Tuple[float, float]: | |
"""训练一个epoch""" | |
self.model.train() | |
total_loss = 0 | |
total_acc = 0 | |
with tqdm(dataloader, desc="Training") as pbar: | |
for batch_idx, (support_batch, query_batch) in enumerate(pbar): | |
self.optimizer.zero_grad() | |
loss, acc = self.compute_loss(support_batch, query_batch) | |
loss.backward() | |
# 梯度裁剪 | |
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 3.0) | |
self.optimizer.step() | |
total_loss += loss.item() | |
total_acc += acc | |
pbar.set_postfix({ | |
'loss': total_loss / (batch_idx + 1), | |
'acc': total_acc / (batch_idx + 1) | |
}) | |
if self.use_wandb: | |
wandb.log({ | |
'batch_loss': loss.item(), | |
'batch_acc': acc | |
}) | |
avg_loss = total_loss / len(dataloader) | |
avg_acc = total_acc / len(dataloader) | |
return avg_loss, avg_acc | |
def validate(self, dataloader: DataLoader) -> Tuple[float, float]: | |
"""验证模型""" | |
self.model.eval() | |
total_loss = 0 | |
total_acc = 0 | |
with torch.no_grad(): | |
for support_batch, query_batch in dataloader: | |
loss, acc = self.compute_loss(support_batch, query_batch) | |
total_loss += loss.item() | |
total_acc += acc | |
avg_loss = total_loss / len(dataloader) | |
avg_acc = total_acc / len(dataloader) | |
return avg_loss, avg_acc | |
def save_checkpoint(self, epoch: int, loss: float, acc: float): | |
"""保存检查点""" | |
checkpoint = { | |
'epoch': epoch, | |
'model_state_dict': self.model.state_dict(), | |
'optimizer_state_dict': self.optimizer.state_dict(), | |
'loss': loss, | |
'acc': acc | |
} | |
checkpoint_path = os.path.join( | |
self.config.training.checkpoint_dir, | |
f'checkpoint_epoch_{epoch}.pt' | |
) | |
os.makedirs(self.config.training.checkpoint_dir, exist_ok=True) | |
torch.save(checkpoint, checkpoint_path) | |
def load_checkpoint(self, checkpoint_path: str): | |
"""加载检查点""" | |
checkpoint = torch.load(checkpoint_path) | |
self.model.load_state_dict(checkpoint['model_state_dict']) | |
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) | |
return checkpoint['epoch'], checkpoint['loss'], checkpoint['acc'] |