File size: 3,784 Bytes
df0f983
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55eb5f9
df0f983
b6c72f1
afe00a8
 
b6c72f1
 
afe00a8
df0f983
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b638e4f
df0f983
 
 
 
 
 
 
 
 
 
afe00a8
 
 
df0f983
 
 
 
 
 
 
 
 
 
 
 
 
 
b6c72f1
df0f983
 
 
 
b6c72f1
df0f983
b6c72f1
df0f983
 
 
 
 
b6c72f1
 
 
 
 
 
 
 
 
 
 
 
df0f983
 
 
afe00a8
 
 
df0f983
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import os
import torch
import streamlit as st
from PIL import Image
from transformers import AutoModelForImageClassification, AutoImageProcessor
from groq import Groq

# Set page config
st.set_page_config(page_title="DermaBot - AI Skin Disease Detector", page_icon="🩺", layout="wide")

# Load model and processor
MODEL_NAME = "Jayanth2002/dinov2-base-finetuned-SkinDisease"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

model = AutoModelForImageClassification.from_pretrained(MODEL_NAME).to(DEVICE)
processor = AutoImageProcessor.from_pretrained(MODEL_NAME)

# Set up the Groq API key (replace with your actual key or use an environment variable)
client = Groq(api_key="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")

# Initialize session state for disease details
if "disease_name" not in st.session_state:
    st.session_state.disease_name = None
if "disease_info" not in st.session_state:
    st.session_state.disease_info = None

# Function to predict skin disease
def predict_skin_disease(image):
    image = image.convert("RGB")
    inputs = processor(images=image, return_tensors="pt").to(DEVICE)

    with torch.no_grad():
        outputs = model(**inputs)

    logits = outputs.logits
    predicted_class_idx = logits.argmax(-1).item()
    predicted_label = model.config.id2label[predicted_class_idx]
    
    return predicted_label

# Function to get disease details from Groq API
def get_disease_info(disease_name):
    prompt = f"Provide a detailed explanation about the skin disease '{disease_name}', including description of disease, causes, precausions, risk and treatment options."

    chat_completion = client.chat.completions.create(
        messages=[{"role": "user", "content": prompt}],
        model="llama-3.3-70b-versatile",
    )

    return chat_completion.choices[0].message.content

# Function to handle chatbot queries
def chatbot_response(disease_name, user_query):
    if not disease_name:
        return "Please upload an image and detect the disease first."

    prompt = f"The detected skin disease is '{disease_name}'. {user_query}"
    
    chat_completion = client.chat.completions.create(
        messages=[{"role": "user", "content": prompt}],
        model="llama-3.3-70b-versatile",
    )

    return chat_completion.choices[0].message.content

# Streamlit UI
st.image("https://huggingface.co/spaces/your-huggingface-space/logo.png", width=200)
st.title("🩺 DermaBot - AI Skin Disease Detector")
st.write("Upload an image of a skin condition to get a diagnosis and ask questions about it.")

# Step 1: Upload image
uploaded_image = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])

if uploaded_image:
    image = Image.open(uploaded_image)
    st.image(image, caption="Uploaded Image", use_container_width=True)

    # Step 2: Detect disease
    if st.button("Detect Disease"):
        with st.spinner("Analyzing..."):
            disease_name = predict_skin_disease(image)
            disease_info = get_disease_info(disease_name)

            # Store results in session state
            st.session_state.disease_name = disease_name
            st.session_state.disease_info = disease_info

# Display detected disease information if available
if st.session_state.disease_name:
    st.success(f"**Detected Disease:** {st.session_state.disease_name}")
    st.write(f"**Details:** {st.session_state.disease_info}")

# Step 3: Chatbot
st.subheader("💬 Ask DermaBot about this disease")

user_query = st.text_input("Ask about the detected disease:")

if st.button("Ask"):
    with st.spinner("Thinking..."):
        response = chatbot_response(st.session_state.disease_name, user_query)
        st.write(response)

st.markdown("---")
st.write("🔍 Powered by **AI & Groq API** | © 2025 DermaBot")