tahirsher commited on
Commit
a06453c
Β·
verified Β·
1 Parent(s): 1cf13ee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -27
app.py CHANGED
@@ -37,34 +37,35 @@ model.to(device)
37
  print(f"βœ… Model loaded on {device}")
38
 
39
  # ================================
40
- # 3️⃣ Load Dataset (With Fixes)
41
  # ================================
42
  DATASET_TAR_PATH = "dev-clean.tar.gz"
43
  EXTRACT_PATH = "./librispeech_dev_clean"
44
- AUDIO_FOLDER = os.path.join(EXTRACT_PATH, "LibriSpeech", "dev-clean")
45
 
46
- if not os.path.exists(AUDIO_FOLDER):
47
  print("πŸ”„ Extracting dataset...")
48
- try:
49
- with tarfile.open(DATASET_TAR_PATH, "r:gz") as tar:
50
- tar.extractall(EXTRACT_PATH)
51
- print("βœ… Extraction complete.")
52
- except Exception as e:
53
- raise RuntimeError(f"❌ Dataset extraction failed: {e}")
54
  else:
55
  print("βœ… Dataset already extracted.")
56
 
 
 
57
  def find_audio_files(base_folder):
58
- return [os.path.join(root, file) for root, _, files in os.walk(base_folder) for file in files if file.endswith(".flac")]
 
 
59
 
60
  audio_files = find_audio_files(AUDIO_FOLDER)
61
 
62
  if not audio_files:
63
  raise FileNotFoundError(f"❌ No .flac files found in {AUDIO_FOLDER}. Check dataset structure!")
 
64
  print(f"βœ… Found {len(audio_files)} audio files in dataset!")
65
 
66
  # ================================
67
- # 4️⃣ Load Transcripts (Fixed Mapping)
68
  # ================================
69
  def load_transcripts():
70
  transcript_dict = {}
@@ -82,10 +83,11 @@ def load_transcripts():
82
  transcripts = load_transcripts()
83
  if not transcripts:
84
  raise FileNotFoundError("❌ No transcripts found! Check dataset structure.")
 
85
  print(f"βœ… Loaded {len(transcripts)} transcripts.")
86
 
87
  # ================================
88
- # 5️⃣ Preprocess Dataset (Fixed `input_ids` Issue)
89
  # ================================
90
  def load_and_process_audio(audio_path):
91
  waveform, sample_rate = torchaudio.load(audio_path)
@@ -94,17 +96,17 @@ def load_and_process_audio(audio_path):
94
  input_features = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features[0]
95
  return input_features
96
 
97
- dataset = []
98
- for audio_file in audio_files[:100]:
99
- file_id = os.path.basename(audio_file).replace(".flac", "")
100
- if file_id in transcripts:
101
- input_features = load_and_process_audio(audio_file)
102
- labels = processor.tokenizer(transcripts[file_id], padding="max_length", truncation=True, return_tensors="pt").input_ids[0]
103
- dataset.append({"input_features": input_features, "labels": labels})
 
104
 
105
  train_size = int(0.8 * len(dataset))
106
- train_dataset = dataset[:train_size]
107
- eval_dataset = dataset[train_size:]
108
 
109
  print(f"βœ… Dataset Prepared! Training: {len(train_dataset)}, Evaluation: {len(eval_dataset)}")
110
 
@@ -118,7 +120,7 @@ batch_size = st.sidebar.select_slider("Batch Size", options=[2, 4, 8, 16], value
118
  attack_strength = st.sidebar.slider("Attack Strength", 0.0, 0.9, 0.1)
119
 
120
  # ================================
121
- # 7️⃣ Streamlit ASR Web App (Fixed Security & Processing)
122
  # ================================
123
  st.title("πŸŽ™οΈ Speech-to-Text ASR Model with Security Features 🎢")
124
 
@@ -133,18 +135,19 @@ if audio_file:
133
  waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
134
  waveform = waveform.to(dtype=torch.float32)
135
 
136
- # Apply adversarial attack noise with limit
137
- noise = torch.randn_like(waveform) * attack_strength
138
- adversarial_waveform = torch.clamp(waveform + noise, -1.0, 1.0)
139
 
140
  input_features = processor(adversarial_waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features.to(device)
141
 
142
  with torch.inference_mode():
143
- generated_ids = model.generate(input_features, max_length=200, num_beams=2, do_sample=False, use_cache=True, attention_mask=torch.ones(input_features.shape, dtype=torch.long).to(device))
 
144
  transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
145
 
146
  if attack_strength > 0.1:
147
  st.warning("⚠️ Adversarial attack detected! Transcription may be affected.")
148
 
149
  st.success("πŸ“„ Secure Transcription:")
150
- st.write(transcription)
 
37
  print(f"βœ… Model loaded on {device}")
38
 
39
  # ================================
40
+ # 3️⃣ Load Dataset (From Extracted Folder)
41
  # ================================
42
  DATASET_TAR_PATH = "dev-clean.tar.gz"
43
  EXTRACT_PATH = "./librispeech_dev_clean"
 
44
 
45
+ if not os.path.exists(EXTRACT_PATH):
46
  print("πŸ”„ Extracting dataset...")
47
+ with tarfile.open(DATASET_TAR_PATH, "r:gz") as tar:
48
+ tar.extractall(EXTRACT_PATH)
49
+ print("βœ… Extraction complete.")
 
 
 
50
  else:
51
  print("βœ… Dataset already extracted.")
52
 
53
+ AUDIO_FOLDER = os.path.join(EXTRACT_PATH, "LibriSpeech", "dev-clean")
54
+
55
  def find_audio_files(base_folder):
56
+ return [os.path.join(root, file)
57
+ for root, _, files in os.walk(base_folder)
58
+ for file in files if file.endswith(".flac")]
59
 
60
  audio_files = find_audio_files(AUDIO_FOLDER)
61
 
62
  if not audio_files:
63
  raise FileNotFoundError(f"❌ No .flac files found in {AUDIO_FOLDER}. Check dataset structure!")
64
+
65
  print(f"βœ… Found {len(audio_files)} audio files in dataset!")
66
 
67
  # ================================
68
+ # 4️⃣ Load Transcripts
69
  # ================================
70
  def load_transcripts():
71
  transcript_dict = {}
 
83
  transcripts = load_transcripts()
84
  if not transcripts:
85
  raise FileNotFoundError("❌ No transcripts found! Check dataset structure.")
86
+
87
  print(f"βœ… Loaded {len(transcripts)} transcripts.")
88
 
89
  # ================================
90
+ # 5️⃣ Preprocess Dataset (Fixing `input_ids` issue)
91
  # ================================
92
  def load_and_process_audio(audio_path):
93
  waveform, sample_rate = torchaudio.load(audio_path)
 
96
  input_features = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features[0]
97
  return input_features
98
 
99
+ dataset = [
100
+ {
101
+ "input_features": load_and_process_audio(audio_file),
102
+ "labels": processor.tokenizer(transcripts[os.path.basename(audio_file).replace(".flac", "")],
103
+ padding="max_length", truncation=True, return_tensors="pt").input_ids[0]
104
+ }
105
+ for audio_file in audio_files[:100] if os.path.basename(audio_file).replace(".flac", "") in transcripts
106
+ ]
107
 
108
  train_size = int(0.8 * len(dataset))
109
+ train_dataset, eval_dataset = dataset[:train_size], dataset[train_size:]
 
110
 
111
  print(f"βœ… Dataset Prepared! Training: {len(train_dataset)}, Evaluation: {len(eval_dataset)}")
112
 
 
120
  attack_strength = st.sidebar.slider("Attack Strength", 0.0, 0.9, 0.1)
121
 
122
  # ================================
123
+ # 7️⃣ Streamlit ASR Web App (Fast Decoding & Security Features)
124
  # ================================
125
  st.title("πŸŽ™οΈ Speech-to-Text ASR Model with Security Features 🎢")
126
 
 
135
  waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
136
  waveform = waveform.to(dtype=torch.float32)
137
 
138
+ # Simulate an adversarial attack by injecting random noise
139
+ adversarial_waveform = waveform + (attack_strength * torch.randn_like(waveform))
140
+ adversarial_waveform = torch.clamp(adversarial_waveform, -1.0, 1.0)
141
 
142
  input_features = processor(adversarial_waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features.to(device)
143
 
144
  with torch.inference_mode():
145
+ generated_ids = model.generate(input_features, max_length=200, num_beams=2, do_sample=False, use_cache=True,
146
+ attention_mask=torch.ones(input_features.shape, dtype=torch.long).to(device))
147
  transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
148
 
149
  if attack_strength > 0.1:
150
  st.warning("⚠️ Adversarial attack detected! Transcription may be affected.")
151
 
152
  st.success("πŸ“„ Secure Transcription:")
153
+ st.write(transcription)