fix added
Browse files- SkinGPT.py +4 -39
SkinGPT.py
CHANGED
@@ -296,49 +296,14 @@ class SkinGPT4(nn.Module):
|
|
296 |
print(f"\n[DEBUG] Before replacement:")
|
297 |
print(f"Text embeddings shape: {input_embeddings.shape}")
|
298 |
print(f"Visual embeddings shape: {visual_embeds.shape}")
|
299 |
-
|
|
|
300 |
|
301 |
for pos in replace_positions:
|
302 |
input_embeddings[0, pos[1]] = visual_embeds[0]
|
303 |
|
304 |
-
# if visual_embeds.dtype != input_embeddings.dtype:
|
305 |
-
# visual_embeds = visual_embeds.to(input_embeddings.dtype)
|
306 |
-
# input_embeddings[image_token_pos] = visual_embeds
|
307 |
-
|
308 |
print(f"\n[DEBUG] After replacement:")
|
309 |
-
print(f"Image token embedding (after):\n{input_embeddings}...")
|
310 |
-
|
311 |
-
# outputs = self.llama.generate(
|
312 |
-
# inputs_embeds=input_embeddings,
|
313 |
-
# max_new_tokens=max_length,
|
314 |
-
# temperature=0.7,
|
315 |
-
# top_p=0.9,
|
316 |
-
# repetition_penalty=1.2, # Prevent repetition
|
317 |
-
# do_sample=True,
|
318 |
-
# pad_token_id=self.tokenizer.eos_token_id,
|
319 |
-
# eos_token_id=self.tokenizer.eos_token_id
|
320 |
-
# )
|
321 |
-
|
322 |
-
# outputs = self.llama.generate(
|
323 |
-
# inputs_embeds=input_embeddings,
|
324 |
-
# max_new_tokens=max_new_tokens,
|
325 |
-
# num_beams=1,
|
326 |
-
# do_sample=True,
|
327 |
-
# min_length=1,
|
328 |
-
# top_p=0.9,
|
329 |
-
# repetition_penalty=1.1,
|
330 |
-
# length_penalty=1,
|
331 |
-
# temperature=1.0,
|
332 |
-
# pad_token_id=self.tokenizer.eos_token_id
|
333 |
-
# )
|
334 |
-
|
335 |
-
with torch.no_grad():
|
336 |
-
# Test forward pass without generation
|
337 |
-
test_outputs = self.llama(
|
338 |
-
inputs_embeds=input_embeddings,
|
339 |
-
output_hidden_states=True
|
340 |
-
)
|
341 |
-
print(f"\n[DEBUG] First 5 output logits:\n{test_outputs.logits[0, :5, :5]}")
|
342 |
|
343 |
outputs = self.llama.generate(
|
344 |
inputs_embeds=input_embeddings,
|
@@ -356,7 +321,7 @@ class SkinGPT4(nn.Module):
|
|
356 |
full_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
357 |
print(f"Full Output from llama : {full_output}")
|
358 |
response = full_output.split("### Response:")[-1].strip()
|
359 |
-
print(f"Response from llama : {full_output}")
|
360 |
|
361 |
return response
|
362 |
|
|
|
296 |
print(f"\n[DEBUG] Before replacement:")
|
297 |
print(f"Text embeddings shape: {input_embeddings.shape}")
|
298 |
print(f"Visual embeddings shape: {visual_embeds.shape}")
|
299 |
+
replaced_pos = replace_positions[1][0]
|
300 |
+
print(f"Image token embedding (before):\n{input_embeddings[0, replaced_pos, :5]}...")
|
301 |
|
302 |
for pos in replace_positions:
|
303 |
input_embeddings[0, pos[1]] = visual_embeds[0]
|
304 |
|
|
|
|
|
|
|
|
|
305 |
print(f"\n[DEBUG] After replacement:")
|
306 |
+
print(f"Image token embedding (after):\n{input_embeddings[0, replaced_pos, :5]}...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
307 |
|
308 |
outputs = self.llama.generate(
|
309 |
inputs_embeds=input_embeddings,
|
|
|
321 |
full_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
322 |
print(f"Full Output from llama : {full_output}")
|
323 |
response = full_output.split("### Response:")[-1].strip()
|
324 |
+
# print(f"Response from llama : {full_output}")
|
325 |
|
326 |
return response
|
327 |
|