|
import os |
|
import torch |
|
import streamlit as st |
|
from PIL import Image |
|
from transformers import AutoModelForImageClassification, AutoImageProcessor |
|
from groq import Groq |
|
|
|
|
|
st.set_page_config(page_title="DermaBot - AI Skin Disease Detector", page_icon="🩺", layout="wide") |
|
|
|
|
|
MODEL_NAME = "Jayanth2002/dinov2-base-finetuned-SkinDisease" |
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
model = AutoModelForImageClassification.from_pretrained(MODEL_NAME).to(DEVICE) |
|
processor = AutoImageProcessor.from_pretrained(MODEL_NAME) |
|
|
|
|
|
client = Groq(api_key="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") |
|
|
|
|
|
if "disease_name" not in st.session_state: |
|
st.session_state.disease_name = None |
|
if "disease_info" not in st.session_state: |
|
st.session_state.disease_info = None |
|
|
|
|
|
def predict_skin_disease(image): |
|
image = image.convert("RGB") |
|
inputs = processor(images=image, return_tensors="pt").to(DEVICE) |
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
|
|
logits = outputs.logits |
|
predicted_class_idx = logits.argmax(-1).item() |
|
predicted_label = model.config.id2label[predicted_class_idx] |
|
|
|
return predicted_label |
|
|
|
|
|
def get_disease_info(disease_name): |
|
prompt = f"Provide a detailed explanation about the skin disease '{disease_name}', including description of disease, causes, precausions, risk and treatment options." |
|
|
|
chat_completion = client.chat.completions.create( |
|
messages=[{"role": "user", "content": prompt}], |
|
model="llama-3.3-70b-versatile", |
|
) |
|
|
|
return chat_completion.choices[0].message.content |
|
|
|
|
|
def chatbot_response(disease_name, user_query): |
|
if not disease_name: |
|
return "Please upload an image and detect the disease first." |
|
|
|
prompt = f"The detected skin disease is '{disease_name}'. {user_query}" |
|
|
|
chat_completion = client.chat.completions.create( |
|
messages=[{"role": "user", "content": prompt}], |
|
model="llama-3.3-70b-versatile", |
|
) |
|
|
|
return chat_completion.choices[0].message.content |
|
|
|
|
|
st.image("https://huggingface.co/spaces/your-huggingface-space/logo.png", width=200) |
|
st.title("🩺 DermaBot - AI Skin Disease Detector") |
|
st.write("Upload an image of a skin condition to get a diagnosis and ask questions about it.") |
|
|
|
|
|
uploaded_image = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"]) |
|
|
|
if uploaded_image: |
|
image = Image.open(uploaded_image) |
|
st.image(image, caption="Uploaded Image", use_container_width=True) |
|
|
|
|
|
if st.button("Detect Disease"): |
|
with st.spinner("Analyzing..."): |
|
disease_name = predict_skin_disease(image) |
|
disease_info = get_disease_info(disease_name) |
|
|
|
|
|
st.session_state.disease_name = disease_name |
|
st.session_state.disease_info = disease_info |
|
|
|
|
|
if st.session_state.disease_name: |
|
st.success(f"**Detected Disease:** {st.session_state.disease_name}") |
|
st.write(f"**Details:** {st.session_state.disease_info}") |
|
|
|
|
|
st.subheader("💬 Ask DermaBot about this disease") |
|
|
|
user_query = st.text_input("Ask about the detected disease:") |
|
|
|
if st.button("Ask"): |
|
with st.spinner("Thinking..."): |
|
response = chatbot_response(st.session_state.disease_name, user_query) |
|
st.write(response) |
|
|
|
st.markdown("---") |
|
st.write("🔍 Powered by **AI & Groq API** | © 2025 DermaBot") |
|
|