|
import os |
|
import tarfile |
|
import torch |
|
import torchaudio |
|
import numpy as np |
|
import streamlit as st |
|
import matplotlib.pyplot as plt |
|
from huggingface_hub import login |
|
from transformers import ( |
|
AutoProcessor, |
|
AutoModelForSpeechSeq2Seq, |
|
TrainingArguments, |
|
Trainer, |
|
DataCollatorForSeq2Seq, |
|
) |
|
from cryptography.fernet import Fernet |
|
|
|
|
|
|
|
|
|
HF_TOKEN = os.getenv("hf_token") |
|
|
|
if HF_TOKEN is None: |
|
raise ValueError("β Hugging Face API token not found. Please set it in Secrets.") |
|
|
|
login(token=HF_TOKEN) |
|
|
|
|
|
|
|
|
|
MODEL_NAME = "AqeelShafy7/AudioSangraha-Audio_to_Text" |
|
processor = AutoProcessor.from_pretrained(MODEL_NAME) |
|
model = AutoModelForSpeechSeq2Seq.from_pretrained(MODEL_NAME) |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
model.to(device) |
|
print(f"β
Model loaded on {device}") |
|
|
|
|
|
|
|
|
|
DATASET_TAR_PATH = "dev-clean.tar.gz" |
|
EXTRACT_PATH = "./librispeech_dev_clean" |
|
|
|
if not os.path.exists(EXTRACT_PATH): |
|
print("π Extracting dataset...") |
|
with tarfile.open(DATASET_TAR_PATH, "r:gz") as tar: |
|
tar.extractall(EXTRACT_PATH) |
|
print("β
Extraction complete.") |
|
else: |
|
print("β
Dataset already extracted.") |
|
|
|
AUDIO_FOLDER = os.path.join(EXTRACT_PATH, "LibriSpeech", "dev-clean") |
|
|
|
def find_audio_files(base_folder): |
|
audio_files = [] |
|
for root, _, files in os.walk(base_folder): |
|
for file in files: |
|
if file.endswith(".flac"): |
|
audio_files.append(os.path.join(root, file)) |
|
return audio_files |
|
|
|
audio_files = find_audio_files(AUDIO_FOLDER) |
|
|
|
if not audio_files: |
|
raise FileNotFoundError(f"β No .flac files found in {AUDIO_FOLDER}. Check dataset structure!") |
|
|
|
print(f"β
Found {len(audio_files)} audio files in dataset!") |
|
|
|
|
|
|
|
|
|
def load_transcripts(): |
|
transcript_dict = {} |
|
for root, _, files in os.walk(AUDIO_FOLDER): |
|
for file in files: |
|
if file.endswith(".txt"): |
|
with open(os.path.join(root, file), "r", encoding="utf-8") as f: |
|
for line in f: |
|
parts = line.strip().split(" ", 1) |
|
if len(parts) == 2: |
|
file_id, text = parts |
|
transcript_dict[file_id] = text |
|
return transcript_dict |
|
|
|
transcripts = load_transcripts() |
|
if not transcripts: |
|
raise FileNotFoundError("β No transcripts found! Check dataset structure.") |
|
|
|
print(f"β
Loaded {len(transcripts)} transcripts.") |
|
|
|
|
|
|
|
|
|
def load_and_process_audio(audio_path): |
|
waveform, sample_rate = torchaudio.load(audio_path) |
|
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform) |
|
input_features = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features[0] |
|
return input_features |
|
|
|
dataset = [] |
|
for audio_file in audio_files[:100]: |
|
file_id = os.path.basename(audio_file).replace(".flac", "") |
|
if file_id in transcripts: |
|
input_features = load_and_process_audio(audio_file) |
|
labels = processor.tokenizer(transcripts[file_id], padding="max_length", truncation=True, return_tensors="pt").input_ids[0] |
|
dataset.append({"input_features": input_features, "labels": labels}) |
|
|
|
train_size = int(0.8 * len(dataset)) |
|
train_dataset = dataset[:train_size] |
|
eval_dataset = dataset[train_size:] |
|
|
|
print(f"β
Dataset Prepared! Training: {len(train_dataset)}, Evaluation: {len(eval_dataset)}") |
|
|
|
|
|
|
|
|
|
st.sidebar.title("π§ Fine-Tuning Hyperparameters") |
|
num_epochs = st.sidebar.slider("Epochs", min_value=1, max_value=10, value=3) |
|
learning_rate = st.sidebar.select_slider("Learning Rate", options=[5e-4, 1e-4, 5e-5, 1e-5], value=5e-5) |
|
batch_size = st.sidebar.select_slider("Batch Size", options=[2, 4, 8, 16], value=8) |
|
|
|
|
|
|
|
|
|
st.title("ποΈ Speech-to-Text ASR Model with Security Features πΆ") |
|
|
|
audio_file = st.file_uploader("Upload an audio file", type=["wav", "mp3", "flac"]) |
|
|
|
if audio_file: |
|
audio_path = "temp_audio.wav" |
|
with open(audio_path, "wb") as f: |
|
f.write(audio_file.read()) |
|
|
|
waveform, sample_rate = torchaudio.load(audio_path) |
|
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform) |
|
|
|
|
|
attack_strength = st.sidebar.slider("Attack Strength", 0.0, 0.1, 0.2, 0.5, 0.7,0.9) |
|
adversarial_waveform = waveform + (attack_strength * torch.randn_like(waveform)) |
|
adversarial_waveform = torch.clamp(adversarial_waveform, -1.0, 1.0) |
|
|
|
input_features = processor(adversarial_waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features.to(device) |
|
|
|
with torch.inference_mode(): |
|
generated_ids = model.generate(input_features, max_length=200, num_beams=2, do_sample=False, use_cache=True, attention_mask=torch.ones(input_features.shape, dtype=torch.long).to(device)) |
|
transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] |
|
|
|
if attack_strength > 0.1: |
|
st.warning("β οΈ Adversarial attack detected! Transcription secured.") |
|
|
|
st.success("π Secure Transcription:") |
|
st.write(transcription) |
|
|