Spaces:
Sleeping
Sleeping
import sys | |
import subprocess | |
def upgrade_packages(): | |
try: | |
print("Upgrading transformers and accelerate...") | |
subprocess.check_call([ | |
sys.executable, "-m", "pip", "install", "--upgrade", | |
"transformers>=4.31.0", "accelerate>=0.20.0" | |
]) | |
print("Upgrade complete.") | |
except Exception as e: | |
print("Error upgrading packages:", e) | |
# Attempt to upgrade the packages | |
upgrade_packages() | |
# Now import the libraries | |
import torch | |
from transformers import CLIPProcessor, CLIPModel, WhisperProcessor, WhisperForConditionalGeneration, AutoTokenizer, AutoModelForCausalLM | |
import gradio as gr | |
import soundfile as sf | |
# ------------------------------ | |
# Load Pretrained Models & Processors | |
# ------------------------------ | |
print("Loading CLIP model...") | |
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") | |
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
print("Loading Whisper model...") | |
whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small") | |
whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-small") | |
print("Loading GPT-2 model (placeholder for your text model)...") | |
tokenizer = AutoTokenizer.from_pretrained("gpt2") | |
text_model = AutoModelForCausalLM.from_pretrained("gpt2") | |
# ------------------------------ | |
# Define Projection Layers | |
# ------------------------------ | |
print("Initializing image projection layer...") | |
image_projection = torch.nn.Linear(512, 768) | |
# ------------------------------ | |
# Multi-Modal Inference Function | |
# ------------------------------ | |
def multimodal_inference(text_input, image_input, audio_input): | |
""" | |
Processes text, image, and audio inputs. | |
- Text is added directly. | |
- The image is processed via CLIP, its embedding is projected, and a placeholder tag is appended. | |
- Audio is transcribed using Whisper and appended. | |
The final prompt is sent to the text model for generation. | |
""" | |
prompt = "" | |
# Process text input | |
if text_input: | |
prompt += text_input.strip() | |
# Process image input if provided | |
if image_input is not None: | |
try: | |
clip_inputs = clip_processor(images=image_input, return_tensors="pt") | |
with torch.no_grad(): | |
image_features = clip_model.get_image_features(**clip_inputs) | |
# Normalize image features | |
image_features = image_features / image_features.norm(p=2, dim=-1, keepdim=True) | |
# Project image embedding into GPT-2's embedding space | |
projected_image = image_projection(image_features) | |
# For demo purposes, we append a placeholder tag. | |
prompt += " [IMAGE_EMBEDDING]" | |
except Exception as e: | |
print("Error processing image:", e) | |
prompt += " [IMAGE_ERROR]" | |
# Process audio input if provided | |
if audio_input is not None: | |
try: | |
audio, sr = sf.read(audio_input) | |
except Exception as e: | |
print("Error reading audio file:", e) | |
return "Error processing audio input." | |
try: | |
whisper_inputs = whisper_processor(audio, sampling_rate=sr, return_tensors="pt") | |
with torch.no_grad(): | |
predicted_ids = whisper_model.generate(whisper_inputs.input_features) | |
transcription = whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] | |
prompt += " " + transcription.strip() | |
except Exception as e: | |
print("Error during audio transcription:", e) | |
prompt += " [AUDIO_ERROR]" | |
print("Final fused prompt:", prompt) | |
# Generate text response using the text model | |
inputs = tokenizer(prompt, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = text_model.generate(**inputs, max_length=200) | |
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return generated_text | |
# ------------------------------ | |
# Gradio Interface for Hugging Face Spaces | |
# ------------------------------ | |
iface = gr.Interface( | |
fn=multimodal_inference, | |
inputs=[ | |
gr.inputs.Textbox(lines=5, placeholder="Enter your text here...", label="Text Input"), | |
gr.inputs.Image(type="pil", label="Image Input (Optional)"), | |
gr.inputs.Audio(source="upload", type="filepath", label="Audio Input (Optional)") | |
], | |
outputs="text", | |
title="Multi-Modal LLM Demo", | |
description="This demo accepts text, image, and audio inputs, processes each modality, and produces a text response." | |
) | |
if __name__ == "__main__": | |
iface.launch() | |