KeerthiVM commited on
Commit
97233f9
·
1 Parent(s): 6cbf851
Files changed (1) hide show
  1. SkinGPT.py +35 -29
SkinGPT.py CHANGED
@@ -229,6 +229,10 @@ class SkinGPT4(nn.Module):
229
 
230
  print(f"Aligned features : {image_embeds}")
231
  print(f"\n Images embeddings shape : {image_embeds.shape} \n Llama config hidden size : {self.llama.config.hidden_size}")
 
 
 
 
232
  if image_embeds.shape[-1] != self.llama.config.hidden_size:
233
  raise ValueError(
234
  f"Feature dimension mismatch. "
@@ -238,24 +242,20 @@ class SkinGPT4(nn.Module):
238
 
239
 
240
  # prompt = (
241
- # "### Instruction: <Img><ImageHere></Img> "
242
  # "Could you describe the skin condition in this image? "
243
  # "Please provide a detailed analysis including possible diagnoses. "
244
  # "### Response:"
245
  # )
246
 
247
- prompt = """### Skin Diagnosis Protocol ###
248
  <IMAGE>
249
- Patient Presentation: [Describe visible symptoms]
250
- Primary Differential Diagnosis:
251
- 1.
252
- 2.
253
- 3.
254
- Recommended Diagnostic Tests:
255
- -
256
- Treatment Options:
257
- -
258
- <|endoftext|>"""
259
 
260
  print(f"\n[DEBUG] Raw Prompt:\n{prompt}")
261
 
@@ -265,7 +265,13 @@ class SkinGPT4(nn.Module):
265
  padding_side="right"
266
  )
267
  # self.tokenizer.add_special_tokens({'additional_special_tokens': ['<Img>', '</Img>', '<ImageHere>']})
268
- self.tokenizer.add_tokens(["<IMAGE>"])
 
 
 
 
 
 
269
  self.llama.resize_token_embeddings(len(self.tokenizer))
270
 
271
  inputs = self.tokenizer(prompt, return_tensors="pt").to(images.device)
@@ -273,37 +279,37 @@ class SkinGPT4(nn.Module):
273
  print(f"\n[DEBUG] Tokenized input IDs:\n{inputs.input_ids}")
274
  print(f"[DEBUG] Special token positions: {self.tokenizer.all_special_tokens}")
275
 
276
- # image_token_id = self.tokenizer.convert_tokens_to_ids("<ImageHere>")
277
- image_token_id = self.tokenizer.convert_tokens_to_ids("<IMAGE>")
278
-
279
-
280
  # Prepare embeddings
281
  input_embeddings = self.llama.model.embed_tokens(inputs.input_ids)
282
- visual_embeds = image_embeds.mean(dim=1, keepdim=True)
 
 
 
 
283
 
284
- visual_embeds = F.layer_norm(visual_embeds, [visual_embeds.size(-1)])
285
- image_token_pos = torch.where(inputs.input_ids == image_token_id)
286
 
287
- if len(image_token_pos[0]) == 0:
288
  raise ValueError("Image token not found in prompt")
289
 
290
- print(f"\n[DEBUG] Image token found at position: {image_token_pos}")
291
 
292
- for pos in zip(*image_token_pos):
293
- input_embeddings[pos] = visual_embeds[0, 0]
294
 
295
  print(f"\n[DEBUG] Before replacement:")
296
  print(f"Text embeddings shape: {input_embeddings.shape}")
297
  print(f"Visual embeddings shape: {visual_embeds.shape}")
298
- print(f"Image token embedding (before):\n{input_embeddings[0, image_token_pos[1], :5]}...")
299
 
 
 
300
 
301
- if visual_embeds.dtype != input_embeddings.dtype:
302
- visual_embeds = visual_embeds.to(input_embeddings.dtype)
303
  # input_embeddings[image_token_pos] = visual_embeds
304
 
305
  print(f"\n[DEBUG] After replacement:")
306
- print(f"Image token embedding (after):\n{input_embeddings[0, image_token_pos[1], :5]}...")
307
 
308
  # outputs = self.llama.generate(
309
  # inputs_embeds=input_embeddings,
@@ -340,7 +346,7 @@ class SkinGPT4(nn.Module):
340
  outputs = self.llama.generate(
341
  inputs_embeds=input_embeddings,
342
  max_new_tokens=max_new_tokens,
343
- temperature=0.3,
344
  top_k=40,
345
  top_p=0.9,
346
  repetition_penalty=1.1,
 
229
 
230
  print(f"Aligned features : {image_embeds}")
231
  print(f"\n Images embeddings shape : {image_embeds.shape} \n Llama config hidden size : {self.llama.config.hidden_size}")
232
+
233
+ print(
234
+ f"\n[VALIDATION] Visual embeds - Mean: {image_embeds.mean().item():.4f}, Std: {image_embeds.std().item():.4f}")
235
+
236
  if image_embeds.shape[-1] != self.llama.config.hidden_size:
237
  raise ValueError(
238
  f"Feature dimension mismatch. "
 
242
 
243
 
244
  # prompt = (
245
+ # "### Instruction: <Img><IMAGE></Img> "
246
  # "Could you describe the skin condition in this image? "
247
  # "Please provide a detailed analysis including possible diagnoses. "
248
  # "### Response:"
249
  # )
250
 
251
+ prompt = """### Skin Diagnosis Analysis ###
252
  <IMAGE>
253
+ Describe the skin condition shown and provide:
254
+ 1. Primary diagnosis (with confidence)
255
+ 2. Three differential diagnoses
256
+ 3. Recommended tests
257
+ 4. Treatment options"""
258
+
 
 
 
 
259
 
260
  print(f"\n[DEBUG] Raw Prompt:\n{prompt}")
261
 
 
265
  padding_side="right"
266
  )
267
  # self.tokenizer.add_special_tokens({'additional_special_tokens': ['<Img>', '</Img>', '<ImageHere>']})
268
+ num_added = self.tokenizer.add_special_tokens({
269
+ 'additional_special_tokens': ['<IMAGE>']
270
+ })
271
+
272
+ if num_added == 0:
273
+ raise ValueError("Failed to add <IMAGE> token!")
274
+
275
  self.llama.resize_token_embeddings(len(self.tokenizer))
276
 
277
  inputs = self.tokenizer(prompt, return_tensors="pt").to(images.device)
 
279
  print(f"\n[DEBUG] Tokenized input IDs:\n{inputs.input_ids}")
280
  print(f"[DEBUG] Special token positions: {self.tokenizer.all_special_tokens}")
281
 
 
 
 
 
282
  # Prepare embeddings
283
  input_embeddings = self.llama.model.embed_tokens(inputs.input_ids)
284
+ visual_embeds = image_embeds.mean(dim=1)
285
+
286
+ # image_token_id = self.tokenizer.convert_tokens_to_ids("<ImageHere>")
287
+ image_token_id = self.tokenizer.convert_tokens_to_ids("<IMAGE>")
288
+ replace_positions = (inputs.input_ids == image_token_id).nonzero()
289
 
290
+ if len(replace_positions) == 0:
291
+ raise ValueError("No <IMAGE> tokens found in prompt!")
292
 
293
+ if len(replace_positions[0]) == 0:
294
  raise ValueError("Image token not found in prompt")
295
 
296
+ print(f"\n[DEBUG] Image token found at position: {replace_positions}")
297
 
 
 
298
 
299
  print(f"\n[DEBUG] Before replacement:")
300
  print(f"Text embeddings shape: {input_embeddings.shape}")
301
  print(f"Visual embeddings shape: {visual_embeds.shape}")
302
+ print(f"Image token embedding (before):\n{input_embeddings[0, replace_positions[1], :5]}...")
303
 
304
+ for pos in replace_positions:
305
+ input_embeddings[0, pos[1]] = visual_embeds[0]
306
 
307
+ # if visual_embeds.dtype != input_embeddings.dtype:
308
+ # visual_embeds = visual_embeds.to(input_embeddings.dtype)
309
  # input_embeddings[image_token_pos] = visual_embeds
310
 
311
  print(f"\n[DEBUG] After replacement:")
312
+ print(f"Image token embedding (after):\n{input_embeddings[0, replace_positions[1], :5]}...")
313
 
314
  # outputs = self.llama.generate(
315
  # inputs_embeds=input_embeddings,
 
346
  outputs = self.llama.generate(
347
  inputs_embeds=input_embeddings,
348
  max_new_tokens=max_new_tokens,
349
+ temperature=0.7,
350
  top_k=40,
351
  top_p=0.9,
352
  repetition_penalty=1.1,