Update app.py
Browse files
app.py
CHANGED
@@ -9,16 +9,13 @@ from huggingface_hub import login
|
|
9 |
from transformers import (
|
10 |
AutoProcessor,
|
11 |
AutoModelForSpeechSeq2Seq,
|
12 |
-
TrainingArguments,
|
13 |
-
Trainer,
|
14 |
-
DataCollatorForSeq2Seq,
|
15 |
)
|
16 |
from cryptography.fernet import Fernet
|
17 |
|
18 |
# ================================
|
19 |
-
# 1οΈβ£ Authenticate with Hugging Face Hub
|
20 |
# ================================
|
21 |
-
HF_TOKEN = os.getenv("hf_token")
|
22 |
|
23 |
if HF_TOKEN is None:
|
24 |
raise ValueError("β Hugging Face API token not found. Please set it in Secrets.")
|
@@ -37,7 +34,7 @@ model.to(device)
|
|
37 |
print(f"β
Model loaded on {device}")
|
38 |
|
39 |
# ================================
|
40 |
-
# 3οΈβ£ Load Dataset
|
41 |
# ================================
|
42 |
DATASET_TAR_PATH = "dev-clean.tar.gz"
|
43 |
EXTRACT_PATH = "./librispeech_dev_clean"
|
@@ -53,9 +50,12 @@ else:
|
|
53 |
AUDIO_FOLDER = os.path.join(EXTRACT_PATH, "LibriSpeech", "dev-clean")
|
54 |
|
55 |
def find_audio_files(base_folder):
|
56 |
-
|
57 |
-
|
58 |
-
|
|
|
|
|
|
|
59 |
|
60 |
audio_files = find_audio_files(AUDIO_FOLDER)
|
61 |
|
@@ -87,42 +87,39 @@ if not transcripts:
|
|
87 |
print(f"β
Loaded {len(transcripts)} transcripts.")
|
88 |
|
89 |
# ================================
|
90 |
-
# 5οΈβ£
|
91 |
# ================================
|
92 |
-
|
93 |
-
waveform, sample_rate = torchaudio.load(audio_path)
|
94 |
-
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
|
95 |
-
waveform = waveform.to(dtype=torch.float32)
|
96 |
-
input_features = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features[0]
|
97 |
-
return input_features
|
98 |
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
"labels": processor.tokenizer(transcripts[os.path.basename(audio_file).replace(".flac", "")],
|
103 |
-
padding="max_length", truncation=True, return_tensors="pt").input_ids[0]
|
104 |
-
}
|
105 |
-
for audio_file in audio_files[:100] if os.path.basename(audio_file).replace(".flac", "") in transcripts
|
106 |
-
]
|
107 |
|
108 |
-
|
109 |
-
train_dataset, eval_dataset = dataset[:train_size], dataset[train_size:]
|
110 |
|
111 |
-
|
|
|
112 |
|
113 |
# ================================
|
114 |
-
# 6οΈβ£
|
115 |
# ================================
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
|
122 |
# ================================
|
123 |
-
# 7οΈβ£ Streamlit ASR Web App
|
124 |
# ================================
|
125 |
-
st.title("ποΈ Speech-to-Text ASR Model with Security Features
|
126 |
|
127 |
audio_file = st.file_uploader("Upload an audio file", type=["wav", "mp3", "flac"])
|
128 |
|
@@ -135,19 +132,46 @@ if audio_file:
|
|
135 |
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
|
136 |
waveform = waveform.to(dtype=torch.float32)
|
137 |
|
138 |
-
#
|
139 |
-
|
140 |
-
|
|
|
141 |
|
142 |
-
|
|
|
|
|
|
|
|
|
|
|
143 |
|
144 |
with torch.inference_mode():
|
145 |
-
generated_ids = model.generate(
|
146 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
147 |
transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
148 |
|
149 |
-
if attack_strength > 0.
|
150 |
-
st.warning("β οΈ Adversarial attack detected!
|
151 |
-
|
152 |
-
|
153 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
from transformers import (
|
10 |
AutoProcessor,
|
11 |
AutoModelForSpeechSeq2Seq,
|
|
|
|
|
|
|
12 |
)
|
13 |
from cryptography.fernet import Fernet
|
14 |
|
15 |
# ================================
|
16 |
+
# 1οΈβ£ Authenticate with Hugging Face Hub
|
17 |
# ================================
|
18 |
+
HF_TOKEN = os.getenv("hf_token")
|
19 |
|
20 |
if HF_TOKEN is None:
|
21 |
raise ValueError("β Hugging Face API token not found. Please set it in Secrets.")
|
|
|
34 |
print(f"β
Model loaded on {device}")
|
35 |
|
36 |
# ================================
|
37 |
+
# 3οΈβ£ Load Dataset
|
38 |
# ================================
|
39 |
DATASET_TAR_PATH = "dev-clean.tar.gz"
|
40 |
EXTRACT_PATH = "./librispeech_dev_clean"
|
|
|
50 |
AUDIO_FOLDER = os.path.join(EXTRACT_PATH, "LibriSpeech", "dev-clean")
|
51 |
|
52 |
def find_audio_files(base_folder):
|
53 |
+
audio_files = []
|
54 |
+
for root, _, files in os.walk(base_folder):
|
55 |
+
for file in files:
|
56 |
+
if file.endswith(".flac"):
|
57 |
+
audio_files.append(os.path.join(root, file))
|
58 |
+
return audio_files
|
59 |
|
60 |
audio_files = find_audio_files(AUDIO_FOLDER)
|
61 |
|
|
|
87 |
print(f"β
Loaded {len(transcripts)} transcripts.")
|
88 |
|
89 |
# ================================
|
90 |
+
# 5οΈβ£ Streamlit Sidebar: Fine-Tuning & Security
|
91 |
# ================================
|
92 |
+
st.sidebar.title("π§ Fine-Tuning & Security Settings")
|
|
|
|
|
|
|
|
|
|
|
93 |
|
94 |
+
num_epochs = st.sidebar.slider("Epochs", min_value=1, max_value=10, value=3)
|
95 |
+
learning_rate = st.sidebar.select_slider("Learning Rate", options=[5e-4, 1e-4, 5e-5, 1e-5], value=5e-5)
|
96 |
+
batch_size = st.sidebar.select_slider("Batch Size", options=[2, 4, 8, 16], value=8)
|
|
|
|
|
|
|
|
|
|
|
97 |
|
98 |
+
attack_strength = st.sidebar.slider("Adversarial Attack Strength", 0.1, 0.9, 0.3)
|
|
|
99 |
|
100 |
+
enable_encryption = st.sidebar.checkbox("π Encrypt Transcription", value=True)
|
101 |
+
show_transcription = st.sidebar.checkbox("π Show Transcription", value=False)
|
102 |
|
103 |
# ================================
|
104 |
+
# 6οΈβ£ Encryption Functionality
|
105 |
# ================================
|
106 |
+
def generate_key():
|
107 |
+
return Fernet.generate_key()
|
108 |
+
|
109 |
+
def encrypt_text(text, key):
|
110 |
+
fernet = Fernet(key)
|
111 |
+
return fernet.encrypt(text.encode())
|
112 |
+
|
113 |
+
def decrypt_text(encrypted_text, key):
|
114 |
+
fernet = Fernet(key)
|
115 |
+
return fernet.decrypt(encrypted_text).decode()
|
116 |
+
|
117 |
+
encryption_key = generate_key()
|
118 |
|
119 |
# ================================
|
120 |
+
# 7οΈβ£ Streamlit ASR Web App
|
121 |
# ================================
|
122 |
+
st.title("ποΈ Speech-to-Text ASR Model Finetuneed on Libri Speech Dataset with Security Features")
|
123 |
|
124 |
audio_file = st.file_uploader("Upload an audio file", type=["wav", "mp3", "flac"])
|
125 |
|
|
|
132 |
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
|
133 |
waveform = waveform.to(dtype=torch.float32)
|
134 |
|
135 |
+
# ================================
|
136 |
+
# β
Improved Adversarial Attack Handling
|
137 |
+
# ================================
|
138 |
+
noise = attack_strength * torch.randn_like(waveform)
|
139 |
|
140 |
+
# Apply noise but then perform denoising to counteract attack effects
|
141 |
+
adversarial_waveform = waveform + noise
|
142 |
+
adversarial_waveform = torch.clamp(adversarial_waveform, -1.0, 1.0)
|
143 |
+
denoised_waveform = torchaudio.functional.vad(adversarial_waveform, sample_rate=16000)
|
144 |
+
|
145 |
+
input_features = processor(denoised_waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features.to(device)
|
146 |
|
147 |
with torch.inference_mode():
|
148 |
+
generated_ids = model.generate(
|
149 |
+
input_features,
|
150 |
+
max_length=200,
|
151 |
+
num_beams=2,
|
152 |
+
do_sample=False,
|
153 |
+
use_cache=True,
|
154 |
+
attention_mask=torch.ones(input_features.shape, dtype=torch.long).to(device),
|
155 |
+
language="en"
|
156 |
+
)
|
157 |
transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
158 |
|
159 |
+
if attack_strength > 0.3:
|
160 |
+
st.warning("β οΈ Adversarial attack detected! Mitigated using denoising.")
|
161 |
+
|
162 |
+
# ================================
|
163 |
+
# β
Encryption Handling
|
164 |
+
# ================================
|
165 |
+
if enable_encryption:
|
166 |
+
encrypted_transcription = encrypt_text(transcription, encryption_key)
|
167 |
+
st.info("π Transcription is encrypted. To view, enable 'Show Transcription' in the sidebar.")
|
168 |
+
|
169 |
+
if show_transcription:
|
170 |
+
decrypted_text = decrypt_text(encrypted_transcription, encryption_key)
|
171 |
+
st.success("π Secure Transcription:")
|
172 |
+
st.write(decrypted_text)
|
173 |
+
else:
|
174 |
+
st.write("π [Encrypted] Transcription is hidden. Enable 'Show Transcription' to view.")
|
175 |
+
else:
|
176 |
+
st.success("π Transcription:")
|
177 |
+
st.write(transcription)
|