import torch import torchaudio from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq, BitsAndBytesConfig import gradio as gr import os import time # Load model and processor (runs once on startup) model_name = "ibm-granite/granite-speech-3.2-8b" device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") print("Loading processor...") speech_granite_processor = AutoProcessor.from_pretrained( model_name, trust_remote_code=True) tokenizer = speech_granite_processor.tokenizer print("Loading model with 4-bit quantization...") quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True ) speech_granite = AutoModelForSpeechSeq2Seq.from_pretrained( model_name, quantization_config=quantization_config, device_map="auto", trust_remote_code=True ) print("Model loaded successfully") def transcribe_audio(audio_input): """Process audio input and return transcription""" start_time = time.time() logs = [f"Audio input received: {type(audio_input)}"] if audio_input is None: return "Error: No audio provided.", 0.0 try: # Handle different audio input formats if isinstance(audio_input, tuple) and len(audio_input) == 2: # Microphone input: (sample_rate, numpy_array) logs.append("Processing microphone input") sr, wav_np = audio_input wav = torch.from_numpy(wav_np).float() # Make sure we have the right dimensions [channels, time] if len(wav.shape) == 1: wav = wav.unsqueeze(0) else: # File input: filepath string logs.append(f"Processing file input: {audio_input}") wav, sr = torchaudio.load(audio_input, normalize=True) logs.append(f"Loaded audio file with sample rate {sr}Hz and shape {wav.shape}") # Convert to mono if stereo if wav.shape[0] > 1: wav = torch.mean(wav, dim=0, keepdim=True) logs.append("Converted stereo to mono") # Resample to 16kHz if needed if sr != 16000: resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000) wav = resampler(wav) sr = 16000 logs.append(f"Resampled to {sr}Hz") logs.append(f"Final audio: sample rate {sr}Hz, shape {wav.shape}, min: {wav.min().item()}, max: {wav.max().item()}") # Verify audio format matches what the model expects assert wav.shape[0] == 1 and sr == 16000, "Audio must be mono and 16kHz" # Create text prompt chat = [ { "role": "system", "content": "Knowledge Cutoff Date: April 2024.\nToday's Date: December 19, 2024.\nYou are Granite, developed by IBM. You are a helpful AI assistant", }, { "role": "user", "content": "<|audio|>can you transcribe the speech into a written format?", } ] text = tokenizer.apply_chat_template( chat, tokenize=False, add_generation_prompt=True ) # CRITICAL CHANGE: Pass text and waveform directly to processor (don't pass audio as named param) logs.append("Preparing model inputs") model_inputs = speech_granite_processor( text, wav, device=device, # Explicitly set device return_tensors="pt", ).to(device) # Generate transcription logs.append("Generating transcription") model_outputs = speech_granite.generate( **model_inputs, max_new_tokens=1000, num_beams=4, do_sample=False, min_length=1, top_p=1.0, repetition_penalty=3.0, length_penalty=1.0, temperature=1.0, bos_token_id=tokenizer.bos_token_id, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id, ) # Extract the generated text (skipping input tokens) logs.append("Processing output") num_input_tokens = model_inputs["input_ids"].shape[-1] new_tokens = torch.unsqueeze(model_outputs[0, num_input_tokens:], dim=0) output_text = tokenizer.batch_decode( new_tokens, add_special_tokens=False, skip_special_tokens=True ) transcription = output_text[0].strip().upper() logs.append(f"Transcription complete: {transcription[:50]}...") except Exception as e: import traceback error_trace = traceback.format_exc() print(error_trace) print("\n".join(logs)) return f"Error: {str(e)}\n\nLogs:\n" + "\n".join(logs), 0.0 processing_time = round(time.time() - start_time, 2) return transcription, processing_time # Create Gradio interface title = "IBM Granite Speech-to-Text (8B Quantized)" description = """ Transcribe speech using IBM's Granite Speech 3.2 8B model (loaded in 4-bit). Upload an audio file or use your microphone to record speech. """ iface = gr.Interface( fn=transcribe_audio, inputs=gr.Audio(sources=["upload", "microphone"], type="filepath"), outputs=[ gr.Textbox(label="Transcription", lines=5), gr.Number(label="Processing Time (seconds)") ], title=title, description=description, ) if __name__ == "__main__": iface.launch()