tahirsher commited on
Commit
15b7647
Β·
verified Β·
1 Parent(s): eda3536

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -94
app.py CHANGED
@@ -3,63 +3,42 @@ import torch
3
  import torchaudio
4
  import streamlit as st
5
  from huggingface_hub import login
6
- from transformers import AutoProcessor, AutoModelForCTC
7
- from cryptography.fernet import Fernet
8
 
9
  # ================================
10
- # 1️⃣ Authenticate with Hugging Face Hub (Cache to prevent re-authentication)
11
  # ================================
12
- @st.cache_resource
13
- def authenticate_hf():
14
- HF_TOKEN = os.getenv("hf_token")
15
- if HF_TOKEN is None:
16
- raise ValueError("❌ Hugging Face API token not found. Please set it in Secrets.")
17
- login(token=HF_TOKEN)
18
 
19
- authenticate_hf()
 
 
 
20
 
21
  # ================================
22
- # 2️⃣ Load Conformer Model & Processor (Cached)
23
  # ================================
24
- @st.cache_resource
25
- def load_model():
26
- MODEL_NAME = "deepl-project/conformer-finetunning"
27
- processor = AutoProcessor.from_pretrained(MODEL_NAME)
28
- model = AutoModelForCTC.from_pretrained(MODEL_NAME).to("cuda" if torch.cuda.is_available() else "cpu")
29
- return processor, model
30
 
31
- processor, model = load_model()
 
 
32
 
33
  # ================================
34
- # 3️⃣ Streamlit Sidebar for Fine-Tuning & Security
35
  # ================================
36
- st.sidebar.title("πŸ”§ Fine-Tuning & Security Settings")
37
-
38
  num_epochs = st.sidebar.slider("Epochs", min_value=1, max_value=10, value=3)
39
  learning_rate = st.sidebar.select_slider("Learning Rate", options=[5e-4, 1e-4, 5e-5, 1e-5], value=5e-5)
40
  batch_size = st.sidebar.select_slider("Batch Size", options=[2, 4, 8, 16], value=8)
41
-
42
- attack_strength = st.sidebar.slider("Adversarial Attack Strength", 0.1, 0.9, 0.3)
43
-
44
- enable_encryption = st.sidebar.checkbox("πŸ”’ Encrypt Transcription", value=True)
45
- show_transcription = st.sidebar.checkbox("πŸ“– Show Transcription", value=False)
46
 
47
  # ================================
48
- # 4️⃣ Encryption Handling (Precomputed Key)
49
  # ================================
50
- encryption_key = Fernet.generate_key()
51
- fernet = Fernet(encryption_key)
52
-
53
- def encrypt_text(text):
54
- return fernet.encrypt(text.encode())
55
-
56
- def decrypt_text(encrypted_text):
57
- return fernet.decrypt(encrypted_text).decode()
58
-
59
- # ================================
60
- # 5️⃣ Optimized ASR Web App
61
- # ================================
62
- st.title("πŸŽ™οΈ Speech-to-Text ASR Model using Conformer with Security Features")
63
 
64
  audio_file = st.file_uploader("Upload an audio file", type=["wav", "mp3", "flac"])
65
 
@@ -68,62 +47,25 @@ if audio_file:
68
  with open(audio_path, "wb") as f:
69
  f.write(audio_file.read())
70
 
71
- # Load and preprocess the audio file using torchaudio
72
  waveform, sample_rate = torchaudio.load(audio_path)
73
  waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
74
  waveform = waveform.to(dtype=torch.float32)
75
 
76
- # ================================
77
- # βœ… Optimized Adversarial Attack Handling
78
- # ================================
79
- noise = attack_strength * torch.randn_like(waveform)
80
- adversarial_waveform = waveform + noise
81
  adversarial_waveform = torch.clamp(adversarial_waveform, -1.0, 1.0)
82
-
83
- # ================================
84
- # βœ… Preprocess Audio with Processor (Corrected)
85
- # ================================
86
- # Ensure the input has batch dimension (even if it's one example)
87
- inputs = processor(adversarial_waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt", padding=True)
88
-
89
- # Check the structure of the returned `inputs` to understand what it contains
90
- st.write("Processor Output:", inputs)
91
-
92
- # Extract the correct key (input_features or input_values depending on the model)
93
- if "input_features" in inputs:
94
- input_features = inputs["input_features"]
95
- elif "input_values" in inputs:
96
- input_features = inputs["input_values"]
97
- else:
98
- raise ValueError("❌ The processor output does not contain 'input_features' or 'input_values'.")
99
-
100
- input_features = input_features.to("cuda" if torch.cuda.is_available() else "cpu")
101
-
102
- # ================================
103
- # βœ… Fast Transcription Processing with Conformer
104
- # ================================
105
- with torch.no_grad():
106
- logits = model(input_features).logits
107
-
108
- predicted_ids = torch.argmax(logits, dim=-1)
109
- transcription = processor.batch_decode(predicted_ids)
110
-
111
- if attack_strength > 0.3:
112
- st.warning("⚠️ Adversarial attack detected! Denoising applied.")
113
-
114
- # ================================
115
- # βœ… Optimized Encryption Handling
116
- # ================================
117
- if enable_encryption:
118
- encrypted_transcription = encrypt_text(transcription[0])
119
- st.info("πŸ”’ Transcription is encrypted. Enable 'Show Transcription' to view.")
120
-
121
- if show_transcription:
122
- decrypted_text = decrypt_text(encrypted_transcription)
123
- st.success("πŸ“„ Secure Transcription:")
124
- st.write(decrypted_text)
125
- else:
126
- st.write("πŸ”’ [Encrypted] Transcription hidden. Enable 'Show Transcription' to view.")
127
- else:
128
- st.success("πŸ“„ Transcription:")
129
- st.write(transcription[0])
 
3
  import torchaudio
4
  import streamlit as st
5
  from huggingface_hub import login
6
+ from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
 
7
 
8
  # ================================
9
+ # 1️⃣ Authenticate with Hugging Face Hub (Securely)
10
  # ================================
11
+ HF_TOKEN = os.getenv("hf_token")
 
 
 
 
 
12
 
13
+ if HF_TOKEN is None:
14
+ raise ValueError("❌ Hugging Face API token not found. Please set it in Secrets.")
15
+
16
+ login(token=HF_TOKEN)
17
 
18
  # ================================
19
+ # 2️⃣ Load Conformer Model & Processor
20
  # ================================
21
+ MODEL_NAME = "facebook/wav2vec2-conformer-rel-pos-large"
22
+ processor = AutoProcessor.from_pretrained(MODEL_NAME)
23
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(MODEL_NAME)
 
 
 
24
 
25
+ device = "cuda" if torch.cuda.is_available() else "cpu"
26
+ model.to(device)
27
+ print(f"βœ… Conformer Model loaded on {device}")
28
 
29
  # ================================
30
+ # 3️⃣ Streamlit UI: Fine-Tuning Hyperparameter Selection
31
  # ================================
32
+ st.sidebar.title("πŸ”§ Fine-Tuning Hyperparameters")
 
33
  num_epochs = st.sidebar.slider("Epochs", min_value=1, max_value=10, value=3)
34
  learning_rate = st.sidebar.select_slider("Learning Rate", options=[5e-4, 1e-4, 5e-5, 1e-5], value=5e-5)
35
  batch_size = st.sidebar.select_slider("Batch Size", options=[2, 4, 8, 16], value=8)
36
+ attack_strength = st.sidebar.slider("Attack Strength", 0.0, 0.9, 0.1)
 
 
 
 
37
 
38
  # ================================
39
+ # 4️⃣ Streamlit ASR Web App (Fast Decoding & Security Features)
40
  # ================================
41
+ st.title("πŸŽ™οΈ Speech-to-Text ASR Model with Security Features 🎢")
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
  audio_file = st.file_uploader("Upload an audio file", type=["wav", "mp3", "flac"])
44
 
 
47
  with open(audio_path, "wb") as f:
48
  f.write(audio_file.read())
49
 
 
50
  waveform, sample_rate = torchaudio.load(audio_path)
51
  waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
52
  waveform = waveform.to(dtype=torch.float32)
53
 
54
+ # Simulate an adversarial attack by injecting random noise
55
+ adversarial_waveform = waveform + (attack_strength * torch.randn_like(waveform))
 
 
 
56
  adversarial_waveform = torch.clamp(adversarial_waveform, -1.0, 1.0)
57
+
58
+ inputs = processor(adversarial_waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt")
59
+ input_features = inputs.input_values.to(device)
60
+ attention_mask = inputs.attention_mask.to(device) if "attention_mask" in inputs else None
61
+
62
+ with torch.inference_mode():
63
+ generated_ids = model.generate(input_features, max_length=200, num_beams=2, do_sample=False, use_cache=True,
64
+ attention_mask=attention_mask)
65
+ transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
66
+
67
+ if attack_strength > 0.1:
68
+ st.warning("⚠️ Adversarial attack detected! Transcription may be affected.")
69
+
70
+ st.success("πŸ“„ Secure Transcription:")
71
+ st.write(transcription)