|
import torch |
|
from torch import nn |
|
from torchvision import transforms |
|
from PIL import Image |
|
from transformers import LlamaForCausalLM, LlamaTokenizer, BertModel, BertConfig |
|
from eva_vit import create_eva_vit_g |
|
import requests |
|
from io import BytesIO |
|
import os |
|
from huggingface_hub import hf_hub_download |
|
from transformers import BitsAndBytesConfig |
|
from accelerate import init_empty_weights |
|
import torch |
|
from torch.cuda.amp import autocast |
|
import warnings |
|
MODEL_DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32 |
|
token = os.getenv("HF_TOKEN") |
|
import streamlit as st |
|
import torch.nn.functional as F |
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
class Blip2QFormer(nn.Module): |
|
def __init__(self, num_query_tokens=32, vision_width=1408): |
|
super().__init__() |
|
|
|
self.bert_config = BertConfig( |
|
vocab_size=30522, |
|
hidden_size=768, |
|
num_hidden_layers=12, |
|
num_attention_heads=12, |
|
intermediate_size=3072, |
|
hidden_act="gelu", |
|
hidden_dropout_prob=0.1, |
|
attention_probs_dropout_prob=0.1, |
|
max_position_embeddings=512, |
|
type_vocab_size=2, |
|
initializer_range=0.02, |
|
layer_norm_eps=1e-12, |
|
pad_token_id=0, |
|
position_embedding_type="absolute", |
|
use_cache=True, |
|
classifier_dropout=None, |
|
) |
|
|
|
self.bert = BertModel(self.bert_config, add_pooling_layer=False) |
|
self.query_tokens = nn.Parameter( |
|
torch.zeros(1, num_query_tokens, self.bert_config.hidden_size) |
|
) |
|
self.vision_proj = nn.Linear(vision_width, self.bert_config.hidden_size) |
|
|
|
|
|
self._init_weights() |
|
|
|
def _init_weights(self): |
|
nn.init.normal_(self.query_tokens, std=0.02) |
|
nn.init.xavier_uniform_(self.vision_proj.weight) |
|
nn.init.constant_(self.vision_proj.bias, 0) |
|
|
|
def load_from_pretrained(self, url_or_filename): |
|
if url_or_filename.startswith('http'): |
|
response = requests.get(url_or_filename) |
|
checkpoint = torch.load(BytesIO(response.content), map_location='cpu') |
|
else: |
|
checkpoint = torch.load(url_or_filename, map_location='cpu') |
|
state_dict = checkpoint['model'] if 'model' in checkpoint else checkpoint |
|
msg = self.load_state_dict(state_dict, strict=False) |
|
|
|
def forward(self, visual_features): |
|
|
|
with autocast(enabled=False): |
|
visual_embeds = self.vision_proj(visual_features.float()) |
|
|
|
visual_attention_mask = torch.ones( |
|
visual_embeds.size()[:-1], |
|
dtype=torch.long, |
|
device=visual_embeds.device |
|
) |
|
|
|
|
|
query_tokens = self.query_tokens.expand(visual_embeds.shape[0], -1, -1) |
|
|
|
|
|
outputs = self.bert( |
|
input_ids=None, |
|
attention_mask=None, |
|
inputs_embeds=query_tokens, |
|
encoder_hidden_states=visual_embeds, |
|
encoder_attention_mask=visual_attention_mask, |
|
return_dict=True |
|
) |
|
|
|
return outputs.last_hidden_state |
|
|
|
|
|
|
|
class SkinGPT4(nn.Module): |
|
def __init__(self, vit_checkpoint_path, |
|
q_former_model="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth"): |
|
super().__init__() |
|
|
|
self.device = device |
|
|
|
self.dtype = MODEL_DTYPE |
|
self.H, self.W, self.C = 224, 224, 3 |
|
self.P = 14 |
|
self.D = 1408 |
|
self.num_query_tokens = 32 |
|
|
|
self.vit = self._init_vit(vit_checkpoint_path).to(self.dtype) |
|
print("Loaded ViT") |
|
self.ln_vision = nn.LayerNorm(self.D).to(self.dtype) |
|
|
|
self.q_former = Blip2QFormer( |
|
num_query_tokens=self.num_query_tokens, |
|
vision_width=self.D |
|
) |
|
self.q_former.load_from_pretrained(q_former_model) |
|
for param in self.q_former.parameters(): |
|
param.requires_grad = False |
|
|
|
print("Loaded QFormer") |
|
|
|
|
|
self.llama = self._init_llama() |
|
|
|
self.llama_proj = nn.Linear( |
|
self.q_former.bert_config.hidden_size, |
|
self.llama.config.hidden_size |
|
).to(self.dtype) |
|
|
|
print(f"Q-Former output dim: {self.q_former.bert_config.hidden_size}") |
|
print(f"LLaMA input dim: {self.llama.config.hidden_size}") |
|
|
|
for module in [self.vit, self.ln_vision, self.q_former, self.llama_proj, self.llama]: |
|
for param in module.parameters(): |
|
param.requires_grad = False |
|
module.eval() |
|
|
|
def _init_vit(self, vit_checkpoint_path): |
|
"""Initialize EVA-ViT-G with paper specifications""" |
|
vit = create_eva_vit_g( |
|
img_size=(self.H, self.W), |
|
patch_size=self.P, |
|
embed_dim=self.D, |
|
depth=39, |
|
num_heads=16, |
|
mlp_ratio=4.3637, |
|
qkv_bias=True, |
|
drop_path_rate=0.1, |
|
norm_layer=nn.LayerNorm, |
|
init_values=1e-5 |
|
).to(self.dtype) |
|
if not hasattr(vit, 'norm'): |
|
vit.norm = nn.LayerNorm(self.D) |
|
checkpoint = torch.load(vit_checkpoint_path, map_location='cpu') |
|
|
|
vit_weights = {k.replace("vit.", ""): v |
|
for k, v in checkpoint.items() |
|
if k.startswith("vit.")} |
|
|
|
|
|
vit.load_state_dict(vit_weights, strict=False) |
|
|
|
|
|
return vit.eval() |
|
|
|
def _init_llama(self): |
|
"""Initialize frozen LLaMA-2-13b-chat with proper error handling""" |
|
try: |
|
device_map = { |
|
"": 0 if torch.cuda.is_available() else "cpu" |
|
} |
|
|
|
model = LlamaForCausalLM.from_pretrained( |
|
"meta-llama/Llama-2-13b-chat-hf", |
|
token=token, |
|
torch_dtype=torch.float16, |
|
device_map=device_map, |
|
low_cpu_mem_usage=True |
|
) |
|
|
|
return model.eval() |
|
|
|
except Exception as e: |
|
raise ImportError( |
|
f"Failed to load LLaMA model. Please ensure:\n" |
|
f"1. You have accepted the license at: https://huggingface.co/meta-llama/Llama-2-13b-chat-hf\n" |
|
f"2. Your Hugging Face token is correct\n" |
|
f"3. Required packages are installed: pip install accelerate bitsandbytes transformers\n" |
|
f"Original error: {str(e)}" |
|
) |
|
|
|
def encode_image(self, x): |
|
"""Convert image to patch embeddings following Eq. (1)""" |
|
|
|
x = x.to(self.dtype) |
|
if x.dim() == 3: |
|
x = x.unsqueeze(0) |
|
if x.dim() != 4: |
|
raise ValueError(f"Input must be 4D tensor (got {x.dim()}D)") |
|
|
|
B, C, H, W = x.shape |
|
N = (H * W) // (self.P ** 2) |
|
|
|
x = self.vit.patch_embed(x) |
|
|
|
num_patches = x.shape[1] |
|
pos_embed = self.vit.pos_embed[:, 1:num_patches + 1, :] |
|
x = x + pos_embed |
|
|
|
|
|
class_token = self.vit.cls_token.expand(x.shape[0], -1, -1) |
|
x = torch.cat([class_token, x], dim=1) |
|
for blk in self.vit.blocks: |
|
x = blk(x) |
|
x = self.vit.norm(x) |
|
vit_features = self.ln_vision(x) |
|
|
|
|
|
with torch.no_grad(): |
|
qformer_output = self.q_former(vit_features.float()) |
|
image_embeds = self.llama_proj(qformer_output.to(self.dtype)) |
|
|
|
return image_embeds |
|
|
|
def generate(self, images, user_input=None, max_new_tokens=300): |
|
|
|
image_embeds = self.encode_image(images) |
|
|
|
print(f"Aligned features : {image_embeds}") |
|
print(f"\n Images embeddings shape : {image_embeds.shape} \n Llama config hidden size : {self.llama.config.hidden_size}") |
|
|
|
print( |
|
f"\n[VALIDATION] Visual embeds - Mean: {image_embeds.mean().item():.4f}, Std: {image_embeds.std().item():.4f}") |
|
|
|
if image_embeds.shape[-1] != self.llama.config.hidden_size: |
|
raise ValueError( |
|
f"Feature dimension mismatch. " |
|
f"Q-Former output: {image_embeds.shape[-1]}, " |
|
f"LLaMA expected: {self.llama.config.hidden_size}" |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
prompt = """### Skin Diagnosis Analysis ### |
|
<IMAGE> |
|
Could you describe the skin condition in this image? |
|
Please provide a detailed analysis including possible diagnoses. |
|
### Response:""" |
|
|
|
|
|
print(f"\n[DEBUG] Raw Prompt:\n{prompt}") |
|
|
|
self.tokenizer = LlamaTokenizer.from_pretrained( |
|
"meta-llama/Llama-2-13b-chat-hf", |
|
token=token, |
|
padding_side="right" |
|
) |
|
|
|
num_added = self.tokenizer.add_special_tokens({ |
|
'additional_special_tokens': ['<IMAGE>'] |
|
}) |
|
|
|
if num_added == 0: |
|
raise ValueError("Failed to add <IMAGE> token!") |
|
|
|
self.llama.resize_token_embeddings(len(self.tokenizer)) |
|
|
|
inputs = self.tokenizer(prompt, return_tensors="pt").to(images.device) |
|
|
|
print(f"\n[DEBUG] Tokenized input IDs:\n{inputs.input_ids}") |
|
print(f"[DEBUG] Special token positions: {self.tokenizer.all_special_tokens}") |
|
|
|
|
|
input_embeddings = self.llama.model.embed_tokens(inputs.input_ids) |
|
visual_embeds = image_embeds.mean(dim=1) |
|
|
|
|
|
image_token_id = self.tokenizer.convert_tokens_to_ids("<IMAGE>") |
|
replace_positions = (inputs.input_ids == image_token_id).nonzero() |
|
|
|
if len(replace_positions) == 0: |
|
raise ValueError("No <IMAGE> tokens found in prompt!") |
|
|
|
if len(replace_positions[0]) == 0: |
|
raise ValueError("Image token not found in prompt") |
|
|
|
print(f"\n[DEBUG] Image token found at position: {replace_positions}") |
|
|
|
|
|
print(f"\n[DEBUG] Before replacement:") |
|
print(f"Text embeddings shape: {input_embeddings.shape}") |
|
print(f"Visual embeddings shape: {visual_embeds.shape}") |
|
print(f"Image token at {replace_positions[0][1].item()}:") |
|
print(f"Image token embedding (before):\n{input_embeddings[0, replace_positions[0][1], :5]}...") |
|
|
|
for pos in replace_positions: |
|
input_embeddings[0, pos[1]] = visual_embeds[0] |
|
|
|
print(f"\n[DEBUG] After replacement:") |
|
print(f"Image token embedding (after):\n{input_embeddings[0, replace_positions[0][1], :5]}...") |
|
|
|
outputs = self.llama.generate( |
|
inputs_embeds=input_embeddings, |
|
max_new_tokens=max_new_tokens, |
|
temperature=0.7, |
|
top_k=40, |
|
top_p=0.9, |
|
repetition_penalty=1.1, |
|
do_sample=True, |
|
pad_token_id = self.tokenizer.eos_token_id, |
|
eos_token_id = self.tokenizer.eos_token_id |
|
) |
|
|
|
|
|
full_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
print(f"Full Output from llama : {full_output}") |
|
response = full_output.split("### Response:")[-1].strip() |
|
|
|
|
|
return response |
|
|
|
|
|
class SkinGPTClassifier: |
|
def __init__(self, device='cuda' if torch.cuda.is_available() else 'cpu'): |
|
self.device = torch.device(device) |
|
self.conversation_history = [] |
|
|
|
with st.spinner("Loading AI models (this may take several minutes)..."): |
|
self.model = self._load_model() |
|
|
|
|
|
|
|
|
|
self.transform = transforms.Compose([ |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
]) |
|
|
|
def _load_model(self): |
|
model_path = hf_hub_download( |
|
repo_id="KeerthiVM/SkinCancerDiagnosis", |
|
filename="dermnet_finetuned_version1.pth", |
|
) |
|
model = SkinGPT4(vit_checkpoint_path=model_path).eval() |
|
model = model.to(self.device) |
|
return model |
|
|
|
def predict(self, image): |
|
image = image.convert('RGB') |
|
image_tensor = self.transform(image).unsqueeze(0).to(self.device) |
|
with torch.no_grad(): |
|
diagnosis = self.model.generate(image_tensor) |
|
|
|
return { |
|
"diagnosis": diagnosis, |
|
"visual_features": None |
|
} |
|
|