Update app.py
Browse files
app.py
CHANGED
@@ -4,7 +4,6 @@ import torch
|
|
4 |
import torchaudio
|
5 |
import numpy as np
|
6 |
import streamlit as st
|
7 |
-
from datasets import load_dataset
|
8 |
from transformers import (
|
9 |
AutoProcessor,
|
10 |
AutoModelForSpeechSeq2Seq,
|
@@ -28,9 +27,9 @@ model.to(device)
|
|
28 |
print(f"✅ Model loaded on {device}")
|
29 |
|
30 |
# ================================
|
31 |
-
# 2️⃣ Load Dataset (
|
32 |
# ================================
|
33 |
-
DATASET_TAR_PATH = "dev-clean.tar.gz" #
|
34 |
EXTRACT_PATH = "./librispeech_dev_clean" # Extracted dataset folder
|
35 |
|
36 |
# Extract dataset only if not already extracted
|
@@ -42,29 +41,29 @@ if not os.path.exists(EXTRACT_PATH):
|
|
42 |
else:
|
43 |
print("✅ Dataset already extracted.")
|
44 |
|
45 |
-
#
|
46 |
-
|
47 |
-
|
48 |
|
49 |
# ================================
|
50 |
-
# 3️⃣ Preprocess Dataset
|
51 |
# ================================
|
52 |
-
def
|
53 |
-
"""
|
54 |
-
|
55 |
-
waveform, sample_rate = torchaudio.load(audio["path"])
|
56 |
|
57 |
-
# Resample to 16kHz
|
58 |
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
|
59 |
|
60 |
# Convert to model input format
|
61 |
-
|
62 |
-
|
63 |
-
return
|
|
|
|
|
|
|
64 |
|
65 |
-
|
66 |
-
dataset = dataset.map(preprocess_audio, remove_columns=["audio"])
|
67 |
-
print(f"✅ Dataset Preprocessed! Ready for Fine-Tuning.")
|
68 |
|
69 |
# ================================
|
70 |
# 4️⃣ Training Arguments & Trainer
|
@@ -82,10 +81,6 @@ training_args = TrainingArguments(
|
|
82 |
logging_steps=500,
|
83 |
save_total_limit=2,
|
84 |
push_to_hub=True,
|
85 |
-
metric_for_best_model="wer",
|
86 |
-
greater_is_better=False,
|
87 |
-
save_on_each_node=True, # Improves stability during multi-GPU training
|
88 |
-
load_best_model_at_end=True, # Saves best model
|
89 |
)
|
90 |
|
91 |
# Data collator (for dynamic padding)
|
@@ -151,7 +146,7 @@ if audio_file:
|
|
151 |
corrected_input = processor.tokenizer(user_correction).input_ids
|
152 |
|
153 |
# Dynamically add new example to dataset
|
154 |
-
dataset
|
155 |
|
156 |
# Perform quick re-training (1 epoch)
|
157 |
trainer.args.num_train_epochs = 1
|
|
|
4 |
import torchaudio
|
5 |
import numpy as np
|
6 |
import streamlit as st
|
|
|
7 |
from transformers import (
|
8 |
AutoProcessor,
|
9 |
AutoModelForSpeechSeq2Seq,
|
|
|
27 |
print(f"✅ Model loaded on {device}")
|
28 |
|
29 |
# ================================
|
30 |
+
# 2️⃣ Load Dataset (Manually from Extracted Path)
|
31 |
# ================================
|
32 |
+
DATASET_TAR_PATH = "dev-clean.tar.gz" # Dataset stored in Hugging Face Space
|
33 |
EXTRACT_PATH = "./librispeech_dev_clean" # Extracted dataset folder
|
34 |
|
35 |
# Extract dataset only if not already extracted
|
|
|
41 |
else:
|
42 |
print("✅ Dataset already extracted.")
|
43 |
|
44 |
+
# Load audio files manually
|
45 |
+
AUDIO_FOLDER = os.path.join(EXTRACT_PATH, "LibriSpeech", "train-clean-100") # Adjust as per structure
|
46 |
+
audio_files = [os.path.join(AUDIO_FOLDER, f) for f in os.listdir(AUDIO_FOLDER) if f.endswith(".flac")]
|
47 |
|
48 |
# ================================
|
49 |
+
# 3️⃣ Preprocess Dataset (Manually)
|
50 |
# ================================
|
51 |
+
def load_and_process_audio(audio_path):
|
52 |
+
"""Loads and processes a single audio file into model format."""
|
53 |
+
waveform, sample_rate = torchaudio.load(audio_path)
|
|
|
54 |
|
55 |
+
# Resample to 16kHz
|
56 |
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
|
57 |
|
58 |
# Convert to model input format
|
59 |
+
input_values = processor(waveform.squeeze().numpy(), sampling_rate=16000).input_values[0]
|
60 |
+
|
61 |
+
return input_values
|
62 |
+
|
63 |
+
# Manually create dataset structure
|
64 |
+
dataset = [{"input_values": load_and_process_audio(f), "labels": []} for f in audio_files[:100]] # Load first 100
|
65 |
|
66 |
+
print(f"✅ Dataset Loaded! Processed {len(dataset)} audio files.")
|
|
|
|
|
67 |
|
68 |
# ================================
|
69 |
# 4️⃣ Training Arguments & Trainer
|
|
|
81 |
logging_steps=500,
|
82 |
save_total_limit=2,
|
83 |
push_to_hub=True,
|
|
|
|
|
|
|
|
|
84 |
)
|
85 |
|
86 |
# Data collator (for dynamic padding)
|
|
|
146 |
corrected_input = processor.tokenizer(user_correction).input_ids
|
147 |
|
148 |
# Dynamically add new example to dataset
|
149 |
+
dataset.append({"input_values": input_values, "labels": corrected_input})
|
150 |
|
151 |
# Perform quick re-training (1 epoch)
|
152 |
trainer.args.num_train_epochs = 1
|