|
|
|
|
|
""" |
|
https://arxiv.org/abs/2306.16962 |
|
|
|
https://huggingface.co/audeering/wav2vec2-large-robust-24-ft-age-gender |
|
""" |
|
import argparse |
|
|
|
import torch |
|
import torch.nn as nn |
|
import librosa |
|
from transformers import Wav2Vec2Processor |
|
from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2Model, Wav2Vec2PreTrainedModel |
|
|
|
from project_settings import project_path |
|
|
|
|
|
def get_args(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"--model_path", |
|
|
|
default=(project_path / "pretrained_models/wav2vec2-large-robust-6-ft-age-gender").as_posix(), |
|
type=str, |
|
) |
|
parser.add_argument( |
|
"--speech_file", |
|
|
|
|
|
|
|
|
|
|
|
default=(project_path / "data/examples/speech-male-1.wav").as_posix(), |
|
type=str, |
|
) |
|
args = parser.parse_args() |
|
return args |
|
|
|
|
|
class ModelHead(nn.Module): |
|
def __init__(self, config, num_labels): |
|
super().__init__() |
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size) |
|
self.dropout = nn.Dropout(config.final_dropout) |
|
self.out_proj = nn.Linear(config.hidden_size, num_labels) |
|
|
|
def forward(self, features, **kwargs): |
|
x = features |
|
x = self.dropout(x) |
|
x = self.dense(x) |
|
x = torch.tanh(x) |
|
x = self.dropout(x) |
|
x = self.out_proj(x) |
|
return x |
|
|
|
|
|
class AgeGenderModel(Wav2Vec2PreTrainedModel): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.config = config |
|
self.wav2vec2 = Wav2Vec2Model(config) |
|
self.age = ModelHead(config, 1) |
|
self.gender = ModelHead(config, 3) |
|
self.init_weights() |
|
|
|
def forward(self, |
|
input_values, |
|
): |
|
outputs = self.wav2vec2(input_values) |
|
hidden_states = outputs[0] |
|
hidden_states = torch.mean(hidden_states, dim=1) |
|
|
|
logits_age = self.age.forward(hidden_states) |
|
logits_gender = torch.softmax(self.gender.forward(hidden_states), dim=1) |
|
|
|
return hidden_states, logits_age, logits_gender |
|
|
|
|
|
def main(): |
|
args = get_args() |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
processor: Wav2Vec2Processor = Wav2Vec2Processor.from_pretrained(args.model_path) |
|
model = AgeGenderModel.from_pretrained(args.model_path) |
|
model.eval() |
|
|
|
|
|
signal, sample_rate = librosa.load(args.speech_file, sr=16000) |
|
|
|
y = processor.__call__(signal, sampling_rate=sample_rate) |
|
y = y['input_values'][0] |
|
y = y.reshape(1, -1) |
|
y = torch.from_numpy(y).to(device) |
|
|
|
_, age, gender = model.forward(y) |
|
print(f"age: {age}") |
|
|
|
print(f"gender: {gender}") |
|
|
|
return |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|