|
import logging |
|
|
|
import torch.nn as nn |
|
from transformers import HubertModel, Wav2Vec2FeatureExtractor |
|
|
|
logging.getLogger("numba").setLevel(logging.WARNING) |
|
|
|
|
|
class CNHubert(nn.Module): |
|
def __init__(self, cnhubert_base_path): |
|
super().__init__() |
|
self.model = HubertModel.from_pretrained(cnhubert_base_path) |
|
self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(cnhubert_base_path) |
|
|
|
def forward(self, x): |
|
input_values = self.feature_extractor(x, return_tensors="pt", sampling_rate=16000).input_values.to(x.device) |
|
feats = self.model(input_values)["last_hidden_state"] |
|
return feats |
|
|
|
|
|
def get_model(cnhubert_base_path): |
|
model = CNHubert(cnhubert_base_path) |
|
model.eval() |
|
return model |
|
|