File size: 5,635 Bytes
2c4caee
 
 
 
 
 
 
99f2df0
2c4caee
 
 
 
 
 
 
 
 
99f2df0
2c4caee
 
 
99f2df0
2c4caee
 
 
 
 
 
 
 
 
99f2df0
8bce1e6
2c4caee
99f2df0
2c4caee
99f2df0
8bce1e6
1fcca35
2c4caee
 
99f2df0
2c4caee
8bce1e6
 
 
 
2c4caee
eb92e9b
 
 
 
8bce1e6
 
 
b0e4499
8bce1e6
99f2df0
2c4caee
99f2df0
2c4caee
8bce1e6
99f2df0
 
2c4caee
 
 
 
8bce1e6
99f2df0
8bce1e6
99f2df0
b0e4499
 
eb92e9b
99f2df0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b0e4499
8bce1e6
2c4caee
b0e4499
 
 
2c4caee
99f2df0
 
 
8bce1e6
2c4caee
 
 
 
 
 
 
 
 
 
b0e4499
 
 
2c4caee
99f2df0
 
8bce1e6
2c4caee
b0e4499
99f2df0
b0e4499
 
 
99f2df0
b0e4499
8bce1e6
99f2df0
2c4caee
 
8bce1e6
 
 
 
99f2df0
 
2c4caee
 
99f2df0
2c4caee
 
99f2df0
 
2c4caee
 
 
 
8bce1e6
99f2df0
 
 
 
2c4caee
 
 
 
 
99f2df0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import torch
import torchaudio
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq, BitsAndBytesConfig
import gradio as gr
import os
import time

# Load model and processor (runs once on startup)
model_name = "ibm-granite/granite-speech-3.2-8b"
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

print("Loading processor...")
speech_granite_processor = AutoProcessor.from_pretrained(
    model_name, trust_remote_code=True)
tokenizer = speech_granite_processor.tokenizer

print("Loading model with 4-bit quantization...")
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True
)

speech_granite = AutoModelForSpeechSeq2Seq.from_pretrained(
    model_name,
    quantization_config=quantization_config,
    device_map="auto",
    trust_remote_code=True
)
print("Model loaded successfully")

def transcribe_audio(audio_input):
    """Process audio input and return transcription"""
    start_time = time.time()
    
    logs = [f"Audio input received: {type(audio_input)}"]
    
    if audio_input is None:
        return "Error: No audio provided.", 0.0
    
    try:
        # Handle different audio input formats
        if isinstance(audio_input, tuple) and len(audio_input) == 2:  
            # Microphone input: (sample_rate, numpy_array)
            logs.append("Processing microphone input")
            sr, wav_np = audio_input
            wav = torch.from_numpy(wav_np).float()
            # Make sure we have the right dimensions [channels, time]
            if len(wav.shape) == 1:
                wav = wav.unsqueeze(0)
        else:
            # File input: filepath string
            logs.append(f"Processing file input: {audio_input}")
            wav, sr = torchaudio.load(audio_input, normalize=True)
            logs.append(f"Loaded audio file with sample rate {sr}Hz and shape {wav.shape}")
        
        # Convert to mono if stereo
        if wav.shape[0] > 1:
            wav = torch.mean(wav, dim=0, keepdim=True)
            logs.append("Converted stereo to mono")
        
        # Resample to 16kHz if needed
        if sr != 16000:
            resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000)
            wav = resampler(wav)
            sr = 16000
            logs.append(f"Resampled to {sr}Hz")
        
        logs.append(f"Final audio: sample rate {sr}Hz, shape {wav.shape}, min: {wav.min().item()}, max: {wav.max().item()}")
        
        # Verify audio format matches what the model expects
        assert wav.shape[0] == 1 and sr == 16000, "Audio must be mono and 16kHz"
        
        # Create text prompt
        chat = [
            {
                "role": "system",
                "content": "Knowledge Cutoff Date: April 2024.\nToday's Date: December 19, 2024.\nYou are Granite, developed by IBM. You are a helpful AI assistant",
            },
            {
                "role": "user",
                "content": "<|audio|>can you transcribe the speech into a written format?",
            }
        ]
        
        text = tokenizer.apply_chat_template(
            chat, tokenize=False, add_generation_prompt=True
        )
        
        # CRITICAL CHANGE: Pass text and waveform directly to processor (don't pass audio as named param)
        logs.append("Preparing model inputs")
        model_inputs = speech_granite_processor(
            text,
            wav,
            device=device,  # Explicitly set device
            return_tensors="pt",
        ).to(device)
        
        # Generate transcription
        logs.append("Generating transcription")
        model_outputs = speech_granite.generate(
            **model_inputs,
            max_new_tokens=1000,
            num_beams=4,
            do_sample=False,
            min_length=1,
            top_p=1.0,
            repetition_penalty=3.0,
            length_penalty=1.0,
            temperature=1.0,
            bos_token_id=tokenizer.bos_token_id,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.pad_token_id,
        )
        
        # Extract the generated text (skipping input tokens)
        logs.append("Processing output")
        num_input_tokens = model_inputs["input_ids"].shape[-1]
        new_tokens = torch.unsqueeze(model_outputs[0, num_input_tokens:], dim=0)
        
        output_text = tokenizer.batch_decode(
            new_tokens, add_special_tokens=False, skip_special_tokens=True
        )
        
        transcription = output_text[0].strip().upper()
        logs.append(f"Transcription complete: {transcription[:50]}...")
        
    except Exception as e:
        import traceback
        error_trace = traceback.format_exc()
        print(error_trace)
        print("\n".join(logs))
        return f"Error: {str(e)}\n\nLogs:\n" + "\n".join(logs), 0.0
    
    processing_time = round(time.time() - start_time, 2)
    return transcription, processing_time

# Create Gradio interface
title = "IBM Granite Speech-to-Text (8B Quantized)"
description = """
Transcribe speech using IBM's Granite Speech 3.2 8B model (loaded in 4-bit).
Upload an audio file or use your microphone to record speech.
"""

iface = gr.Interface(
    fn=transcribe_audio,
    inputs=gr.Audio(sources=["upload", "microphone"], type="filepath"),
    outputs=[
        gr.Textbox(label="Transcription", lines=5),
        gr.Number(label="Processing Time (seconds)")
    ],
    title=title,
    description=description,
)

if __name__ == "__main__":
    iface.launch()