Issue fix
Browse files
app.py
CHANGED
@@ -24,6 +24,7 @@ from transformers import BitsAndBytesConfig
|
|
24 |
from accelerate import init_empty_weights
|
25 |
import warnings
|
26 |
from transformers import logging
|
|
|
27 |
warnings.filterwarnings("ignore", category=UserWarning)
|
28 |
logging.set_verbosity_error()
|
29 |
token = os.getenv("HF_TOKEN")
|
@@ -57,7 +58,7 @@ class Blip2QFormer(nn.Module):
|
|
57 |
classifier_dropout=None,
|
58 |
)
|
59 |
|
60 |
-
self.bert = BertModel(self.bert_config, add_pooling_layer=False)
|
61 |
self.query_tokens = nn.Parameter(
|
62 |
torch.zeros(1, num_query_tokens, self.bert_config.hidden_size)
|
63 |
)
|
@@ -82,7 +83,7 @@ class Blip2QFormer(nn.Module):
|
|
82 |
|
83 |
def forward(self, visual_features):
|
84 |
# Project visual features
|
85 |
-
visual_embeds = self.vision_proj(visual_features)
|
86 |
visual_attention_mask = torch.ones(
|
87 |
visual_embeds.size()[:-1],
|
88 |
dtype=torch.long,
|
@@ -285,10 +286,12 @@ class SkinGPT4(nn.Module):
|
|
285 |
return x # (B, N+1, D)
|
286 |
|
287 |
def forward(self, images):
|
|
|
288 |
x = self._create_patches(images)
|
289 |
vit_output = self.forward_encoder(x)
|
290 |
-
|
291 |
-
|
|
|
292 |
return aligned_features
|
293 |
|
294 |
|
|
|
24 |
from accelerate import init_empty_weights
|
25 |
import warnings
|
26 |
from transformers import logging
|
27 |
+
warnings.filterwarnings("ignore", category=FutureWarning, module="timm")
|
28 |
warnings.filterwarnings("ignore", category=UserWarning)
|
29 |
logging.set_verbosity_error()
|
30 |
token = os.getenv("HF_TOKEN")
|
|
|
58 |
classifier_dropout=None,
|
59 |
)
|
60 |
|
61 |
+
self.bert = BertModel(self.bert_config, add_pooling_layer=False)
|
62 |
self.query_tokens = nn.Parameter(
|
63 |
torch.zeros(1, num_query_tokens, self.bert_config.hidden_size)
|
64 |
)
|
|
|
83 |
|
84 |
def forward(self, visual_features):
|
85 |
# Project visual features
|
86 |
+
visual_embeds = self.vision_proj(visual_features.float())
|
87 |
visual_attention_mask = torch.ones(
|
88 |
visual_embeds.size()[:-1],
|
89 |
dtype=torch.long,
|
|
|
286 |
return x # (B, N+1, D)
|
287 |
|
288 |
def forward(self, images):
|
289 |
+
images = images.to(self.dtype)
|
290 |
x = self._create_patches(images)
|
291 |
vit_output = self.forward_encoder(x)
|
292 |
+
with torch.cuda.amp.autocast(enabled=False):
|
293 |
+
qformer_output = self.q_former(vit_output.float())
|
294 |
+
aligned_features = self.llama_proj(qformer_output.to(self.dtype))
|
295 |
return aligned_features
|
296 |
|
297 |
|