KeerthiVM commited on
Commit
3c32556
Β·
1 Parent(s): fbbfa8d

Initial commit

Browse files
Files changed (1) hide show
  1. app.py +491 -0
app.py ADDED
@@ -0,0 +1,491 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torchvision.transforms as transforms
3
+ import torch
4
+ import io
5
+ import os
6
+ from fpdf import FPDF
7
+ import nest_asyncio
8
+ nest_asyncio.apply()
9
+ device='cuda' if torch.cuda.is_available() else 'cpu'
10
+
11
+ st.set_page_config(page_title="DermBOT", page_icon="🧬", layout="centered")
12
+
13
+ import torch
14
+ from torch import nn
15
+ from torchvision import transforms
16
+ from PIL import Image
17
+ from transformers import LlamaForCausalLM, LlamaTokenizer, BertModel, BertConfig
18
+ from eva_vit import create_eva_vit_g
19
+ import requests
20
+ from io import BytesIO
21
+ import os
22
+
23
+ token = os.getenv("HF_TOKEN")
24
+ if not token:
25
+ raise ValueError("Hugging Face token not found in environment variables")
26
+ import warnings
27
+
28
+ warnings.filterwarnings("ignore")
29
+
30
+
31
+ class Blip2QFormer(nn.Module):
32
+ def __init__(self, num_query_tokens=32, vision_width=1408):
33
+ super().__init__()
34
+ # Load pre-trained Q-Former config
35
+ self.bert_config = BertConfig(
36
+ vocab_size=30522,
37
+ hidden_size=768,
38
+ num_hidden_layers=12,
39
+ num_attention_heads=12,
40
+ intermediate_size=3072,
41
+ hidden_act="gelu",
42
+ hidden_dropout_prob=0.1,
43
+ attention_probs_dropout_prob=0.1,
44
+ max_position_embeddings=512,
45
+ type_vocab_size=2,
46
+ initializer_range=0.02,
47
+ layer_norm_eps=1e-12,
48
+ pad_token_id=0,
49
+ position_embedding_type="absolute",
50
+ use_cache=True,
51
+ classifier_dropout=None,
52
+ )
53
+
54
+ self.bert = BertModel(self.bert_config, add_pooling_layer=False).to(torch.float16)
55
+
56
+ # Replace position embeddings with a dummy implementation
57
+ self.bert.embeddings.position_embeddings = nn.Identity() # Completely bypass position embeddings
58
+
59
+ # Disable word embeddings
60
+ self.bert.embeddings.word_embeddings = None
61
+
62
+ # Initialize query tokens
63
+ self.query_tokens = nn.Parameter(
64
+ torch.zeros(1, num_query_tokens, self.bert_config.hidden_size, dtype=torch.float16)
65
+ )
66
+ self.vision_proj = nn.Sequential(
67
+ nn.Linear(vision_width, self.bert_config.hidden_size),
68
+ nn.LayerNorm(self.bert_config.hidden_size)
69
+ ).to(torch.float16)
70
+
71
+
72
+ def load_from_pretrained(self, url_or_filename):
73
+ if url_or_filename.startswith('http'):
74
+ response = requests.get(url_or_filename)
75
+ checkpoint = torch.load(BytesIO(response.content), map_location='cpu')
76
+ else:
77
+ checkpoint = torch.load(url_or_filename, map_location='cpu')
78
+
79
+ # Load Q-Former weights only
80
+ state_dict = checkpoint['model'] if 'model' in checkpoint else checkpoint
81
+ msg = self.load_state_dict(state_dict, strict=False)
82
+ # print(f"Loaded Q-Former weights with message: {msg}")
83
+
84
+ def forward(self, query_embeds=None, encoder_hidden_states=None, encoder_attention_mask=None):
85
+ if query_embeds is None:
86
+ query_embeds = self.query_tokens.expand(encoder_hidden_states.shape[0], -1, -1)
87
+
88
+ # Project visual features
89
+ visual_embeds = self.vision_proj(encoder_hidden_states)
90
+
91
+ # Create proper attention mask
92
+ if encoder_attention_mask is None:
93
+ encoder_attention_mask = torch.ones(
94
+ visual_embeds.size()[:-1],
95
+ dtype=torch.long,
96
+ device=visual_embeds.device
97
+ )
98
+ batch_size = query_embeds.size(0)
99
+ extended_attention_mask = encoder_attention_mask.unsqueeze(1).expand(-1, query_embeds.size(1), -1)
100
+
101
+ encoder_outputs = self.bert.encoder(
102
+ hidden_states=query_embeds,
103
+ attention_mask=None,
104
+ encoder_hidden_states=visual_embeds,
105
+ encoder_attention_mask=encoder_attention_mask,
106
+ return_dict=True
107
+ )
108
+ return encoder_outputs.last_hidden_state
109
+
110
+
111
+ class LayerNorm(nn.LayerNorm):
112
+ """Subclass torch's LayerNorm to handle fp16."""
113
+
114
+ def forward(self, x: torch.Tensor):
115
+ orig_type = x.dtype
116
+ ret = super().forward(x.type(torch.float32))
117
+ return ret.type(orig_type)
118
+
119
+
120
+ class ViTClassifier(nn.Module):
121
+ def __init__(self, vit, ln_vision, num_labels):
122
+ super(ViTClassifier, self).__init__()
123
+ self.vit = vit # Pretrained ViT from MiniGPT-4
124
+ self.ln_vision = ln_vision # LayerNorm from MiniGPT-4
125
+ self.classifier = nn.Linear(vit.num_features, num_labels)
126
+
127
+ def forward(self, x):
128
+ features = self.ln_vision(self.vit(x)) # [batch, seq_len, dim]
129
+ cls_token = features[:, 0, :] # Extract CLS token
130
+ return self.classifier(cls_token)
131
+
132
+
133
+ class SkinGPT4(nn.Module):
134
+ def __init__(self, vit_checkpoint_path,
135
+ q_former_model="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth"):
136
+ super().__init__()
137
+ # Image encoder parameters from paper
138
+ self.dtype = torch.float16
139
+ self.H, self.W, self.C = 224, 224, 3
140
+ self.P = 14 # Patch size
141
+ self.D = 1408 # ViT embedding dimension
142
+ self.num_query_tokens = 32
143
+ # Initialize components
144
+ self.vit = self._init_vit(vit_checkpoint_path)
145
+ print("Loaded ViT")
146
+ self.ln_vision = nn.LayerNorm(self.D).to(self.dtype)
147
+
148
+ self.q_former = Blip2QFormer(
149
+ num_query_tokens=self.num_query_tokens,
150
+ vision_width=self.D
151
+ ).to(self.dtype)
152
+ self.q_former.load_from_pretrained(q_former_model)
153
+ for param in self.q_former.parameters():
154
+ param.requires_grad = False
155
+ self.q_former.eval()
156
+ print("Loaded QFormer")
157
+ self.llama = self._init_llama()
158
+ self.llama_proj = nn.Linear(
159
+ self.q_former.bert_config.hidden_size,
160
+ self.llama.config.hidden_size
161
+ ).to(self.dtype)
162
+ self._init_alignment_projection()
163
+ print("Loaded Llama")
164
+ # Initialize learnable query tokens
165
+
166
+ self.query_tokens = nn.Parameter(
167
+ torch.zeros(1, self.num_query_tokens, self.q_former.bert_config.hidden_size)
168
+ )
169
+ nn.init.normal_(self.query_tokens, std=0.02)
170
+ self.tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-13b-chat-hf",
171
+ token=token, padding_side="right")
172
+
173
+ print("Loaded tokenizer")
174
+ def _init_vit(self, vit_checkpoint_path):
175
+ """Initialize EVA-ViT-G with paper specifications"""
176
+ vit = create_eva_vit_g(
177
+ img_size=(self.H, self.W),
178
+ patch_size=self.P,
179
+ embed_dim=self.D,
180
+ depth=39,
181
+ num_heads=16,
182
+ mlp_ratio=4.3637,
183
+ qkv_bias=True,
184
+ drop_path_rate=0.1,
185
+ norm_layer=nn.LayerNorm,
186
+ init_values=1e-5
187
+ ).to(self.dtype)
188
+ if not hasattr(vit, 'norm'):
189
+ vit.norm = nn.LayerNorm(self.D)
190
+ checkpoint = torch.load(vit_checkpoint_path, map_location='cpu')
191
+ # 3. Filter weights for ViT components only
192
+ vit_weights = {k.replace("vit.", ""): v
193
+ for k, v in checkpoint.items()
194
+ if k.startswith("vit.")}
195
+
196
+ # 4. Load weights while ignoring classifier head
197
+ vit.load_state_dict(vit_weights, strict=False)
198
+
199
+ # 5. Freeze according to paper specs
200
+ for param in vit.parameters():
201
+ param.requires_grad = False
202
+
203
+ return vit.eval()
204
+
205
+ def _init_llama(self):
206
+ """Initialize frozen LLaMA-2-13b-chat with proper error handling"""
207
+ try:
208
+ from transformers import BitsAndBytesConfig
209
+ from accelerate import init_empty_weights
210
+
211
+ # Configure 4-bit quantization to reduce memory usage
212
+ # quantization_config = BitsAndBytesConfig(
213
+ # load_in_4bit=True,
214
+ # bnb_4bit_compute_dtype=torch.float16,
215
+ # bnb_4bit_use_double_quant=True,
216
+ # bnb_4bit_quant_type="nf4"
217
+ # )
218
+ quant_config = BitsAndBytesConfig(
219
+ load_in_4bit=True,
220
+ bnb_4bit_compute_dtype=torch.float16,
221
+ bnb_4bit_quant_type="nf4",
222
+ )
223
+
224
+ # First try loading with device_map="auto"
225
+ try:
226
+ model = LlamaForCausalLM.from_pretrained(
227
+ "meta-llama/Llama-2-13b-chat-hf",
228
+ # quantization_config=quant_config,
229
+ token=token,
230
+ torch_dtype=torch.float16,
231
+ device_map="auto",
232
+ low_cpu_mem_usage=True
233
+ )
234
+ except ImportError:
235
+ # Fallback to CPU-offloading if GPU memory is insufficient
236
+ with init_empty_weights():
237
+ model = LlamaForCausalLM.from_pretrained(
238
+ "meta-llama/Llama-2-13b-chat-hf",
239
+ token=token,
240
+ torch_dtype=torch.float16
241
+ )
242
+ model = model.to(self.device)
243
+
244
+ # Freeze all parameters
245
+ for param in model.parameters():
246
+ param.requires_grad = False
247
+
248
+ return model.eval()
249
+
250
+ except Exception as e:
251
+ raise ImportError(
252
+ f"Failed to load LLaMA model. Please ensure:\n"
253
+ f"1. You have accepted the license at: https://huggingface.co/meta-llama/Llama-2-13b-chat-hf\n"
254
+ f"2. Your Hugging Face token is correct\n"
255
+ f"3. Required packages are installed: pip install accelerate bitsandbytes transformers\n"
256
+ f"Original error: {str(e)}"
257
+ )
258
+
259
+ def _init_alignment_projection(self):
260
+ """Paper specifies Xavier initialization for alignment layer"""
261
+ nn.init.xavier_normal_(self.llama_proj.weight)
262
+ nn.init.constant_(self.llama_proj.bias, 0)
263
+
264
+ def _create_patches(self, x):
265
+ """Convert image to patch embeddings following Eq. (1)"""
266
+ # x: (B, C, H, W)
267
+ x = x.to(self.dtype)
268
+ print(f"Shape of x : {x.shape}")
269
+ if x.dim() == 3:
270
+ x = x.unsqueeze(0) # Add batch dimension if missing
271
+ if x.dim() != 4:
272
+ raise ValueError(f"Input must be 4D tensor (got {x.dim()}D)")
273
+
274
+ B, C, H, W = x.shape
275
+ N = (H * W) // (self.P ** 2)
276
+
277
+ x = self.vit.patch_embed(x) # (B, N, D)
278
+
279
+ num_patches = x.shape[1]
280
+ pos_embed = self.vit.pos_embed[:, 1:num_patches + 1, :] # Adjust for exact match
281
+ x = x + pos_embed
282
+
283
+ # Add class token
284
+ class_token = self.vit.cls_token.expand(B, -1, -1)
285
+ x = torch.cat([class_token, x], dim=1) # (B, N+1, D)
286
+ print(f"Final output shape: {x.shape}")
287
+ return x
288
+
289
+ def forward_encoder(self, x):
290
+ """ViT encoder from Eqs. (2)-(3)"""
291
+ # x: (B, N+1, D)
292
+ for blk in self.vit.blocks:
293
+ x = blk(x)
294
+ x = self.vit.norm(x)
295
+ x = self.ln_vision(x)
296
+ return x # (B, N+1, D)
297
+
298
+ def forward(self, images):
299
+ images = images.to(self.dtype)
300
+ # Convert images to patches
301
+ x = self._create_patches(images) # (B, N+1, D)
302
+
303
+ # ViT processing
304
+ x = x.to(self.dtype)
305
+ self.vit = self.vit.to(self.dtype)
306
+ vit_output = self.forward_encoder(x) # (B, N+1, D)
307
+
308
+ # Q-Former processing
309
+ query_tokens = self.query_tokens.expand(x.size(0), -1, -1).to(torch.float16)
310
+ qformer_output = self.q_former(
311
+ query_embeds=query_tokens,
312
+ encoder_hidden_states=vit_output.to(torch.float16),
313
+ encoder_attention_mask=torch.ones_like(vit_output[:, :, 0])
314
+ ).to(self.dtype)
315
+
316
+ # Alignment projection
317
+ aligned_features = self.llama_proj(qformer_output.to(self.dtype))
318
+
319
+ return aligned_features
320
+
321
+ def add_to_history(self, role, content):
322
+ self.conversation_history.append({"role": role, "content": content})
323
+
324
+ def get_full_context(self):
325
+ return "\n".join([f"{msg['role']}: {msg['content']}" for msg in self.conversation_history])
326
+
327
+ def build_prompt(self, image_embeds, user_question=None):
328
+ # Base prompt for initial diagnosis
329
+ if not user_question:
330
+ prompt = (
331
+ "### Instruction: <Img ><Image ></Img> "
332
+ "Could you describe the skin disease in this image for me? "
333
+ "### Response:"
334
+ )
335
+ else:
336
+ # Follow-up prompt with conversation history
337
+ history = self.get_full_context()
338
+ prompt = (
339
+ f"### Instruction: <Img ><Image ></Img> "
340
+ f"Based on our previous conversation:\n{history}\n"
341
+ f"User asks: {user_question}\n"
342
+ "### Response:"
343
+ )
344
+
345
+ return prompt
346
+
347
+ def generate(self, images, user_input=None, max_length=300):
348
+ # Get aligned features
349
+ images = images.to(self.dtype)
350
+
351
+ aligned_features = self.forward(images)
352
+
353
+ prompt = self.build_prompt(aligned_features, user_input)
354
+
355
+ self.llama = self.llama.to(self.dtype)
356
+
357
+ # Tokenize prompt
358
+
359
+ self.tokenizer.add_special_tokens({'additional_special_tokens': ['<ImageHere>']})
360
+ self.llama.resize_token_embeddings(len(self.tokenizer))
361
+
362
+ inputs = self.tokenizer(prompt, return_tensors="pt").to(images.device)
363
+
364
+ # Replace <ImageHere> with aligned features
365
+ image_embeddings = self.llama.model.embed_tokens(inputs.input_ids)
366
+ image_token_index = torch.where(inputs.input_ids == self.tokenizer.convert_tokens_to_ids("<ImageHere>"))
367
+ image_embeddings[image_token_index] = aligned_features.mean(dim=1) # Pool query tokens
368
+
369
+ # Generate response
370
+ outputs = self.llama.generate(
371
+ inputs_embeds=image_embeddings,
372
+ max_length=max_length,
373
+ temperature=0.7,
374
+ top_p=0.9,
375
+ do_sample=True
376
+ )
377
+
378
+ return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
379
+
380
+
381
+ def load_model(model_path):
382
+ model = SkinGPT4(vit_checkpoint_path="dermnet_finetuned_version1.pth")
383
+ model.to(device)
384
+ model.eval()
385
+ return model
386
+
387
+
388
+
389
+ class SkinGPTClassifier:
390
+ def __init__(self, device='cuda' if torch.cuda.is_available() else 'cpu'):
391
+ self.device = torch.device(device)
392
+ self.conversation_history = []
393
+ # Initialize models (they'll be loaded when needed)
394
+ self.base_models = None
395
+ self.meta_model = None
396
+ self.resnet_feature_extractor = None
397
+
398
+ # Image transformations
399
+ self.transform = transforms.Compose([
400
+ transforms.Resize((224, 224)),
401
+ transforms.ToTensor(),
402
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
403
+ ])
404
+
405
+ def load_models(self):
406
+
407
+ self.meta_model = SkinGPT4(vit_checkpoint_path="dermnet_finetuned_version1.pth")
408
+ self.meta_model.to_empty(device=device)
409
+
410
+ def predict(self, image, top_k=3):
411
+ """Make prediction for a single image"""
412
+ if self.meta_model is None:
413
+ self.load_models()
414
+
415
+ # Load and preprocess image
416
+ try:
417
+ # image = Image.open(image_path).convert('RGB')
418
+ image = image.convert('RGB')
419
+ except:
420
+ raise ValueError("Could not load image from path")
421
+
422
+ image_tensor = self.transform(image).unsqueeze(0).to(self.device)
423
+ diagnosis = self.meta_model.generate(
424
+ image_tensor
425
+ )
426
+
427
+ return {
428
+ "top_predictions": diagnosis,
429
+ }
430
+
431
+ classifier = SkinGPTClassifier()
432
+
433
+
434
+ # === Session Init ===
435
+ if "messages" not in st.session_state:
436
+ st.session_state.messages = []
437
+
438
+ # === Image Processing Function ===
439
+ def run_inference(image):
440
+ result = classifier.predict(image, top_k=1)
441
+
442
+ return result
443
+
444
+ # === PDF Export ===
445
+ def export_chat_to_pdf(messages):
446
+ pdf = FPDF()
447
+ pdf.add_page()
448
+ pdf.set_font("Arial", size=12)
449
+ for msg in messages:
450
+ role = "You" if msg["role"] == "user" else "AI"
451
+ pdf.multi_cell(0, 10, f"{role}: {msg['content']}\n")
452
+ buf = io.BytesIO()
453
+ pdf.output(buf)
454
+ buf.seek(0)
455
+ return buf
456
+
457
+ # === App UI ===
458
+
459
+ st.title("🧬 DermBOT β€” Skin AI Assistant")
460
+ st.caption(f"🧠 Using model: SkinGPT")
461
+ uploaded_file = st.file_uploader("Upload a skin image", type=["jpg", "jpeg", "png"])
462
+ if "conversation" not in st.session_state:
463
+ st.session_state.conversation = []
464
+ if uploaded_file:
465
+ st.image(uploaded_file, caption="Uploaded image", use_column_width=True)
466
+ image = Image.open(uploaded_file).convert("RGB")
467
+ if not st.session_state.conversation:
468
+ # First message - diagnosis
469
+ diagnosis = classifier.predict(image, top_k=1)
470
+ st.session_state.conversation.append(("assistant", diagnosis))
471
+ with st.chat_message("assistant"):
472
+ st.markdown(diagnosis)
473
+ else:
474
+ # Follow-up questions
475
+ if user_query := st.chat_input("Ask a follow-up question..."):
476
+ st.session_state.conversation.append(("user", user_query))
477
+ with st.chat_message("user"):
478
+ st.markdown(user_query)
479
+
480
+ # Generate response with context
481
+ context = "\n".join([f"{role}: {msg}" for role, msg in st.session_state.conversation])
482
+ response = classifier.generate(image, user_input=context)
483
+
484
+ st.session_state.conversation.append(("assistant", response))
485
+ with st.chat_message("assistant"):
486
+ st.markdown(response)
487
+
488
+ # === PDF Button ===
489
+ if st.button("πŸ“„ Download Chat as PDF"):
490
+ pdf_file = export_chat_to_pdf(st.session_state.messages)
491
+ st.download_button("Download PDF", data=pdf_file, file_name="chat_history.pdf", mime="application/pdf")