tahirsher commited on
Commit
f0a5b40
·
verified ·
1 Parent(s): 7db1356

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +120 -0
app.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
2
+
3
+ # Load Processor & Model
4
+ processor = AutoProcessor.from_pretrained("AqeelShafy7/AudioSangraha-Audio_to_Text")
5
+ model = AutoModelForSpeechSeq2Seq.from_pretrained("AqeelShafy7/AudioSangraha-Audio_to_Text")
6
+
7
+ # Move model to GPU if available
8
+ import torch
9
+ device = "cuda" if torch.cuda.is_available() else "cpu"
10
+ model.to(device)
11
+ print(f"Model loaded on {device}")
12
+
13
+ from datasets import load_dataset
14
+ import torchaudio
15
+ import torch
16
+
17
+ # Load the "clean" LibriSpeech dataset
18
+
19
+ dataset = load_dataset("librispeech_asr", "clean", split="train")
20
+
21
+ # Function to load & resample audio
22
+ def preprocess_audio(batch):
23
+ audio = batch["audio"]
24
+ waveform, sample_rate = torchaudio.load(audio["path"])
25
+
26
+ # Resample to 16kHz (ASR models usually require this)
27
+ waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
28
+
29
+ # Convert to correct format
30
+ batch["input_values"] = processor(waveform.squeeze().numpy(), sampling_rate=16000).input_values[0]
31
+ batch["labels"] = processor.tokenizer(batch["text"]).input_ids
32
+ return batch
33
+
34
+ # Apply preprocessing
35
+ dataset = dataset.map(preprocess_audio, remove_columns=["audio"])
36
+
37
+ from transformers import TrainingArguments, Trainer, DataCollatorForSeq2Seq
38
+
39
+ # Define Training Arguments
40
+ training_args = TrainingArguments(
41
+ output_dir="./asr_model_finetuned",
42
+ evaluation_strategy="epoch",
43
+ save_strategy="epoch",
44
+ learning_rate=5e-5,
45
+ per_device_train_batch_size=8,
46
+ per_device_eval_batch_size=8,
47
+ num_train_epochs=3,
48
+ weight_decay=0.01,
49
+ logging_dir="./logs",
50
+ logging_steps=500,
51
+ save_total_limit=2,
52
+ push_to_hub=True, # Enable uploading to Hugging Face Hub
53
+ )
54
+
55
+ # Define Data Collator
56
+ data_collator = DataCollatorForSeq2Seq(processor.tokenizer, model=model)
57
+
58
+ # Define Trainer
59
+ trainer = Trainer(
60
+ model=model,
61
+ args=training_args,
62
+ train_dataset=dataset,
63
+ eval_dataset=None, # We use only training data here
64
+ tokenizer=processor.feature_extractor,
65
+ data_collator=data_collator,
66
+ )
67
+
68
+ # Start Fine-Tuning
69
+ trainer.train()
70
+
71
+ # Deployment of Huggingface using streamlit
72
+ import streamlit as st
73
+ import soundfile as sf
74
+ import numpy as np
75
+
76
+ st.title("🎙️ Automatic Speech Recognition with Fine-Tuning 🎶")
77
+
78
+ # Upload audio file
79
+ audio_file = st.file_uploader("Upload an audio file", type=["wav", "mp3", "flac"])
80
+
81
+ if audio_file:
82
+ # Save and load audio file
83
+ with open("temp_audio.wav", "wb") as f:
84
+ f.write(audio_file.read())
85
+
86
+ waveform, sample_rate = torchaudio.load("temp_audio.wav")
87
+
88
+ # Resample to 16kHz
89
+ waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
90
+
91
+ # Convert to model input
92
+ input_values = processor(waveform.squeeze().numpy(), sampling_rate=16000).input_values[0]
93
+
94
+ # Perform transcription
95
+ with torch.no_grad():
96
+ input_tensor = torch.tensor([input_values]).to(device)
97
+ logits = model(input_tensor).logits
98
+ predicted_ids = torch.argmax(logits, dim=-1)
99
+ transcription = processor.batch_decode(predicted_ids)[0]
100
+
101
+ # Display transcription
102
+ st.success("Transcription:")
103
+ st.write(transcription)
104
+
105
+ # Fine-tune with user input
106
+ user_correction = st.text_area("Correct the transcription (if needed):")
107
+
108
+ if st.button("Fine-Tune Model"):
109
+ if user_correction:
110
+ # Convert correction to training format
111
+ corrected_input = processor.tokenizer(user_correction).input_ids
112
+
113
+ # Update dataset dynamically (simple approach)
114
+ dataset = dataset.add_item({"input_values": input_values, "labels": corrected_input})
115
+
116
+ # Retrain for one step
117
+ trainer.train()
118
+
119
+ st.success("Model fine-tuned successfully! Try another audio file.")
120
+