Spaces:
Sleeping
Sleeping
import torch | |
from transformers import CLIPProcessor, CLIPModel, WhisperProcessor, WhisperForConditionalGeneration, AutoTokenizer, AutoModelForCausalLM | |
import gradio as gr | |
import soundfile as sf | |
# ------------------------------ | |
# Load Pretrained Models & Processors | |
# ------------------------------ | |
print("Loading CLIP model...") | |
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") | |
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
print("Loading Whisper model...") | |
whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small") | |
whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-small") | |
print("Loading GPT-2 model (placeholder for your text model)...") | |
tokenizer = AutoTokenizer.from_pretrained("gpt2") | |
text_model = AutoModelForCausalLM.from_pretrained("gpt2") | |
# ------------------------------ | |
# Define Projection Layers | |
# ------------------------------ | |
# Here we create a simple linear layer to project CLIP's image embeddings (512 dims) | |
# to GPT-2's embedding dimension (768 dims). In a full project, this layer would be fine-tuned. | |
print("Initializing image projection layer...") | |
image_projection = torch.nn.Linear(512, 768) | |
# ------------------------------ | |
# Multi-Modal Inference Function | |
# ------------------------------ | |
def multimodal_inference(text_input, image_input, audio_input): | |
""" | |
Processes three modalities: | |
- Text: used directly. | |
- Image: processed via CLIP and projected. | |
- Audio: transcribed using Whisper. | |
The function fuses the outputs by concatenating their textual representations, | |
and then feeds the final prompt to the text model for generation. | |
""" | |
prompt = "" | |
# Process text input | |
if text_input: | |
prompt += text_input.strip() | |
# Process image input if provided | |
if image_input is not None: | |
try: | |
clip_inputs = clip_processor(images=image_input, return_tensors="pt") | |
with torch.no_grad(): | |
image_features = clip_model.get_image_features(**clip_inputs) | |
# Normalize image features | |
image_features = image_features / image_features.norm(p=2, dim=-1, keepdim=True) | |
# Project image embedding into GPT-2's embedding space | |
projected_image = image_projection(image_features) | |
# For demo purposes, we simply append a placeholder tag. | |
# In a full system, you would integrate these embeddings into your model. | |
prompt += " [IMAGE_EMBEDDING]" | |
except Exception as e: | |
print("Error processing image:", e) | |
prompt += " [IMAGE_ERROR]" | |
# Process audio input if provided | |
if audio_input is not None: | |
try: | |
# Gradio provides a filepath for the audio file. | |
audio, sr = sf.read(audio_input) | |
except Exception as e: | |
print("Error reading audio file:", e) | |
return "Error processing audio input." | |
try: | |
whisper_inputs = whisper_processor(audio, sampling_rate=sr, return_tensors="pt") | |
with torch.no_grad(): | |
predicted_ids = whisper_model.generate(whisper_inputs.input_features) | |
transcription = whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] | |
prompt += " " + transcription.strip() | |
except Exception as e: | |
print("Error during audio transcription:", e) | |
prompt += " [AUDIO_ERROR]" | |
# Debug: Print the final prompt for verification | |
print("Final fused prompt:", prompt) | |
# Generate text response using the text model | |
inputs = tokenizer(prompt, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = text_model.generate(**inputs, max_length=200) | |
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return generated_text | |
# ------------------------------ | |
# Gradio Interface for Hugging Face Spaces | |
# ------------------------------ | |
iface = gr.Interface( | |
fn=multimodal_inference, | |
inputs=[ | |
gr.inputs.Textbox(lines=5, placeholder="Enter your text here...", label="Text Input"), | |
gr.inputs.Image(type="pil", label="Image Input (Optional)"), | |
gr.inputs.Audio(source="upload", type="filepath", label="Audio Input (Optional)") | |
], | |
outputs="text", | |
title="Multi-Modal LLM Demo", | |
description="This demo accepts text, image, and audio inputs, processes each modality, and produces a text response." | |
) | |
if __name__ == "__main__": | |
iface.launch() | |