atharvasc27112001 commited on
Commit
d6298eb
·
verified ·
1 Parent(s): a2b8602

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +113 -0
app.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import CLIPProcessor, CLIPModel, WhisperProcessor, WhisperForConditionalGeneration, AutoTokenizer, AutoModelForCausalLM
3
+ import gradio as gr
4
+ import soundfile as sf
5
+
6
+ # ------------------------------
7
+ # Load Pretrained Models & Processors
8
+ # ------------------------------
9
+
10
+ print("Loading CLIP model...")
11
+ clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
12
+ clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
13
+
14
+ print("Loading Whisper model...")
15
+ whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")
16
+ whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-small")
17
+
18
+ print("Loading GPT-2 model (placeholder for your text model)...")
19
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
20
+ text_model = AutoModelForCausalLM.from_pretrained("gpt2")
21
+
22
+ # ------------------------------
23
+ # Define Projection Layers
24
+ # ------------------------------
25
+ # Here we create a simple linear layer to project CLIP's image embeddings (512 dims)
26
+ # to GPT-2's embedding dimension (768 dims). In a full project, this layer would be fine-tuned.
27
+ print("Initializing image projection layer...")
28
+ image_projection = torch.nn.Linear(512, 768)
29
+
30
+ # ------------------------------
31
+ # Multi-Modal Inference Function
32
+ # ------------------------------
33
+
34
+ def multimodal_inference(text_input, image_input, audio_input):
35
+ """
36
+ Processes three modalities:
37
+ - Text: used directly.
38
+ - Image: processed via CLIP and projected.
39
+ - Audio: transcribed using Whisper.
40
+
41
+ The function fuses the outputs by concatenating their textual representations,
42
+ and then feeds the final prompt to the text model for generation.
43
+ """
44
+ prompt = ""
45
+
46
+ # Process text input
47
+ if text_input:
48
+ prompt += text_input.strip()
49
+
50
+ # Process image input if provided
51
+ if image_input is not None:
52
+ try:
53
+ clip_inputs = clip_processor(images=image_input, return_tensors="pt")
54
+ with torch.no_grad():
55
+ image_features = clip_model.get_image_features(**clip_inputs)
56
+ # Normalize image features
57
+ image_features = image_features / image_features.norm(p=2, dim=-1, keepdim=True)
58
+ # Project image embedding into GPT-2's embedding space
59
+ projected_image = image_projection(image_features)
60
+ # For demo purposes, we simply append a placeholder tag.
61
+ # In a full system, you would integrate these embeddings into your model.
62
+ prompt += " [IMAGE_EMBEDDING]"
63
+ except Exception as e:
64
+ print("Error processing image:", e)
65
+ prompt += " [IMAGE_ERROR]"
66
+
67
+ # Process audio input if provided
68
+ if audio_input is not None:
69
+ try:
70
+ # Gradio provides a filepath for the audio file.
71
+ audio, sr = sf.read(audio_input)
72
+ except Exception as e:
73
+ print("Error reading audio file:", e)
74
+ return "Error processing audio input."
75
+ try:
76
+ whisper_inputs = whisper_processor(audio, sampling_rate=sr, return_tensors="pt")
77
+ with torch.no_grad():
78
+ predicted_ids = whisper_model.generate(whisper_inputs.input_features)
79
+ transcription = whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
80
+ prompt += " " + transcription.strip()
81
+ except Exception as e:
82
+ print("Error during audio transcription:", e)
83
+ prompt += " [AUDIO_ERROR]"
84
+
85
+ # Debug: Print the final prompt for verification
86
+ print("Final fused prompt:", prompt)
87
+
88
+ # Generate text response using the text model
89
+ inputs = tokenizer(prompt, return_tensors="pt")
90
+ with torch.no_grad():
91
+ outputs = text_model.generate(**inputs, max_length=200)
92
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
93
+
94
+ return generated_text
95
+
96
+ # ------------------------------
97
+ # Gradio Interface for Hugging Face Spaces
98
+ # ------------------------------
99
+
100
+ iface = gr.Interface(
101
+ fn=multimodal_inference,
102
+ inputs=[
103
+ gr.inputs.Textbox(lines=5, placeholder="Enter your text here...", label="Text Input"),
104
+ gr.inputs.Image(type="pil", label="Image Input (Optional)"),
105
+ gr.inputs.Audio(source="upload", type="filepath", label="Audio Input (Optional)")
106
+ ],
107
+ outputs="text",
108
+ title="Multi-Modal LLM Demo",
109
+ description="This demo accepts text, image, and audio inputs, processes each modality, and produces a text response."
110
+ )
111
+
112
+ if __name__ == "__main__":
113
+ iface.launch()