HoneyTian's picture
update
027a68c
#!/usr/bin/python3
# -*- coding: utf-8 -*-
"""
https://arxiv.org/abs/2306.16962
https://huggingface.co/audeering/wav2vec2-large-robust-24-ft-age-gender
"""
import argparse
import torch
import torch.nn as nn
import librosa
from transformers import Wav2Vec2Processor
from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2Model, Wav2Vec2PreTrainedModel
from project_settings import project_path
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_path",
# default=(project_path / "pretrained_models/wav2vec2-large-robust-6-ft-age-gender").as_posix(),
default=(project_path / "pretrained_models/wav2vec2-large-robust-6-ft-age-gender").as_posix(),
type=str,
)
parser.add_argument(
"--speech_file",
# default=(project_path / "data/examples/voicemail-female-1.wav").as_posix(),
# default=(project_path / "data/examples/voicemail-female-2.wav").as_posix(),
# default=(project_path / "data/examples/voicemail-female-3.wav").as_posix(),
# default=(project_path / "data/examples/voicemail-male-1.wav").as_posix(),
# default=(project_path / "data/examples/voicemail-male-2.wav").as_posix(),
default=(project_path / "data/examples/speech-male-1.wav").as_posix(),
type=str,
)
args = parser.parse_args()
return args
class ModelHead(nn.Module):
def __init__(self, config, num_labels):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.dropout = nn.Dropout(config.final_dropout)
self.out_proj = nn.Linear(config.hidden_size, num_labels)
def forward(self, features, **kwargs):
x = features
x = self.dropout(x)
x = self.dense(x)
x = torch.tanh(x)
x = self.dropout(x)
x = self.out_proj(x)
return x
class AgeGenderModel(Wav2Vec2PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.config = config
self.wav2vec2 = Wav2Vec2Model(config)
self.age = ModelHead(config, 1)
self.gender = ModelHead(config, 3)
self.init_weights()
def forward(self,
input_values,
):
outputs = self.wav2vec2(input_values)
hidden_states = outputs[0]
hidden_states = torch.mean(hidden_states, dim=1)
logits_age = self.age.forward(hidden_states)
logits_gender = torch.softmax(self.gender.forward(hidden_states), dim=1)
return hidden_states, logits_age, logits_gender
def main():
args = get_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
processor: Wav2Vec2Processor = Wav2Vec2Processor.from_pretrained(args.model_path)
model = AgeGenderModel.from_pretrained(args.model_path)
model.eval()
# signal
signal, sample_rate = librosa.load(args.speech_file, sr=16000)
y = processor.__call__(signal, sampling_rate=sample_rate)
y = y['input_values'][0]
y = y.reshape(1, -1)
y = torch.from_numpy(y).to(device)
_, age, gender = model.forward(y)
print(f"age: {age}")
# female male child
print(f"gender: {gender}")
return
if __name__ == '__main__':
main()