KeerthiVM commited on
Commit
c05c346
·
1 Parent(s): 5661612
Files changed (1) hide show
  1. SkinGPT.py +4 -39
SkinGPT.py CHANGED
@@ -296,49 +296,14 @@ class SkinGPT4(nn.Module):
296
  print(f"\n[DEBUG] Before replacement:")
297
  print(f"Text embeddings shape: {input_embeddings.shape}")
298
  print(f"Visual embeddings shape: {visual_embeds.shape}")
299
- print(f"Image token embedding (before):\n{input_embeddings}...")
 
300
 
301
  for pos in replace_positions:
302
  input_embeddings[0, pos[1]] = visual_embeds[0]
303
 
304
- # if visual_embeds.dtype != input_embeddings.dtype:
305
- # visual_embeds = visual_embeds.to(input_embeddings.dtype)
306
- # input_embeddings[image_token_pos] = visual_embeds
307
-
308
  print(f"\n[DEBUG] After replacement:")
309
- print(f"Image token embedding (after):\n{input_embeddings}...")
310
-
311
- # outputs = self.llama.generate(
312
- # inputs_embeds=input_embeddings,
313
- # max_new_tokens=max_length,
314
- # temperature=0.7,
315
- # top_p=0.9,
316
- # repetition_penalty=1.2, # Prevent repetition
317
- # do_sample=True,
318
- # pad_token_id=self.tokenizer.eos_token_id,
319
- # eos_token_id=self.tokenizer.eos_token_id
320
- # )
321
-
322
- # outputs = self.llama.generate(
323
- # inputs_embeds=input_embeddings,
324
- # max_new_tokens=max_new_tokens,
325
- # num_beams=1,
326
- # do_sample=True,
327
- # min_length=1,
328
- # top_p=0.9,
329
- # repetition_penalty=1.1,
330
- # length_penalty=1,
331
- # temperature=1.0,
332
- # pad_token_id=self.tokenizer.eos_token_id
333
- # )
334
-
335
- with torch.no_grad():
336
- # Test forward pass without generation
337
- test_outputs = self.llama(
338
- inputs_embeds=input_embeddings,
339
- output_hidden_states=True
340
- )
341
- print(f"\n[DEBUG] First 5 output logits:\n{test_outputs.logits[0, :5, :5]}")
342
 
343
  outputs = self.llama.generate(
344
  inputs_embeds=input_embeddings,
@@ -356,7 +321,7 @@ class SkinGPT4(nn.Module):
356
  full_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
357
  print(f"Full Output from llama : {full_output}")
358
  response = full_output.split("### Response:")[-1].strip()
359
- print(f"Response from llama : {full_output}")
360
 
361
  return response
362
 
 
296
  print(f"\n[DEBUG] Before replacement:")
297
  print(f"Text embeddings shape: {input_embeddings.shape}")
298
  print(f"Visual embeddings shape: {visual_embeds.shape}")
299
+ replaced_pos = replace_positions[1][0]
300
+ print(f"Image token embedding (before):\n{input_embeddings[0, replaced_pos, :5]}...")
301
 
302
  for pos in replace_positions:
303
  input_embeddings[0, pos[1]] = visual_embeds[0]
304
 
 
 
 
 
305
  print(f"\n[DEBUG] After replacement:")
306
+ print(f"Image token embedding (after):\n{input_embeddings[0, replaced_pos, :5]}...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
307
 
308
  outputs = self.llama.generate(
309
  inputs_embeds=input_embeddings,
 
321
  full_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
322
  print(f"Full Output from llama : {full_output}")
323
  response = full_output.split("### Response:")[-1].strip()
324
+ # print(f"Response from llama : {full_output}")
325
 
326
  return response
327