File size: 5,252 Bytes
cd7aa15
fcd8965
cd7aa15
 
 
 
 
 
 
 
 
 
 
f0a5b40
cd7aa15
 
 
 
 
 
 
 
f0a5b40
 
 
 
cd7aa15
098a61e
cd7aa15
d2d38cf
cd7aa15
d2d38cf
fcd8965
098a61e
fcd8965
cd7aa15
 
 
 
 
fcd8965
 
f0a5b40
d2d38cf
 
 
393feaa
cd7aa15
d2d38cf
cd7aa15
d2d38cf
 
 
cd7aa15
d2d38cf
f0a5b40
cd7aa15
 
d2d38cf
 
 
 
 
 
f0a5b40
d2d38cf
f0a5b40
cd7aa15
 
 
f0a5b40
 
 
 
 
 
 
 
 
 
 
 
cd7aa15
f0a5b40
 
cd7aa15
f0a5b40
 
 
 
 
 
 
cd7aa15
f0a5b40
 
 
 
cd7aa15
 
 
 
 
 
 
f0a5b40
cd7aa15
 
 
 
f0a5b40
 
 
 
 
cd7aa15
 
 
f0a5b40
 
cd7aa15
 
f0a5b40
 
cd7aa15
f0a5b40
 
cd7aa15
f0a5b40
 
 
 
 
 
 
cd7aa15
f0a5b40
 
cd7aa15
 
 
 
 
 
f0a5b40
 
 
cd7aa15
d2d38cf
f0a5b40
cd7aa15
 
f0a5b40
 
cd7aa15
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import os
import tarfile
import torch
import torchaudio
import numpy as np
import streamlit as st
from transformers import (
    AutoProcessor,
    AutoModelForSpeechSeq2Seq,
    TrainingArguments,
    Trainer,
    DataCollatorForSeq2Seq,
)

# ================================
# 1️⃣ Load Model & Processor
# ================================
MODEL_NAME = "AqeelShafy7/AudioSangraha-Audio_to_Text"

# Load ASR model and processor
processor = AutoProcessor.from_pretrained(MODEL_NAME)
model = AutoModelForSpeechSeq2Seq.from_pretrained(MODEL_NAME)

# Move model to GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
print(f"βœ… Model loaded on {device}")

# ================================
# 2️⃣ Load Dataset (Manually from Extracted Path)
# ================================
DATASET_TAR_PATH = "dev-clean.tar.gz"  # Dataset stored in Hugging Face Space
EXTRACT_PATH = "./librispeech_dev_clean"  # Extracted dataset folder

# Extract dataset only if not already extracted
if not os.path.exists(EXTRACT_PATH):
    print("πŸ”„ Extracting dataset...")
    with tarfile.open(DATASET_TAR_PATH, "r:gz") as tar:
        tar.extractall(EXTRACT_PATH)
    print("βœ… Extraction complete.")
else:
    print("βœ… Dataset already extracted.")

# Load audio files manually
AUDIO_FOLDER = os.path.join(EXTRACT_PATH, "LibriSpeech", "train-clean-100")  # Adjust as per structure
audio_files = [os.path.join(AUDIO_FOLDER, f) for f in os.listdir(AUDIO_FOLDER) if f.endswith(".flac")]

# ================================
# 3️⃣ Preprocess Dataset (Manually)
# ================================
def load_and_process_audio(audio_path):
    """Loads and processes a single audio file into model format."""
    waveform, sample_rate = torchaudio.load(audio_path)

    # Resample to 16kHz
    waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)

    # Convert to model input format
    input_values = processor(waveform.squeeze().numpy(), sampling_rate=16000).input_values[0]

    return input_values

# Manually create dataset structure
dataset = [{"input_values": load_and_process_audio(f), "labels": []} for f in audio_files[:100]]  # Load first 100

print(f"βœ… Dataset Loaded! Processed {len(dataset)} audio files.")

# ================================
# 4️⃣ Training Arguments & Trainer
# ================================
training_args = TrainingArguments(
    output_dir="./asr_model_finetuned",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=3,
    weight_decay=0.01,
    logging_dir="./logs",
    logging_steps=500,
    save_total_limit=2,
    push_to_hub=True,
)

# Data collator (for dynamic padding)
data_collator = DataCollatorForSeq2Seq(processor.tokenizer, model=model)

# Define Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    eval_dataset=None,  # No validation dataset for now
    tokenizer=processor.feature_extractor,
    data_collator=data_collator,
)

# ================================
# 5️⃣ Fine-Tuning Execution
# ================================
if st.button("Start Fine-Tuning"):
    with st.spinner("Fine-tuning in progress... Please wait!"):
        trainer.train()
    st.success("βœ… Fine-Tuning Completed! Model updated.")

# ================================
# 6️⃣ Streamlit ASR Web App
# ================================
st.title("πŸŽ™οΈ Speech-to-Text ASR with Fine-Tuning 🎢")

# Upload audio file
audio_file = st.file_uploader("Upload an audio file", type=["wav", "mp3", "flac"])

if audio_file:
    # Save uploaded file temporarily
    audio_path = "temp_audio.wav"
    with open(audio_path, "wb") as f:
        f.write(audio_file.read())

    # Load and process audio
    waveform, sample_rate = torchaudio.load(audio_path)
    waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)

    # Convert audio to model input
    input_values = processor(waveform.squeeze().numpy(), sampling_rate=16000).input_values[0]

    # Perform ASR inference
    with torch.no_grad():
        input_tensor = torch.tensor([input_values]).to(device)
        logits = model(input_tensor).logits
        predicted_ids = torch.argmax(logits, dim=-1)
        transcription = processor.batch_decode(predicted_ids)[0]

    # Display transcription
    st.success("πŸ“„ Transcription:")
    st.write(transcription)

    # ================================
    # 7️⃣ Fine-Tune Model with User Correction
    # ================================
    user_correction = st.text_area("πŸ”§ Correct the transcription (if needed):", transcription)

    if st.button("Fine-Tune with Correction"):
        if user_correction:
            corrected_input = processor.tokenizer(user_correction).input_ids

            # Dynamically add new example to dataset
            dataset.append({"input_values": input_values, "labels": corrected_input})

            # Perform quick re-training (1 epoch)
            trainer.args.num_train_epochs = 1
            trainer.train()

            st.success("βœ… Model fine-tuned with new correction! Try another audio file.")