import streamlit as st from PIL import Image import torch import cohere import torch.nn as nn from torchvision import transforms from torchvision.models import vit_b_16, vit_l_16, ViT_B_16_Weights, ViT_L_16_Weights import pandas as pd from huggingface_hub import hf_hub_download from langchain_huggingface import HuggingFaceEmbeddings import io import os import base64 from fpdf import FPDF from sqlalchemy import create_engine from langchain.chains import RetrievalQA from langchain.prompts import PromptTemplate from qdrant_client import QdrantClient from qdrant_client.http.models import Distance, VectorParams from sentence_transformers import SentenceTransformer # from langchain_community.vectorstores.pgvector import PGVector # from langchain_postgres import PGVector from langchain_openai import OpenAIEmbeddings, ChatOpenAI from langchain_community.vectorstores import Qdrant from langchain_community.embeddings import HuggingFaceEmbeddings from langchain_community.embeddings import SentenceTransformerEmbeddings from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification, AutoModelForCausalLM import nest_asyncio torch.cuda.empty_cache() nest_asyncio.apply() co = cohere.Client(st.secrets["COHERE_API_KEY"]) st.set_page_config(page_title="DermBOT", page_icon="🧬", layout="centered") # === Model Selection === available_models = ["GPT-4o", "LLaMA 4 Maverick", "Gemini 2.5 Pro", "All"] st.session_state["selected_model"] = st.sidebar.selectbox("Select LLM Model", available_models) # === Qdrant DB Setup === qdrant_client = QdrantClient( url="https://2715ddd8-647f-40ee-bca4-9027d193e8aa.us-east-1-0.aws.cloud.qdrant.io", api_key="eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIn0.HXzezXdWMFeeR16F7zvqgjzsqrcm8hqa-StXdToFP9Q" ) collection_name = "ks_collection_1.5BE" # embedding_model = SentenceTransformer("D:\DR\RAG\gte-Qwen2-1.5B-instruct", trust_remote_code=True) # embedding_model.max_seq_length = 8192 # local_embedding = SentenceTransformerEmbeddings(model=embedding_model) device = "cuda" if torch.cuda.is_available() else "cpu" def get_safe_embedding_model(): model_name = "Alibaba-NLP/gte-Qwen2-1.5B-instruct" try: print("Trying to load embedding model on CUDA...") embedding = HuggingFaceEmbeddings( model_name=model_name, model_kwargs={ "trust_remote_code": True, "device": "cuda" } ) print("Loaded embedding model on GPU.") return embedding except RuntimeError as e: if "CUDA out of memory" in str(e): print("CUDA OOM. Falling back to CPU.") else: print(" Error loading model on CUDA:", str(e)) print("Loading embedding model on CPU...") return HuggingFaceEmbeddings( model_name=model_name, model_kwargs={ "trust_remote_code": True, "device": "cpu" } ) # Replace your old local_embedding line with this local_embedding = get_safe_embedding_model() print(" Qwen2-1.5B local embedding model loaded.") vector_store = Qdrant( client=qdrant_client, collection_name=collection_name, embeddings=local_embedding ) retriever = vector_store.as_retriever() pair_ranker = pipeline( "text-classification", model="llm-blender/PairRM", tokenizer="llm-blender/PairRM", return_all_scores=True ) gen_fuser = pipeline( "text-generation", model="llm-blender/gen_fuser_3b", tokenizer="llm-blender/gen_fuser_3b", max_length=2048, do_sample=False ) # selected_model = st.session_state["selected_model"] if "OpenAI" in selected_model: from langchain_openai import ChatOpenAI llm = ChatOpenAI(model="gpt-4o", temperature=0.2, api_key=st.secrets["OPENAI_API_KEY"]) elif "LLaMA" in selected_model: from groq import Groq client = Groq(api_key=st.secrets["GROQ_API_KEY"]) # Store in `.streamlit/secrets.toml` def get_llama_response(prompt): completion = client.chat.completions.create( model="meta-llama/llama-4-maverick-17b-128e-instruct", messages=[{"role": "user", "content": prompt}], temperature=1, max_completion_tokens=1024, top_p=1, stream=False ) return completion.choices[0].message.content llm = get_llama_response # use this in place of llm.invoke() elif "Gemini" in selected_model: import google.generativeai as genai genai.configure(api_key=st.secrets["GEMINI_API_KEY"]) # Store in `.streamlit/secrets.toml` gemini_model = genai.GenerativeModel("gemini-2.5-pro-exp-03-25") def get_gemini_response(prompt): response = gemini_model.generate_content(prompt) return response.text llm = get_gemini_response elif "All" in selected_model: from groq import Groq import google.generativeai as genai genai.configure(api_key=st.secrets["GEMINI_API_KEY"]) def get_all_model_responses(prompt): openai_resp = ChatOpenAI(model="gpt-4o", temperature=0.2, api_key=st.secrets["OPENAI_API_KEY"]).invoke( [{"role": "system", "content": prompt}]).content gemini = genai.GenerativeModel("gemini-2.5-pro-exp-03-25") gemini_resp = gemini.generate_content(prompt).text llama = Groq(api_key=st.secrets["GROQ_API_KEY"]) llama_resp = llama.chat.completions.create( model="meta-llama/llama-4-maverick-17b-128e-instruct", messages=[{"role": "user", "content": prompt}], temperature=1, max_completion_tokens=1024, top_p=1, stream=False ).choices[0].message.content return [openai_resp, gemini_resp, llama_resp] def rank_and_fuse(prompt, responses): ranked = [(resp, pair_ranker(f"{prompt}\n\n{resp}")[0][1]['score']) for resp in responses] ranked.sort(key=lambda x: x[1], reverse=True) fusion_input = "\n\n".join([f"[Answer {i + 1}]: {ans}" for i, (ans, _) in enumerate(ranked)]) return gen_fuser(f"Fuse these responses:\n{fusion_input}", return_full_text=False)[0]['generated_text'] else: st.error("Unsupported model selected.") st.stop() # retriever = vector_store.as_retriever() AI_PROMPT_TEMPLATE = """ You are DermBOT, a compassionate and knowledgeable AI Dermatology Assistant designed to educate users about skin-related health concerns with clarity, empathy, and precision. Your goal is to respond like a well-informed human expertβ€”balancing professionalism with warmth and reassurance. When crafting responses: - Begin with a clear, engaging summary of the condition or concern. - Use short paragraphs for readability. - Include bullet points or numbered lists where appropriate. - Avoid overly technical terms unless explained simply. - End with a helpful next step, such as lifestyle advice or when to see a doctor. 🩺 Response Structure: 1. **Overview** β€” Briefly introduce the condition or concern. 2. **Common Symptoms** β€” Describe noticeable signs in simple terms. 3. **Causes & Risk Factors** β€” Include genetic, lifestyle, and environmental aspects. 4. **Treatment Options** β€” Outline common OTC and prescription treatments. 5. **When to Seek Help** β€” Warn about symptoms that require urgent care. Always encourage consulting a licensed dermatologist for personal diagnosis and treatment. For any breathing difficulties, serious infections, or rapid symptom worsening, advise calling emergency services immediately. --- Query: {question} Relevant Context: {context} Your Response: """ prompt_template = PromptTemplate(template=AI_PROMPT_TEMPLATE, input_variables=["question", "context"]) # rag_chain = RetrievalQA.from_chain_type( # llm=llm, # retriever=retriever, # chain_type="stuff", # chain_type_kwargs={"prompt": prompt_template, "document_variable_name": "context"} # ) # === Class Names === multilabel_class_names = [ "Vesicle", "Papule", "Macule", "Plaque", "Abscess", "Pustule", "Bulla", "Patch", "Nodule", "Ulcer", "Crust", "Erosion", "Excoriation", "Atrophy", "Exudate", "Purpura/Petechiae", "Fissure", "Induration", "Xerosis", "Telangiectasia", "Scale", "Scar", "Friable", "Sclerosis", "Pedunculated", "Exophytic/Fungating", "Warty/Papillomatous", "Dome-shaped", "Flat topped", "Brown(Hyperpigmentation)", "Translucent", "White(Hypopigmentation)", "Purple", "Yellow", "Black", "Erythema", "Comedo", "Lichenification", "Blue", "Umbilicated", "Poikiloderma", "Salmon", "Wheal", "Acuminate", "Burrow", "Gray", "Pigmented", "Cyst" ] multiclass_class_names = [ "systemic", "hair", "drug_reactions", "uriticaria", "acne", "light", "autoimmune", "papulosquamous", "eczema", "skincancer", "benign_tumors", "bacteria_parasetic_infections", "fungal_infections", "viral_skin_infections" ] # === Load Models === class SkinViT(nn.Module): def __init__(self, num_classes): super(SkinViT, self).__init__() self.model = vit_b_16(weights=ViT_B_16_Weights.DEFAULT) in_features = self.model.heads.head.in_features self.model.heads.head = nn.Linear(in_features, num_classes) def forward(self, x): return self.model(x) class DermNetViT(nn.Module): def __init__(self, num_classes): super(DermNetViT, self).__init__() self.model = vit_l_16(weights=ViT_L_16_Weights.DEFAULT) in_features = self.model.heads[0].in_features self.model.heads[0] = nn.Sequential( nn.Dropout(0.3), nn.Linear(in_features, num_classes) ) def forward(self, x): return self.model(x) # multilabel_model = torch.load("D:/DR/RAG/BestModels2703/skin_vit_fold10.pth", map_location='cpu') # multiclass_model = torch.load("D:/DR/RAG/BestModels2703/best_dermnet_vit.pth", map_location='cpu') # === Load Model State Dicts === multilabel_model_path = hf_hub_download(repo_id="santhoshraghu/DermBOT", filename="skin_vit_fold10_sd.pth") multiclass_model_path = hf_hub_download(repo_id="santhoshraghu/DermBOT", filename="best_dermnet_vit_sd.pth") def load_model_with_fallback(model_class, weight_path, num_classes, model_name): try: print(f"πŸ” Loading {model_name} on GPU...") model = model_class(num_classes) model.load_state_dict(torch.load(weight_path, map_location="cuda")) model.to("cuda") print(f"βœ… {model_name} loaded on GPU.") return model except RuntimeError as e: if "CUDA out of memory" in str(e): print(f"⚠️ {model_name} OOM. Falling back to CPU.") else: print(f"❌ Error loading {model_name} on CUDA: {e}") print(f"πŸ”„ Loading {model_name} on CPU...") model = model_class(num_classes) model.load_state_dict(torch.load(weight_path, map_location="cpu")) model.to("cpu") return model # Load both models with fallback multilabel_model = load_model_with_fallback(SkinViT, multilabel_model_path, len(multilabel_class_names), "SkinViT") multiclass_model = load_model_with_fallback(DermNetViT, multiclass_model_path, len(multiclass_class_names), "DermNetViT") multilabel_model.eval() multiclass_model.eval() # === Session Init === if "messages" not in st.session_state: st.session_state.messages = [] # === Image Processing Function === def run_inference(image): transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]) ]) input_tensor = transform(image).unsqueeze(0) # Automatically match model device (GPU or CPU) model_device = next(multilabel_model.parameters()).device input_tensor = input_tensor.to(model_device) with torch.no_grad(): probs_multi = torch.sigmoid(multilabel_model(input_tensor)).squeeze().cpu().numpy() pred_idx = torch.argmax(multiclass_model(input_tensor), dim=1).item() predicted_multi = [multilabel_class_names[i] for i, p in enumerate(probs_multi) if p > 0.5] predicted_single = multiclass_class_names[pred_idx] return predicted_multi, predicted_single # === PDF Export === def export_chat_to_pdf(messages): pdf = FPDF() pdf.add_page() pdf.set_font("Arial", size=12) for msg in messages: role = "You" if msg["role"] == "user" else "AI" pdf.multi_cell(0, 10, f"{role}: {msg['content']}\n") buf = io.BytesIO() pdf.output(buf) buf.seek(0) return buf # Reranker utility def rerank_with_cohere(query, documents, top_n=5): if not documents: return [] raw_texts = [doc.page_content for doc in documents] results = co.rerank(query=query, documents=raw_texts, top_n=min(top_n, len(raw_texts)), model="rerank-v3.5") return [documents[result.index] for result in results] # Final answer generation using reranked context def get_reranked_response(query): docs = retriever.get_relevant_documents(query) reranked_docs = rerank_with_cohere(query, docs) context = "\n\n".join([doc.page_content for doc in reranked_docs]) prompt = AI_PROMPT_TEMPLATE.format(question=query, context=context) if selected_model == "All": responses = get_all_model_responses(prompt) fused = rank_and_fuse(prompt, responses) return type("Obj", (), {"content": fused}) if callable(llm): return type("Obj", (), {"content": llm(prompt)}) else: return llm.invoke([{"role": "system", "content": prompt}]) # === App UI === st.title("🧬 DermBOT β€” Skin AI Assistant") st.caption(f"🧠 Using model: {selected_model}") uploaded_file = st.file_uploader("Upload a skin image", type=["jpg", "jpeg", "png"]) if uploaded_file: st.image(uploaded_file, caption="Uploaded image", use_container_width=True) image = Image.open(uploaded_file).convert("RGB") predicted_multi, predicted_single = run_inference(image) # Show predictions clearly to the user st.markdown(f"🧾 **Skin Issues**: {', '.join(predicted_multi)}") st.markdown(f"πŸ“Œ **Most Likely Diagnosis**: {predicted_single}") query = f"What are my treatment options for {predicted_multi} and {predicted_single}?" st.session_state.messages.append({"role": "user", "content": query}) with st.spinner("πŸ”Ž Analyzing and retrieving context..."): response = get_reranked_response(query) st.session_state.messages.append({"role": "assistant", "content": response.content}) with st.chat_message("assistant"): st.markdown(response.content) # === Chat Interface === if prompt := st.chat_input("Ask a follow-up..."): st.session_state.messages.append({"role": "user", "content": prompt}) with st.chat_message("user"): st.markdown(prompt) response = get_reranked_response(prompt) st.session_state.messages.append({"role": "assistant", "content": response.content}) with st.chat_message("assistant"): st.markdown(response.content) # === PDF Button === if st.button("πŸ“„ Download Chat as PDF"): pdf_file = export_chat_to_pdf(st.session_state.messages) st.download_button("Download PDF", data=pdf_file, file_name="chat_history.pdf", mime="application/pdf")