File size: 4,675 Bytes
d6298eb
43d8873
d6298eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43d8873
 
 
d6298eb
 
 
 
 
43d8873
 
d6298eb
 
 
 
 
 
 
 
f52b61a
43d8873
 
 
f52b61a
43d8873
d6298eb
 
 
 
 
 
 
 
 
 
 
 
 
f52b61a
d6298eb
 
43d8873
d6298eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43d8873
d6298eb
 
3180216
 
 
43d8873
 
 
 
3180216
 
d6298eb
 
 
 
 
 
 
 
 
 
f52b61a
 
87ac0d2
d6298eb
 
43d8873
 
d6298eb
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import torch
from transformers import CLIPProcessor, CLIPModel, WhisperProcessor, WhisperForConditionalGeneration, AutoTokenizer, AutoModelForSeq2SeqLM
import gradio as gr
import soundfile as sf

# ------------------------------
# Load Pretrained Models & Processors
# ------------------------------

print("Loading CLIP model...")
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

print("Loading Whisper model...")
whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")
whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-small")

print("Loading Flan-T5 model (instruction-tuned for better responses)...")
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-large")
text_model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-large")

# ------------------------------
# Define Projection Layers
# ------------------------------
print("Initializing image projection layer...")
# This linear layer projects CLIP's 512-dimensional image embeddings to Flan-T5's expected dimension.
# (For a real system, you would fine-tune this layer.)
image_projection = torch.nn.Linear(512, 768)

# ------------------------------
# Multi-Modal Inference Function
# ------------------------------

def multimodal_inference(text_input, image_input, audio_input):
    """
    Processes text, image, and audio inputs:
      - Text: is used directly.
      - Image: is processed via CLIP; its embedding is projected and a placeholder is appended.
      - Audio: is transcribed using Whisper.
    
    The combined prompt is then fed into Flan-T5 to generate a text response.
    """
    prompt = ""
    
    # Process text input
    if text_input:
        prompt += text_input.strip()
    
    # Process image input if provided
    if image_input is not None:
        try:
            clip_inputs = clip_processor(images=image_input, return_tensors="pt")
            with torch.no_grad():
                image_features = clip_model.get_image_features(**clip_inputs)
            # Normalize and project image features
            image_features = image_features / image_features.norm(p=2, dim=-1, keepdim=True)
            projected_image = image_projection(image_features)
            # For this demo, we append a placeholder tag to indicate image information.
            prompt += " [IMAGE_EMBEDDING]"
        except Exception as e:
            print("Error processing image:", e)
            prompt += " [IMAGE_ERROR]"
    
    # Process audio input if provided
    if audio_input is not None:
        try:
            audio, sr = sf.read(audio_input)
        except Exception as e:
            print("Error reading audio file:", e)
            return "Error processing audio input."
        try:
            whisper_inputs = whisper_processor(audio, sampling_rate=sr, return_tensors="pt")
            with torch.no_grad():
                predicted_ids = whisper_model.generate(whisper_inputs.input_features)
            transcription = whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
            prompt += " " + transcription.strip()
        except Exception as e:
            print("Error during audio transcription:", e)
            prompt += " [AUDIO_ERROR]"
    
    print("Final fused prompt:", prompt)
    
    # Tokenize and generate text using Flan-T5
    inputs = tokenizer(prompt, return_tensors="pt")
    with torch.no_grad():
        generated_ids = text_model.generate(
            **inputs,
            max_length=200,
            temperature=0.7,       # Moderate randomness
            top_p=0.9,             # Nucleus sampling to limit token choices
            repetition_penalty=1.2,# Penalize repeated tokens
            do_sample=True         # Enable sampling for more varied responses
        )
    generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
    
    return generated_text

# ------------------------------
# Gradio Interface for Hugging Face Spaces
# ------------------------------

iface = gr.Interface(
    fn=multimodal_inference,
    inputs=[
        gr.Textbox(lines=5, placeholder="Enter your text here...", label="Text Input"),
        gr.Image(type="pil", label="Image Input (Optional)"),
        gr.Audio(type="filepath", label="Audio Input (Optional)")
    ],
    outputs="text",
    title="Multi-Modal LLM Demo with Flan-T5",
    description="This demo accepts text, image, and audio inputs, processes each modality, and produces a text response using an instruction-tuned model."
)

if __name__ == "__main__":
    iface.launch()