KeerthiVM commited on
Commit
546c625
·
1 Parent(s): 400535e
Files changed (1) hide show
  1. SkinGPT.py +47 -15
SkinGPT.py CHANGED
@@ -225,14 +225,36 @@ class SkinGPT4(nn.Module):
225
  def generate(self, images, user_input=None, max_new_tokens=300):
226
 
227
  image_embeds = self.encode_image(images)
 
228
  print(f"Aligned features : {image_embeds}")
 
 
 
 
 
 
 
229
 
230
- prompt = (
231
- "### Instruction: <Img><ImageHere></Img> "
232
- "Could you describe the skin condition in this image? "
233
- "Please provide a detailed analysis including possible diagnoses. "
234
- "### Response:"
235
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
 
237
  print(f"\n[DEBUG] Raw Prompt:\n{prompt}")
238
 
@@ -241,7 +263,8 @@ class SkinGPT4(nn.Module):
241
  token=token,
242
  padding_side="right"
243
  )
244
- self.tokenizer.add_special_tokens({'additional_special_tokens': ['<Img>', '</Img>', '<ImageHere>']})
 
245
  self.llama.resize_token_embeddings(len(self.tokenizer))
246
 
247
  inputs = self.tokenizer(prompt, return_tensors="pt").to(images.device)
@@ -249,17 +272,24 @@ class SkinGPT4(nn.Module):
249
  print(f"\n[DEBUG] Tokenized input IDs:\n{inputs.input_ids}")
250
  print(f"[DEBUG] Special token positions: {self.tokenizer.all_special_tokens}")
251
 
252
- image_token_id = self.tokenizer.convert_tokens_to_ids("<ImageHere>")
 
 
 
 
 
 
 
 
253
  image_token_pos = torch.where(inputs.input_ids == image_token_id)
 
254
  if len(image_token_pos[0]) == 0:
255
  raise ValueError("Image token not found in prompt")
256
 
257
  print(f"\n[DEBUG] Image token found at position: {image_token_pos}")
258
 
259
-
260
- # Prepare embeddings
261
- input_embeddings = self.llama.model.embed_tokens(inputs.input_ids)
262
- visual_embeds = image_embeds.mean(dim=1, keepdim=True)
263
 
264
  print(f"\n[DEBUG] Before replacement:")
265
  print(f"Text embeddings shape: {input_embeddings.shape}")
@@ -269,7 +299,7 @@ class SkinGPT4(nn.Module):
269
 
270
  if visual_embeds.dtype != input_embeddings.dtype:
271
  visual_embeds = visual_embeds.to(input_embeddings.dtype)
272
- input_embeddings[image_token_pos] = visual_embeds
273
 
274
  print(f"\n[DEBUG] After replacement:")
275
  print(f"Image token embedding (after):\n{input_embeddings[0, image_token_pos[1], :5]}...")
@@ -309,11 +339,13 @@ class SkinGPT4(nn.Module):
309
  outputs = self.llama.generate(
310
  inputs_embeds=input_embeddings,
311
  max_new_tokens=max_new_tokens,
312
- temperature=0.7,
 
313
  top_p=0.9,
314
  repetition_penalty=1.1,
315
  do_sample=True,
316
- pad_token_id=self.tokenizer.eos_token_id
 
317
  )
318
 
319
 
 
225
  def generate(self, images, user_input=None, max_new_tokens=300):
226
 
227
  image_embeds = self.encode_image(images)
228
+
229
  print(f"Aligned features : {image_embeds}")
230
+ print(f"\n Images embeddings shape : {image_embeds.shape} \n Llama config hidden size : {self.llama.config.hidden_size}")
231
+ if image_embeds.shape[-1] != self.llama.config.hidden_size:
232
+ raise ValueError(
233
+ f"Feature dimension mismatch. "
234
+ f"Q-Former output: {image_embeds.shape[-1]}, "
235
+ f"LLaMA expected: {self.llama.config.hidden_size}"
236
+ )
237
 
238
+
239
+ # prompt = (
240
+ # "### Instruction: <Img><ImageHere></Img> "
241
+ # "Could you describe the skin condition in this image? "
242
+ # "Please provide a detailed analysis including possible diagnoses. "
243
+ # "### Response:"
244
+ # )
245
+
246
+ prompt = """### Skin Diagnosis Protocol ###
247
+ <IMAGE>
248
+ Patient Presentation: [Describe visible symptoms]
249
+ Primary Differential Diagnosis:
250
+ 1.
251
+ 2.
252
+ 3.
253
+ Recommended Diagnostic Tests:
254
+ -
255
+ Treatment Options:
256
+ -
257
+ <|endoftext|>"""
258
 
259
  print(f"\n[DEBUG] Raw Prompt:\n{prompt}")
260
 
 
263
  token=token,
264
  padding_side="right"
265
  )
266
+ # self.tokenizer.add_special_tokens({'additional_special_tokens': ['<Img>', '</Img>', '<ImageHere>']})
267
+ self.tokenizer.add_tokens(["<IMAGE>"])
268
  self.llama.resize_token_embeddings(len(self.tokenizer))
269
 
270
  inputs = self.tokenizer(prompt, return_tensors="pt").to(images.device)
 
272
  print(f"\n[DEBUG] Tokenized input IDs:\n{inputs.input_ids}")
273
  print(f"[DEBUG] Special token positions: {self.tokenizer.all_special_tokens}")
274
 
275
+ # image_token_id = self.tokenizer.convert_tokens_to_ids("<ImageHere>")
276
+ image_token_id = self.tokenizer.convert_tokens_to_ids("<IMAGE>")
277
+
278
+
279
+ # Prepare embeddings
280
+ input_embeddings = self.llama.model.embed_tokens(inputs.input_ids)
281
+ visual_embeds = image_embeds.mean(dim=1, keepdim=True)
282
+
283
+ visual_embeds = F.layer_norm(visual_embeds, [visual_embeds.size(-1)])
284
  image_token_pos = torch.where(inputs.input_ids == image_token_id)
285
+
286
  if len(image_token_pos[0]) == 0:
287
  raise ValueError("Image token not found in prompt")
288
 
289
  print(f"\n[DEBUG] Image token found at position: {image_token_pos}")
290
 
291
+ for pos in zip(*image_token_pos):
292
+ input_embeddings[pos] = visual_embeds[0, 0]
 
 
293
 
294
  print(f"\n[DEBUG] Before replacement:")
295
  print(f"Text embeddings shape: {input_embeddings.shape}")
 
299
 
300
  if visual_embeds.dtype != input_embeddings.dtype:
301
  visual_embeds = visual_embeds.to(input_embeddings.dtype)
302
+ # input_embeddings[image_token_pos] = visual_embeds
303
 
304
  print(f"\n[DEBUG] After replacement:")
305
  print(f"Image token embedding (after):\n{input_embeddings[0, image_token_pos[1], :5]}...")
 
339
  outputs = self.llama.generate(
340
  inputs_embeds=input_embeddings,
341
  max_new_tokens=max_new_tokens,
342
+ temperature=0.3,
343
+ top_k=40,
344
  top_p=0.9,
345
  repetition_penalty=1.1,
346
  do_sample=True,
347
+ pad_token_id = self.tokenizer.eos_token_id,
348
+ eos_token_id = self.tokenizer.eos_token_id
349
  )
350
 
351