Commit
·
8f96165
1
Parent(s):
9995d17
add model logic implementation
Browse files- .gitignore +3 -1
- README.md +8 -0
- app.py +32 -5
- requirements.txt +4 -1
- src/audio_processor.py +190 -0
- src/generate_graph.py +176 -0
- src/model/behaviour_model.py +28 -0
- src/model/custom_model.py +38 -0
- src/model/make_model.py +20 -0
- src/model/multilevel_classifier.py +76 -0
- src/model/wav2vec2_wrapper.py +93 -0
.gitignore
CHANGED
@@ -16,4 +16,6 @@ build/
|
|
16 |
|
17 |
# VSCode
|
18 |
.vscode/
|
19 |
-
*.code-workspace
|
|
|
|
|
|
16 |
|
17 |
# VSCode
|
18 |
.vscode/
|
19 |
+
*.code-workspace
|
20 |
+
|
21 |
+
behaviour_model/
|
README.md
CHANGED
@@ -9,6 +9,14 @@ app_file: app.py
|
|
9 |
pinned: false
|
10 |
license: mit
|
11 |
short_description: Audio and Text Emotion Recognition
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
---
|
13 |
|
14 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
9 |
pinned: false
|
10 |
license: mit
|
11 |
short_description: Audio and Text Emotion Recognition
|
12 |
+
models:
|
13 |
+
- links-ads/kk-speech-emotion-recognition
|
14 |
+
- openai/whisper-large-v3
|
15 |
+
- facebook/wav2vec2-large-xlsr-53
|
16 |
+
preload_from_hub:
|
17 |
+
- links-ads/kk-speech-emotion-recognition
|
18 |
+
- openai/whisper-large-v3
|
19 |
+
- facebook/wav2vec2-large-xlsr-53
|
20 |
---
|
21 |
|
22 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
CHANGED
@@ -1,13 +1,40 @@
|
|
|
|
1 |
import gradio as gr
|
2 |
-
import plotly.graph_objects as go
|
3 |
from src.load_html import get_description_html
|
4 |
-
|
5 |
-
|
6 |
-
|
|
|
|
|
|
|
7 |
|
8 |
|
9 |
# Gradio interface
|
10 |
def create_demo():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
with gr.Blocks() as demo:
|
12 |
gr.HTML(get_description_html)
|
13 |
|
@@ -17,7 +44,7 @@ def create_demo():
|
|
17 |
graph_output = gr.Plot(label="Generated Graph")
|
18 |
|
19 |
submit_button.click(
|
20 |
-
fn=
|
21 |
inputs=audio_input,
|
22 |
outputs=graph_output
|
23 |
)
|
|
|
1 |
+
import torch
|
2 |
import gradio as gr
|
|
|
3 |
from src.load_html import get_description_html
|
4 |
+
from src.audio_processor import AudioProcessor
|
5 |
+
from src.model.behaviour_model import get_behaviour_model
|
6 |
+
from transformers import (
|
7 |
+
pipeline,
|
8 |
+
WavLMForSequenceClassification
|
9 |
+
)
|
10 |
|
11 |
|
12 |
# Gradio interface
|
13 |
def create_demo():
|
14 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
15 |
+
segmentation_model = pipeline(
|
16 |
+
task="automatic-speech-recognition",
|
17 |
+
model="openai/whisper-large-v3-turbo",
|
18 |
+
tokenizer="openai/whisper-large-v3-turbo",
|
19 |
+
device=device
|
20 |
+
)
|
21 |
+
|
22 |
+
emotion_model = WavLMForSequenceClassification.from_pretrained("links-ads/kk-speech-emotion-recognition")
|
23 |
+
emotion_model.to(device)
|
24 |
+
emotion_model.eval()
|
25 |
+
|
26 |
+
behaviour_model = get_behaviour_model(
|
27 |
+
behaviour_model_path="behaviour_model/",
|
28 |
+
device=device,
|
29 |
+
)
|
30 |
+
|
31 |
+
audio_processor = AudioProcessor(
|
32 |
+
emotion_model=emotion_model,
|
33 |
+
segmentation_model=segmentation_model,
|
34 |
+
device=device,
|
35 |
+
behaviour_model=behaviour_model,
|
36 |
+
)
|
37 |
+
|
38 |
with gr.Blocks() as demo:
|
39 |
gr.HTML(get_description_html)
|
40 |
|
|
|
44 |
graph_output = gr.Plot(label="Generated Graph")
|
45 |
|
46 |
submit_button.click(
|
47 |
+
fn=audio_processor,
|
48 |
inputs=audio_input,
|
49 |
outputs=graph_output
|
50 |
)
|
requirements.txt
CHANGED
@@ -1,2 +1,5 @@
|
|
1 |
gradio==5.24.0
|
2 |
-
plotly==6.0.1
|
|
|
|
|
|
|
|
1 |
gradio==5.24.0
|
2 |
+
plotly==6.0.1
|
3 |
+
torch==2.7.0
|
4 |
+
librosa==0.11.0
|
5 |
+
transformers==4.51.3
|
src/audio_processor.py
ADDED
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import torch
|
3 |
+
import librosa
|
4 |
+
import numpy as np
|
5 |
+
import gradio as gr
|
6 |
+
import gradio as gr
|
7 |
+
from .generate_graph import create_behaviour_gantt_plot
|
8 |
+
from transformers import Wav2Vec2Processor
|
9 |
+
|
10 |
+
|
11 |
+
SAMPLING_RATE = 16_000
|
12 |
+
|
13 |
+
class AudioProcessor:
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
emotion_model,
|
17 |
+
segmentation_model,
|
18 |
+
device,
|
19 |
+
behaviour_model=None,
|
20 |
+
):
|
21 |
+
self.emotion_processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
|
22 |
+
self.emotion_model = emotion_model
|
23 |
+
self.behaviour_model = behaviour_model
|
24 |
+
self.device = device
|
25 |
+
self.audio_emotion_labels = {
|
26 |
+
0: "Neutralità",
|
27 |
+
1: "Rabbia",
|
28 |
+
2: "Paura",
|
29 |
+
3: "Gioia",
|
30 |
+
4: "Sorpresa",
|
31 |
+
5: "Tristezza",
|
32 |
+
6: "Disgusto",
|
33 |
+
}
|
34 |
+
self.emotion_translation = {
|
35 |
+
"neutrality": "Neutralità",
|
36 |
+
"anger": "Rabbia",
|
37 |
+
"fear": "Paura",
|
38 |
+
"joy": "Gioia",
|
39 |
+
"surprise": "Sorpresa",
|
40 |
+
"sadness": "Tristezza",
|
41 |
+
"disgust": "Disgusto"
|
42 |
+
}
|
43 |
+
self.behaviour_labels = {
|
44 |
+
0: "frustrated",
|
45 |
+
1: "delighted",
|
46 |
+
2: "dysregulated",
|
47 |
+
}
|
48 |
+
self.behaviour_translation = {
|
49 |
+
"frustrated": "frustazione",
|
50 |
+
"delighted": "incantato",
|
51 |
+
"dysregulated": "disregolazione",
|
52 |
+
}
|
53 |
+
self.segmentation_model = segmentation_model
|
54 |
+
|
55 |
+
self._set_emotion_model()
|
56 |
+
if self.behaviour_model:
|
57 |
+
self._set_behaviour_model()
|
58 |
+
|
59 |
+
self.behaviour_confidence = 0.6
|
60 |
+
|
61 |
+
self.chart_generator = None
|
62 |
+
|
63 |
+
def _set_emotion_model(self):
|
64 |
+
self.emotion_model.to(self.device)
|
65 |
+
self.emotion_model.eval()
|
66 |
+
|
67 |
+
def _set_behaviour_model(self):
|
68 |
+
self.behaviour_model.to(self.device)
|
69 |
+
self.behaviour_model.eval()
|
70 |
+
|
71 |
+
def _prepare_transcribed_text(self, chunks):
|
72 |
+
formated_timestamps = []
|
73 |
+
predictions = []
|
74 |
+
|
75 |
+
for chunk in chunks:
|
76 |
+
start = chunk[0] / SAMPLING_RATE
|
77 |
+
end = chunk[1] / SAMPLING_RATE
|
78 |
+
formated_start = time.strftime('%H:%M:%S', time.gmtime(start))
|
79 |
+
formated_end = time.strftime('%H:%M:%S', time.gmtime(end))
|
80 |
+
formated_timestamps.append(f"**({formated_start} - {formated_end})**")
|
81 |
+
|
82 |
+
predictions.append(f"**[{chunk[2]}]**")
|
83 |
+
|
84 |
+
transcribed_texts = [chunk[3] for chunk in chunks]
|
85 |
+
transcribed_text = "<br/>".join(
|
86 |
+
[
|
87 |
+
f"{formated_timestamps[i]}: {transcribed_texts[i]} {predictions[i]}" for i in range(len(transcribed_texts))
|
88 |
+
]
|
89 |
+
)
|
90 |
+
|
91 |
+
print(f"Transcribed text:\n{transcribed_text}")
|
92 |
+
|
93 |
+
|
94 |
+
return transcribed_text
|
95 |
+
|
96 |
+
def __call__(self, audio_path: str):
|
97 |
+
"""
|
98 |
+
Predicts the emotion label for a given audio input.
|
99 |
+
|
100 |
+
Args:
|
101 |
+
audio (filepath): The audio input path to be processed.
|
102 |
+
|
103 |
+
Returns:
|
104 |
+
str: The predicted emotion label.
|
105 |
+
|
106 |
+
"""
|
107 |
+
|
108 |
+
print("Segmenting audio...")
|
109 |
+
out = self.segmentation_model(
|
110 |
+
inputs=audio_path,
|
111 |
+
return_timestamps=True,
|
112 |
+
)
|
113 |
+
|
114 |
+
emotion_chunks = []
|
115 |
+
behaviour_chunks = []
|
116 |
+
timestamps = []
|
117 |
+
predicted_labels = []
|
118 |
+
all_probabilities = []
|
119 |
+
|
120 |
+
print("Analizing chunks...")
|
121 |
+
for chunk in out["chunks"]:
|
122 |
+
# trim audio from timestamps
|
123 |
+
start = int(chunk["timestamp"][0] * SAMPLING_RATE)
|
124 |
+
end = int(chunk["timestamp"][1] * SAMPLING_RATE if chunk["timestamp"][1] else len(input_frames))
|
125 |
+
|
126 |
+
audio = input_frames[start:end]
|
127 |
+
|
128 |
+
inputs = self.emotion_processor(audio, chunk["text"], return_tensors="pt", sampling_rate=SAMPLING_RATE)
|
129 |
+
|
130 |
+
print(f"Inputs: {inputs}")
|
131 |
+
|
132 |
+
if "input_values" in inputs:
|
133 |
+
inputs["input_features"] = inputs.pop("input_values")
|
134 |
+
|
135 |
+
inputs['input_features'] = inputs['input_features'].to(self.device)
|
136 |
+
inputs['input_ids'] = inputs['input_ids'].to(self.device)
|
137 |
+
inputs['text_attention_mask'] = inputs['text_attention_mask'].to(self.device)
|
138 |
+
|
139 |
+
print("Predicting emotion for chunk...")
|
140 |
+
logits = self.emotion_model(**inputs).logits
|
141 |
+
|
142 |
+
logits = logits.detach().cpu()
|
143 |
+
softmax = torch.nn.Softmax(dim=1)
|
144 |
+
probabilities = softmax(logits).squeeze(0)
|
145 |
+
|
146 |
+
prediction = probabilities.argmax().item()
|
147 |
+
|
148 |
+
predicted_label = self.emotion_processor.config.id2label[prediction]
|
149 |
+
label_translation = self.emotion_translation[predicted_label]
|
150 |
+
|
151 |
+
emotion_chunks.append(
|
152 |
+
(
|
153 |
+
start,
|
154 |
+
end,
|
155 |
+
label_translation,
|
156 |
+
chunk["text"],
|
157 |
+
np.round(probabilities[prediction].item(), 2)
|
158 |
+
)
|
159 |
+
)
|
160 |
+
|
161 |
+
timestamps.append((start, end))
|
162 |
+
predicted_labels.append(label_translation)
|
163 |
+
all_probabilities.append(probabilities[prediction].item())
|
164 |
+
|
165 |
+
|
166 |
+
inputs = self.emotion_processor(audio, return_tensors="pt", sampling_rate=SAMPLING_RATE)
|
167 |
+
if "input_values" in inputs:
|
168 |
+
inputs["input_features"] = inputs.pop("input_values")
|
169 |
+
|
170 |
+
inputs = inputs.input_features.to(self.device)
|
171 |
+
print("Predicting behaviour for chunk...")
|
172 |
+
logits = self.behaviour_model(inputs).logits
|
173 |
+
probabilities = torch.nn.functional.softmax(logits.detach().cpu(), dim=-1).squeeze()
|
174 |
+
behaviour_chunks.append(
|
175 |
+
(
|
176 |
+
start,
|
177 |
+
end,
|
178 |
+
chunk["text"],
|
179 |
+
np.round(probabilities[2].item(), 2),
|
180 |
+
label_translation,
|
181 |
+
)
|
182 |
+
)
|
183 |
+
behaviour_gantt = create_behaviour_gantt_plot(behaviour_chunks)
|
184 |
+
|
185 |
+
# transcribed_text = self._prepare_transcribed_text(emotion_chunks)
|
186 |
+
|
187 |
+
return (
|
188 |
+
behaviour_gantt,
|
189 |
+
# transcribed_text,
|
190 |
+
)
|
src/generate_graph.py
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import plotly.graph_objects as go
|
3 |
+
from datetime import datetime, timedelta
|
4 |
+
|
5 |
+
|
6 |
+
SAMPLING_RATE = 16_000
|
7 |
+
COLOR_MAP = {
|
8 |
+
"Neutralità": "rgb(178, 178, 178)",
|
9 |
+
"Rabbia": "rgb(160, 61, 62)",
|
10 |
+
"Paura": "rgb(91, 57, 136)",
|
11 |
+
"Gioia": "rgb(255, 255, 0)",
|
12 |
+
"Sorpresa": "rgb(60, 175, 175)",
|
13 |
+
"Tristezza": "rgb(64, 106, 173)",
|
14 |
+
"Disgusto": "rgb(100, 153, 65)",
|
15 |
+
}
|
16 |
+
|
17 |
+
def create_behaviour_gantt_plot(behaviour_chunks, confidence_threshold=60):
|
18 |
+
print("Creating behaviour Gantt plot...")
|
19 |
+
emotion_order = [
|
20 |
+
"Gioia",
|
21 |
+
"Sorpresa",
|
22 |
+
"Disgusto",
|
23 |
+
"Tristezza",
|
24 |
+
"Paura",
|
25 |
+
"Rabbia",
|
26 |
+
"Neutralità"
|
27 |
+
]
|
28 |
+
|
29 |
+
fig = go.Figure()
|
30 |
+
|
31 |
+
chunk_starts = [start/SAMPLING_RATE for start, _, _, _, _ in behaviour_chunks]
|
32 |
+
chunk_ends = [end/SAMPLING_RATE for _, end, _, _, _ in behaviour_chunks]
|
33 |
+
|
34 |
+
# Create reference time for plotting (starting at 0)
|
35 |
+
# We'll use a base datetime and add seconds
|
36 |
+
base_time = datetime(2_000, 1, 1, 0, 0, 0) # TODO: change magic numbers
|
37 |
+
|
38 |
+
start_times = [base_time + timedelta(seconds=t) for t in chunk_starts]
|
39 |
+
end_times = [base_time + timedelta(seconds=t) for t in chunk_ends]
|
40 |
+
|
41 |
+
# Calculate midpoints for each chunk (for trend line)
|
42 |
+
mid_times = [base_time + timedelta(seconds=(s+e)/2) for s, e in zip(chunk_starts, chunk_ends)]
|
43 |
+
|
44 |
+
heights = [height * 100 for _, _, _, height, _ in behaviour_chunks]
|
45 |
+
|
46 |
+
emotions = [emotion for _, _, _, _, emotion in behaviour_chunks]
|
47 |
+
|
48 |
+
hover_texts = []
|
49 |
+
for i, (start, end, label, height, emotion) in enumerate(behaviour_chunks):
|
50 |
+
start_fmt = time.strftime('%H:%M:%S', time.gmtime(start / SAMPLING_RATE))
|
51 |
+
end_fmt = time.strftime('%H:%M:%S', time.gmtime(end / SAMPLING_RATE))
|
52 |
+
duration_seconds = (end - start) / SAMPLING_RATE
|
53 |
+
duration_str = time.strftime('%H:%M:%S', time.gmtime(duration_seconds))
|
54 |
+
|
55 |
+
hover_text = f"Inizio: {start_fmt}<br>Fine: {end_fmt}<br>Durata: {duration_str}<br>Testo: {label}<br>Attendibilità: {height*100:.2f}%<br>Emozione: {emotion}"
|
56 |
+
hover_texts.append(hover_text)
|
57 |
+
|
58 |
+
fig.add_shape(
|
59 |
+
type="rect",
|
60 |
+
x0=start_times[0],
|
61 |
+
x1=end_times[-1],
|
62 |
+
y0=confidence_threshold,
|
63 |
+
y1=100,
|
64 |
+
fillcolor="rgba(188,223,241,0.8)",
|
65 |
+
opacity=0.8,
|
66 |
+
layer="below",
|
67 |
+
line_width=0,
|
68 |
+
)
|
69 |
+
|
70 |
+
fig.add_hline(y=confidence_threshold, line_dash="dash", line_color="black", line_width=1)
|
71 |
+
|
72 |
+
fig.add_trace(
|
73 |
+
go.Scatter(
|
74 |
+
x=mid_times,
|
75 |
+
y=heights,
|
76 |
+
mode='lines',
|
77 |
+
name='Disregolazione',
|
78 |
+
line=dict(
|
79 |
+
color='orange',
|
80 |
+
width=2,
|
81 |
+
shape='spline', # This enables smoothing
|
82 |
+
smoothing=1.0, # Adjust smoothing factor
|
83 |
+
),
|
84 |
+
text=hover_texts,
|
85 |
+
hoverinfo='text',
|
86 |
+
showlegend=False,
|
87 |
+
)
|
88 |
+
)
|
89 |
+
|
90 |
+
emotion_data = {}
|
91 |
+
|
92 |
+
for i, height in enumerate(heights):
|
93 |
+
if height >= confidence_threshold:
|
94 |
+
emotion = emotions[i]
|
95 |
+
if emotion not in emotion_data:
|
96 |
+
emotion_data[emotion] = {
|
97 |
+
'times': [],
|
98 |
+
'heights': [],
|
99 |
+
'hover_texts': []
|
100 |
+
}
|
101 |
+
|
102 |
+
emotion_data[emotion]['times'].append(mid_times[i])
|
103 |
+
emotion_data[emotion]['heights'].append(height)
|
104 |
+
emotion_data[emotion]['hover_texts'].append(hover_texts[i])
|
105 |
+
|
106 |
+
for emotion in emotion_order:
|
107 |
+
color = COLOR_MAP.get(emotion, '#000000')
|
108 |
+
|
109 |
+
if emotion in emotion_data:
|
110 |
+
data = emotion_data[emotion]
|
111 |
+
fig.add_trace(
|
112 |
+
go.Scatter(
|
113 |
+
x=data['times'],
|
114 |
+
y=data['heights'],
|
115 |
+
mode='markers',
|
116 |
+
name=emotion.capitalize(),
|
117 |
+
marker=dict(
|
118 |
+
size=15,
|
119 |
+
color=color,
|
120 |
+
symbol='circle'
|
121 |
+
),
|
122 |
+
text=data['hover_texts'],
|
123 |
+
hoverinfo='text',
|
124 |
+
showlegend=True,
|
125 |
+
)
|
126 |
+
)
|
127 |
+
else:
|
128 |
+
fig.add_trace(
|
129 |
+
go.Scatter(
|
130 |
+
x=[None],
|
131 |
+
y=[None],
|
132 |
+
mode='markers',
|
133 |
+
name=emotion.capitalize(),
|
134 |
+
marker=dict(
|
135 |
+
size=15,
|
136 |
+
color=color,
|
137 |
+
symbol='circle'
|
138 |
+
),
|
139 |
+
showlegend=True,
|
140 |
+
)
|
141 |
+
)
|
142 |
+
|
143 |
+
fig.update_layout(
|
144 |
+
title='Distribuzione della disregolazione',
|
145 |
+
xaxis_title='Tempo',
|
146 |
+
yaxis_title='Attendibilità',
|
147 |
+
xaxis=dict(
|
148 |
+
type='date',
|
149 |
+
tickformat='%H:%M:%S',
|
150 |
+
showline=True,
|
151 |
+
zeroline=False,
|
152 |
+
side='bottom',
|
153 |
+
showgrid=False,
|
154 |
+
),
|
155 |
+
yaxis=dict(
|
156 |
+
range=[0, 100],
|
157 |
+
tickvals=[0, 20, 40, 60, 80, 100],
|
158 |
+
ticktext=['0%', '20%', '40%', '60%', '80%', '100%'],
|
159 |
+
tickmode='array',
|
160 |
+
showgrid=False,
|
161 |
+
),
|
162 |
+
legend_title=None,
|
163 |
+
legend=dict(
|
164 |
+
yanchor="top"
|
165 |
+
),
|
166 |
+
hoverlabel=dict(
|
167 |
+
font_size=12,
|
168 |
+
font_family="Arial"
|
169 |
+
),
|
170 |
+
paper_bgcolor='white',
|
171 |
+
plot_bgcolor='white',
|
172 |
+
)
|
173 |
+
|
174 |
+
fig.update_traces(hovertemplate=None)
|
175 |
+
|
176 |
+
return fig
|
src/model/behaviour_model.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import argparse
|
3 |
+
import torch
|
4 |
+
from .make_model import make_model
|
5 |
+
|
6 |
+
|
7 |
+
hparams_dict = {
|
8 |
+
'HF_MODEL_PATH': 'facebook/wav2vec2-large-xlsr-53',
|
9 |
+
'DATASET': 'recanvo',
|
10 |
+
'MAX_DURATION': 4,
|
11 |
+
'SAMPLING_RATE': 16_000,
|
12 |
+
'OUTPUT_HIDDEN_STATES': True,
|
13 |
+
'CLASSIFIER_NAME': 'multilevel',
|
14 |
+
'CLASSIFIER_PROJ_SIZE': 256,
|
15 |
+
'NUM_LABELS': 3,
|
16 |
+
'LABEL_WEIGHTS': [1.0],
|
17 |
+
'LOSS': 'cross-entropy',
|
18 |
+
'GPU_ID': 0,
|
19 |
+
'RETURN_RAW_ARRAY': False,
|
20 |
+
}
|
21 |
+
hparams = argparse.Namespace(**hparams_dict)
|
22 |
+
|
23 |
+
def get_behaviour_model(behaviour_model_path, device):
|
24 |
+
state_dict = torch.load(os.path.join(behaviour_model_path, 'pytorch_model.bin'), map_location=device)
|
25 |
+
model = make_model(hparams)
|
26 |
+
model.load_state_dict(state_dict)
|
27 |
+
|
28 |
+
return model
|
src/model/custom_model.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from typing import Optional, Union, Tuple
|
5 |
+
from transformers.modeling_outputs import SequenceClassifierOutput
|
6 |
+
from .wav2vec2_wrapper import Wav2VecWrapper
|
7 |
+
from .multilevel_classifier import MultiLevelDownstreamModel
|
8 |
+
|
9 |
+
class CustomModelForAudioClassification(nn.Module):
|
10 |
+
def __init__(self, config):
|
11 |
+
super().__init__()
|
12 |
+
assert config.output_hidden_states == True, "The upstream model must return all hidden states"
|
13 |
+
self.config = config
|
14 |
+
self.encoder = Wav2VecWrapper(config)
|
15 |
+
|
16 |
+
self.classifier = MultiLevelDownstreamModel(config, use_conv_output=True)
|
17 |
+
|
18 |
+
def forward(
|
19 |
+
self,
|
20 |
+
input_features: Optional[torch.LongTensor],
|
21 |
+
length: Optional[torch.LongTensor] = None,
|
22 |
+
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
23 |
+
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
|
24 |
+
if encoder_outputs is None:
|
25 |
+
encoder_output = self.encoder(
|
26 |
+
input_features,
|
27 |
+
length=length,
|
28 |
+
)
|
29 |
+
|
30 |
+
logits = self.classifier(**encoder_output)
|
31 |
+
|
32 |
+
loss = None
|
33 |
+
|
34 |
+
return SequenceClassifierOutput(
|
35 |
+
loss=loss,
|
36 |
+
logits=logits,
|
37 |
+
hidden_states=encoder_output['encoder_hidden_states']
|
38 |
+
)
|
src/model/make_model.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import (
|
2 |
+
AutoConfig
|
3 |
+
)
|
4 |
+
from .custom_model import CustomModelForAudioClassification
|
5 |
+
|
6 |
+
def make_model(hparams):
|
7 |
+
""" Returns a model instance based on the provided hyperparameters. """
|
8 |
+
hparams = vars(hparams)
|
9 |
+
config = AutoConfig.from_pretrained(hparams['HF_MODEL_PATH'])
|
10 |
+
config.max_duration = hparams['MAX_DURATION']
|
11 |
+
config.sampling_rate = hparams['SAMPLING_RATE']
|
12 |
+
config.output_hidden_states = hparams['OUTPUT_HIDDEN_STATES']
|
13 |
+
config.classifier_name = hparams['CLASSIFIER_NAME']
|
14 |
+
config.classifier_proj_size = hparams['CLASSIFIER_PROJ_SIZE']
|
15 |
+
config.num_labels = hparams['NUM_LABELS']
|
16 |
+
config.label_weights = hparams['LABEL_WEIGHTS']
|
17 |
+
config.lossname = hparams['LOSS']
|
18 |
+
model = CustomModelForAudioClassification(config)
|
19 |
+
|
20 |
+
return model
|
src/model/multilevel_classifier.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from torch.nn import functional as F
|
6 |
+
|
7 |
+
|
8 |
+
class MultiLevelDownstreamModel(nn.Module):
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
model_config,
|
12 |
+
use_conv_output: Optional[bool] = True,
|
13 |
+
):
|
14 |
+
super().__init__()
|
15 |
+
assert model_config.output_hidden_states == True, "The upstream model must return all hidden states"
|
16 |
+
|
17 |
+
self.model_config = model_config
|
18 |
+
self.use_conv_output = use_conv_output
|
19 |
+
|
20 |
+
|
21 |
+
self.model_seq = nn.Sequential(
|
22 |
+
nn.Conv1d(self.model_config.hidden_size, self.model_config.classifier_proj_size, 1, padding=0),
|
23 |
+
nn.ReLU(),
|
24 |
+
nn.Dropout(p=0.1),
|
25 |
+
nn.Conv1d(self.model_config.classifier_proj_size, self.model_config.classifier_proj_size, 1, padding=0),
|
26 |
+
nn.ReLU(),
|
27 |
+
nn.Dropout(p=0.1),
|
28 |
+
nn.Conv1d(self.model_config.classifier_proj_size, self.model_config.classifier_proj_size, 1, padding=0)
|
29 |
+
)
|
30 |
+
|
31 |
+
if self.use_conv_output:
|
32 |
+
num_layers = self.model_config.num_hidden_layers + 1 # transformer layers + input embeddings
|
33 |
+
self.weights = nn.Parameter(torch.ones(num_layers)/num_layers)
|
34 |
+
else:
|
35 |
+
num_layers = self.model_config.num_hidden_layers
|
36 |
+
self.weights = nn.Parameter(torch.zeros(num_layers))
|
37 |
+
|
38 |
+
self.out_layer = nn.Sequential(
|
39 |
+
nn.Linear(self.model_config.classifier_proj_size, self.model_config.classifier_proj_size),
|
40 |
+
nn.ReLU(),
|
41 |
+
nn.Linear(self.model_config.classifier_proj_size, self.model_config.num_labels),
|
42 |
+
)
|
43 |
+
|
44 |
+
def forward(self, encoder_hidden_states, length=None):
|
45 |
+
if self.use_conv_output:
|
46 |
+
stacked_feature = torch.stack(encoder_hidden_states, dim=0)
|
47 |
+
else:
|
48 |
+
stacked_feature = torch.stack(encoder_hidden_states, dim=0)[1:] # exclude the convolution output
|
49 |
+
|
50 |
+
_, *origin_shape = stacked_feature.shape
|
51 |
+
|
52 |
+
if self.use_conv_output:
|
53 |
+
stacked_feature = stacked_feature.view(self.model_config.num_hidden_layers + 1, -1)
|
54 |
+
else:
|
55 |
+
stacked_feature = stacked_feature.view(self.model_config.config.num_hidden_layers, -1)
|
56 |
+
|
57 |
+
norm_weights = F.softmax(self.weights, dim=-1)
|
58 |
+
|
59 |
+
|
60 |
+
weighted_feature = (norm_weights.unsqueeze(-1) * stacked_feature).sum(dim=0)
|
61 |
+
features = weighted_feature.view(*origin_shape)
|
62 |
+
|
63 |
+
features = features.transpose(1, 2)
|
64 |
+
features = self.model_seq(features)
|
65 |
+
features = features.transpose(1, 2)
|
66 |
+
|
67 |
+
if length is not None:
|
68 |
+
length = length.cuda()
|
69 |
+
masks = torch.arange(features.size(1)).expand(length.size(0), -1).cuda() < length.unsqueeze(1)
|
70 |
+
masks = masks.float()
|
71 |
+
features = (features * masks.unsqueeze(-1)).sum(1) / length.unsqueeze(1)
|
72 |
+
else:
|
73 |
+
features = torch.mean(features, dim=1)
|
74 |
+
|
75 |
+
predicted = self.out_layer(features)
|
76 |
+
return predicted
|
src/model/wav2vec2_wrapper.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import transformers.models.wav2vec2.modeling_wav2vec2 as w2v2
|
3 |
+
from torch import nn
|
4 |
+
from transformers import Wav2Vec2Model
|
5 |
+
|
6 |
+
class Wav2Vec2EncoderLayer(nn.Module):
|
7 |
+
def __init__(
|
8 |
+
self,
|
9 |
+
config,
|
10 |
+
i
|
11 |
+
):
|
12 |
+
super().__init__()
|
13 |
+
self.attention = w2v2.Wav2Vec2Attention(
|
14 |
+
embed_dim=config.hidden_size,
|
15 |
+
num_heads=config.num_attention_heads,
|
16 |
+
dropout=config.attention_dropout,
|
17 |
+
is_decoder=False,
|
18 |
+
)
|
19 |
+
self.dropout = nn.Dropout(config.hidden_dropout)
|
20 |
+
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
21 |
+
self.feed_forward = w2v2.Wav2Vec2FeedForward(config)
|
22 |
+
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
23 |
+
self.config = config
|
24 |
+
self.i = i
|
25 |
+
|
26 |
+
def forward(self, hidden_states, attention_mask=None, output_attentions=False):
|
27 |
+
attn_residual = hidden_states
|
28 |
+
|
29 |
+
hidden_states, attn_weights, _ = self.attention(
|
30 |
+
hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
|
31 |
+
)
|
32 |
+
hidden_states = self.dropout(hidden_states)
|
33 |
+
hidden_states = attn_residual + hidden_states
|
34 |
+
hidden_states = self.layer_norm(hidden_states)
|
35 |
+
hidden_states = hidden_states + self.feed_forward(hidden_states)
|
36 |
+
hidden_states = self.final_layer_norm(hidden_states)
|
37 |
+
|
38 |
+
outputs = (hidden_states,)
|
39 |
+
|
40 |
+
if output_attentions:
|
41 |
+
outputs += (attn_weights,)
|
42 |
+
return outputs
|
43 |
+
|
44 |
+
class Wav2VecWrapper(nn.Module):
|
45 |
+
def __init__(
|
46 |
+
self,
|
47 |
+
config,
|
48 |
+
):
|
49 |
+
super(Wav2VecWrapper, self).__init__()
|
50 |
+
self.config = config
|
51 |
+
|
52 |
+
self.backbone_model = Wav2Vec2Model.from_pretrained(
|
53 |
+
config._name_or_path,
|
54 |
+
output_hidden_states=config.output_hidden_states,
|
55 |
+
)
|
56 |
+
state_dict = self.backbone_model.state_dict()
|
57 |
+
|
58 |
+
self.model_config = self.backbone_model.config
|
59 |
+
self.backbone_model.encoder.layers = nn.ModuleList([Wav2Vec2EncoderLayer(self.model_config, i) for i in range(self.model_config.num_hidden_layers)])
|
60 |
+
|
61 |
+
def forward(self,
|
62 |
+
input_features: torch.Tensor,
|
63 |
+
length: torch.Tensor = None,
|
64 |
+
):
|
65 |
+
with torch.no_grad():
|
66 |
+
hidden_states = self.backbone_model.feature_extractor(input_features)
|
67 |
+
hidden_states = hidden_states.transpose(1, 2)
|
68 |
+
hidden_states, _ = self.backbone_model.feature_projection(hidden_states)
|
69 |
+
|
70 |
+
if length is not None:
|
71 |
+
length = self.get_feat_extract_output_lengths(length.detach().cpu())
|
72 |
+
|
73 |
+
hidden_states = self.backbone_model.encoder(
|
74 |
+
hidden_states,
|
75 |
+
output_hidden_states=self.config.output_hidden_states
|
76 |
+
).hidden_states
|
77 |
+
|
78 |
+
return {'encoder_hidden_states': hidden_states, 'length': length}
|
79 |
+
|
80 |
+
def get_feat_extract_output_lengths(self, input_length):
|
81 |
+
def _conv_out_length(input_length, kernel_size, stride):
|
82 |
+
return (input_length - kernel_size) // stride + 1
|
83 |
+
for kernel_size, stride in zip(self.backbone_model.config.conv_kernel, self.backbone_model.config.conv_stride):
|
84 |
+
input_length = _conv_out_length(input_length, kernel_size, stride)
|
85 |
+
return input_length
|
86 |
+
|
87 |
+
def prepare_mask(length, shape, dtype):
|
88 |
+
mask = torch.zeros(
|
89 |
+
shape, dtype=dtype
|
90 |
+
)
|
91 |
+
mask[(torch.arange(mask.shape[0]), length.cpu() - 1)] = 1
|
92 |
+
mask = mask.flip([-1]).cumsum(-1).flip([-1]).bool()
|
93 |
+
return mask
|