Update app.py
Browse files
app.py
CHANGED
@@ -37,34 +37,35 @@ model.to(device)
|
|
37 |
print(f"β
Model loaded on {device}")
|
38 |
|
39 |
# ================================
|
40 |
-
# 3οΈβ£ Load Dataset (
|
41 |
# ================================
|
42 |
DATASET_TAR_PATH = "dev-clean.tar.gz"
|
43 |
EXTRACT_PATH = "./librispeech_dev_clean"
|
44 |
-
AUDIO_FOLDER = os.path.join(EXTRACT_PATH, "LibriSpeech", "dev-clean")
|
45 |
|
46 |
-
if not os.path.exists(
|
47 |
print("π Extracting dataset...")
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
print("β
Extraction complete.")
|
52 |
-
except Exception as e:
|
53 |
-
raise RuntimeError(f"β Dataset extraction failed: {e}")
|
54 |
else:
|
55 |
print("β
Dataset already extracted.")
|
56 |
|
|
|
|
|
57 |
def find_audio_files(base_folder):
|
58 |
-
return [os.path.join(root, file)
|
|
|
|
|
59 |
|
60 |
audio_files = find_audio_files(AUDIO_FOLDER)
|
61 |
|
62 |
if not audio_files:
|
63 |
raise FileNotFoundError(f"β No .flac files found in {AUDIO_FOLDER}. Check dataset structure!")
|
|
|
64 |
print(f"β
Found {len(audio_files)} audio files in dataset!")
|
65 |
|
66 |
# ================================
|
67 |
-
# 4οΈβ£ Load Transcripts
|
68 |
# ================================
|
69 |
def load_transcripts():
|
70 |
transcript_dict = {}
|
@@ -82,10 +83,11 @@ def load_transcripts():
|
|
82 |
transcripts = load_transcripts()
|
83 |
if not transcripts:
|
84 |
raise FileNotFoundError("β No transcripts found! Check dataset structure.")
|
|
|
85 |
print(f"β
Loaded {len(transcripts)} transcripts.")
|
86 |
|
87 |
# ================================
|
88 |
-
# 5οΈβ£ Preprocess Dataset (
|
89 |
# ================================
|
90 |
def load_and_process_audio(audio_path):
|
91 |
waveform, sample_rate = torchaudio.load(audio_path)
|
@@ -94,17 +96,17 @@ def load_and_process_audio(audio_path):
|
|
94 |
input_features = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features[0]
|
95 |
return input_features
|
96 |
|
97 |
-
dataset = [
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
|
|
104 |
|
105 |
train_size = int(0.8 * len(dataset))
|
106 |
-
train_dataset = dataset[:train_size]
|
107 |
-
eval_dataset = dataset[train_size:]
|
108 |
|
109 |
print(f"β
Dataset Prepared! Training: {len(train_dataset)}, Evaluation: {len(eval_dataset)}")
|
110 |
|
@@ -118,7 +120,7 @@ batch_size = st.sidebar.select_slider("Batch Size", options=[2, 4, 8, 16], value
|
|
118 |
attack_strength = st.sidebar.slider("Attack Strength", 0.0, 0.9, 0.1)
|
119 |
|
120 |
# ================================
|
121 |
-
# 7οΈβ£ Streamlit ASR Web App (
|
122 |
# ================================
|
123 |
st.title("ποΈ Speech-to-Text ASR Model with Security Features πΆ")
|
124 |
|
@@ -133,18 +135,19 @@ if audio_file:
|
|
133 |
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
|
134 |
waveform = waveform.to(dtype=torch.float32)
|
135 |
|
136 |
-
#
|
137 |
-
|
138 |
-
adversarial_waveform = torch.clamp(
|
139 |
|
140 |
input_features = processor(adversarial_waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features.to(device)
|
141 |
|
142 |
with torch.inference_mode():
|
143 |
-
generated_ids = model.generate(input_features, max_length=200, num_beams=2, do_sample=False, use_cache=True,
|
|
|
144 |
transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
145 |
|
146 |
if attack_strength > 0.1:
|
147 |
st.warning("β οΈ Adversarial attack detected! Transcription may be affected.")
|
148 |
|
149 |
st.success("π Secure Transcription:")
|
150 |
-
st.write(transcription)
|
|
|
37 |
print(f"β
Model loaded on {device}")
|
38 |
|
39 |
# ================================
|
40 |
+
# 3οΈβ£ Load Dataset (From Extracted Folder)
|
41 |
# ================================
|
42 |
DATASET_TAR_PATH = "dev-clean.tar.gz"
|
43 |
EXTRACT_PATH = "./librispeech_dev_clean"
|
|
|
44 |
|
45 |
+
if not os.path.exists(EXTRACT_PATH):
|
46 |
print("π Extracting dataset...")
|
47 |
+
with tarfile.open(DATASET_TAR_PATH, "r:gz") as tar:
|
48 |
+
tar.extractall(EXTRACT_PATH)
|
49 |
+
print("β
Extraction complete.")
|
|
|
|
|
|
|
50 |
else:
|
51 |
print("β
Dataset already extracted.")
|
52 |
|
53 |
+
AUDIO_FOLDER = os.path.join(EXTRACT_PATH, "LibriSpeech", "dev-clean")
|
54 |
+
|
55 |
def find_audio_files(base_folder):
|
56 |
+
return [os.path.join(root, file)
|
57 |
+
for root, _, files in os.walk(base_folder)
|
58 |
+
for file in files if file.endswith(".flac")]
|
59 |
|
60 |
audio_files = find_audio_files(AUDIO_FOLDER)
|
61 |
|
62 |
if not audio_files:
|
63 |
raise FileNotFoundError(f"β No .flac files found in {AUDIO_FOLDER}. Check dataset structure!")
|
64 |
+
|
65 |
print(f"β
Found {len(audio_files)} audio files in dataset!")
|
66 |
|
67 |
# ================================
|
68 |
+
# 4οΈβ£ Load Transcripts
|
69 |
# ================================
|
70 |
def load_transcripts():
|
71 |
transcript_dict = {}
|
|
|
83 |
transcripts = load_transcripts()
|
84 |
if not transcripts:
|
85 |
raise FileNotFoundError("β No transcripts found! Check dataset structure.")
|
86 |
+
|
87 |
print(f"β
Loaded {len(transcripts)} transcripts.")
|
88 |
|
89 |
# ================================
|
90 |
+
# 5οΈβ£ Preprocess Dataset (Fixing `input_ids` issue)
|
91 |
# ================================
|
92 |
def load_and_process_audio(audio_path):
|
93 |
waveform, sample_rate = torchaudio.load(audio_path)
|
|
|
96 |
input_features = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features[0]
|
97 |
return input_features
|
98 |
|
99 |
+
dataset = [
|
100 |
+
{
|
101 |
+
"input_features": load_and_process_audio(audio_file),
|
102 |
+
"labels": processor.tokenizer(transcripts[os.path.basename(audio_file).replace(".flac", "")],
|
103 |
+
padding="max_length", truncation=True, return_tensors="pt").input_ids[0]
|
104 |
+
}
|
105 |
+
for audio_file in audio_files[:100] if os.path.basename(audio_file).replace(".flac", "") in transcripts
|
106 |
+
]
|
107 |
|
108 |
train_size = int(0.8 * len(dataset))
|
109 |
+
train_dataset, eval_dataset = dataset[:train_size], dataset[train_size:]
|
|
|
110 |
|
111 |
print(f"β
Dataset Prepared! Training: {len(train_dataset)}, Evaluation: {len(eval_dataset)}")
|
112 |
|
|
|
120 |
attack_strength = st.sidebar.slider("Attack Strength", 0.0, 0.9, 0.1)
|
121 |
|
122 |
# ================================
|
123 |
+
# 7οΈβ£ Streamlit ASR Web App (Fast Decoding & Security Features)
|
124 |
# ================================
|
125 |
st.title("ποΈ Speech-to-Text ASR Model with Security Features πΆ")
|
126 |
|
|
|
135 |
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
|
136 |
waveform = waveform.to(dtype=torch.float32)
|
137 |
|
138 |
+
# Simulate an adversarial attack by injecting random noise
|
139 |
+
adversarial_waveform = waveform + (attack_strength * torch.randn_like(waveform))
|
140 |
+
adversarial_waveform = torch.clamp(adversarial_waveform, -1.0, 1.0)
|
141 |
|
142 |
input_features = processor(adversarial_waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features.to(device)
|
143 |
|
144 |
with torch.inference_mode():
|
145 |
+
generated_ids = model.generate(input_features, max_length=200, num_beams=2, do_sample=False, use_cache=True,
|
146 |
+
attention_mask=torch.ones(input_features.shape, dtype=torch.long).to(device))
|
147 |
transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
148 |
|
149 |
if attack_strength > 0.1:
|
150 |
st.warning("β οΈ Adversarial attack detected! Transcription may be affected.")
|
151 |
|
152 |
st.success("π Secure Transcription:")
|
153 |
+
st.write(transcription)
|