tahirsher's picture
Update app.py
fcd8965 verified
raw
history blame
5.5 kB
import os
import tarfile
import torch
import torchaudio
import numpy as np
import streamlit as st
from datasets import load_dataset
from transformers import (
AutoProcessor,
AutoModelForSpeechSeq2Seq,
TrainingArguments,
Trainer,
DataCollatorForSeq2Seq,
)
# ================================
# 1️⃣ Load Model & Processor
# ================================
MODEL_NAME = "AqeelShafy7/AudioSangraha-Audio_to_Text"
# Load ASR model and processor
processor = AutoProcessor.from_pretrained(MODEL_NAME)
model = AutoModelForSpeechSeq2Seq.from_pretrained(MODEL_NAME)
# Move model to GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
print(f"βœ… Model loaded on {device}")
# ================================
# 2️⃣ Load Dataset (LibriSpeech) from Extracted Path
# ================================
DATASET_TAR_PATH = "dev-clean.tar.gz" # Uploaded dataset in your Hugging Face Space
EXTRACT_PATH = "./librispeech_dev_clean" # Extracted dataset folder
# Extract dataset only if not already extracted
if not os.path.exists(EXTRACT_PATH):
print("πŸ”„ Extracting dataset...")
with tarfile.open(DATASET_TAR_PATH, "r:gz") as tar:
tar.extractall(EXTRACT_PATH)
print("βœ… Extraction complete.")
else:
print("βœ… Dataset already extracted.")
# βœ… Load dataset from extracted folder
dataset = load_dataset("librispeech_asr", data_dir=EXTRACT_PATH, split="train", trust_remote_code=True)
print(f"βœ… Dataset Loaded Successfully! Size: {len(dataset)}")
# ================================
# 3️⃣ Preprocess Dataset
# ================================
def preprocess_audio(batch):
"""Converts raw audio to a model-compatible format."""
audio = batch["audio"]
waveform, sample_rate = torchaudio.load(audio["path"])
# Resample to 16kHz (ASR models usually require this)
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
# Convert to model input format
batch["input_values"] = processor(waveform.squeeze().numpy(), sampling_rate=16000).input_values[0]
batch["labels"] = processor.tokenizer(batch["text"]).input_ids
return batch
# Apply preprocessing
dataset = dataset.map(preprocess_audio, remove_columns=["audio"])
print(f"βœ… Dataset Preprocessed! Ready for Fine-Tuning.")
# ================================
# 4️⃣ Training Arguments & Trainer
# ================================
training_args = TrainingArguments(
output_dir="./asr_model_finetuned",
evaluation_strategy="epoch",
save_strategy="epoch",
learning_rate=5e-5,
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
num_train_epochs=3,
weight_decay=0.01,
logging_dir="./logs",
logging_steps=500,
save_total_limit=2,
push_to_hub=True,
metric_for_best_model="wer",
greater_is_better=False,
save_on_each_node=True, # Improves stability during multi-GPU training
load_best_model_at_end=True, # Saves best model
)
# Data collator (for dynamic padding)
data_collator = DataCollatorForSeq2Seq(processor.tokenizer, model=model)
# Define Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset,
eval_dataset=None, # No validation dataset for now
tokenizer=processor.feature_extractor,
data_collator=data_collator,
)
# ================================
# 5️⃣ Fine-Tuning Execution
# ================================
if st.button("Start Fine-Tuning"):
with st.spinner("Fine-tuning in progress... Please wait!"):
trainer.train()
st.success("βœ… Fine-Tuning Completed! Model updated.")
# ================================
# 6️⃣ Streamlit ASR Web App
# ================================
st.title("πŸŽ™οΈ Speech-to-Text ASR with Fine-Tuning 🎢")
# Upload audio file
audio_file = st.file_uploader("Upload an audio file", type=["wav", "mp3", "flac"])
if audio_file:
# Save uploaded file temporarily
audio_path = "temp_audio.wav"
with open(audio_path, "wb") as f:
f.write(audio_file.read())
# Load and process audio
waveform, sample_rate = torchaudio.load(audio_path)
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
# Convert audio to model input
input_values = processor(waveform.squeeze().numpy(), sampling_rate=16000).input_values[0]
# Perform ASR inference
with torch.no_grad():
input_tensor = torch.tensor([input_values]).to(device)
logits = model(input_tensor).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids)[0]
# Display transcription
st.success("πŸ“„ Transcription:")
st.write(transcription)
# ================================
# 7️⃣ Fine-Tune Model with User Correction
# ================================
user_correction = st.text_area("πŸ”§ Correct the transcription (if needed):", transcription)
if st.button("Fine-Tune with Correction"):
if user_correction:
corrected_input = processor.tokenizer(user_correction).input_ids
# Dynamically add new example to dataset
dataset = dataset.add_item({"input_values": input_values, "labels": corrected_input})
# Perform quick re-training (1 epoch)
trainer.args.num_train_epochs = 1
trainer.train()
st.success("βœ… Model fine-tuned with new correction! Try another audio file.")