JuanJoseMV commited on
Commit
8f96165
·
1 Parent(s): 9995d17

add model logic implementation

Browse files
.gitignore CHANGED
@@ -16,4 +16,6 @@ build/
16
 
17
  # VSCode
18
  .vscode/
19
- *.code-workspace
 
 
 
16
 
17
  # VSCode
18
  .vscode/
19
+ *.code-workspace
20
+
21
+ behaviour_model/
README.md CHANGED
@@ -9,6 +9,14 @@ app_file: app.py
9
  pinned: false
10
  license: mit
11
  short_description: Audio and Text Emotion Recognition
 
 
 
 
 
 
 
 
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
9
  pinned: false
10
  license: mit
11
  short_description: Audio and Text Emotion Recognition
12
+ models:
13
+ - links-ads/kk-speech-emotion-recognition
14
+ - openai/whisper-large-v3
15
+ - facebook/wav2vec2-large-xlsr-53
16
+ preload_from_hub:
17
+ - links-ads/kk-speech-emotion-recognition
18
+ - openai/whisper-large-v3
19
+ - facebook/wav2vec2-large-xlsr-53
20
  ---
21
 
22
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -1,13 +1,40 @@
 
1
  import gradio as gr
2
- import plotly.graph_objects as go
3
  from src.load_html import get_description_html
4
-
5
- def process_audio(audio_file):
6
- ...
 
 
 
7
 
8
 
9
  # Gradio interface
10
  def create_demo():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  with gr.Blocks() as demo:
12
  gr.HTML(get_description_html)
13
 
@@ -17,7 +44,7 @@ def create_demo():
17
  graph_output = gr.Plot(label="Generated Graph")
18
 
19
  submit_button.click(
20
- fn=process_audio,
21
  inputs=audio_input,
22
  outputs=graph_output
23
  )
 
1
+ import torch
2
  import gradio as gr
 
3
  from src.load_html import get_description_html
4
+ from src.audio_processor import AudioProcessor
5
+ from src.model.behaviour_model import get_behaviour_model
6
+ from transformers import (
7
+ pipeline,
8
+ WavLMForSequenceClassification
9
+ )
10
 
11
 
12
  # Gradio interface
13
  def create_demo():
14
+ device = "cuda" if torch.cuda.is_available() else "cpu"
15
+ segmentation_model = pipeline(
16
+ task="automatic-speech-recognition",
17
+ model="openai/whisper-large-v3-turbo",
18
+ tokenizer="openai/whisper-large-v3-turbo",
19
+ device=device
20
+ )
21
+
22
+ emotion_model = WavLMForSequenceClassification.from_pretrained("links-ads/kk-speech-emotion-recognition")
23
+ emotion_model.to(device)
24
+ emotion_model.eval()
25
+
26
+ behaviour_model = get_behaviour_model(
27
+ behaviour_model_path="behaviour_model/",
28
+ device=device,
29
+ )
30
+
31
+ audio_processor = AudioProcessor(
32
+ emotion_model=emotion_model,
33
+ segmentation_model=segmentation_model,
34
+ device=device,
35
+ behaviour_model=behaviour_model,
36
+ )
37
+
38
  with gr.Blocks() as demo:
39
  gr.HTML(get_description_html)
40
 
 
44
  graph_output = gr.Plot(label="Generated Graph")
45
 
46
  submit_button.click(
47
+ fn=audio_processor,
48
  inputs=audio_input,
49
  outputs=graph_output
50
  )
requirements.txt CHANGED
@@ -1,2 +1,5 @@
1
  gradio==5.24.0
2
- plotly==6.0.1
 
 
 
 
1
  gradio==5.24.0
2
+ plotly==6.0.1
3
+ torch==2.7.0
4
+ librosa==0.11.0
5
+ transformers==4.51.3
src/audio_processor.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import torch
3
+ import librosa
4
+ import numpy as np
5
+ import gradio as gr
6
+ import gradio as gr
7
+ from .generate_graph import create_behaviour_gantt_plot
8
+ from transformers import Wav2Vec2Processor
9
+
10
+
11
+ SAMPLING_RATE = 16_000
12
+
13
+ class AudioProcessor:
14
+ def __init__(
15
+ self,
16
+ emotion_model,
17
+ segmentation_model,
18
+ device,
19
+ behaviour_model=None,
20
+ ):
21
+ self.emotion_processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
22
+ self.emotion_model = emotion_model
23
+ self.behaviour_model = behaviour_model
24
+ self.device = device
25
+ self.audio_emotion_labels = {
26
+ 0: "Neutralità",
27
+ 1: "Rabbia",
28
+ 2: "Paura",
29
+ 3: "Gioia",
30
+ 4: "Sorpresa",
31
+ 5: "Tristezza",
32
+ 6: "Disgusto",
33
+ }
34
+ self.emotion_translation = {
35
+ "neutrality": "Neutralità",
36
+ "anger": "Rabbia",
37
+ "fear": "Paura",
38
+ "joy": "Gioia",
39
+ "surprise": "Sorpresa",
40
+ "sadness": "Tristezza",
41
+ "disgust": "Disgusto"
42
+ }
43
+ self.behaviour_labels = {
44
+ 0: "frustrated",
45
+ 1: "delighted",
46
+ 2: "dysregulated",
47
+ }
48
+ self.behaviour_translation = {
49
+ "frustrated": "frustazione",
50
+ "delighted": "incantato",
51
+ "dysregulated": "disregolazione",
52
+ }
53
+ self.segmentation_model = segmentation_model
54
+
55
+ self._set_emotion_model()
56
+ if self.behaviour_model:
57
+ self._set_behaviour_model()
58
+
59
+ self.behaviour_confidence = 0.6
60
+
61
+ self.chart_generator = None
62
+
63
+ def _set_emotion_model(self):
64
+ self.emotion_model.to(self.device)
65
+ self.emotion_model.eval()
66
+
67
+ def _set_behaviour_model(self):
68
+ self.behaviour_model.to(self.device)
69
+ self.behaviour_model.eval()
70
+
71
+ def _prepare_transcribed_text(self, chunks):
72
+ formated_timestamps = []
73
+ predictions = []
74
+
75
+ for chunk in chunks:
76
+ start = chunk[0] / SAMPLING_RATE
77
+ end = chunk[1] / SAMPLING_RATE
78
+ formated_start = time.strftime('%H:%M:%S', time.gmtime(start))
79
+ formated_end = time.strftime('%H:%M:%S', time.gmtime(end))
80
+ formated_timestamps.append(f"**({formated_start} - {formated_end})**")
81
+
82
+ predictions.append(f"**[{chunk[2]}]**")
83
+
84
+ transcribed_texts = [chunk[3] for chunk in chunks]
85
+ transcribed_text = "<br/>".join(
86
+ [
87
+ f"{formated_timestamps[i]}: {transcribed_texts[i]} {predictions[i]}" for i in range(len(transcribed_texts))
88
+ ]
89
+ )
90
+
91
+ print(f"Transcribed text:\n{transcribed_text}")
92
+
93
+
94
+ return transcribed_text
95
+
96
+ def __call__(self, audio_path: str):
97
+ """
98
+ Predicts the emotion label for a given audio input.
99
+
100
+ Args:
101
+ audio (filepath): The audio input path to be processed.
102
+
103
+ Returns:
104
+ str: The predicted emotion label.
105
+
106
+ """
107
+
108
+ print("Segmenting audio...")
109
+ out = self.segmentation_model(
110
+ inputs=audio_path,
111
+ return_timestamps=True,
112
+ )
113
+
114
+ emotion_chunks = []
115
+ behaviour_chunks = []
116
+ timestamps = []
117
+ predicted_labels = []
118
+ all_probabilities = []
119
+
120
+ print("Analizing chunks...")
121
+ for chunk in out["chunks"]:
122
+ # trim audio from timestamps
123
+ start = int(chunk["timestamp"][0] * SAMPLING_RATE)
124
+ end = int(chunk["timestamp"][1] * SAMPLING_RATE if chunk["timestamp"][1] else len(input_frames))
125
+
126
+ audio = input_frames[start:end]
127
+
128
+ inputs = self.emotion_processor(audio, chunk["text"], return_tensors="pt", sampling_rate=SAMPLING_RATE)
129
+
130
+ print(f"Inputs: {inputs}")
131
+
132
+ if "input_values" in inputs:
133
+ inputs["input_features"] = inputs.pop("input_values")
134
+
135
+ inputs['input_features'] = inputs['input_features'].to(self.device)
136
+ inputs['input_ids'] = inputs['input_ids'].to(self.device)
137
+ inputs['text_attention_mask'] = inputs['text_attention_mask'].to(self.device)
138
+
139
+ print("Predicting emotion for chunk...")
140
+ logits = self.emotion_model(**inputs).logits
141
+
142
+ logits = logits.detach().cpu()
143
+ softmax = torch.nn.Softmax(dim=1)
144
+ probabilities = softmax(logits).squeeze(0)
145
+
146
+ prediction = probabilities.argmax().item()
147
+
148
+ predicted_label = self.emotion_processor.config.id2label[prediction]
149
+ label_translation = self.emotion_translation[predicted_label]
150
+
151
+ emotion_chunks.append(
152
+ (
153
+ start,
154
+ end,
155
+ label_translation,
156
+ chunk["text"],
157
+ np.round(probabilities[prediction].item(), 2)
158
+ )
159
+ )
160
+
161
+ timestamps.append((start, end))
162
+ predicted_labels.append(label_translation)
163
+ all_probabilities.append(probabilities[prediction].item())
164
+
165
+
166
+ inputs = self.emotion_processor(audio, return_tensors="pt", sampling_rate=SAMPLING_RATE)
167
+ if "input_values" in inputs:
168
+ inputs["input_features"] = inputs.pop("input_values")
169
+
170
+ inputs = inputs.input_features.to(self.device)
171
+ print("Predicting behaviour for chunk...")
172
+ logits = self.behaviour_model(inputs).logits
173
+ probabilities = torch.nn.functional.softmax(logits.detach().cpu(), dim=-1).squeeze()
174
+ behaviour_chunks.append(
175
+ (
176
+ start,
177
+ end,
178
+ chunk["text"],
179
+ np.round(probabilities[2].item(), 2),
180
+ label_translation,
181
+ )
182
+ )
183
+ behaviour_gantt = create_behaviour_gantt_plot(behaviour_chunks)
184
+
185
+ # transcribed_text = self._prepare_transcribed_text(emotion_chunks)
186
+
187
+ return (
188
+ behaviour_gantt,
189
+ # transcribed_text,
190
+ )
src/generate_graph.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import plotly.graph_objects as go
3
+ from datetime import datetime, timedelta
4
+
5
+
6
+ SAMPLING_RATE = 16_000
7
+ COLOR_MAP = {
8
+ "Neutralità": "rgb(178, 178, 178)",
9
+ "Rabbia": "rgb(160, 61, 62)",
10
+ "Paura": "rgb(91, 57, 136)",
11
+ "Gioia": "rgb(255, 255, 0)",
12
+ "Sorpresa": "rgb(60, 175, 175)",
13
+ "Tristezza": "rgb(64, 106, 173)",
14
+ "Disgusto": "rgb(100, 153, 65)",
15
+ }
16
+
17
+ def create_behaviour_gantt_plot(behaviour_chunks, confidence_threshold=60):
18
+ print("Creating behaviour Gantt plot...")
19
+ emotion_order = [
20
+ "Gioia",
21
+ "Sorpresa",
22
+ "Disgusto",
23
+ "Tristezza",
24
+ "Paura",
25
+ "Rabbia",
26
+ "Neutralità"
27
+ ]
28
+
29
+ fig = go.Figure()
30
+
31
+ chunk_starts = [start/SAMPLING_RATE for start, _, _, _, _ in behaviour_chunks]
32
+ chunk_ends = [end/SAMPLING_RATE for _, end, _, _, _ in behaviour_chunks]
33
+
34
+ # Create reference time for plotting (starting at 0)
35
+ # We'll use a base datetime and add seconds
36
+ base_time = datetime(2_000, 1, 1, 0, 0, 0) # TODO: change magic numbers
37
+
38
+ start_times = [base_time + timedelta(seconds=t) for t in chunk_starts]
39
+ end_times = [base_time + timedelta(seconds=t) for t in chunk_ends]
40
+
41
+ # Calculate midpoints for each chunk (for trend line)
42
+ mid_times = [base_time + timedelta(seconds=(s+e)/2) for s, e in zip(chunk_starts, chunk_ends)]
43
+
44
+ heights = [height * 100 for _, _, _, height, _ in behaviour_chunks]
45
+
46
+ emotions = [emotion for _, _, _, _, emotion in behaviour_chunks]
47
+
48
+ hover_texts = []
49
+ for i, (start, end, label, height, emotion) in enumerate(behaviour_chunks):
50
+ start_fmt = time.strftime('%H:%M:%S', time.gmtime(start / SAMPLING_RATE))
51
+ end_fmt = time.strftime('%H:%M:%S', time.gmtime(end / SAMPLING_RATE))
52
+ duration_seconds = (end - start) / SAMPLING_RATE
53
+ duration_str = time.strftime('%H:%M:%S', time.gmtime(duration_seconds))
54
+
55
+ hover_text = f"Inizio: {start_fmt}<br>Fine: {end_fmt}<br>Durata: {duration_str}<br>Testo: {label}<br>Attendibilità: {height*100:.2f}%<br>Emozione: {emotion}"
56
+ hover_texts.append(hover_text)
57
+
58
+ fig.add_shape(
59
+ type="rect",
60
+ x0=start_times[0],
61
+ x1=end_times[-1],
62
+ y0=confidence_threshold,
63
+ y1=100,
64
+ fillcolor="rgba(188,223,241,0.8)",
65
+ opacity=0.8,
66
+ layer="below",
67
+ line_width=0,
68
+ )
69
+
70
+ fig.add_hline(y=confidence_threshold, line_dash="dash", line_color="black", line_width=1)
71
+
72
+ fig.add_trace(
73
+ go.Scatter(
74
+ x=mid_times,
75
+ y=heights,
76
+ mode='lines',
77
+ name='Disregolazione',
78
+ line=dict(
79
+ color='orange',
80
+ width=2,
81
+ shape='spline', # This enables smoothing
82
+ smoothing=1.0, # Adjust smoothing factor
83
+ ),
84
+ text=hover_texts,
85
+ hoverinfo='text',
86
+ showlegend=False,
87
+ )
88
+ )
89
+
90
+ emotion_data = {}
91
+
92
+ for i, height in enumerate(heights):
93
+ if height >= confidence_threshold:
94
+ emotion = emotions[i]
95
+ if emotion not in emotion_data:
96
+ emotion_data[emotion] = {
97
+ 'times': [],
98
+ 'heights': [],
99
+ 'hover_texts': []
100
+ }
101
+
102
+ emotion_data[emotion]['times'].append(mid_times[i])
103
+ emotion_data[emotion]['heights'].append(height)
104
+ emotion_data[emotion]['hover_texts'].append(hover_texts[i])
105
+
106
+ for emotion in emotion_order:
107
+ color = COLOR_MAP.get(emotion, '#000000')
108
+
109
+ if emotion in emotion_data:
110
+ data = emotion_data[emotion]
111
+ fig.add_trace(
112
+ go.Scatter(
113
+ x=data['times'],
114
+ y=data['heights'],
115
+ mode='markers',
116
+ name=emotion.capitalize(),
117
+ marker=dict(
118
+ size=15,
119
+ color=color,
120
+ symbol='circle'
121
+ ),
122
+ text=data['hover_texts'],
123
+ hoverinfo='text',
124
+ showlegend=True,
125
+ )
126
+ )
127
+ else:
128
+ fig.add_trace(
129
+ go.Scatter(
130
+ x=[None],
131
+ y=[None],
132
+ mode='markers',
133
+ name=emotion.capitalize(),
134
+ marker=dict(
135
+ size=15,
136
+ color=color,
137
+ symbol='circle'
138
+ ),
139
+ showlegend=True,
140
+ )
141
+ )
142
+
143
+ fig.update_layout(
144
+ title='Distribuzione della disregolazione',
145
+ xaxis_title='Tempo',
146
+ yaxis_title='Attendibilità',
147
+ xaxis=dict(
148
+ type='date',
149
+ tickformat='%H:%M:%S',
150
+ showline=True,
151
+ zeroline=False,
152
+ side='bottom',
153
+ showgrid=False,
154
+ ),
155
+ yaxis=dict(
156
+ range=[0, 100],
157
+ tickvals=[0, 20, 40, 60, 80, 100],
158
+ ticktext=['0%', '20%', '40%', '60%', '80%', '100%'],
159
+ tickmode='array',
160
+ showgrid=False,
161
+ ),
162
+ legend_title=None,
163
+ legend=dict(
164
+ yanchor="top"
165
+ ),
166
+ hoverlabel=dict(
167
+ font_size=12,
168
+ font_family="Arial"
169
+ ),
170
+ paper_bgcolor='white',
171
+ plot_bgcolor='white',
172
+ )
173
+
174
+ fig.update_traces(hovertemplate=None)
175
+
176
+ return fig
src/model/behaviour_model.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import torch
4
+ from .make_model import make_model
5
+
6
+
7
+ hparams_dict = {
8
+ 'HF_MODEL_PATH': 'facebook/wav2vec2-large-xlsr-53',
9
+ 'DATASET': 'recanvo',
10
+ 'MAX_DURATION': 4,
11
+ 'SAMPLING_RATE': 16_000,
12
+ 'OUTPUT_HIDDEN_STATES': True,
13
+ 'CLASSIFIER_NAME': 'multilevel',
14
+ 'CLASSIFIER_PROJ_SIZE': 256,
15
+ 'NUM_LABELS': 3,
16
+ 'LABEL_WEIGHTS': [1.0],
17
+ 'LOSS': 'cross-entropy',
18
+ 'GPU_ID': 0,
19
+ 'RETURN_RAW_ARRAY': False,
20
+ }
21
+ hparams = argparse.Namespace(**hparams_dict)
22
+
23
+ def get_behaviour_model(behaviour_model_path, device):
24
+ state_dict = torch.load(os.path.join(behaviour_model_path, 'pytorch_model.bin'), map_location=device)
25
+ model = make_model(hparams)
26
+ model.load_state_dict(state_dict)
27
+
28
+ return model
src/model/custom_model.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from typing import Optional, Union, Tuple
5
+ from transformers.modeling_outputs import SequenceClassifierOutput
6
+ from .wav2vec2_wrapper import Wav2VecWrapper
7
+ from .multilevel_classifier import MultiLevelDownstreamModel
8
+
9
+ class CustomModelForAudioClassification(nn.Module):
10
+ def __init__(self, config):
11
+ super().__init__()
12
+ assert config.output_hidden_states == True, "The upstream model must return all hidden states"
13
+ self.config = config
14
+ self.encoder = Wav2VecWrapper(config)
15
+
16
+ self.classifier = MultiLevelDownstreamModel(config, use_conv_output=True)
17
+
18
+ def forward(
19
+ self,
20
+ input_features: Optional[torch.LongTensor],
21
+ length: Optional[torch.LongTensor] = None,
22
+ encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
23
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
24
+ if encoder_outputs is None:
25
+ encoder_output = self.encoder(
26
+ input_features,
27
+ length=length,
28
+ )
29
+
30
+ logits = self.classifier(**encoder_output)
31
+
32
+ loss = None
33
+
34
+ return SequenceClassifierOutput(
35
+ loss=loss,
36
+ logits=logits,
37
+ hidden_states=encoder_output['encoder_hidden_states']
38
+ )
src/model/make_model.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import (
2
+ AutoConfig
3
+ )
4
+ from .custom_model import CustomModelForAudioClassification
5
+
6
+ def make_model(hparams):
7
+ """ Returns a model instance based on the provided hyperparameters. """
8
+ hparams = vars(hparams)
9
+ config = AutoConfig.from_pretrained(hparams['HF_MODEL_PATH'])
10
+ config.max_duration = hparams['MAX_DURATION']
11
+ config.sampling_rate = hparams['SAMPLING_RATE']
12
+ config.output_hidden_states = hparams['OUTPUT_HIDDEN_STATES']
13
+ config.classifier_name = hparams['CLASSIFIER_NAME']
14
+ config.classifier_proj_size = hparams['CLASSIFIER_PROJ_SIZE']
15
+ config.num_labels = hparams['NUM_LABELS']
16
+ config.label_weights = hparams['LABEL_WEIGHTS']
17
+ config.lossname = hparams['LOSS']
18
+ model = CustomModelForAudioClassification(config)
19
+
20
+ return model
src/model/multilevel_classifier.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.nn import functional as F
6
+
7
+
8
+ class MultiLevelDownstreamModel(nn.Module):
9
+ def __init__(
10
+ self,
11
+ model_config,
12
+ use_conv_output: Optional[bool] = True,
13
+ ):
14
+ super().__init__()
15
+ assert model_config.output_hidden_states == True, "The upstream model must return all hidden states"
16
+
17
+ self.model_config = model_config
18
+ self.use_conv_output = use_conv_output
19
+
20
+
21
+ self.model_seq = nn.Sequential(
22
+ nn.Conv1d(self.model_config.hidden_size, self.model_config.classifier_proj_size, 1, padding=0),
23
+ nn.ReLU(),
24
+ nn.Dropout(p=0.1),
25
+ nn.Conv1d(self.model_config.classifier_proj_size, self.model_config.classifier_proj_size, 1, padding=0),
26
+ nn.ReLU(),
27
+ nn.Dropout(p=0.1),
28
+ nn.Conv1d(self.model_config.classifier_proj_size, self.model_config.classifier_proj_size, 1, padding=0)
29
+ )
30
+
31
+ if self.use_conv_output:
32
+ num_layers = self.model_config.num_hidden_layers + 1 # transformer layers + input embeddings
33
+ self.weights = nn.Parameter(torch.ones(num_layers)/num_layers)
34
+ else:
35
+ num_layers = self.model_config.num_hidden_layers
36
+ self.weights = nn.Parameter(torch.zeros(num_layers))
37
+
38
+ self.out_layer = nn.Sequential(
39
+ nn.Linear(self.model_config.classifier_proj_size, self.model_config.classifier_proj_size),
40
+ nn.ReLU(),
41
+ nn.Linear(self.model_config.classifier_proj_size, self.model_config.num_labels),
42
+ )
43
+
44
+ def forward(self, encoder_hidden_states, length=None):
45
+ if self.use_conv_output:
46
+ stacked_feature = torch.stack(encoder_hidden_states, dim=0)
47
+ else:
48
+ stacked_feature = torch.stack(encoder_hidden_states, dim=0)[1:] # exclude the convolution output
49
+
50
+ _, *origin_shape = stacked_feature.shape
51
+
52
+ if self.use_conv_output:
53
+ stacked_feature = stacked_feature.view(self.model_config.num_hidden_layers + 1, -1)
54
+ else:
55
+ stacked_feature = stacked_feature.view(self.model_config.config.num_hidden_layers, -1)
56
+
57
+ norm_weights = F.softmax(self.weights, dim=-1)
58
+
59
+
60
+ weighted_feature = (norm_weights.unsqueeze(-1) * stacked_feature).sum(dim=0)
61
+ features = weighted_feature.view(*origin_shape)
62
+
63
+ features = features.transpose(1, 2)
64
+ features = self.model_seq(features)
65
+ features = features.transpose(1, 2)
66
+
67
+ if length is not None:
68
+ length = length.cuda()
69
+ masks = torch.arange(features.size(1)).expand(length.size(0), -1).cuda() < length.unsqueeze(1)
70
+ masks = masks.float()
71
+ features = (features * masks.unsqueeze(-1)).sum(1) / length.unsqueeze(1)
72
+ else:
73
+ features = torch.mean(features, dim=1)
74
+
75
+ predicted = self.out_layer(features)
76
+ return predicted
src/model/wav2vec2_wrapper.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import transformers.models.wav2vec2.modeling_wav2vec2 as w2v2
3
+ from torch import nn
4
+ from transformers import Wav2Vec2Model
5
+
6
+ class Wav2Vec2EncoderLayer(nn.Module):
7
+ def __init__(
8
+ self,
9
+ config,
10
+ i
11
+ ):
12
+ super().__init__()
13
+ self.attention = w2v2.Wav2Vec2Attention(
14
+ embed_dim=config.hidden_size,
15
+ num_heads=config.num_attention_heads,
16
+ dropout=config.attention_dropout,
17
+ is_decoder=False,
18
+ )
19
+ self.dropout = nn.Dropout(config.hidden_dropout)
20
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
21
+ self.feed_forward = w2v2.Wav2Vec2FeedForward(config)
22
+ self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
23
+ self.config = config
24
+ self.i = i
25
+
26
+ def forward(self, hidden_states, attention_mask=None, output_attentions=False):
27
+ attn_residual = hidden_states
28
+
29
+ hidden_states, attn_weights, _ = self.attention(
30
+ hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
31
+ )
32
+ hidden_states = self.dropout(hidden_states)
33
+ hidden_states = attn_residual + hidden_states
34
+ hidden_states = self.layer_norm(hidden_states)
35
+ hidden_states = hidden_states + self.feed_forward(hidden_states)
36
+ hidden_states = self.final_layer_norm(hidden_states)
37
+
38
+ outputs = (hidden_states,)
39
+
40
+ if output_attentions:
41
+ outputs += (attn_weights,)
42
+ return outputs
43
+
44
+ class Wav2VecWrapper(nn.Module):
45
+ def __init__(
46
+ self,
47
+ config,
48
+ ):
49
+ super(Wav2VecWrapper, self).__init__()
50
+ self.config = config
51
+
52
+ self.backbone_model = Wav2Vec2Model.from_pretrained(
53
+ config._name_or_path,
54
+ output_hidden_states=config.output_hidden_states,
55
+ )
56
+ state_dict = self.backbone_model.state_dict()
57
+
58
+ self.model_config = self.backbone_model.config
59
+ self.backbone_model.encoder.layers = nn.ModuleList([Wav2Vec2EncoderLayer(self.model_config, i) for i in range(self.model_config.num_hidden_layers)])
60
+
61
+ def forward(self,
62
+ input_features: torch.Tensor,
63
+ length: torch.Tensor = None,
64
+ ):
65
+ with torch.no_grad():
66
+ hidden_states = self.backbone_model.feature_extractor(input_features)
67
+ hidden_states = hidden_states.transpose(1, 2)
68
+ hidden_states, _ = self.backbone_model.feature_projection(hidden_states)
69
+
70
+ if length is not None:
71
+ length = self.get_feat_extract_output_lengths(length.detach().cpu())
72
+
73
+ hidden_states = self.backbone_model.encoder(
74
+ hidden_states,
75
+ output_hidden_states=self.config.output_hidden_states
76
+ ).hidden_states
77
+
78
+ return {'encoder_hidden_states': hidden_states, 'length': length}
79
+
80
+ def get_feat_extract_output_lengths(self, input_length):
81
+ def _conv_out_length(input_length, kernel_size, stride):
82
+ return (input_length - kernel_size) // stride + 1
83
+ for kernel_size, stride in zip(self.backbone_model.config.conv_kernel, self.backbone_model.config.conv_stride):
84
+ input_length = _conv_out_length(input_length, kernel_size, stride)
85
+ return input_length
86
+
87
+ def prepare_mask(length, shape, dtype):
88
+ mask = torch.zeros(
89
+ shape, dtype=dtype
90
+ )
91
+ mask[(torch.arange(mask.shape[0]), length.cpu() - 1)] = 1
92
+ mask = mask.flip([-1]).cumsum(-1).flip([-1]).bool()
93
+ return mask