tahirsher's picture
Update app.py
a4a32f2 verified
raw
history blame
4.55 kB
import os
import torch
import torchaudio
import librosa
import streamlit as st
from huggingface_hub import login
from transformers import AutoProcessor, AutoModelForCTC
from cryptography.fernet import Fernet
# ================================
# 1️⃣ Authenticate with Hugging Face Hub (Cache to prevent re-authentication)
# ================================
@st.cache_resource
def authenticate_hf():
HF_TOKEN = os.getenv("hf_token")
if HF_TOKEN is None:
raise ValueError("❌ Hugging Face API token not found. Please set it in Secrets.")
login(token=HF_TOKEN)
authenticate_hf()
# ================================
# 2️⃣ Load Conformer Model & Processor (Cached)
# ================================
@st.cache_resource
def load_model():
MODEL_NAME = "deepl-project/conformer-finetunning"
processor = AutoProcessor.from_pretrained(MODEL_NAME)
model = AutoModelForCTC.from_pretrained(MODEL_NAME).to("cuda" if torch.cuda.is_available() else "cpu")
return processor, model
processor, model = load_model()
# ================================
# 3️⃣ Streamlit Sidebar for Fine-Tuning & Security
# ================================
st.sidebar.title("πŸ”§ Fine-Tuning & Security Settings")
num_epochs = st.sidebar.slider("Epochs", min_value=1, max_value=10, value=3)
learning_rate = st.sidebar.select_slider("Learning Rate", options=[5e-4, 1e-4, 5e-5, 1e-5], value=5e-5)
batch_size = st.sidebar.select_slider("Batch Size", options=[2, 4, 8, 16], value=8)
attack_strength = st.sidebar.slider("Adversarial Attack Strength", 0.1, 0.9, 0.3)
enable_encryption = st.sidebar.checkbox("πŸ”’ Encrypt Transcription", value=True)
show_transcription = st.sidebar.checkbox("πŸ“– Show Transcription", value=False)
# ================================
# 4️⃣ Encryption Handling (Precomputed Key)
# ================================
encryption_key = Fernet.generate_key()
fernet = Fernet(encryption_key)
def encrypt_text(text):
return fernet.encrypt(text.encode())
def decrypt_text(encrypted_text):
return fernet.decrypt(encrypted_text).decode()
# ================================
# 5️⃣ Optimized ASR Web App
# ================================
st.title("πŸŽ™οΈ Speech-to-Text ASR Model using Conformer with Security Features")
audio_file = st.file_uploader("Upload an audio file", type=["wav", "mp3", "flac"])
if audio_file:
audio_path = "temp_audio.wav"
with open(audio_path, "wb") as f:
f.write(audio_file.read())
# Load and preprocess the audio file using librosa
speech, sr = librosa.load(audio_path, sr=16000)
# ================================
# βœ… Optimized Adversarial Attack Handling
# ================================
noise = attack_strength * torch.randn_like(torch.tensor(speech))
adversarial_waveform = torch.tensor(speech) + noise
adversarial_waveform = torch.clamp(adversarial_waveform, -1.0, 1.0)
# Remove background noise for speed & accuracy
denoised_waveform = torchaudio.functional.vad(adversarial_waveform, sample_rate=16000)
# ================================
# βœ… Fast Transcription Processing with Conformer
# ================================
# Convert waveform into the required format
inputs = processor(denoised_waveform.numpy(), sampling_rate=sr, return_tensors="pt", padding=True).to("cuda" if torch.cuda.is_available() else "cpu")
# Make sure the input has batch dimension (even if it's one example)
if len(inputs.input_values.shape) == 1:
inputs.input_values = inputs.input_values.unsqueeze(0)
with torch.no_grad():
logits = model(**inputs).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids)
if attack_strength > 0.3:
st.warning("⚠️ Adversarial attack detected! Denoising applied.")
# ================================
# βœ… Optimized Encryption Handling
# ================================
if enable_encryption:
encrypted_transcription = encrypt_text(transcription[0])
st.info("πŸ”’ Transcription is encrypted. Enable 'Show Transcription' to view.")
if show_transcription:
decrypted_text = decrypt_text(encrypted_transcription)
st.success("πŸ“„ Secure Transcription:")
st.write(decrypted_text)
else:
st.write("πŸ”’ [Encrypted] Transcription hidden. Enable 'Show Transcription' to view.")
else:
st.success("πŸ“„ Transcription:")
st.write(transcription[0])