# Copyright (c) 2023 Amphion. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import os import shutil import time import json5 import torch import numpy as np from tqdm import tqdm from utils.util import ValueWindow from torch.utils.data import DataLoader from models.vc.Noro.noro_base_trainer import Noro_base_Trainer from torch.nn import functional as F from models.base.base_sampler import VariableSampler from diffusers import get_scheduler import accelerate from accelerate.logging import get_logger from accelerate.utils import ProjectConfiguration from models.vc.Noro.noro_model import Noro_VCmodel from models.vc.Noro.noro_dataset import VCCollator, VCDataset, batch_by_size from processors.content_extractor import HubertExtractor from models.vc.Noro.noro_loss import diff_loss, ConstractiveSpeakerLoss from utils.mel import mel_spectrogram_torch from utils.f0 import get_f0_features_using_dio, interpolate from torch.nn.utils.rnn import pad_sequence class NoroTrainer(Noro_base_Trainer): def __init__(self, args, cfg): self.args = args self.cfg = cfg cfg.exp_name = args.exp_name self.content_extractor = "mhubert" # Initialize accelerator and ensure all processes are ready self._init_accelerator() self.accelerator.wait_for_everyone() # Initialize logger on the main process if self.accelerator.is_main_process: self.logger = get_logger(args.exp_name, log_level="INFO") # Configure noise and speaker usage self.use_ref_noise = self.cfg.trans_exp.use_ref_noise # Log configuration on the main process if self.accelerator.is_main_process: self.logger.info(f"use_ref_noise: {self.use_ref_noise}") # Initialize a time window for monitoring metrics self.time_window = ValueWindow(50) # Log the start of training if self.accelerator.is_main_process: self.logger.info("=" * 56) self.logger.info("||\t\tNew training process started.\t\t||") self.logger.info("=" * 56) self.logger.info("\n") self.logger.debug(f"Using {args.log_level.upper()} logging level.") self.logger.info(f"Experiment name: {args.exp_name}") self.logger.info(f"Experiment directory: {self.exp_dir}") # Initialize checkpoint directory self.checkpoint_dir = os.path.join(self.exp_dir, "checkpoint") if self.accelerator.is_main_process: os.makedirs(self.checkpoint_dir, exist_ok=True) self.logger.debug(f"Checkpoint directory: {self.checkpoint_dir}") # Initialize training counters self.batch_count: int = 0 self.step: int = 0 self.epoch: int = 0 self.max_epoch = ( self.cfg.train.max_epoch if self.cfg.train.max_epoch > 0 else float("inf") ) if self.accelerator.is_main_process: self.logger.info( f"Max epoch: {self.max_epoch if self.max_epoch < float('inf') else 'Unlimited'}" ) # Check basic configuration if self.accelerator.is_main_process: self._check_basic_configs() self.save_checkpoint_stride = self.cfg.train.save_checkpoint_stride self.keep_last = [ i if i > 0 else float("inf") for i in self.cfg.train.keep_last ] self.run_eval = self.cfg.train.run_eval # Set random seed with self.accelerator.main_process_first(): self._set_random_seed(self.cfg.train.random_seed) # Setup data loader with self.accelerator.main_process_first(): if self.accelerator.is_main_process: self.logger.info("Building dataset...") self.train_dataloader = self._build_dataloader() self.speaker_num = len(self.train_dataloader.dataset.speaker2id) if self.accelerator.is_main_process: self.logger.info("Speaker num: {}".format(self.speaker_num)) # Build model with self.accelerator.main_process_first(): if self.accelerator.is_main_process: self.logger.info("Building model...") self.model, self.w2v = self._build_model() # Resume training if specified with self.accelerator.main_process_first(): if self.accelerator.is_main_process: self.logger.info("Resume training: {}".format(args.resume)) if args.resume: if self.accelerator.is_main_process: self.logger.info("Resuming from checkpoint...") ckpt_path = self._load_model( self.checkpoint_dir, args.checkpoint_path, resume_type=args.resume_type, ) self.checkpoint_dir = os.path.join(self.exp_dir, "checkpoint") if self.accelerator.is_main_process: os.makedirs(self.checkpoint_dir, exist_ok=True) if self.accelerator.is_main_process: self.logger.debug(f"Checkpoint directory: {self.checkpoint_dir}") # Initialize optimizer & scheduler with self.accelerator.main_process_first(): if self.accelerator.is_main_process: self.logger.info("Building optimizer and scheduler...") self.optimizer = self._build_optimizer() self.scheduler = self._build_scheduler() # Prepare model, w2v, optimizer, and scheduler for accelerator self.model = self._prepare_for_accelerator(self.model) self.w2v = self._prepare_for_accelerator(self.w2v) self.optimizer = self._prepare_for_accelerator(self.optimizer) self.scheduler = self._prepare_for_accelerator(self.scheduler) # Build criterion with self.accelerator.main_process_first(): if self.accelerator.is_main_process: self.logger.info("Building criterion...") self.criterion = self._build_criterion() self.config_save_path = os.path.join(self.exp_dir, "args.json") self.task_type = "VC" self.contrastive_speaker_loss = ConstractiveSpeakerLoss() if self.accelerator.is_main_process: self.logger.info("Task type: {}".format(self.task_type)) def _init_accelerator(self): self.exp_dir = os.path.join( os.path.abspath(self.cfg.log_dir), self.args.exp_name ) project_config = ProjectConfiguration( project_dir=self.exp_dir, logging_dir=os.path.join(self.exp_dir, "log"), ) self.accelerator = accelerate.Accelerator( gradient_accumulation_steps=self.cfg.train.gradient_accumulation_step, log_with=self.cfg.train.tracker, project_config=project_config, ) if self.accelerator.is_main_process: os.makedirs(project_config.project_dir, exist_ok=True) os.makedirs(project_config.logging_dir, exist_ok=True) self.accelerator.wait_for_everyone() with self.accelerator.main_process_first(): self.accelerator.init_trackers(self.args.exp_name) def _build_model(self): w2v = HubertExtractor(self.cfg) model = Noro_VCmodel(cfg=self.cfg.model, use_ref_noise=self.use_ref_noise) return model, w2v def _build_dataloader(self): np.random.seed(int(time.time())) if self.accelerator.is_main_process: self.logger.info("Use Dynamic Batchsize...") train_dataset = VCDataset(self.cfg.trans_exp) train_collate = VCCollator(self.cfg) batch_sampler = batch_by_size( train_dataset.num_frame_indices, train_dataset.get_num_frames, max_tokens=self.cfg.train.max_tokens * self.accelerator.num_processes, max_sentences=self.cfg.train.max_sentences * self.accelerator.num_processes, required_batch_size_multiple=self.accelerator.num_processes, ) np.random.shuffle(batch_sampler) batches = [ x[self.accelerator.local_process_index :: self.accelerator.num_processes] for x in batch_sampler if len(x) % self.accelerator.num_processes == 0 ] train_loader = DataLoader( train_dataset, collate_fn=train_collate, num_workers=self.cfg.train.dataloader.num_worker, batch_sampler=VariableSampler( batches, drop_last=False, use_random_sampler=True ), pin_memory=self.cfg.train.dataloader.pin_memory, ) self.accelerator.wait_for_everyone() return train_loader def _build_optimizer(self): optimizer = torch.optim.AdamW( filter(lambda p: p.requires_grad, self.model.parameters()), **self.cfg.train.adam, ) return optimizer def _build_scheduler(self): lr_scheduler = get_scheduler( self.cfg.train.lr_scheduler, optimizer=self.optimizer, num_warmup_steps=self.cfg.train.lr_warmup_steps, num_training_steps=self.cfg.train.num_train_steps, ) return lr_scheduler def _build_criterion(self): criterion = torch.nn.L1Loss(reduction="mean") return criterion def _dump_cfg(self, path): os.makedirs(os.path.dirname(path), exist_ok=True) json5.dump( self.cfg, open(path, "w"), indent=4, sort_keys=True, ensure_ascii=False, quote_keys=True, ) def load_model(self, checkpoint): self.step = checkpoint["step"] self.epoch = checkpoint["epoch"] self.model.load_state_dict(checkpoint["model"]) self.optimizer.load_state_dict(checkpoint["optimizer"]) self.scheduler.load_state_dict(checkpoint["scheduler"]) def _prepare_for_accelerator(self, component): if isinstance(component, dict): for key in component.keys(): component[key] = self.accelerator.prepare(component[key]) else: component = self.accelerator.prepare(component) return component def _train_step(self, batch): total_loss = 0.0 train_losses = {} device = self.accelerator.device # Move all Tensor data to the specified device batch = { k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items() } speech = batch["speech"] ref_speech = batch["ref_speech"] with torch.set_grad_enabled(False): # Extract features and spectrograms mel = mel_spectrogram_torch(speech, self.cfg).transpose(1, 2) ref_mel = mel_spectrogram_torch(ref_speech, self.cfg).transpose(1, 2) mask = batch["mask"] ref_mask = batch["ref_mask"] # Extract pitch and content features audio = speech.cpu().numpy() f0s = [] for i in range(audio.shape[0]): wav = audio[i] f0 = get_f0_features_using_dio(wav, self.cfg.preprocess) f0, _ = interpolate(f0) frame_num = len(wav) // self.cfg.preprocess.hop_size f0 = torch.from_numpy(f0[:frame_num]).to(speech.device) f0s.append(f0) pitch = pad_sequence(f0s, batch_first=True, padding_value=0).float() pitch = (pitch - pitch.mean(dim=1, keepdim=True)) / ( pitch.std(dim=1, keepdim=True) + 1e-6 ) # Normalize pitch (B,T) _, content_feature = self.w2v.extract_content_features( speech ) # semantic (B, T, 768) if self.use_ref_noise: noisy_ref_mel = mel_spectrogram_torch( batch["noisy_ref_speech"], self.cfg ).transpose(1, 2) if self.use_ref_noise: diff_out, (ref_emb, noisy_ref_emb), (cond_emb, _) = self.model( x=mel, content_feature=content_feature, pitch=pitch, x_ref=ref_mel, x_mask=mask, x_ref_mask=ref_mask, noisy_x_ref=noisy_ref_mel, ) else: diff_out, (ref_emb, _), (cond_emb, _) = self.model( x=mel, content_feature=content_feature, pitch=pitch, x_ref=ref_mel, x_mask=mask, x_ref_mask=ref_mask, ) if self.use_ref_noise: # B x N_query x D ref_emb = torch.mean(ref_emb, dim=1) # B x D noisy_ref_emb = torch.mean(noisy_ref_emb, dim=1) # B x D all_ref_emb = torch.cat([ref_emb, noisy_ref_emb], dim=0) # 2B x D all_speaker_ids = torch.cat( [batch["speaker_id"], batch["speaker_id"]], dim=0 ) # 2B cs_loss = self.contrastive_speaker_loss(all_ref_emb, all_speaker_ids) * 0.25 total_loss += cs_loss train_losses["ref_loss"] = cs_loss diff_loss_x0 = diff_loss(diff_out["x0_pred"], mel, mask=mask) total_loss += diff_loss_x0 train_losses["diff_loss_x0"] = diff_loss_x0 diff_loss_noise = diff_loss( diff_out["noise_pred"], diff_out["noise"], mask=mask ) total_loss += diff_loss_noise train_losses["diff_loss_noise"] = diff_loss_noise train_losses["total_loss"] = total_loss self.optimizer.zero_grad() self.accelerator.backward(total_loss) if self.accelerator.sync_gradients: self.accelerator.clip_grad_norm_( filter(lambda p: p.requires_grad, self.model.parameters()), 0.5 ) self.optimizer.step() self.scheduler.step() for item in train_losses: train_losses[item] = train_losses[item].item() train_losses["learning_rate"] = f"{self.optimizer.param_groups[0]['lr']:.1e}" train_losses["batch_size"] = batch["speaker_id"].shape[0] return (train_losses["total_loss"], train_losses, None) def _train_epoch(self): r"""Training epoch. Should return average loss of a batch (sample) over one epoch. See ``train_loop`` for usage. """ if isinstance(self.model, dict): for key in self.model.keys(): self.model[key].train() else: self.model.train() if isinstance(self.w2v, dict): for key in self.w2v.keys(): self.w2v[key].eval() else: self.w2v.eval() epoch_sum_loss: float = 0.0 # total loss # Put the data to cuda device device = self.accelerator.device with device: torch.cuda.empty_cache() self.model = self.model.to(device) self.w2v = self.w2v.to(device) for batch in tqdm( self.train_dataloader, desc=f"Training Epoch {self.epoch}", unit="batch", colour="GREEN", leave=False, dynamic_ncols=True, smoothing=0.04, disable=not self.accelerator.is_main_process, ): speech = batch["speech"].cpu().numpy() speech = speech[0] self.batch_count += 1 self.step += 1 if len(speech) >= 16000 * 25: continue with self.accelerator.accumulate(self.model): total_loss, train_losses, _ = self._train_step(batch) if self.batch_count % self.cfg.train.gradient_accumulation_step == 0: epoch_sum_loss += total_loss self.current_loss = total_loss if isinstance(train_losses, dict): for key, loss in train_losses.items(): self.accelerator.log( {"Epoch/Train {} Loss".format(key): loss}, step=self.step, ) if self.accelerator.is_main_process and self.batch_count % 10 == 0: self.echo_log(train_losses, mode="Training") self.save_checkpoint() self.accelerator.wait_for_everyone() return epoch_sum_loss, None def train_loop(self): r"""Training loop. The public entry of training process.""" # Wait everyone to prepare before we move on self.accelerator.wait_for_everyone() # Dump config file if self.accelerator.is_main_process: self._dump_cfg(self.config_save_path) # Wait to ensure good to go self.accelerator.wait_for_everyone() # Stop when meeting max epoch or self.cfg.train.num_train_steps while ( self.epoch < self.max_epoch and self.step < self.cfg.train.num_train_steps ): if self.accelerator.is_main_process: self.logger.info("\n") self.logger.info("-" * 32) self.logger.info("Epoch {}: ".format(self.epoch)) self.logger.info("Start training...") train_total_loss, _ = self._train_epoch() self.epoch += 1 self.accelerator.wait_for_everyone() if isinstance(self.scheduler, dict): for key in self.scheduler.keys(): self.scheduler[key].step() else: self.scheduler.step() # Finish training and save final checkpoint self.accelerator.wait_for_everyone() if self.accelerator.is_main_process: self.accelerator.save_state( os.path.join( self.checkpoint_dir, "final_epoch-{:04d}_step-{:07d}_loss-{:.6f}".format( self.epoch, self.step, train_total_loss ), ) ) self.accelerator.end_training() if self.accelerator.is_main_process: self.logger.info("Training finished...") def save_checkpoint(self): self.accelerator.wait_for_everyone() # Main process only if self.accelerator.is_main_process: if self.batch_count % self.save_checkpoint_stride[0] == 0: keep_last = self.keep_last[0] # Read all folders in self.checkpoint_dir all_ckpts = os.listdir(self.checkpoint_dir) # Exclude non-folders all_ckpts = [ ckpt for ckpt in all_ckpts if os.path.isdir(os.path.join(self.checkpoint_dir, ckpt)) ] if len(all_ckpts) > keep_last: # Keep only the last keep_last folders in self.checkpoint_dir, sorted by step "epoch-{:04d}_step-{:07d}_loss-{:.6f}" all_ckpts = sorted( all_ckpts, key=lambda x: int(x.split("_")[1].split("-")[1]) ) for ckpt in all_ckpts[:-keep_last]: shutil.rmtree(os.path.join(self.checkpoint_dir, ckpt)) checkpoint_filename = "epoch-{:04d}_step-{:07d}_loss-{:.6f}".format( self.epoch, self.step, self.current_loss ) path = os.path.join(self.checkpoint_dir, checkpoint_filename) self.logger.info("Saving state to {}...".format(path)) self.accelerator.save_state(path) self.logger.info("Finished saving state.") self.accelerator.wait_for_everyone() def echo_log(self, losses, mode="Training"): message = [ "{} - Epoch {} Step {}: [{:.3f} s/step]".format( mode, self.epoch + 1, self.step, self.time_window.average ) ] for key in sorted(losses.keys()): if isinstance(losses[key], dict): for k, v in losses[key].items(): message.append( str(k).split("/")[-1] + "=" + str(round(float(v), 5)) ) else: message.append( str(key).split("/")[-1] + "=" + str(round(float(losses[key]), 5)) ) self.logger.info(", ".join(message))