tahirsher commited on
Commit
a312467
Β·
verified Β·
1 Parent(s): 2e48e3c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -38
app.py CHANGED
@@ -131,12 +131,12 @@ batch_size = st.sidebar.select_slider("Batch Size", options=[2, 4, 8, 16], value
131
  # ================================
132
  training_args = TrainingArguments(
133
  output_dir="./asr_model_finetuned",
134
- evaluation_strategy="epoch",
135
  save_strategy="epoch",
136
- learning_rate=learning_rate,
137
- per_device_train_batch_size=batch_size,
138
- per_device_eval_batch_size=batch_size,
139
- num_train_epochs=num_epochs,
140
  weight_decay=0.01,
141
  logging_dir="./logs",
142
  logging_steps=500,
@@ -157,27 +157,7 @@ trainer = Trainer(
157
  )
158
 
159
  # ================================
160
- # 8️⃣ Fine-Tuning Execution & Training Stats
161
- # ================================
162
- if st.sidebar.button("πŸš€ Start Fine-Tuning"):
163
- with st.spinner("Fine-tuning in progress... Please wait!"):
164
- trainer.train()
165
- st.success("βœ… Fine-Tuning Completed! Model updated.")
166
-
167
- # βœ… Plot Training Loss
168
- train_loss = trainer.state.log_history
169
- losses = [entry['loss'] for entry in train_loss if 'loss' in entry]
170
-
171
- plt.figure(figsize=(8, 5))
172
- plt.plot(range(len(losses)), losses, label="Training Loss", color="blue")
173
- plt.xlabel("Steps")
174
- plt.ylabel("Loss")
175
- plt.title("Training Loss Over Time")
176
- plt.legend()
177
- st.pyplot(plt)
178
-
179
- # ================================
180
- # 9️⃣ Streamlit ASR Web App (Proper Decoding)
181
  # ================================
182
  st.title("πŸŽ™οΈ Speech-to-Text ASR Model with Fine-Tuning 🎢")
183
 
@@ -191,22 +171,21 @@ if audio_file:
191
  waveform, sample_rate = torchaudio.load(audio_path)
192
  waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
193
 
194
- input_features = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features
195
-
196
- input_tensor = input_features.to(device)
197
 
198
- # βœ… FIX: Use `generate()` for Proper Transcription
199
- with torch.no_grad():
200
  generated_ids = model.generate(
201
- input_tensor,
202
- max_length=500,
203
- num_beams=5,
204
- do_sample=True,
205
- top_k=50
 
 
206
  )
207
  transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
208
 
209
- # Display transcription
210
  st.success("πŸ“„ Transcription:")
211
  st.write(transcription)
212
-
 
131
  # ================================
132
  training_args = TrainingArguments(
133
  output_dir="./asr_model_finetuned",
134
+ eval_strategy="epoch",
135
  save_strategy="epoch",
136
+ learning_rate=learning_rate,
137
+ per_device_train_batch_size=batch_size,
138
+ per_device_eval_batch_size=batch_size,
139
+ num_train_epochs=num_epochs,
140
  weight_decay=0.01,
141
  logging_dir="./logs",
142
  logging_steps=500,
 
157
  )
158
 
159
  # ================================
160
+ # 8️⃣ Streamlit ASR Web App (Fast Decoding)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  # ================================
162
  st.title("πŸŽ™οΈ Speech-to-Text ASR Model with Fine-Tuning 🎢")
163
 
 
171
  waveform, sample_rate = torchaudio.load(audio_path)
172
  waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
173
 
174
+ input_features = processor(
175
+ waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt"
176
+ ).input_features.to(device)
177
 
178
+ with torch.inference_mode():
 
179
  generated_ids = model.generate(
180
+ input_features,
181
+ max_length=200,
182
+ num_beams=2,
183
+ do_sample=False,
184
+ use_cache=True,
185
+ language="en",
186
+ attention_mask=torch.ones(input_features.shape, dtype=torch.long).to(device),
187
  )
188
  transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
189
 
 
190
  st.success("πŸ“„ Transcription:")
191
  st.write(transcription)