KeerthiVM commited on
Commit
682ada8
·
1 Parent(s): af6df0b
Files changed (1) hide show
  1. app.py +7 -4
app.py CHANGED
@@ -24,6 +24,7 @@ from transformers import BitsAndBytesConfig
24
  from accelerate import init_empty_weights
25
  import warnings
26
  from transformers import logging
 
27
  warnings.filterwarnings("ignore", category=UserWarning)
28
  logging.set_verbosity_error()
29
  token = os.getenv("HF_TOKEN")
@@ -57,7 +58,7 @@ class Blip2QFormer(nn.Module):
57
  classifier_dropout=None,
58
  )
59
 
60
- self.bert = BertModel(self.bert_config, add_pooling_layer=False).to(torch.float16)
61
  self.query_tokens = nn.Parameter(
62
  torch.zeros(1, num_query_tokens, self.bert_config.hidden_size)
63
  )
@@ -82,7 +83,7 @@ class Blip2QFormer(nn.Module):
82
 
83
  def forward(self, visual_features):
84
  # Project visual features
85
- visual_embeds = self.vision_proj(visual_features)
86
  visual_attention_mask = torch.ones(
87
  visual_embeds.size()[:-1],
88
  dtype=torch.long,
@@ -285,10 +286,12 @@ class SkinGPT4(nn.Module):
285
  return x # (B, N+1, D)
286
 
287
  def forward(self, images):
 
288
  x = self._create_patches(images)
289
  vit_output = self.forward_encoder(x)
290
- qformer_output = self.q_former(vit_output)
291
- aligned_features = self.llama_proj(qformer_output)
 
292
  return aligned_features
293
 
294
 
 
24
  from accelerate import init_empty_weights
25
  import warnings
26
  from transformers import logging
27
+ warnings.filterwarnings("ignore", category=FutureWarning, module="timm")
28
  warnings.filterwarnings("ignore", category=UserWarning)
29
  logging.set_verbosity_error()
30
  token = os.getenv("HF_TOKEN")
 
58
  classifier_dropout=None,
59
  )
60
 
61
+ self.bert = BertModel(self.bert_config, add_pooling_layer=False)
62
  self.query_tokens = nn.Parameter(
63
  torch.zeros(1, num_query_tokens, self.bert_config.hidden_size)
64
  )
 
83
 
84
  def forward(self, visual_features):
85
  # Project visual features
86
+ visual_embeds = self.vision_proj(visual_features.float())
87
  visual_attention_mask = torch.ones(
88
  visual_embeds.size()[:-1],
89
  dtype=torch.long,
 
286
  return x # (B, N+1, D)
287
 
288
  def forward(self, images):
289
+ images = images.to(self.dtype)
290
  x = self._create_patches(images)
291
  vit_output = self.forward_encoder(x)
292
+ with torch.cuda.amp.autocast(enabled=False):
293
+ qformer_output = self.q_former(vit_output.float())
294
+ aligned_features = self.llama_proj(qformer_output.to(self.dtype))
295
  return aligned_features
296
 
297