File size: 3,931 Bytes
f0a5b40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a68fd86
393feaa
 
 
 
 
a68fd86
f0a5b40
 
393feaa
f0a5b40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq

# Load Processor & Model
processor = AutoProcessor.from_pretrained("AqeelShafy7/AudioSangraha-Audio_to_Text")
model = AutoModelForSpeechSeq2Seq.from_pretrained("AqeelShafy7/AudioSangraha-Audio_to_Text")

# Move model to GPU if available
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
print(f"Model loaded on {device}")

from datasets import load_dataset
import torchaudio
import torch

# Fix: Add trust_remote_code=True
import fsspec

# Set a higher timeout limit
fsspec.config.conf["timeout"] = 20000  #  minutes

dataset = load_dataset("librispeech_asr", "clean", split="train", trust_remote_code=True)



# Function to load & resample audio
def preprocess_audio(batch):
    audio = batch["audio"]
    waveform, sample_rate = torchaudio.load(audio["path"])
    
    # Resample to 16kHz (ASR models usually require this)
    waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
    
    # Convert to correct format
    batch["input_values"] = processor(waveform.squeeze().numpy(), sampling_rate=16000).input_values[0]
    batch["labels"] = processor.tokenizer(batch["text"]).input_ids
    return batch

# Apply preprocessing
dataset = dataset.map(preprocess_audio, remove_columns=["audio"])

from transformers import TrainingArguments, Trainer, DataCollatorForSeq2Seq

# Define Training Arguments
training_args = TrainingArguments(
    output_dir="./asr_model_finetuned",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=3,
    weight_decay=0.01,
    logging_dir="./logs",
    logging_steps=500,
    save_total_limit=2,
    push_to_hub=True,  # Enable uploading to Hugging Face Hub
)

# Define Data Collator
data_collator = DataCollatorForSeq2Seq(processor.tokenizer, model=model)

# Define Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    eval_dataset=None,  # We use only training data here
    tokenizer=processor.feature_extractor,
    data_collator=data_collator,
)

# Start Fine-Tuning
trainer.train()

# Deployment of Huggingface using streamlit
import streamlit as st
import soundfile as sf
import numpy as np

st.title("๐ŸŽ™๏ธ Automatic Speech Recognition with Fine-Tuning ๐ŸŽถ")

# Upload audio file
audio_file = st.file_uploader("Upload an audio file", type=["wav", "mp3", "flac"])

if audio_file:
    # Save and load audio file
    with open("temp_audio.wav", "wb") as f:
        f.write(audio_file.read())

    waveform, sample_rate = torchaudio.load("temp_audio.wav")

    # Resample to 16kHz
    waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)

    # Convert to model input
    input_values = processor(waveform.squeeze().numpy(), sampling_rate=16000).input_values[0]

    # Perform transcription
    with torch.no_grad():
        input_tensor = torch.tensor([input_values]).to(device)
        logits = model(input_tensor).logits
        predicted_ids = torch.argmax(logits, dim=-1)
        transcription = processor.batch_decode(predicted_ids)[0]

    # Display transcription
    st.success("Transcription:")
    st.write(transcription)

    # Fine-tune with user input
    user_correction = st.text_area("Correct the transcription (if needed):")
    
    if st.button("Fine-Tune Model"):
        if user_correction:
            # Convert correction to training format
            corrected_input = processor.tokenizer(user_correction).input_ids

            # Update dataset dynamically (simple approach)
            dataset = dataset.add_item({"input_values": input_values, "labels": corrected_input})

            # Retrain for one step
            trainer.train()

            st.success("Model fine-tuned successfully! Try another audio file.")