avsv commited on
Commit
936f253
·
1 Parent(s): c4f284d

✅ Fix: use correct extractor for superb/wav2vec2-base-superb-er

Browse files
Files changed (1) hide show
  1. app.py +6 -12
app.py CHANGED
@@ -2,18 +2,16 @@ import streamlit as st
2
  import torch
3
  import torchaudio
4
  import tempfile
5
- from transformers import Wav2Vec2Processor, Wav2Vec2ForSequenceClassification
6
  from pydub import AudioSegment
 
7
 
8
- # Load model + processor (cached for performance)
9
  @st.cache_resource
10
  def load_model():
11
- processor = Wav2Vec2Processor.from_pretrained("superb/wav2vec2-base-superb-er")
12
  model = Wav2Vec2ForSequenceClassification.from_pretrained("superb/wav2vec2-base-superb-er")
13
  model.eval()
14
- return processor, model
15
 
16
- # Convert MP3/WAV to 16kHz mono WAV
17
  def convert_to_wav(uploaded_file):
18
  audio = AudioSegment.from_file(uploaded_file)
19
  audio = audio.set_frame_rate(16000).set_channels(1)
@@ -21,29 +19,26 @@ def convert_to_wav(uploaded_file):
21
  audio.export(temp_path, format="wav")
22
  return temp_path
23
 
24
- # Map prediction index to emotion label
25
  def get_emotion_label(logits):
26
  emotions = ["angry", "happy", "neutral", "sad"]
27
  scores = torch.softmax(torch.tensor(logits), dim=0).tolist()
28
  top_idx = scores.index(max(scores))
29
  return emotions[top_idx], scores
30
 
31
- # Analyze emotion from audio
32
  def analyze_emotion(audio_path):
33
- processor, model = load_model()
34
  waveform, sr = torchaudio.load(audio_path)
35
-
36
  if sr != 16000:
37
  waveform = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000)(waveform)
38
 
39
- inputs = processor(waveform.squeeze(), sampling_rate=16000, return_tensors="pt")
40
  with torch.no_grad():
41
  logits = model(**inputs).logits[0]
42
 
43
  emotion, scores = get_emotion_label(logits)
44
  return emotion.capitalize(), scores
45
 
46
- # --- Streamlit UI ---
47
  st.set_page_config(page_title="🎧 Audio Emotion Detector", layout="centered")
48
  st.title("🎧 Audio Emotion Analysis (Wav2Vec2)")
49
 
@@ -62,4 +57,3 @@ if uploaded_file:
62
  emotions = ["angry", "happy", "neutral", "sad"]
63
  for i, label in enumerate(emotions):
64
  st.write(f"- **{label.capitalize()}**: {scores[i]*100:.2f}%")
65
-
 
2
  import torch
3
  import torchaudio
4
  import tempfile
 
5
  from pydub import AudioSegment
6
+ from transformers import AutoFeatureExtractor, Wav2Vec2ForSequenceClassification
7
 
 
8
  @st.cache_resource
9
  def load_model():
10
+ extractor = AutoFeatureExtractor.from_pretrained("superb/wav2vec2-base-superb-er")
11
  model = Wav2Vec2ForSequenceClassification.from_pretrained("superb/wav2vec2-base-superb-er")
12
  model.eval()
13
+ return extractor, model
14
 
 
15
  def convert_to_wav(uploaded_file):
16
  audio = AudioSegment.from_file(uploaded_file)
17
  audio = audio.set_frame_rate(16000).set_channels(1)
 
19
  audio.export(temp_path, format="wav")
20
  return temp_path
21
 
 
22
  def get_emotion_label(logits):
23
  emotions = ["angry", "happy", "neutral", "sad"]
24
  scores = torch.softmax(torch.tensor(logits), dim=0).tolist()
25
  top_idx = scores.index(max(scores))
26
  return emotions[top_idx], scores
27
 
 
28
  def analyze_emotion(audio_path):
29
+ extractor, model = load_model()
30
  waveform, sr = torchaudio.load(audio_path)
 
31
  if sr != 16000:
32
  waveform = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000)(waveform)
33
 
34
+ inputs = extractor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt")
35
  with torch.no_grad():
36
  logits = model(**inputs).logits[0]
37
 
38
  emotion, scores = get_emotion_label(logits)
39
  return emotion.capitalize(), scores
40
 
41
+ # Streamlit UI
42
  st.set_page_config(page_title="🎧 Audio Emotion Detector", layout="centered")
43
  st.title("🎧 Audio Emotion Analysis (Wav2Vec2)")
44
 
 
57
  emotions = ["angry", "happy", "neutral", "sad"]
58
  for i, label in enumerate(emotions):
59
  st.write(f"- **{label.capitalize()}**: {scores[i]*100:.2f}%")