Update app.py
Browse files
app.py
CHANGED
@@ -131,12 +131,12 @@ batch_size = st.sidebar.select_slider("Batch Size", options=[2, 4, 8, 16], value
|
|
131 |
# ================================
|
132 |
training_args = TrainingArguments(
|
133 |
output_dir="./asr_model_finetuned",
|
134 |
-
|
135 |
save_strategy="epoch",
|
136 |
-
learning_rate=learning_rate,
|
137 |
-
per_device_train_batch_size=batch_size,
|
138 |
-
per_device_eval_batch_size=batch_size,
|
139 |
-
num_train_epochs=num_epochs,
|
140 |
weight_decay=0.01,
|
141 |
logging_dir="./logs",
|
142 |
logging_steps=500,
|
@@ -157,27 +157,7 @@ trainer = Trainer(
|
|
157 |
)
|
158 |
|
159 |
# ================================
|
160 |
-
# 8οΈβ£
|
161 |
-
# ================================
|
162 |
-
if st.sidebar.button("π Start Fine-Tuning"):
|
163 |
-
with st.spinner("Fine-tuning in progress... Please wait!"):
|
164 |
-
trainer.train()
|
165 |
-
st.success("β
Fine-Tuning Completed! Model updated.")
|
166 |
-
|
167 |
-
# β
Plot Training Loss
|
168 |
-
train_loss = trainer.state.log_history
|
169 |
-
losses = [entry['loss'] for entry in train_loss if 'loss' in entry]
|
170 |
-
|
171 |
-
plt.figure(figsize=(8, 5))
|
172 |
-
plt.plot(range(len(losses)), losses, label="Training Loss", color="blue")
|
173 |
-
plt.xlabel("Steps")
|
174 |
-
plt.ylabel("Loss")
|
175 |
-
plt.title("Training Loss Over Time")
|
176 |
-
plt.legend()
|
177 |
-
st.pyplot(plt)
|
178 |
-
|
179 |
-
# ================================
|
180 |
-
# 9οΈβ£ Streamlit ASR Web App (Proper Decoding)
|
181 |
# ================================
|
182 |
st.title("ποΈ Speech-to-Text ASR Model with Fine-Tuning πΆ")
|
183 |
|
@@ -191,22 +171,21 @@ if audio_file:
|
|
191 |
waveform, sample_rate = torchaudio.load(audio_path)
|
192 |
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
|
193 |
|
194 |
-
input_features = processor(
|
195 |
-
|
196 |
-
|
197 |
|
198 |
-
|
199 |
-
with torch.no_grad():
|
200 |
generated_ids = model.generate(
|
201 |
-
|
202 |
-
max_length=
|
203 |
-
num_beams=
|
204 |
-
do_sample=
|
205 |
-
|
|
|
|
|
206 |
)
|
207 |
transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
208 |
|
209 |
-
# Display transcription
|
210 |
st.success("π Transcription:")
|
211 |
st.write(transcription)
|
212 |
-
|
|
|
131 |
# ================================
|
132 |
training_args = TrainingArguments(
|
133 |
output_dir="./asr_model_finetuned",
|
134 |
+
eval_strategy="epoch",
|
135 |
save_strategy="epoch",
|
136 |
+
learning_rate=learning_rate,
|
137 |
+
per_device_train_batch_size=batch_size,
|
138 |
+
per_device_eval_batch_size=batch_size,
|
139 |
+
num_train_epochs=num_epochs,
|
140 |
weight_decay=0.01,
|
141 |
logging_dir="./logs",
|
142 |
logging_steps=500,
|
|
|
157 |
)
|
158 |
|
159 |
# ================================
|
160 |
+
# 8οΈβ£ Streamlit ASR Web App (Fast Decoding)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
161 |
# ================================
|
162 |
st.title("ποΈ Speech-to-Text ASR Model with Fine-Tuning πΆ")
|
163 |
|
|
|
171 |
waveform, sample_rate = torchaudio.load(audio_path)
|
172 |
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
|
173 |
|
174 |
+
input_features = processor(
|
175 |
+
waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt"
|
176 |
+
).input_features.to(device)
|
177 |
|
178 |
+
with torch.inference_mode():
|
|
|
179 |
generated_ids = model.generate(
|
180 |
+
input_features,
|
181 |
+
max_length=200,
|
182 |
+
num_beams=2,
|
183 |
+
do_sample=False,
|
184 |
+
use_cache=True,
|
185 |
+
language="en",
|
186 |
+
attention_mask=torch.ones(input_features.shape, dtype=torch.long).to(device),
|
187 |
)
|
188 |
transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
189 |
|
|
|
190 |
st.success("π Transcription:")
|
191 |
st.write(transcription)
|
|