Update app.py
Browse files
app.py
CHANGED
@@ -29,10 +29,10 @@ print(f"β
Model loaded on {device}")
|
|
29 |
# ================================
|
30 |
# 2οΈβ£ Load Dataset (Recursively from Extracted Path)
|
31 |
# ================================
|
32 |
-
DATASET_TAR_PATH = "dev-clean.tar.gz"
|
33 |
-
EXTRACT_PATH = "./librispeech_dev_clean"
|
34 |
|
35 |
-
# Extract dataset
|
36 |
if not os.path.exists(EXTRACT_PATH):
|
37 |
print("π Extracting dataset...")
|
38 |
with tarfile.open(DATASET_TAR_PATH, "r:gz") as tar:
|
@@ -41,7 +41,7 @@ if not os.path.exists(EXTRACT_PATH):
|
|
41 |
else:
|
42 |
print("β
Dataset already extracted.")
|
43 |
|
44 |
-
#
|
45 |
AUDIO_FOLDER = os.path.join(EXTRACT_PATH, "LibriSpeech", "dev-clean")
|
46 |
|
47 |
# Recursively find all `.flac` files inside the dataset directory
|
@@ -57,7 +57,6 @@ def find_audio_files(base_folder):
|
|
57 |
# Get all audio files
|
58 |
audio_files = find_audio_files(AUDIO_FOLDER)
|
59 |
|
60 |
-
# Check if audio files were found
|
61 |
if not audio_files:
|
62 |
raise FileNotFoundError(f"β No .flac files found in {AUDIO_FOLDER}. Check dataset structure!")
|
63 |
|
@@ -73,22 +72,27 @@ def load_and_process_audio(audio_path):
|
|
73 |
# Resample to 16kHz
|
74 |
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
|
75 |
|
76 |
-
# Convert to model input format
|
77 |
input_features = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features[0]
|
78 |
|
79 |
return input_features
|
80 |
|
81 |
# Manually create dataset structure
|
82 |
-
dataset = [{"input_features": load_and_process_audio(f), "labels": []} for f in audio_files[:100]]
|
83 |
|
84 |
-
|
|
|
|
|
|
|
|
|
|
|
85 |
|
86 |
# ================================
|
87 |
# 4οΈβ£ Training Arguments & Trainer
|
88 |
# ================================
|
89 |
training_args = TrainingArguments(
|
90 |
output_dir="./asr_model_finetuned",
|
91 |
-
|
92 |
save_strategy="epoch",
|
93 |
learning_rate=5e-5,
|
94 |
per_device_train_batch_size=8,
|
@@ -102,15 +106,15 @@ training_args = TrainingArguments(
|
|
102 |
)
|
103 |
|
104 |
# Data collator (for dynamic padding)
|
105 |
-
data_collator = DataCollatorForSeq2Seq(processor
|
106 |
|
107 |
-
# Define Trainer
|
108 |
trainer = Trainer(
|
109 |
model=model,
|
110 |
args=training_args,
|
111 |
-
train_dataset=
|
112 |
-
eval_dataset=
|
113 |
-
|
114 |
data_collator=data_collator,
|
115 |
)
|
116 |
|
@@ -140,7 +144,7 @@ if audio_file:
|
|
140 |
waveform, sample_rate = torchaudio.load(audio_path)
|
141 |
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
|
142 |
|
143 |
-
# Convert audio to model input
|
144 |
input_features = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features[0]
|
145 |
|
146 |
# Perform ASR inference
|
|
|
29 |
# ================================
|
30 |
# 2οΈβ£ Load Dataset (Recursively from Extracted Path)
|
31 |
# ================================
|
32 |
+
DATASET_TAR_PATH = "dev-clean.tar.gz"
|
33 |
+
EXTRACT_PATH = "./librispeech_dev_clean"
|
34 |
|
35 |
+
# Extract dataset if not already extracted
|
36 |
if not os.path.exists(EXTRACT_PATH):
|
37 |
print("π Extracting dataset...")
|
38 |
with tarfile.open(DATASET_TAR_PATH, "r:gz") as tar:
|
|
|
41 |
else:
|
42 |
print("β
Dataset already extracted.")
|
43 |
|
44 |
+
# Base directory where audio files are stored
|
45 |
AUDIO_FOLDER = os.path.join(EXTRACT_PATH, "LibriSpeech", "dev-clean")
|
46 |
|
47 |
# Recursively find all `.flac` files inside the dataset directory
|
|
|
57 |
# Get all audio files
|
58 |
audio_files = find_audio_files(AUDIO_FOLDER)
|
59 |
|
|
|
60 |
if not audio_files:
|
61 |
raise FileNotFoundError(f"β No .flac files found in {AUDIO_FOLDER}. Check dataset structure!")
|
62 |
|
|
|
72 |
# Resample to 16kHz
|
73 |
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
|
74 |
|
75 |
+
# Convert to model input format
|
76 |
input_features = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features[0]
|
77 |
|
78 |
return input_features
|
79 |
|
80 |
# Manually create dataset structure
|
81 |
+
dataset = [{"input_features": load_and_process_audio(f), "labels": []} for f in audio_files[:100]]
|
82 |
|
83 |
+
# Split dataset into train and eval (Recommended Fix)
|
84 |
+
train_size = int(0.9 * len(dataset))
|
85 |
+
train_dataset = dataset[:train_size]
|
86 |
+
eval_dataset = dataset[train_size:]
|
87 |
+
|
88 |
+
print(f"β
Dataset Loaded! Training: {len(train_dataset)}, Evaluation: {len(eval_dataset)}")
|
89 |
|
90 |
# ================================
|
91 |
# 4οΈβ£ Training Arguments & Trainer
|
92 |
# ================================
|
93 |
training_args = TrainingArguments(
|
94 |
output_dir="./asr_model_finetuned",
|
95 |
+
eval_strategy="epoch", # Fix: Proper evaluation
|
96 |
save_strategy="epoch",
|
97 |
learning_rate=5e-5,
|
98 |
per_device_train_batch_size=8,
|
|
|
106 |
)
|
107 |
|
108 |
# Data collator (for dynamic padding)
|
109 |
+
data_collator = DataCollatorForSeq2Seq(processor, model=model)
|
110 |
|
111 |
+
# Define Trainer (Fixed `processing_class` warning)
|
112 |
trainer = Trainer(
|
113 |
model=model,
|
114 |
args=training_args,
|
115 |
+
train_dataset=train_dataset,
|
116 |
+
eval_dataset=eval_dataset, # Fix: Providing eval_dataset
|
117 |
+
processing_class=processor, # Fix: Replacing deprecated `tokenizer`
|
118 |
data_collator=data_collator,
|
119 |
)
|
120 |
|
|
|
144 |
waveform, sample_rate = torchaudio.load(audio_path)
|
145 |
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
|
146 |
|
147 |
+
# Convert audio to model input
|
148 |
input_features = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features[0]
|
149 |
|
150 |
# Perform ASR inference
|