KeerthiVM commited on
Commit
51b26bc
·
1 Parent(s): 682ada8
Files changed (1) hide show
  1. app.py +17 -6
app.py CHANGED
@@ -24,6 +24,14 @@ from transformers import BitsAndBytesConfig
24
  from accelerate import init_empty_weights
25
  import warnings
26
  from transformers import logging
 
 
 
 
 
 
 
 
27
  warnings.filterwarnings("ignore", category=FutureWarning, module="timm")
28
  warnings.filterwarnings("ignore", category=UserWarning)
29
  logging.set_verbosity_error()
@@ -83,7 +91,9 @@ class Blip2QFormer(nn.Module):
83
 
84
  def forward(self, visual_features):
85
  # Project visual features
86
- visual_embeds = self.vision_proj(visual_features.float())
 
 
87
  visual_attention_mask = torch.ones(
88
  visual_embeds.size()[:-1],
89
  dtype=torch.long,
@@ -133,20 +143,21 @@ class SkinGPT4(nn.Module):
133
  super().__init__()
134
  # Image encoder parameters from paper
135
  self.device = device
136
- self.dtype = torch.float16
 
137
  self.H, self.W, self.C = 224, 224, 3
138
  self.P = 14 # Patch size
139
  self.D = 1408 # ViT embedding dimension
140
  self.num_query_tokens = 32
141
 
142
- self.vit = self._init_vit(vit_checkpoint_path)
143
  print("Loaded ViT")
144
  self.ln_vision = nn.LayerNorm(self.D).to(self.dtype)
145
 
146
  self.q_former = Blip2QFormer(
147
  num_query_tokens=self.num_query_tokens,
148
  vision_width=self.D
149
- ).to(self.dtype)
150
  self.q_former.load_from_pretrained(q_former_model)
151
  for param in self.q_former.parameters():
152
  param.requires_grad = False
@@ -368,8 +379,8 @@ class SkinGPTClassifier:
368
 
369
  with st.spinner("Loading AI models (this may take several minutes)..."):
370
  self.model = self._load_model()
371
- print(f"Q-Former output shape: {self.model.q_former(torch.randn(1, 197, 1408)).shape}")
372
- print(f"Projection layer: {self.model.llama_proj}")
373
 
374
  # Image transformations
375
  self.transform = transforms.Compose([
 
24
  from accelerate import init_empty_weights
25
  import warnings
26
  from transformers import logging
27
+
28
+ import torch
29
+ from torch.cuda.amp import autocast
30
+
31
+ # Set default dtypes
32
+ torch.set_default_dtype(torch.float32) # Main computations in float32
33
+ MODEL_DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
34
+
35
  warnings.filterwarnings("ignore", category=FutureWarning, module="timm")
36
  warnings.filterwarnings("ignore", category=UserWarning)
37
  logging.set_verbosity_error()
 
91
 
92
  def forward(self, visual_features):
93
  # Project visual features
94
+ with autocast(enabled=False):
95
+ visual_embeds = self.vision_proj(visual_features.float())
96
+ # visual_embeds = self.vision_proj(visual_features.float())
97
  visual_attention_mask = torch.ones(
98
  visual_embeds.size()[:-1],
99
  dtype=torch.long,
 
143
  super().__init__()
144
  # Image encoder parameters from paper
145
  self.device = device
146
+ # self.dtype = torch.float16
147
+ self.dtype = MODEL_DTYPE
148
  self.H, self.W, self.C = 224, 224, 3
149
  self.P = 14 # Patch size
150
  self.D = 1408 # ViT embedding dimension
151
  self.num_query_tokens = 32
152
 
153
+ self.vit = self._init_vit(vit_checkpoint_path).to(self.dtype)
154
  print("Loaded ViT")
155
  self.ln_vision = nn.LayerNorm(self.D).to(self.dtype)
156
 
157
  self.q_former = Blip2QFormer(
158
  num_query_tokens=self.num_query_tokens,
159
  vision_width=self.D
160
+ )
161
  self.q_former.load_from_pretrained(q_former_model)
162
  for param in self.q_former.parameters():
163
  param.requires_grad = False
 
379
 
380
  with st.spinner("Loading AI models (this may take several minutes)..."):
381
  self.model = self._load_model()
382
+ # print(f"Q-Former output shape: {self.model.q_former(torch.randn(1, 197, 1408)).shape}")
383
+ # print(f"Projection layer: {self.model.llama_proj}")
384
 
385
  # Image transformations
386
  self.transform = transforms.Compose([