SkinGPT / app.py
KeerthiVM's picture
Issue fix
cca45c8
raw
history blame
17.5 kB
import streamlit as st
import torchvision.transforms as transforms
import torch
import io
import os
from fpdf import FPDF
import nest_asyncio
nest_asyncio.apply()
device='cuda' if torch.cuda.is_available() else 'cpu'
st.set_page_config(page_title="DermBOT", page_icon="🧬", layout="centered")
import torch
from torch import nn
from torchvision import transforms
from PIL import Image
from transformers import LlamaForCausalLM, LlamaTokenizer, BertModel, BertConfig
from eva_vit import create_eva_vit_g
import requests
from io import BytesIO
import os
from huggingface_hub import hf_hub_download
from transformers import BitsAndBytesConfig
from accelerate import init_empty_weights
import warnings
from transformers import logging
import torch
from torch.cuda.amp import autocast
# Set default dtypes
torch.set_default_dtype(torch.float32) # Main computations in float32
MODEL_DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)
logging.set_verbosity_error()
token = os.getenv("HF_TOKEN")
if not token:
raise ValueError("Hugging Face token not found in environment variables")
import warnings
warnings.filterwarnings("ignore")
class Blip2QFormer(nn.Module):
def __init__(self, num_query_tokens=32, vision_width=1408):
super().__init__()
# Load pre-trained Q-Former config
self.bert_config = BertConfig(
vocab_size=30522,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02,
layer_norm_eps=1e-12,
pad_token_id=0,
position_embedding_type="absolute",
use_cache=True,
classifier_dropout=None,
)
self.bert = BertModel(self.bert_config, add_pooling_layer=False)
self.query_tokens = nn.Parameter(
torch.zeros(1, num_query_tokens, self.bert_config.hidden_size)
)
self.vision_proj = nn.Linear(vision_width, self.bert_config.hidden_size)
# Initialize weights
self._init_weights()
def _init_weights(self):
nn.init.normal_(self.query_tokens, std=0.02)
nn.init.xavier_uniform_(self.vision_proj.weight)
nn.init.constant_(self.vision_proj.bias, 0)
def load_from_pretrained(self, url_or_filename):
if url_or_filename.startswith('http'):
response = requests.get(url_or_filename)
checkpoint = torch.load(BytesIO(response.content), map_location='cpu')
else:
checkpoint = torch.load(url_or_filename, map_location='cpu')
state_dict = checkpoint['model'] if 'model' in checkpoint else checkpoint
msg = self.load_state_dict(state_dict, strict=False)
def forward(self, visual_features):
# Project visual features
with autocast(enabled=False):
visual_embeds = self.vision_proj(visual_features.float())
# visual_embeds = self.vision_proj(visual_features.float())
visual_attention_mask = torch.ones(
visual_embeds.size()[:-1],
dtype=torch.long,
device=visual_embeds.device
)
# Expand query tokens
query_tokens = self.query_tokens.expand(visual_embeds.shape[0], -1, -1)
# Forward through BERT
outputs = self.bert(
input_ids=None, # No text input
attention_mask=None,
inputs_embeds=query_tokens,
encoder_hidden_states=visual_embeds,
encoder_attention_mask=visual_attention_mask,
return_dict=True
)
return outputs.last_hidden_state
class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""
def forward(self, x: torch.Tensor):
orig_type = x.dtype
ret = super().forward(x.type(torch.float32))
return ret.type(orig_type)
class ViTClassifier(nn.Module):
def __init__(self, vit, ln_vision, num_labels):
super(ViTClassifier, self).__init__()
self.vit = vit # Pretrained ViT from MiniGPT-4
self.ln_vision = ln_vision # LayerNorm from MiniGPT-4
self.classifier = nn.Linear(vit.num_features, num_labels)
def forward(self, x):
features = self.ln_vision(self.vit(x)) # [batch, seq_len, dim]
cls_token = features[:, 0, :] # Extract CLS token
return self.classifier(cls_token)
class SkinGPT4(nn.Module):
def __init__(self, vit_checkpoint_path,
q_former_model="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth"):
super().__init__()
# Image encoder parameters from paper
self.device = device
# self.dtype = torch.float16
self.dtype = MODEL_DTYPE
self.H, self.W, self.C = 224, 224, 3
self.P = 14 # Patch size
self.D = 1408 # ViT embedding dimension
self.num_query_tokens = 32
self.vit = self._init_vit(vit_checkpoint_path).to(self.dtype)
print("Loaded ViT")
self.ln_vision = nn.LayerNorm(self.D).to(self.dtype)
self.q_former = Blip2QFormer(
num_query_tokens=self.num_query_tokens,
vision_width=self.D
)
self.q_former.load_from_pretrained(q_former_model)
for param in self.q_former.parameters():
param.requires_grad = False
for module in [self.vit, self.ln_vision, self.q_former]:
for param in module.parameters():
param.requires_grad = False
module.eval()
print("Loaded QFormer")
self.tokenizer = LlamaTokenizer.from_pretrained(
"meta-llama/Llama-2-13b-chat-hf",
token=token,
padding_side="right"
)
self.tokenizer.add_special_tokens({'additional_special_tokens': ['<Img>', '</Img>', '<ImageHere>']})
self.llama = self._init_llama()
self.llama.resize_token_embeddings(len(self.tokenizer))
self.llama_proj = nn.Linear(
self.q_former.bert_config.hidden_size,
self.llama.config.hidden_size
).to(self.dtype)
print(f"Q-Former output dim: {self.q_former.bert_config.hidden_size}")
print(f"LLaMA input dim: {self.llama.config.hidden_size}")
for param in self.llama_proj.parameters():
param.requires_grad = False
def _init_vit(self, vit_checkpoint_path):
"""Initialize EVA-ViT-G with paper specifications"""
vit = create_eva_vit_g(
img_size=(self.H, self.W),
patch_size=self.P,
embed_dim=self.D,
depth=39,
num_heads=16,
mlp_ratio=4.3637,
qkv_bias=True,
drop_path_rate=0.1,
norm_layer=nn.LayerNorm,
init_values=1e-5
).to(self.dtype)
if not hasattr(vit, 'norm'):
vit.norm = nn.LayerNorm(self.D)
checkpoint = torch.load(vit_checkpoint_path, map_location='cpu')
# 3. Filter weights for ViT components only
vit_weights = {k.replace("vit.", ""): v
for k, v in checkpoint.items()
if k.startswith("vit.")}
# 4. Load weights while ignoring classifier head
vit.load_state_dict(vit_weights, strict=False)
# 5. Freeze according to paper specs
for param in vit.parameters():
param.requires_grad = False
return vit.eval()
def _init_llama(self):
"""Initialize frozen LLaMA-2-13b-chat with proper error handling"""
try:
device_map = {
"": 0 if torch.cuda.is_available() else "cpu"
}
# First try loading with device_map="auto"
try:
model = LlamaForCausalLM.from_pretrained(
"meta-llama/Llama-2-13b-chat-hf",
token=token,
torch_dtype=torch.float16,
device_map=device_map,
low_cpu_mem_usage=True
)
except ImportError:
# Fallback to CPU-offloading if GPU memory is insufficient
with init_empty_weights():
model = LlamaForCausalLM.from_pretrained(
"meta-llama/Llama-2-13b-chat-hf",
token=token,
torch_dtype=torch.float16
)
model = model.to(self.device)
# Freeze all parameters
for param in model.parameters():
param.requires_grad = False
return model.eval()
except Exception as e:
raise ImportError(
f"Failed to load LLaMA model. Please ensure:\n"
f"1. You have accepted the license at: https://huggingface.co/meta-llama/Llama-2-13b-chat-hf\n"
f"2. Your Hugging Face token is correct\n"
f"3. Required packages are installed: pip install accelerate bitsandbytes transformers\n"
f"Original error: {str(e)}"
)
def _init_alignment_projection(self):
"""Paper specifies Xavier initialization for alignment layer"""
nn.init.xavier_normal_(self.llama_proj.weight)
nn.init.constant_(self.llama_proj.bias, 0)
def _create_patches(self, x):
"""Convert image to patch embeddings following Eq. (1)"""
# x: (B, C, H, W)
x = x.to(self.dtype)
if x.dim() == 3:
x = x.unsqueeze(0) # Add batch dimension if missing
if x.dim() != 4:
raise ValueError(f"Input must be 4D tensor (got {x.dim()}D)")
B, C, H, W = x.shape
N = (H * W) // (self.P ** 2)
x = self.vit.patch_embed(x) # (B, N, D)
num_patches = x.shape[1]
pos_embed = self.vit.pos_embed[:, 1:num_patches + 1, :] # Adjust for exact match
x = x + pos_embed
# Add class token
class_token = self.vit.cls_token.expand(B, -1, -1)
x = torch.cat([class_token, x], dim=1) # (B, N+1, D)
return x
def forward_encoder(self, x):
"""ViT encoder from Eqs. (2)-(3)"""
# x: (B, N+1, D)
for blk in self.vit.blocks:
x = blk(x)
x = self.vit.norm(x)
x = self.ln_vision(x)
return x # (B, N+1, D)
def forward(self, images):
images = images.to(self.dtype)
x = self._create_patches(images)
vit_output = self.forward_encoder(x)
with torch.cuda.amp.autocast(enabled=False):
qformer_output = self.q_former(vit_output.float())
aligned_features = self.llama_proj(qformer_output.to(self.dtype))
return aligned_features
def add_to_history(self, role, content):
self.conversation_history.append({"role": role, "content": content})
def get_full_context(self):
return "\n".join([f"{msg['role']}: {msg['content']}" for msg in self.conversation_history])
def build_prompt(self, image_embeds, user_question=None):
# Base prompt for initial diagnosis
if not user_question:
prompt = (
"### Instruction: <Img><ImageHere></Img> "
"Could you describe the skin disease in this image for me? "
"### Response:"
)
else:
# Follow-up prompt with conversation history
history = self.get_full_context()
prompt = (
f"### Instruction: <Img><ImageHere></Img> "
f"Based on our previous conversation:\n{history}\n"
f"User asks: {user_question}\n"
"### Response:"
)
return prompt
def generate(self, images, user_input=None, max_length=300):
print("Analysing the image to generate the diagnosis")
aligned_features = self.forward(images)
print(f"Aligned features : {aligned_features}")
print("Generated the aligned features with ViT and Qformer")
prompt = (
"<Img><ImageHere></Img> Could you describe the skin disease in this image for me? [/INST]"
)
inputs = self.tokenizer(prompt, return_tensors="pt").to(images.device)
image_token_id = self.tokenizer.convert_tokens_to_ids("<ImageHere>")
image_token_pos = torch.where(inputs.input_ids == image_token_id)
if len(image_token_pos[0]) == 0:
raise ValueError("Image token not found in prompt")
# Prepare embeddings
input_embeddings = self.llama.model.embed_tokens(inputs.input_ids)
# projected_features = self.llama_proj(aligned_features.mean(dim=1, keepdim=True))
visual_embeds = aligned_features.mean(dim=1, keepdim=True) # [1, 1, 5120]
visual_embeds = visual_embeds.to(input_embeddings.dtype)
print(f"Visual embeddings : {visual_embeds}")
input_embeddings[image_token_pos] = visual_embeds
print(f"input embeddings : {input_embeddings}")
outputs = self.llama.generate(
inputs_embeds=input_embeddings,
max_new_tokens=max_length,
temperature=0.7,
top_p=0.9,
repetition_penalty=1.2, # Prevent repetition
do_sample=True,
pad_token_id=self.tokenizer.eos_token_id,
eos_token_id=self.tokenizer.eos_token_id
)
full_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"Output from llama : {full_output}")
return full_output.split("[/INST]")[-1].strip()
class SkinGPTClassifier:
def __init__(self, device='cuda' if torch.cuda.is_available() else 'cpu'):
self.device = torch.device(device)
self.conversation_history = []
with st.spinner("Loading AI models (this may take several minutes)..."):
self.model = self._load_model()
# print(f"Q-Former output shape: {self.model.q_former(torch.randn(1, 197, 1408)).shape}")
# print(f"Projection layer: {self.model.llama_proj}")
# Image transformations
self.transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def _load_model(self):
model_path = hf_hub_download(
repo_id="KeerthiVM/SkinCancerDiagnosis",
filename="dermnet_finetuned_version1.pth",
)
model = SkinGPT4(vit_checkpoint_path=model_path).eval()
model = model.to(self.device)
model.eval()
return model
def predict(self, image):
image = image.convert('RGB')
image_tensor = self.transform(image).unsqueeze(0).to(self.device)
with torch.no_grad():
diagnosis = self.model.generate(image_tensor)
return {
"diagnosis": diagnosis,
"visual_features": None # Can return features if needed
}
def get_classifier():
return SkinGPTClassifier()
classifier = get_classifier()
# === Session Init ===
if "messages" not in st.session_state:
st.session_state.messages = []
# === PDF Export ===
def export_chat_to_pdf(messages):
pdf = FPDF()
pdf.add_page()
pdf.set_font("Arial", size=12)
for msg in messages:
role = "You" if msg["role"] == "user" else "AI"
pdf.multi_cell(0, 10, f"{role}: {msg['content']}\n")
buf = io.BytesIO()
pdf.output(buf)
buf.seek(0)
return buf
# === App UI ===
st.title("🧬 DermBOT β€” Skin AI Assistant")
st.caption(f"🧠 Using model: SkinGPT")
uploaded_file = st.file_uploader("Upload a skin image", type=["jpg", "jpeg", "png"])
if "conversation" not in st.session_state:
st.session_state.conversation = []
if uploaded_file:
st.image(uploaded_file, caption="Uploaded image", use_column_width=True)
image = Image.open(uploaded_file).convert("RGB")
if not st.session_state.conversation:
with st.spinner("Analyzing image..."):
result = classifier.predict(image)
if "error" in result:
st.error(result["error"])
else:
st.session_state.conversation.append(("assistant", result))
with st.chat_message("assistant"):
st.markdown(result)
else:
# Follow-up questions
if user_query := st.chat_input("Ask a follow-up question..."):
st.session_state.conversation.append(("user", user_query))
with st.chat_message("user"):
st.markdown(user_query)
# Generate response with context
context = "\n".join([f"{role}: {msg}" for role, msg in st.session_state.conversation])
response = classifier.generate(image, user_input=context)
st.session_state.conversation.append(("assistant", response))
with st.chat_message("assistant"):
st.markdown(response)
# === PDF Button ===
if st.button("πŸ“„ Download Chat as PDF"):
pdf_file = export_chat_to_pdf(st.session_state.messages)
st.download_button("Download PDF", data=pdf_file, file_name="chat_history.pdf", mime="application/pdf")