tahirsher's picture
Update app.py
f6dc6c7 verified
raw
history blame
5.94 kB
import os
import tarfile
import torch
import torchaudio
import numpy as np
import streamlit as st
import matplotlib.pyplot as plt
from huggingface_hub import login
from transformers import (
AutoProcessor,
AutoModelForSpeechSeq2Seq,
TrainingArguments,
Trainer,
DataCollatorForSeq2Seq,
)
from cryptography.fernet import Fernet
# ================================
# 1️⃣ Authenticate with Hugging Face Hub (Securely)
# ================================
HF_TOKEN = os.getenv("hf_token")
if HF_TOKEN is None:
raise ValueError("❌ Hugging Face API token not found. Please set it in Secrets.")
login(token=HF_TOKEN)
# ================================
# 2️⃣ Load Model & Processor
# ================================
MODEL_NAME = "AqeelShafy7/AudioSangraha-Audio_to_Text"
processor = AutoProcessor.from_pretrained(MODEL_NAME)
model = AutoModelForSpeechSeq2Seq.from_pretrained(MODEL_NAME)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
print(f"βœ… Model loaded on {device}")
# ================================
# 3️⃣ Load Dataset (From Extracted Folder)
# ================================
DATASET_TAR_PATH = "dev-clean.tar.gz"
EXTRACT_PATH = "./librispeech_dev_clean"
if not os.path.exists(EXTRACT_PATH):
print("πŸ”„ Extracting dataset...")
with tarfile.open(DATASET_TAR_PATH, "r:gz") as tar:
tar.extractall(EXTRACT_PATH)
print("βœ… Extraction complete.")
else:
print("βœ… Dataset already extracted.")
AUDIO_FOLDER = os.path.join(EXTRACT_PATH, "LibriSpeech", "dev-clean")
def find_audio_files(base_folder):
audio_files = []
for root, _, files in os.walk(base_folder):
for file in files:
if file.endswith(".flac"):
audio_files.append(os.path.join(root, file))
return audio_files
audio_files = find_audio_files(AUDIO_FOLDER)
if not audio_files:
raise FileNotFoundError(f"❌ No .flac files found in {AUDIO_FOLDER}. Check dataset structure!")
print(f"βœ… Found {len(audio_files)} audio files in dataset!")
# ================================
# 4️⃣ Load Transcripts
# ================================
def load_transcripts():
transcript_dict = {}
for root, _, files in os.walk(AUDIO_FOLDER):
for file in files:
if file.endswith(".txt"):
with open(os.path.join(root, file), "r", encoding="utf-8") as f:
for line in f:
parts = line.strip().split(" ", 1)
if len(parts) == 2:
file_id, text = parts
transcript_dict[file_id] = text
return transcript_dict
transcripts = load_transcripts()
if not transcripts:
raise FileNotFoundError("❌ No transcripts found! Check dataset structure.")
print(f"βœ… Loaded {len(transcripts)} transcripts.")
# ================================
# 5️⃣ Preprocess Dataset (Fixing `input_ids` issue)
# ================================
def load_and_process_audio(audio_path):
waveform, sample_rate = torchaudio.load(audio_path)
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
input_features = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features[0]
return input_features
dataset = []
for audio_file in audio_files[:100]:
file_id = os.path.basename(audio_file).replace(".flac", "")
if file_id in transcripts:
input_features = load_and_process_audio(audio_file)
labels = processor.tokenizer(transcripts[file_id], padding="max_length", truncation=True, return_tensors="pt").input_ids[0]
dataset.append({"input_features": input_features, "labels": labels})
train_size = int(0.8 * len(dataset))
train_dataset = dataset[:train_size]
eval_dataset = dataset[train_size:]
print(f"βœ… Dataset Prepared! Training: {len(train_dataset)}, Evaluation: {len(eval_dataset)}")
# ================================
# 6️⃣ Streamlit UI: Fine-Tuning Hyperparameter Selection
# ================================
st.sidebar.title("πŸ”§ Fine-Tuning Hyperparameters")
num_epochs = st.sidebar.slider("Epochs", min_value=1, max_value=10, value=3)
learning_rate = st.sidebar.select_slider("Learning Rate", options=[5e-4, 1e-4, 5e-5, 1e-5], value=5e-5)
batch_size = st.sidebar.select_slider("Batch Size", options=[2, 4, 8, 16], value=8)
# ================================
# 7️⃣ Streamlit ASR Web App (Fast Decoding & Adversarial Attack Detection)
# ================================
st.title("πŸŽ™οΈ Speech-to-Text ASR Model with Security Features 🎢")
audio_file = st.file_uploader("Upload an audio file", type=["wav", "mp3", "flac"])
if audio_file:
audio_path = "temp_audio.wav"
with open(audio_path, "wb") as f:
f.write(audio_file.read())
waveform, sample_rate = torchaudio.load(audio_path)
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
# Simulate an adversarial attack by injecting random noise
attack_strength = st.sidebar.slider("Attack Strength", 0.0, 0.1, 0.2, 0.5, 0.7,0.9)
adversarial_waveform = waveform + (attack_strength * torch.randn_like(waveform))
adversarial_waveform = torch.clamp(adversarial_waveform, -1.0, 1.0)
input_features = processor(adversarial_waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features.to(device)
with torch.inference_mode():
generated_ids = model.generate(input_features, max_length=200, num_beams=2, do_sample=False, use_cache=True, attention_mask=torch.ones(input_features.shape, dtype=torch.long).to(device))
transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
if attack_strength > 0.1:
st.warning("⚠️ Adversarial attack detected! Transcription secured.")
st.success("πŸ“„ Secure Transcription:")
st.write(transcription)