KeerthiVM commited on
Commit
a60c293
·
1 Parent(s): a7194aa
Files changed (1) hide show
  1. app.py +31 -15
app.py CHANGED
@@ -357,22 +357,38 @@ class SkinGPT4(nn.Module):
357
  # "### Response:"
358
  # )
359
 
360
- prompt_parts = [
361
- "### Instruction: <Img>",
362
- "<Image>",
363
- "</Img> Could you describe the skin disease in this image for me? ### Response:"
364
- ]
 
 
 
365
  # Tokenize each part separately
366
- tokens_before = self.tokenizer(prompt_parts[0], return_tensors="pt").input_ids.to(images.device)
367
- tokens_after = self.tokenizer(prompt_parts[2], return_tensors="pt").input_ids.to(images.device)
368
- input_ids = torch.cat([
369
- tokens_before[:, :-1], # Remove EOS from first part
370
- torch.full((1, 1), self.tokenizer.convert_tokens_to_ids("<Image>")).to(images.device),
371
- tokens_after[:, 1:] # Remove BOS from second part
372
- ], dim=1)
373
- embeddings = self.llama.model.embed_tokens(input_ids)
374
- image_token_pos = (input_ids == self.tokenizer.convert_tokens_to_ids("<Image>")).nonzero()
375
- embeddings[image_token_pos] = aligned_features.mean(dim=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
376
  outputs = self.llama.generate(
377
  inputs_embeds=embeddings,
378
  max_length=max_length,
 
357
  # "### Response:"
358
  # )
359
 
360
+ # prompt_parts = [
361
+ # "### Instruction: <Img>",
362
+ # "<Image>",
363
+ # "</Img> Could you describe the skin disease in this image for me? ### Response:"
364
+ # ]
365
+ prompt = "### Instruction: <Img><Image></Img> Could you describe the skin disease in this image for me? ### Response:"
366
+ inputs = self.tokenizer(prompt, return_tensors="pt").to(images.device)
367
+
368
  # Tokenize each part separately
369
+ # tokens_before = self.tokenizer(prompt_parts[0], return_tensors="pt").input_ids.to(images.device)
370
+ # tokens_after = self.tokenizer(prompt_parts[2], return_tensors="pt").input_ids.to(images.device)
371
+ # input_ids = torch.cat([
372
+ # tokens_before[:, :-1], # Remove EOS from first part
373
+ # torch.full((1, 1), self.tokenizer.convert_tokens_to_ids("<Image>")).to(images.device),
374
+ # tokens_after[:, 1:] # Remove BOS from second part
375
+ # ], dim=1)
376
+ # embeddings = self.llama.model.embed_tokens(input_ids)
377
+ # image_token_pos = (input_ids == self.tokenizer.convert_tokens_to_ids("<Image>")).nonzero()
378
+ # embeddings[image_token_pos] = aligned_features.mean(dim=1)
379
+
380
+ image_token_id = self.tokenizer.convert_tokens_to_ids("<Image>")
381
+ image_token_pos = (inputs.input_ids == image_token_id).nonzero()
382
+
383
+ if image_token_pos.shape[0] != 1:
384
+ raise ValueError(f"Expected 1 image token, found {image_token_pos.shape[0]}")
385
+
386
+ # Prepare embeddings
387
+ embeddings = self.llama.model.embed_tokens(inputs.input_ids)
388
+
389
+ row, col = image_token_pos[0]
390
+ embeddings[row, col] = aligned_features.mean(dim=1)
391
+
392
  outputs = self.llama.generate(
393
  inputs_embeds=embeddings,
394
  max_length=max_length,