import torch from torch import nn from torchvision import transforms from PIL import Image from transformers import LlamaForCausalLM, LlamaTokenizer, BertModel, BertConfig from eva_vit import create_eva_vit_g import requests from io import BytesIO import os from huggingface_hub import hf_hub_download from transformers import BitsAndBytesConfig from accelerate import init_empty_weights import torch from torch.cuda.amp import autocast import warnings MODEL_DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32 token = os.getenv("HF_TOKEN") import streamlit as st import torch.nn.functional as F device = 'cuda' if torch.cuda.is_available() else 'cpu' class Blip2QFormer(nn.Module): def __init__(self, num_query_tokens=32, vision_width=1408): super().__init__() # Load pre-trained Q-Former config self.bert_config = BertConfig( vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None, ) self.bert = BertModel(self.bert_config, add_pooling_layer=False) self.query_tokens = nn.Parameter( torch.zeros(1, num_query_tokens, self.bert_config.hidden_size) ) self.vision_proj = nn.Linear(vision_width, self.bert_config.hidden_size) # Initialize weights self._init_weights() def _init_weights(self): nn.init.normal_(self.query_tokens, std=0.02) nn.init.xavier_uniform_(self.vision_proj.weight) nn.init.constant_(self.vision_proj.bias, 0) def load_from_pretrained(self, url_or_filename): if url_or_filename.startswith('http'): response = requests.get(url_or_filename) checkpoint = torch.load(BytesIO(response.content), map_location='cpu') else: checkpoint = torch.load(url_or_filename, map_location='cpu') state_dict = checkpoint['model'] if 'model' in checkpoint else checkpoint msg = self.load_state_dict(state_dict, strict=False) def forward(self, visual_features): # Project visual features with autocast(enabled=False): visual_embeds = self.vision_proj(visual_features.float()) # visual_embeds = self.vision_proj(visual_features.float()) visual_attention_mask = torch.ones( visual_embeds.size()[:-1], dtype=torch.long, device=visual_embeds.device ) # Expand query tokens query_tokens = self.query_tokens.expand(visual_embeds.shape[0], -1, -1) # Forward through BERT outputs = self.bert( input_ids=None, # No text input attention_mask=None, inputs_embeds=query_tokens, encoder_hidden_states=visual_embeds, encoder_attention_mask=visual_attention_mask, return_dict=True ) return outputs.last_hidden_state class SkinGPT4(nn.Module): def __init__(self, vit_checkpoint_path, q_former_model="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth"): super().__init__() # Image encoder parameters from paper self.device = device # self.dtype = torch.float16 self.dtype = MODEL_DTYPE self.H, self.W, self.C = 224, 224, 3 self.P = 14 # Patch size self.D = 1408 # ViT embedding dimension self.num_query_tokens = 32 self.vit = self._init_vit(vit_checkpoint_path).to(self.dtype) print("Loaded ViT") self.ln_vision = nn.LayerNorm(self.D).to(self.dtype) self.q_former = Blip2QFormer( num_query_tokens=self.num_query_tokens, vision_width=self.D ) self.q_former.load_from_pretrained(q_former_model) for param in self.q_former.parameters(): param.requires_grad = False print("Loaded QFormer") self.llama = self._init_llama() self.llama_proj = nn.Linear( self.q_former.bert_config.hidden_size, self.llama.config.hidden_size ).to(self.dtype) print(f"Q-Former output dim: {self.q_former.bert_config.hidden_size}") print(f"LLaMA input dim: {self.llama.config.hidden_size}") for module in [self.vit, self.ln_vision, self.q_former, self.llama_proj, self.llama]: for param in module.parameters(): param.requires_grad = False module.eval() def _init_vit(self, vit_checkpoint_path): """Initialize EVA-ViT-G with paper specifications""" vit = create_eva_vit_g( img_size=(self.H, self.W), patch_size=self.P, embed_dim=self.D, depth=39, num_heads=16, mlp_ratio=4.3637, qkv_bias=True, drop_path_rate=0.1, norm_layer=nn.LayerNorm, init_values=1e-5 ).to(self.dtype) if not hasattr(vit, 'norm'): vit.norm = nn.LayerNorm(self.D) checkpoint = torch.load(vit_checkpoint_path, map_location='cpu') # 3. Filter weights for ViT components only vit_weights = {k.replace("vit.", ""): v for k, v in checkpoint.items() if k.startswith("vit.")} # 4. Load weights while ignoring classifier head vit.load_state_dict(vit_weights, strict=False) return vit.eval() def _init_llama(self): """Initialize frozen LLaMA-2-13b-chat with proper error handling""" try: device_map = { "": 0 if torch.cuda.is_available() else "cpu" } # First try loading with device_map="auto" model = LlamaForCausalLM.from_pretrained( "meta-llama/Llama-2-13b-chat-hf", token=token, torch_dtype=torch.float16, device_map=device_map, low_cpu_mem_usage=True ) return model.eval() except Exception as e: raise ImportError( f"Failed to load LLaMA model. Please ensure:\n" f"1. You have accepted the license at: https://huggingface.co/meta-llama/Llama-2-13b-chat-hf\n" f"2. Your Hugging Face token is correct\n" f"3. Required packages are installed: pip install accelerate bitsandbytes transformers\n" f"Original error: {str(e)}" ) def encode_image(self, x): """Convert image to patch embeddings following Eq. (1)""" # x: (B, C, H, W) x = x.to(self.dtype) if x.dim() == 3: x = x.unsqueeze(0) # Add batch dimension if missing if x.dim() != 4: raise ValueError(f"Input must be 4D tensor (got {x.dim()}D)") B, C, H, W = x.shape N = (H * W) // (self.P ** 2) x = self.vit.patch_embed(x) num_patches = x.shape[1] pos_embed = self.vit.pos_embed[:, 1:num_patches + 1, :] x = x + pos_embed # Add class token class_token = self.vit.cls_token.expand(x.shape[0], -1, -1) x = torch.cat([class_token, x], dim=1) for blk in self.vit.blocks: x = blk(x) x = self.vit.norm(x) vit_features = self.ln_vision(x) # Q-Former forward pass with torch.no_grad(): qformer_output = self.q_former(vit_features.float()) image_embeds = self.llama_proj(qformer_output.to(self.dtype)) return image_embeds def generate(self, images, user_input=None, max_new_tokens=300): image_embeds = self.encode_image(images) print(f"Aligned features : {image_embeds}") print(f"\n Images embeddings shape : {image_embeds.shape} \n Llama config hidden size : {self.llama.config.hidden_size}") print( f"\n[VALIDATION] Visual embeds - Mean: {image_embeds.mean().item():.4f}, Std: {image_embeds.std().item():.4f}") if image_embeds.shape[-1] != self.llama.config.hidden_size: raise ValueError( f"Feature dimension mismatch. " f"Q-Former output: {image_embeds.shape[-1]}, " f"LLaMA expected: {self.llama.config.hidden_size}" ) # prompt = ( # "### Instruction: " # "Could you describe the skin condition in this image? " # "Please provide a detailed analysis including possible diagnoses. " # "### Response:" # ) prompt = """### Skin Diagnosis Analysis ### Could you describe the skin condition in this image? Please provide a detailed analysis including possible diagnoses. ### Response:""" print(f"\n[DEBUG] Raw Prompt:\n{prompt}") self.tokenizer = LlamaTokenizer.from_pretrained( "meta-llama/Llama-2-13b-chat-hf", token=token, padding_side="right" ) # self.tokenizer.add_special_tokens({'additional_special_tokens': ['', '', '']}) num_added = self.tokenizer.add_special_tokens({ 'additional_special_tokens': [''] }) if num_added == 0: raise ValueError("Failed to add token!") self.llama.resize_token_embeddings(len(self.tokenizer)) inputs = self.tokenizer(prompt, return_tensors="pt").to(images.device) print(f"\n[DEBUG] Tokenized input IDs:\n{inputs.input_ids}") print(f"[DEBUG] Special token positions: {self.tokenizer.all_special_tokens}") # Prepare embeddings input_embeddings = self.llama.model.embed_tokens(inputs.input_ids) visual_embeds = image_embeds.mean(dim=1) # image_token_id = self.tokenizer.convert_tokens_to_ids("") image_token_id = self.tokenizer.convert_tokens_to_ids("") replace_positions = (inputs.input_ids == image_token_id).nonzero() if len(replace_positions) == 0: raise ValueError("No tokens found in prompt!") if len(replace_positions[0]) == 0: raise ValueError("Image token not found in prompt") print(f"\n[DEBUG] Image token found at position: {replace_positions}") print(f"\n[DEBUG] Before replacement:") print(f"Text embeddings shape: {input_embeddings.shape}") print(f"Visual embeddings shape: {visual_embeds.shape}") print(f"Image token at {replace_positions[0][1].item()}:") print(f"Image token embedding (before):\n{input_embeddings[0, replace_positions[0][1], :5]}...") for pos in replace_positions: input_embeddings[0, pos[1]] = visual_embeds[0] print(f"\n[DEBUG] After replacement:") print(f"Image token embedding (after):\n{input_embeddings[0, replace_positions[0][1], :5]}...") outputs = self.llama.generate( inputs_embeds=input_embeddings, max_new_tokens=max_new_tokens, temperature=0.7, top_k=40, top_p=0.9, repetition_penalty=1.1, do_sample=True, pad_token_id = self.tokenizer.eos_token_id, eos_token_id = self.tokenizer.eos_token_id ) full_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True) print(f"Full Output from llama : {full_output}") response = full_output.split("### Response:")[-1].strip() # print(f"Response from llama : {full_output}") return response class SkinGPTClassifier: def __init__(self, device='cuda' if torch.cuda.is_available() else 'cpu'): self.device = torch.device(device) self.conversation_history = [] with st.spinner("Loading AI models (this may take several minutes)..."): self.model = self._load_model() # print(f"Q-Former output shape: {self.model.q_former(torch.randn(1, 197, 1408)).shape}") # print(f"Projection layer: {self.model.llama_proj}") # Image transformations self.transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) def _load_model(self): model_path = hf_hub_download( repo_id="KeerthiVM/SkinCancerDiagnosis", filename="dermnet_finetuned_version1.pth", ) model = SkinGPT4(vit_checkpoint_path=model_path).eval() model = model.to(self.device) return model def predict(self, image): image = image.convert('RGB') image_tensor = self.transform(image).unsqueeze(0).to(self.device) with torch.no_grad(): diagnosis = self.model.generate(image_tensor) return { "diagnosis": diagnosis, "visual_features": None # Can return features if needed }