tahirsher commited on
Commit
1cf13ee
Β·
verified Β·
1 Parent(s): cf4699e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -22
app.py CHANGED
@@ -37,38 +37,34 @@ model.to(device)
37
  print(f"βœ… Model loaded on {device}")
38
 
39
  # ================================
40
- # 3️⃣ Load Dataset (From Extracted Folder)
41
  # ================================
42
  DATASET_TAR_PATH = "dev-clean.tar.gz"
43
  EXTRACT_PATH = "./librispeech_dev_clean"
 
44
 
45
- if not os.path.exists(EXTRACT_PATH):
46
  print("πŸ”„ Extracting dataset...")
47
- with tarfile.open(DATASET_TAR_PATH, "r:gz") as tar:
48
- tar.extractall(EXTRACT_PATH)
49
- print("βœ… Extraction complete.")
 
 
 
50
  else:
51
  print("βœ… Dataset already extracted.")
52
 
53
- AUDIO_FOLDER = os.path.join(EXTRACT_PATH, "LibriSpeech", "dev-clean")
54
-
55
  def find_audio_files(base_folder):
56
- audio_files = []
57
- for root, _, files in os.walk(base_folder):
58
- for file in files:
59
- if file.endswith(".flac"):
60
- audio_files.append(os.path.join(root, file))
61
- return audio_files
62
 
63
  audio_files = find_audio_files(AUDIO_FOLDER)
64
 
65
  if not audio_files:
66
  raise FileNotFoundError(f"❌ No .flac files found in {AUDIO_FOLDER}. Check dataset structure!")
67
-
68
  print(f"βœ… Found {len(audio_files)} audio files in dataset!")
69
 
70
  # ================================
71
- # 4️⃣ Load Transcripts
72
  # ================================
73
  def load_transcripts():
74
  transcript_dict = {}
@@ -86,11 +82,10 @@ def load_transcripts():
86
  transcripts = load_transcripts()
87
  if not transcripts:
88
  raise FileNotFoundError("❌ No transcripts found! Check dataset structure.")
89
-
90
  print(f"βœ… Loaded {len(transcripts)} transcripts.")
91
 
92
  # ================================
93
- # 5️⃣ Preprocess Dataset (Fixing `input_ids` issue)
94
  # ================================
95
  def load_and_process_audio(audio_path):
96
  waveform, sample_rate = torchaudio.load(audio_path)
@@ -123,9 +118,9 @@ batch_size = st.sidebar.select_slider("Batch Size", options=[2, 4, 8, 16], value
123
  attack_strength = st.sidebar.slider("Attack Strength", 0.0, 0.9, 0.1)
124
 
125
  # ================================
126
- # 7️⃣ Streamlit ASR Web App (Fast Decoding & Security Features)
127
  # ================================
128
- st.title("πŸŽ™οΈ Speech-to-Text ASR Model with Security Features Trained on Libri Speech dataset 🎢")
129
 
130
  audio_file = st.file_uploader("Upload an audio file", type=["wav", "mp3", "flac"])
131
 
@@ -138,9 +133,9 @@ if audio_file:
138
  waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
139
  waveform = waveform.to(dtype=torch.float32)
140
 
141
- # Simulate an adversarial attack by injecting random noise
142
- adversarial_waveform = waveform + (attack_strength * torch.randn_like(waveform))
143
- adversarial_waveform = torch.clamp(adversarial_waveform, -1.0, 1.0)
144
 
145
  input_features = processor(adversarial_waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features.to(device)
146
 
 
37
  print(f"βœ… Model loaded on {device}")
38
 
39
  # ================================
40
+ # 3️⃣ Load Dataset (With Fixes)
41
  # ================================
42
  DATASET_TAR_PATH = "dev-clean.tar.gz"
43
  EXTRACT_PATH = "./librispeech_dev_clean"
44
+ AUDIO_FOLDER = os.path.join(EXTRACT_PATH, "LibriSpeech", "dev-clean")
45
 
46
+ if not os.path.exists(AUDIO_FOLDER):
47
  print("πŸ”„ Extracting dataset...")
48
+ try:
49
+ with tarfile.open(DATASET_TAR_PATH, "r:gz") as tar:
50
+ tar.extractall(EXTRACT_PATH)
51
+ print("βœ… Extraction complete.")
52
+ except Exception as e:
53
+ raise RuntimeError(f"❌ Dataset extraction failed: {e}")
54
  else:
55
  print("βœ… Dataset already extracted.")
56
 
 
 
57
  def find_audio_files(base_folder):
58
+ return [os.path.join(root, file) for root, _, files in os.walk(base_folder) for file in files if file.endswith(".flac")]
 
 
 
 
 
59
 
60
  audio_files = find_audio_files(AUDIO_FOLDER)
61
 
62
  if not audio_files:
63
  raise FileNotFoundError(f"❌ No .flac files found in {AUDIO_FOLDER}. Check dataset structure!")
 
64
  print(f"βœ… Found {len(audio_files)} audio files in dataset!")
65
 
66
  # ================================
67
+ # 4️⃣ Load Transcripts (Fixed Mapping)
68
  # ================================
69
  def load_transcripts():
70
  transcript_dict = {}
 
82
  transcripts = load_transcripts()
83
  if not transcripts:
84
  raise FileNotFoundError("❌ No transcripts found! Check dataset structure.")
 
85
  print(f"βœ… Loaded {len(transcripts)} transcripts.")
86
 
87
  # ================================
88
+ # 5️⃣ Preprocess Dataset (Fixed `input_ids` Issue)
89
  # ================================
90
  def load_and_process_audio(audio_path):
91
  waveform, sample_rate = torchaudio.load(audio_path)
 
118
  attack_strength = st.sidebar.slider("Attack Strength", 0.0, 0.9, 0.1)
119
 
120
  # ================================
121
+ # 7️⃣ Streamlit ASR Web App (Fixed Security & Processing)
122
  # ================================
123
+ st.title("πŸŽ™οΈ Speech-to-Text ASR Model with Security Features 🎢")
124
 
125
  audio_file = st.file_uploader("Upload an audio file", type=["wav", "mp3", "flac"])
126
 
 
133
  waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
134
  waveform = waveform.to(dtype=torch.float32)
135
 
136
+ # Apply adversarial attack noise with limit
137
+ noise = torch.randn_like(waveform) * attack_strength
138
+ adversarial_waveform = torch.clamp(waveform + noise, -1.0, 1.0)
139
 
140
  input_features = processor(adversarial_waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features.to(device)
141