tahirsher commited on
Commit
cd7aa15
Β·
verified Β·
1 Parent(s): 098a61e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -67
app.py CHANGED
@@ -1,53 +1,61 @@
1
- from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- # Load Processor & Model
4
- processor = AutoProcessor.from_pretrained("AqeelShafy7/AudioSangraha-Audio_to_Text")
5
- model = AutoModelForSpeechSeq2Seq.from_pretrained("AqeelShafy7/AudioSangraha-Audio_to_Text")
 
 
 
 
 
6
 
7
  # Move model to GPU if available
8
- import torch
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
  model.to(device)
11
- print(f"Model loaded on {device}")
12
-
13
- from datasets import load_dataset
14
- import torchaudio
15
- import torch
16
-
17
- # Fix: Add trust_remote_code=True
18
- import fsspec
19
- import os
20
- import tarfile
21
 
22
- # Define paths
23
- dataset_tar_path = "dev-clean.tar.gz" # Path in your repo
24
- extract_path = "./librispeech_dev_clean" # Extracted folder
25
 
26
- # Check if dataset is already extracted, if not, extract it
27
- if not os.path.exists(extract_path):
28
- print("Extracting dataset...")
29
- with tarfile.open(dataset_tar_path, "r:gz") as tar:
30
- tar.extractall(extract_path)
31
- print("Extraction complete.")
32
- else:
33
- print("Dataset already extracted.")
34
-
35
- from datasets import load_dataset
36
 
37
- # Load extracted dataset
38
- dataset = load_dataset("librispeech_asr", data_dir=extract_path, split="train", trust_remote_code=True)
 
 
 
 
39
 
40
- print("Dataset loaded successfully!")
 
 
41
 
42
- # Function to load & resample audio
 
 
43
  def preprocess_audio(batch):
44
  audio = batch["audio"]
45
  waveform, sample_rate = torchaudio.load(audio["path"])
46
-
47
- # Resample to 16kHz (ASR models usually require this)
48
  waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
49
-
50
- # Convert to correct format
51
  batch["input_values"] = processor(waveform.squeeze().numpy(), sampling_rate=16000).input_values[0]
52
  batch["labels"] = processor.tokenizer(batch["text"]).input_ids
53
  return batch
@@ -55,9 +63,9 @@ def preprocess_audio(batch):
55
  # Apply preprocessing
56
  dataset = dataset.map(preprocess_audio, remove_columns=["audio"])
57
 
58
- from transformers import TrainingArguments, Trainer, DataCollatorForSeq2Seq
59
-
60
- # Define Training Arguments
61
  training_args = TrainingArguments(
62
  output_dir="./asr_model_finetuned",
63
  evaluation_strategy="epoch",
@@ -70,10 +78,14 @@ training_args = TrainingArguments(
70
  logging_dir="./logs",
71
  logging_steps=500,
72
  save_total_limit=2,
73
- push_to_hub=True, # Enable uploading to Hugging Face Hub
 
 
 
 
74
  )
75
 
76
- # Define Data Collator
77
  data_collator = DataCollatorForSeq2Seq(processor.tokenizer, model=model)
78
 
79
  # Define Trainer
@@ -81,38 +93,41 @@ trainer = Trainer(
81
  model=model,
82
  args=training_args,
83
  train_dataset=dataset,
84
- eval_dataset=None, # We use only training data here
85
  tokenizer=processor.feature_extractor,
86
  data_collator=data_collator,
87
  )
88
 
89
- # Start Fine-Tuning
90
- trainer.train()
 
 
 
 
 
91
 
92
- # Deployment of Huggingface using streamlit
93
- import streamlit as st
94
- import soundfile as sf
95
- import numpy as np
96
-
97
- st.title("πŸŽ™οΈ Automatic Speech Recognition with Fine-Tuning 🎢")
98
 
99
  # Upload audio file
100
  audio_file = st.file_uploader("Upload an audio file", type=["wav", "mp3", "flac"])
101
 
102
  if audio_file:
103
- # Save and load audio file
104
- with open("temp_audio.wav", "wb") as f:
 
105
  f.write(audio_file.read())
106
 
107
- waveform, sample_rate = torchaudio.load("temp_audio.wav")
108
-
109
- # Resample to 16kHz
110
  waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
111
 
112
- # Convert to model input
113
  input_values = processor(waveform.squeeze().numpy(), sampling_rate=16000).input_values[0]
114
 
115
- # Perform transcription
116
  with torch.no_grad():
117
  input_tensor = torch.tensor([input_values]).to(device)
118
  logits = model(input_tensor).logits
@@ -120,22 +135,23 @@ if audio_file:
120
  transcription = processor.batch_decode(predicted_ids)[0]
121
 
122
  # Display transcription
123
- st.success("Transcription:")
124
  st.write(transcription)
125
 
126
- # Fine-tune with user input
127
- user_correction = st.text_area("Correct the transcription (if needed):")
128
-
129
- if st.button("Fine-Tune Model"):
 
 
130
  if user_correction:
131
- # Convert correction to training format
132
  corrected_input = processor.tokenizer(user_correction).input_ids
133
 
134
- # Update dataset dynamically (simple approach)
135
  dataset = dataset.add_item({"input_values": input_values, "labels": corrected_input})
136
 
137
- # Retrain for one step
 
138
  trainer.train()
139
 
140
- st.success("Model fine-tuned successfully! Try another audio file.")
141
-
 
1
+ import os
2
+ import torch
3
+ import torchaudio
4
+ import tarfile
5
+ import numpy as np
6
+ import streamlit as st
7
+ from datasets import load_dataset
8
+ from transformers import (
9
+ AutoProcessor,
10
+ AutoModelForSpeechSeq2Seq,
11
+ TrainingArguments,
12
+ Trainer,
13
+ DataCollatorForSeq2Seq,
14
+ )
15
 
16
+ # ================================
17
+ # 1️⃣ Load Model & Processor
18
+ # ================================
19
+ MODEL_NAME = "AqeelShafy7/AudioSangraha-Audio_to_Text"
20
+
21
+ # Load ASR model and processor
22
+ processor = AutoProcessor.from_pretrained(MODEL_NAME)
23
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(MODEL_NAME)
24
 
25
  # Move model to GPU if available
 
26
  device = "cuda" if torch.cuda.is_available() else "cpu"
27
  model.to(device)
28
+ print(f"βœ… Model loaded on {device}")
 
 
 
 
 
 
 
 
 
29
 
30
+ # ================================
31
+ # 2️⃣ Load Dataset (LibriSpeech)
32
+ # ================================
33
 
34
+ DATASET_TAR_PATH = "dev-clean.tar.gz" # The uploaded dataset in Hugging Face space
35
+ EXTRACT_PATH = "./librispeech_dev_clean" # Extracted folder
 
 
 
 
 
 
 
 
36
 
37
+ # Extract dataset if not already extracted
38
+ if not os.path.exists(EXTRACT_PATH):
39
+ print("πŸ”„ Extracting dataset...")
40
+ with tarfile.open(DATASET_TAR_PATH, "r:gz") as tar:
41
+ tar.extractall(EXTRACT_PATH)
42
+ print("βœ… Extraction complete.")
43
 
44
+ # Load dataset from extracted path
45
+ dataset = load_dataset("librispeech_asr", data_dir=EXTRACT_PATH, split="train", trust_remote_code=True)
46
+ print(f"βœ… Dataset Loaded! {dataset}")
47
 
48
+ # ================================
49
+ # 3️⃣ Preprocess Dataset
50
+ # ================================
51
  def preprocess_audio(batch):
52
  audio = batch["audio"]
53
  waveform, sample_rate = torchaudio.load(audio["path"])
54
+
55
+ # Resample to 16kHz
56
  waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
57
+
58
+ # Convert to model input format
59
  batch["input_values"] = processor(waveform.squeeze().numpy(), sampling_rate=16000).input_values[0]
60
  batch["labels"] = processor.tokenizer(batch["text"]).input_ids
61
  return batch
 
63
  # Apply preprocessing
64
  dataset = dataset.map(preprocess_audio, remove_columns=["audio"])
65
 
66
+ # ================================
67
+ # 4️⃣ Training Arguments & Trainer
68
+ # ================================
69
  training_args = TrainingArguments(
70
  output_dir="./asr_model_finetuned",
71
  evaluation_strategy="epoch",
 
78
  logging_dir="./logs",
79
  logging_steps=500,
80
  save_total_limit=2,
81
+ push_to_hub=True,
82
+ metric_for_best_model="wer",
83
+ greater_is_better=False,
84
+ save_on_each_node=True, # Improves stability during multi-GPU training
85
+ load_best_model_at_end=True, # Saves best model
86
  )
87
 
88
+ # Data collator (for dynamic padding)
89
  data_collator = DataCollatorForSeq2Seq(processor.tokenizer, model=model)
90
 
91
  # Define Trainer
 
93
  model=model,
94
  args=training_args,
95
  train_dataset=dataset,
96
+ eval_dataset=None, # No validation dataset for now
97
  tokenizer=processor.feature_extractor,
98
  data_collator=data_collator,
99
  )
100
 
101
+ # ================================
102
+ # 5️⃣ Fine-Tuning Execution
103
+ # ================================
104
+ if st.button("Start Fine-Tuning"):
105
+ with st.spinner("Fine-tuning in progress... Please wait!"):
106
+ trainer.train()
107
+ st.success("βœ… Fine-Tuning Completed! Model updated.")
108
 
109
+ # ================================
110
+ # 6️⃣ Streamlit ASR Web App
111
+ # ================================
112
+ st.title("πŸŽ™οΈ Speech-to-Text ASR with Fine-Tuning 🎢")
 
 
113
 
114
  # Upload audio file
115
  audio_file = st.file_uploader("Upload an audio file", type=["wav", "mp3", "flac"])
116
 
117
  if audio_file:
118
+ # Save uploaded file temporarily
119
+ audio_path = "temp_audio.wav"
120
+ with open(audio_path, "wb") as f:
121
  f.write(audio_file.read())
122
 
123
+ # Load and process audio
124
+ waveform, sample_rate = torchaudio.load(audio_path)
 
125
  waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
126
 
127
+ # Convert audio to model input
128
  input_values = processor(waveform.squeeze().numpy(), sampling_rate=16000).input_values[0]
129
 
130
+ # Perform ASR inference
131
  with torch.no_grad():
132
  input_tensor = torch.tensor([input_values]).to(device)
133
  logits = model(input_tensor).logits
 
135
  transcription = processor.batch_decode(predicted_ids)[0]
136
 
137
  # Display transcription
138
+ st.success("πŸ“„ Transcription:")
139
  st.write(transcription)
140
 
141
+ # ================================
142
+ # 7️⃣ Fine-Tune Model with User Correction
143
+ # ================================
144
+ user_correction = st.text_area("πŸ”§ Correct the transcription (if needed):", transcription)
145
+
146
+ if st.button("Fine-Tune with Correction"):
147
  if user_correction:
 
148
  corrected_input = processor.tokenizer(user_correction).input_ids
149
 
150
+ # Dynamically add new example to dataset
151
  dataset = dataset.add_item({"input_values": input_values, "labels": corrected_input})
152
 
153
+ # Perform quick re-training (1 epoch)
154
+ trainer.args.num_train_epochs = 1
155
  trainer.train()
156
 
157
+ st.success("βœ… Model fine-tuned with new correction! Try another audio file.")