Update app.py
Browse files
app.py
CHANGED
@@ -4,36 +4,46 @@ from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2Processor
|
|
4 |
import librosa
|
5 |
|
6 |
# Placeholder model (Replace later with your trained model)
|
7 |
-
MODEL_NAME = "
|
8 |
-
|
9 |
-
# Load the pre-trained model and processor
|
10 |
processor = Wav2Vec2Processor.from_pretrained(MODEL_NAME)
|
11 |
model = Wav2Vec2ForSequenceClassification.from_pretrained(MODEL_NAME)
|
12 |
|
13 |
-
#
|
14 |
-
|
15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
speech, sr = librosa.load(audio_file, sr=16000)
|
17 |
-
|
18 |
-
# Preprocess with Hugging Face's feature extractor
|
19 |
inputs = processor(speech, sampling_rate=16000, return_tensors="pt", padding=True, truncation=True)
|
20 |
|
21 |
-
#
|
22 |
with torch.no_grad():
|
23 |
logits = model(**inputs).logits
|
24 |
-
|
25 |
predicted_class_id = torch.argmax(logits, dim=-1).item()
|
26 |
-
|
|
|
|
|
|
|
|
|
27 |
|
28 |
# Gradio Interface
|
29 |
interface = gr.Interface(
|
30 |
-
fn=
|
31 |
inputs=gr.Audio(source="upload", type="filepath"),
|
32 |
outputs="text",
|
33 |
-
title="
|
34 |
-
description="Upload an audio file
|
35 |
)
|
36 |
|
37 |
-
# Launch the
|
38 |
if __name__ == "__main__":
|
39 |
interface.launch()
|
|
|
4 |
import librosa
|
5 |
|
6 |
# Placeholder model (Replace later with your trained model)
|
7 |
+
MODEL_NAME = "audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim"
|
|
|
|
|
8 |
processor = Wav2Vec2Processor.from_pretrained(MODEL_NAME)
|
9 |
model = Wav2Vec2ForSequenceClassification.from_pretrained(MODEL_NAME)
|
10 |
|
11 |
+
# Emotion labels (based on the dataset used to train the model)
|
12 |
+
id2label = {
|
13 |
+
0: "Neutral",
|
14 |
+
1: "Happy",
|
15 |
+
2: "Sad",
|
16 |
+
3: "Angry",
|
17 |
+
4: "Surprised",
|
18 |
+
5: "Disgusted",
|
19 |
+
6: "Fearful"
|
20 |
+
}
|
21 |
+
|
22 |
+
# Function to classify emotions from audio
|
23 |
+
def classify_emotion(audio_file):
|
24 |
+
# Load and process audio
|
25 |
speech, sr = librosa.load(audio_file, sr=16000)
|
|
|
|
|
26 |
inputs = processor(speech, sampling_rate=16000, return_tensors="pt", padding=True, truncation=True)
|
27 |
|
28 |
+
# Get predictions
|
29 |
with torch.no_grad():
|
30 |
logits = model(**inputs).logits
|
|
|
31 |
predicted_class_id = torch.argmax(logits, dim=-1).item()
|
32 |
+
|
33 |
+
# Convert class ID to emotion label
|
34 |
+
predicted_emotion = id2label.get(predicted_class_id, "Unknown")
|
35 |
+
|
36 |
+
return f"Predicted Emotion: {predicted_emotion}"
|
37 |
|
38 |
# Gradio Interface
|
39 |
interface = gr.Interface(
|
40 |
+
fn=classify_emotion,
|
41 |
inputs=gr.Audio(source="upload", type="filepath"),
|
42 |
outputs="text",
|
43 |
+
title="Speech Emotion Classifier 🎭",
|
44 |
+
description="Upload an audio file and the model will classify its emotion (e.g., Happy, Sad, Angry)."
|
45 |
)
|
46 |
|
47 |
+
# Launch the app
|
48 |
if __name__ == "__main__":
|
49 |
interface.launch()
|