File size: 934 Bytes
c1f2d61
 
 
 
 
d1ab29e
 
 
 
 
c1f2d61
 
 
d1ab29e
c1f2d61
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import torch
import bigvgan
from huggingface_hub import hf_hub_download

class BigVGANVocoder:
    def __init__(self, device=None):
        # Set default device if none provided
        if device is None:
            device = 'cuda' if torch.cuda.is_available() else 'cpu'
            
        # Load the pretrained model
        self.model = bigvgan.BigVGAN.from_pretrained(
            'nvidia/bigvgan_v2_44khz_128band_512x',
            use_cuda_kernel=(device == 'cuda')
        )
        self.model.remove_weight_norm()
        self.model.eval().to(device)
        self.device = device
        self.h = self.model.h  # This holds config like sampling_rate, etc.

    @torch.no_grad()
    def infer_waveform(self, mel):
        # mel shape: [B, n_mels, T], BigVGAN expects mel at model.h.n_mels, typically 128
        mel = mel.to(self.device)
        wav_gen = self.model(mel)
        return wav_gen.squeeze(1)  # Returns [B, T]