Gapeleon's picture
Update app.py
b0e4499 verified
import torch
import torchaudio
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq, BitsAndBytesConfig
import gradio as gr
import os
import time
# Load model and processor (runs once on startup)
model_name = "ibm-granite/granite-speech-3.2-8b"
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
print("Loading processor...")
speech_granite_processor = AutoProcessor.from_pretrained(
model_name, trust_remote_code=True)
tokenizer = speech_granite_processor.tokenizer
print("Loading model with 4-bit quantization...")
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True
)
speech_granite = AutoModelForSpeechSeq2Seq.from_pretrained(
model_name,
quantization_config=quantization_config,
device_map="auto",
trust_remote_code=True
)
print("Model loaded successfully")
def transcribe_audio(audio_input):
"""Process audio input and return transcription"""
start_time = time.time()
logs = [f"Audio input received: {type(audio_input)}"]
if audio_input is None:
return "Error: No audio provided.", 0.0
try:
# Handle different audio input formats
if isinstance(audio_input, tuple) and len(audio_input) == 2:
# Microphone input: (sample_rate, numpy_array)
logs.append("Processing microphone input")
sr, wav_np = audio_input
wav = torch.from_numpy(wav_np).float()
# Make sure we have the right dimensions [channels, time]
if len(wav.shape) == 1:
wav = wav.unsqueeze(0)
else:
# File input: filepath string
logs.append(f"Processing file input: {audio_input}")
wav, sr = torchaudio.load(audio_input, normalize=True)
logs.append(f"Loaded audio file with sample rate {sr}Hz and shape {wav.shape}")
# Convert to mono if stereo
if wav.shape[0] > 1:
wav = torch.mean(wav, dim=0, keepdim=True)
logs.append("Converted stereo to mono")
# Resample to 16kHz if needed
if sr != 16000:
resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000)
wav = resampler(wav)
sr = 16000
logs.append(f"Resampled to {sr}Hz")
logs.append(f"Final audio: sample rate {sr}Hz, shape {wav.shape}, min: {wav.min().item()}, max: {wav.max().item()}")
# Verify audio format matches what the model expects
assert wav.shape[0] == 1 and sr == 16000, "Audio must be mono and 16kHz"
# Create text prompt
chat = [
{
"role": "system",
"content": "Knowledge Cutoff Date: April 2024.\nToday's Date: December 19, 2024.\nYou are Granite, developed by IBM. You are a helpful AI assistant",
},
{
"role": "user",
"content": "<|audio|>can you transcribe the speech into a written format?",
}
]
text = tokenizer.apply_chat_template(
chat, tokenize=False, add_generation_prompt=True
)
# CRITICAL CHANGE: Pass text and waveform directly to processor (don't pass audio as named param)
logs.append("Preparing model inputs")
model_inputs = speech_granite_processor(
text,
wav,
device=device, # Explicitly set device
return_tensors="pt",
).to(device)
# Generate transcription
logs.append("Generating transcription")
model_outputs = speech_granite.generate(
**model_inputs,
max_new_tokens=1000,
num_beams=4,
do_sample=False,
min_length=1,
top_p=1.0,
repetition_penalty=3.0,
length_penalty=1.0,
temperature=1.0,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
)
# Extract the generated text (skipping input tokens)
logs.append("Processing output")
num_input_tokens = model_inputs["input_ids"].shape[-1]
new_tokens = torch.unsqueeze(model_outputs[0, num_input_tokens:], dim=0)
output_text = tokenizer.batch_decode(
new_tokens, add_special_tokens=False, skip_special_tokens=True
)
transcription = output_text[0].strip().upper()
logs.append(f"Transcription complete: {transcription[:50]}...")
except Exception as e:
import traceback
error_trace = traceback.format_exc()
print(error_trace)
print("\n".join(logs))
return f"Error: {str(e)}\n\nLogs:\n" + "\n".join(logs), 0.0
processing_time = round(time.time() - start_time, 2)
return transcription, processing_time
# Create Gradio interface
title = "IBM Granite Speech-to-Text (8B Quantized)"
description = """
Transcribe speech using IBM's Granite Speech 3.2 8B model (loaded in 4-bit).
Upload an audio file or use your microphone to record speech.
"""
iface = gr.Interface(
fn=transcribe_audio,
inputs=gr.Audio(sources=["upload", "microphone"], type="filepath"),
outputs=[
gr.Textbox(label="Transcription", lines=5),
gr.Number(label="Processing Time (seconds)")
],
title=title,
description=description,
)
if __name__ == "__main__":
iface.launch()