Issue fix
Browse files
app.py
CHANGED
@@ -31,8 +31,8 @@ from torch.cuda.amp import autocast
|
|
31 |
# Set default dtypes
|
32 |
torch.set_default_dtype(torch.float32) # Main computations in float32
|
33 |
MODEL_DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
|
34 |
-
|
35 |
-
warnings.filterwarnings("ignore", category=FutureWarning
|
36 |
warnings.filterwarnings("ignore", category=UserWarning)
|
37 |
logging.set_verbosity_error()
|
38 |
token = os.getenv("HF_TOKEN")
|
@@ -354,6 +354,7 @@ class SkinGPT4(nn.Module):
|
|
354 |
input_embeddings = self.llama.model.embed_tokens(inputs.input_ids)
|
355 |
# projected_features = self.llama_proj(aligned_features.mean(dim=1, keepdim=True))
|
356 |
visual_embeds = aligned_features.mean(dim=1, keepdim=True) # [1, 1, 5120]
|
|
|
357 |
input_embeddings[image_token_pos] = visual_embeds
|
358 |
|
359 |
outputs = self.llama.generate(
|
|
|
31 |
# Set default dtypes
|
32 |
torch.set_default_dtype(torch.float32) # Main computations in float32
|
33 |
MODEL_DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
|
34 |
+
import warnings
|
35 |
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
36 |
warnings.filterwarnings("ignore", category=UserWarning)
|
37 |
logging.set_verbosity_error()
|
38 |
token = os.getenv("HF_TOKEN")
|
|
|
354 |
input_embeddings = self.llama.model.embed_tokens(inputs.input_ids)
|
355 |
# projected_features = self.llama_proj(aligned_features.mean(dim=1, keepdim=True))
|
356 |
visual_embeds = aligned_features.mean(dim=1, keepdim=True) # [1, 1, 5120]
|
357 |
+
visual_embeds = visual_embeds.to(input_embeddings.dtype)
|
358 |
input_embeddings[image_token_pos] = visual_embeds
|
359 |
|
360 |
outputs = self.llama.generate(
|