import os import tarfile import torch import torchaudio import numpy as np import streamlit as st from transformers import ( AutoProcessor, AutoModelForSpeechSeq2Seq, TrainingArguments, Trainer, DataCollatorForSeq2Seq, ) # ================================ # 1️⃣ Load Model & Processor # ================================ MODEL_NAME = "AqeelShafy7/AudioSangraha-Audio_to_Text" # Load ASR model and processor processor = AutoProcessor.from_pretrained(MODEL_NAME) model = AutoModelForSpeechSeq2Seq.from_pretrained(MODEL_NAME) # Move model to GPU if available device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) print(f"✅ Model loaded on {device}") # ================================ # 2️⃣ Load Dataset (Recursively from Extracted Path) # ================================ DATASET_TAR_PATH = "dev-clean.tar.gz" EXTRACT_PATH = "./librispeech_dev_clean" # Extract dataset if not already extracted if not os.path.exists(EXTRACT_PATH): print("🔄 Extracting dataset...") with tarfile.open(DATASET_TAR_PATH, "r:gz") as tar: tar.extractall(EXTRACT_PATH) print("✅ Extraction complete.") else: print("✅ Dataset already extracted.") # Base directory where audio files are stored AUDIO_FOLDER = os.path.join(EXTRACT_PATH, "LibriSpeech", "dev-clean") # Recursively find all `.flac` files inside the dataset directory def find_audio_files(base_folder): """Recursively search for all .flac files in subdirectories.""" audio_files = [] for root, _, files in os.walk(base_folder): for file in files: if file.endswith(".flac"): audio_files.append(os.path.join(root, file)) return audio_files # Get all audio files audio_files = find_audio_files(AUDIO_FOLDER) if not audio_files: raise FileNotFoundError(f"❌ No .flac files found in {AUDIO_FOLDER}. Check dataset structure!") print(f"✅ Found {len(audio_files)} audio files in dataset!") # ================================ # 3️⃣ Preprocess Dataset (Fixed input_features) # ================================ def load_and_process_audio(audio_path): """Loads and processes a single audio file into model format.""" waveform, sample_rate = torchaudio.load(audio_path) # Resample to 16kHz waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform) # Convert to model input format input_features = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features[0] return input_features # Manually create dataset structure dataset = [{"input_features": load_and_process_audio(f), "labels": []} for f in audio_files[:100]] # Split dataset into train and eval (Recommended Fix) train_size = int(0.9 * len(dataset)) train_dataset = dataset[:train_size] eval_dataset = dataset[train_size:] print(f"✅ Dataset Loaded! Training: {len(train_dataset)}, Evaluation: {len(eval_dataset)}") # ================================ # 4️⃣ Training Arguments & Trainer # ================================ training_args = TrainingArguments( output_dir="./asr_model_finetuned", eval_strategy="epoch", # Fix: Proper evaluation save_strategy="epoch", learning_rate=5e-5, per_device_train_batch_size=8, per_device_eval_batch_size=8, num_train_epochs=3, weight_decay=0.01, logging_dir="./logs", logging_steps=500, save_total_limit=2, push_to_hub=True, ) # Data collator (for dynamic padding) data_collator = DataCollatorForSeq2Seq(processor, model=model) # Define Trainer (Fixed `processing_class` warning) trainer = Trainer( model=model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset, # Fix: Providing eval_dataset processing_class=processor, # Fix: Replacing deprecated `tokenizer` data_collator=data_collator, ) # ================================ # 5️⃣ Fine-Tuning Execution # ================================ if st.button("Start Fine-Tuning"): with st.spinner("Fine-tuning in progress... Please wait!"): trainer.train() st.success("✅ Fine-Tuning Completed! Model updated.") # ================================ # 6️⃣ Streamlit ASR Web App # ================================ st.title("🎙️ Speech-to-Text ASR with Fine-Tuning 🎶") # Upload audio file audio_file = st.file_uploader("Upload an audio file", type=["wav", "mp3", "flac"]) if audio_file: # Save uploaded file temporarily audio_path = "temp_audio.wav" with open(audio_path, "wb") as f: f.write(audio_file.read()) # Load and process audio waveform, sample_rate = torchaudio.load(audio_path) waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform) # Convert audio to model input input_features = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features[0] # Perform ASR inference with torch.no_grad(): input_tensor = torch.tensor([input_features]).to(device) logits = model(input_tensor).logits predicted_ids = torch.argmax(logits, dim=-1) transcription = processor.batch_decode(predicted_ids)[0] # Display transcription st.success("📄 Transcription:") st.write(transcription) # ================================ # 7️⃣ Fine-Tune Model with User Correction # ================================ user_correction = st.text_area("🔧 Correct the transcription (if needed):", transcription) if st.button("Fine-Tune with Correction"): if user_correction: corrected_input = processor.tokenizer(user_correction).input_ids # Dynamically add new example to dataset dataset.append({"input_features": input_features, "labels": corrected_input}) # Perform quick re-training (1 epoch) trainer.args.num_train_epochs = 1 trainer.train() st.success("✅ Model fine-tuned with new correction! Try another audio file.")