tahirsher commited on
Commit
d2d38cf
·
verified ·
1 Parent(s): 5e01fef

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -23
app.py CHANGED
@@ -4,7 +4,6 @@ import torch
4
  import torchaudio
5
  import numpy as np
6
  import streamlit as st
7
- from datasets import load_dataset
8
  from transformers import (
9
  AutoProcessor,
10
  AutoModelForSpeechSeq2Seq,
@@ -28,9 +27,9 @@ model.to(device)
28
  print(f"✅ Model loaded on {device}")
29
 
30
  # ================================
31
- # 2️⃣ Load Dataset (LibriSpeech) from Extracted Path
32
  # ================================
33
- DATASET_TAR_PATH = "dev-clean.tar.gz" # Uploaded dataset in your Hugging Face Space
34
  EXTRACT_PATH = "./librispeech_dev_clean" # Extracted dataset folder
35
 
36
  # Extract dataset only if not already extracted
@@ -42,29 +41,29 @@ if not os.path.exists(EXTRACT_PATH):
42
  else:
43
  print("✅ Dataset already extracted.")
44
 
45
- # Load dataset from extracted folder
46
- dataset = load_dataset("librispeech_asr", data_dir=EXTRACT_PATH, split="train", trust_remote_code=True)
47
- print(f"✅ Dataset Loaded Successfully! Size: {len(dataset)}")
48
 
49
  # ================================
50
- # 3️⃣ Preprocess Dataset
51
  # ================================
52
- def preprocess_audio(batch):
53
- """Converts raw audio to a model-compatible format."""
54
- audio = batch["audio"]
55
- waveform, sample_rate = torchaudio.load(audio["path"])
56
 
57
- # Resample to 16kHz (ASR models usually require this)
58
  waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
59
 
60
  # Convert to model input format
61
- batch["input_values"] = processor(waveform.squeeze().numpy(), sampling_rate=16000).input_values[0]
62
- batch["labels"] = processor.tokenizer(batch["text"]).input_ids
63
- return batch
 
 
 
64
 
65
- # Apply preprocessing
66
- dataset = dataset.map(preprocess_audio, remove_columns=["audio"])
67
- print(f"✅ Dataset Preprocessed! Ready for Fine-Tuning.")
68
 
69
  # ================================
70
  # 4️⃣ Training Arguments & Trainer
@@ -82,10 +81,6 @@ training_args = TrainingArguments(
82
  logging_steps=500,
83
  save_total_limit=2,
84
  push_to_hub=True,
85
- metric_for_best_model="wer",
86
- greater_is_better=False,
87
- save_on_each_node=True, # Improves stability during multi-GPU training
88
- load_best_model_at_end=True, # Saves best model
89
  )
90
 
91
  # Data collator (for dynamic padding)
@@ -151,7 +146,7 @@ if audio_file:
151
  corrected_input = processor.tokenizer(user_correction).input_ids
152
 
153
  # Dynamically add new example to dataset
154
- dataset = dataset.add_item({"input_values": input_values, "labels": corrected_input})
155
 
156
  # Perform quick re-training (1 epoch)
157
  trainer.args.num_train_epochs = 1
 
4
  import torchaudio
5
  import numpy as np
6
  import streamlit as st
 
7
  from transformers import (
8
  AutoProcessor,
9
  AutoModelForSpeechSeq2Seq,
 
27
  print(f"✅ Model loaded on {device}")
28
 
29
  # ================================
30
+ # 2️⃣ Load Dataset (Manually from Extracted Path)
31
  # ================================
32
+ DATASET_TAR_PATH = "dev-clean.tar.gz" # Dataset stored in Hugging Face Space
33
  EXTRACT_PATH = "./librispeech_dev_clean" # Extracted dataset folder
34
 
35
  # Extract dataset only if not already extracted
 
41
  else:
42
  print("✅ Dataset already extracted.")
43
 
44
+ # Load audio files manually
45
+ AUDIO_FOLDER = os.path.join(EXTRACT_PATH, "LibriSpeech", "train-clean-100") # Adjust as per structure
46
+ audio_files = [os.path.join(AUDIO_FOLDER, f) for f in os.listdir(AUDIO_FOLDER) if f.endswith(".flac")]
47
 
48
  # ================================
49
+ # 3️⃣ Preprocess Dataset (Manually)
50
  # ================================
51
+ def load_and_process_audio(audio_path):
52
+ """Loads and processes a single audio file into model format."""
53
+ waveform, sample_rate = torchaudio.load(audio_path)
 
54
 
55
+ # Resample to 16kHz
56
  waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
57
 
58
  # Convert to model input format
59
+ input_values = processor(waveform.squeeze().numpy(), sampling_rate=16000).input_values[0]
60
+
61
+ return input_values
62
+
63
+ # Manually create dataset structure
64
+ dataset = [{"input_values": load_and_process_audio(f), "labels": []} for f in audio_files[:100]] # Load first 100
65
 
66
+ print(f"✅ Dataset Loaded! Processed {len(dataset)} audio files.")
 
 
67
 
68
  # ================================
69
  # 4️⃣ Training Arguments & Trainer
 
81
  logging_steps=500,
82
  save_total_limit=2,
83
  push_to_hub=True,
 
 
 
 
84
  )
85
 
86
  # Data collator (for dynamic padding)
 
146
  corrected_input = processor.tokenizer(user_correction).input_ids
147
 
148
  # Dynamically add new example to dataset
149
+ dataset.append({"input_values": input_values, "labels": corrected_input})
150
 
151
  # Perform quick re-training (1 epoch)
152
  trainer.args.num_train_epochs = 1