Spaces:
Sleeping
Sleeping
RAG fix
Browse files- .streamlit/secrets.toml +4 -0
- app.py +33 -70
- rag_pipeline.py +142 -18
- requirements.txt +4 -1
- test.py +416 -0
.streamlit/secrets.toml
CHANGED
@@ -1 +1,5 @@
|
|
1 |
OPENAI_API_KEY = "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"
|
|
|
|
|
|
|
|
|
|
1 |
OPENAI_API_KEY = "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"
|
2 |
+
QDRANT_API_KEY = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIn0.HXzezXdWMFeeR16F7zvqgjzsqrcm8hqa-StXdToFP9Q"
|
3 |
+
COHERE_API_KEY = "PBEEQJ8n9sV2Xhpc7OMb8NRsBKtADvcEi9V0iPm5"
|
4 |
+
GROQ_API_KEY = "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"
|
5 |
+
GOOGLE_API_KEY = "AIzaSyD5pnSzkIuu86ByTPQewVKlh2zxOJI-f8M"
|
app.py
CHANGED
@@ -8,14 +8,23 @@ import torch.nn.functional as F
|
|
8 |
from evo_vit import EvoViTModel
|
9 |
import io
|
10 |
import os
|
|
|
11 |
from fpdf import FPDF
|
12 |
from torchvision.models import resnet50
|
13 |
import nest_asyncio
|
14 |
from huggingface_hub import hf_hub_download
|
15 |
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
|
16 |
from SkinCancerDiagnosis import initialize_classifier
|
17 |
-
from rag_pipeline import
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
from langchain_core.messages import HumanMessage, AIMessage
|
|
|
|
|
19 |
|
20 |
nest_asyncio.apply()
|
21 |
device='cuda' if torch.cuda.is_available() else 'cpu'
|
@@ -25,33 +34,15 @@ st.set_page_config(page_title="DermBOT", page_icon="π§¬", layout="centered")
|
|
25 |
@st.cache_resource(show_spinner=False)
|
26 |
def load_models():
|
27 |
"""Cache all models to load only once"""
|
28 |
-
with st.spinner("Loading AI models (one-time operation)..."):
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
return ChatOpenAI(model="gpt-4o", temperature=0.2, api_key=_api_key)
|
37 |
-
elif "LLaMA" in _model_name:
|
38 |
-
st.warning("LLaMA integration is not implemented yet.")
|
39 |
-
st.stop()
|
40 |
-
elif "Gemini" in _model_name:
|
41 |
-
st.warning("Gemini integration is not implemented yet.")
|
42 |
-
st.stop()
|
43 |
-
else:
|
44 |
-
st.error("Unsupported model selected.")
|
45 |
-
st.stop()
|
46 |
-
|
47 |
-
@st.cache_resource(show_spinner=False)
|
48 |
-
def load_rag_chain(_model_name, _api_key):
|
49 |
-
"""Initialize RAG chain only once"""
|
50 |
-
llm = initialize_llm(_model_name, _api_key)
|
51 |
-
return invoke_rag_chain(llm)
|
52 |
|
53 |
-
# === Model Selection ===
|
54 |
-
available_models = ["OpenAI GPT-4o", "LLaMA 3", "Gemini Pro"]
|
55 |
if "selected_model" not in st.session_state:
|
56 |
st.session_state["selected_model"] = available_models[0]
|
57 |
|
@@ -63,7 +54,10 @@ st.session_state["selected_model"] = st.sidebar.selectbox(
|
|
63 |
index=available_models.index(st.session_state["selected_model"])
|
64 |
)
|
65 |
|
66 |
-
|
|
|
|
|
|
|
67 |
|
68 |
if "model_change_confirmed" not in st.session_state:
|
69 |
st.session_state.model_change_confirmed = False
|
@@ -88,15 +82,12 @@ if st.session_state["selected_model"] != previous_model:
|
|
88 |
st.session_state.model_change_confirmed = True
|
89 |
|
90 |
if "model_change_confirmed" not in st.session_state or st.session_state.model_change_confirmed:
|
91 |
-
llm = initialize_llm(st.session_state["selected_model"]
|
92 |
-
rag_chain = load_rag_chain(st.session_state[
|
93 |
-
st.session_state.llm
|
94 |
-
st.session_state.rag_chain = rag_chain
|
95 |
else:
|
96 |
-
|
97 |
-
rag_chain = st.session_state.get("rag_chain", load_rag_chain(previous_model, OPENAI_API_KEY))
|
98 |
|
99 |
-
classifier = load_models()
|
100 |
|
101 |
# === Session Init ===
|
102 |
if "messages" not in st.session_state:
|
@@ -149,10 +140,10 @@ if uploaded_file is not None and uploaded_file != st.session_state.current_image
|
|
149 |
|
150 |
initial_query = f"What are my treatment options for {predicted_label}?"
|
151 |
st.session_state.messages.append({"role": "user", "content": initial_query})
|
152 |
-
|
153 |
with st.spinner("Retrieving medical information..."):
|
154 |
-
response =
|
155 |
-
st.session_state.messages.append({"role": "assistant", "content": response
|
|
|
156 |
|
157 |
for message in st.session_state.messages:
|
158 |
with st.chat_message(message["role"]):
|
@@ -165,32 +156,6 @@ if prompt := st.chat_input("Ask a follow-up question..."):
|
|
165 |
st.markdown(prompt)
|
166 |
|
167 |
with st.chat_message("assistant"):
|
168 |
-
# with st.spinner("Thinking..."):
|
169 |
-
# Convert messages to LangChain format
|
170 |
-
# chat_history = []
|
171 |
-
# for msg in st.session_state.messages[:-1]: # Exclude the current prompt
|
172 |
-
# if msg["role"] == "user":
|
173 |
-
# chat_history.append(HumanMessage(content=msg["content"]))
|
174 |
-
# else:
|
175 |
-
# chat_history.append(AIMessage(content=msg["content"]))
|
176 |
-
#
|
177 |
-
# # Get response
|
178 |
-
# response = llm.invoke([HumanMessage(content=prompt)] + chat_history)
|
179 |
-
# assistant_response = response.content
|
180 |
-
#
|
181 |
-
# st.markdown(assistant_response)
|
182 |
-
# st.session_state.messages.append({"role": "assistant", "content": assistant_response})
|
183 |
-
# with st.spinner("Thinking..."):
|
184 |
-
# if len(st.session_state.messages) > 1:
|
185 |
-
# response = llm.invoke([{"role": m["role"], "content": m["content"]} for m in st.session_state.messages])
|
186 |
-
# response = response.content
|
187 |
-
# else:
|
188 |
-
# response = rag_chain.invoke(prompt)
|
189 |
-
# response = response['result']
|
190 |
-
#
|
191 |
-
# st.markdown(response)
|
192 |
-
# st.session_state.messages.append({"role": "assistant", "content": response})
|
193 |
-
|
194 |
with st.spinner("Thinking..."):
|
195 |
if len(st.session_state.messages) > 1:
|
196 |
conversation_context = "\n".join(
|
@@ -201,14 +166,12 @@ if prompt := st.chat_input("Ask a follow-up question..."):
|
|
201 |
f"Conversation history:\n{conversation_context}\n\n"
|
202 |
f"Current question: {prompt}"
|
203 |
)
|
204 |
-
response =
|
205 |
-
assistant_response = response['result']
|
206 |
else:
|
207 |
-
response =
|
208 |
-
assistant_response = response['result']
|
209 |
|
210 |
-
st.markdown(
|
211 |
-
st.session_state.messages.append({"role": "assistant", "content":
|
212 |
|
213 |
if st.session_state.messages and st.button("π Download Chat as PDF"):
|
214 |
pdf_file = export_chat_to_pdf(st.session_state.messages)
|
|
|
8 |
from evo_vit import EvoViTModel
|
9 |
import io
|
10 |
import os
|
11 |
+
import cohere
|
12 |
from fpdf import FPDF
|
13 |
from torchvision.models import resnet50
|
14 |
import nest_asyncio
|
15 |
from huggingface_hub import hf_hub_download
|
16 |
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
|
17 |
from SkinCancerDiagnosis import initialize_classifier
|
18 |
+
from rag_pipeline import (
|
19 |
+
available_models,
|
20 |
+
initialize_llm,
|
21 |
+
load_rag_chain,
|
22 |
+
get_reranked_response,
|
23 |
+
initialize_rag_components
|
24 |
+
)
|
25 |
from langchain_core.messages import HumanMessage, AIMessage
|
26 |
+
from groq import Groq
|
27 |
+
import google.generativeai as genai
|
28 |
|
29 |
nest_asyncio.apply()
|
30 |
device='cuda' if torch.cuda.is_available() else 'cpu'
|
|
|
34 |
@st.cache_resource(show_spinner=False)
|
35 |
def load_models():
|
36 |
"""Cache all models to load only once"""
|
37 |
+
with st.spinner("Loading all AI models (one-time operation)..."):
|
38 |
+
models = {
|
39 |
+
'classifier': initialize_classifier(),
|
40 |
+
'rag_components': initialize_rag_components(),
|
41 |
+
'llm': initialize_llm(st.session_state["selected_model"])
|
42 |
+
}
|
43 |
+
models['rag_chain'] = load_rag_chain(models['llm'])
|
44 |
+
return models
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
|
|
|
|
|
46 |
if "selected_model" not in st.session_state:
|
47 |
st.session_state["selected_model"] = available_models[0]
|
48 |
|
|
|
54 |
index=available_models.index(st.session_state["selected_model"])
|
55 |
)
|
56 |
|
57 |
+
if 'app_models' not in st.session_state:
|
58 |
+
st.session_state.app_models = load_models()
|
59 |
+
classifier = st.session_state.app_models['classifier']
|
60 |
+
llm = st.session_state.app_models['llm']
|
61 |
|
62 |
if "model_change_confirmed" not in st.session_state:
|
63 |
st.session_state.model_change_confirmed = False
|
|
|
82 |
st.session_state.model_change_confirmed = True
|
83 |
|
84 |
if "model_change_confirmed" not in st.session_state or st.session_state.model_change_confirmed:
|
85 |
+
st.session_state.app_models['llm'] = initialize_llm(st.session_state["selected_model"])
|
86 |
+
st.session_state.app_models['rag_chain'] = load_rag_chain(st.session_state.app_models['llm'])
|
87 |
+
llm = st.session_state.app_models['llm']
|
|
|
88 |
else:
|
89 |
+
pass
|
|
|
90 |
|
|
|
91 |
|
92 |
# === Session Init ===
|
93 |
if "messages" not in st.session_state:
|
|
|
140 |
|
141 |
initial_query = f"What are my treatment options for {predicted_label}?"
|
142 |
st.session_state.messages.append({"role": "user", "content": initial_query})
|
|
|
143 |
with st.spinner("Retrieving medical information..."):
|
144 |
+
response = get_reranked_response(initial_query, st.session_state.app_models['llm'], st.session_state.app_models['rag_components'])
|
145 |
+
st.session_state.messages.append({"role": "assistant", "content": response})
|
146 |
+
|
147 |
|
148 |
for message in st.session_state.messages:
|
149 |
with st.chat_message(message["role"]):
|
|
|
156 |
st.markdown(prompt)
|
157 |
|
158 |
with st.chat_message("assistant"):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
159 |
with st.spinner("Thinking..."):
|
160 |
if len(st.session_state.messages) > 1:
|
161 |
conversation_context = "\n".join(
|
|
|
166 |
f"Conversation history:\n{conversation_context}\n\n"
|
167 |
f"Current question: {prompt}"
|
168 |
)
|
169 |
+
response = get_reranked_response(augmented_prompt, st.session_state.app_models['llm'], st.session_state.app_models['rag_components'])
|
|
|
170 |
else:
|
171 |
+
response = get_reranked_response(prompt, st.session_state.app_models['llm'], st.session_state.app_models['rag_components'])
|
|
|
172 |
|
173 |
+
st.markdown(response)
|
174 |
+
st.session_state.messages.append({"role": "assistant", "content": response})
|
175 |
|
176 |
if st.session_state.messages and st.button("π Download Chat as PDF"):
|
177 |
pdf_file = export_chat_to_pdf(st.session_state.messages)
|
rag_pipeline.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
from langchain.chains import RetrievalQA
|
2 |
from langchain.prompts import PromptTemplate
|
3 |
from sentence_transformers import SentenceTransformer
|
@@ -5,10 +6,89 @@ from qdrant_client import QdrantClient
|
|
5 |
from langchain_qdrant import Qdrant
|
6 |
from langchain_community.embeddings import HuggingFaceEmbeddings
|
7 |
from langchain_community.embeddings import SentenceTransformerEmbeddings
|
|
|
8 |
import os
|
9 |
import torch
|
|
|
|
|
|
|
|
|
10 |
|
11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
# === Qdrant DB Setup ===
|
13 |
qdrant_client = QdrantClient(
|
14 |
url="https://2715ddd8-647f-40ee-bca4-9027d193e8aa.us-east-1-0.aws.cloud.qdrant.io",
|
@@ -30,30 +110,74 @@ def invoke_rag_chain(llm):
|
|
30 |
collection_name=collection_name,
|
31 |
embeddings=local_embedding
|
32 |
)
|
33 |
-
|
34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
|
36 |
-
AI_PROMPT_TEMPLATE = """You are an AI-assisted Dermatology Chatbot, specializing in diagnosing and educating users about skin diseases.
|
37 |
-
You provide accurate, compassionate, and detailed explanations while using correct medical terminology.
|
38 |
-
|
39 |
-
Guidelines:
|
40 |
-
1. Symptoms - Explain in simple terms with proper medical definitions.
|
41 |
-
2. Causes - Include genetic, environmental, and lifestyle-related risk factors.
|
42 |
-
3. Medications & Treatments - Provide common prescription and over-the-counter treatments.
|
43 |
-
4. Warnings & Emergencies - Always recommend consulting a licensed dermatologist.
|
44 |
-
5. Emergency Note - If symptoms worsen or include difficulty breathing, **advise calling 911 immediately.
|
45 |
-
|
46 |
-
Query: {question}
|
47 |
-
Relevant Information: {context}
|
48 |
-
Answer:
|
49 |
-
"""
|
50 |
prompt_template = PromptTemplate(template=AI_PROMPT_TEMPLATE, input_variables=["question", "context"])
|
51 |
|
52 |
rag_chain = RetrievalQA.from_chain_type(
|
53 |
llm=llm,
|
54 |
-
retriever=
|
55 |
chain_type="stuff",
|
56 |
chain_type_kwargs={"prompt": prompt_template, "document_variable_name": "context"}
|
57 |
)
|
58 |
|
59 |
-
return rag_chain
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
from langchain.chains import RetrievalQA
|
3 |
from langchain.prompts import PromptTemplate
|
4 |
from sentence_transformers import SentenceTransformer
|
|
|
6 |
from langchain_qdrant import Qdrant
|
7 |
from langchain_community.embeddings import HuggingFaceEmbeddings
|
8 |
from langchain_community.embeddings import SentenceTransformerEmbeddings
|
9 |
+
from transformers import pipeline
|
10 |
import os
|
11 |
import torch
|
12 |
+
from groq import Groq
|
13 |
+
import google.generativeai as genai
|
14 |
+
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
|
15 |
+
import cohere
|
16 |
|
17 |
+
available_models = ["OpenAI GPT-4o", "LLaMA 3", "Gemini Pro", "Ensemble"]
|
18 |
+
AI_PROMPT_TEMPLATE = """You are an AI-assisted Dermatology Chatbot, specializing in diagnosing and educating users about skin diseases.
|
19 |
+
You provide accurate, compassionate, and detailed explanations while using correct medical terminology.
|
20 |
+
|
21 |
+
Guidelines:
|
22 |
+
1. Symptoms - Explain in simple terms with proper medical definitions.
|
23 |
+
2. Causes - Include genetic, environmental, and lifestyle-related risk factors.
|
24 |
+
3. Medications & Treatments - Provide common prescription and over-the-counter treatments.
|
25 |
+
4. Warnings & Emergencies - Always recommend consulting a licensed dermatologist.
|
26 |
+
5. Emergency Note - If symptoms worsen or include difficulty breathing, **advise calling 911 immediately.
|
27 |
+
|
28 |
+
Query: {question}
|
29 |
+
Relevant Information: {context}
|
30 |
+
Answer:
|
31 |
+
"""
|
32 |
+
|
33 |
+
|
34 |
+
@st.cache_resource(show_spinner=False)
|
35 |
+
def initialize_rag_components():
|
36 |
+
components = {
|
37 |
+
'cohere_client': cohere.Client(st.secrets["COHERE_API_KEY"]),
|
38 |
+
'pair_ranker': pipeline("text-classification",
|
39 |
+
model="llm-blender/PairRM",
|
40 |
+
tokenizer="llm-blender/PairRM",
|
41 |
+
return_all_scores=True
|
42 |
+
),
|
43 |
+
'gen_fuser': pipeline("text-generation",
|
44 |
+
model="llm-blender/gen_fuser_3b",
|
45 |
+
tokenizer="llm-blender/gen_fuser_3b",
|
46 |
+
max_length=2048,
|
47 |
+
do_sample=False
|
48 |
+
),
|
49 |
+
'retriever': get_retriever()
|
50 |
+
}
|
51 |
+
return components
|
52 |
+
|
53 |
+
class AllModelsWrapper:
|
54 |
+
def invoke(self, messages):
|
55 |
+
prompt = messages[0]["content"]
|
56 |
+
rag_components = st.session_state.app_models['rag_components'] # Get components
|
57 |
+
responses = get_all_responses(prompt)
|
58 |
+
fused = rank_and_fuse(prompt, responses, rag_components)
|
59 |
+
return type('obj', (object,), {'content': fused})()
|
60 |
+
|
61 |
+
def get_all_responses(prompt):
|
62 |
+
# Get responses from all models
|
63 |
+
openai_resp = ChatOpenAI(model="gpt-4o", temperature=0.2,
|
64 |
+
api_key=st.secrets["OPENAI_API_KEY"]).invoke(
|
65 |
+
[{"role": "user", "content": prompt}]).content
|
66 |
+
|
67 |
+
gemini = genai.GenerativeModel("gemini-2.5-pro-exp-03-25")
|
68 |
+
gemini_resp = gemini.generate_content(prompt).text
|
69 |
+
|
70 |
+
llama = Groq(api_key=st.secrets["GROQ_API_KEY"])
|
71 |
+
llama_resp = llama.chat.completions.create(
|
72 |
+
model="meta-llama/llama-4-maverick-17b-128e-instruct",
|
73 |
+
messages=[{"role": "user", "content": prompt}],
|
74 |
+
temperature=1, max_completion_tokens=1024, top_p=1, stream=False
|
75 |
+
).choices[0].message.content
|
76 |
+
|
77 |
+
return [openai_resp, gemini_resp, llama_resp]
|
78 |
+
|
79 |
+
|
80 |
+
def rank_and_fuse(prompt, responses, rag_components):
|
81 |
+
ranked = [(resp, rag_components['pair_ranker'](f"{prompt}\n\n{resp}")[0][1]['score'])
|
82 |
+
for resp in responses]
|
83 |
+
ranked.sort(key=lambda x: x[1], reverse=True)
|
84 |
+
|
85 |
+
# Fuse top responses
|
86 |
+
fusion_input = "\n\n".join([f"[Answer {i + 1}]: {ans}" for i, (ans, _) in enumerate(ranked[:2])])
|
87 |
+
return rag_components['gen_fuser'](f"Fuse these responses:\n{fusion_input}",
|
88 |
+
return_full_text=False)[0]['generated_text']
|
89 |
+
|
90 |
+
|
91 |
+
def get_retriever():
|
92 |
# === Qdrant DB Setup ===
|
93 |
qdrant_client = QdrantClient(
|
94 |
url="https://2715ddd8-647f-40ee-bca4-9027d193e8aa.us-east-1-0.aws.cloud.qdrant.io",
|
|
|
110 |
collection_name=collection_name,
|
111 |
embeddings=local_embedding
|
112 |
)
|
113 |
+
return vector_store.as_retriever()
|
114 |
|
115 |
+
def initialize_llm(_model_name):
|
116 |
+
"""Initialize the LLM based on selection"""
|
117 |
+
print(f"Model name : {_model_name}")
|
118 |
+
if "OpenAI" in _model_name:
|
119 |
+
return ChatOpenAI(model="gpt-4o", temperature=0.2, api_key=st.secrets["OPENAI_API_KEY"])
|
120 |
+
elif "LLaMA" in _model_name:
|
121 |
+
client = Groq(api_key=st.secrets["GROQ_API_KEY"])
|
122 |
+
def get_llama_response(prompt):
|
123 |
+
completion = client.chat.completions.create(
|
124 |
+
model="meta-llama/llama-4-maverick-17b-128e-instruct",
|
125 |
+
messages=[{"role": "user", "content": prompt}],
|
126 |
+
temperature=1,
|
127 |
+
max_completion_tokens=1024,
|
128 |
+
top_p=1,
|
129 |
+
stream=False
|
130 |
+
)
|
131 |
+
return completion.choices[0].message.content
|
132 |
+
return type('obj', (object,), {'invoke': lambda self, x: get_llama_response(x[0]["content"])})()
|
133 |
+
|
134 |
+
elif "Gemini" in _model_name:
|
135 |
+
genai.configure(api_key=st.secrets["GEMINI_API_KEY"])
|
136 |
+
gemini_model = genai.GenerativeModel("gemini-2.5-pro-exp-03-25")
|
137 |
+
def get_gemini_response(prompt):
|
138 |
+
response = gemini_model.generate_content(prompt)
|
139 |
+
return response.text
|
140 |
+
return type('obj', (object,), {'invoke': lambda self, x: get_gemini_response(x[0]["content"])})()
|
141 |
+
|
142 |
+
elif "Ensemble" in _model_name:
|
143 |
+
return AllModelsWrapper()
|
144 |
+
else:
|
145 |
+
raise ValueError("Unsupported model selected")
|
146 |
+
|
147 |
+
|
148 |
+
def load_rag_chain(llm):
|
149 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
150 |
prompt_template = PromptTemplate(template=AI_PROMPT_TEMPLATE, input_variables=["question", "context"])
|
151 |
|
152 |
rag_chain = RetrievalQA.from_chain_type(
|
153 |
llm=llm,
|
154 |
+
retriever=get_retriever(),
|
155 |
chain_type="stuff",
|
156 |
chain_type_kwargs={"prompt": prompt_template, "document_variable_name": "context"}
|
157 |
)
|
158 |
|
159 |
+
return rag_chain
|
160 |
+
|
161 |
+
def rerank_with_cohere(query, documents, co, top_n=5):
|
162 |
+
if not documents:
|
163 |
+
return []
|
164 |
+
raw_texts = [doc.page_content for doc in documents]
|
165 |
+
results = co.rerank(query=query, documents=raw_texts, top_n=min(top_n, len(raw_texts)), model="rerank-v3.5")
|
166 |
+
return [documents[result.index] for result in results]
|
167 |
+
|
168 |
+
|
169 |
+
def get_reranked_response(query, llm, rag_components):
|
170 |
+
"""Get response with reranking"""
|
171 |
+
docs = rag_components['retriever'].get_relevant_documents(query)
|
172 |
+
reranked_docs = rerank_with_cohere(query, docs, rag_components['cohere_client'])
|
173 |
+
context = "\n\n".join([doc.page_content for doc in reranked_docs])
|
174 |
+
|
175 |
+
if isinstance(llm, (ChatOpenAI, AllModelsWrapper)):
|
176 |
+
return load_rag_chain(llm).invoke({"query": query, "context": context})['result']
|
177 |
+
else:
|
178 |
+
prompt = AI_PROMPT_TEMPLATE.format(question=query, context=context)
|
179 |
+
return llm.invoke([{"role": "user", "content": prompt}]).content
|
180 |
+
|
181 |
+
|
182 |
+
if __name__ == "__main__":
|
183 |
+
print("This is a module - import it instead of running directly")
|
requirements.txt
CHANGED
@@ -17,4 +17,7 @@ nest_asyncio
|
|
17 |
sentence_transformers
|
18 |
langchain-qdrant
|
19 |
huggingface_hub
|
20 |
-
langchain_core
|
|
|
|
|
|
|
|
17 |
sentence_transformers
|
18 |
langchain-qdrant
|
19 |
huggingface_hub
|
20 |
+
langchain_core
|
21 |
+
groq
|
22 |
+
google
|
23 |
+
cohere
|
test.py
ADDED
@@ -0,0 +1,416 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from PIL import Image
|
3 |
+
import torch
|
4 |
+
import cohere
|
5 |
+
import torch.nn as nn
|
6 |
+
from torchvision import transforms
|
7 |
+
from torchvision.models import vit_b_16, vit_l_16, ViT_B_16_Weights, ViT_L_16_Weights
|
8 |
+
import pandas as pd
|
9 |
+
from huggingface_hub import hf_hub_download
|
10 |
+
from langchain_huggingface import HuggingFaceEmbeddings
|
11 |
+
import io
|
12 |
+
import os
|
13 |
+
import base64
|
14 |
+
from fpdf import FPDF
|
15 |
+
from sqlalchemy import create_engine
|
16 |
+
from langchain.chains import RetrievalQA
|
17 |
+
from langchain.prompts import PromptTemplate
|
18 |
+
from qdrant_client import QdrantClient
|
19 |
+
from qdrant_client.http.models import Distance, VectorParams
|
20 |
+
from sentence_transformers import SentenceTransformer
|
21 |
+
# from langchain_community.vectorstores.pgvector import PGVector
|
22 |
+
# from langchain_postgres import PGVector
|
23 |
+
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
|
24 |
+
from langchain_community.vectorstores import Qdrant
|
25 |
+
from langchain_community.embeddings import HuggingFaceEmbeddings
|
26 |
+
from langchain_community.embeddings import SentenceTransformerEmbeddings
|
27 |
+
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification, AutoModelForCausalLM
|
28 |
+
import nest_asyncio
|
29 |
+
|
30 |
+
torch.cuda.empty_cache()
|
31 |
+
nest_asyncio.apply()
|
32 |
+
co = cohere.Client(st.secrets["COHERE_API_KEY"])
|
33 |
+
|
34 |
+
st.set_page_config(page_title="DermBOT", page_icon="π§¬", layout="centered")
|
35 |
+
|
36 |
+
# === Model Selection ===
|
37 |
+
available_models = ["GPT-4o", "LLaMA 4 Maverick", "Gemini 2.5 Pro", "All"]
|
38 |
+
st.session_state["selected_model"] = st.sidebar.selectbox("Select LLM Model", available_models)
|
39 |
+
|
40 |
+
# === Qdrant DB Setup ===
|
41 |
+
qdrant_client = QdrantClient(
|
42 |
+
url="https://2715ddd8-647f-40ee-bca4-9027d193e8aa.us-east-1-0.aws.cloud.qdrant.io",
|
43 |
+
api_key="eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIn0.HXzezXdWMFeeR16F7zvqgjzsqrcm8hqa-StXdToFP9Q"
|
44 |
+
)
|
45 |
+
collection_name = "ks_collection_1.5BE"
|
46 |
+
# embedding_model = SentenceTransformer("D:\DR\RAG\gte-Qwen2-1.5B-instruct", trust_remote_code=True)
|
47 |
+
# embedding_model.max_seq_length = 8192
|
48 |
+
# local_embedding = SentenceTransformerEmbeddings(model=embedding_model)
|
49 |
+
|
50 |
+
|
51 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
52 |
+
|
53 |
+
|
54 |
+
def get_safe_embedding_model():
|
55 |
+
model_name = "Alibaba-NLP/gte-Qwen2-1.5B-instruct"
|
56 |
+
|
57 |
+
try:
|
58 |
+
print("Trying to load embedding model on CUDA...")
|
59 |
+
embedding = HuggingFaceEmbeddings(
|
60 |
+
model_name=model_name,
|
61 |
+
model_kwargs={
|
62 |
+
"trust_remote_code": True,
|
63 |
+
"device": "cuda"
|
64 |
+
}
|
65 |
+
)
|
66 |
+
print("Loaded embedding model on GPU.")
|
67 |
+
return embedding
|
68 |
+
except RuntimeError as e:
|
69 |
+
if "CUDA out of memory" in str(e):
|
70 |
+
print("CUDA OOM. Falling back to CPU.")
|
71 |
+
else:
|
72 |
+
print(" Error loading model on CUDA:", str(e))
|
73 |
+
print("Loading embedding model on CPU...")
|
74 |
+
return HuggingFaceEmbeddings(
|
75 |
+
model_name=model_name,
|
76 |
+
model_kwargs={
|
77 |
+
"trust_remote_code": True,
|
78 |
+
"device": "cpu"
|
79 |
+
}
|
80 |
+
)
|
81 |
+
|
82 |
+
|
83 |
+
# Replace your old local_embedding line with this
|
84 |
+
local_embedding = get_safe_embedding_model()
|
85 |
+
|
86 |
+
print(" Qwen2-1.5B local embedding model loaded.")
|
87 |
+
|
88 |
+
vector_store = Qdrant(
|
89 |
+
client=qdrant_client,
|
90 |
+
collection_name=collection_name,
|
91 |
+
embeddings=local_embedding
|
92 |
+
)
|
93 |
+
retriever = vector_store.as_retriever()
|
94 |
+
|
95 |
+
pair_ranker = pipeline(
|
96 |
+
"text-classification",
|
97 |
+
model="llm-blender/PairRM",
|
98 |
+
tokenizer="llm-blender/PairRM",
|
99 |
+
return_all_scores=True
|
100 |
+
)
|
101 |
+
|
102 |
+
gen_fuser = pipeline(
|
103 |
+
"text-generation",
|
104 |
+
model="llm-blender/gen_fuser_3b",
|
105 |
+
tokenizer="llm-blender/gen_fuser_3b",
|
106 |
+
max_length=2048,
|
107 |
+
do_sample=False
|
108 |
+
)
|
109 |
+
|
110 |
+
# selected_model = st.session_state["selected_model"]
|
111 |
+
|
112 |
+
if "OpenAI" in selected_model:
|
113 |
+
from langchain_openai import ChatOpenAI
|
114 |
+
|
115 |
+
llm = ChatOpenAI(model="gpt-4o", temperature=0.2, api_key=st.secrets["OPENAI_API_KEY"])
|
116 |
+
|
117 |
+
elif "LLaMA" in selected_model:
|
118 |
+
from groq import Groq
|
119 |
+
|
120 |
+
client = Groq(api_key=st.secrets["GROQ_API_KEY"]) # Store in `.streamlit/secrets.toml`
|
121 |
+
|
122 |
+
|
123 |
+
def get_llama_response(prompt):
|
124 |
+
completion = client.chat.completions.create(
|
125 |
+
model="meta-llama/llama-4-maverick-17b-128e-instruct",
|
126 |
+
messages=[{"role": "user", "content": prompt}],
|
127 |
+
temperature=1,
|
128 |
+
max_completion_tokens=1024,
|
129 |
+
top_p=1,
|
130 |
+
stream=False
|
131 |
+
)
|
132 |
+
return completion.choices[0].message.content
|
133 |
+
|
134 |
+
|
135 |
+
llm = get_llama_response # use this in place of llm.invoke()
|
136 |
+
|
137 |
+
elif "Gemini" in selected_model:
|
138 |
+
import google.generativeai as genai
|
139 |
+
|
140 |
+
genai.configure(api_key=st.secrets["GEMINI_API_KEY"]) # Store in `.streamlit/secrets.toml`
|
141 |
+
|
142 |
+
gemini_model = genai.GenerativeModel("gemini-2.5-pro-exp-03-25")
|
143 |
+
|
144 |
+
|
145 |
+
def get_gemini_response(prompt):
|
146 |
+
response = gemini_model.generate_content(prompt)
|
147 |
+
return response.text
|
148 |
+
|
149 |
+
|
150 |
+
llm = get_gemini_response
|
151 |
+
|
152 |
+
elif "All" in selected_model:
|
153 |
+
|
154 |
+
from groq import Groq
|
155 |
+
import google.generativeai as genai
|
156 |
+
|
157 |
+
genai.configure(api_key=st.secrets["GEMINI_API_KEY"])
|
158 |
+
|
159 |
+
|
160 |
+
def get_all_model_responses(prompt):
|
161 |
+
openai_resp = ChatOpenAI(model="gpt-4o", temperature=0.2, api_key=st.secrets["OPENAI_API_KEY"]).invoke(
|
162 |
+
[{"role": "system", "content": prompt}]).content
|
163 |
+
|
164 |
+
gemini = genai.GenerativeModel("gemini-2.5-pro-exp-03-25")
|
165 |
+
gemini_resp = gemini.generate_content(prompt).text
|
166 |
+
|
167 |
+
llama = Groq(api_key=st.secrets["GROQ_API_KEY"])
|
168 |
+
llama_resp = llama.chat.completions.create(
|
169 |
+
model="meta-llama/llama-4-maverick-17b-128e-instruct",
|
170 |
+
messages=[{"role": "user", "content": prompt}],
|
171 |
+
temperature=1, max_completion_tokens=1024, top_p=1, stream=False
|
172 |
+
).choices[0].message.content
|
173 |
+
|
174 |
+
return [openai_resp, gemini_resp, llama_resp]
|
175 |
+
|
176 |
+
|
177 |
+
def rank_and_fuse(prompt, responses):
|
178 |
+
ranked = [(resp, pair_ranker(f"{prompt}\n\n{resp}")[0][1]['score']) for resp in responses]
|
179 |
+
ranked.sort(key=lambda x: x[1], reverse=True)
|
180 |
+
fusion_input = "\n\n".join([f"[Answer {i + 1}]: {ans}" for i, (ans, _) in enumerate(ranked)])
|
181 |
+
return gen_fuser(f"Fuse these responses:\n{fusion_input}", return_full_text=False)[0]['generated_text']
|
182 |
+
|
183 |
+
|
184 |
+
else:
|
185 |
+
st.error("Unsupported model selected.")
|
186 |
+
st.stop()
|
187 |
+
|
188 |
+
# retriever = vector_store.as_retriever()
|
189 |
+
|
190 |
+
AI_PROMPT_TEMPLATE = """
|
191 |
+
You are DermBOT, a compassionate and knowledgeable AI Dermatology Assistant designed to educate users about skin-related health concerns with clarity, empathy, and precision.
|
192 |
+
|
193 |
+
Your goal is to respond like a well-informed human expertβbalancing professionalism with warmth and reassurance.
|
194 |
+
|
195 |
+
When crafting responses:
|
196 |
+
- Begin with a clear, engaging summary of the condition or concern.
|
197 |
+
- Use short paragraphs for readability.
|
198 |
+
- Include bullet points or numbered lists where appropriate.
|
199 |
+
- Avoid overly technical terms unless explained simply.
|
200 |
+
- End with a helpful next step, such as lifestyle advice or when to see a doctor.
|
201 |
+
|
202 |
+
π©Ί Response Structure:
|
203 |
+
1. **Overview** β Briefly introduce the condition or concern.
|
204 |
+
2. **Common Symptoms** β Describe noticeable signs in simple terms.
|
205 |
+
3. **Causes & Risk Factors** β Include genetic, lifestyle, and environmental aspects.
|
206 |
+
4. **Treatment Options** β Outline common OTC and prescription treatments.
|
207 |
+
5. **When to Seek Help** β Warn about symptoms that require urgent care.
|
208 |
+
|
209 |
+
Always encourage consulting a licensed dermatologist for personal diagnosis and treatment. For any breathing difficulties, serious infections, or rapid symptom worsening, advise calling emergency services immediately.
|
210 |
+
|
211 |
+
---
|
212 |
+
|
213 |
+
Query: {question}
|
214 |
+
Relevant Context: {context}
|
215 |
+
|
216 |
+
Your Response:
|
217 |
+
"""
|
218 |
+
|
219 |
+
prompt_template = PromptTemplate(template=AI_PROMPT_TEMPLATE, input_variables=["question", "context"])
|
220 |
+
|
221 |
+
# rag_chain = RetrievalQA.from_chain_type(
|
222 |
+
# llm=llm,
|
223 |
+
# retriever=retriever,
|
224 |
+
# chain_type="stuff",
|
225 |
+
# chain_type_kwargs={"prompt": prompt_template, "document_variable_name": "context"}
|
226 |
+
# )
|
227 |
+
|
228 |
+
# === Class Names ===
|
229 |
+
multilabel_class_names = [
|
230 |
+
"Vesicle", "Papule", "Macule", "Plaque", "Abscess", "Pustule", "Bulla", "Patch",
|
231 |
+
"Nodule", "Ulcer", "Crust", "Erosion", "Excoriation", "Atrophy", "Exudate", "Purpura/Petechiae",
|
232 |
+
"Fissure", "Induration", "Xerosis", "Telangiectasia", "Scale", "Scar", "Friable", "Sclerosis",
|
233 |
+
"Pedunculated", "Exophytic/Fungating", "Warty/Papillomatous", "Dome-shaped", "Flat topped",
|
234 |
+
"Brown(Hyperpigmentation)", "Translucent", "White(Hypopigmentation)", "Purple", "Yellow",
|
235 |
+
"Black", "Erythema", "Comedo", "Lichenification", "Blue", "Umbilicated", "Poikiloderma",
|
236 |
+
"Salmon", "Wheal", "Acuminate", "Burrow", "Gray", "Pigmented", "Cyst"
|
237 |
+
]
|
238 |
+
|
239 |
+
multiclass_class_names = [
|
240 |
+
"systemic", "hair", "drug_reactions", "uriticaria", "acne", "light",
|
241 |
+
"autoimmune", "papulosquamous", "eczema", "skincancer",
|
242 |
+
"benign_tumors", "bacteria_parasetic_infections", "fungal_infections", "viral_skin_infections"
|
243 |
+
]
|
244 |
+
|
245 |
+
|
246 |
+
# === Load Models ===
|
247 |
+
class SkinViT(nn.Module):
|
248 |
+
def __init__(self, num_classes):
|
249 |
+
super(SkinViT, self).__init__()
|
250 |
+
self.model = vit_b_16(weights=ViT_B_16_Weights.DEFAULT)
|
251 |
+
in_features = self.model.heads.head.in_features
|
252 |
+
self.model.heads.head = nn.Linear(in_features, num_classes)
|
253 |
+
|
254 |
+
def forward(self, x):
|
255 |
+
return self.model(x)
|
256 |
+
|
257 |
+
|
258 |
+
class DermNetViT(nn.Module):
|
259 |
+
def __init__(self, num_classes):
|
260 |
+
super(DermNetViT, self).__init__()
|
261 |
+
self.model = vit_l_16(weights=ViT_L_16_Weights.DEFAULT)
|
262 |
+
in_features = self.model.heads[0].in_features
|
263 |
+
self.model.heads[0] = nn.Sequential(
|
264 |
+
nn.Dropout(0.3),
|
265 |
+
nn.Linear(in_features, num_classes)
|
266 |
+
)
|
267 |
+
|
268 |
+
def forward(self, x):
|
269 |
+
return self.model(x)
|
270 |
+
|
271 |
+
|
272 |
+
# multilabel_model = torch.load("D:/DR/RAG/BestModels2703/skin_vit_fold10.pth", map_location='cpu')
|
273 |
+
# multiclass_model = torch.load("D:/DR/RAG/BestModels2703/best_dermnet_vit.pth", map_location='cpu')
|
274 |
+
|
275 |
+
# === Load Model State Dicts ===
|
276 |
+
multilabel_model_path = hf_hub_download(repo_id="santhoshraghu/DermBOT", filename="skin_vit_fold10_sd.pth")
|
277 |
+
multiclass_model_path = hf_hub_download(repo_id="santhoshraghu/DermBOT", filename="best_dermnet_vit_sd.pth")
|
278 |
+
|
279 |
+
|
280 |
+
def load_model_with_fallback(model_class, weight_path, num_classes, model_name):
|
281 |
+
try:
|
282 |
+
print(f"π Loading {model_name} on GPU...")
|
283 |
+
model = model_class(num_classes)
|
284 |
+
model.load_state_dict(torch.load(weight_path, map_location="cuda"))
|
285 |
+
model.to("cuda")
|
286 |
+
print(f"β
{model_name} loaded on GPU.")
|
287 |
+
return model
|
288 |
+
except RuntimeError as e:
|
289 |
+
if "CUDA out of memory" in str(e):
|
290 |
+
print(f"β οΈ {model_name} OOM. Falling back to CPU.")
|
291 |
+
else:
|
292 |
+
print(f"β Error loading {model_name} on CUDA: {e}")
|
293 |
+
print(f"π Loading {model_name} on CPU...")
|
294 |
+
model = model_class(num_classes)
|
295 |
+
model.load_state_dict(torch.load(weight_path, map_location="cpu"))
|
296 |
+
model.to("cpu")
|
297 |
+
return model
|
298 |
+
|
299 |
+
|
300 |
+
# Load both models with fallback
|
301 |
+
multilabel_model = load_model_with_fallback(SkinViT, multilabel_model_path, len(multilabel_class_names), "SkinViT")
|
302 |
+
multiclass_model = load_model_with_fallback(DermNetViT, multiclass_model_path, len(multiclass_class_names),
|
303 |
+
"DermNetViT")
|
304 |
+
|
305 |
+
multilabel_model.eval()
|
306 |
+
multiclass_model.eval()
|
307 |
+
|
308 |
+
# === Session Init ===
|
309 |
+
if "messages" not in st.session_state:
|
310 |
+
st.session_state.messages = []
|
311 |
+
|
312 |
+
|
313 |
+
# === Image Processing Function ===
|
314 |
+
def run_inference(image):
|
315 |
+
transform = transforms.Compose([
|
316 |
+
transforms.Resize((224, 224)),
|
317 |
+
transforms.ToTensor(),
|
318 |
+
transforms.Normalize([0.5], [0.5])
|
319 |
+
])
|
320 |
+
input_tensor = transform(image).unsqueeze(0)
|
321 |
+
|
322 |
+
# Automatically match model device (GPU or CPU)
|
323 |
+
model_device = next(multilabel_model.parameters()).device
|
324 |
+
input_tensor = input_tensor.to(model_device)
|
325 |
+
|
326 |
+
with torch.no_grad():
|
327 |
+
probs_multi = torch.sigmoid(multilabel_model(input_tensor)).squeeze().cpu().numpy()
|
328 |
+
pred_idx = torch.argmax(multiclass_model(input_tensor), dim=1).item()
|
329 |
+
predicted_multi = [multilabel_class_names[i] for i, p in enumerate(probs_multi) if p > 0.5]
|
330 |
+
predicted_single = multiclass_class_names[pred_idx]
|
331 |
+
|
332 |
+
return predicted_multi, predicted_single
|
333 |
+
|
334 |
+
|
335 |
+
# === PDF Export ===
|
336 |
+
def export_chat_to_pdf(messages):
|
337 |
+
pdf = FPDF()
|
338 |
+
pdf.add_page()
|
339 |
+
pdf.set_font("Arial", size=12)
|
340 |
+
for msg in messages:
|
341 |
+
role = "You" if msg["role"] == "user" else "AI"
|
342 |
+
pdf.multi_cell(0, 10, f"{role}: {msg['content']}\n")
|
343 |
+
buf = io.BytesIO()
|
344 |
+
pdf.output(buf)
|
345 |
+
buf.seek(0)
|
346 |
+
return buf
|
347 |
+
|
348 |
+
|
349 |
+
# Reranker utility
|
350 |
+
def rerank_with_cohere(query, documents, top_n=5):
|
351 |
+
if not documents:
|
352 |
+
return []
|
353 |
+
raw_texts = [doc.page_content for doc in documents]
|
354 |
+
results = co.rerank(query=query, documents=raw_texts, top_n=min(top_n, len(raw_texts)), model="rerank-v3.5")
|
355 |
+
return [documents[result.index] for result in results]
|
356 |
+
|
357 |
+
|
358 |
+
# Final answer generation using reranked context
|
359 |
+
def get_reranked_response(query):
|
360 |
+
docs = retriever.get_relevant_documents(query)
|
361 |
+
reranked_docs = rerank_with_cohere(query, docs)
|
362 |
+
context = "\n\n".join([doc.page_content for doc in reranked_docs])
|
363 |
+
prompt = AI_PROMPT_TEMPLATE.format(question=query, context=context)
|
364 |
+
|
365 |
+
if selected_model == "All":
|
366 |
+
responses = get_all_model_responses(prompt)
|
367 |
+
fused = rank_and_fuse(prompt, responses)
|
368 |
+
return type("Obj", (), {"content": fused})
|
369 |
+
|
370 |
+
if callable(llm):
|
371 |
+
return type("Obj", (), {"content": llm(prompt)})
|
372 |
+
else:
|
373 |
+
return llm.invoke([{"role": "system", "content": prompt}])
|
374 |
+
|
375 |
+
|
376 |
+
# === App UI ===
|
377 |
+
|
378 |
+
st.title("𧬠DermBOT β Skin AI Assistant")
|
379 |
+
st.caption(f"π§ Using model: {selected_model}")
|
380 |
+
uploaded_file = st.file_uploader("Upload a skin image", type=["jpg", "jpeg", "png"])
|
381 |
+
|
382 |
+
if uploaded_file:
|
383 |
+
st.image(uploaded_file, caption="Uploaded image", use_container_width=True)
|
384 |
+
image = Image.open(uploaded_file).convert("RGB")
|
385 |
+
|
386 |
+
predicted_multi, predicted_single = run_inference(image)
|
387 |
+
|
388 |
+
# Show predictions clearly to the user
|
389 |
+
st.markdown(f"π§Ύ **Skin Issues**: {', '.join(predicted_multi)}")
|
390 |
+
st.markdown(f"π **Most Likely Diagnosis**: {predicted_single}")
|
391 |
+
|
392 |
+
query = f"What are my treatment options for {predicted_multi} and {predicted_single}?"
|
393 |
+
st.session_state.messages.append({"role": "user", "content": query})
|
394 |
+
|
395 |
+
with st.spinner("π Analyzing and retrieving context..."):
|
396 |
+
response = get_reranked_response(query)
|
397 |
+
st.session_state.messages.append({"role": "assistant", "content": response.content})
|
398 |
+
|
399 |
+
with st.chat_message("assistant"):
|
400 |
+
st.markdown(response.content)
|
401 |
+
|
402 |
+
# === Chat Interface ===
|
403 |
+
if prompt := st.chat_input("Ask a follow-up..."):
|
404 |
+
st.session_state.messages.append({"role": "user", "content": prompt})
|
405 |
+
with st.chat_message("user"):
|
406 |
+
st.markdown(prompt)
|
407 |
+
|
408 |
+
response = get_reranked_response(prompt)
|
409 |
+
st.session_state.messages.append({"role": "assistant", "content": response.content})
|
410 |
+
with st.chat_message("assistant"):
|
411 |
+
st.markdown(response.content)
|
412 |
+
|
413 |
+
# === PDF Button ===
|
414 |
+
if st.button("π Download Chat as PDF"):
|
415 |
+
pdf_file = export_chat_to_pdf(st.session_state.messages)
|
416 |
+
st.download_button("Download PDF", data=pdf_file, file_name="chat_history.pdf", mime="application/pdf")
|