ameliabb0913 commited on
Commit
3a18141
·
verified ·
1 Parent(s): fa298aa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -15
app.py CHANGED
@@ -4,36 +4,46 @@ from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2Processor
4
  import librosa
5
 
6
  # Placeholder model (Replace later with your trained model)
7
- MODEL_NAME = "facebook/wav2vec2-base-960h"
8
-
9
- # Load the pre-trained model and processor
10
  processor = Wav2Vec2Processor.from_pretrained(MODEL_NAME)
11
  model = Wav2Vec2ForSequenceClassification.from_pretrained(MODEL_NAME)
12
 
13
- # Function to process audio input
14
- def classify_audio(audio_file):
15
- # Load the audio file
 
 
 
 
 
 
 
 
 
 
 
16
  speech, sr = librosa.load(audio_file, sr=16000)
17
-
18
- # Preprocess with Hugging Face's feature extractor
19
  inputs = processor(speech, sampling_rate=16000, return_tensors="pt", padding=True, truncation=True)
20
 
21
- # Make predictions
22
  with torch.no_grad():
23
  logits = model(**inputs).logits
24
-
25
  predicted_class_id = torch.argmax(logits, dim=-1).item()
26
- return f"Predicted Class: {predicted_class_id}"
 
 
 
 
27
 
28
  # Gradio Interface
29
  interface = gr.Interface(
30
- fn=classify_audio,
31
  inputs=gr.Audio(source="upload", type="filepath"),
32
  outputs="text",
33
- title="Wav2Vec2 Audio Classification",
34
- description="Upload an audio file, and the model will classify it."
35
  )
36
 
37
- # Launch the Gradio demo
38
  if __name__ == "__main__":
39
  interface.launch()
 
4
  import librosa
5
 
6
  # Placeholder model (Replace later with your trained model)
7
+ MODEL_NAME = "audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim"
 
 
8
  processor = Wav2Vec2Processor.from_pretrained(MODEL_NAME)
9
  model = Wav2Vec2ForSequenceClassification.from_pretrained(MODEL_NAME)
10
 
11
+ # Emotion labels (based on the dataset used to train the model)
12
+ id2label = {
13
+ 0: "Neutral",
14
+ 1: "Happy",
15
+ 2: "Sad",
16
+ 3: "Angry",
17
+ 4: "Surprised",
18
+ 5: "Disgusted",
19
+ 6: "Fearful"
20
+ }
21
+
22
+ # Function to classify emotions from audio
23
+ def classify_emotion(audio_file):
24
+ # Load and process audio
25
  speech, sr = librosa.load(audio_file, sr=16000)
 
 
26
  inputs = processor(speech, sampling_rate=16000, return_tensors="pt", padding=True, truncation=True)
27
 
28
+ # Get predictions
29
  with torch.no_grad():
30
  logits = model(**inputs).logits
 
31
  predicted_class_id = torch.argmax(logits, dim=-1).item()
32
+
33
+ # Convert class ID to emotion label
34
+ predicted_emotion = id2label.get(predicted_class_id, "Unknown")
35
+
36
+ return f"Predicted Emotion: {predicted_emotion}"
37
 
38
  # Gradio Interface
39
  interface = gr.Interface(
40
+ fn=classify_emotion,
41
  inputs=gr.Audio(source="upload", type="filepath"),
42
  outputs="text",
43
+ title="Speech Emotion Classifier 🎭",
44
+ description="Upload an audio file and the model will classify its emotion (e.g., Happy, Sad, Angry)."
45
  )
46
 
47
+ # Launch the app
48
  if __name__ == "__main__":
49
  interface.launch()