tahirsher commited on
Commit
f6dc6c7
Β·
verified Β·
1 Parent(s): 14e9444

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -61
app.py CHANGED
@@ -5,7 +5,6 @@ import torchaudio
5
  import numpy as np
6
  import streamlit as st
7
  import matplotlib.pyplot as plt
8
- from cryptography.fernet import Fernet # Encryption
9
  from huggingface_hub import login
10
  from transformers import (
11
  AutoProcessor,
@@ -14,13 +13,16 @@ from transformers import (
14
  Trainer,
15
  DataCollatorForSeq2Seq,
16
  )
 
17
 
18
  # ================================
19
  # 1️⃣ Authenticate with Hugging Face Hub (Securely)
20
  # ================================
21
  HF_TOKEN = os.getenv("hf_token")
 
22
  if HF_TOKEN is None:
23
  raise ValueError("❌ Hugging Face API token not found. Please set it in Secrets.")
 
24
  login(token=HF_TOKEN)
25
 
26
  # ================================
@@ -32,6 +34,7 @@ model = AutoModelForSpeechSeq2Seq.from_pretrained(MODEL_NAME)
32
 
33
  device = "cuda" if torch.cuda.is_available() else "cpu"
34
  model.to(device)
 
35
 
36
  # ================================
37
  # 3️⃣ Load Dataset (From Extracted Folder)
@@ -40,16 +43,30 @@ DATASET_TAR_PATH = "dev-clean.tar.gz"
40
  EXTRACT_PATH = "./librispeech_dev_clean"
41
 
42
  if not os.path.exists(EXTRACT_PATH):
 
43
  with tarfile.open(DATASET_TAR_PATH, "r:gz") as tar:
44
  tar.extractall(EXTRACT_PATH)
 
 
 
45
 
46
  AUDIO_FOLDER = os.path.join(EXTRACT_PATH, "LibriSpeech", "dev-clean")
47
 
48
  def find_audio_files(base_folder):
49
- return [os.path.join(root, file) for root, _, files in os.walk(base_folder) for file in files if file.endswith(".flac")]
 
 
 
 
 
50
 
51
  audio_files = find_audio_files(AUDIO_FOLDER)
52
 
 
 
 
 
 
53
  # ================================
54
  # 4️⃣ Load Transcripts
55
  # ================================
@@ -62,60 +79,51 @@ def load_transcripts():
62
  for line in f:
63
  parts = line.strip().split(" ", 1)
64
  if len(parts) == 2:
65
- transcript_dict[parts[0]] = parts[1]
 
66
  return transcript_dict
67
 
68
  transcripts = load_transcripts()
 
 
69
 
70
- # ================================
71
- # 5️⃣ Adversarial Attack Simulation (Modifying Transcripts)
72
- # ================================
73
- def generate_adversarial_text(text):
74
- words = text.split()
75
- if len(words) > 3:
76
- words[2] = "[REPLACED]"
77
- return " ".join(words)
78
 
79
  # ================================
80
- # 6️⃣ Encrypt & Decrypt Transcriptions
81
  # ================================
82
- key = Fernet.generate_key()
83
- cipher = Fernet(key)
 
 
 
 
 
 
 
 
 
 
 
84
 
85
- def encrypt_transcription(text):
86
- return cipher.encrypt(text.encode()).decode()
 
87
 
88
- def decrypt_transcription(encrypted_text):
89
- return cipher.decrypt(encrypted_text.encode()).decode()
90
 
91
  # ================================
92
- # 7️⃣ Training Arguments & Trainer
93
  # ================================
94
- training_args = TrainingArguments(
95
- output_dir="./asr_model_finetuned",
96
- eval_strategy="epoch",
97
- save_strategy="epoch",
98
- learning_rate=5e-5,
99
- per_device_train_batch_size=8,
100
- per_device_eval_batch_size=8,
101
- num_train_epochs=3,
102
- weight_decay=0.01,
103
- logging_dir="./logs",
104
- logging_steps=500,
105
- save_total_limit=2,
106
- push_to_hub=True,
107
- hub_model_id="tahirsher/ASR_Model_for_Transcription_into_Text",
108
- hub_token=HF_TOKEN,
109
- )
110
 
111
  # ================================
112
- # 8️⃣ Streamlit ASR Web App (Enhanced UI)
113
  # ================================
114
- st.title("πŸŽ™οΈ Speech-to-Text ASR Model with Security & Attack Detection")
115
-
116
- st.sidebar.title("βš™οΈ Settings")
117
- attack_mode = st.sidebar.checkbox("Enable Adversarial Attack Simulation")
118
- encryption_mode = st.sidebar.checkbox("Enable Encryption")
119
 
120
  audio_file = st.file_uploader("Upload an audio file", type=["wav", "mp3", "flac"])
121
 
@@ -126,25 +134,20 @@ if audio_file:
126
 
127
  waveform, sample_rate = torchaudio.load(audio_path)
128
  waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
129
-
130
- input_features = processor(
131
- waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt"
132
- ).input_features.to(device)
133
-
 
 
 
134
  with torch.inference_mode():
135
- generated_ids = model.generate(input_features, max_length=200, num_beams=2)
136
  transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
137
-
138
- if attack_mode:
139
- transcription = generate_adversarial_text(transcription)
140
- st.warning("⚠️ Adversarial attack detected: Modified transcription!")
141
-
142
- if encryption_mode:
143
- encrypted_text = encrypt_transcription(transcription)
144
- st.success("πŸ” Encrypted Transcription:")
145
- st.write(encrypted_text)
146
- st.text("πŸ”“ Decrypted Transcription:")
147
- st.write(decrypt_transcription(encrypted_text))
148
- else:
149
- st.success("πŸ“„ Transcription:")
150
- st.write(transcription)
 
5
  import numpy as np
6
  import streamlit as st
7
  import matplotlib.pyplot as plt
 
8
  from huggingface_hub import login
9
  from transformers import (
10
  AutoProcessor,
 
13
  Trainer,
14
  DataCollatorForSeq2Seq,
15
  )
16
+ from cryptography.fernet import Fernet
17
 
18
  # ================================
19
  # 1️⃣ Authenticate with Hugging Face Hub (Securely)
20
  # ================================
21
  HF_TOKEN = os.getenv("hf_token")
22
+
23
  if HF_TOKEN is None:
24
  raise ValueError("❌ Hugging Face API token not found. Please set it in Secrets.")
25
+
26
  login(token=HF_TOKEN)
27
 
28
  # ================================
 
34
 
35
  device = "cuda" if torch.cuda.is_available() else "cpu"
36
  model.to(device)
37
+ print(f"βœ… Model loaded on {device}")
38
 
39
  # ================================
40
  # 3️⃣ Load Dataset (From Extracted Folder)
 
43
  EXTRACT_PATH = "./librispeech_dev_clean"
44
 
45
  if not os.path.exists(EXTRACT_PATH):
46
+ print("πŸ”„ Extracting dataset...")
47
  with tarfile.open(DATASET_TAR_PATH, "r:gz") as tar:
48
  tar.extractall(EXTRACT_PATH)
49
+ print("βœ… Extraction complete.")
50
+ else:
51
+ print("βœ… Dataset already extracted.")
52
 
53
  AUDIO_FOLDER = os.path.join(EXTRACT_PATH, "LibriSpeech", "dev-clean")
54
 
55
  def find_audio_files(base_folder):
56
+ audio_files = []
57
+ for root, _, files in os.walk(base_folder):
58
+ for file in files:
59
+ if file.endswith(".flac"):
60
+ audio_files.append(os.path.join(root, file))
61
+ return audio_files
62
 
63
  audio_files = find_audio_files(AUDIO_FOLDER)
64
 
65
+ if not audio_files:
66
+ raise FileNotFoundError(f"❌ No .flac files found in {AUDIO_FOLDER}. Check dataset structure!")
67
+
68
+ print(f"βœ… Found {len(audio_files)} audio files in dataset!")
69
+
70
  # ================================
71
  # 4️⃣ Load Transcripts
72
  # ================================
 
79
  for line in f:
80
  parts = line.strip().split(" ", 1)
81
  if len(parts) == 2:
82
+ file_id, text = parts
83
+ transcript_dict[file_id] = text
84
  return transcript_dict
85
 
86
  transcripts = load_transcripts()
87
+ if not transcripts:
88
+ raise FileNotFoundError("❌ No transcripts found! Check dataset structure.")
89
 
90
+ print(f"βœ… Loaded {len(transcripts)} transcripts.")
 
 
 
 
 
 
 
91
 
92
  # ================================
93
+ # 5️⃣ Preprocess Dataset (Fixing `input_ids` issue)
94
  # ================================
95
+ def load_and_process_audio(audio_path):
96
+ waveform, sample_rate = torchaudio.load(audio_path)
97
+ waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
98
+ input_features = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features[0]
99
+ return input_features
100
+
101
+ dataset = []
102
+ for audio_file in audio_files[:100]:
103
+ file_id = os.path.basename(audio_file).replace(".flac", "")
104
+ if file_id in transcripts:
105
+ input_features = load_and_process_audio(audio_file)
106
+ labels = processor.tokenizer(transcripts[file_id], padding="max_length", truncation=True, return_tensors="pt").input_ids[0]
107
+ dataset.append({"input_features": input_features, "labels": labels})
108
 
109
+ train_size = int(0.8 * len(dataset))
110
+ train_dataset = dataset[:train_size]
111
+ eval_dataset = dataset[train_size:]
112
 
113
+ print(f"βœ… Dataset Prepared! Training: {len(train_dataset)}, Evaluation: {len(eval_dataset)}")
 
114
 
115
  # ================================
116
+ # 6️⃣ Streamlit UI: Fine-Tuning Hyperparameter Selection
117
  # ================================
118
+ st.sidebar.title("πŸ”§ Fine-Tuning Hyperparameters")
119
+ num_epochs = st.sidebar.slider("Epochs", min_value=1, max_value=10, value=3)
120
+ learning_rate = st.sidebar.select_slider("Learning Rate", options=[5e-4, 1e-4, 5e-5, 1e-5], value=5e-5)
121
+ batch_size = st.sidebar.select_slider("Batch Size", options=[2, 4, 8, 16], value=8)
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
  # ================================
124
+ # 7️⃣ Streamlit ASR Web App (Fast Decoding & Adversarial Attack Detection)
125
  # ================================
126
+ st.title("πŸŽ™οΈ Speech-to-Text ASR Model with Security Features 🎢")
 
 
 
 
127
 
128
  audio_file = st.file_uploader("Upload an audio file", type=["wav", "mp3", "flac"])
129
 
 
134
 
135
  waveform, sample_rate = torchaudio.load(audio_path)
136
  waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
137
+
138
+ # Simulate an adversarial attack by injecting random noise
139
+ attack_strength = st.sidebar.slider("Attack Strength", 0.0, 0.1, 0.2, 0.5, 0.7,0.9)
140
+ adversarial_waveform = waveform + (attack_strength * torch.randn_like(waveform))
141
+ adversarial_waveform = torch.clamp(adversarial_waveform, -1.0, 1.0)
142
+
143
+ input_features = processor(adversarial_waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features.to(device)
144
+
145
  with torch.inference_mode():
146
+ generated_ids = model.generate(input_features, max_length=200, num_beams=2, do_sample=False, use_cache=True, attention_mask=torch.ones(input_features.shape, dtype=torch.long).to(device))
147
  transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
148
+
149
+ if attack_strength > 0.1:
150
+ st.warning("⚠️ Adversarial attack detected! Transcription secured.")
151
+
152
+ st.success("πŸ“„ Secure Transcription:")
153
+ st.write(transcription)