Spaces:
Sleeping
Sleeping
✅ Fix: use correct extractor for superb/wav2vec2-base-superb-er
Browse files
app.py
CHANGED
@@ -2,18 +2,16 @@ import streamlit as st
|
|
2 |
import torch
|
3 |
import torchaudio
|
4 |
import tempfile
|
5 |
-
from transformers import Wav2Vec2Processor, Wav2Vec2ForSequenceClassification
|
6 |
from pydub import AudioSegment
|
|
|
7 |
|
8 |
-
# Load model + processor (cached for performance)
|
9 |
@st.cache_resource
|
10 |
def load_model():
|
11 |
-
|
12 |
model = Wav2Vec2ForSequenceClassification.from_pretrained("superb/wav2vec2-base-superb-er")
|
13 |
model.eval()
|
14 |
-
return
|
15 |
|
16 |
-
# Convert MP3/WAV to 16kHz mono WAV
|
17 |
def convert_to_wav(uploaded_file):
|
18 |
audio = AudioSegment.from_file(uploaded_file)
|
19 |
audio = audio.set_frame_rate(16000).set_channels(1)
|
@@ -21,29 +19,26 @@ def convert_to_wav(uploaded_file):
|
|
21 |
audio.export(temp_path, format="wav")
|
22 |
return temp_path
|
23 |
|
24 |
-
# Map prediction index to emotion label
|
25 |
def get_emotion_label(logits):
|
26 |
emotions = ["angry", "happy", "neutral", "sad"]
|
27 |
scores = torch.softmax(torch.tensor(logits), dim=0).tolist()
|
28 |
top_idx = scores.index(max(scores))
|
29 |
return emotions[top_idx], scores
|
30 |
|
31 |
-
# Analyze emotion from audio
|
32 |
def analyze_emotion(audio_path):
|
33 |
-
|
34 |
waveform, sr = torchaudio.load(audio_path)
|
35 |
-
|
36 |
if sr != 16000:
|
37 |
waveform = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000)(waveform)
|
38 |
|
39 |
-
inputs =
|
40 |
with torch.no_grad():
|
41 |
logits = model(**inputs).logits[0]
|
42 |
|
43 |
emotion, scores = get_emotion_label(logits)
|
44 |
return emotion.capitalize(), scores
|
45 |
|
46 |
-
#
|
47 |
st.set_page_config(page_title="🎧 Audio Emotion Detector", layout="centered")
|
48 |
st.title("🎧 Audio Emotion Analysis (Wav2Vec2)")
|
49 |
|
@@ -62,4 +57,3 @@ if uploaded_file:
|
|
62 |
emotions = ["angry", "happy", "neutral", "sad"]
|
63 |
for i, label in enumerate(emotions):
|
64 |
st.write(f"- **{label.capitalize()}**: {scores[i]*100:.2f}%")
|
65 |
-
|
|
|
2 |
import torch
|
3 |
import torchaudio
|
4 |
import tempfile
|
|
|
5 |
from pydub import AudioSegment
|
6 |
+
from transformers import AutoFeatureExtractor, Wav2Vec2ForSequenceClassification
|
7 |
|
|
|
8 |
@st.cache_resource
|
9 |
def load_model():
|
10 |
+
extractor = AutoFeatureExtractor.from_pretrained("superb/wav2vec2-base-superb-er")
|
11 |
model = Wav2Vec2ForSequenceClassification.from_pretrained("superb/wav2vec2-base-superb-er")
|
12 |
model.eval()
|
13 |
+
return extractor, model
|
14 |
|
|
|
15 |
def convert_to_wav(uploaded_file):
|
16 |
audio = AudioSegment.from_file(uploaded_file)
|
17 |
audio = audio.set_frame_rate(16000).set_channels(1)
|
|
|
19 |
audio.export(temp_path, format="wav")
|
20 |
return temp_path
|
21 |
|
|
|
22 |
def get_emotion_label(logits):
|
23 |
emotions = ["angry", "happy", "neutral", "sad"]
|
24 |
scores = torch.softmax(torch.tensor(logits), dim=0).tolist()
|
25 |
top_idx = scores.index(max(scores))
|
26 |
return emotions[top_idx], scores
|
27 |
|
|
|
28 |
def analyze_emotion(audio_path):
|
29 |
+
extractor, model = load_model()
|
30 |
waveform, sr = torchaudio.load(audio_path)
|
|
|
31 |
if sr != 16000:
|
32 |
waveform = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000)(waveform)
|
33 |
|
34 |
+
inputs = extractor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt")
|
35 |
with torch.no_grad():
|
36 |
logits = model(**inputs).logits[0]
|
37 |
|
38 |
emotion, scores = get_emotion_label(logits)
|
39 |
return emotion.capitalize(), scores
|
40 |
|
41 |
+
# Streamlit UI
|
42 |
st.set_page_config(page_title="🎧 Audio Emotion Detector", layout="centered")
|
43 |
st.title("🎧 Audio Emotion Analysis (Wav2Vec2)")
|
44 |
|
|
|
57 |
emotions = ["angry", "happy", "neutral", "sad"]
|
58 |
for i, label in enumerate(emotions):
|
59 |
st.write(f"- **{label.capitalize()}**: {scores[i]*100:.2f}%")
|
|