File size: 3,280 Bytes
027a68c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
#!/usr/bin/python3
# -*- coding: utf-8 -*-
"""
https://arxiv.org/abs/2306.16962

https://huggingface.co/audeering/wav2vec2-large-robust-24-ft-age-gender
"""
import argparse

import torch
import torch.nn as nn
import librosa
from transformers import Wav2Vec2Processor
from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2Model, Wav2Vec2PreTrainedModel

from project_settings import project_path


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model_path",
        # default=(project_path / "pretrained_models/wav2vec2-large-robust-6-ft-age-gender").as_posix(),
        default=(project_path / "pretrained_models/wav2vec2-large-robust-6-ft-age-gender").as_posix(),
        type=str,
    )
    parser.add_argument(
        "--speech_file",
        # default=(project_path / "data/examples/voicemail-female-1.wav").as_posix(),
        # default=(project_path / "data/examples/voicemail-female-2.wav").as_posix(),
        # default=(project_path / "data/examples/voicemail-female-3.wav").as_posix(),
        # default=(project_path / "data/examples/voicemail-male-1.wav").as_posix(),
        # default=(project_path / "data/examples/voicemail-male-2.wav").as_posix(),
        default=(project_path / "data/examples/speech-male-1.wav").as_posix(),
        type=str,
    )
    args = parser.parse_args()
    return args


class ModelHead(nn.Module):
    def __init__(self, config, num_labels):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.dropout = nn.Dropout(config.final_dropout)
        self.out_proj = nn.Linear(config.hidden_size, num_labels)

    def forward(self, features, **kwargs):
        x = features
        x = self.dropout(x)
        x = self.dense(x)
        x = torch.tanh(x)
        x = self.dropout(x)
        x = self.out_proj(x)
        return x


class AgeGenderModel(Wav2Vec2PreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.config = config
        self.wav2vec2 = Wav2Vec2Model(config)
        self.age = ModelHead(config, 1)
        self.gender = ModelHead(config, 3)
        self.init_weights()

    def forward(self,
                input_values,
                ):
        outputs = self.wav2vec2(input_values)
        hidden_states = outputs[0]
        hidden_states = torch.mean(hidden_states, dim=1)

        logits_age = self.age.forward(hidden_states)
        logits_gender = torch.softmax(self.gender.forward(hidden_states), dim=1)

        return hidden_states, logits_age, logits_gender


def main():
    args = get_args()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    processor: Wav2Vec2Processor = Wav2Vec2Processor.from_pretrained(args.model_path)
    model = AgeGenderModel.from_pretrained(args.model_path)
    model.eval()

    # signal
    signal, sample_rate = librosa.load(args.speech_file, sr=16000)

    y = processor.__call__(signal, sampling_rate=sample_rate)
    y = y['input_values'][0]
    y = y.reshape(1, -1)
    y = torch.from_numpy(y).to(device)

    _, age, gender = model.forward(y)
    print(f"age: {age}")
    # female     male       child
    print(f"gender: {gender}")

    return


if __name__ == '__main__':
    main()