Update app.py
Browse files
app.py
CHANGED
@@ -1,53 +1,61 @@
|
|
1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
-
#
|
4 |
-
|
5 |
-
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
# Move model to GPU if available
|
8 |
-
import torch
|
9 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
10 |
model.to(device)
|
11 |
-
print(f"Model loaded on {device}")
|
12 |
-
|
13 |
-
from datasets import load_dataset
|
14 |
-
import torchaudio
|
15 |
-
import torch
|
16 |
-
|
17 |
-
# Fix: Add trust_remote_code=True
|
18 |
-
import fsspec
|
19 |
-
import os
|
20 |
-
import tarfile
|
21 |
|
22 |
-
#
|
23 |
-
|
24 |
-
|
25 |
|
26 |
-
#
|
27 |
-
|
28 |
-
print("Extracting dataset...")
|
29 |
-
with tarfile.open(dataset_tar_path, "r:gz") as tar:
|
30 |
-
tar.extractall(extract_path)
|
31 |
-
print("Extraction complete.")
|
32 |
-
else:
|
33 |
-
print("Dataset already extracted.")
|
34 |
-
|
35 |
-
from datasets import load_dataset
|
36 |
|
37 |
-
#
|
38 |
-
|
|
|
|
|
|
|
|
|
39 |
|
40 |
-
|
|
|
|
|
41 |
|
42 |
-
#
|
|
|
|
|
43 |
def preprocess_audio(batch):
|
44 |
audio = batch["audio"]
|
45 |
waveform, sample_rate = torchaudio.load(audio["path"])
|
46 |
-
|
47 |
-
# Resample to 16kHz
|
48 |
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
|
49 |
-
|
50 |
-
# Convert to
|
51 |
batch["input_values"] = processor(waveform.squeeze().numpy(), sampling_rate=16000).input_values[0]
|
52 |
batch["labels"] = processor.tokenizer(batch["text"]).input_ids
|
53 |
return batch
|
@@ -55,9 +63,9 @@ def preprocess_audio(batch):
|
|
55 |
# Apply preprocessing
|
56 |
dataset = dataset.map(preprocess_audio, remove_columns=["audio"])
|
57 |
|
58 |
-
|
59 |
-
|
60 |
-
#
|
61 |
training_args = TrainingArguments(
|
62 |
output_dir="./asr_model_finetuned",
|
63 |
evaluation_strategy="epoch",
|
@@ -70,10 +78,14 @@ training_args = TrainingArguments(
|
|
70 |
logging_dir="./logs",
|
71 |
logging_steps=500,
|
72 |
save_total_limit=2,
|
73 |
-
push_to_hub=True,
|
|
|
|
|
|
|
|
|
74 |
)
|
75 |
|
76 |
-
#
|
77 |
data_collator = DataCollatorForSeq2Seq(processor.tokenizer, model=model)
|
78 |
|
79 |
# Define Trainer
|
@@ -81,38 +93,41 @@ trainer = Trainer(
|
|
81 |
model=model,
|
82 |
args=training_args,
|
83 |
train_dataset=dataset,
|
84 |
-
eval_dataset=None, #
|
85 |
tokenizer=processor.feature_extractor,
|
86 |
data_collator=data_collator,
|
87 |
)
|
88 |
|
89 |
-
#
|
90 |
-
|
|
|
|
|
|
|
|
|
|
|
91 |
|
92 |
-
#
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
st.title("ποΈ Automatic Speech Recognition with Fine-Tuning πΆ")
|
98 |
|
99 |
# Upload audio file
|
100 |
audio_file = st.file_uploader("Upload an audio file", type=["wav", "mp3", "flac"])
|
101 |
|
102 |
if audio_file:
|
103 |
-
# Save
|
104 |
-
|
|
|
105 |
f.write(audio_file.read())
|
106 |
|
107 |
-
|
108 |
-
|
109 |
-
# Resample to 16kHz
|
110 |
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
|
111 |
|
112 |
-
# Convert to model input
|
113 |
input_values = processor(waveform.squeeze().numpy(), sampling_rate=16000).input_values[0]
|
114 |
|
115 |
-
# Perform
|
116 |
with torch.no_grad():
|
117 |
input_tensor = torch.tensor([input_values]).to(device)
|
118 |
logits = model(input_tensor).logits
|
@@ -120,22 +135,23 @@ if audio_file:
|
|
120 |
transcription = processor.batch_decode(predicted_ids)[0]
|
121 |
|
122 |
# Display transcription
|
123 |
-
st.success("Transcription:")
|
124 |
st.write(transcription)
|
125 |
|
126 |
-
#
|
127 |
-
|
128 |
-
|
129 |
-
|
|
|
|
|
130 |
if user_correction:
|
131 |
-
# Convert correction to training format
|
132 |
corrected_input = processor.tokenizer(user_correction).input_ids
|
133 |
|
134 |
-
#
|
135 |
dataset = dataset.add_item({"input_values": input_values, "labels": corrected_input})
|
136 |
|
137 |
-
#
|
|
|
138 |
trainer.train()
|
139 |
|
140 |
-
st.success("Model fine-tuned
|
141 |
-
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import torchaudio
|
4 |
+
import tarfile
|
5 |
+
import numpy as np
|
6 |
+
import streamlit as st
|
7 |
+
from datasets import load_dataset
|
8 |
+
from transformers import (
|
9 |
+
AutoProcessor,
|
10 |
+
AutoModelForSpeechSeq2Seq,
|
11 |
+
TrainingArguments,
|
12 |
+
Trainer,
|
13 |
+
DataCollatorForSeq2Seq,
|
14 |
+
)
|
15 |
|
16 |
+
# ================================
|
17 |
+
# 1οΈβ£ Load Model & Processor
|
18 |
+
# ================================
|
19 |
+
MODEL_NAME = "AqeelShafy7/AudioSangraha-Audio_to_Text"
|
20 |
+
|
21 |
+
# Load ASR model and processor
|
22 |
+
processor = AutoProcessor.from_pretrained(MODEL_NAME)
|
23 |
+
model = AutoModelForSpeechSeq2Seq.from_pretrained(MODEL_NAME)
|
24 |
|
25 |
# Move model to GPU if available
|
|
|
26 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
27 |
model.to(device)
|
28 |
+
print(f"β
Model loaded on {device}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
|
30 |
+
# ================================
|
31 |
+
# 2οΈβ£ Load Dataset (LibriSpeech)
|
32 |
+
# ================================
|
33 |
|
34 |
+
DATASET_TAR_PATH = "dev-clean.tar.gz" # The uploaded dataset in Hugging Face space
|
35 |
+
EXTRACT_PATH = "./librispeech_dev_clean" # Extracted folder
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
|
37 |
+
# Extract dataset if not already extracted
|
38 |
+
if not os.path.exists(EXTRACT_PATH):
|
39 |
+
print("π Extracting dataset...")
|
40 |
+
with tarfile.open(DATASET_TAR_PATH, "r:gz") as tar:
|
41 |
+
tar.extractall(EXTRACT_PATH)
|
42 |
+
print("β
Extraction complete.")
|
43 |
|
44 |
+
# Load dataset from extracted path
|
45 |
+
dataset = load_dataset("librispeech_asr", data_dir=EXTRACT_PATH, split="train", trust_remote_code=True)
|
46 |
+
print(f"β
Dataset Loaded! {dataset}")
|
47 |
|
48 |
+
# ================================
|
49 |
+
# 3οΈβ£ Preprocess Dataset
|
50 |
+
# ================================
|
51 |
def preprocess_audio(batch):
|
52 |
audio = batch["audio"]
|
53 |
waveform, sample_rate = torchaudio.load(audio["path"])
|
54 |
+
|
55 |
+
# Resample to 16kHz
|
56 |
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
|
57 |
+
|
58 |
+
# Convert to model input format
|
59 |
batch["input_values"] = processor(waveform.squeeze().numpy(), sampling_rate=16000).input_values[0]
|
60 |
batch["labels"] = processor.tokenizer(batch["text"]).input_ids
|
61 |
return batch
|
|
|
63 |
# Apply preprocessing
|
64 |
dataset = dataset.map(preprocess_audio, remove_columns=["audio"])
|
65 |
|
66 |
+
# ================================
|
67 |
+
# 4οΈβ£ Training Arguments & Trainer
|
68 |
+
# ================================
|
69 |
training_args = TrainingArguments(
|
70 |
output_dir="./asr_model_finetuned",
|
71 |
evaluation_strategy="epoch",
|
|
|
78 |
logging_dir="./logs",
|
79 |
logging_steps=500,
|
80 |
save_total_limit=2,
|
81 |
+
push_to_hub=True,
|
82 |
+
metric_for_best_model="wer",
|
83 |
+
greater_is_better=False,
|
84 |
+
save_on_each_node=True, # Improves stability during multi-GPU training
|
85 |
+
load_best_model_at_end=True, # Saves best model
|
86 |
)
|
87 |
|
88 |
+
# Data collator (for dynamic padding)
|
89 |
data_collator = DataCollatorForSeq2Seq(processor.tokenizer, model=model)
|
90 |
|
91 |
# Define Trainer
|
|
|
93 |
model=model,
|
94 |
args=training_args,
|
95 |
train_dataset=dataset,
|
96 |
+
eval_dataset=None, # No validation dataset for now
|
97 |
tokenizer=processor.feature_extractor,
|
98 |
data_collator=data_collator,
|
99 |
)
|
100 |
|
101 |
+
# ================================
|
102 |
+
# 5οΈβ£ Fine-Tuning Execution
|
103 |
+
# ================================
|
104 |
+
if st.button("Start Fine-Tuning"):
|
105 |
+
with st.spinner("Fine-tuning in progress... Please wait!"):
|
106 |
+
trainer.train()
|
107 |
+
st.success("β
Fine-Tuning Completed! Model updated.")
|
108 |
|
109 |
+
# ================================
|
110 |
+
# 6οΈβ£ Streamlit ASR Web App
|
111 |
+
# ================================
|
112 |
+
st.title("ποΈ Speech-to-Text ASR with Fine-Tuning πΆ")
|
|
|
|
|
113 |
|
114 |
# Upload audio file
|
115 |
audio_file = st.file_uploader("Upload an audio file", type=["wav", "mp3", "flac"])
|
116 |
|
117 |
if audio_file:
|
118 |
+
# Save uploaded file temporarily
|
119 |
+
audio_path = "temp_audio.wav"
|
120 |
+
with open(audio_path, "wb") as f:
|
121 |
f.write(audio_file.read())
|
122 |
|
123 |
+
# Load and process audio
|
124 |
+
waveform, sample_rate = torchaudio.load(audio_path)
|
|
|
125 |
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
|
126 |
|
127 |
+
# Convert audio to model input
|
128 |
input_values = processor(waveform.squeeze().numpy(), sampling_rate=16000).input_values[0]
|
129 |
|
130 |
+
# Perform ASR inference
|
131 |
with torch.no_grad():
|
132 |
input_tensor = torch.tensor([input_values]).to(device)
|
133 |
logits = model(input_tensor).logits
|
|
|
135 |
transcription = processor.batch_decode(predicted_ids)[0]
|
136 |
|
137 |
# Display transcription
|
138 |
+
st.success("π Transcription:")
|
139 |
st.write(transcription)
|
140 |
|
141 |
+
# ================================
|
142 |
+
# 7οΈβ£ Fine-Tune Model with User Correction
|
143 |
+
# ================================
|
144 |
+
user_correction = st.text_area("π§ Correct the transcription (if needed):", transcription)
|
145 |
+
|
146 |
+
if st.button("Fine-Tune with Correction"):
|
147 |
if user_correction:
|
|
|
148 |
corrected_input = processor.tokenizer(user_correction).input_ids
|
149 |
|
150 |
+
# Dynamically add new example to dataset
|
151 |
dataset = dataset.add_item({"input_values": input_values, "labels": corrected_input})
|
152 |
|
153 |
+
# Perform quick re-training (1 epoch)
|
154 |
+
trainer.args.num_train_epochs = 1
|
155 |
trainer.train()
|
156 |
|
157 |
+
st.success("β
Model fine-tuned with new correction! Try another audio file.")
|
|