SkinGPT / SkinGPT.py
KeerthiVM's picture
fix added
6a1c9cc
import torch
from torch import nn
from torchvision import transforms
from PIL import Image
from transformers import LlamaForCausalLM, LlamaTokenizer, BertModel, BertConfig
from eva_vit import create_eva_vit_g
import requests
from io import BytesIO
import os
from huggingface_hub import hf_hub_download
from transformers import BitsAndBytesConfig
from accelerate import init_empty_weights
import torch
from torch.cuda.amp import autocast
import warnings
MODEL_DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
token = os.getenv("HF_TOKEN")
import streamlit as st
import torch.nn.functional as F
device = 'cuda' if torch.cuda.is_available() else 'cpu'
class Blip2QFormer(nn.Module):
def __init__(self, num_query_tokens=32, vision_width=1408):
super().__init__()
# Load pre-trained Q-Former config
self.bert_config = BertConfig(
vocab_size=30522,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02,
layer_norm_eps=1e-12,
pad_token_id=0,
position_embedding_type="absolute",
use_cache=True,
classifier_dropout=None,
)
self.bert = BertModel(self.bert_config, add_pooling_layer=False)
self.query_tokens = nn.Parameter(
torch.zeros(1, num_query_tokens, self.bert_config.hidden_size)
)
self.vision_proj = nn.Linear(vision_width, self.bert_config.hidden_size)
# Initialize weights
self._init_weights()
def _init_weights(self):
nn.init.normal_(self.query_tokens, std=0.02)
nn.init.xavier_uniform_(self.vision_proj.weight)
nn.init.constant_(self.vision_proj.bias, 0)
def load_from_pretrained(self, url_or_filename):
if url_or_filename.startswith('http'):
response = requests.get(url_or_filename)
checkpoint = torch.load(BytesIO(response.content), map_location='cpu')
else:
checkpoint = torch.load(url_or_filename, map_location='cpu')
state_dict = checkpoint['model'] if 'model' in checkpoint else checkpoint
msg = self.load_state_dict(state_dict, strict=False)
def forward(self, visual_features):
# Project visual features
with autocast(enabled=False):
visual_embeds = self.vision_proj(visual_features.float())
# visual_embeds = self.vision_proj(visual_features.float())
visual_attention_mask = torch.ones(
visual_embeds.size()[:-1],
dtype=torch.long,
device=visual_embeds.device
)
# Expand query tokens
query_tokens = self.query_tokens.expand(visual_embeds.shape[0], -1, -1)
# Forward through BERT
outputs = self.bert(
input_ids=None, # No text input
attention_mask=None,
inputs_embeds=query_tokens,
encoder_hidden_states=visual_embeds,
encoder_attention_mask=visual_attention_mask,
return_dict=True
)
return outputs.last_hidden_state
class SkinGPT4(nn.Module):
def __init__(self, vit_checkpoint_path,
q_former_model="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth"):
super().__init__()
# Image encoder parameters from paper
self.device = device
# self.dtype = torch.float16
self.dtype = MODEL_DTYPE
self.H, self.W, self.C = 224, 224, 3
self.P = 14 # Patch size
self.D = 1408 # ViT embedding dimension
self.num_query_tokens = 32
self.vit = self._init_vit(vit_checkpoint_path).to(self.dtype)
print("Loaded ViT")
self.ln_vision = nn.LayerNorm(self.D).to(self.dtype)
self.q_former = Blip2QFormer(
num_query_tokens=self.num_query_tokens,
vision_width=self.D
)
self.q_former.load_from_pretrained(q_former_model)
for param in self.q_former.parameters():
param.requires_grad = False
print("Loaded QFormer")
self.llama = self._init_llama()
self.llama_proj = nn.Linear(
self.q_former.bert_config.hidden_size,
self.llama.config.hidden_size
).to(self.dtype)
print(f"Q-Former output dim: {self.q_former.bert_config.hidden_size}")
print(f"LLaMA input dim: {self.llama.config.hidden_size}")
for module in [self.vit, self.ln_vision, self.q_former, self.llama_proj, self.llama]:
for param in module.parameters():
param.requires_grad = False
module.eval()
def _init_vit(self, vit_checkpoint_path):
"""Initialize EVA-ViT-G with paper specifications"""
vit = create_eva_vit_g(
img_size=(self.H, self.W),
patch_size=self.P,
embed_dim=self.D,
depth=39,
num_heads=16,
mlp_ratio=4.3637,
qkv_bias=True,
drop_path_rate=0.1,
norm_layer=nn.LayerNorm,
init_values=1e-5
).to(self.dtype)
if not hasattr(vit, 'norm'):
vit.norm = nn.LayerNorm(self.D)
checkpoint = torch.load(vit_checkpoint_path, map_location='cpu')
# 3. Filter weights for ViT components only
vit_weights = {k.replace("vit.", ""): v
for k, v in checkpoint.items()
if k.startswith("vit.")}
# 4. Load weights while ignoring classifier head
vit.load_state_dict(vit_weights, strict=False)
return vit.eval()
def _init_llama(self):
"""Initialize frozen LLaMA-2-13b-chat with proper error handling"""
try:
device_map = {
"": 0 if torch.cuda.is_available() else "cpu"
}
# First try loading with device_map="auto"
model = LlamaForCausalLM.from_pretrained(
"meta-llama/Llama-2-13b-chat-hf",
token=token,
torch_dtype=torch.float16,
device_map=device_map,
low_cpu_mem_usage=True
)
return model.eval()
except Exception as e:
raise ImportError(
f"Failed to load LLaMA model. Please ensure:\n"
f"1. You have accepted the license at: https://huggingface.co/meta-llama/Llama-2-13b-chat-hf\n"
f"2. Your Hugging Face token is correct\n"
f"3. Required packages are installed: pip install accelerate bitsandbytes transformers\n"
f"Original error: {str(e)}"
)
def encode_image(self, x):
"""Convert image to patch embeddings following Eq. (1)"""
# x: (B, C, H, W)
x = x.to(self.dtype)
if x.dim() == 3:
x = x.unsqueeze(0) # Add batch dimension if missing
if x.dim() != 4:
raise ValueError(f"Input must be 4D tensor (got {x.dim()}D)")
B, C, H, W = x.shape
N = (H * W) // (self.P ** 2)
x = self.vit.patch_embed(x)
num_patches = x.shape[1]
pos_embed = self.vit.pos_embed[:, 1:num_patches + 1, :]
x = x + pos_embed
# Add class token
class_token = self.vit.cls_token.expand(x.shape[0], -1, -1)
x = torch.cat([class_token, x], dim=1)
for blk in self.vit.blocks:
x = blk(x)
x = self.vit.norm(x)
vit_features = self.ln_vision(x)
# Q-Former forward pass
with torch.no_grad():
qformer_output = self.q_former(vit_features.float())
image_embeds = self.llama_proj(qformer_output.to(self.dtype))
return image_embeds
def generate(self, images, user_input=None, max_new_tokens=300):
image_embeds = self.encode_image(images)
print(f"Aligned features : {image_embeds}")
print(f"\n Images embeddings shape : {image_embeds.shape} \n Llama config hidden size : {self.llama.config.hidden_size}")
print(
f"\n[VALIDATION] Visual embeds - Mean: {image_embeds.mean().item():.4f}, Std: {image_embeds.std().item():.4f}")
if image_embeds.shape[-1] != self.llama.config.hidden_size:
raise ValueError(
f"Feature dimension mismatch. "
f"Q-Former output: {image_embeds.shape[-1]}, "
f"LLaMA expected: {self.llama.config.hidden_size}"
)
# prompt = (
# "### Instruction: <Img><IMAGE></Img> "
# "Could you describe the skin condition in this image? "
# "Please provide a detailed analysis including possible diagnoses. "
# "### Response:"
# )
prompt = """### Skin Diagnosis Analysis ###
<IMAGE>
Could you describe the skin condition in this image?
Please provide a detailed analysis including possible diagnoses.
### Response:"""
print(f"\n[DEBUG] Raw Prompt:\n{prompt}")
self.tokenizer = LlamaTokenizer.from_pretrained(
"meta-llama/Llama-2-13b-chat-hf",
token=token,
padding_side="right"
)
# self.tokenizer.add_special_tokens({'additional_special_tokens': ['<Img>', '</Img>', '<ImageHere>']})
num_added = self.tokenizer.add_special_tokens({
'additional_special_tokens': ['<IMAGE>']
})
if num_added == 0:
raise ValueError("Failed to add <IMAGE> token!")
self.llama.resize_token_embeddings(len(self.tokenizer))
inputs = self.tokenizer(prompt, return_tensors="pt").to(images.device)
print(f"\n[DEBUG] Tokenized input IDs:\n{inputs.input_ids}")
print(f"[DEBUG] Special token positions: {self.tokenizer.all_special_tokens}")
# Prepare embeddings
input_embeddings = self.llama.model.embed_tokens(inputs.input_ids)
visual_embeds = image_embeds.mean(dim=1)
# image_token_id = self.tokenizer.convert_tokens_to_ids("<ImageHere>")
image_token_id = self.tokenizer.convert_tokens_to_ids("<IMAGE>")
replace_positions = (inputs.input_ids == image_token_id).nonzero()
if len(replace_positions) == 0:
raise ValueError("No <IMAGE> tokens found in prompt!")
if len(replace_positions[0]) == 0:
raise ValueError("Image token not found in prompt")
print(f"\n[DEBUG] Image token found at position: {replace_positions}")
print(f"\n[DEBUG] Before replacement:")
print(f"Text embeddings shape: {input_embeddings.shape}")
print(f"Visual embeddings shape: {visual_embeds.shape}")
print(f"Image token at {replace_positions[0][1].item()}:")
print(f"Image token embedding (before):\n{input_embeddings[0, replace_positions[0][1], :5]}...")
for pos in replace_positions:
input_embeddings[0, pos[1]] = visual_embeds[0]
print(f"\n[DEBUG] After replacement:")
print(f"Image token embedding (after):\n{input_embeddings[0, replace_positions[0][1], :5]}...")
outputs = self.llama.generate(
inputs_embeds=input_embeddings,
max_new_tokens=max_new_tokens,
temperature=0.7,
top_k=40,
top_p=0.9,
repetition_penalty=1.1,
do_sample=True,
pad_token_id = self.tokenizer.eos_token_id,
eos_token_id = self.tokenizer.eos_token_id
)
full_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"Full Output from llama : {full_output}")
response = full_output.split("### Response:")[-1].strip()
# print(f"Response from llama : {full_output}")
return response
class SkinGPTClassifier:
def __init__(self, device='cuda' if torch.cuda.is_available() else 'cpu'):
self.device = torch.device(device)
self.conversation_history = []
with st.spinner("Loading AI models (this may take several minutes)..."):
self.model = self._load_model()
# print(f"Q-Former output shape: {self.model.q_former(torch.randn(1, 197, 1408)).shape}")
# print(f"Projection layer: {self.model.llama_proj}")
# Image transformations
self.transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def _load_model(self):
model_path = hf_hub_download(
repo_id="KeerthiVM/SkinCancerDiagnosis",
filename="dermnet_finetuned_version1.pth",
)
model = SkinGPT4(vit_checkpoint_path=model_path).eval()
model = model.to(self.device)
return model
def predict(self, image):
image = image.convert('RGB')
image_tensor = self.transform(image).unsqueeze(0).to(self.device)
with torch.no_grad():
diagnosis = self.model.generate(image_tensor)
return {
"diagnosis": diagnosis,
"visual_features": None # Can return features if needed
}