KeerthiVM commited on
Commit
af6df0b
·
1 Parent(s): 5882944
Files changed (1) hide show
  1. app.py +12 -2
app.py CHANGED
@@ -22,6 +22,10 @@ import os
22
  from huggingface_hub import hf_hub_download
23
  from transformers import BitsAndBytesConfig
24
  from accelerate import init_empty_weights
 
 
 
 
25
  token = os.getenv("HF_TOKEN")
26
  if not token:
27
  raise ValueError("Hugging Face token not found in environment variables")
@@ -166,6 +170,9 @@ class SkinGPT4(nn.Module):
166
  self.llama.config.hidden_size
167
  ).to(self.dtype)
168
 
 
 
 
169
  for param in self.llama_proj.parameters():
170
  param.requires_grad = False
171
 
@@ -331,8 +338,9 @@ class SkinGPT4(nn.Module):
331
  raise ValueError("Image token not found in prompt")
332
  # Prepare embeddings
333
  input_embeddings = self.llama.model.embed_tokens(inputs.input_ids)
334
- projected_features = self.llama_proj(aligned_features.mean(dim=1, keepdim=True))
335
- input_embeddings[image_token_pos] = projected_features
 
336
 
337
  outputs = self.llama.generate(
338
  inputs_embeds=input_embeddings,
@@ -357,6 +365,8 @@ class SkinGPTClassifier:
357
 
358
  with st.spinner("Loading AI models (this may take several minutes)..."):
359
  self.model = self._load_model()
 
 
360
 
361
  # Image transformations
362
  self.transform = transforms.Compose([
 
22
  from huggingface_hub import hf_hub_download
23
  from transformers import BitsAndBytesConfig
24
  from accelerate import init_empty_weights
25
+ import warnings
26
+ from transformers import logging
27
+ warnings.filterwarnings("ignore", category=UserWarning)
28
+ logging.set_verbosity_error()
29
  token = os.getenv("HF_TOKEN")
30
  if not token:
31
  raise ValueError("Hugging Face token not found in environment variables")
 
170
  self.llama.config.hidden_size
171
  ).to(self.dtype)
172
 
173
+ print(f"Q-Former output dim: {self.q_former.bert_config.hidden_size}")
174
+ print(f"LLaMA input dim: {self.llama.config.hidden_size}")
175
+
176
  for param in self.llama_proj.parameters():
177
  param.requires_grad = False
178
 
 
338
  raise ValueError("Image token not found in prompt")
339
  # Prepare embeddings
340
  input_embeddings = self.llama.model.embed_tokens(inputs.input_ids)
341
+ # projected_features = self.llama_proj(aligned_features.mean(dim=1, keepdim=True))
342
+ visual_embeds = aligned_features.mean(dim=1, keepdim=True) # [1, 1, 5120]
343
+ input_embeddings[image_token_pos] = visual_embeds
344
 
345
  outputs = self.llama.generate(
346
  inputs_embeds=input_embeddings,
 
365
 
366
  with st.spinner("Loading AI models (this may take several minutes)..."):
367
  self.model = self._load_model()
368
+ print(f"Q-Former output shape: {self.model.q_former(torch.randn(1, 197, 1408)).shape}")
369
+ print(f"Projection layer: {self.model.llama_proj}")
370
 
371
  # Image transformations
372
  self.transform = transforms.Compose([