FAYO
model
1ef9436
import logging
import torch.nn as nn
from transformers import HubertModel, Wav2Vec2FeatureExtractor
logging.getLogger("numba").setLevel(logging.WARNING)
class CNHubert(nn.Module):
def __init__(self, cnhubert_base_path):
super().__init__()
self.model = HubertModel.from_pretrained(cnhubert_base_path)
self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(cnhubert_base_path)
def forward(self, x):
input_values = self.feature_extractor(x, return_tensors="pt", sampling_rate=16000).input_values.to(x.device)
feats = self.model(input_values)["last_hidden_state"]
return feats
def get_model(cnhubert_base_path):
model = CNHubert(cnhubert_base_path)
model.eval()
return model