jacob-c's picture
diffuserfix
d1ab29e
raw
history blame contribute delete
934 Bytes
import torch
import bigvgan
from huggingface_hub import hf_hub_download
class BigVGANVocoder:
def __init__(self, device=None):
# Set default device if none provided
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Load the pretrained model
self.model = bigvgan.BigVGAN.from_pretrained(
'nvidia/bigvgan_v2_44khz_128band_512x',
use_cuda_kernel=(device == 'cuda')
)
self.model.remove_weight_norm()
self.model.eval().to(device)
self.device = device
self.h = self.model.h # This holds config like sampling_rate, etc.
@torch.no_grad()
def infer_waveform(self, mel):
# mel shape: [B, n_mels, T], BigVGAN expects mel at model.h.n_mels, typically 128
mel = mel.to(self.device)
wav_gen = self.model(mel)
return wav_gen.squeeze(1) # Returns [B, T]