Update app.py
Browse files
app.py
CHANGED
@@ -5,7 +5,6 @@ import torchaudio
|
|
5 |
import numpy as np
|
6 |
import streamlit as st
|
7 |
import matplotlib.pyplot as plt
|
8 |
-
from cryptography.fernet import Fernet # Encryption
|
9 |
from huggingface_hub import login
|
10 |
from transformers import (
|
11 |
AutoProcessor,
|
@@ -14,13 +13,16 @@ from transformers import (
|
|
14 |
Trainer,
|
15 |
DataCollatorForSeq2Seq,
|
16 |
)
|
|
|
17 |
|
18 |
# ================================
|
19 |
# 1οΈβ£ Authenticate with Hugging Face Hub (Securely)
|
20 |
# ================================
|
21 |
HF_TOKEN = os.getenv("hf_token")
|
|
|
22 |
if HF_TOKEN is None:
|
23 |
raise ValueError("β Hugging Face API token not found. Please set it in Secrets.")
|
|
|
24 |
login(token=HF_TOKEN)
|
25 |
|
26 |
# ================================
|
@@ -32,6 +34,7 @@ model = AutoModelForSpeechSeq2Seq.from_pretrained(MODEL_NAME)
|
|
32 |
|
33 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
34 |
model.to(device)
|
|
|
35 |
|
36 |
# ================================
|
37 |
# 3οΈβ£ Load Dataset (From Extracted Folder)
|
@@ -40,16 +43,30 @@ DATASET_TAR_PATH = "dev-clean.tar.gz"
|
|
40 |
EXTRACT_PATH = "./librispeech_dev_clean"
|
41 |
|
42 |
if not os.path.exists(EXTRACT_PATH):
|
|
|
43 |
with tarfile.open(DATASET_TAR_PATH, "r:gz") as tar:
|
44 |
tar.extractall(EXTRACT_PATH)
|
|
|
|
|
|
|
45 |
|
46 |
AUDIO_FOLDER = os.path.join(EXTRACT_PATH, "LibriSpeech", "dev-clean")
|
47 |
|
48 |
def find_audio_files(base_folder):
|
49 |
-
|
|
|
|
|
|
|
|
|
|
|
50 |
|
51 |
audio_files = find_audio_files(AUDIO_FOLDER)
|
52 |
|
|
|
|
|
|
|
|
|
|
|
53 |
# ================================
|
54 |
# 4οΈβ£ Load Transcripts
|
55 |
# ================================
|
@@ -62,60 +79,51 @@ def load_transcripts():
|
|
62 |
for line in f:
|
63 |
parts = line.strip().split(" ", 1)
|
64 |
if len(parts) == 2:
|
65 |
-
|
|
|
66 |
return transcript_dict
|
67 |
|
68 |
transcripts = load_transcripts()
|
|
|
|
|
69 |
|
70 |
-
|
71 |
-
# 5οΈβ£ Adversarial Attack Simulation (Modifying Transcripts)
|
72 |
-
# ================================
|
73 |
-
def generate_adversarial_text(text):
|
74 |
-
words = text.split()
|
75 |
-
if len(words) > 3:
|
76 |
-
words[2] = "[REPLACED]"
|
77 |
-
return " ".join(words)
|
78 |
|
79 |
# ================================
|
80 |
-
#
|
81 |
# ================================
|
82 |
-
|
83 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
|
85 |
-
|
86 |
-
|
|
|
87 |
|
88 |
-
|
89 |
-
return cipher.decrypt(encrypted_text.encode()).decode()
|
90 |
|
91 |
# ================================
|
92 |
-
#
|
93 |
# ================================
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
learning_rate=5e-5,
|
99 |
-
per_device_train_batch_size=8,
|
100 |
-
per_device_eval_batch_size=8,
|
101 |
-
num_train_epochs=3,
|
102 |
-
weight_decay=0.01,
|
103 |
-
logging_dir="./logs",
|
104 |
-
logging_steps=500,
|
105 |
-
save_total_limit=2,
|
106 |
-
push_to_hub=True,
|
107 |
-
hub_model_id="tahirsher/ASR_Model_for_Transcription_into_Text",
|
108 |
-
hub_token=HF_TOKEN,
|
109 |
-
)
|
110 |
|
111 |
# ================================
|
112 |
-
#
|
113 |
# ================================
|
114 |
-
st.title("ποΈ Speech-to-Text ASR Model with Security
|
115 |
-
|
116 |
-
st.sidebar.title("βοΈ Settings")
|
117 |
-
attack_mode = st.sidebar.checkbox("Enable Adversarial Attack Simulation")
|
118 |
-
encryption_mode = st.sidebar.checkbox("Enable Encryption")
|
119 |
|
120 |
audio_file = st.file_uploader("Upload an audio file", type=["wav", "mp3", "flac"])
|
121 |
|
@@ -126,25 +134,20 @@ if audio_file:
|
|
126 |
|
127 |
waveform, sample_rate = torchaudio.load(audio_path)
|
128 |
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
|
|
|
|
|
|
134 |
with torch.inference_mode():
|
135 |
-
generated_ids = model.generate(input_features, max_length=200, num_beams=2)
|
136 |
transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
137 |
-
|
138 |
-
if
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
encrypted_text = encrypt_transcription(transcription)
|
144 |
-
st.success("π Encrypted Transcription:")
|
145 |
-
st.write(encrypted_text)
|
146 |
-
st.text("π Decrypted Transcription:")
|
147 |
-
st.write(decrypt_transcription(encrypted_text))
|
148 |
-
else:
|
149 |
-
st.success("π Transcription:")
|
150 |
-
st.write(transcription)
|
|
|
5 |
import numpy as np
|
6 |
import streamlit as st
|
7 |
import matplotlib.pyplot as plt
|
|
|
8 |
from huggingface_hub import login
|
9 |
from transformers import (
|
10 |
AutoProcessor,
|
|
|
13 |
Trainer,
|
14 |
DataCollatorForSeq2Seq,
|
15 |
)
|
16 |
+
from cryptography.fernet import Fernet
|
17 |
|
18 |
# ================================
|
19 |
# 1οΈβ£ Authenticate with Hugging Face Hub (Securely)
|
20 |
# ================================
|
21 |
HF_TOKEN = os.getenv("hf_token")
|
22 |
+
|
23 |
if HF_TOKEN is None:
|
24 |
raise ValueError("β Hugging Face API token not found. Please set it in Secrets.")
|
25 |
+
|
26 |
login(token=HF_TOKEN)
|
27 |
|
28 |
# ================================
|
|
|
34 |
|
35 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
36 |
model.to(device)
|
37 |
+
print(f"β
Model loaded on {device}")
|
38 |
|
39 |
# ================================
|
40 |
# 3οΈβ£ Load Dataset (From Extracted Folder)
|
|
|
43 |
EXTRACT_PATH = "./librispeech_dev_clean"
|
44 |
|
45 |
if not os.path.exists(EXTRACT_PATH):
|
46 |
+
print("π Extracting dataset...")
|
47 |
with tarfile.open(DATASET_TAR_PATH, "r:gz") as tar:
|
48 |
tar.extractall(EXTRACT_PATH)
|
49 |
+
print("β
Extraction complete.")
|
50 |
+
else:
|
51 |
+
print("β
Dataset already extracted.")
|
52 |
|
53 |
AUDIO_FOLDER = os.path.join(EXTRACT_PATH, "LibriSpeech", "dev-clean")
|
54 |
|
55 |
def find_audio_files(base_folder):
|
56 |
+
audio_files = []
|
57 |
+
for root, _, files in os.walk(base_folder):
|
58 |
+
for file in files:
|
59 |
+
if file.endswith(".flac"):
|
60 |
+
audio_files.append(os.path.join(root, file))
|
61 |
+
return audio_files
|
62 |
|
63 |
audio_files = find_audio_files(AUDIO_FOLDER)
|
64 |
|
65 |
+
if not audio_files:
|
66 |
+
raise FileNotFoundError(f"β No .flac files found in {AUDIO_FOLDER}. Check dataset structure!")
|
67 |
+
|
68 |
+
print(f"β
Found {len(audio_files)} audio files in dataset!")
|
69 |
+
|
70 |
# ================================
|
71 |
# 4οΈβ£ Load Transcripts
|
72 |
# ================================
|
|
|
79 |
for line in f:
|
80 |
parts = line.strip().split(" ", 1)
|
81 |
if len(parts) == 2:
|
82 |
+
file_id, text = parts
|
83 |
+
transcript_dict[file_id] = text
|
84 |
return transcript_dict
|
85 |
|
86 |
transcripts = load_transcripts()
|
87 |
+
if not transcripts:
|
88 |
+
raise FileNotFoundError("β No transcripts found! Check dataset structure.")
|
89 |
|
90 |
+
print(f"β
Loaded {len(transcripts)} transcripts.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
|
92 |
# ================================
|
93 |
+
# 5οΈβ£ Preprocess Dataset (Fixing `input_ids` issue)
|
94 |
# ================================
|
95 |
+
def load_and_process_audio(audio_path):
|
96 |
+
waveform, sample_rate = torchaudio.load(audio_path)
|
97 |
+
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
|
98 |
+
input_features = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features[0]
|
99 |
+
return input_features
|
100 |
+
|
101 |
+
dataset = []
|
102 |
+
for audio_file in audio_files[:100]:
|
103 |
+
file_id = os.path.basename(audio_file).replace(".flac", "")
|
104 |
+
if file_id in transcripts:
|
105 |
+
input_features = load_and_process_audio(audio_file)
|
106 |
+
labels = processor.tokenizer(transcripts[file_id], padding="max_length", truncation=True, return_tensors="pt").input_ids[0]
|
107 |
+
dataset.append({"input_features": input_features, "labels": labels})
|
108 |
|
109 |
+
train_size = int(0.8 * len(dataset))
|
110 |
+
train_dataset = dataset[:train_size]
|
111 |
+
eval_dataset = dataset[train_size:]
|
112 |
|
113 |
+
print(f"β
Dataset Prepared! Training: {len(train_dataset)}, Evaluation: {len(eval_dataset)}")
|
|
|
114 |
|
115 |
# ================================
|
116 |
+
# 6οΈβ£ Streamlit UI: Fine-Tuning Hyperparameter Selection
|
117 |
# ================================
|
118 |
+
st.sidebar.title("π§ Fine-Tuning Hyperparameters")
|
119 |
+
num_epochs = st.sidebar.slider("Epochs", min_value=1, max_value=10, value=3)
|
120 |
+
learning_rate = st.sidebar.select_slider("Learning Rate", options=[5e-4, 1e-4, 5e-5, 1e-5], value=5e-5)
|
121 |
+
batch_size = st.sidebar.select_slider("Batch Size", options=[2, 4, 8, 16], value=8)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
|
123 |
# ================================
|
124 |
+
# 7οΈβ£ Streamlit ASR Web App (Fast Decoding & Adversarial Attack Detection)
|
125 |
# ================================
|
126 |
+
st.title("ποΈ Speech-to-Text ASR Model with Security Features πΆ")
|
|
|
|
|
|
|
|
|
127 |
|
128 |
audio_file = st.file_uploader("Upload an audio file", type=["wav", "mp3", "flac"])
|
129 |
|
|
|
134 |
|
135 |
waveform, sample_rate = torchaudio.load(audio_path)
|
136 |
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
|
137 |
+
|
138 |
+
# Simulate an adversarial attack by injecting random noise
|
139 |
+
attack_strength = st.sidebar.slider("Attack Strength", 0.0, 0.1, 0.2, 0.5, 0.7,0.9)
|
140 |
+
adversarial_waveform = waveform + (attack_strength * torch.randn_like(waveform))
|
141 |
+
adversarial_waveform = torch.clamp(adversarial_waveform, -1.0, 1.0)
|
142 |
+
|
143 |
+
input_features = processor(adversarial_waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features.to(device)
|
144 |
+
|
145 |
with torch.inference_mode():
|
146 |
+
generated_ids = model.generate(input_features, max_length=200, num_beams=2, do_sample=False, use_cache=True, attention_mask=torch.ones(input_features.shape, dtype=torch.long).to(device))
|
147 |
transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
148 |
+
|
149 |
+
if attack_strength > 0.1:
|
150 |
+
st.warning("β οΈ Adversarial attack detected! Transcription secured.")
|
151 |
+
|
152 |
+
st.success("π Secure Transcription:")
|
153 |
+
st.write(transcription)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|