Spaces:
Sleeping
Sleeping
import torch | |
import torchaudio | |
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq, BitsAndBytesConfig | |
import gradio as gr | |
import os | |
import time | |
# Load model and processor (runs once on startup) | |
model_name = "ibm-granite/granite-speech-3.2-8b" | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"Using device: {device}") | |
print("Loading processor...") | |
speech_granite_processor = AutoProcessor.from_pretrained( | |
model_name, trust_remote_code=True) | |
tokenizer = speech_granite_processor.tokenizer | |
print("Loading model with 4-bit quantization...") | |
quantization_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_compute_dtype=torch.float16, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_use_double_quant=True | |
) | |
speech_granite = AutoModelForSpeechSeq2Seq.from_pretrained( | |
model_name, | |
quantization_config=quantization_config, | |
device_map="auto", | |
trust_remote_code=True | |
) | |
print("Model loaded successfully") | |
def transcribe_audio(audio_input): | |
"""Process audio input and return transcription""" | |
start_time = time.time() | |
logs = [f"Audio input received: {type(audio_input)}"] | |
if audio_input is None: | |
return "Error: No audio provided.", 0.0 | |
try: | |
# Handle different audio input formats | |
if isinstance(audio_input, tuple) and len(audio_input) == 2: | |
# Microphone input: (sample_rate, numpy_array) | |
logs.append("Processing microphone input") | |
sr, wav_np = audio_input | |
wav = torch.from_numpy(wav_np).float() | |
# Make sure we have the right dimensions [channels, time] | |
if len(wav.shape) == 1: | |
wav = wav.unsqueeze(0) | |
else: | |
# File input: filepath string | |
logs.append(f"Processing file input: {audio_input}") | |
wav, sr = torchaudio.load(audio_input, normalize=True) | |
logs.append(f"Loaded audio file with sample rate {sr}Hz and shape {wav.shape}") | |
# Convert to mono if stereo | |
if wav.shape[0] > 1: | |
wav = torch.mean(wav, dim=0, keepdim=True) | |
logs.append("Converted stereo to mono") | |
# Resample to 16kHz if needed | |
if sr != 16000: | |
resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000) | |
wav = resampler(wav) | |
sr = 16000 | |
logs.append(f"Resampled to {sr}Hz") | |
logs.append(f"Final audio: sample rate {sr}Hz, shape {wav.shape}, min: {wav.min().item()}, max: {wav.max().item()}") | |
# Verify audio format matches what the model expects | |
assert wav.shape[0] == 1 and sr == 16000, "Audio must be mono and 16kHz" | |
# Create text prompt | |
chat = [ | |
{ | |
"role": "system", | |
"content": "Knowledge Cutoff Date: April 2024.\nToday's Date: December 19, 2024.\nYou are Granite, developed by IBM. You are a helpful AI assistant", | |
}, | |
{ | |
"role": "user", | |
"content": "<|audio|>can you transcribe the speech into a written format?", | |
} | |
] | |
text = tokenizer.apply_chat_template( | |
chat, tokenize=False, add_generation_prompt=True | |
) | |
# CRITICAL CHANGE: Pass text and waveform directly to processor (don't pass audio as named param) | |
logs.append("Preparing model inputs") | |
model_inputs = speech_granite_processor( | |
text, | |
wav, | |
device=device, # Explicitly set device | |
return_tensors="pt", | |
).to(device) | |
# Generate transcription | |
logs.append("Generating transcription") | |
model_outputs = speech_granite.generate( | |
**model_inputs, | |
max_new_tokens=1000, | |
num_beams=4, | |
do_sample=False, | |
min_length=1, | |
top_p=1.0, | |
repetition_penalty=3.0, | |
length_penalty=1.0, | |
temperature=1.0, | |
bos_token_id=tokenizer.bos_token_id, | |
eos_token_id=tokenizer.eos_token_id, | |
pad_token_id=tokenizer.pad_token_id, | |
) | |
# Extract the generated text (skipping input tokens) | |
logs.append("Processing output") | |
num_input_tokens = model_inputs["input_ids"].shape[-1] | |
new_tokens = torch.unsqueeze(model_outputs[0, num_input_tokens:], dim=0) | |
output_text = tokenizer.batch_decode( | |
new_tokens, add_special_tokens=False, skip_special_tokens=True | |
) | |
transcription = output_text[0].strip().upper() | |
logs.append(f"Transcription complete: {transcription[:50]}...") | |
except Exception as e: | |
import traceback | |
error_trace = traceback.format_exc() | |
print(error_trace) | |
print("\n".join(logs)) | |
return f"Error: {str(e)}\n\nLogs:\n" + "\n".join(logs), 0.0 | |
processing_time = round(time.time() - start_time, 2) | |
return transcription, processing_time | |
# Create Gradio interface | |
title = "IBM Granite Speech-to-Text (8B Quantized)" | |
description = """ | |
Transcribe speech using IBM's Granite Speech 3.2 8B model (loaded in 4-bit). | |
Upload an audio file or use your microphone to record speech. | |
""" | |
iface = gr.Interface( | |
fn=transcribe_audio, | |
inputs=gr.Audio(sources=["upload", "microphone"], type="filepath"), | |
outputs=[ | |
gr.Textbox(label="Transcription", lines=5), | |
gr.Number(label="Processing Time (seconds)") | |
], | |
title=title, | |
description=description, | |
) | |
if __name__ == "__main__": | |
iface.launch() |