tahirsher commited on
Commit
941924a
Β·
verified Β·
1 Parent(s): a312467

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -83
app.py CHANGED
@@ -5,6 +5,7 @@ import torchaudio
5
  import numpy as np
6
  import streamlit as st
7
  import matplotlib.pyplot as plt
 
8
  from huggingface_hub import login
9
  from transformers import (
10
  AutoProcessor,
@@ -18,10 +19,8 @@ from transformers import (
18
  # 1️⃣ Authenticate with Hugging Face Hub (Securely)
19
  # ================================
20
  HF_TOKEN = os.getenv("hf_token")
21
-
22
  if HF_TOKEN is None:
23
  raise ValueError("❌ Hugging Face API token not found. Please set it in Secrets.")
24
-
25
  login(token=HF_TOKEN)
26
 
27
  # ================================
@@ -33,7 +32,6 @@ model = AutoModelForSpeechSeq2Seq.from_pretrained(MODEL_NAME)
33
 
34
  device = "cuda" if torch.cuda.is_available() else "cpu"
35
  model.to(device)
36
- print(f"βœ… Model loaded on {device}")
37
 
38
  # ================================
39
  # 3️⃣ Load Dataset (From Extracted Folder)
@@ -42,89 +40,53 @@ DATASET_TAR_PATH = "dev-clean.tar.gz"
42
  EXTRACT_PATH = "./librispeech_dev_clean"
43
 
44
  if not os.path.exists(EXTRACT_PATH):
45
- print("πŸ”„ Extracting dataset...")
46
  with tarfile.open(DATASET_TAR_PATH, "r:gz") as tar:
47
  tar.extractall(EXTRACT_PATH)
48
- print("βœ… Extraction complete.")
49
- else:
50
- print("βœ… Dataset already extracted.")
51
 
52
  AUDIO_FOLDER = os.path.join(EXTRACT_PATH, "LibriSpeech", "dev-clean")
53
 
54
  def find_audio_files(base_folder):
55
- """Recursively search for all .flac files in subdirectories."""
56
- audio_files = []
57
- for root, _, files in os.walk(base_folder):
58
- for file in files:
59
- if file.endswith(".flac"):
60
- audio_files.append(os.path.join(root, file))
61
- return audio_files
62
 
63
  audio_files = find_audio_files(AUDIO_FOLDER)
64
 
65
- if not audio_files:
66
- raise FileNotFoundError(f"❌ No .flac files found in {AUDIO_FOLDER}. Check dataset structure!")
67
-
68
- print(f"βœ… Found {len(audio_files)} audio files in dataset!")
69
-
70
  # ================================
71
  # 4️⃣ Load Transcripts
72
  # ================================
73
  def load_transcripts():
74
- """Loads transcript text files and maps them to audio files."""
75
  transcript_dict = {}
76
  for root, _, files in os.walk(AUDIO_FOLDER):
77
  for file in files:
78
- if file.endswith(".txt"): # Transcript files
79
  with open(os.path.join(root, file), "r", encoding="utf-8") as f:
80
  for line in f:
81
  parts = line.strip().split(" ", 1)
82
  if len(parts) == 2:
83
- file_id, text = parts
84
- transcript_dict[file_id] = text
85
  return transcript_dict
86
 
87
  transcripts = load_transcripts()
88
- if not transcripts:
89
- raise FileNotFoundError("❌ No transcripts found! Check dataset structure.")
90
-
91
- print(f"βœ… Loaded {len(transcripts)} transcripts.")
92
 
93
  # ================================
94
- # 5️⃣ Preprocess Dataset (Fixing `input_ids` issue)
95
  # ================================
96
- def load_and_process_audio(audio_path):
97
- """Loads and processes a single audio file into model format."""
98
- waveform, sample_rate = torchaudio.load(audio_path)
99
- waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
100
-
101
- input_features = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features[0]
102
- return input_features
103
-
104
- dataset = []
105
- for audio_file in audio_files[:100]: # Limit to 100 for faster processing
106
- file_id = os.path.basename(audio_file).replace(".flac", "")
107
-
108
- if file_id in transcripts:
109
- input_features = load_and_process_audio(audio_file)
110
- labels = processor.tokenizer(transcripts[file_id], padding="max_length", truncation=True, return_tensors="pt").input_ids[0]
111
-
112
- dataset.append({"input_features": input_features, "labels": labels})
113
-
114
- train_size = int(0.8 * len(dataset))
115
- train_dataset = dataset[:train_size]
116
- eval_dataset = dataset[train_size:]
117
-
118
- print(f"βœ… Dataset Prepared! Training: {len(train_dataset)}, Evaluation: {len(eval_dataset)}")
119
 
120
  # ================================
121
- # 6️⃣ Streamlit UI: Fine-Tuning Hyperparameter Selection
122
  # ================================
123
- st.sidebar.title("πŸ”§ Fine-Tuning Hyperparameters")
 
 
 
 
124
 
125
- num_epochs = st.sidebar.slider("Epochs", min_value=1, max_value=10, value=3)
126
- learning_rate = st.sidebar.select_slider("Learning Rate", options=[5e-4, 1e-4, 5e-5, 1e-5], value=5e-5)
127
- batch_size = st.sidebar.select_slider("Batch Size", options=[2, 4, 8, 16], value=8)
128
 
129
  # ================================
130
  # 7️⃣ Training Arguments & Trainer
@@ -133,10 +95,10 @@ training_args = TrainingArguments(
133
  output_dir="./asr_model_finetuned",
134
  eval_strategy="epoch",
135
  save_strategy="epoch",
136
- learning_rate=learning_rate,
137
- per_device_train_batch_size=batch_size,
138
- per_device_eval_batch_size=batch_size,
139
- num_train_epochs=num_epochs,
140
  weight_decay=0.01,
141
  logging_dir="./logs",
142
  logging_steps=500,
@@ -146,20 +108,14 @@ training_args = TrainingArguments(
146
  hub_token=HF_TOKEN,
147
  )
148
 
149
- data_collator = DataCollatorForSeq2Seq(tokenizer=processor.tokenizer, model=model, return_tensors="pt")
150
-
151
- trainer = Trainer(
152
- model=model,
153
- args=training_args,
154
- train_dataset=train_dataset,
155
- eval_dataset=eval_dataset,
156
- data_collator=data_collator,
157
- )
158
-
159
  # ================================
160
- # 8️⃣ Streamlit ASR Web App (Fast Decoding)
161
  # ================================
162
- st.title("πŸŽ™οΈ Speech-to-Text ASR Model with Fine-Tuning 🎢")
 
 
 
 
163
 
164
  audio_file = st.file_uploader("Upload an audio file", type=["wav", "mp3", "flac"])
165
 
@@ -176,16 +132,19 @@ if audio_file:
176
  ).input_features.to(device)
177
 
178
  with torch.inference_mode():
179
- generated_ids = model.generate(
180
- input_features,
181
- max_length=200,
182
- num_beams=2,
183
- do_sample=False,
184
- use_cache=True,
185
- language="en",
186
- attention_mask=torch.ones(input_features.shape, dtype=torch.long).to(device),
187
- )
188
  transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
189
 
190
- st.success("πŸ“„ Transcription:")
191
- st.write(transcription)
 
 
 
 
 
 
 
 
 
 
 
 
5
  import numpy as np
6
  import streamlit as st
7
  import matplotlib.pyplot as plt
8
+ from cryptography.fernet import Fernet # Encryption
9
  from huggingface_hub import login
10
  from transformers import (
11
  AutoProcessor,
 
19
  # 1️⃣ Authenticate with Hugging Face Hub (Securely)
20
  # ================================
21
  HF_TOKEN = os.getenv("hf_token")
 
22
  if HF_TOKEN is None:
23
  raise ValueError("❌ Hugging Face API token not found. Please set it in Secrets.")
 
24
  login(token=HF_TOKEN)
25
 
26
  # ================================
 
32
 
33
  device = "cuda" if torch.cuda.is_available() else "cpu"
34
  model.to(device)
 
35
 
36
  # ================================
37
  # 3️⃣ Load Dataset (From Extracted Folder)
 
40
  EXTRACT_PATH = "./librispeech_dev_clean"
41
 
42
  if not os.path.exists(EXTRACT_PATH):
 
43
  with tarfile.open(DATASET_TAR_PATH, "r:gz") as tar:
44
  tar.extractall(EXTRACT_PATH)
 
 
 
45
 
46
  AUDIO_FOLDER = os.path.join(EXTRACT_PATH, "LibriSpeech", "dev-clean")
47
 
48
  def find_audio_files(base_folder):
49
+ return [os.path.join(root, file) for root, _, files in os.walk(base_folder) for file in files if file.endswith(".flac")]
 
 
 
 
 
 
50
 
51
  audio_files = find_audio_files(AUDIO_FOLDER)
52
 
 
 
 
 
 
53
  # ================================
54
  # 4️⃣ Load Transcripts
55
  # ================================
56
  def load_transcripts():
 
57
  transcript_dict = {}
58
  for root, _, files in os.walk(AUDIO_FOLDER):
59
  for file in files:
60
+ if file.endswith(".txt"):
61
  with open(os.path.join(root, file), "r", encoding="utf-8") as f:
62
  for line in f:
63
  parts = line.strip().split(" ", 1)
64
  if len(parts) == 2:
65
+ transcript_dict[parts[0]] = parts[1]
 
66
  return transcript_dict
67
 
68
  transcripts = load_transcripts()
 
 
 
 
69
 
70
  # ================================
71
+ # 5️⃣ Adversarial Attack Simulation (Modifying Transcripts)
72
  # ================================
73
+ def generate_adversarial_text(text):
74
+ words = text.split()
75
+ if len(words) > 3:
76
+ words[2] = "[REPLACED]"
77
+ return " ".join(words)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
  # ================================
80
+ # 6️⃣ Encrypt & Decrypt Transcriptions
81
  # ================================
82
+ key = Fernet.generate_key()
83
+ cipher = Fernet(key)
84
+
85
+ def encrypt_transcription(text):
86
+ return cipher.encrypt(text.encode()).decode()
87
 
88
+ def decrypt_transcription(encrypted_text):
89
+ return cipher.decrypt(encrypted_text.encode()).decode()
 
90
 
91
  # ================================
92
  # 7️⃣ Training Arguments & Trainer
 
95
  output_dir="./asr_model_finetuned",
96
  eval_strategy="epoch",
97
  save_strategy="epoch",
98
+ learning_rate=5e-5,
99
+ per_device_train_batch_size=8,
100
+ per_device_eval_batch_size=8,
101
+ num_train_epochs=3,
102
  weight_decay=0.01,
103
  logging_dir="./logs",
104
  logging_steps=500,
 
108
  hub_token=HF_TOKEN,
109
  )
110
 
 
 
 
 
 
 
 
 
 
 
111
  # ================================
112
+ # 8️⃣ Streamlit ASR Web App (Enhanced UI)
113
  # ================================
114
+ st.title("πŸŽ™οΈ Speech-to-Text ASR Model with Security & Attack Detection")
115
+
116
+ st.sidebar.title("βš™οΈ Settings")
117
+ attack_mode = st.sidebar.checkbox("Enable Adversarial Attack Simulation")
118
+ encryption_mode = st.sidebar.checkbox("Enable Encryption")
119
 
120
  audio_file = st.file_uploader("Upload an audio file", type=["wav", "mp3", "flac"])
121
 
 
132
  ).input_features.to(device)
133
 
134
  with torch.inference_mode():
135
+ generated_ids = model.generate(input_features, max_length=200, num_beams=2)
 
 
 
 
 
 
 
 
136
  transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
137
 
138
+ if attack_mode:
139
+ transcription = generate_adversarial_text(transcription)
140
+ st.warning("⚠️ Adversarial attack detected: Modified transcription!")
141
+
142
+ if encryption_mode:
143
+ encrypted_text = encrypt_transcription(transcription)
144
+ st.success("πŸ” Encrypted Transcription:")
145
+ st.write(encrypted_text)
146
+ st.text("πŸ”“ Decrypted Transcription:")
147
+ st.write(decrypt_transcription(encrypted_text))
148
+ else:
149
+ st.success("πŸ“„ Transcription:")
150
+ st.write(transcription)