fix added
Browse files- SkinGPT.py +47 -15
SkinGPT.py
CHANGED
@@ -225,14 +225,36 @@ class SkinGPT4(nn.Module):
|
|
225 |
def generate(self, images, user_input=None, max_new_tokens=300):
|
226 |
|
227 |
image_embeds = self.encode_image(images)
|
|
|
228 |
print(f"Aligned features : {image_embeds}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
229 |
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
236 |
|
237 |
print(f"\n[DEBUG] Raw Prompt:\n{prompt}")
|
238 |
|
@@ -241,7 +263,8 @@ class SkinGPT4(nn.Module):
|
|
241 |
token=token,
|
242 |
padding_side="right"
|
243 |
)
|
244 |
-
self.tokenizer.add_special_tokens({'additional_special_tokens': ['<Img>', '</Img>', '<ImageHere>']})
|
|
|
245 |
self.llama.resize_token_embeddings(len(self.tokenizer))
|
246 |
|
247 |
inputs = self.tokenizer(prompt, return_tensors="pt").to(images.device)
|
@@ -249,17 +272,24 @@ class SkinGPT4(nn.Module):
|
|
249 |
print(f"\n[DEBUG] Tokenized input IDs:\n{inputs.input_ids}")
|
250 |
print(f"[DEBUG] Special token positions: {self.tokenizer.all_special_tokens}")
|
251 |
|
252 |
-
image_token_id = self.tokenizer.convert_tokens_to_ids("<ImageHere>")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
253 |
image_token_pos = torch.where(inputs.input_ids == image_token_id)
|
|
|
254 |
if len(image_token_pos[0]) == 0:
|
255 |
raise ValueError("Image token not found in prompt")
|
256 |
|
257 |
print(f"\n[DEBUG] Image token found at position: {image_token_pos}")
|
258 |
|
259 |
-
|
260 |
-
|
261 |
-
input_embeddings = self.llama.model.embed_tokens(inputs.input_ids)
|
262 |
-
visual_embeds = image_embeds.mean(dim=1, keepdim=True)
|
263 |
|
264 |
print(f"\n[DEBUG] Before replacement:")
|
265 |
print(f"Text embeddings shape: {input_embeddings.shape}")
|
@@ -269,7 +299,7 @@ class SkinGPT4(nn.Module):
|
|
269 |
|
270 |
if visual_embeds.dtype != input_embeddings.dtype:
|
271 |
visual_embeds = visual_embeds.to(input_embeddings.dtype)
|
272 |
-
input_embeddings[image_token_pos] = visual_embeds
|
273 |
|
274 |
print(f"\n[DEBUG] After replacement:")
|
275 |
print(f"Image token embedding (after):\n{input_embeddings[0, image_token_pos[1], :5]}...")
|
@@ -309,11 +339,13 @@ class SkinGPT4(nn.Module):
|
|
309 |
outputs = self.llama.generate(
|
310 |
inputs_embeds=input_embeddings,
|
311 |
max_new_tokens=max_new_tokens,
|
312 |
-
temperature=0.
|
|
|
313 |
top_p=0.9,
|
314 |
repetition_penalty=1.1,
|
315 |
do_sample=True,
|
316 |
-
pad_token_id=self.tokenizer.eos_token_id
|
|
|
317 |
)
|
318 |
|
319 |
|
|
|
225 |
def generate(self, images, user_input=None, max_new_tokens=300):
|
226 |
|
227 |
image_embeds = self.encode_image(images)
|
228 |
+
|
229 |
print(f"Aligned features : {image_embeds}")
|
230 |
+
print(f"\n Images embeddings shape : {image_embeds.shape} \n Llama config hidden size : {self.llama.config.hidden_size}")
|
231 |
+
if image_embeds.shape[-1] != self.llama.config.hidden_size:
|
232 |
+
raise ValueError(
|
233 |
+
f"Feature dimension mismatch. "
|
234 |
+
f"Q-Former output: {image_embeds.shape[-1]}, "
|
235 |
+
f"LLaMA expected: {self.llama.config.hidden_size}"
|
236 |
+
)
|
237 |
|
238 |
+
|
239 |
+
# prompt = (
|
240 |
+
# "### Instruction: <Img><ImageHere></Img> "
|
241 |
+
# "Could you describe the skin condition in this image? "
|
242 |
+
# "Please provide a detailed analysis including possible diagnoses. "
|
243 |
+
# "### Response:"
|
244 |
+
# )
|
245 |
+
|
246 |
+
prompt = """### Skin Diagnosis Protocol ###
|
247 |
+
<IMAGE>
|
248 |
+
Patient Presentation: [Describe visible symptoms]
|
249 |
+
Primary Differential Diagnosis:
|
250 |
+
1.
|
251 |
+
2.
|
252 |
+
3.
|
253 |
+
Recommended Diagnostic Tests:
|
254 |
+
-
|
255 |
+
Treatment Options:
|
256 |
+
-
|
257 |
+
<|endoftext|>"""
|
258 |
|
259 |
print(f"\n[DEBUG] Raw Prompt:\n{prompt}")
|
260 |
|
|
|
263 |
token=token,
|
264 |
padding_side="right"
|
265 |
)
|
266 |
+
# self.tokenizer.add_special_tokens({'additional_special_tokens': ['<Img>', '</Img>', '<ImageHere>']})
|
267 |
+
self.tokenizer.add_tokens(["<IMAGE>"])
|
268 |
self.llama.resize_token_embeddings(len(self.tokenizer))
|
269 |
|
270 |
inputs = self.tokenizer(prompt, return_tensors="pt").to(images.device)
|
|
|
272 |
print(f"\n[DEBUG] Tokenized input IDs:\n{inputs.input_ids}")
|
273 |
print(f"[DEBUG] Special token positions: {self.tokenizer.all_special_tokens}")
|
274 |
|
275 |
+
# image_token_id = self.tokenizer.convert_tokens_to_ids("<ImageHere>")
|
276 |
+
image_token_id = self.tokenizer.convert_tokens_to_ids("<IMAGE>")
|
277 |
+
|
278 |
+
|
279 |
+
# Prepare embeddings
|
280 |
+
input_embeddings = self.llama.model.embed_tokens(inputs.input_ids)
|
281 |
+
visual_embeds = image_embeds.mean(dim=1, keepdim=True)
|
282 |
+
|
283 |
+
visual_embeds = F.layer_norm(visual_embeds, [visual_embeds.size(-1)])
|
284 |
image_token_pos = torch.where(inputs.input_ids == image_token_id)
|
285 |
+
|
286 |
if len(image_token_pos[0]) == 0:
|
287 |
raise ValueError("Image token not found in prompt")
|
288 |
|
289 |
print(f"\n[DEBUG] Image token found at position: {image_token_pos}")
|
290 |
|
291 |
+
for pos in zip(*image_token_pos):
|
292 |
+
input_embeddings[pos] = visual_embeds[0, 0]
|
|
|
|
|
293 |
|
294 |
print(f"\n[DEBUG] Before replacement:")
|
295 |
print(f"Text embeddings shape: {input_embeddings.shape}")
|
|
|
299 |
|
300 |
if visual_embeds.dtype != input_embeddings.dtype:
|
301 |
visual_embeds = visual_embeds.to(input_embeddings.dtype)
|
302 |
+
# input_embeddings[image_token_pos] = visual_embeds
|
303 |
|
304 |
print(f"\n[DEBUG] After replacement:")
|
305 |
print(f"Image token embedding (after):\n{input_embeddings[0, image_token_pos[1], :5]}...")
|
|
|
339 |
outputs = self.llama.generate(
|
340 |
inputs_embeds=input_embeddings,
|
341 |
max_new_tokens=max_new_tokens,
|
342 |
+
temperature=0.3,
|
343 |
+
top_k=40,
|
344 |
top_p=0.9,
|
345 |
repetition_penalty=1.1,
|
346 |
do_sample=True,
|
347 |
+
pad_token_id = self.tokenizer.eos_token_id,
|
348 |
+
eos_token_id = self.tokenizer.eos_token_id
|
349 |
)
|
350 |
|
351 |
|