Update app.py
Browse files
app.py
CHANGED
@@ -6,7 +6,6 @@ import numpy as np
|
|
6 |
import streamlit as st
|
7 |
import matplotlib.pyplot as plt
|
8 |
from huggingface_hub import login
|
9 |
-
from datasets import load_dataset, DatasetDict
|
10 |
from transformers import (
|
11 |
AutoProcessor,
|
12 |
AutoModelForSpeechSeq2Seq,
|
@@ -37,7 +36,7 @@ model.to(device)
|
|
37 |
print(f"β
Model loaded on {device}")
|
38 |
|
39 |
# ================================
|
40 |
-
# 3οΈβ£ Load Dataset (
|
41 |
# ================================
|
42 |
DATASET_TAR_PATH = "dev-clean.tar.gz"
|
43 |
EXTRACT_PATH = "./librispeech_dev_clean"
|
@@ -69,34 +68,75 @@ if not audio_files:
|
|
69 |
print(f"β
Found {len(audio_files)} audio files in dataset!")
|
70 |
|
71 |
# ================================
|
72 |
-
# 4οΈβ£
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
# ================================
|
74 |
def load_and_process_audio(audio_path):
|
75 |
"""Loads and processes a single audio file into model format."""
|
76 |
waveform, sample_rate = torchaudio.load(audio_path)
|
77 |
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
|
|
|
78 |
input_features = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features[0]
|
79 |
return input_features
|
80 |
|
81 |
-
dataset = [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
|
83 |
train_size = int(0.8 * len(dataset))
|
84 |
train_dataset = dataset[:train_size]
|
85 |
eval_dataset = dataset[train_size:]
|
86 |
|
87 |
-
print(f"β
Dataset
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
|
89 |
# ================================
|
90 |
-
#
|
91 |
# ================================
|
92 |
training_args = TrainingArguments(
|
93 |
output_dir="./asr_model_finetuned",
|
94 |
evaluation_strategy="epoch",
|
95 |
save_strategy="epoch",
|
96 |
-
learning_rate=
|
97 |
-
per_device_train_batch_size=
|
98 |
-
per_device_eval_batch_size=
|
99 |
-
num_train_epochs=
|
100 |
weight_decay=0.01,
|
101 |
logging_dir="./logs",
|
102 |
logging_steps=500,
|
@@ -117,14 +157,14 @@ trainer = Trainer(
|
|
117 |
)
|
118 |
|
119 |
# ================================
|
120 |
-
#
|
121 |
# ================================
|
122 |
-
if st.button("Start Fine-Tuning"):
|
123 |
with st.spinner("Fine-tuning in progress... Please wait!"):
|
124 |
trainer.train()
|
125 |
st.success("β
Fine-Tuning Completed! Model updated.")
|
126 |
|
127 |
-
# Plot Training Loss
|
128 |
train_loss = trainer.state.log_history
|
129 |
losses = [entry['loss'] for entry in train_loss if 'loss' in entry]
|
130 |
|
@@ -137,7 +177,7 @@ if st.button("Start Fine-Tuning"):
|
|
137 |
st.pyplot(plt)
|
138 |
|
139 |
# ================================
|
140 |
-
#
|
141 |
# ================================
|
142 |
st.title("ποΈ Speech-to-Text ASR Model with Fine-Tuning πΆ")
|
143 |
|
@@ -166,6 +206,5 @@ if audio_file:
|
|
166 |
)
|
167 |
transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
168 |
|
169 |
-
# Display transcription
|
170 |
st.success("π Transcription:")
|
171 |
st.write(transcription)
|
|
|
6 |
import streamlit as st
|
7 |
import matplotlib.pyplot as plt
|
8 |
from huggingface_hub import login
|
|
|
9 |
from transformers import (
|
10 |
AutoProcessor,
|
11 |
AutoModelForSpeechSeq2Seq,
|
|
|
36 |
print(f"β
Model loaded on {device}")
|
37 |
|
38 |
# ================================
|
39 |
+
# 3οΈβ£ Load Dataset (From Extracted Folder)
|
40 |
# ================================
|
41 |
DATASET_TAR_PATH = "dev-clean.tar.gz"
|
42 |
EXTRACT_PATH = "./librispeech_dev_clean"
|
|
|
68 |
print(f"β
Found {len(audio_files)} audio files in dataset!")
|
69 |
|
70 |
# ================================
|
71 |
+
# 4οΈβ£ Load Transcripts
|
72 |
+
# ================================
|
73 |
+
def load_transcripts():
|
74 |
+
"""Loads transcript text files and maps them to audio files."""
|
75 |
+
transcript_dict = {}
|
76 |
+
for root, _, files in os.walk(AUDIO_FOLDER):
|
77 |
+
for file in files:
|
78 |
+
if file.endswith(".txt"): # Transcript files
|
79 |
+
with open(os.path.join(root, file), "r", encoding="utf-8") as f:
|
80 |
+
for line in f:
|
81 |
+
parts = line.strip().split(" ", 1)
|
82 |
+
if len(parts) == 2:
|
83 |
+
file_id, text = parts
|
84 |
+
transcript_dict[file_id] = text
|
85 |
+
return transcript_dict
|
86 |
+
|
87 |
+
transcripts = load_transcripts()
|
88 |
+
if not transcripts:
|
89 |
+
raise FileNotFoundError("β No transcripts found! Check dataset structure.")
|
90 |
+
|
91 |
+
print(f"β
Loaded {len(transcripts)} transcripts.")
|
92 |
+
|
93 |
+
# ================================
|
94 |
+
# 5οΈβ£ Preprocess Dataset (Fixing `input_ids` issue)
|
95 |
# ================================
|
96 |
def load_and_process_audio(audio_path):
|
97 |
"""Loads and processes a single audio file into model format."""
|
98 |
waveform, sample_rate = torchaudio.load(audio_path)
|
99 |
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
|
100 |
+
|
101 |
input_features = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features[0]
|
102 |
return input_features
|
103 |
|
104 |
+
dataset = []
|
105 |
+
for audio_file in audio_files[:100]: # Limit to 100 for faster processing
|
106 |
+
file_id = os.path.basename(audio_file).replace(".flac", "")
|
107 |
+
|
108 |
+
if file_id in transcripts:
|
109 |
+
input_features = load_and_process_audio(audio_file)
|
110 |
+
labels = processor.tokenizer(transcripts[file_id], padding="max_length", truncation=True, return_tensors="pt").input_ids[0]
|
111 |
+
|
112 |
+
dataset.append({"input_features": input_features, "labels": labels})
|
113 |
|
114 |
train_size = int(0.8 * len(dataset))
|
115 |
train_dataset = dataset[:train_size]
|
116 |
eval_dataset = dataset[train_size:]
|
117 |
|
118 |
+
print(f"β
Dataset Prepared! Training: {len(train_dataset)}, Evaluation: {len(eval_dataset)}")
|
119 |
+
|
120 |
+
# ================================
|
121 |
+
# 6οΈβ£ Streamlit UI: Fine-Tuning Hyperparameter Selection
|
122 |
+
# ================================
|
123 |
+
st.sidebar.title("π§ Fine-Tuning Hyperparameters")
|
124 |
+
|
125 |
+
num_epochs = st.sidebar.slider("Epochs", min_value=1, max_value=10, value=3)
|
126 |
+
learning_rate = st.sidebar.select_slider("Learning Rate", options=[5e-4, 1e-4, 5e-5, 1e-5], value=5e-5)
|
127 |
+
batch_size = st.sidebar.select_slider("Batch Size", options=[2, 4, 8, 16], value=8)
|
128 |
|
129 |
# ================================
|
130 |
+
# 7οΈβ£ Training Arguments & Trainer
|
131 |
# ================================
|
132 |
training_args = TrainingArguments(
|
133 |
output_dir="./asr_model_finetuned",
|
134 |
evaluation_strategy="epoch",
|
135 |
save_strategy="epoch",
|
136 |
+
learning_rate=learning_rate,
|
137 |
+
per_device_train_batch_size=batch_size,
|
138 |
+
per_device_eval_batch_size=batch_size,
|
139 |
+
num_train_epochs=num_epochs,
|
140 |
weight_decay=0.01,
|
141 |
logging_dir="./logs",
|
142 |
logging_steps=500,
|
|
|
157 |
)
|
158 |
|
159 |
# ================================
|
160 |
+
# 8οΈβ£ Fine-Tuning Execution & Training Stats
|
161 |
# ================================
|
162 |
+
if st.sidebar.button("π Start Fine-Tuning"):
|
163 |
with st.spinner("Fine-tuning in progress... Please wait!"):
|
164 |
trainer.train()
|
165 |
st.success("β
Fine-Tuning Completed! Model updated.")
|
166 |
|
167 |
+
# β
Plot Training Loss
|
168 |
train_loss = trainer.state.log_history
|
169 |
losses = [entry['loss'] for entry in train_loss if 'loss' in entry]
|
170 |
|
|
|
177 |
st.pyplot(plt)
|
178 |
|
179 |
# ================================
|
180 |
+
# 9οΈβ£ Streamlit ASR Web App (Proper Decoding)
|
181 |
# ================================
|
182 |
st.title("ποΈ Speech-to-Text ASR Model with Fine-Tuning πΆ")
|
183 |
|
|
|
206 |
)
|
207 |
transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
208 |
|
|
|
209 |
st.success("π Transcription:")
|
210 |
st.write(transcription)
|