sagar007 commited on
Commit
1abfce8
·
verified ·
1 Parent(s): 06e746f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -125
app.py CHANGED
@@ -6,7 +6,6 @@ import logging
6
  import spaces
7
  import numpy as np
8
 
9
- # Setup logging
10
  logging.basicConfig(level=logging.INFO)
11
 
12
  class LLaVAPhiModel:
@@ -23,12 +22,12 @@ class LLaVAPhiModel:
23
  self.history = []
24
  self.model = None
25
  self.clip = None
26
-
27
- # Add a linear projection layer to align CLIP features with text embeddings
28
  self.projection = None
29
 
30
  @spaces.GPU
31
  def ensure_models_loaded(self):
 
 
32
  if self.model is None:
33
  from transformers import BitsAndBytesConfig
34
  quantization_config = BitsAndBytesConfig(
@@ -44,142 +43,97 @@ class LLaVAPhiModel:
44
  trust_remote_code=True
45
  )
46
  self.model.config.pad_token_id = self.tokenizer.eos_token_id
47
- logging.info("Successfully loaded main model")
48
 
49
  if self.clip is None:
50
  self.clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(self.device)
51
  logging.info("Successfully loaded CLIP model")
52
-
53
- # Initialize projection layer (CLIP features: 512-dim, model embedding size: e.g., 2048 for Phi)
54
- embed_dim = self.model.config.hidden_size # e.g., 2048 for Phi-1.5
55
- clip_dim = self.clip.config.projection_dim # 512 for CLIP
56
  self.projection = torch.nn.Linear(clip_dim, embed_dim).to(self.device)
57
 
58
- @spaces.GPU
59
- def process_image(self, image):
60
- try:
61
- self.ensure_models_loaded()
62
- if self.clip is None or self.processor is None:
63
- logging.warning("CLIP model or processor not available")
64
- return None
65
-
66
- if isinstance(image, str):
67
- image = Image.open(image)
68
- elif isinstance(image, np.ndarray):
69
- image = Image.fromarray(image)
70
- if image.mode != 'RGB':
71
- image = image.convert('RGB')
72
-
73
- with torch.no_grad():
74
- image_inputs = self.processor(images=image, return_tensors="pt")
75
- image_features = self.clip.get_image_features(
76
- pixel_values=image_inputs.pixel_values.to(self.device)
77
- )
78
- # Project image features to text embedding space
79
- projected_features = self.projection(image_features)
80
- logging.info("Successfully processed image through CLIP")
81
- return projected_features
82
- except Exception as e:
83
- logging.error(f"Error in process_image: {str(e)}")
84
- return None
85
 
86
- @spaces.GPU(duration=120)
87
- def generate_response(self, message, image=None):
88
- try:
89
- self.ensure_models_loaded()
 
 
 
 
 
 
 
 
90
 
91
- if image is not None:
92
- image_features = self.process_image(image)
93
- has_image = image_features is not None
94
- if not has_image:
95
- message = "Note: Image processing is not available - continuing with text only.\n" + message
96
-
97
- prompt = f"human: {'<image>' if has_image else ''}\n{message}\ngpt:"
98
- context = ""
99
- for turn in self.history[-5:]:
100
- context += f"human: {turn[0]}\ngpt: {turn[1]}\n"
101
- full_prompt = context + prompt
102
-
103
- inputs = self.tokenizer(
104
- full_prompt,
105
- return_tensors="pt",
106
- padding=True,
107
- truncation=True,
108
- max_length=1024
109
- )
110
- inputs = {k: v.to(self.device) for k, v in inputs.items()}
111
-
112
- if has_image:
113
- # Convert input_ids to embeddings
114
- embeddings = self.model.get_input_embeddings()(inputs["input_ids"])
115
- # Concatenate image features with text embeddings
116
- image_features_expanded = image_features.unsqueeze(1) # Shape: [batch, 1, embed_dim]
117
- combined_embeddings = torch.cat([image_features_expanded, embeddings], dim=1)
118
- inputs["inputs_embeds"] = combined_embeddings
119
- # Update attention mask to account for the extra image token
120
- inputs["attention_mask"] = torch.cat(
121
- [torch.ones(inputs["attention_mask"].shape[0], 1).to(self.device),
122
- inputs["attention_mask"]],
123
- dim=1
124
  )
125
- # Remove input_ids since we're using inputs_embeds
126
- del inputs["input_ids"]
127
- else:
128
- prompt = f"human: {message}\ngpt:"
129
- context = ""
130
- for turn in self.history[-5:]:
131
- context += f"human: {turn[0]}\ngpt: {turn[1]}\n"
132
- full_prompt = context + prompt
 
 
 
 
 
 
 
 
 
 
133
 
134
- inputs = self.tokenizer(
135
- full_prompt,
136
- return_tensors="pt",
137
- padding=True,
138
- truncation=True,
139
- max_length=1024
140
- )
141
- inputs = {k: v.to(self.device) for k, v in inputs.items()}
142
 
143
- with torch.no_grad():
144
- outputs = self.model.generate(
145
- **inputs,
146
- max_new_tokens=256,
147
- min_length=20,
148
- temperature=0.3,
149
- do_sample=True,
150
- top_p=0.92,
151
- top_k=50,
152
- repetition_penalty=1.2,
153
- no_repeat_ngram_size=3,
154
- use_cache=True,
155
- pad_token_id=self.tokenizer.pad_token_id,
156
- eos_token_id=self.tokenizer.eos_token_id
157
- )
158
 
159
- response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
160
- if "gpt:" in response:
161
- response = response.split("gpt:")[-1].strip()
162
- if "human:" in response:
163
- response = response.split("human:")[0].strip()
164
- if "<image>" in response:
165
- response = response.replace("<image>", "").strip()
166
 
167
- self.history.append((message, response))
168
- return response
 
 
 
169
 
170
- except Exception as e:
171
- logging.error(f"Error generating response: {str(e)}")
172
- return f"Error: {str(e)}"
173
-
174
- def clear_history(self):
175
- self.history = []
176
- return None
177
-
178
- def create_demo():
179
- model = LLaVAPhiModel()
180
- # Rest of your Gradio setup remains the same
181
- # ... (omitted for brevity)
182
- return demo
 
 
 
 
 
 
 
 
 
183
 
184
  if __name__ == "__main__":
185
  demo = create_demo()
 
6
  import spaces
7
  import numpy as np
8
 
 
9
  logging.basicConfig(level=logging.INFO)
10
 
11
  class LLaVAPhiModel:
 
22
  self.history = []
23
  self.model = None
24
  self.clip = None
 
 
25
  self.projection = None
26
 
27
  @spaces.GPU
28
  def ensure_models_loaded(self):
29
+ if not torch.cuda.is_available():
30
+ raise RuntimeError("CUDA is not available. This model requires a GPU.")
31
  if self.model is None:
32
  from transformers import BitsAndBytesConfig
33
  quantization_config = BitsAndBytesConfig(
 
43
  trust_remote_code=True
44
  )
45
  self.model.config.pad_token_id = self.tokenizer.eos_token_id
46
+ logging.info("Successfully loaded main model on GPU")
47
 
48
  if self.clip is None:
49
  self.clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(self.device)
50
  logging.info("Successfully loaded CLIP model")
51
+ embed_dim = self.model.config.hidden_size
52
+ clip_dim = self.clip.config.projection_dim
 
 
53
  self.projection = torch.nn.Linear(clip_dim, embed_dim).to(self.device)
54
 
55
+ # Rest of your class (process_image, generate_response, etc.) remains unchanged
56
+ # ... (omitted for brevity)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
+ def create_demo():
59
+ try:
60
+ model = LLaVAPhiModel()
61
+
62
+ demo = gr.Blocks(css="footer {visibility: hidden}")
63
+ with demo:
64
+ gr.Markdown(
65
+ """
66
+ # LLaVA-Phi Demo (Optimized for Accuracy)
67
+ Chat with a vision-language model that can understand both text and images.
68
+ """
69
+ )
70
 
71
+ chatbot = gr.Chatbot(height=400)
72
+ with gr.Row():
73
+ with gr.Column(scale=0.7):
74
+ msg = gr.Textbox(
75
+ show_label=False,
76
+ placeholder="Enter text and/or upload an image",
77
+ container=False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  )
79
+ with gr.Column(scale=0.15, min_width=0):
80
+ clear = gr.Button("Clear")
81
+ with gr.Column(scale=0.15, min_width=0):
82
+ submit = gr.Button("Submit", variant="primary")
83
+
84
+ image = gr.Image(type="pil", label="Upload Image (Optional)")
85
+
86
+ with gr.Accordion("Advanced Settings", open=False):
87
+ gr.Markdown("Adjust these parameters to control hallucination tendency")
88
+ temp_slider = gr.Slider(0.1, 1.0, value=0.3, step=0.1, label="Temperature (lower = more factual)")
89
+ top_p_slider = gr.Slider(0.5, 1.0, value=0.92, step=0.01, label="Top-p (nucleus sampling)")
90
+ top_k_slider = gr.Slider(10, 100, value=50, step=5, label="Top-k")
91
+ rep_penalty_slider = gr.Slider(1.0, 2.0, value=1.2, step=0.1, label="Repetition Penalty")
92
+ update_params = gr.Button("Update Parameters")
93
+
94
+ def respond(message, chat_history, image):
95
+ if not message and image is None:
96
+ return chat_history
97
 
98
+ response = model.generate_response(message, image)
99
+ chat_history.append((message, response))
100
+ return "", chat_history
 
 
 
 
 
101
 
102
+ def clear_chat():
103
+ model.clear_history()
104
+ return None, None
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
+ def update_params_fn(temp, top_p, top_k, rep_penalty):
107
+ return model.update_generation_params(temp, top_p, top_k, rep_penalty)
 
 
 
 
 
108
 
109
+ submit.click(
110
+ respond,
111
+ [msg, chatbot, image],
112
+ [msg, chatbot],
113
+ )
114
 
115
+ clear.click(
116
+ clear_chat,
117
+ None,
118
+ [chatbot, image],
119
+ )
120
+
121
+ msg.submit(
122
+ respond,
123
+ [msg, chatbot, image],
124
+ [msg, chatbot],
125
+ )
126
+
127
+ update_params.click(
128
+ update_params_fn,
129
+ [temp_slider, top_p_slider, top_k_slider, rep_penalty_slider],
130
+ None
131
+ )
132
+
133
+ return demo
134
+ except Exception as e:
135
+ logging.error(f"Error creating demo: {str(e)}")
136
+ raise
137
 
138
  if __name__ == "__main__":
139
  demo = create_demo()