import os import torch import streamlit as st from PIL import Image from transformers import AutoModelForImageClassification, AutoImageProcessor from groq import Groq # Set page config st.set_page_config(page_title="DermaBot - AI Skin Disease Detector", page_icon="🩺", layout="wide") # Load model and processor MODEL_NAME = "Jayanth2002/dinov2-base-finetuned-SkinDisease" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" model = AutoModelForImageClassification.from_pretrained(MODEL_NAME).to(DEVICE) processor = AutoImageProcessor.from_pretrained(MODEL_NAME) # Set up the Groq API key (replace with your actual key or use an environment variable) client = Groq(api_key="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") # Initialize session state for disease details if "disease_name" not in st.session_state: st.session_state.disease_name = None if "disease_info" not in st.session_state: st.session_state.disease_info = None # Function to predict skin disease def predict_skin_disease(image): image = image.convert("RGB") inputs = processor(images=image, return_tensors="pt").to(DEVICE) with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits predicted_class_idx = logits.argmax(-1).item() predicted_label = model.config.id2label[predicted_class_idx] return predicted_label # Function to get disease details from Groq API def get_disease_info(disease_name): prompt = f"Provide a detailed explanation about the skin disease '{disease_name}', including description of disease, causes, precausions, risk and treatment options." chat_completion = client.chat.completions.create( messages=[{"role": "user", "content": prompt}], model="llama-3.3-70b-versatile", ) return chat_completion.choices[0].message.content # Function to handle chatbot queries def chatbot_response(disease_name, user_query): if not disease_name: return "Please upload an image and detect the disease first." prompt = f"The detected skin disease is '{disease_name}'. {user_query}" chat_completion = client.chat.completions.create( messages=[{"role": "user", "content": prompt}], model="llama-3.3-70b-versatile", ) return chat_completion.choices[0].message.content # Streamlit UI st.image("https://huggingface.co/spaces/your-huggingface-space/logo.png", width=200) st.title("🩺 DermaBot - AI Skin Disease Detector") st.write("Upload an image of a skin condition to get a diagnosis and ask questions about it.") # Step 1: Upload image uploaded_image = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"]) if uploaded_image: image = Image.open(uploaded_image) st.image(image, caption="Uploaded Image", use_container_width=True) # Step 2: Detect disease if st.button("Detect Disease"): with st.spinner("Analyzing..."): disease_name = predict_skin_disease(image) disease_info = get_disease_info(disease_name) # Store results in session state st.session_state.disease_name = disease_name st.session_state.disease_info = disease_info # Display detected disease information if available if st.session_state.disease_name: st.success(f"**Detected Disease:** {st.session_state.disease_name}") st.write(f"**Details:** {st.session_state.disease_info}") # Step 3: Chatbot st.subheader("💬 Ask DermaBot about this disease") user_query = st.text_input("Ask about the detected disease:") if st.button("Ask"): with st.spinner("Thinking..."): response = chatbot_response(st.session_state.disease_name, user_query) st.write(response) st.markdown("---") st.write("🔍 Powered by **AI & Groq API** | © 2025 DermaBot")