tahirsher's picture
Update app.py
d2d38cf verified
raw
history blame
5.25 kB
import os
import tarfile
import torch
import torchaudio
import numpy as np
import streamlit as st
from transformers import (
AutoProcessor,
AutoModelForSpeechSeq2Seq,
TrainingArguments,
Trainer,
DataCollatorForSeq2Seq,
)
# ================================
# 1️⃣ Load Model & Processor
# ================================
MODEL_NAME = "AqeelShafy7/AudioSangraha-Audio_to_Text"
# Load ASR model and processor
processor = AutoProcessor.from_pretrained(MODEL_NAME)
model = AutoModelForSpeechSeq2Seq.from_pretrained(MODEL_NAME)
# Move model to GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
print(f"βœ… Model loaded on {device}")
# ================================
# 2️⃣ Load Dataset (Manually from Extracted Path)
# ================================
DATASET_TAR_PATH = "dev-clean.tar.gz" # Dataset stored in Hugging Face Space
EXTRACT_PATH = "./librispeech_dev_clean" # Extracted dataset folder
# Extract dataset only if not already extracted
if not os.path.exists(EXTRACT_PATH):
print("πŸ”„ Extracting dataset...")
with tarfile.open(DATASET_TAR_PATH, "r:gz") as tar:
tar.extractall(EXTRACT_PATH)
print("βœ… Extraction complete.")
else:
print("βœ… Dataset already extracted.")
# Load audio files manually
AUDIO_FOLDER = os.path.join(EXTRACT_PATH, "LibriSpeech", "train-clean-100") # Adjust as per structure
audio_files = [os.path.join(AUDIO_FOLDER, f) for f in os.listdir(AUDIO_FOLDER) if f.endswith(".flac")]
# ================================
# 3️⃣ Preprocess Dataset (Manually)
# ================================
def load_and_process_audio(audio_path):
"""Loads and processes a single audio file into model format."""
waveform, sample_rate = torchaudio.load(audio_path)
# Resample to 16kHz
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
# Convert to model input format
input_values = processor(waveform.squeeze().numpy(), sampling_rate=16000).input_values[0]
return input_values
# Manually create dataset structure
dataset = [{"input_values": load_and_process_audio(f), "labels": []} for f in audio_files[:100]] # Load first 100
print(f"βœ… Dataset Loaded! Processed {len(dataset)} audio files.")
# ================================
# 4️⃣ Training Arguments & Trainer
# ================================
training_args = TrainingArguments(
output_dir="./asr_model_finetuned",
evaluation_strategy="epoch",
save_strategy="epoch",
learning_rate=5e-5,
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
num_train_epochs=3,
weight_decay=0.01,
logging_dir="./logs",
logging_steps=500,
save_total_limit=2,
push_to_hub=True,
)
# Data collator (for dynamic padding)
data_collator = DataCollatorForSeq2Seq(processor.tokenizer, model=model)
# Define Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset,
eval_dataset=None, # No validation dataset for now
tokenizer=processor.feature_extractor,
data_collator=data_collator,
)
# ================================
# 5️⃣ Fine-Tuning Execution
# ================================
if st.button("Start Fine-Tuning"):
with st.spinner("Fine-tuning in progress... Please wait!"):
trainer.train()
st.success("βœ… Fine-Tuning Completed! Model updated.")
# ================================
# 6️⃣ Streamlit ASR Web App
# ================================
st.title("πŸŽ™οΈ Speech-to-Text ASR with Fine-Tuning 🎢")
# Upload audio file
audio_file = st.file_uploader("Upload an audio file", type=["wav", "mp3", "flac"])
if audio_file:
# Save uploaded file temporarily
audio_path = "temp_audio.wav"
with open(audio_path, "wb") as f:
f.write(audio_file.read())
# Load and process audio
waveform, sample_rate = torchaudio.load(audio_path)
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
# Convert audio to model input
input_values = processor(waveform.squeeze().numpy(), sampling_rate=16000).input_values[0]
# Perform ASR inference
with torch.no_grad():
input_tensor = torch.tensor([input_values]).to(device)
logits = model(input_tensor).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids)[0]
# Display transcription
st.success("πŸ“„ Transcription:")
st.write(transcription)
# ================================
# 7️⃣ Fine-Tune Model with User Correction
# ================================
user_correction = st.text_area("πŸ”§ Correct the transcription (if needed):", transcription)
if st.button("Fine-Tune with Correction"):
if user_correction:
corrected_input = processor.tokenizer(user_correction).input_ids
# Dynamically add new example to dataset
dataset.append({"input_values": input_values, "labels": corrected_input})
# Perform quick re-training (1 epoch)
trainer.args.num_train_epochs = 1
trainer.train()
st.success("βœ… Model fine-tuned with new correction! Try another audio file.")