KeerthiVM commited on
Commit
89a06bf
·
1 Parent(s): a60c293
Files changed (1) hide show
  1. app.py +71 -149
app.py CHANGED
@@ -54,22 +54,18 @@ class Blip2QFormer(nn.Module):
54
  )
55
 
56
  self.bert = BertModel(self.bert_config, add_pooling_layer=False).to(torch.float16)
57
-
58
- # Replace position embeddings with a dummy implementation
59
- self.bert.embeddings.position_embeddings = nn.Identity() # Completely bypass position embeddings
60
-
61
- # Disable word embeddings
62
- self.bert.embeddings.word_embeddings = None
63
-
64
- # Initialize query tokens
65
  self.query_tokens = nn.Parameter(
66
- torch.zeros(1, num_query_tokens, self.bert_config.hidden_size, dtype=torch.float16)
67
  )
68
- self.vision_proj = nn.Sequential(
69
- nn.Linear(vision_width, self.bert_config.hidden_size),
70
- nn.LayerNorm(self.bert_config.hidden_size)
71
- ).to(torch.float16)
72
 
 
 
 
 
73
 
74
  def load_from_pretrained(self, url_or_filename):
75
  if url_or_filename.startswith('http'):
@@ -77,38 +73,31 @@ class Blip2QFormer(nn.Module):
77
  checkpoint = torch.load(BytesIO(response.content), map_location='cpu')
78
  else:
79
  checkpoint = torch.load(url_or_filename, map_location='cpu')
80
-
81
- # Load Q-Former weights only
82
  state_dict = checkpoint['model'] if 'model' in checkpoint else checkpoint
83
  msg = self.load_state_dict(state_dict, strict=False)
84
- # print(f"Loaded Q-Former weights with message: {msg}")
85
-
86
- def forward(self, query_embeds=None, encoder_hidden_states=None, encoder_attention_mask=None):
87
- if query_embeds is None:
88
- query_embeds = self.query_tokens.expand(encoder_hidden_states.shape[0], -1, -1)
89
 
 
90
  # Project visual features
91
- visual_embeds = self.vision_proj(encoder_hidden_states)
92
-
93
- # Create proper attention mask
94
- if encoder_attention_mask is None:
95
- encoder_attention_mask = torch.ones(
96
- visual_embeds.size()[:-1],
97
- dtype=torch.long,
98
- device=visual_embeds.device
99
- )
100
- batch_size = query_embeds.size(0)
101
- extended_attention_mask = encoder_attention_mask.unsqueeze(1).expand(-1, query_embeds.size(1), -1)
102
 
103
- encoder_outputs = self.bert.encoder(
104
- hidden_states=query_embeds,
 
105
  attention_mask=None,
106
  encoder_hidden_states=visual_embeds,
107
- encoder_attention_mask=encoder_attention_mask,
108
  return_dict=True
109
  )
110
- return encoder_outputs.last_hidden_state
111
 
 
112
 
113
  class LayerNorm(nn.LayerNorm):
114
  """Subclass torch's LayerNorm to handle fp16."""
@@ -137,19 +126,13 @@ class SkinGPT4(nn.Module):
137
  q_former_model="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth"):
138
  super().__init__()
139
  # Image encoder parameters from paper
 
140
  self.dtype = torch.float16
141
  self.H, self.W, self.C = 224, 224, 3
142
  self.P = 14 # Patch size
143
  self.D = 1408 # ViT embedding dimension
144
  self.num_query_tokens = 32
145
 
146
- # self.tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-13b-chat-hf",
147
- # token=token, padding_side="right")
148
- #
149
- # print("Loaded tokenizer")
150
- # self.tokenizer.add_special_tokens({'additional_special_tokens': ['<ImageHere>']})
151
-
152
- # Initialize components
153
  self.vit = self._init_vit(vit_checkpoint_path)
154
  print("Loaded ViT")
155
  self.ln_vision = nn.LayerNorm(self.D).to(self.dtype)
@@ -161,7 +144,10 @@ class SkinGPT4(nn.Module):
161
  self.q_former.load_from_pretrained(q_former_model)
162
  for param in self.q_former.parameters():
163
  param.requires_grad = False
164
- self.q_former.eval()
 
 
 
165
  print("Loaded QFormer")
166
 
167
  self.tokenizer = LlamaTokenizer.from_pretrained(
@@ -169,24 +155,18 @@ class SkinGPT4(nn.Module):
169
  token=token,
170
  padding_side="right"
171
  )
172
- self.tokenizer.add_special_tokens({'additional_special_tokens': ['<Img>', '</Img>', '<Image>']})
173
 
174
  self.llama = self._init_llama()
175
- # self.llama.resize_token_embeddings(len(self.tokenizer))
176
  self.llama.resize_token_embeddings(len(self.tokenizer))
177
 
178
  self.llama_proj = nn.Linear(
179
  self.q_former.bert_config.hidden_size,
180
  self.llama.config.hidden_size
181
  ).to(self.dtype)
182
- self._init_alignment_projection()
183
- print("Loaded Llama")
184
- # Initialize learnable query tokens
185
 
186
- self.query_tokens = nn.Parameter(
187
- torch.zeros(1, self.num_query_tokens, self.q_former.bert_config.hidden_size)
188
- )
189
- nn.init.normal_(self.query_tokens, std=0.02)
190
 
191
  def _init_vit(self, vit_checkpoint_path):
192
  """Initialize EVA-ViT-G with paper specifications"""
@@ -297,28 +277,13 @@ class SkinGPT4(nn.Module):
297
  return x # (B, N+1, D)
298
 
299
  def forward(self, images):
300
- images = images.to(self.dtype)
301
- # Convert images to patches
302
- x = self._create_patches(images) # (B, N+1, D)
303
-
304
- # ViT processing
305
- x = x.to(self.dtype)
306
- self.vit = self.vit.to(self.dtype)
307
- vit_output = self.forward_encoder(x) # (B, N+1, D)
308
-
309
- # Q-Former processing
310
- query_tokens = self.query_tokens.expand(x.size(0), -1, -1).to(torch.float16)
311
- qformer_output = self.q_former(
312
- query_embeds=query_tokens,
313
- encoder_hidden_states=vit_output.to(torch.float16),
314
- encoder_attention_mask=torch.ones_like(vit_output[:, :, 0])
315
- ).to(self.dtype)
316
-
317
- # Alignment projection
318
- aligned_features = self.llama_proj(qformer_output.to(self.dtype))
319
-
320
  return aligned_features
321
 
 
322
  def add_to_history(self, role, content):
323
  self.conversation_history.append({"role": role, "content": content})
324
 
@@ -347,85 +312,42 @@ class SkinGPT4(nn.Module):
347
 
348
  def generate(self, images, user_input=None, max_length=300):
349
  print("Analysing the image to generate the diagnosis")
350
- # Get aligned features
351
  aligned_features = self.forward(images)
352
  print("Generated the aligned features with ViT and Qformer")
353
- # prompt = self.build_prompt(aligned_features, user_input)
354
- # prompt = (
355
- # "### Instruction: <Img><ImageHere></Img> "
356
- # "Could you describe the skin disease in this image for me? "
357
- # "### Response:"
358
- # )
359
-
360
- # prompt_parts = [
361
- # "### Instruction: <Img>",
362
- # "<Image>",
363
- # "</Img> Could you describe the skin disease in this image for me? ### Response:"
364
- # ]
365
- prompt = "### Instruction: <Img><Image></Img> Could you describe the skin disease in this image for me? ### Response:"
366
  inputs = self.tokenizer(prompt, return_tensors="pt").to(images.device)
367
-
368
- # Tokenize each part separately
369
- # tokens_before = self.tokenizer(prompt_parts[0], return_tensors="pt").input_ids.to(images.device)
370
- # tokens_after = self.tokenizer(prompt_parts[2], return_tensors="pt").input_ids.to(images.device)
371
- # input_ids = torch.cat([
372
- # tokens_before[:, :-1], # Remove EOS from first part
373
- # torch.full((1, 1), self.tokenizer.convert_tokens_to_ids("<Image>")).to(images.device),
374
- # tokens_after[:, 1:] # Remove BOS from second part
375
- # ], dim=1)
376
- # embeddings = self.llama.model.embed_tokens(input_ids)
377
- # image_token_pos = (input_ids == self.tokenizer.convert_tokens_to_ids("<Image>")).nonzero()
378
- # embeddings[image_token_pos] = aligned_features.mean(dim=1)
379
-
380
- image_token_id = self.tokenizer.convert_tokens_to_ids("<Image>")
381
- image_token_pos = (inputs.input_ids == image_token_id).nonzero()
382
-
383
- if image_token_pos.shape[0] != 1:
384
- raise ValueError(f"Expected 1 image token, found {image_token_pos.shape[0]}")
385
-
386
  # Prepare embeddings
387
- embeddings = self.llama.model.embed_tokens(inputs.input_ids)
388
-
389
- row, col = image_token_pos[0]
390
- embeddings[row, col] = aligned_features.mean(dim=1)
391
 
392
  outputs = self.llama.generate(
393
- inputs_embeds=embeddings,
394
- max_length=max_length,
395
  temperature=0.7,
396
  top_p=0.9,
397
  do_sample=True,
398
- pad_token_id=self.tokenizer.eos_token_id
 
 
399
  )
 
400
  print(f"Output from llama : {outputs}")
401
  full_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
402
- response_start = full_output.find("### Response:") + len("### Response:")
403
- return full_output[response_start:].strip()
404
-
405
-
406
- # self.tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-13b-chat-hf",
407
- # token=token, padding_side="right")
408
- # self.tokenizer.add_special_tokens({'additional_special_tokens': ['<ImageHere>']})
409
- # self.llama.resize_token_embeddings(len(self.tokenizer))
410
- # tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-13b-chat-hf",
411
- # token=token, padding_side="right")
412
- # tokenizer.add_special_tokens({'additional_special_tokens': ['<ImageHere>']})
413
- # self.llama.resize_token_embeddings(len(tokenizer))
414
- # inputs = self.tokenizer(prompt, return_tensors="pt").to(images.device)
415
- # image_embeddings = self.llama.model.embed_tokens(inputs.input_ids)
416
- # image_token_index = torch.where(inputs.input_ids == self.tokenizer.convert_tokens_to_ids("<ImageHere>"))
417
- # image_embeddings[image_token_index] = aligned_features.mean(dim=1) # Pool query tokens
418
- # print("Generating the diagnosis with llama")
419
- # # Generate response
420
- # outputs = self.llama.generate(
421
- # inputs_embeds=image_embeddings,
422
- # max_length=max_length,
423
- # temperature=0.7,
424
- # top_p=0.9,
425
- # do_sample=True
426
- # )
427
- # print("Generated diagnosis")
428
- # return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
429
 
430
  class SkinGPTClassifier:
431
  def __init__(self, device='cuda' if torch.cuda.is_available() else 'cpu'):
@@ -433,8 +355,7 @@ class SkinGPTClassifier:
433
  self.conversation_history = []
434
 
435
  with st.spinner("Loading AI models (this may take several minutes)..."):
436
- self.meta_model = self.load_models()
437
- self.resnet_feature_extractor = None
438
 
439
  # Image transformations
440
  self.transform = transforms.Compose([
@@ -443,26 +364,27 @@ class SkinGPTClassifier:
443
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
444
  ])
445
 
446
- def load_models(self):
447
  model_path = hf_hub_download(
448
  repo_id="KeerthiVM/SkinCancerDiagnosis",
449
  filename="dermnet_finetuned_version1.pth",
450
  )
451
- meta_model = SkinGPT4(vit_checkpoint_path=model_path)
452
- return meta_model
 
 
453
 
454
  def predict(self, image):
455
  image = image.convert('RGB')
456
  image_tensor = self.transform(image).unsqueeze(0).to(self.device)
457
- diagnosis = self.meta_model.generate(
458
- image_tensor
459
- )
460
 
461
  return {
462
- "top_predictions": diagnosis,
 
463
  }
464
 
465
- # @st.cache_resource
466
  def get_classifier():
467
  return SkinGPTClassifier()
468
 
 
54
  )
55
 
56
  self.bert = BertModel(self.bert_config, add_pooling_layer=False).to(torch.float16)
 
 
 
 
 
 
 
 
57
  self.query_tokens = nn.Parameter(
58
+ torch.zeros(1, num_query_tokens, self.bert_config.hidden_size)
59
  )
60
+ self.vision_proj = nn.Linear(vision_width, self.bert_config.hidden_size)
61
+
62
+ # Initialize weights
63
+ self._init_weights()
64
 
65
+ def _init_weights(self):
66
+ nn.init.normal_(self.query_tokens, std=0.02)
67
+ nn.init.xavier_uniform_(self.vision_proj.weight)
68
+ nn.init.constant_(self.vision_proj.bias, 0)
69
 
70
  def load_from_pretrained(self, url_or_filename):
71
  if url_or_filename.startswith('http'):
 
73
  checkpoint = torch.load(BytesIO(response.content), map_location='cpu')
74
  else:
75
  checkpoint = torch.load(url_or_filename, map_location='cpu')
 
 
76
  state_dict = checkpoint['model'] if 'model' in checkpoint else checkpoint
77
  msg = self.load_state_dict(state_dict, strict=False)
 
 
 
 
 
78
 
79
+ def forward(self, visual_features):
80
  # Project visual features
81
+ visual_embeds = self.vision_proj(visual_features)
82
+ visual_attention_mask = torch.ones(
83
+ visual_embeds.size()[:-1],
84
+ dtype=torch.long,
85
+ device=visual_embeds.device
86
+ )
87
+
88
+ # Expand query tokens
89
+ query_tokens = self.query_tokens.expand(visual_embeds.shape[0], -1, -1)
 
 
90
 
91
+ # Forward through BERT
92
+ outputs = self.bert(
93
+ None, # No text input
94
  attention_mask=None,
95
  encoder_hidden_states=visual_embeds,
96
+ encoder_attention_mask=visual_attention_mask,
97
  return_dict=True
98
  )
 
99
 
100
+ return outputs.last_hidden_state
101
 
102
  class LayerNorm(nn.LayerNorm):
103
  """Subclass torch's LayerNorm to handle fp16."""
 
126
  q_former_model="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth"):
127
  super().__init__()
128
  # Image encoder parameters from paper
129
+ self.device = device
130
  self.dtype = torch.float16
131
  self.H, self.W, self.C = 224, 224, 3
132
  self.P = 14 # Patch size
133
  self.D = 1408 # ViT embedding dimension
134
  self.num_query_tokens = 32
135
 
 
 
 
 
 
 
 
136
  self.vit = self._init_vit(vit_checkpoint_path)
137
  print("Loaded ViT")
138
  self.ln_vision = nn.LayerNorm(self.D).to(self.dtype)
 
144
  self.q_former.load_from_pretrained(q_former_model)
145
  for param in self.q_former.parameters():
146
  param.requires_grad = False
147
+ for module in [self.vit, self.ln_vision, self.q_former]:
148
+ for param in module.parameters():
149
+ param.requires_grad = False
150
+ module.eval()
151
  print("Loaded QFormer")
152
 
153
  self.tokenizer = LlamaTokenizer.from_pretrained(
 
155
  token=token,
156
  padding_side="right"
157
  )
158
+ self.tokenizer.add_special_tokens({'additional_special_tokens': ['<Img>', '</Img>', '<ImageHere>']})
159
 
160
  self.llama = self._init_llama()
 
161
  self.llama.resize_token_embeddings(len(self.tokenizer))
162
 
163
  self.llama_proj = nn.Linear(
164
  self.q_former.bert_config.hidden_size,
165
  self.llama.config.hidden_size
166
  ).to(self.dtype)
 
 
 
167
 
168
+ for param in self.llama_proj.parameters():
169
+ param.requires_grad = False
 
 
170
 
171
  def _init_vit(self, vit_checkpoint_path):
172
  """Initialize EVA-ViT-G with paper specifications"""
 
277
  return x # (B, N+1, D)
278
 
279
  def forward(self, images):
280
+ x = self._create_patches(images)
281
+ vit_output = self.forward_encoder(x)
282
+ qformer_output = self.q_former(vit_output)
283
+ aligned_features = self.llama_proj(qformer_output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
284
  return aligned_features
285
 
286
+
287
  def add_to_history(self, role, content):
288
  self.conversation_history.append({"role": role, "content": content})
289
 
 
312
 
313
  def generate(self, images, user_input=None, max_length=300):
314
  print("Analysing the image to generate the diagnosis")
 
315
  aligned_features = self.forward(images)
316
  print("Generated the aligned features with ViT and Qformer")
317
+ prompt = (
318
+ "[INST] <<SYS>>\n"
319
+ "You are a dermatology AI assistant. Analyze this skin image carefully and provide:\n"
320
+ "1. A description of visible features\n"
321
+ "2. Potential diagnoses\n"
322
+ "3. Recommendations for next steps\n"
323
+ "<</SYS>>\n\n"
324
+ "<Img><ImageHere></Img> Could you describe the skin disease in this image for me? [/INST]"
325
+ )
 
 
 
 
326
  inputs = self.tokenizer(prompt, return_tensors="pt").to(images.device)
327
+ image_token_id = self.tokenizer.convert_tokens_to_ids("<ImageHere>")
328
+ image_token_pos = torch.where(inputs.input_ids == image_token_id)
329
+ if len(image_token_pos[0]) == 0:
330
+ raise ValueError("Image token not found in prompt")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
331
  # Prepare embeddings
332
+ input_embeddings = self.llama.model.embed_tokens(inputs.input_ids)
333
+ projected_features = self.llama_proj(aligned_features.mean(dim=1, keepdim=True))
334
+ input_embeddings[image_token_pos] = projected_features
 
335
 
336
  outputs = self.llama.generate(
337
+ inputs_embeds=input_embeddings,
338
+ max_new_tokens=max_length,
339
  temperature=0.7,
340
  top_p=0.9,
341
  do_sample=True,
342
+ pad_token_id=self.tokenizer.eos_token_id,
343
+ attention_mask=inputs.attention_mask,
344
+ num_return_sequences=1
345
  )
346
+
347
  print(f"Output from llama : {outputs}")
348
  full_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
349
+ return full_output.split("[/INST]")[-1].strip()
350
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
351
 
352
  class SkinGPTClassifier:
353
  def __init__(self, device='cuda' if torch.cuda.is_available() else 'cpu'):
 
355
  self.conversation_history = []
356
 
357
  with st.spinner("Loading AI models (this may take several minutes)..."):
358
+ self.model = self._load_model()
 
359
 
360
  # Image transformations
361
  self.transform = transforms.Compose([
 
364
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
365
  ])
366
 
367
+ def _load_model(self):
368
  model_path = hf_hub_download(
369
  repo_id="KeerthiVM/SkinCancerDiagnosis",
370
  filename="dermnet_finetuned_version1.pth",
371
  )
372
+ model = SkinGPT4(vit_checkpoint_path=model_path).eval()
373
+ model = model.to(self.device)
374
+ model.eval()
375
+ return model
376
 
377
  def predict(self, image):
378
  image = image.convert('RGB')
379
  image_tensor = self.transform(image).unsqueeze(0).to(self.device)
380
+ with torch.no_grad():
381
+ diagnosis = self.model.generate(image_tensor)
 
382
 
383
  return {
384
+ "diagnosis": diagnosis,
385
+ "visual_features": None # Can return features if needed
386
  }
387
 
 
388
  def get_classifier():
389
  return SkinGPTClassifier()
390