Spaces:
Sleeping
Sleeping
import torch | |
from transformers import CLIPProcessor, CLIPModel, WhisperProcessor, WhisperForConditionalGeneration, AutoTokenizer, AutoModelForSeq2SeqLM | |
import gradio as gr | |
import soundfile as sf | |
# ------------------------------ | |
# Load Pretrained Models & Processors | |
# ------------------------------ | |
print("Loading CLIP model...") | |
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") | |
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
print("Loading Whisper model...") | |
whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small") | |
whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-small") | |
print("Loading Flan-T5 model (instruction-tuned for better responses)...") | |
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-large") | |
text_model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-large") | |
# ------------------------------ | |
# Define Projection Layers | |
# ------------------------------ | |
print("Initializing image projection layer...") | |
# This linear layer projects CLIP's 512-dimensional image embeddings to Flan-T5's expected dimension. | |
# (For a real system, you would fine-tune this layer.) | |
image_projection = torch.nn.Linear(512, 768) | |
# ------------------------------ | |
# Multi-Modal Inference Function | |
# ------------------------------ | |
def multimodal_inference(text_input, image_input, audio_input): | |
""" | |
Processes text, image, and audio inputs: | |
- Text: is used directly. | |
- Image: is processed via CLIP; its embedding is projected and a placeholder is appended. | |
- Audio: is transcribed using Whisper. | |
The combined prompt is then fed into Flan-T5 to generate a text response. | |
""" | |
prompt = "" | |
# Process text input | |
if text_input: | |
prompt += text_input.strip() | |
# Process image input if provided | |
if image_input is not None: | |
try: | |
clip_inputs = clip_processor(images=image_input, return_tensors="pt") | |
with torch.no_grad(): | |
image_features = clip_model.get_image_features(**clip_inputs) | |
# Normalize and project image features | |
image_features = image_features / image_features.norm(p=2, dim=-1, keepdim=True) | |
projected_image = image_projection(image_features) | |
# For this demo, we append a placeholder tag to indicate image information. | |
prompt += " [IMAGE_EMBEDDING]" | |
except Exception as e: | |
print("Error processing image:", e) | |
prompt += " [IMAGE_ERROR]" | |
# Process audio input if provided | |
if audio_input is not None: | |
try: | |
audio, sr = sf.read(audio_input) | |
except Exception as e: | |
print("Error reading audio file:", e) | |
return "Error processing audio input." | |
try: | |
whisper_inputs = whisper_processor(audio, sampling_rate=sr, return_tensors="pt") | |
with torch.no_grad(): | |
predicted_ids = whisper_model.generate(whisper_inputs.input_features) | |
transcription = whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] | |
prompt += " " + transcription.strip() | |
except Exception as e: | |
print("Error during audio transcription:", e) | |
prompt += " [AUDIO_ERROR]" | |
print("Final fused prompt:", prompt) | |
# Tokenize and generate text using Flan-T5 | |
inputs = tokenizer(prompt, return_tensors="pt") | |
with torch.no_grad(): | |
generated_ids = text_model.generate( | |
**inputs, | |
max_length=200, | |
temperature=0.7, # Moderate randomness | |
top_p=0.9, # Nucleus sampling to limit token choices | |
repetition_penalty=1.2,# Penalize repeated tokens | |
do_sample=True # Enable sampling for more varied responses | |
) | |
generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) | |
return generated_text | |
# ------------------------------ | |
# Gradio Interface for Hugging Face Spaces | |
# ------------------------------ | |
iface = gr.Interface( | |
fn=multimodal_inference, | |
inputs=[ | |
gr.Textbox(lines=5, placeholder="Enter your text here...", label="Text Input"), | |
gr.Image(type="pil", label="Image Input (Optional)"), | |
gr.Audio(type="filepath", label="Audio Input (Optional)") | |
], | |
outputs="text", | |
title="Multi-Modal LLM Demo with Flan-T5", | |
description="This demo accepts text, image, and audio inputs, processes each modality, and produces a text response using an instruction-tuned model." | |
) | |
if __name__ == "__main__": | |
iface.launch() | |