tahirsher commited on
Commit
3a9d859
Β·
verified Β·
1 Parent(s): 771c2e9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -15
app.py CHANGED
@@ -6,7 +6,6 @@ import numpy as np
6
  import streamlit as st
7
  import matplotlib.pyplot as plt
8
  from huggingface_hub import login
9
- from datasets import load_dataset, DatasetDict
10
  from transformers import (
11
  AutoProcessor,
12
  AutoModelForSpeechSeq2Seq,
@@ -37,7 +36,7 @@ model.to(device)
37
  print(f"βœ… Model loaded on {device}")
38
 
39
  # ================================
40
- # 3️⃣ Load Dataset (Recursively from Extracted Path)
41
  # ================================
42
  DATASET_TAR_PATH = "dev-clean.tar.gz"
43
  EXTRACT_PATH = "./librispeech_dev_clean"
@@ -69,34 +68,75 @@ if not audio_files:
69
  print(f"βœ… Found {len(audio_files)} audio files in dataset!")
70
 
71
  # ================================
72
- # 4️⃣ Preprocess Dataset
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  # ================================
74
  def load_and_process_audio(audio_path):
75
  """Loads and processes a single audio file into model format."""
76
  waveform, sample_rate = torchaudio.load(audio_path)
77
  waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
 
78
  input_features = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features[0]
79
  return input_features
80
 
81
- dataset = [{"input_features": load_and_process_audio(f), "labels": []} for f in audio_files[:100]]
 
 
 
 
 
 
 
 
82
 
83
  train_size = int(0.8 * len(dataset))
84
  train_dataset = dataset[:train_size]
85
  eval_dataset = dataset[train_size:]
86
 
87
- print(f"βœ… Dataset Loaded! Training: {len(train_dataset)}, Evaluation: {len(eval_dataset)}")
 
 
 
 
 
 
 
 
 
88
 
89
  # ================================
90
- # 4️⃣ Training Arguments & Trainer
91
  # ================================
92
  training_args = TrainingArguments(
93
  output_dir="./asr_model_finetuned",
94
  evaluation_strategy="epoch",
95
  save_strategy="epoch",
96
- learning_rate=5e-5,
97
- per_device_train_batch_size=8,
98
- per_device_eval_batch_size=8,
99
- num_train_epochs=3,
100
  weight_decay=0.01,
101
  logging_dir="./logs",
102
  logging_steps=500,
@@ -117,14 +157,14 @@ trainer = Trainer(
117
  )
118
 
119
  # ================================
120
- # 5️⃣ Fine-Tuning Execution & Training Stats
121
  # ================================
122
- if st.button("Start Fine-Tuning"):
123
  with st.spinner("Fine-tuning in progress... Please wait!"):
124
  trainer.train()
125
  st.success("βœ… Fine-Tuning Completed! Model updated.")
126
 
127
- # Plot Training Loss
128
  train_loss = trainer.state.log_history
129
  losses = [entry['loss'] for entry in train_loss if 'loss' in entry]
130
 
@@ -137,7 +177,7 @@ if st.button("Start Fine-Tuning"):
137
  st.pyplot(plt)
138
 
139
  # ================================
140
- # 6️⃣ Streamlit ASR Web App (Proper Decoding)
141
  # ================================
142
  st.title("πŸŽ™οΈ Speech-to-Text ASR Model with Fine-Tuning 🎢")
143
 
@@ -166,6 +206,5 @@ if audio_file:
166
  )
167
  transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
168
 
169
- # Display transcription
170
  st.success("πŸ“„ Transcription:")
171
  st.write(transcription)
 
6
  import streamlit as st
7
  import matplotlib.pyplot as plt
8
  from huggingface_hub import login
 
9
  from transformers import (
10
  AutoProcessor,
11
  AutoModelForSpeechSeq2Seq,
 
36
  print(f"βœ… Model loaded on {device}")
37
 
38
  # ================================
39
+ # 3️⃣ Load Dataset (From Extracted Folder)
40
  # ================================
41
  DATASET_TAR_PATH = "dev-clean.tar.gz"
42
  EXTRACT_PATH = "./librispeech_dev_clean"
 
68
  print(f"βœ… Found {len(audio_files)} audio files in dataset!")
69
 
70
  # ================================
71
+ # 4️⃣ Load Transcripts
72
+ # ================================
73
+ def load_transcripts():
74
+ """Loads transcript text files and maps them to audio files."""
75
+ transcript_dict = {}
76
+ for root, _, files in os.walk(AUDIO_FOLDER):
77
+ for file in files:
78
+ if file.endswith(".txt"): # Transcript files
79
+ with open(os.path.join(root, file), "r", encoding="utf-8") as f:
80
+ for line in f:
81
+ parts = line.strip().split(" ", 1)
82
+ if len(parts) == 2:
83
+ file_id, text = parts
84
+ transcript_dict[file_id] = text
85
+ return transcript_dict
86
+
87
+ transcripts = load_transcripts()
88
+ if not transcripts:
89
+ raise FileNotFoundError("❌ No transcripts found! Check dataset structure.")
90
+
91
+ print(f"βœ… Loaded {len(transcripts)} transcripts.")
92
+
93
+ # ================================
94
+ # 5️⃣ Preprocess Dataset (Fixing `input_ids` issue)
95
  # ================================
96
  def load_and_process_audio(audio_path):
97
  """Loads and processes a single audio file into model format."""
98
  waveform, sample_rate = torchaudio.load(audio_path)
99
  waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
100
+
101
  input_features = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features[0]
102
  return input_features
103
 
104
+ dataset = []
105
+ for audio_file in audio_files[:100]: # Limit to 100 for faster processing
106
+ file_id = os.path.basename(audio_file).replace(".flac", "")
107
+
108
+ if file_id in transcripts:
109
+ input_features = load_and_process_audio(audio_file)
110
+ labels = processor.tokenizer(transcripts[file_id], padding="max_length", truncation=True, return_tensors="pt").input_ids[0]
111
+
112
+ dataset.append({"input_features": input_features, "labels": labels})
113
 
114
  train_size = int(0.8 * len(dataset))
115
  train_dataset = dataset[:train_size]
116
  eval_dataset = dataset[train_size:]
117
 
118
+ print(f"βœ… Dataset Prepared! Training: {len(train_dataset)}, Evaluation: {len(eval_dataset)}")
119
+
120
+ # ================================
121
+ # 6️⃣ Streamlit UI: Fine-Tuning Hyperparameter Selection
122
+ # ================================
123
+ st.sidebar.title("πŸ”§ Fine-Tuning Hyperparameters")
124
+
125
+ num_epochs = st.sidebar.slider("Epochs", min_value=1, max_value=10, value=3)
126
+ learning_rate = st.sidebar.select_slider("Learning Rate", options=[5e-4, 1e-4, 5e-5, 1e-5], value=5e-5)
127
+ batch_size = st.sidebar.select_slider("Batch Size", options=[2, 4, 8, 16], value=8)
128
 
129
  # ================================
130
+ # 7️⃣ Training Arguments & Trainer
131
  # ================================
132
  training_args = TrainingArguments(
133
  output_dir="./asr_model_finetuned",
134
  evaluation_strategy="epoch",
135
  save_strategy="epoch",
136
+ learning_rate=learning_rate,
137
+ per_device_train_batch_size=batch_size,
138
+ per_device_eval_batch_size=batch_size,
139
+ num_train_epochs=num_epochs,
140
  weight_decay=0.01,
141
  logging_dir="./logs",
142
  logging_steps=500,
 
157
  )
158
 
159
  # ================================
160
+ # 8️⃣ Fine-Tuning Execution & Training Stats
161
  # ================================
162
+ if st.sidebar.button("πŸš€ Start Fine-Tuning"):
163
  with st.spinner("Fine-tuning in progress... Please wait!"):
164
  trainer.train()
165
  st.success("βœ… Fine-Tuning Completed! Model updated.")
166
 
167
+ # βœ… Plot Training Loss
168
  train_loss = trainer.state.log_history
169
  losses = [entry['loss'] for entry in train_loss if 'loss' in entry]
170
 
 
177
  st.pyplot(plt)
178
 
179
  # ================================
180
+ # 9️⃣ Streamlit ASR Web App (Proper Decoding)
181
  # ================================
182
  st.title("πŸŽ™οΈ Speech-to-Text ASR Model with Fine-Tuning 🎢")
183
 
 
206
  )
207
  transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
208
 
 
209
  st.success("πŸ“„ Transcription:")
210
  st.write(transcription)