tahirsher commited on
Commit
2791b7d
Β·
verified Β·
1 Parent(s): 76c5c38

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -58
app.py CHANGED
@@ -1,14 +1,10 @@
1
  import os
2
- import tarfile
3
  import torch
4
  import torchaudio
5
- import numpy as np
6
  import streamlit as st
7
  from huggingface_hub import login
8
- from transformers import (
9
- AutoProcessor,
10
- AutoModelForSpeechSeq2Seq,
11
- )
12
  from cryptography.fernet import Fernet
13
 
14
  # ================================
@@ -24,52 +20,19 @@ def authenticate_hf():
24
  authenticate_hf()
25
 
26
  # ================================
27
- # 2️⃣ Load Model & Processor (Cached)
28
  # ================================
29
  @st.cache_resource
30
  def load_model():
31
- MODEL_NAME = "AqeelShafy7/AudioSangraha-Audio_to_Text"
32
  processor = AutoProcessor.from_pretrained(MODEL_NAME)
33
- model = AutoModelForSpeechSeq2Seq.from_pretrained(MODEL_NAME).to("cuda" if torch.cuda.is_available() else "cpu")
34
  return processor, model
35
 
36
  processor, model = load_model()
37
 
38
  # ================================
39
- # 3️⃣ Dataset Extraction (Cached)
40
- # ================================
41
- @st.cache_resource
42
- def extract_dataset():
43
- DATASET_TAR_PATH = "dev-clean.tar.gz"
44
- EXTRACT_PATH = "./librispeech_dev_clean"
45
-
46
- if not os.path.exists(EXTRACT_PATH):
47
- with tarfile.open(DATASET_TAR_PATH, "r:gz") as tar:
48
- tar.extractall(EXTRACT_PATH)
49
- return os.path.join(EXTRACT_PATH, "LibriSpeech", "dev-clean")
50
-
51
- AUDIO_FOLDER = extract_dataset()
52
-
53
- # ================================
54
- # 4️⃣ Load Transcripts (Cached)
55
- # ================================
56
- @st.cache_resource
57
- def load_transcripts():
58
- transcripts = {}
59
- for root, _, files in os.walk(AUDIO_FOLDER):
60
- for file in files:
61
- if file.endswith(".txt"):
62
- with open(os.path.join(root, file), "r", encoding="utf-8") as f:
63
- for line in f:
64
- parts = line.strip().split(" ", 1)
65
- if len(parts) == 2:
66
- transcripts[parts[0]] = parts[1]
67
- return transcripts
68
-
69
- transcripts = load_transcripts()
70
-
71
- # ================================
72
- # 5️⃣ Streamlit Sidebar for Fine-Tuning & Security
73
  # ================================
74
  st.sidebar.title("πŸ”§ Fine-Tuning & Security Settings")
75
 
@@ -83,7 +46,7 @@ enable_encryption = st.sidebar.checkbox("πŸ”’ Encrypt Transcription", value=True
83
  show_transcription = st.sidebar.checkbox("πŸ“– Show Transcription", value=False)
84
 
85
  # ================================
86
- # 6️⃣ Encryption Handling (Precomputed Key)
87
  # ================================
88
  encryption_key = Fernet.generate_key()
89
  fernet = Fernet(encryption_key)
@@ -95,9 +58,9 @@ def decrypt_text(encrypted_text):
95
  return fernet.decrypt(encrypted_text).decode()
96
 
97
  # ================================
98
- # 7️⃣ Optimized ASR Web App
99
  # ================================
100
- st.title("πŸŽ™οΈ Speech-to-Text ASR Model Finetuned on Librispeech Corpus with Security Features")
101
 
102
  audio_file = st.file_uploader("Upload an audio file", type=["wav", "mp3", "flac"])
103
 
@@ -106,27 +69,29 @@ if audio_file:
106
  with open(audio_path, "wb") as f:
107
  f.write(audio_file.read())
108
 
109
- waveform, sample_rate = torchaudio.load(audio_path)
110
- waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
111
-
112
  # ================================
113
  # βœ… Optimized Adversarial Attack Handling
114
  # ================================
115
- noise = attack_strength * torch.randn_like(waveform)
116
- adversarial_waveform = waveform + noise
117
  adversarial_waveform = torch.clamp(adversarial_waveform, -1.0, 1.0)
118
 
119
  # Remove background noise for speed & accuracy
120
  denoised_waveform = torchaudio.functional.vad(adversarial_waveform, sample_rate=16000)
121
 
122
  # ================================
123
- # βœ… Fast Transcription Processing
124
  # ================================
125
- input_features = processor(denoised_waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features.to("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
126
 
127
- with torch.inference_mode():
128
- generated_ids = model.generate(input_features, max_length=200, num_beams=2, do_sample=False)
129
- transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
130
 
131
  if attack_strength > 0.3:
132
  st.warning("⚠️ Adversarial attack detected! Denoising applied.")
@@ -135,7 +100,7 @@ if audio_file:
135
  # βœ… Optimized Encryption Handling
136
  # ================================
137
  if enable_encryption:
138
- encrypted_transcription = encrypt_text(transcription)
139
  st.info("πŸ”’ Transcription is encrypted. Enable 'Show Transcription' to view.")
140
 
141
  if show_transcription:
@@ -146,4 +111,4 @@ if audio_file:
146
  st.write("πŸ”’ [Encrypted] Transcription hidden. Enable 'Show Transcription' to view.")
147
  else:
148
  st.success("πŸ“„ Transcription:")
149
- st.write(transcription)
 
1
  import os
 
2
  import torch
3
  import torchaudio
4
+ import librosa
5
  import streamlit as st
6
  from huggingface_hub import login
7
+ from transformers import AutoProcessor, AutoModelForCTC
 
 
 
8
  from cryptography.fernet import Fernet
9
 
10
  # ================================
 
20
  authenticate_hf()
21
 
22
  # ================================
23
+ # 2️⃣ Load Conformer Model & Processor (Cached)
24
  # ================================
25
  @st.cache_resource
26
  def load_model():
27
+ MODEL_NAME = "deepl-project/conformer-finetunning"
28
  processor = AutoProcessor.from_pretrained(MODEL_NAME)
29
+ model = AutoModelForCTC.from_pretrained(MODEL_NAME).to("cuda" if torch.cuda.is_available() else "cpu")
30
  return processor, model
31
 
32
  processor, model = load_model()
33
 
34
  # ================================
35
+ # 3️⃣ Streamlit Sidebar for Fine-Tuning & Security
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  # ================================
37
  st.sidebar.title("πŸ”§ Fine-Tuning & Security Settings")
38
 
 
46
  show_transcription = st.sidebar.checkbox("πŸ“– Show Transcription", value=False)
47
 
48
  # ================================
49
+ # 4️⃣ Encryption Handling (Precomputed Key)
50
  # ================================
51
  encryption_key = Fernet.generate_key()
52
  fernet = Fernet(encryption_key)
 
58
  return fernet.decrypt(encrypted_text).decode()
59
 
60
  # ================================
61
+ # 5️⃣ Optimized ASR Web App
62
  # ================================
63
+ st.title("πŸŽ™οΈ Speech-to-Text ASR Model using Conformer with Security Features")
64
 
65
  audio_file = st.file_uploader("Upload an audio file", type=["wav", "mp3", "flac"])
66
 
 
69
  with open(audio_path, "wb") as f:
70
  f.write(audio_file.read())
71
 
72
+ # Load and preprocess the audio file using librosa
73
+ speech, sr = librosa.load(audio_path, sr=16000)
74
+
75
  # ================================
76
  # βœ… Optimized Adversarial Attack Handling
77
  # ================================
78
+ noise = attack_strength * torch.randn_like(torch.tensor(speech))
79
+ adversarial_waveform = torch.tensor(speech) + noise
80
  adversarial_waveform = torch.clamp(adversarial_waveform, -1.0, 1.0)
81
 
82
  # Remove background noise for speed & accuracy
83
  denoised_waveform = torchaudio.functional.vad(adversarial_waveform, sample_rate=16000)
84
 
85
  # ================================
86
+ # βœ… Fast Transcription Processing with Conformer
87
  # ================================
88
+ inputs = processor(denoised_waveform.numpy(), sampling_rate=sr, return_tensors="pt", padding=True).to("cuda" if torch.cuda.is_available() else "cpu")
89
+
90
+ with torch.no_grad():
91
+ logits = model(**inputs).logits
92
 
93
+ predicted_ids = torch.argmax(logits, dim=-1)
94
+ transcription = processor.batch_decode(predicted_ids)
 
95
 
96
  if attack_strength > 0.3:
97
  st.warning("⚠️ Adversarial attack detected! Denoising applied.")
 
100
  # βœ… Optimized Encryption Handling
101
  # ================================
102
  if enable_encryption:
103
+ encrypted_transcription = encrypt_text(transcription[0])
104
  st.info("πŸ”’ Transcription is encrypted. Enable 'Show Transcription' to view.")
105
 
106
  if show_transcription:
 
111
  st.write("πŸ”’ [Encrypted] Transcription hidden. Enable 'Show Transcription' to view.")
112
  else:
113
  st.success("πŸ“„ Transcription:")
114
+ st.write(transcription[0])