|
import streamlit as st |
|
import torchvision.transforms as transforms |
|
import torch |
|
import io |
|
import os |
|
from fpdf import FPDF |
|
import nest_asyncio |
|
nest_asyncio.apply() |
|
device='cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
st.set_page_config(page_title="DermBOT", page_icon="π§¬", layout="centered") |
|
|
|
import torch |
|
from torch import nn |
|
from torchvision import transforms |
|
from PIL import Image |
|
from transformers import LlamaForCausalLM, LlamaTokenizer, BertModel, BertConfig |
|
from eva_vit import create_eva_vit_g |
|
import requests |
|
from io import BytesIO |
|
import os |
|
from huggingface_hub import hf_hub_download |
|
from transformers import BitsAndBytesConfig |
|
from accelerate import init_empty_weights |
|
|
|
token = os.getenv("HF_TOKEN") |
|
if not token: |
|
raise ValueError("Hugging Face token not found in environment variables") |
|
import warnings |
|
|
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
class Blip2QFormer(nn.Module): |
|
def __init__(self, num_query_tokens=32, vision_width=1408): |
|
super().__init__() |
|
|
|
self.bert_config = BertConfig( |
|
vocab_size=30522, |
|
hidden_size=768, |
|
num_hidden_layers=12, |
|
num_attention_heads=12, |
|
intermediate_size=3072, |
|
hidden_act="gelu", |
|
hidden_dropout_prob=0.1, |
|
attention_probs_dropout_prob=0.1, |
|
max_position_embeddings=512, |
|
type_vocab_size=2, |
|
initializer_range=0.02, |
|
layer_norm_eps=1e-12, |
|
pad_token_id=0, |
|
position_embedding_type="absolute", |
|
use_cache=True, |
|
classifier_dropout=None, |
|
) |
|
|
|
self.bert = BertModel(self.bert_config, add_pooling_layer=False).to(torch.float16) |
|
|
|
|
|
self.bert.embeddings.position_embeddings = nn.Identity() |
|
|
|
|
|
self.bert.embeddings.word_embeddings = None |
|
|
|
|
|
self.query_tokens = nn.Parameter( |
|
torch.zeros(1, num_query_tokens, self.bert_config.hidden_size, dtype=torch.float16) |
|
) |
|
self.vision_proj = nn.Sequential( |
|
nn.Linear(vision_width, self.bert_config.hidden_size), |
|
nn.LayerNorm(self.bert_config.hidden_size) |
|
).to(torch.float16) |
|
|
|
|
|
def load_from_pretrained(self, url_or_filename): |
|
if url_or_filename.startswith('http'): |
|
response = requests.get(url_or_filename) |
|
checkpoint = torch.load(BytesIO(response.content), map_location='cpu') |
|
else: |
|
checkpoint = torch.load(url_or_filename, map_location='cpu') |
|
|
|
|
|
state_dict = checkpoint['model'] if 'model' in checkpoint else checkpoint |
|
msg = self.load_state_dict(state_dict, strict=False) |
|
|
|
|
|
def forward(self, query_embeds=None, encoder_hidden_states=None, encoder_attention_mask=None): |
|
if query_embeds is None: |
|
query_embeds = self.query_tokens.expand(encoder_hidden_states.shape[0], -1, -1) |
|
|
|
|
|
visual_embeds = self.vision_proj(encoder_hidden_states) |
|
|
|
|
|
if encoder_attention_mask is None: |
|
encoder_attention_mask = torch.ones( |
|
visual_embeds.size()[:-1], |
|
dtype=torch.long, |
|
device=visual_embeds.device |
|
) |
|
batch_size = query_embeds.size(0) |
|
extended_attention_mask = encoder_attention_mask.unsqueeze(1).expand(-1, query_embeds.size(1), -1) |
|
|
|
encoder_outputs = self.bert.encoder( |
|
hidden_states=query_embeds, |
|
attention_mask=None, |
|
encoder_hidden_states=visual_embeds, |
|
encoder_attention_mask=encoder_attention_mask, |
|
return_dict=True |
|
) |
|
return encoder_outputs.last_hidden_state |
|
|
|
|
|
class LayerNorm(nn.LayerNorm): |
|
"""Subclass torch's LayerNorm to handle fp16.""" |
|
|
|
def forward(self, x: torch.Tensor): |
|
orig_type = x.dtype |
|
ret = super().forward(x.type(torch.float32)) |
|
return ret.type(orig_type) |
|
|
|
|
|
class ViTClassifier(nn.Module): |
|
def __init__(self, vit, ln_vision, num_labels): |
|
super(ViTClassifier, self).__init__() |
|
self.vit = vit |
|
self.ln_vision = ln_vision |
|
self.classifier = nn.Linear(vit.num_features, num_labels) |
|
|
|
def forward(self, x): |
|
features = self.ln_vision(self.vit(x)) |
|
cls_token = features[:, 0, :] |
|
return self.classifier(cls_token) |
|
|
|
|
|
class SkinGPT4(nn.Module): |
|
def __init__(self, vit_checkpoint_path, |
|
q_former_model="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth"): |
|
super().__init__() |
|
|
|
self.dtype = torch.float16 |
|
self.H, self.W, self.C = 224, 224, 3 |
|
self.P = 14 |
|
self.D = 1408 |
|
self.num_query_tokens = 32 |
|
|
|
self.tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-13b-chat-hf", |
|
token=token, padding_side="right") |
|
|
|
print("Loaded tokenizer") |
|
self.tokenizer.add_special_tokens({'additional_special_tokens': ['<ImageHere>']}) |
|
|
|
|
|
self.vit = self._init_vit(vit_checkpoint_path) |
|
print("Loaded ViT") |
|
self.ln_vision = nn.LayerNorm(self.D).to(self.dtype) |
|
|
|
self.q_former = Blip2QFormer( |
|
num_query_tokens=self.num_query_tokens, |
|
vision_width=self.D |
|
).to(self.dtype) |
|
self.q_former.load_from_pretrained(q_former_model) |
|
for param in self.q_former.parameters(): |
|
param.requires_grad = False |
|
self.q_former.eval() |
|
print("Loaded QFormer") |
|
self.llama = self._init_llama() |
|
self.llama.resize_token_embeddings(len(self.tokenizer)) |
|
|
|
self.llama_proj = nn.Linear( |
|
self.q_former.bert_config.hidden_size, |
|
self.llama.config.hidden_size |
|
).to(self.dtype) |
|
self._init_alignment_projection() |
|
print("Loaded Llama") |
|
|
|
|
|
self.query_tokens = nn.Parameter( |
|
torch.zeros(1, self.num_query_tokens, self.q_former.bert_config.hidden_size) |
|
) |
|
nn.init.normal_(self.query_tokens, std=0.02) |
|
|
|
def _init_vit(self, vit_checkpoint_path): |
|
"""Initialize EVA-ViT-G with paper specifications""" |
|
vit = create_eva_vit_g( |
|
img_size=(self.H, self.W), |
|
patch_size=self.P, |
|
embed_dim=self.D, |
|
depth=39, |
|
num_heads=16, |
|
mlp_ratio=4.3637, |
|
qkv_bias=True, |
|
drop_path_rate=0.1, |
|
norm_layer=nn.LayerNorm, |
|
init_values=1e-5 |
|
).to(self.dtype) |
|
if not hasattr(vit, 'norm'): |
|
vit.norm = nn.LayerNorm(self.D) |
|
checkpoint = torch.load(vit_checkpoint_path, map_location='cpu') |
|
|
|
vit_weights = {k.replace("vit.", ""): v |
|
for k, v in checkpoint.items() |
|
if k.startswith("vit.")} |
|
|
|
|
|
vit.load_state_dict(vit_weights, strict=False) |
|
|
|
|
|
for param in vit.parameters(): |
|
param.requires_grad = False |
|
|
|
return vit.eval() |
|
|
|
def _init_llama(self): |
|
"""Initialize frozen LLaMA-2-13b-chat with proper error handling""" |
|
try: |
|
device_map = { |
|
"": 0 if torch.cuda.is_available() else "cpu" |
|
} |
|
|
|
try: |
|
model = LlamaForCausalLM.from_pretrained( |
|
"meta-llama/Llama-2-13b-chat-hf", |
|
token=token, |
|
torch_dtype=torch.float16, |
|
device_map=device_map, |
|
low_cpu_mem_usage=True |
|
) |
|
except ImportError: |
|
|
|
with init_empty_weights(): |
|
model = LlamaForCausalLM.from_pretrained( |
|
"meta-llama/Llama-2-13b-chat-hf", |
|
token=token, |
|
torch_dtype=torch.float16 |
|
) |
|
model = model.to(self.device) |
|
|
|
|
|
for param in model.parameters(): |
|
param.requires_grad = False |
|
|
|
return model.eval() |
|
|
|
except Exception as e: |
|
raise ImportError( |
|
f"Failed to load LLaMA model. Please ensure:\n" |
|
f"1. You have accepted the license at: https://huggingface.co/meta-llama/Llama-2-13b-chat-hf\n" |
|
f"2. Your Hugging Face token is correct\n" |
|
f"3. Required packages are installed: pip install accelerate bitsandbytes transformers\n" |
|
f"Original error: {str(e)}" |
|
) |
|
|
|
def _init_alignment_projection(self): |
|
"""Paper specifies Xavier initialization for alignment layer""" |
|
nn.init.xavier_normal_(self.llama_proj.weight) |
|
nn.init.constant_(self.llama_proj.bias, 0) |
|
|
|
def _create_patches(self, x): |
|
"""Convert image to patch embeddings following Eq. (1)""" |
|
|
|
x = x.to(self.dtype) |
|
print(f"Shape of x : {x.shape}") |
|
if x.dim() == 3: |
|
x = x.unsqueeze(0) |
|
if x.dim() != 4: |
|
raise ValueError(f"Input must be 4D tensor (got {x.dim()}D)") |
|
|
|
B, C, H, W = x.shape |
|
N = (H * W) // (self.P ** 2) |
|
|
|
x = self.vit.patch_embed(x) |
|
|
|
num_patches = x.shape[1] |
|
pos_embed = self.vit.pos_embed[:, 1:num_patches + 1, :] |
|
x = x + pos_embed |
|
|
|
|
|
class_token = self.vit.cls_token.expand(B, -1, -1) |
|
x = torch.cat([class_token, x], dim=1) |
|
print(f"Final output shape: {x.shape}") |
|
return x |
|
|
|
def forward_encoder(self, x): |
|
"""ViT encoder from Eqs. (2)-(3)""" |
|
|
|
for blk in self.vit.blocks: |
|
x = blk(x) |
|
x = self.vit.norm(x) |
|
x = self.ln_vision(x) |
|
return x |
|
|
|
def forward(self, images): |
|
images = images.to(self.dtype) |
|
|
|
x = self._create_patches(images) |
|
|
|
|
|
x = x.to(self.dtype) |
|
self.vit = self.vit.to(self.dtype) |
|
vit_output = self.forward_encoder(x) |
|
|
|
|
|
query_tokens = self.query_tokens.expand(x.size(0), -1, -1).to(torch.float16) |
|
qformer_output = self.q_former( |
|
query_embeds=query_tokens, |
|
encoder_hidden_states=vit_output.to(torch.float16), |
|
encoder_attention_mask=torch.ones_like(vit_output[:, :, 0]) |
|
).to(self.dtype) |
|
|
|
|
|
aligned_features = self.llama_proj(qformer_output.to(self.dtype)) |
|
|
|
return aligned_features |
|
|
|
def add_to_history(self, role, content): |
|
self.conversation_history.append({"role": role, "content": content}) |
|
|
|
def get_full_context(self): |
|
return "\n".join([f"{msg['role']}: {msg['content']}" for msg in self.conversation_history]) |
|
|
|
def build_prompt(self, image_embeds, user_question=None): |
|
|
|
if not user_question: |
|
prompt = ( |
|
"### Instruction: <Img ><Image ></Img> " |
|
"Could you describe the skin disease in this image for me? " |
|
"### Response:" |
|
) |
|
else: |
|
|
|
history = self.get_full_context() |
|
prompt = ( |
|
f"### Instruction: <Img ><Image ></Img> " |
|
f"Based on our previous conversation:\n{history}\n" |
|
f"User asks: {user_question}\n" |
|
"### Response:" |
|
) |
|
|
|
return prompt |
|
|
|
def generate(self, images, user_input=None, max_length=300): |
|
|
|
aligned_features = self.forward(images) |
|
|
|
prompt = self.build_prompt(aligned_features, user_input) |
|
inputs = self.tokenizer(prompt, return_tensors="pt").to(images.device) |
|
image_embeddings = self.llama.model.embed_tokens(inputs.input_ids) |
|
image_token_index = torch.where(inputs.input_ids == self.tokenizer.convert_tokens_to_ids("<ImageHere>")) |
|
image_embeddings[image_token_index] = aligned_features.mean(dim=1) |
|
|
|
|
|
outputs = self.llama.generate( |
|
inputs_embeds=image_embeddings, |
|
max_length=max_length, |
|
temperature=0.7, |
|
top_p=0.9, |
|
do_sample=True |
|
) |
|
|
|
return self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
class SkinGPTClassifier: |
|
def __init__(self, device='cuda' if torch.cuda.is_available() else 'cpu'): |
|
self.device = torch.device(device) |
|
self.conversation_history = [] |
|
|
|
with st.spinner("Loading AI models (this may take several minutes)..."): |
|
self.meta_model = self.load_models() |
|
self.resnet_feature_extractor = None |
|
|
|
|
|
self.transform = transforms.Compose([ |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
]) |
|
|
|
def load_models(self): |
|
model_path = hf_hub_download( |
|
repo_id="KeerthiVM/SkinCancerDiagnosis", |
|
filename="dermnet_finetuned_version1.pth", |
|
) |
|
meta_model = SkinGPT4(vit_checkpoint_path=model_path) |
|
return meta_model |
|
|
|
def predict(self, image): |
|
image = image.convert('RGB') |
|
image_tensor = self.transform(image).unsqueeze(0).to(self.device) |
|
diagnosis = self.meta_model.generate( |
|
image_tensor |
|
) |
|
|
|
return { |
|
"top_predictions": diagnosis, |
|
} |
|
|
|
@st.cache_resource |
|
def get_classifier(): |
|
return SkinGPTClassifier() |
|
|
|
classifier = get_classifier() |
|
|
|
|
|
if "messages" not in st.session_state: |
|
st.session_state.messages = [] |
|
|
|
|
|
|
|
def export_chat_to_pdf(messages): |
|
pdf = FPDF() |
|
pdf.add_page() |
|
pdf.set_font("Arial", size=12) |
|
for msg in messages: |
|
role = "You" if msg["role"] == "user" else "AI" |
|
pdf.multi_cell(0, 10, f"{role}: {msg['content']}\n") |
|
buf = io.BytesIO() |
|
pdf.output(buf) |
|
buf.seek(0) |
|
return buf |
|
|
|
|
|
|
|
st.title("𧬠DermBOT β Skin AI Assistant") |
|
st.caption(f"π§ Using model: SkinGPT") |
|
uploaded_file = st.file_uploader("Upload a skin image", type=["jpg", "jpeg", "png"]) |
|
if "conversation" not in st.session_state: |
|
st.session_state.conversation = [] |
|
if uploaded_file: |
|
st.image(uploaded_file, caption="Uploaded image", use_column_width=True) |
|
image = Image.open(uploaded_file).convert("RGB") |
|
if not st.session_state.conversation: |
|
|
|
with st.spinner("Analyzing image..."): |
|
diagnosis = classifier.predict(image) |
|
st.session_state.conversation.append(("assistant", diagnosis)) |
|
with st.chat_message("assistant"): |
|
st.markdown(diagnosis) |
|
else: |
|
|
|
if user_query := st.chat_input("Ask a follow-up question..."): |
|
st.session_state.conversation.append(("user", user_query)) |
|
with st.chat_message("user"): |
|
st.markdown(user_query) |
|
|
|
|
|
context = "\n".join([f"{role}: {msg}" for role, msg in st.session_state.conversation]) |
|
response = classifier.generate(image, user_input=context) |
|
|
|
st.session_state.conversation.append(("assistant", response)) |
|
with st.chat_message("assistant"): |
|
st.markdown(response) |
|
|
|
|
|
if st.button("π Download Chat as PDF"): |
|
pdf_file = export_chat_to_pdf(st.session_state.messages) |
|
st.download_button("Download PDF", data=pdf_file, file_name="chat_history.pdf", mime="application/pdf") |