SkinGPT / app.py
KeerthiVM's picture
Removed secrets from history
b3795f6
raw
history blame
16.7 kB
import streamlit as st
import torchvision.transforms as transforms
import torch
import io
import os
from fpdf import FPDF
import nest_asyncio
nest_asyncio.apply()
device='cuda' if torch.cuda.is_available() else 'cpu'
st.set_page_config(page_title="DermBOT", page_icon="🧬", layout="centered")
import torch
from torch import nn
from torchvision import transforms
from PIL import Image
from transformers import LlamaForCausalLM, LlamaTokenizer, BertModel, BertConfig
from eva_vit import create_eva_vit_g
import requests
from io import BytesIO
import os
from huggingface_hub import hf_hub_download
from transformers import BitsAndBytesConfig
from accelerate import init_empty_weights
token = os.getenv("HF_TOKEN")
if not token:
raise ValueError("Hugging Face token not found in environment variables")
import warnings
warnings.filterwarnings("ignore")
class Blip2QFormer(nn.Module):
def __init__(self, num_query_tokens=32, vision_width=1408):
super().__init__()
# Load pre-trained Q-Former config
self.bert_config = BertConfig(
vocab_size=30522,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02,
layer_norm_eps=1e-12,
pad_token_id=0,
position_embedding_type="absolute",
use_cache=True,
classifier_dropout=None,
)
self.bert = BertModel(self.bert_config, add_pooling_layer=False).to(torch.float16)
# Replace position embeddings with a dummy implementation
self.bert.embeddings.position_embeddings = nn.Identity() # Completely bypass position embeddings
# Disable word embeddings
self.bert.embeddings.word_embeddings = None
# Initialize query tokens
self.query_tokens = nn.Parameter(
torch.zeros(1, num_query_tokens, self.bert_config.hidden_size, dtype=torch.float16)
)
self.vision_proj = nn.Sequential(
nn.Linear(vision_width, self.bert_config.hidden_size),
nn.LayerNorm(self.bert_config.hidden_size)
).to(torch.float16)
def load_from_pretrained(self, url_or_filename):
if url_or_filename.startswith('http'):
response = requests.get(url_or_filename)
checkpoint = torch.load(BytesIO(response.content), map_location='cpu')
else:
checkpoint = torch.load(url_or_filename, map_location='cpu')
# Load Q-Former weights only
state_dict = checkpoint['model'] if 'model' in checkpoint else checkpoint
msg = self.load_state_dict(state_dict, strict=False)
# print(f"Loaded Q-Former weights with message: {msg}")
def forward(self, query_embeds=None, encoder_hidden_states=None, encoder_attention_mask=None):
if query_embeds is None:
query_embeds = self.query_tokens.expand(encoder_hidden_states.shape[0], -1, -1)
# Project visual features
visual_embeds = self.vision_proj(encoder_hidden_states)
# Create proper attention mask
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(
visual_embeds.size()[:-1],
dtype=torch.long,
device=visual_embeds.device
)
batch_size = query_embeds.size(0)
extended_attention_mask = encoder_attention_mask.unsqueeze(1).expand(-1, query_embeds.size(1), -1)
encoder_outputs = self.bert.encoder(
hidden_states=query_embeds,
attention_mask=None,
encoder_hidden_states=visual_embeds,
encoder_attention_mask=encoder_attention_mask,
return_dict=True
)
return encoder_outputs.last_hidden_state
class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""
def forward(self, x: torch.Tensor):
orig_type = x.dtype
ret = super().forward(x.type(torch.float32))
return ret.type(orig_type)
class ViTClassifier(nn.Module):
def __init__(self, vit, ln_vision, num_labels):
super(ViTClassifier, self).__init__()
self.vit = vit # Pretrained ViT from MiniGPT-4
self.ln_vision = ln_vision # LayerNorm from MiniGPT-4
self.classifier = nn.Linear(vit.num_features, num_labels)
def forward(self, x):
features = self.ln_vision(self.vit(x)) # [batch, seq_len, dim]
cls_token = features[:, 0, :] # Extract CLS token
return self.classifier(cls_token)
class SkinGPT4(nn.Module):
def __init__(self, vit_checkpoint_path,
q_former_model="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth"):
super().__init__()
# Image encoder parameters from paper
self.dtype = torch.float16
self.H, self.W, self.C = 224, 224, 3
self.P = 14 # Patch size
self.D = 1408 # ViT embedding dimension
self.num_query_tokens = 32
self.tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-13b-chat-hf",
token=token, padding_side="right")
print("Loaded tokenizer")
self.tokenizer.add_special_tokens({'additional_special_tokens': ['<ImageHere>']})
# Initialize components
self.vit = self._init_vit(vit_checkpoint_path)
print("Loaded ViT")
self.ln_vision = nn.LayerNorm(self.D).to(self.dtype)
self.q_former = Blip2QFormer(
num_query_tokens=self.num_query_tokens,
vision_width=self.D
).to(self.dtype)
self.q_former.load_from_pretrained(q_former_model)
for param in self.q_former.parameters():
param.requires_grad = False
self.q_former.eval()
print("Loaded QFormer")
self.llama = self._init_llama()
self.llama.resize_token_embeddings(len(self.tokenizer))
self.llama_proj = nn.Linear(
self.q_former.bert_config.hidden_size,
self.llama.config.hidden_size
).to(self.dtype)
self._init_alignment_projection()
print("Loaded Llama")
# Initialize learnable query tokens
self.query_tokens = nn.Parameter(
torch.zeros(1, self.num_query_tokens, self.q_former.bert_config.hidden_size)
)
nn.init.normal_(self.query_tokens, std=0.02)
def _init_vit(self, vit_checkpoint_path):
"""Initialize EVA-ViT-G with paper specifications"""
vit = create_eva_vit_g(
img_size=(self.H, self.W),
patch_size=self.P,
embed_dim=self.D,
depth=39,
num_heads=16,
mlp_ratio=4.3637,
qkv_bias=True,
drop_path_rate=0.1,
norm_layer=nn.LayerNorm,
init_values=1e-5
).to(self.dtype)
if not hasattr(vit, 'norm'):
vit.norm = nn.LayerNorm(self.D)
checkpoint = torch.load(vit_checkpoint_path, map_location='cpu')
# 3. Filter weights for ViT components only
vit_weights = {k.replace("vit.", ""): v
for k, v in checkpoint.items()
if k.startswith("vit.")}
# 4. Load weights while ignoring classifier head
vit.load_state_dict(vit_weights, strict=False)
# 5. Freeze according to paper specs
for param in vit.parameters():
param.requires_grad = False
return vit.eval()
def _init_llama(self):
"""Initialize frozen LLaMA-2-13b-chat with proper error handling"""
try:
device_map = {
"": 0 if torch.cuda.is_available() else "cpu"
}
# First try loading with device_map="auto"
try:
model = LlamaForCausalLM.from_pretrained(
"meta-llama/Llama-2-13b-chat-hf",
token=token,
torch_dtype=torch.float16,
device_map=device_map,
low_cpu_mem_usage=True
)
except ImportError:
# Fallback to CPU-offloading if GPU memory is insufficient
with init_empty_weights():
model = LlamaForCausalLM.from_pretrained(
"meta-llama/Llama-2-13b-chat-hf",
token=token,
torch_dtype=torch.float16
)
model = model.to(self.device)
# Freeze all parameters
for param in model.parameters():
param.requires_grad = False
return model.eval()
except Exception as e:
raise ImportError(
f"Failed to load LLaMA model. Please ensure:\n"
f"1. You have accepted the license at: https://huggingface.co/meta-llama/Llama-2-13b-chat-hf\n"
f"2. Your Hugging Face token is correct\n"
f"3. Required packages are installed: pip install accelerate bitsandbytes transformers\n"
f"Original error: {str(e)}"
)
def _init_alignment_projection(self):
"""Paper specifies Xavier initialization for alignment layer"""
nn.init.xavier_normal_(self.llama_proj.weight)
nn.init.constant_(self.llama_proj.bias, 0)
def _create_patches(self, x):
"""Convert image to patch embeddings following Eq. (1)"""
# x: (B, C, H, W)
x = x.to(self.dtype)
print(f"Shape of x : {x.shape}")
if x.dim() == 3:
x = x.unsqueeze(0) # Add batch dimension if missing
if x.dim() != 4:
raise ValueError(f"Input must be 4D tensor (got {x.dim()}D)")
B, C, H, W = x.shape
N = (H * W) // (self.P ** 2)
x = self.vit.patch_embed(x) # (B, N, D)
num_patches = x.shape[1]
pos_embed = self.vit.pos_embed[:, 1:num_patches + 1, :] # Adjust for exact match
x = x + pos_embed
# Add class token
class_token = self.vit.cls_token.expand(B, -1, -1)
x = torch.cat([class_token, x], dim=1) # (B, N+1, D)
print(f"Final output shape: {x.shape}")
return x
def forward_encoder(self, x):
"""ViT encoder from Eqs. (2)-(3)"""
# x: (B, N+1, D)
for blk in self.vit.blocks:
x = blk(x)
x = self.vit.norm(x)
x = self.ln_vision(x)
return x # (B, N+1, D)
def forward(self, images):
images = images.to(self.dtype)
# Convert images to patches
x = self._create_patches(images) # (B, N+1, D)
# ViT processing
x = x.to(self.dtype)
self.vit = self.vit.to(self.dtype)
vit_output = self.forward_encoder(x) # (B, N+1, D)
# Q-Former processing
query_tokens = self.query_tokens.expand(x.size(0), -1, -1).to(torch.float16)
qformer_output = self.q_former(
query_embeds=query_tokens,
encoder_hidden_states=vit_output.to(torch.float16),
encoder_attention_mask=torch.ones_like(vit_output[:, :, 0])
).to(self.dtype)
# Alignment projection
aligned_features = self.llama_proj(qformer_output.to(self.dtype))
return aligned_features
def add_to_history(self, role, content):
self.conversation_history.append({"role": role, "content": content})
def get_full_context(self):
return "\n".join([f"{msg['role']}: {msg['content']}" for msg in self.conversation_history])
def build_prompt(self, image_embeds, user_question=None):
# Base prompt for initial diagnosis
if not user_question:
prompt = (
"### Instruction: <Img ><Image ></Img> "
"Could you describe the skin disease in this image for me? "
"### Response:"
)
else:
# Follow-up prompt with conversation history
history = self.get_full_context()
prompt = (
f"### Instruction: <Img ><Image ></Img> "
f"Based on our previous conversation:\n{history}\n"
f"User asks: {user_question}\n"
"### Response:"
)
return prompt
def generate(self, images, user_input=None, max_length=300):
# Get aligned features
aligned_features = self.forward(images)
prompt = self.build_prompt(aligned_features, user_input)
inputs = self.tokenizer(prompt, return_tensors="pt").to(images.device)
image_embeddings = self.llama.model.embed_tokens(inputs.input_ids)
image_token_index = torch.where(inputs.input_ids == self.tokenizer.convert_tokens_to_ids("<ImageHere>"))
image_embeddings[image_token_index] = aligned_features.mean(dim=1) # Pool query tokens
# Generate response
outputs = self.llama.generate(
inputs_embeds=image_embeddings,
max_length=max_length,
temperature=0.7,
top_p=0.9,
do_sample=True
)
return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
class SkinGPTClassifier:
def __init__(self, device='cuda' if torch.cuda.is_available() else 'cpu'):
self.device = torch.device(device)
self.conversation_history = []
with st.spinner("Loading AI models (this may take several minutes)..."):
self.meta_model = self.load_models()
self.resnet_feature_extractor = None
# Image transformations
self.transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def load_models(self):
model_path = hf_hub_download(
repo_id="KeerthiVM/SkinCancerDiagnosis",
filename="dermnet_finetuned_version1.pth",
)
meta_model = SkinGPT4(vit_checkpoint_path=model_path)
return meta_model
def predict(self, image):
image = image.convert('RGB')
image_tensor = self.transform(image).unsqueeze(0).to(self.device)
diagnosis = self.meta_model.generate(
image_tensor
)
return {
"top_predictions": diagnosis,
}
@st.cache_resource
def get_classifier():
return SkinGPTClassifier()
classifier = get_classifier()
# === Session Init ===
if "messages" not in st.session_state:
st.session_state.messages = []
# === PDF Export ===
def export_chat_to_pdf(messages):
pdf = FPDF()
pdf.add_page()
pdf.set_font("Arial", size=12)
for msg in messages:
role = "You" if msg["role"] == "user" else "AI"
pdf.multi_cell(0, 10, f"{role}: {msg['content']}\n")
buf = io.BytesIO()
pdf.output(buf)
buf.seek(0)
return buf
# === App UI ===
st.title("🧬 DermBOT β€” Skin AI Assistant")
st.caption(f"🧠 Using model: SkinGPT")
uploaded_file = st.file_uploader("Upload a skin image", type=["jpg", "jpeg", "png"])
if "conversation" not in st.session_state:
st.session_state.conversation = []
if uploaded_file:
st.image(uploaded_file, caption="Uploaded image", use_column_width=True)
image = Image.open(uploaded_file).convert("RGB")
if not st.session_state.conversation:
# First message - diagnosis
with st.spinner("Analyzing image..."):
diagnosis = classifier.predict(image)
st.session_state.conversation.append(("assistant", diagnosis))
with st.chat_message("assistant"):
st.markdown(diagnosis)
else:
# Follow-up questions
if user_query := st.chat_input("Ask a follow-up question..."):
st.session_state.conversation.append(("user", user_query))
with st.chat_message("user"):
st.markdown(user_query)
# Generate response with context
context = "\n".join([f"{role}: {msg}" for role, msg in st.session_state.conversation])
response = classifier.generate(image, user_input=context)
st.session_state.conversation.append(("assistant", response))
with st.chat_message("assistant"):
st.markdown(response)
# === PDF Button ===
if st.button("πŸ“„ Download Chat as PDF"):
pdf_file = export_chat_to_pdf(st.session_state.messages)
st.download_button("Download PDF", data=pdf_file, file_name="chat_history.pdf", mime="application/pdf")