|
import streamlit as st |
|
import torchvision.transforms as transforms |
|
import torch |
|
import io |
|
import os |
|
from fpdf import FPDF |
|
import nest_asyncio |
|
nest_asyncio.apply() |
|
device='cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
st.set_page_config(page_title="DermBOT", page_icon="π§¬", layout="centered") |
|
|
|
import torch |
|
from torch import nn |
|
from torchvision import transforms |
|
from PIL import Image |
|
from transformers import LlamaForCausalLM, LlamaTokenizer, BertModel, BertConfig |
|
from eva_vit import create_eva_vit_g |
|
import requests |
|
from io import BytesIO |
|
import os |
|
from huggingface_hub import hf_hub_download |
|
from transformers import BitsAndBytesConfig |
|
from accelerate import init_empty_weights |
|
import warnings |
|
from transformers import logging |
|
|
|
import torch |
|
from torch.cuda.amp import autocast |
|
from SkinGPT import SkinGPTClassifier |
|
|
|
torch.set_default_dtype(torch.float32) |
|
MODEL_DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32 |
|
import warnings |
|
warnings.filterwarnings("ignore", category=FutureWarning) |
|
warnings.filterwarnings("ignore", category=UserWarning) |
|
logging.set_verbosity_error() |
|
token = os.getenv("HF_TOKEN") |
|
if not token: |
|
raise ValueError("Hugging Face token not found in environment variables") |
|
import warnings |
|
|
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
|
|
def get_classifier(): |
|
return SkinGPTClassifier() |
|
|
|
classifier = get_classifier() |
|
|
|
|
|
if "messages" not in st.session_state: |
|
st.session_state.messages = [] |
|
|
|
|
|
|
|
def export_chat_to_pdf(messages): |
|
pdf = FPDF() |
|
pdf.add_page() |
|
pdf.set_font("Arial", size=12) |
|
for msg in messages: |
|
role = "You" if msg["role"] == "user" else "AI" |
|
pdf.multi_cell(0, 10, f"{role}: {msg['content']}\n") |
|
buf = io.BytesIO() |
|
pdf.output(buf) |
|
buf.seek(0) |
|
return buf |
|
|
|
|
|
|
|
st.title("𧬠DermBOT β Skin AI Assistant") |
|
st.caption(f"π§ Using model: SkinGPT") |
|
uploaded_file = st.file_uploader("Upload a skin image", type=["jpg", "jpeg", "png"]) |
|
if "conversation" not in st.session_state: |
|
st.session_state.conversation = [] |
|
if uploaded_file: |
|
st.image(uploaded_file, caption="Uploaded image", use_column_width=True) |
|
image = Image.open(uploaded_file).convert("RGB") |
|
if not st.session_state.conversation: |
|
with st.spinner("Analyzing image..."): |
|
result = classifier.predict(image) |
|
if "error" in result: |
|
st.error(result["error"]) |
|
else: |
|
st.session_state.conversation.append(("assistant", result)) |
|
with st.chat_message("assistant"): |
|
st.markdown(result["diagnosis"]) |
|
else: |
|
|
|
if user_query := st.chat_input("Ask a follow-up question..."): |
|
st.session_state.conversation.append(("user", user_query)) |
|
with st.chat_message("user"): |
|
st.markdown(user_query) |
|
|
|
|
|
context = "\n".join([f"{role}: {msg}" for role, msg in st.session_state.conversation]) |
|
response = classifier.generate(image, user_input=context) |
|
|
|
st.session_state.conversation.append(("assistant", response)) |
|
with st.chat_message("assistant"): |
|
st.markdown(response) |
|
|
|
|
|
if st.button("π Download Chat as PDF"): |
|
pdf_file = export_chat_to_pdf(st.session_state.messages) |
|
st.download_button("Download PDF", data=pdf_file, file_name="chat_history.pdf", mime="application/pdf") |