Update app.py
Browse files
app.py
CHANGED
@@ -37,38 +37,34 @@ model.to(device)
|
|
37 |
print(f"β
Model loaded on {device}")
|
38 |
|
39 |
# ================================
|
40 |
-
# 3οΈβ£ Load Dataset (
|
41 |
# ================================
|
42 |
DATASET_TAR_PATH = "dev-clean.tar.gz"
|
43 |
EXTRACT_PATH = "./librispeech_dev_clean"
|
|
|
44 |
|
45 |
-
if not os.path.exists(
|
46 |
print("π Extracting dataset...")
|
47 |
-
|
48 |
-
|
49 |
-
|
|
|
|
|
|
|
50 |
else:
|
51 |
print("β
Dataset already extracted.")
|
52 |
|
53 |
-
AUDIO_FOLDER = os.path.join(EXTRACT_PATH, "LibriSpeech", "dev-clean")
|
54 |
-
|
55 |
def find_audio_files(base_folder):
|
56 |
-
|
57 |
-
for root, _, files in os.walk(base_folder):
|
58 |
-
for file in files:
|
59 |
-
if file.endswith(".flac"):
|
60 |
-
audio_files.append(os.path.join(root, file))
|
61 |
-
return audio_files
|
62 |
|
63 |
audio_files = find_audio_files(AUDIO_FOLDER)
|
64 |
|
65 |
if not audio_files:
|
66 |
raise FileNotFoundError(f"β No .flac files found in {AUDIO_FOLDER}. Check dataset structure!")
|
67 |
-
|
68 |
print(f"β
Found {len(audio_files)} audio files in dataset!")
|
69 |
|
70 |
# ================================
|
71 |
-
# 4οΈβ£ Load Transcripts
|
72 |
# ================================
|
73 |
def load_transcripts():
|
74 |
transcript_dict = {}
|
@@ -86,11 +82,10 @@ def load_transcripts():
|
|
86 |
transcripts = load_transcripts()
|
87 |
if not transcripts:
|
88 |
raise FileNotFoundError("β No transcripts found! Check dataset structure.")
|
89 |
-
|
90 |
print(f"β
Loaded {len(transcripts)} transcripts.")
|
91 |
|
92 |
# ================================
|
93 |
-
# 5οΈβ£ Preprocess Dataset (
|
94 |
# ================================
|
95 |
def load_and_process_audio(audio_path):
|
96 |
waveform, sample_rate = torchaudio.load(audio_path)
|
@@ -123,9 +118,9 @@ batch_size = st.sidebar.select_slider("Batch Size", options=[2, 4, 8, 16], value
|
|
123 |
attack_strength = st.sidebar.slider("Attack Strength", 0.0, 0.9, 0.1)
|
124 |
|
125 |
# ================================
|
126 |
-
# 7οΈβ£ Streamlit ASR Web App (
|
127 |
# ================================
|
128 |
-
st.title("ποΈ Speech-to-Text ASR Model with Security Features
|
129 |
|
130 |
audio_file = st.file_uploader("Upload an audio file", type=["wav", "mp3", "flac"])
|
131 |
|
@@ -138,9 +133,9 @@ if audio_file:
|
|
138 |
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
|
139 |
waveform = waveform.to(dtype=torch.float32)
|
140 |
|
141 |
-
#
|
142 |
-
|
143 |
-
adversarial_waveform = torch.clamp(
|
144 |
|
145 |
input_features = processor(adversarial_waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features.to(device)
|
146 |
|
|
|
37 |
print(f"β
Model loaded on {device}")
|
38 |
|
39 |
# ================================
|
40 |
+
# 3οΈβ£ Load Dataset (With Fixes)
|
41 |
# ================================
|
42 |
DATASET_TAR_PATH = "dev-clean.tar.gz"
|
43 |
EXTRACT_PATH = "./librispeech_dev_clean"
|
44 |
+
AUDIO_FOLDER = os.path.join(EXTRACT_PATH, "LibriSpeech", "dev-clean")
|
45 |
|
46 |
+
if not os.path.exists(AUDIO_FOLDER):
|
47 |
print("π Extracting dataset...")
|
48 |
+
try:
|
49 |
+
with tarfile.open(DATASET_TAR_PATH, "r:gz") as tar:
|
50 |
+
tar.extractall(EXTRACT_PATH)
|
51 |
+
print("β
Extraction complete.")
|
52 |
+
except Exception as e:
|
53 |
+
raise RuntimeError(f"β Dataset extraction failed: {e}")
|
54 |
else:
|
55 |
print("β
Dataset already extracted.")
|
56 |
|
|
|
|
|
57 |
def find_audio_files(base_folder):
|
58 |
+
return [os.path.join(root, file) for root, _, files in os.walk(base_folder) for file in files if file.endswith(".flac")]
|
|
|
|
|
|
|
|
|
|
|
59 |
|
60 |
audio_files = find_audio_files(AUDIO_FOLDER)
|
61 |
|
62 |
if not audio_files:
|
63 |
raise FileNotFoundError(f"β No .flac files found in {AUDIO_FOLDER}. Check dataset structure!")
|
|
|
64 |
print(f"β
Found {len(audio_files)} audio files in dataset!")
|
65 |
|
66 |
# ================================
|
67 |
+
# 4οΈβ£ Load Transcripts (Fixed Mapping)
|
68 |
# ================================
|
69 |
def load_transcripts():
|
70 |
transcript_dict = {}
|
|
|
82 |
transcripts = load_transcripts()
|
83 |
if not transcripts:
|
84 |
raise FileNotFoundError("β No transcripts found! Check dataset structure.")
|
|
|
85 |
print(f"β
Loaded {len(transcripts)} transcripts.")
|
86 |
|
87 |
# ================================
|
88 |
+
# 5οΈβ£ Preprocess Dataset (Fixed `input_ids` Issue)
|
89 |
# ================================
|
90 |
def load_and_process_audio(audio_path):
|
91 |
waveform, sample_rate = torchaudio.load(audio_path)
|
|
|
118 |
attack_strength = st.sidebar.slider("Attack Strength", 0.0, 0.9, 0.1)
|
119 |
|
120 |
# ================================
|
121 |
+
# 7οΈβ£ Streamlit ASR Web App (Fixed Security & Processing)
|
122 |
# ================================
|
123 |
+
st.title("ποΈ Speech-to-Text ASR Model with Security Features πΆ")
|
124 |
|
125 |
audio_file = st.file_uploader("Upload an audio file", type=["wav", "mp3", "flac"])
|
126 |
|
|
|
133 |
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
|
134 |
waveform = waveform.to(dtype=torch.float32)
|
135 |
|
136 |
+
# Apply adversarial attack noise with limit
|
137 |
+
noise = torch.randn_like(waveform) * attack_strength
|
138 |
+
adversarial_waveform = torch.clamp(waveform + noise, -1.0, 1.0)
|
139 |
|
140 |
input_features = processor(adversarial_waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features.to(device)
|
141 |
|