KeerthiVM commited on
Commit
8afe134
·
1 Parent(s): cca45c8
Files changed (1) hide show
  1. app.py +57 -121
app.py CHANGED
@@ -115,27 +115,6 @@ class Blip2QFormer(nn.Module):
115
 
116
  return outputs.last_hidden_state
117
 
118
- class LayerNorm(nn.LayerNorm):
119
- """Subclass torch's LayerNorm to handle fp16."""
120
-
121
- def forward(self, x: torch.Tensor):
122
- orig_type = x.dtype
123
- ret = super().forward(x.type(torch.float32))
124
- return ret.type(orig_type)
125
-
126
-
127
- class ViTClassifier(nn.Module):
128
- def __init__(self, vit, ln_vision, num_labels):
129
- super(ViTClassifier, self).__init__()
130
- self.vit = vit # Pretrained ViT from MiniGPT-4
131
- self.ln_vision = ln_vision # LayerNorm from MiniGPT-4
132
- self.classifier = nn.Linear(vit.num_features, num_labels)
133
-
134
- def forward(self, x):
135
- features = self.ln_vision(self.vit(x)) # [batch, seq_len, dim]
136
- cls_token = features[:, 0, :] # Extract CLS token
137
- return self.classifier(cls_token)
138
-
139
 
140
  class SkinGPT4(nn.Module):
141
  def __init__(self, vit_checkpoint_path,
@@ -161,10 +140,7 @@ class SkinGPT4(nn.Module):
161
  self.q_former.load_from_pretrained(q_former_model)
162
  for param in self.q_former.parameters():
163
  param.requires_grad = False
164
- for module in [self.vit, self.ln_vision, self.q_former]:
165
- for param in module.parameters():
166
- param.requires_grad = False
167
- module.eval()
168
  print("Loaded QFormer")
169
 
170
  self.tokenizer = LlamaTokenizer.from_pretrained(
@@ -185,8 +161,10 @@ class SkinGPT4(nn.Module):
185
  print(f"Q-Former output dim: {self.q_former.bert_config.hidden_size}")
186
  print(f"LLaMA input dim: {self.llama.config.hidden_size}")
187
 
188
- for param in self.llama_proj.parameters():
189
- param.requires_grad = False
 
 
190
 
191
  def _init_vit(self, vit_checkpoint_path):
192
  """Initialize EVA-ViT-G with paper specifications"""
@@ -213,9 +191,6 @@ class SkinGPT4(nn.Module):
213
  # 4. Load weights while ignoring classifier head
214
  vit.load_state_dict(vit_weights, strict=False)
215
 
216
- # 5. Freeze according to paper specs
217
- for param in vit.parameters():
218
- param.requires_grad = False
219
 
220
  return vit.eval()
221
 
@@ -226,27 +201,13 @@ class SkinGPT4(nn.Module):
226
  "": 0 if torch.cuda.is_available() else "cpu"
227
  }
228
  # First try loading with device_map="auto"
229
- try:
230
- model = LlamaForCausalLM.from_pretrained(
231
- "meta-llama/Llama-2-13b-chat-hf",
232
- token=token,
233
- torch_dtype=torch.float16,
234
- device_map=device_map,
235
- low_cpu_mem_usage=True
236
- )
237
- except ImportError:
238
- # Fallback to CPU-offloading if GPU memory is insufficient
239
- with init_empty_weights():
240
- model = LlamaForCausalLM.from_pretrained(
241
- "meta-llama/Llama-2-13b-chat-hf",
242
- token=token,
243
- torch_dtype=torch.float16
244
- )
245
- model = model.to(self.device)
246
-
247
- # Freeze all parameters
248
- for param in model.parameters():
249
- param.requires_grad = False
250
 
251
  return model.eval()
252
 
@@ -259,12 +220,7 @@ class SkinGPT4(nn.Module):
259
  f"Original error: {str(e)}"
260
  )
261
 
262
- def _init_alignment_projection(self):
263
- """Paper specifies Xavier initialization for alignment layer"""
264
- nn.init.xavier_normal_(self.llama_proj.weight)
265
- nn.init.constant_(self.llama_proj.bias, 0)
266
-
267
- def _create_patches(self, x):
268
  """Convert image to patch embeddings following Eq. (1)"""
269
  # x: (B, C, H, W)
270
  x = x.to(self.dtype)
@@ -276,69 +232,39 @@ class SkinGPT4(nn.Module):
276
  B, C, H, W = x.shape
277
  N = (H * W) // (self.P ** 2)
278
 
279
- x = self.vit.patch_embed(x) # (B, N, D)
280
 
281
  num_patches = x.shape[1]
282
- pos_embed = self.vit.pos_embed[:, 1:num_patches + 1, :] # Adjust for exact match
283
  x = x + pos_embed
284
 
285
  # Add class token
286
- class_token = self.vit.cls_token.expand(B, -1, -1)
287
- x = torch.cat([class_token, x], dim=1) # (B, N+1, D)
288
- return x
289
-
290
- def forward_encoder(self, x):
291
- """ViT encoder from Eqs. (2)-(3)"""
292
- # x: (B, N+1, D)
293
  for blk in self.vit.blocks:
294
  x = blk(x)
295
  x = self.vit.norm(x)
296
- x = self.ln_vision(x)
297
- return x # (B, N+1, D)
298
-
299
- def forward(self, images):
300
- images = images.to(self.dtype)
301
- x = self._create_patches(images)
302
- vit_output = self.forward_encoder(x)
303
- with torch.cuda.amp.autocast(enabled=False):
304
- qformer_output = self.q_former(vit_output.float())
305
- aligned_features = self.llama_proj(qformer_output.to(self.dtype))
306
- return aligned_features
307
-
308
-
309
- def add_to_history(self, role, content):
310
- self.conversation_history.append({"role": role, "content": content})
311
-
312
- def get_full_context(self):
313
- return "\n".join([f"{msg['role']}: {msg['content']}" for msg in self.conversation_history])
314
-
315
- def build_prompt(self, image_embeds, user_question=None):
316
- # Base prompt for initial diagnosis
317
- if not user_question:
318
- prompt = (
319
- "### Instruction: <Img><ImageHere></Img> "
320
- "Could you describe the skin disease in this image for me? "
321
- "### Response:"
322
- )
323
- else:
324
- # Follow-up prompt with conversation history
325
- history = self.get_full_context()
326
- prompt = (
327
- f"### Instruction: <Img><ImageHere></Img> "
328
- f"Based on our previous conversation:\n{history}\n"
329
- f"User asks: {user_question}\n"
330
- "### Response:"
331
- )
332
 
333
- return prompt
334
 
335
- def generate(self, images, user_input=None, max_length=300):
336
  print("Analysing the image to generate the diagnosis")
337
- aligned_features = self.forward(images)
338
- print(f"Aligned features : {aligned_features}")
 
339
  print("Generated the aligned features with ViT and Qformer")
 
340
  prompt = (
341
- "<Img><ImageHere></Img> Could you describe the skin disease in this image for me? [/INST]"
 
 
 
342
  )
343
  inputs = self.tokenizer(prompt, return_tensors="pt").to(images.device)
344
  image_token_id = self.tokenizer.convert_tokens_to_ids("<ImageHere>")
@@ -347,28 +273,39 @@ class SkinGPT4(nn.Module):
347
  raise ValueError("Image token not found in prompt")
348
  # Prepare embeddings
349
  input_embeddings = self.llama.model.embed_tokens(inputs.input_ids)
350
- # projected_features = self.llama_proj(aligned_features.mean(dim=1, keepdim=True))
351
- visual_embeds = aligned_features.mean(dim=1, keepdim=True) # [1, 1, 5120]
352
- visual_embeds = visual_embeds.to(input_embeddings.dtype)
353
- print(f"Visual embeddings : {visual_embeds}")
354
  input_embeddings[image_token_pos] = visual_embeds
355
- print(f"input embeddings : {input_embeddings}")
 
 
 
 
 
 
 
 
 
 
356
 
357
  outputs = self.llama.generate(
358
  inputs_embeds=input_embeddings,
359
- max_new_tokens=max_length,
360
- temperature=0.7,
361
- top_p=0.9,
362
- repetition_penalty=1.2, # Prevent repetition
363
  do_sample=True,
364
- pad_token_id=self.tokenizer.eos_token_id,
365
- eos_token_id=self.tokenizer.eos_token_id
 
 
 
 
366
  )
367
 
368
 
369
  full_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
370
  print(f"Output from llama : {full_output}")
371
- return full_output.split("[/INST]")[-1].strip()
 
 
372
 
373
 
374
  class SkinGPTClassifier:
@@ -395,7 +332,6 @@ class SkinGPTClassifier:
395
  )
396
  model = SkinGPT4(vit_checkpoint_path=model_path).eval()
397
  model = model.to(self.device)
398
- model.eval()
399
  return model
400
 
401
  def predict(self, image):
@@ -450,7 +386,7 @@ if uploaded_file:
450
  else:
451
  st.session_state.conversation.append(("assistant", result))
452
  with st.chat_message("assistant"):
453
- st.markdown(result)
454
  else:
455
  # Follow-up questions
456
  if user_query := st.chat_input("Ask a follow-up question..."):
 
115
 
116
  return outputs.last_hidden_state
117
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
  class SkinGPT4(nn.Module):
120
  def __init__(self, vit_checkpoint_path,
 
140
  self.q_former.load_from_pretrained(q_former_model)
141
  for param in self.q_former.parameters():
142
  param.requires_grad = False
143
+
 
 
 
144
  print("Loaded QFormer")
145
 
146
  self.tokenizer = LlamaTokenizer.from_pretrained(
 
161
  print(f"Q-Former output dim: {self.q_former.bert_config.hidden_size}")
162
  print(f"LLaMA input dim: {self.llama.config.hidden_size}")
163
 
164
+ for module in [self.vit, self.ln_vision, self.q_former, self.llama_proj, self.llama]:
165
+ for param in module.parameters():
166
+ param.requires_grad = False
167
+ module.eval()
168
 
169
  def _init_vit(self, vit_checkpoint_path):
170
  """Initialize EVA-ViT-G with paper specifications"""
 
191
  # 4. Load weights while ignoring classifier head
192
  vit.load_state_dict(vit_weights, strict=False)
193
 
 
 
 
194
 
195
  return vit.eval()
196
 
 
201
  "": 0 if torch.cuda.is_available() else "cpu"
202
  }
203
  # First try loading with device_map="auto"
204
+ model = LlamaForCausalLM.from_pretrained(
205
+ "meta-llama/Llama-2-13b-chat-hf",
206
+ token=token,
207
+ torch_dtype=torch.float16,
208
+ device_map=device_map,
209
+ low_cpu_mem_usage=True
210
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
 
212
  return model.eval()
213
 
 
220
  f"Original error: {str(e)}"
221
  )
222
 
223
+ def encode_image(self, x):
 
 
 
 
 
224
  """Convert image to patch embeddings following Eq. (1)"""
225
  # x: (B, C, H, W)
226
  x = x.to(self.dtype)
 
232
  B, C, H, W = x.shape
233
  N = (H * W) // (self.P ** 2)
234
 
235
+ x = self.vit.patch_embed(x)
236
 
237
  num_patches = x.shape[1]
238
+ pos_embed = self.vit.pos_embed[:, 1:num_patches + 1, :]
239
  x = x + pos_embed
240
 
241
  # Add class token
242
+ class_token = self.vit.cls_token.expand(x.shape[0], -1, -1)
243
+ x = torch.cat([class_token, x], dim=1)
 
 
 
 
 
244
  for blk in self.vit.blocks:
245
  x = blk(x)
246
  x = self.vit.norm(x)
247
+ vit_features = self.ln_vision(x)
248
+
249
+ # Q-Former forward pass
250
+ with torch.no_grad():
251
+ qformer_output = self.q_former(vit_features.float())
252
+ image_embeds = self.llama_proj(qformer_output.to(self.dtype))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
 
254
+ return image_embeds
255
 
256
+ def generate(self, images, user_input=None, max_new_tokens=300):
257
  print("Analysing the image to generate the diagnosis")
258
+
259
+ image_embeds = self.encode_image(images)
260
+ print(f"Aligned features : {image_embeds}")
261
  print("Generated the aligned features with ViT and Qformer")
262
+
263
  prompt = (
264
+ "### Instruction: <Img><ImageHere></Img> "
265
+ "Could you describe the skin condition in this image? "
266
+ "Please provide a detailed analysis including possible diagnoses. "
267
+ "### Response:"
268
  )
269
  inputs = self.tokenizer(prompt, return_tensors="pt").to(images.device)
270
  image_token_id = self.tokenizer.convert_tokens_to_ids("<ImageHere>")
 
273
  raise ValueError("Image token not found in prompt")
274
  # Prepare embeddings
275
  input_embeddings = self.llama.model.embed_tokens(inputs.input_ids)
276
+ visual_embeds = image_embeds.mean(dim=1, keepdim=True)
 
 
 
277
  input_embeddings[image_token_pos] = visual_embeds
278
+
279
+ # outputs = self.llama.generate(
280
+ # inputs_embeds=input_embeddings,
281
+ # max_new_tokens=max_length,
282
+ # temperature=0.7,
283
+ # top_p=0.9,
284
+ # repetition_penalty=1.2, # Prevent repetition
285
+ # do_sample=True,
286
+ # pad_token_id=self.tokenizer.eos_token_id,
287
+ # eos_token_id=self.tokenizer.eos_token_id
288
+ # )
289
 
290
  outputs = self.llama.generate(
291
  inputs_embeds=input_embeddings,
292
+ max_new_tokens=max_new_tokens,
293
+ num_beams=1,
 
 
294
  do_sample=True,
295
+ min_length=1,
296
+ top_p=0.9,
297
+ repetition_penalty=1.1,
298
+ length_penalty=1,
299
+ temperature=1.0,
300
+ pad_token_id=self.tokenizer.eos_token_id
301
  )
302
 
303
 
304
  full_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
305
  print(f"Output from llama : {full_output}")
306
+ response = full_output.split("### Response:")[-1].strip()
307
+
308
+ return response
309
 
310
 
311
  class SkinGPTClassifier:
 
332
  )
333
  model = SkinGPT4(vit_checkpoint_path=model_path).eval()
334
  model = model.to(self.device)
 
335
  return model
336
 
337
  def predict(self, image):
 
386
  else:
387
  st.session_state.conversation.append(("assistant", result))
388
  with st.chat_message("assistant"):
389
+ st.markdown(result["diagnosis"])
390
  else:
391
  # Follow-up questions
392
  if user_query := st.chat_input("Ask a follow-up question..."):