Issue fix
Browse files
app.py
CHANGED
@@ -24,6 +24,14 @@ from transformers import BitsAndBytesConfig
|
|
24 |
from accelerate import init_empty_weights
|
25 |
import warnings
|
26 |
from transformers import logging
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
warnings.filterwarnings("ignore", category=FutureWarning, module="timm")
|
28 |
warnings.filterwarnings("ignore", category=UserWarning)
|
29 |
logging.set_verbosity_error()
|
@@ -83,7 +91,9 @@ class Blip2QFormer(nn.Module):
|
|
83 |
|
84 |
def forward(self, visual_features):
|
85 |
# Project visual features
|
86 |
-
|
|
|
|
|
87 |
visual_attention_mask = torch.ones(
|
88 |
visual_embeds.size()[:-1],
|
89 |
dtype=torch.long,
|
@@ -133,20 +143,21 @@ class SkinGPT4(nn.Module):
|
|
133 |
super().__init__()
|
134 |
# Image encoder parameters from paper
|
135 |
self.device = device
|
136 |
-
self.dtype = torch.float16
|
|
|
137 |
self.H, self.W, self.C = 224, 224, 3
|
138 |
self.P = 14 # Patch size
|
139 |
self.D = 1408 # ViT embedding dimension
|
140 |
self.num_query_tokens = 32
|
141 |
|
142 |
-
self.vit = self._init_vit(vit_checkpoint_path)
|
143 |
print("Loaded ViT")
|
144 |
self.ln_vision = nn.LayerNorm(self.D).to(self.dtype)
|
145 |
|
146 |
self.q_former = Blip2QFormer(
|
147 |
num_query_tokens=self.num_query_tokens,
|
148 |
vision_width=self.D
|
149 |
-
)
|
150 |
self.q_former.load_from_pretrained(q_former_model)
|
151 |
for param in self.q_former.parameters():
|
152 |
param.requires_grad = False
|
@@ -368,8 +379,8 @@ class SkinGPTClassifier:
|
|
368 |
|
369 |
with st.spinner("Loading AI models (this may take several minutes)..."):
|
370 |
self.model = self._load_model()
|
371 |
-
print(f"Q-Former output shape: {self.model.q_former(torch.randn(1, 197, 1408)).shape}")
|
372 |
-
print(f"Projection layer: {self.model.llama_proj}")
|
373 |
|
374 |
# Image transformations
|
375 |
self.transform = transforms.Compose([
|
|
|
24 |
from accelerate import init_empty_weights
|
25 |
import warnings
|
26 |
from transformers import logging
|
27 |
+
|
28 |
+
import torch
|
29 |
+
from torch.cuda.amp import autocast
|
30 |
+
|
31 |
+
# Set default dtypes
|
32 |
+
torch.set_default_dtype(torch.float32) # Main computations in float32
|
33 |
+
MODEL_DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
|
34 |
+
|
35 |
warnings.filterwarnings("ignore", category=FutureWarning, module="timm")
|
36 |
warnings.filterwarnings("ignore", category=UserWarning)
|
37 |
logging.set_verbosity_error()
|
|
|
91 |
|
92 |
def forward(self, visual_features):
|
93 |
# Project visual features
|
94 |
+
with autocast(enabled=False):
|
95 |
+
visual_embeds = self.vision_proj(visual_features.float())
|
96 |
+
# visual_embeds = self.vision_proj(visual_features.float())
|
97 |
visual_attention_mask = torch.ones(
|
98 |
visual_embeds.size()[:-1],
|
99 |
dtype=torch.long,
|
|
|
143 |
super().__init__()
|
144 |
# Image encoder parameters from paper
|
145 |
self.device = device
|
146 |
+
# self.dtype = torch.float16
|
147 |
+
self.dtype = MODEL_DTYPE
|
148 |
self.H, self.W, self.C = 224, 224, 3
|
149 |
self.P = 14 # Patch size
|
150 |
self.D = 1408 # ViT embedding dimension
|
151 |
self.num_query_tokens = 32
|
152 |
|
153 |
+
self.vit = self._init_vit(vit_checkpoint_path).to(self.dtype)
|
154 |
print("Loaded ViT")
|
155 |
self.ln_vision = nn.LayerNorm(self.D).to(self.dtype)
|
156 |
|
157 |
self.q_former = Blip2QFormer(
|
158 |
num_query_tokens=self.num_query_tokens,
|
159 |
vision_width=self.D
|
160 |
+
)
|
161 |
self.q_former.load_from_pretrained(q_former_model)
|
162 |
for param in self.q_former.parameters():
|
163 |
param.requires_grad = False
|
|
|
379 |
|
380 |
with st.spinner("Loading AI models (this may take several minutes)..."):
|
381 |
self.model = self._load_model()
|
382 |
+
# print(f"Q-Former output shape: {self.model.q_former(torch.randn(1, 197, 1408)).shape}")
|
383 |
+
# print(f"Projection layer: {self.model.llama_proj}")
|
384 |
|
385 |
# Image transformations
|
386 |
self.transform = transforms.Compose([
|