File size: 4,653 Bytes
d6de1a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d6298eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d6de1a9
 
 
 
 
d6298eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d6de1a9
d6298eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import sys
import subprocess

def upgrade_packages():
    try:
        print("Upgrading transformers and accelerate...")
        subprocess.check_call([
            sys.executable, "-m", "pip", "install", "--upgrade",
            "transformers>=4.31.0", "accelerate>=0.20.0"
        ])
        print("Upgrade complete.")
    except Exception as e:
        print("Error upgrading packages:", e)

# Attempt to upgrade the packages
upgrade_packages()

# Now import the libraries
import torch
from transformers import CLIPProcessor, CLIPModel, WhisperProcessor, WhisperForConditionalGeneration, AutoTokenizer, AutoModelForCausalLM
import gradio as gr
import soundfile as sf

# ------------------------------
# Load Pretrained Models & Processors
# ------------------------------

print("Loading CLIP model...")
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

print("Loading Whisper model...")
whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")
whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-small")

print("Loading GPT-2 model (placeholder for your text model)...")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
text_model = AutoModelForCausalLM.from_pretrained("gpt2")

# ------------------------------
# Define Projection Layers
# ------------------------------
print("Initializing image projection layer...")
image_projection = torch.nn.Linear(512, 768)

# ------------------------------
# Multi-Modal Inference Function
# ------------------------------

def multimodal_inference(text_input, image_input, audio_input):
    """
    Processes text, image, and audio inputs.
      - Text is added directly.
      - The image is processed via CLIP, its embedding is projected, and a placeholder tag is appended.
      - Audio is transcribed using Whisper and appended.
    The final prompt is sent to the text model for generation.
    """
    prompt = ""
    
    # Process text input
    if text_input:
        prompt += text_input.strip()
    
    # Process image input if provided
    if image_input is not None:
        try:
            clip_inputs = clip_processor(images=image_input, return_tensors="pt")
            with torch.no_grad():
                image_features = clip_model.get_image_features(**clip_inputs)
            # Normalize image features
            image_features = image_features / image_features.norm(p=2, dim=-1, keepdim=True)
            # Project image embedding into GPT-2's embedding space
            projected_image = image_projection(image_features)
            # For demo purposes, we append a placeholder tag.
            prompt += " [IMAGE_EMBEDDING]"
        except Exception as e:
            print("Error processing image:", e)
            prompt += " [IMAGE_ERROR]"
    
    # Process audio input if provided
    if audio_input is not None:
        try:
            audio, sr = sf.read(audio_input)
        except Exception as e:
            print("Error reading audio file:", e)
            return "Error processing audio input."
        try:
            whisper_inputs = whisper_processor(audio, sampling_rate=sr, return_tensors="pt")
            with torch.no_grad():
                predicted_ids = whisper_model.generate(whisper_inputs.input_features)
            transcription = whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
            prompt += " " + transcription.strip()
        except Exception as e:
            print("Error during audio transcription:", e)
            prompt += " [AUDIO_ERROR]"
    
    print("Final fused prompt:", prompt)
    
    # Generate text response using the text model
    inputs = tokenizer(prompt, return_tensors="pt")
    with torch.no_grad():
        outputs = text_model.generate(**inputs, max_length=200)
    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    return generated_text

# ------------------------------
# Gradio Interface for Hugging Face Spaces
# ------------------------------

iface = gr.Interface(
    fn=multimodal_inference,
    inputs=[
        gr.inputs.Textbox(lines=5, placeholder="Enter your text here...", label="Text Input"),
        gr.inputs.Image(type="pil", label="Image Input (Optional)"),
        gr.inputs.Audio(source="upload", type="filepath", label="Audio Input (Optional)")
    ],
    outputs="text",
    title="Multi-Modal LLM Demo",
    description="This demo accepts text, image, and audio inputs, processes each modality, and produces a text response."
)

if __name__ == "__main__":
    iface.launch()