Update app.py
Browse files
app.py
CHANGED
@@ -5,6 +5,7 @@ import torchaudio
|
|
5 |
import numpy as np
|
6 |
import streamlit as st
|
7 |
import matplotlib.pyplot as plt
|
|
|
8 |
from huggingface_hub import login
|
9 |
from transformers import (
|
10 |
AutoProcessor,
|
@@ -18,10 +19,8 @@ from transformers import (
|
|
18 |
# 1οΈβ£ Authenticate with Hugging Face Hub (Securely)
|
19 |
# ================================
|
20 |
HF_TOKEN = os.getenv("hf_token")
|
21 |
-
|
22 |
if HF_TOKEN is None:
|
23 |
raise ValueError("β Hugging Face API token not found. Please set it in Secrets.")
|
24 |
-
|
25 |
login(token=HF_TOKEN)
|
26 |
|
27 |
# ================================
|
@@ -33,7 +32,6 @@ model = AutoModelForSpeechSeq2Seq.from_pretrained(MODEL_NAME)
|
|
33 |
|
34 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
35 |
model.to(device)
|
36 |
-
print(f"β
Model loaded on {device}")
|
37 |
|
38 |
# ================================
|
39 |
# 3οΈβ£ Load Dataset (From Extracted Folder)
|
@@ -42,89 +40,53 @@ DATASET_TAR_PATH = "dev-clean.tar.gz"
|
|
42 |
EXTRACT_PATH = "./librispeech_dev_clean"
|
43 |
|
44 |
if not os.path.exists(EXTRACT_PATH):
|
45 |
-
print("π Extracting dataset...")
|
46 |
with tarfile.open(DATASET_TAR_PATH, "r:gz") as tar:
|
47 |
tar.extractall(EXTRACT_PATH)
|
48 |
-
print("β
Extraction complete.")
|
49 |
-
else:
|
50 |
-
print("β
Dataset already extracted.")
|
51 |
|
52 |
AUDIO_FOLDER = os.path.join(EXTRACT_PATH, "LibriSpeech", "dev-clean")
|
53 |
|
54 |
def find_audio_files(base_folder):
|
55 |
-
|
56 |
-
audio_files = []
|
57 |
-
for root, _, files in os.walk(base_folder):
|
58 |
-
for file in files:
|
59 |
-
if file.endswith(".flac"):
|
60 |
-
audio_files.append(os.path.join(root, file))
|
61 |
-
return audio_files
|
62 |
|
63 |
audio_files = find_audio_files(AUDIO_FOLDER)
|
64 |
|
65 |
-
if not audio_files:
|
66 |
-
raise FileNotFoundError(f"β No .flac files found in {AUDIO_FOLDER}. Check dataset structure!")
|
67 |
-
|
68 |
-
print(f"β
Found {len(audio_files)} audio files in dataset!")
|
69 |
-
|
70 |
# ================================
|
71 |
# 4οΈβ£ Load Transcripts
|
72 |
# ================================
|
73 |
def load_transcripts():
|
74 |
-
"""Loads transcript text files and maps them to audio files."""
|
75 |
transcript_dict = {}
|
76 |
for root, _, files in os.walk(AUDIO_FOLDER):
|
77 |
for file in files:
|
78 |
-
if file.endswith(".txt"):
|
79 |
with open(os.path.join(root, file), "r", encoding="utf-8") as f:
|
80 |
for line in f:
|
81 |
parts = line.strip().split(" ", 1)
|
82 |
if len(parts) == 2:
|
83 |
-
|
84 |
-
transcript_dict[file_id] = text
|
85 |
return transcript_dict
|
86 |
|
87 |
transcripts = load_transcripts()
|
88 |
-
if not transcripts:
|
89 |
-
raise FileNotFoundError("β No transcripts found! Check dataset structure.")
|
90 |
-
|
91 |
-
print(f"β
Loaded {len(transcripts)} transcripts.")
|
92 |
|
93 |
# ================================
|
94 |
-
# 5οΈβ£
|
95 |
# ================================
|
96 |
-
def
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
input_features = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features[0]
|
102 |
-
return input_features
|
103 |
-
|
104 |
-
dataset = []
|
105 |
-
for audio_file in audio_files[:100]: # Limit to 100 for faster processing
|
106 |
-
file_id = os.path.basename(audio_file).replace(".flac", "")
|
107 |
-
|
108 |
-
if file_id in transcripts:
|
109 |
-
input_features = load_and_process_audio(audio_file)
|
110 |
-
labels = processor.tokenizer(transcripts[file_id], padding="max_length", truncation=True, return_tensors="pt").input_ids[0]
|
111 |
-
|
112 |
-
dataset.append({"input_features": input_features, "labels": labels})
|
113 |
-
|
114 |
-
train_size = int(0.8 * len(dataset))
|
115 |
-
train_dataset = dataset[:train_size]
|
116 |
-
eval_dataset = dataset[train_size:]
|
117 |
-
|
118 |
-
print(f"β
Dataset Prepared! Training: {len(train_dataset)}, Evaluation: {len(eval_dataset)}")
|
119 |
|
120 |
# ================================
|
121 |
-
# 6οΈβ£
|
122 |
# ================================
|
123 |
-
|
|
|
|
|
|
|
|
|
124 |
|
125 |
-
|
126 |
-
|
127 |
-
batch_size = st.sidebar.select_slider("Batch Size", options=[2, 4, 8, 16], value=8)
|
128 |
|
129 |
# ================================
|
130 |
# 7οΈβ£ Training Arguments & Trainer
|
@@ -133,10 +95,10 @@ training_args = TrainingArguments(
|
|
133 |
output_dir="./asr_model_finetuned",
|
134 |
eval_strategy="epoch",
|
135 |
save_strategy="epoch",
|
136 |
-
learning_rate=
|
137 |
-
per_device_train_batch_size=
|
138 |
-
per_device_eval_batch_size=
|
139 |
-
num_train_epochs=
|
140 |
weight_decay=0.01,
|
141 |
logging_dir="./logs",
|
142 |
logging_steps=500,
|
@@ -146,20 +108,14 @@ training_args = TrainingArguments(
|
|
146 |
hub_token=HF_TOKEN,
|
147 |
)
|
148 |
|
149 |
-
data_collator = DataCollatorForSeq2Seq(tokenizer=processor.tokenizer, model=model, return_tensors="pt")
|
150 |
-
|
151 |
-
trainer = Trainer(
|
152 |
-
model=model,
|
153 |
-
args=training_args,
|
154 |
-
train_dataset=train_dataset,
|
155 |
-
eval_dataset=eval_dataset,
|
156 |
-
data_collator=data_collator,
|
157 |
-
)
|
158 |
-
|
159 |
# ================================
|
160 |
-
# 8οΈβ£ Streamlit ASR Web App (
|
161 |
# ================================
|
162 |
-
st.title("ποΈ Speech-to-Text ASR Model with
|
|
|
|
|
|
|
|
|
163 |
|
164 |
audio_file = st.file_uploader("Upload an audio file", type=["wav", "mp3", "flac"])
|
165 |
|
@@ -176,16 +132,19 @@ if audio_file:
|
|
176 |
).input_features.to(device)
|
177 |
|
178 |
with torch.inference_mode():
|
179 |
-
generated_ids = model.generate(
|
180 |
-
input_features,
|
181 |
-
max_length=200,
|
182 |
-
num_beams=2,
|
183 |
-
do_sample=False,
|
184 |
-
use_cache=True,
|
185 |
-
language="en",
|
186 |
-
attention_mask=torch.ones(input_features.shape, dtype=torch.long).to(device),
|
187 |
-
)
|
188 |
transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
189 |
|
190 |
-
|
191 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
import numpy as np
|
6 |
import streamlit as st
|
7 |
import matplotlib.pyplot as plt
|
8 |
+
from cryptography.fernet import Fernet # Encryption
|
9 |
from huggingface_hub import login
|
10 |
from transformers import (
|
11 |
AutoProcessor,
|
|
|
19 |
# 1οΈβ£ Authenticate with Hugging Face Hub (Securely)
|
20 |
# ================================
|
21 |
HF_TOKEN = os.getenv("hf_token")
|
|
|
22 |
if HF_TOKEN is None:
|
23 |
raise ValueError("β Hugging Face API token not found. Please set it in Secrets.")
|
|
|
24 |
login(token=HF_TOKEN)
|
25 |
|
26 |
# ================================
|
|
|
32 |
|
33 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
34 |
model.to(device)
|
|
|
35 |
|
36 |
# ================================
|
37 |
# 3οΈβ£ Load Dataset (From Extracted Folder)
|
|
|
40 |
EXTRACT_PATH = "./librispeech_dev_clean"
|
41 |
|
42 |
if not os.path.exists(EXTRACT_PATH):
|
|
|
43 |
with tarfile.open(DATASET_TAR_PATH, "r:gz") as tar:
|
44 |
tar.extractall(EXTRACT_PATH)
|
|
|
|
|
|
|
45 |
|
46 |
AUDIO_FOLDER = os.path.join(EXTRACT_PATH, "LibriSpeech", "dev-clean")
|
47 |
|
48 |
def find_audio_files(base_folder):
|
49 |
+
return [os.path.join(root, file) for root, _, files in os.walk(base_folder) for file in files if file.endswith(".flac")]
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
|
51 |
audio_files = find_audio_files(AUDIO_FOLDER)
|
52 |
|
|
|
|
|
|
|
|
|
|
|
53 |
# ================================
|
54 |
# 4οΈβ£ Load Transcripts
|
55 |
# ================================
|
56 |
def load_transcripts():
|
|
|
57 |
transcript_dict = {}
|
58 |
for root, _, files in os.walk(AUDIO_FOLDER):
|
59 |
for file in files:
|
60 |
+
if file.endswith(".txt"):
|
61 |
with open(os.path.join(root, file), "r", encoding="utf-8") as f:
|
62 |
for line in f:
|
63 |
parts = line.strip().split(" ", 1)
|
64 |
if len(parts) == 2:
|
65 |
+
transcript_dict[parts[0]] = parts[1]
|
|
|
66 |
return transcript_dict
|
67 |
|
68 |
transcripts = load_transcripts()
|
|
|
|
|
|
|
|
|
69 |
|
70 |
# ================================
|
71 |
+
# 5οΈβ£ Adversarial Attack Simulation (Modifying Transcripts)
|
72 |
# ================================
|
73 |
+
def generate_adversarial_text(text):
|
74 |
+
words = text.split()
|
75 |
+
if len(words) > 3:
|
76 |
+
words[2] = "[REPLACED]"
|
77 |
+
return " ".join(words)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
|
79 |
# ================================
|
80 |
+
# 6οΈβ£ Encrypt & Decrypt Transcriptions
|
81 |
# ================================
|
82 |
+
key = Fernet.generate_key()
|
83 |
+
cipher = Fernet(key)
|
84 |
+
|
85 |
+
def encrypt_transcription(text):
|
86 |
+
return cipher.encrypt(text.encode()).decode()
|
87 |
|
88 |
+
def decrypt_transcription(encrypted_text):
|
89 |
+
return cipher.decrypt(encrypted_text.encode()).decode()
|
|
|
90 |
|
91 |
# ================================
|
92 |
# 7οΈβ£ Training Arguments & Trainer
|
|
|
95 |
output_dir="./asr_model_finetuned",
|
96 |
eval_strategy="epoch",
|
97 |
save_strategy="epoch",
|
98 |
+
learning_rate=5e-5,
|
99 |
+
per_device_train_batch_size=8,
|
100 |
+
per_device_eval_batch_size=8,
|
101 |
+
num_train_epochs=3,
|
102 |
weight_decay=0.01,
|
103 |
logging_dir="./logs",
|
104 |
logging_steps=500,
|
|
|
108 |
hub_token=HF_TOKEN,
|
109 |
)
|
110 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
# ================================
|
112 |
+
# 8οΈβ£ Streamlit ASR Web App (Enhanced UI)
|
113 |
# ================================
|
114 |
+
st.title("ποΈ Speech-to-Text ASR Model with Security & Attack Detection")
|
115 |
+
|
116 |
+
st.sidebar.title("βοΈ Settings")
|
117 |
+
attack_mode = st.sidebar.checkbox("Enable Adversarial Attack Simulation")
|
118 |
+
encryption_mode = st.sidebar.checkbox("Enable Encryption")
|
119 |
|
120 |
audio_file = st.file_uploader("Upload an audio file", type=["wav", "mp3", "flac"])
|
121 |
|
|
|
132 |
).input_features.to(device)
|
133 |
|
134 |
with torch.inference_mode():
|
135 |
+
generated_ids = model.generate(input_features, max_length=200, num_beams=2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
137 |
|
138 |
+
if attack_mode:
|
139 |
+
transcription = generate_adversarial_text(transcription)
|
140 |
+
st.warning("β οΈ Adversarial attack detected: Modified transcription!")
|
141 |
+
|
142 |
+
if encryption_mode:
|
143 |
+
encrypted_text = encrypt_transcription(transcription)
|
144 |
+
st.success("π Encrypted Transcription:")
|
145 |
+
st.write(encrypted_text)
|
146 |
+
st.text("π Decrypted Transcription:")
|
147 |
+
st.write(decrypt_transcription(encrypted_text))
|
148 |
+
else:
|
149 |
+
st.success("π Transcription:")
|
150 |
+
st.write(transcription)
|