Update app.py
Browse files
app.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
import os
|
|
|
2 |
import torch
|
3 |
import torchaudio
|
4 |
-
import tarfile
|
5 |
import numpy as np
|
6 |
import streamlit as st
|
7 |
from datasets import load_dataset
|
@@ -28,31 +28,33 @@ model.to(device)
|
|
28 |
print(f"β
Model loaded on {device}")
|
29 |
|
30 |
# ================================
|
31 |
-
# 2οΈβ£ Load Dataset (LibriSpeech)
|
32 |
# ================================
|
|
|
|
|
33 |
|
34 |
-
|
35 |
-
EXTRACT_PATH = "./librispeech_dev_clean" # Extracted folder
|
36 |
-
|
37 |
-
# Extract dataset if not already extracted
|
38 |
if not os.path.exists(EXTRACT_PATH):
|
39 |
print("π Extracting dataset...")
|
40 |
with tarfile.open(DATASET_TAR_PATH, "r:gz") as tar:
|
41 |
tar.extractall(EXTRACT_PATH)
|
42 |
print("β
Extraction complete.")
|
|
|
|
|
43 |
|
44 |
-
# Load dataset from extracted
|
45 |
dataset = load_dataset("librispeech_asr", data_dir=EXTRACT_PATH, split="train", trust_remote_code=True)
|
46 |
-
print(f"β
Dataset Loaded! {dataset}")
|
47 |
|
48 |
# ================================
|
49 |
# 3οΈβ£ Preprocess Dataset
|
50 |
# ================================
|
51 |
def preprocess_audio(batch):
|
|
|
52 |
audio = batch["audio"]
|
53 |
waveform, sample_rate = torchaudio.load(audio["path"])
|
54 |
|
55 |
-
# Resample to 16kHz
|
56 |
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
|
57 |
|
58 |
# Convert to model input format
|
@@ -62,6 +64,7 @@ def preprocess_audio(batch):
|
|
62 |
|
63 |
# Apply preprocessing
|
64 |
dataset = dataset.map(preprocess_audio, remove_columns=["audio"])
|
|
|
65 |
|
66 |
# ================================
|
67 |
# 4οΈβ£ Training Arguments & Trainer
|
|
|
1 |
import os
|
2 |
+
import tarfile
|
3 |
import torch
|
4 |
import torchaudio
|
|
|
5 |
import numpy as np
|
6 |
import streamlit as st
|
7 |
from datasets import load_dataset
|
|
|
28 |
print(f"β
Model loaded on {device}")
|
29 |
|
30 |
# ================================
|
31 |
+
# 2οΈβ£ Load Dataset (LibriSpeech) from Extracted Path
|
32 |
# ================================
|
33 |
+
DATASET_TAR_PATH = "dev-clean.tar.gz" # Uploaded dataset in your Hugging Face Space
|
34 |
+
EXTRACT_PATH = "./librispeech_dev_clean" # Extracted dataset folder
|
35 |
|
36 |
+
# Extract dataset only if not already extracted
|
|
|
|
|
|
|
37 |
if not os.path.exists(EXTRACT_PATH):
|
38 |
print("π Extracting dataset...")
|
39 |
with tarfile.open(DATASET_TAR_PATH, "r:gz") as tar:
|
40 |
tar.extractall(EXTRACT_PATH)
|
41 |
print("β
Extraction complete.")
|
42 |
+
else:
|
43 |
+
print("β
Dataset already extracted.")
|
44 |
|
45 |
+
# β
Load dataset from extracted folder
|
46 |
dataset = load_dataset("librispeech_asr", data_dir=EXTRACT_PATH, split="train", trust_remote_code=True)
|
47 |
+
print(f"β
Dataset Loaded Successfully! Size: {len(dataset)}")
|
48 |
|
49 |
# ================================
|
50 |
# 3οΈβ£ Preprocess Dataset
|
51 |
# ================================
|
52 |
def preprocess_audio(batch):
|
53 |
+
"""Converts raw audio to a model-compatible format."""
|
54 |
audio = batch["audio"]
|
55 |
waveform, sample_rate = torchaudio.load(audio["path"])
|
56 |
|
57 |
+
# Resample to 16kHz (ASR models usually require this)
|
58 |
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
|
59 |
|
60 |
# Convert to model input format
|
|
|
64 |
|
65 |
# Apply preprocessing
|
66 |
dataset = dataset.map(preprocess_audio, remove_columns=["audio"])
|
67 |
+
print(f"β
Dataset Preprocessed! Ready for Fine-Tuning.")
|
68 |
|
69 |
# ================================
|
70 |
# 4οΈβ£ Training Arguments & Trainer
|