fix added
Browse files- SkinGPT.py +35 -29
SkinGPT.py
CHANGED
@@ -229,6 +229,10 @@ class SkinGPT4(nn.Module):
|
|
229 |
|
230 |
print(f"Aligned features : {image_embeds}")
|
231 |
print(f"\n Images embeddings shape : {image_embeds.shape} \n Llama config hidden size : {self.llama.config.hidden_size}")
|
|
|
|
|
|
|
|
|
232 |
if image_embeds.shape[-1] != self.llama.config.hidden_size:
|
233 |
raise ValueError(
|
234 |
f"Feature dimension mismatch. "
|
@@ -238,24 +242,20 @@ class SkinGPT4(nn.Module):
|
|
238 |
|
239 |
|
240 |
# prompt = (
|
241 |
-
# "### Instruction: <Img><
|
242 |
# "Could you describe the skin condition in this image? "
|
243 |
# "Please provide a detailed analysis including possible diagnoses. "
|
244 |
# "### Response:"
|
245 |
# )
|
246 |
|
247 |
-
prompt = """### Skin Diagnosis
|
248 |
<IMAGE>
|
249 |
-
|
250 |
-
Primary
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
-
|
256 |
-
Treatment Options:
|
257 |
-
-
|
258 |
-
<|endoftext|>"""
|
259 |
|
260 |
print(f"\n[DEBUG] Raw Prompt:\n{prompt}")
|
261 |
|
@@ -265,7 +265,13 @@ class SkinGPT4(nn.Module):
|
|
265 |
padding_side="right"
|
266 |
)
|
267 |
# self.tokenizer.add_special_tokens({'additional_special_tokens': ['<Img>', '</Img>', '<ImageHere>']})
|
268 |
-
self.tokenizer.
|
|
|
|
|
|
|
|
|
|
|
|
|
269 |
self.llama.resize_token_embeddings(len(self.tokenizer))
|
270 |
|
271 |
inputs = self.tokenizer(prompt, return_tensors="pt").to(images.device)
|
@@ -273,37 +279,37 @@ class SkinGPT4(nn.Module):
|
|
273 |
print(f"\n[DEBUG] Tokenized input IDs:\n{inputs.input_ids}")
|
274 |
print(f"[DEBUG] Special token positions: {self.tokenizer.all_special_tokens}")
|
275 |
|
276 |
-
# image_token_id = self.tokenizer.convert_tokens_to_ids("<ImageHere>")
|
277 |
-
image_token_id = self.tokenizer.convert_tokens_to_ids("<IMAGE>")
|
278 |
-
|
279 |
-
|
280 |
# Prepare embeddings
|
281 |
input_embeddings = self.llama.model.embed_tokens(inputs.input_ids)
|
282 |
-
visual_embeds = image_embeds.mean(dim=1
|
|
|
|
|
|
|
|
|
283 |
|
284 |
-
|
285 |
-
|
286 |
|
287 |
-
if len(
|
288 |
raise ValueError("Image token not found in prompt")
|
289 |
|
290 |
-
print(f"\n[DEBUG] Image token found at position: {
|
291 |
|
292 |
-
for pos in zip(*image_token_pos):
|
293 |
-
input_embeddings[pos] = visual_embeds[0, 0]
|
294 |
|
295 |
print(f"\n[DEBUG] Before replacement:")
|
296 |
print(f"Text embeddings shape: {input_embeddings.shape}")
|
297 |
print(f"Visual embeddings shape: {visual_embeds.shape}")
|
298 |
-
print(f"Image token embedding (before):\n{input_embeddings[0,
|
299 |
|
|
|
|
|
300 |
|
301 |
-
if visual_embeds.dtype != input_embeddings.dtype:
|
302 |
-
|
303 |
# input_embeddings[image_token_pos] = visual_embeds
|
304 |
|
305 |
print(f"\n[DEBUG] After replacement:")
|
306 |
-
print(f"Image token embedding (after):\n{input_embeddings[0,
|
307 |
|
308 |
# outputs = self.llama.generate(
|
309 |
# inputs_embeds=input_embeddings,
|
@@ -340,7 +346,7 @@ class SkinGPT4(nn.Module):
|
|
340 |
outputs = self.llama.generate(
|
341 |
inputs_embeds=input_embeddings,
|
342 |
max_new_tokens=max_new_tokens,
|
343 |
-
temperature=0.
|
344 |
top_k=40,
|
345 |
top_p=0.9,
|
346 |
repetition_penalty=1.1,
|
|
|
229 |
|
230 |
print(f"Aligned features : {image_embeds}")
|
231 |
print(f"\n Images embeddings shape : {image_embeds.shape} \n Llama config hidden size : {self.llama.config.hidden_size}")
|
232 |
+
|
233 |
+
print(
|
234 |
+
f"\n[VALIDATION] Visual embeds - Mean: {image_embeds.mean().item():.4f}, Std: {image_embeds.std().item():.4f}")
|
235 |
+
|
236 |
if image_embeds.shape[-1] != self.llama.config.hidden_size:
|
237 |
raise ValueError(
|
238 |
f"Feature dimension mismatch. "
|
|
|
242 |
|
243 |
|
244 |
# prompt = (
|
245 |
+
# "### Instruction: <Img><IMAGE></Img> "
|
246 |
# "Could you describe the skin condition in this image? "
|
247 |
# "Please provide a detailed analysis including possible diagnoses. "
|
248 |
# "### Response:"
|
249 |
# )
|
250 |
|
251 |
+
prompt = """### Skin Diagnosis Analysis ###
|
252 |
<IMAGE>
|
253 |
+
Describe the skin condition shown and provide:
|
254 |
+
1. Primary diagnosis (with confidence)
|
255 |
+
2. Three differential diagnoses
|
256 |
+
3. Recommended tests
|
257 |
+
4. Treatment options"""
|
258 |
+
|
|
|
|
|
|
|
|
|
259 |
|
260 |
print(f"\n[DEBUG] Raw Prompt:\n{prompt}")
|
261 |
|
|
|
265 |
padding_side="right"
|
266 |
)
|
267 |
# self.tokenizer.add_special_tokens({'additional_special_tokens': ['<Img>', '</Img>', '<ImageHere>']})
|
268 |
+
num_added = self.tokenizer.add_special_tokens({
|
269 |
+
'additional_special_tokens': ['<IMAGE>']
|
270 |
+
})
|
271 |
+
|
272 |
+
if num_added == 0:
|
273 |
+
raise ValueError("Failed to add <IMAGE> token!")
|
274 |
+
|
275 |
self.llama.resize_token_embeddings(len(self.tokenizer))
|
276 |
|
277 |
inputs = self.tokenizer(prompt, return_tensors="pt").to(images.device)
|
|
|
279 |
print(f"\n[DEBUG] Tokenized input IDs:\n{inputs.input_ids}")
|
280 |
print(f"[DEBUG] Special token positions: {self.tokenizer.all_special_tokens}")
|
281 |
|
|
|
|
|
|
|
|
|
282 |
# Prepare embeddings
|
283 |
input_embeddings = self.llama.model.embed_tokens(inputs.input_ids)
|
284 |
+
visual_embeds = image_embeds.mean(dim=1)
|
285 |
+
|
286 |
+
# image_token_id = self.tokenizer.convert_tokens_to_ids("<ImageHere>")
|
287 |
+
image_token_id = self.tokenizer.convert_tokens_to_ids("<IMAGE>")
|
288 |
+
replace_positions = (inputs.input_ids == image_token_id).nonzero()
|
289 |
|
290 |
+
if len(replace_positions) == 0:
|
291 |
+
raise ValueError("No <IMAGE> tokens found in prompt!")
|
292 |
|
293 |
+
if len(replace_positions[0]) == 0:
|
294 |
raise ValueError("Image token not found in prompt")
|
295 |
|
296 |
+
print(f"\n[DEBUG] Image token found at position: {replace_positions}")
|
297 |
|
|
|
|
|
298 |
|
299 |
print(f"\n[DEBUG] Before replacement:")
|
300 |
print(f"Text embeddings shape: {input_embeddings.shape}")
|
301 |
print(f"Visual embeddings shape: {visual_embeds.shape}")
|
302 |
+
print(f"Image token embedding (before):\n{input_embeddings[0, replace_positions[1], :5]}...")
|
303 |
|
304 |
+
for pos in replace_positions:
|
305 |
+
input_embeddings[0, pos[1]] = visual_embeds[0]
|
306 |
|
307 |
+
# if visual_embeds.dtype != input_embeddings.dtype:
|
308 |
+
# visual_embeds = visual_embeds.to(input_embeddings.dtype)
|
309 |
# input_embeddings[image_token_pos] = visual_embeds
|
310 |
|
311 |
print(f"\n[DEBUG] After replacement:")
|
312 |
+
print(f"Image token embedding (after):\n{input_embeddings[0, replace_positions[1], :5]}...")
|
313 |
|
314 |
# outputs = self.llama.generate(
|
315 |
# inputs_embeds=input_embeddings,
|
|
|
346 |
outputs = self.llama.generate(
|
347 |
inputs_embeds=input_embeddings,
|
348 |
max_new_tokens=max_new_tokens,
|
349 |
+
temperature=0.7,
|
350 |
top_k=40,
|
351 |
top_p=0.9,
|
352 |
repetition_penalty=1.1,
|