tahirsher commited on
Commit
3a79217
Β·
verified Β·
1 Parent(s): 78855a4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -15
app.py CHANGED
@@ -29,10 +29,10 @@ print(f"βœ… Model loaded on {device}")
29
  # ================================
30
  # 2️⃣ Load Dataset (Recursively from Extracted Path)
31
  # ================================
32
- DATASET_TAR_PATH = "dev-clean.tar.gz" # Dataset stored in Hugging Face Space
33
- EXTRACT_PATH = "./librispeech_dev_clean" # Extracted dataset folder
34
 
35
- # Extract dataset only if not already extracted
36
  if not os.path.exists(EXTRACT_PATH):
37
  print("πŸ”„ Extracting dataset...")
38
  with tarfile.open(DATASET_TAR_PATH, "r:gz") as tar:
@@ -41,7 +41,7 @@ if not os.path.exists(EXTRACT_PATH):
41
  else:
42
  print("βœ… Dataset already extracted.")
43
 
44
- # Define the base directory where audio files are stored
45
  AUDIO_FOLDER = os.path.join(EXTRACT_PATH, "LibriSpeech", "dev-clean")
46
 
47
  # Recursively find all `.flac` files inside the dataset directory
@@ -57,7 +57,6 @@ def find_audio_files(base_folder):
57
  # Get all audio files
58
  audio_files = find_audio_files(AUDIO_FOLDER)
59
 
60
- # Check if audio files were found
61
  if not audio_files:
62
  raise FileNotFoundError(f"❌ No .flac files found in {AUDIO_FOLDER}. Check dataset structure!")
63
 
@@ -73,22 +72,27 @@ def load_and_process_audio(audio_path):
73
  # Resample to 16kHz
74
  waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
75
 
76
- # Convert to model input format (Fixed key: use input_features instead of input_values)
77
  input_features = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features[0]
78
 
79
  return input_features
80
 
81
  # Manually create dataset structure
82
- dataset = [{"input_features": load_and_process_audio(f), "labels": []} for f in audio_files[:100]] # Load first 100
83
 
84
- print(f"βœ… Dataset Loaded! Processed {len(dataset)} audio files.")
 
 
 
 
 
85
 
86
  # ================================
87
  # 4️⃣ Training Arguments & Trainer
88
  # ================================
89
  training_args = TrainingArguments(
90
  output_dir="./asr_model_finetuned",
91
- evaluation_strategy="epoch",
92
  save_strategy="epoch",
93
  learning_rate=5e-5,
94
  per_device_train_batch_size=8,
@@ -102,15 +106,15 @@ training_args = TrainingArguments(
102
  )
103
 
104
  # Data collator (for dynamic padding)
105
- data_collator = DataCollatorForSeq2Seq(processor.tokenizer, model=model)
106
 
107
- # Define Trainer
108
  trainer = Trainer(
109
  model=model,
110
  args=training_args,
111
- train_dataset=dataset,
112
- eval_dataset=None, # No validation dataset for now
113
- tokenizer=processor.feature_extractor,
114
  data_collator=data_collator,
115
  )
116
 
@@ -140,7 +144,7 @@ if audio_file:
140
  waveform, sample_rate = torchaudio.load(audio_path)
141
  waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
142
 
143
- # Convert audio to model input (Fixed key: use input_features)
144
  input_features = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features[0]
145
 
146
  # Perform ASR inference
 
29
  # ================================
30
  # 2️⃣ Load Dataset (Recursively from Extracted Path)
31
  # ================================
32
+ DATASET_TAR_PATH = "dev-clean.tar.gz"
33
+ EXTRACT_PATH = "./librispeech_dev_clean"
34
 
35
+ # Extract dataset if not already extracted
36
  if not os.path.exists(EXTRACT_PATH):
37
  print("πŸ”„ Extracting dataset...")
38
  with tarfile.open(DATASET_TAR_PATH, "r:gz") as tar:
 
41
  else:
42
  print("βœ… Dataset already extracted.")
43
 
44
+ # Base directory where audio files are stored
45
  AUDIO_FOLDER = os.path.join(EXTRACT_PATH, "LibriSpeech", "dev-clean")
46
 
47
  # Recursively find all `.flac` files inside the dataset directory
 
57
  # Get all audio files
58
  audio_files = find_audio_files(AUDIO_FOLDER)
59
 
 
60
  if not audio_files:
61
  raise FileNotFoundError(f"❌ No .flac files found in {AUDIO_FOLDER}. Check dataset structure!")
62
 
 
72
  # Resample to 16kHz
73
  waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
74
 
75
+ # Convert to model input format
76
  input_features = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features[0]
77
 
78
  return input_features
79
 
80
  # Manually create dataset structure
81
+ dataset = [{"input_features": load_and_process_audio(f), "labels": []} for f in audio_files[:100]]
82
 
83
+ # Split dataset into train and eval (Recommended Fix)
84
+ train_size = int(0.9 * len(dataset))
85
+ train_dataset = dataset[:train_size]
86
+ eval_dataset = dataset[train_size:]
87
+
88
+ print(f"βœ… Dataset Loaded! Training: {len(train_dataset)}, Evaluation: {len(eval_dataset)}")
89
 
90
  # ================================
91
  # 4️⃣ Training Arguments & Trainer
92
  # ================================
93
  training_args = TrainingArguments(
94
  output_dir="./asr_model_finetuned",
95
+ eval_strategy="epoch", # Fix: Proper evaluation
96
  save_strategy="epoch",
97
  learning_rate=5e-5,
98
  per_device_train_batch_size=8,
 
106
  )
107
 
108
  # Data collator (for dynamic padding)
109
+ data_collator = DataCollatorForSeq2Seq(processor, model=model)
110
 
111
+ # Define Trainer (Fixed `processing_class` warning)
112
  trainer = Trainer(
113
  model=model,
114
  args=training_args,
115
+ train_dataset=train_dataset,
116
+ eval_dataset=eval_dataset, # Fix: Providing eval_dataset
117
+ processing_class=processor, # Fix: Replacing deprecated `tokenizer`
118
  data_collator=data_collator,
119
  )
120
 
 
144
  waveform, sample_rate = torchaudio.load(audio_path)
145
  waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
146
 
147
+ # Convert audio to model input
148
  input_features = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features[0]
149
 
150
  # Perform ASR inference