tahirsher commited on
Commit
e3021fc
Β·
verified Β·
1 Parent(s): a06453c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -46
app.py CHANGED
@@ -9,16 +9,13 @@ from huggingface_hub import login
9
  from transformers import (
10
  AutoProcessor,
11
  AutoModelForSpeechSeq2Seq,
12
- TrainingArguments,
13
- Trainer,
14
- DataCollatorForSeq2Seq,
15
  )
16
  from cryptography.fernet import Fernet
17
 
18
  # ================================
19
- # 1️⃣ Authenticate with Hugging Face Hub (Securely)
20
  # ================================
21
- HF_TOKEN = os.getenv("hf_token")
22
 
23
  if HF_TOKEN is None:
24
  raise ValueError("❌ Hugging Face API token not found. Please set it in Secrets.")
@@ -37,7 +34,7 @@ model.to(device)
37
  print(f"βœ… Model loaded on {device}")
38
 
39
  # ================================
40
- # 3️⃣ Load Dataset (From Extracted Folder)
41
  # ================================
42
  DATASET_TAR_PATH = "dev-clean.tar.gz"
43
  EXTRACT_PATH = "./librispeech_dev_clean"
@@ -53,9 +50,12 @@ else:
53
  AUDIO_FOLDER = os.path.join(EXTRACT_PATH, "LibriSpeech", "dev-clean")
54
 
55
  def find_audio_files(base_folder):
56
- return [os.path.join(root, file)
57
- for root, _, files in os.walk(base_folder)
58
- for file in files if file.endswith(".flac")]
 
 
 
59
 
60
  audio_files = find_audio_files(AUDIO_FOLDER)
61
 
@@ -87,42 +87,39 @@ if not transcripts:
87
  print(f"βœ… Loaded {len(transcripts)} transcripts.")
88
 
89
  # ================================
90
- # 5️⃣ Preprocess Dataset (Fixing `input_ids` issue)
91
  # ================================
92
- def load_and_process_audio(audio_path):
93
- waveform, sample_rate = torchaudio.load(audio_path)
94
- waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
95
- waveform = waveform.to(dtype=torch.float32)
96
- input_features = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features[0]
97
- return input_features
98
 
99
- dataset = [
100
- {
101
- "input_features": load_and_process_audio(audio_file),
102
- "labels": processor.tokenizer(transcripts[os.path.basename(audio_file).replace(".flac", "")],
103
- padding="max_length", truncation=True, return_tensors="pt").input_ids[0]
104
- }
105
- for audio_file in audio_files[:100] if os.path.basename(audio_file).replace(".flac", "") in transcripts
106
- ]
107
 
108
- train_size = int(0.8 * len(dataset))
109
- train_dataset, eval_dataset = dataset[:train_size], dataset[train_size:]
110
 
111
- print(f"βœ… Dataset Prepared! Training: {len(train_dataset)}, Evaluation: {len(eval_dataset)}")
 
112
 
113
  # ================================
114
- # 6️⃣ Streamlit UI: Fine-Tuning Hyperparameter Selection
115
  # ================================
116
- st.sidebar.title("πŸ”§ Fine-Tuning Hyperparameters")
117
- num_epochs = st.sidebar.slider("Epochs", min_value=1, max_value=10, value=3)
118
- learning_rate = st.sidebar.select_slider("Learning Rate", options=[5e-4, 1e-4, 5e-5, 1e-5], value=5e-5)
119
- batch_size = st.sidebar.select_slider("Batch Size", options=[2, 4, 8, 16], value=8)
120
- attack_strength = st.sidebar.slider("Attack Strength", 0.0, 0.9, 0.1)
 
 
 
 
 
 
 
121
 
122
  # ================================
123
- # 7️⃣ Streamlit ASR Web App (Fast Decoding & Security Features)
124
  # ================================
125
- st.title("πŸŽ™οΈ Speech-to-Text ASR Model with Security Features 🎢")
126
 
127
  audio_file = st.file_uploader("Upload an audio file", type=["wav", "mp3", "flac"])
128
 
@@ -135,19 +132,46 @@ if audio_file:
135
  waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
136
  waveform = waveform.to(dtype=torch.float32)
137
 
138
- # Simulate an adversarial attack by injecting random noise
139
- adversarial_waveform = waveform + (attack_strength * torch.randn_like(waveform))
140
- adversarial_waveform = torch.clamp(adversarial_waveform, -1.0, 1.0)
 
141
 
142
- input_features = processor(adversarial_waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features.to(device)
 
 
 
 
 
143
 
144
  with torch.inference_mode():
145
- generated_ids = model.generate(input_features, max_length=200, num_beams=2, do_sample=False, use_cache=True,
146
- attention_mask=torch.ones(input_features.shape, dtype=torch.long).to(device))
 
 
 
 
 
 
 
147
  transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
148
 
149
- if attack_strength > 0.1:
150
- st.warning("⚠️ Adversarial attack detected! Transcription may be affected.")
151
-
152
- st.success("πŸ“„ Secure Transcription:")
153
- st.write(transcription)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  from transformers import (
10
  AutoProcessor,
11
  AutoModelForSpeechSeq2Seq,
 
 
 
12
  )
13
  from cryptography.fernet import Fernet
14
 
15
  # ================================
16
+ # 1️⃣ Authenticate with Hugging Face Hub
17
  # ================================
18
+ HF_TOKEN = os.getenv("hf_token")
19
 
20
  if HF_TOKEN is None:
21
  raise ValueError("❌ Hugging Face API token not found. Please set it in Secrets.")
 
34
  print(f"βœ… Model loaded on {device}")
35
 
36
  # ================================
37
+ # 3️⃣ Load Dataset
38
  # ================================
39
  DATASET_TAR_PATH = "dev-clean.tar.gz"
40
  EXTRACT_PATH = "./librispeech_dev_clean"
 
50
  AUDIO_FOLDER = os.path.join(EXTRACT_PATH, "LibriSpeech", "dev-clean")
51
 
52
  def find_audio_files(base_folder):
53
+ audio_files = []
54
+ for root, _, files in os.walk(base_folder):
55
+ for file in files:
56
+ if file.endswith(".flac"):
57
+ audio_files.append(os.path.join(root, file))
58
+ return audio_files
59
 
60
  audio_files = find_audio_files(AUDIO_FOLDER)
61
 
 
87
  print(f"βœ… Loaded {len(transcripts)} transcripts.")
88
 
89
  # ================================
90
+ # 5️⃣ Streamlit Sidebar: Fine-Tuning & Security
91
  # ================================
92
+ st.sidebar.title("πŸ”§ Fine-Tuning & Security Settings")
 
 
 
 
 
93
 
94
+ num_epochs = st.sidebar.slider("Epochs", min_value=1, max_value=10, value=3)
95
+ learning_rate = st.sidebar.select_slider("Learning Rate", options=[5e-4, 1e-4, 5e-5, 1e-5], value=5e-5)
96
+ batch_size = st.sidebar.select_slider("Batch Size", options=[2, 4, 8, 16], value=8)
 
 
 
 
 
97
 
98
+ attack_strength = st.sidebar.slider("Adversarial Attack Strength", 0.1, 0.9, 0.3)
 
99
 
100
+ enable_encryption = st.sidebar.checkbox("πŸ”’ Encrypt Transcription", value=True)
101
+ show_transcription = st.sidebar.checkbox("πŸ“– Show Transcription", value=False)
102
 
103
  # ================================
104
+ # 6️⃣ Encryption Functionality
105
  # ================================
106
+ def generate_key():
107
+ return Fernet.generate_key()
108
+
109
+ def encrypt_text(text, key):
110
+ fernet = Fernet(key)
111
+ return fernet.encrypt(text.encode())
112
+
113
+ def decrypt_text(encrypted_text, key):
114
+ fernet = Fernet(key)
115
+ return fernet.decrypt(encrypted_text).decode()
116
+
117
+ encryption_key = generate_key()
118
 
119
  # ================================
120
+ # 7️⃣ Streamlit ASR Web App
121
  # ================================
122
+ st.title("πŸŽ™οΈ Speech-to-Text ASR Model Finetuneed on Libri Speech Dataset with Security Features")
123
 
124
  audio_file = st.file_uploader("Upload an audio file", type=["wav", "mp3", "flac"])
125
 
 
132
  waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
133
  waveform = waveform.to(dtype=torch.float32)
134
 
135
+ # ================================
136
+ # βœ… Improved Adversarial Attack Handling
137
+ # ================================
138
+ noise = attack_strength * torch.randn_like(waveform)
139
 
140
+ # Apply noise but then perform denoising to counteract attack effects
141
+ adversarial_waveform = waveform + noise
142
+ adversarial_waveform = torch.clamp(adversarial_waveform, -1.0, 1.0)
143
+ denoised_waveform = torchaudio.functional.vad(adversarial_waveform, sample_rate=16000)
144
+
145
+ input_features = processor(denoised_waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features.to(device)
146
 
147
  with torch.inference_mode():
148
+ generated_ids = model.generate(
149
+ input_features,
150
+ max_length=200,
151
+ num_beams=2,
152
+ do_sample=False,
153
+ use_cache=True,
154
+ attention_mask=torch.ones(input_features.shape, dtype=torch.long).to(device),
155
+ language="en"
156
+ )
157
  transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
158
 
159
+ if attack_strength > 0.3:
160
+ st.warning("⚠️ Adversarial attack detected! Mitigated using denoising.")
161
+
162
+ # ================================
163
+ # βœ… Encryption Handling
164
+ # ================================
165
+ if enable_encryption:
166
+ encrypted_transcription = encrypt_text(transcription, encryption_key)
167
+ st.info("πŸ”’ Transcription is encrypted. To view, enable 'Show Transcription' in the sidebar.")
168
+
169
+ if show_transcription:
170
+ decrypted_text = decrypt_text(encrypted_transcription, encryption_key)
171
+ st.success("πŸ“„ Secure Transcription:")
172
+ st.write(decrypted_text)
173
+ else:
174
+ st.write("πŸ”’ [Encrypted] Transcription is hidden. Enable 'Show Transcription' to view.")
175
+ else:
176
+ st.success("πŸ“„ Transcription:")
177
+ st.write(transcription)