import os import tarfile import torch import torchaudio import numpy as np import streamlit as st from transformers import ( AutoProcessor, AutoModelForSpeechSeq2Seq, TrainingArguments, Trainer, DataCollatorForSeq2Seq, ) # ================================ # 1️⃣ Load Model & Processor # ================================ MODEL_NAME = "AqeelShafy7/AudioSangraha-Audio_to_Text" # Load ASR model and processor processor = AutoProcessor.from_pretrained(MODEL_NAME) model = AutoModelForSpeechSeq2Seq.from_pretrained(MODEL_NAME) # Move model to GPU if available device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) print(f"✅ Model loaded on {device}") # ================================ # 2️⃣ Load Dataset (Manually from Extracted Path) # ================================ DATASET_TAR_PATH = "dev-clean.tar.gz" # Dataset stored in Hugging Face Space EXTRACT_PATH = "./librispeech_dev_clean" # Extracted dataset folder # Extract dataset only if not already extracted if not os.path.exists(EXTRACT_PATH): print("🔄 Extracting dataset...") with tarfile.open(DATASET_TAR_PATH, "r:gz") as tar: tar.extractall(EXTRACT_PATH) print("✅ Extraction complete.") else: print("✅ Dataset already extracted.") # Load audio files manually AUDIO_FOLDER = os.path.join(EXTRACT_PATH, "LibriSpeech", "train-clean-100") # Adjust as per structure audio_files = [os.path.join(AUDIO_FOLDER, f) for f in os.listdir(AUDIO_FOLDER) if f.endswith(".flac")] # ================================ # 3️⃣ Preprocess Dataset (Manually) # ================================ def load_and_process_audio(audio_path): """Loads and processes a single audio file into model format.""" waveform, sample_rate = torchaudio.load(audio_path) # Resample to 16kHz waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform) # Convert to model input format input_values = processor(waveform.squeeze().numpy(), sampling_rate=16000).input_values[0] return input_values # Manually create dataset structure dataset = [{"input_values": load_and_process_audio(f), "labels": []} for f in audio_files[:100]] # Load first 100 print(f"✅ Dataset Loaded! Processed {len(dataset)} audio files.") # ================================ # 4️⃣ Training Arguments & Trainer # ================================ training_args = TrainingArguments( output_dir="./asr_model_finetuned", evaluation_strategy="epoch", save_strategy="epoch", learning_rate=5e-5, per_device_train_batch_size=8, per_device_eval_batch_size=8, num_train_epochs=3, weight_decay=0.01, logging_dir="./logs", logging_steps=500, save_total_limit=2, push_to_hub=True, ) # Data collator (for dynamic padding) data_collator = DataCollatorForSeq2Seq(processor.tokenizer, model=model) # Define Trainer trainer = Trainer( model=model, args=training_args, train_dataset=dataset, eval_dataset=None, # No validation dataset for now tokenizer=processor.feature_extractor, data_collator=data_collator, ) # ================================ # 5️⃣ Fine-Tuning Execution # ================================ if st.button("Start Fine-Tuning"): with st.spinner("Fine-tuning in progress... Please wait!"): trainer.train() st.success("✅ Fine-Tuning Completed! Model updated.") # ================================ # 6️⃣ Streamlit ASR Web App # ================================ st.title("🎙️ Speech-to-Text ASR with Fine-Tuning 🎶") # Upload audio file audio_file = st.file_uploader("Upload an audio file", type=["wav", "mp3", "flac"]) if audio_file: # Save uploaded file temporarily audio_path = "temp_audio.wav" with open(audio_path, "wb") as f: f.write(audio_file.read()) # Load and process audio waveform, sample_rate = torchaudio.load(audio_path) waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform) # Convert audio to model input input_values = processor(waveform.squeeze().numpy(), sampling_rate=16000).input_values[0] # Perform ASR inference with torch.no_grad(): input_tensor = torch.tensor([input_values]).to(device) logits = model(input_tensor).logits predicted_ids = torch.argmax(logits, dim=-1) transcription = processor.batch_decode(predicted_ids)[0] # Display transcription st.success("📄 Transcription:") st.write(transcription) # ================================ # 7️⃣ Fine-Tune Model with User Correction # ================================ user_correction = st.text_area("🔧 Correct the transcription (if needed):", transcription) if st.button("Fine-Tune with Correction"): if user_correction: corrected_input = processor.tokenizer(user_correction).input_ids # Dynamically add new example to dataset dataset.append({"input_values": input_values, "labels": corrected_input}) # Perform quick re-training (1 epoch) trainer.args.num_train_epochs = 1 trainer.train() st.success("✅ Model fine-tuned with new correction! Try another audio file.")