#!/usr/bin/python3 # -*- coding: utf-8 -*- """ https://arxiv.org/abs/2306.16962 https://huggingface.co/audeering/wav2vec2-large-robust-24-ft-age-gender """ import argparse import torch import torch.nn as nn import librosa from transformers import Wav2Vec2Processor from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2Model, Wav2Vec2PreTrainedModel from project_settings import project_path def get_args(): parser = argparse.ArgumentParser() parser.add_argument( "--model_path", # default=(project_path / "pretrained_models/wav2vec2-large-robust-6-ft-age-gender").as_posix(), default=(project_path / "pretrained_models/wav2vec2-large-robust-6-ft-age-gender").as_posix(), type=str, ) parser.add_argument( "--speech_file", # default=(project_path / "data/examples/voicemail-female-1.wav").as_posix(), # default=(project_path / "data/examples/voicemail-female-2.wav").as_posix(), # default=(project_path / "data/examples/voicemail-female-3.wav").as_posix(), # default=(project_path / "data/examples/voicemail-male-1.wav").as_posix(), # default=(project_path / "data/examples/voicemail-male-2.wav").as_posix(), default=(project_path / "data/examples/speech-male-1.wav").as_posix(), type=str, ) args = parser.parse_args() return args class ModelHead(nn.Module): def __init__(self, config, num_labels): super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dropout = nn.Dropout(config.final_dropout) self.out_proj = nn.Linear(config.hidden_size, num_labels) def forward(self, features, **kwargs): x = features x = self.dropout(x) x = self.dense(x) x = torch.tanh(x) x = self.dropout(x) x = self.out_proj(x) return x class AgeGenderModel(Wav2Vec2PreTrainedModel): def __init__(self, config): super().__init__(config) self.config = config self.wav2vec2 = Wav2Vec2Model(config) self.age = ModelHead(config, 1) self.gender = ModelHead(config, 3) self.init_weights() def forward(self, input_values, ): outputs = self.wav2vec2(input_values) hidden_states = outputs[0] hidden_states = torch.mean(hidden_states, dim=1) logits_age = self.age.forward(hidden_states) logits_gender = torch.softmax(self.gender.forward(hidden_states), dim=1) return hidden_states, logits_age, logits_gender def main(): args = get_args() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") processor: Wav2Vec2Processor = Wav2Vec2Processor.from_pretrained(args.model_path) model = AgeGenderModel.from_pretrained(args.model_path) model.eval() # signal signal, sample_rate = librosa.load(args.speech_file, sr=16000) y = processor.__call__(signal, sampling_rate=sample_rate) y = y['input_values'][0] y = y.reshape(1, -1) y = torch.from_numpy(y).to(device) _, age, gender = model.forward(y) print(f"age: {age}") # female male child print(f"gender: {gender}") return if __name__ == '__main__': main()