KeerthiVM commited on
Commit
6c7e865
·
1 Parent(s): 51b26bc
Files changed (1) hide show
  1. app.py +3 -2
app.py CHANGED
@@ -31,8 +31,8 @@ from torch.cuda.amp import autocast
31
  # Set default dtypes
32
  torch.set_default_dtype(torch.float32) # Main computations in float32
33
  MODEL_DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
34
-
35
- warnings.filterwarnings("ignore", category=FutureWarning, module="timm")
36
  warnings.filterwarnings("ignore", category=UserWarning)
37
  logging.set_verbosity_error()
38
  token = os.getenv("HF_TOKEN")
@@ -354,6 +354,7 @@ class SkinGPT4(nn.Module):
354
  input_embeddings = self.llama.model.embed_tokens(inputs.input_ids)
355
  # projected_features = self.llama_proj(aligned_features.mean(dim=1, keepdim=True))
356
  visual_embeds = aligned_features.mean(dim=1, keepdim=True) # [1, 1, 5120]
 
357
  input_embeddings[image_token_pos] = visual_embeds
358
 
359
  outputs = self.llama.generate(
 
31
  # Set default dtypes
32
  torch.set_default_dtype(torch.float32) # Main computations in float32
33
  MODEL_DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
34
+ import warnings
35
+ warnings.filterwarnings("ignore", category=FutureWarning)
36
  warnings.filterwarnings("ignore", category=UserWarning)
37
  logging.set_verbosity_error()
38
  token = os.getenv("HF_TOKEN")
 
354
  input_embeddings = self.llama.model.embed_tokens(inputs.input_ids)
355
  # projected_features = self.llama_proj(aligned_features.mean(dim=1, keepdim=True))
356
  visual_embeds = aligned_features.mean(dim=1, keepdim=True) # [1, 1, 5120]
357
+ visual_embeds = visual_embeds.to(input_embeddings.dtype)
358
  input_embeddings[image_token_pos] = visual_embeds
359
 
360
  outputs = self.llama.generate(