tahirsher's picture
Update app.py
3a79217 verified
raw
history blame
6.04 kB
import os
import tarfile
import torch
import torchaudio
import numpy as np
import streamlit as st
from transformers import (
AutoProcessor,
AutoModelForSpeechSeq2Seq,
TrainingArguments,
Trainer,
DataCollatorForSeq2Seq,
)
# ================================
# 1️⃣ Load Model & Processor
# ================================
MODEL_NAME = "AqeelShafy7/AudioSangraha-Audio_to_Text"
# Load ASR model and processor
processor = AutoProcessor.from_pretrained(MODEL_NAME)
model = AutoModelForSpeechSeq2Seq.from_pretrained(MODEL_NAME)
# Move model to GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
print(f"βœ… Model loaded on {device}")
# ================================
# 2️⃣ Load Dataset (Recursively from Extracted Path)
# ================================
DATASET_TAR_PATH = "dev-clean.tar.gz"
EXTRACT_PATH = "./librispeech_dev_clean"
# Extract dataset if not already extracted
if not os.path.exists(EXTRACT_PATH):
print("πŸ”„ Extracting dataset...")
with tarfile.open(DATASET_TAR_PATH, "r:gz") as tar:
tar.extractall(EXTRACT_PATH)
print("βœ… Extraction complete.")
else:
print("βœ… Dataset already extracted.")
# Base directory where audio files are stored
AUDIO_FOLDER = os.path.join(EXTRACT_PATH, "LibriSpeech", "dev-clean")
# Recursively find all `.flac` files inside the dataset directory
def find_audio_files(base_folder):
"""Recursively search for all .flac files in subdirectories."""
audio_files = []
for root, _, files in os.walk(base_folder):
for file in files:
if file.endswith(".flac"):
audio_files.append(os.path.join(root, file))
return audio_files
# Get all audio files
audio_files = find_audio_files(AUDIO_FOLDER)
if not audio_files:
raise FileNotFoundError(f"❌ No .flac files found in {AUDIO_FOLDER}. Check dataset structure!")
print(f"βœ… Found {len(audio_files)} audio files in dataset!")
# ================================
# 3️⃣ Preprocess Dataset (Fixed input_features)
# ================================
def load_and_process_audio(audio_path):
"""Loads and processes a single audio file into model format."""
waveform, sample_rate = torchaudio.load(audio_path)
# Resample to 16kHz
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
# Convert to model input format
input_features = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features[0]
return input_features
# Manually create dataset structure
dataset = [{"input_features": load_and_process_audio(f), "labels": []} for f in audio_files[:100]]
# Split dataset into train and eval (Recommended Fix)
train_size = int(0.9 * len(dataset))
train_dataset = dataset[:train_size]
eval_dataset = dataset[train_size:]
print(f"βœ… Dataset Loaded! Training: {len(train_dataset)}, Evaluation: {len(eval_dataset)}")
# ================================
# 4️⃣ Training Arguments & Trainer
# ================================
training_args = TrainingArguments(
output_dir="./asr_model_finetuned",
eval_strategy="epoch", # Fix: Proper evaluation
save_strategy="epoch",
learning_rate=5e-5,
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
num_train_epochs=3,
weight_decay=0.01,
logging_dir="./logs",
logging_steps=500,
save_total_limit=2,
push_to_hub=True,
)
# Data collator (for dynamic padding)
data_collator = DataCollatorForSeq2Seq(processor, model=model)
# Define Trainer (Fixed `processing_class` warning)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset, # Fix: Providing eval_dataset
processing_class=processor, # Fix: Replacing deprecated `tokenizer`
data_collator=data_collator,
)
# ================================
# 5️⃣ Fine-Tuning Execution
# ================================
if st.button("Start Fine-Tuning"):
with st.spinner("Fine-tuning in progress... Please wait!"):
trainer.train()
st.success("βœ… Fine-Tuning Completed! Model updated.")
# ================================
# 6️⃣ Streamlit ASR Web App
# ================================
st.title("πŸŽ™οΈ Speech-to-Text ASR with Fine-Tuning 🎢")
# Upload audio file
audio_file = st.file_uploader("Upload an audio file", type=["wav", "mp3", "flac"])
if audio_file:
# Save uploaded file temporarily
audio_path = "temp_audio.wav"
with open(audio_path, "wb") as f:
f.write(audio_file.read())
# Load and process audio
waveform, sample_rate = torchaudio.load(audio_path)
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
# Convert audio to model input
input_features = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features[0]
# Perform ASR inference
with torch.no_grad():
input_tensor = torch.tensor([input_features]).to(device)
logits = model(input_tensor).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids)[0]
# Display transcription
st.success("πŸ“„ Transcription:")
st.write(transcription)
# ================================
# 7️⃣ Fine-Tune Model with User Correction
# ================================
user_correction = st.text_area("πŸ”§ Correct the transcription (if needed):", transcription)
if st.button("Fine-Tune with Correction"):
if user_correction:
corrected_input = processor.tokenizer(user_correction).input_ids
# Dynamically add new example to dataset
dataset.append({"input_features": input_features, "labels": corrected_input})
# Perform quick re-training (1 epoch)
trainer.args.num_train_epochs = 1
trainer.train()
st.success("βœ… Model fine-tuned with new correction! Try another audio file.")