Spaces:
Sleeping
Sleeping
updated the app file
Browse files
app.py
CHANGED
@@ -1,54 +1,54 @@
|
|
1 |
-
import streamlit as st
|
2 |
-
from ced_model.feature_extraction_ced import CedFeatureExtractor
|
3 |
-
from ced_model.modeling_ced import CedForAudioClassification
|
4 |
-
import torchaudio
|
5 |
-
import torch
|
6 |
-
import os
|
7 |
-
import soundfile as sf
|
8 |
-
|
9 |
-
model_name = "mispeech/ced-tiny"
|
10 |
-
feature_extractor = CedFeatureExtractor.from_pretrained(model_name)
|
11 |
-
model = CedForAudioClassification.from_pretrained(model_name)
|
12 |
-
|
13 |
-
st.title("Audio Classification App")
|
14 |
-
st.subheader("Trained on 50 classes of ESC 50 dataset")
|
15 |
-
st.write("Upload an audio file to predict its class.")
|
16 |
-
|
17 |
-
audio_file = st.file_uploader("Upload Audio File", type=["wav"])
|
18 |
-
|
19 |
-
if audio_file is not None:
|
20 |
-
st.write(f"Uploaded file: {audio_file.name}")
|
21 |
-
|
22 |
-
try:
|
23 |
-
temp_file_path = "temp.wav"
|
24 |
-
with open(temp_file_path, "wb") as f:
|
25 |
-
f.write(audio_file.read())
|
26 |
-
|
27 |
-
try:
|
28 |
-
audio, sampling_rate = torchaudio.load(temp_file_path)
|
29 |
-
except Exception:
|
30 |
-
st.warning("Fallback to soundfile for audio loading.")
|
31 |
-
audio_data, sampling_rate = sf.read(temp_file_path)
|
32 |
-
audio = torch.tensor(audio_data).unsqueeze(0)
|
33 |
-
|
34 |
-
if sampling_rate != 16000:
|
35 |
-
st.warning("Resampling audio to 16000 Hz...")
|
36 |
-
resampler = torchaudio.transforms.Resample(orig_freq=sampling_rate, new_freq=16000)
|
37 |
-
audio = resampler(audio)
|
38 |
-
sampling_rate = 16000
|
39 |
-
|
40 |
-
inputs = feature_extractor(audio, sampling_rate=sampling_rate, return_tensors="pt")
|
41 |
-
|
42 |
-
with torch.no_grad():
|
43 |
-
logits = model(**inputs).logits
|
44 |
-
|
45 |
-
predicted_class_id = torch.argmax(logits, dim=-1).item()
|
46 |
-
predicted_label = model.config.id2label[predicted_class_id]
|
47 |
-
|
48 |
-
st.success(f"Predicted Class: {predicted_label}")
|
49 |
-
|
50 |
-
os.remove(temp_file_path)
|
51 |
-
except Exception as e:
|
52 |
-
st.error(f"An error occurred: {e}")
|
53 |
-
else:
|
54 |
-
st.info("Please upload a .wav audio file to continue.")
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from ced_model.feature_extraction_ced import CedFeatureExtractor
|
3 |
+
from ced_model.modeling_ced import CedForAudioClassification
|
4 |
+
import torchaudio
|
5 |
+
import torch
|
6 |
+
import os
|
7 |
+
import soundfile as sf
|
8 |
+
|
9 |
+
model_name = "mispeech/ced-tiny"
|
10 |
+
feature_extractor = CedFeatureExtractor.from_pretrained(model_name)
|
11 |
+
model = CedForAudioClassification.from_pretrained(model_name)
|
12 |
+
|
13 |
+
st.title("Audio Classification App")
|
14 |
+
st.subheader("Trained on 50 classes of ESC 50 dataset")
|
15 |
+
st.write("Upload an audio file to predict its class.")
|
16 |
+
|
17 |
+
audio_file = st.file_uploader("Upload Audio File", type=["wav","mp3","m4a"])
|
18 |
+
|
19 |
+
if audio_file is not None:
|
20 |
+
st.write(f"Uploaded file: {audio_file.name}")
|
21 |
+
|
22 |
+
try:
|
23 |
+
temp_file_path = "temp.wav"
|
24 |
+
with open(temp_file_path, "wb") as f:
|
25 |
+
f.write(audio_file.read())
|
26 |
+
|
27 |
+
try:
|
28 |
+
audio, sampling_rate = torchaudio.load(temp_file_path)
|
29 |
+
except Exception:
|
30 |
+
st.warning("Fallback to soundfile for audio loading.")
|
31 |
+
audio_data, sampling_rate = sf.read(temp_file_path)
|
32 |
+
audio = torch.tensor(audio_data).unsqueeze(0)
|
33 |
+
|
34 |
+
if sampling_rate != 16000:
|
35 |
+
st.warning("Resampling audio to 16000 Hz...")
|
36 |
+
resampler = torchaudio.transforms.Resample(orig_freq=sampling_rate, new_freq=16000)
|
37 |
+
audio = resampler(audio)
|
38 |
+
sampling_rate = 16000
|
39 |
+
|
40 |
+
inputs = feature_extractor(audio, sampling_rate=sampling_rate, return_tensors="pt")
|
41 |
+
|
42 |
+
with torch.no_grad():
|
43 |
+
logits = model(**inputs).logits
|
44 |
+
|
45 |
+
predicted_class_id = torch.argmax(logits, dim=-1).item()
|
46 |
+
predicted_label = model.config.id2label[predicted_class_id]
|
47 |
+
|
48 |
+
st.success(f"Predicted Class: {predicted_label}")
|
49 |
+
|
50 |
+
os.remove(temp_file_path)
|
51 |
+
except Exception as e:
|
52 |
+
st.error(f"An error occurred: {e}")
|
53 |
+
else:
|
54 |
+
st.info("Please upload a .wav audio file to continue.")
|