KeerthiVM commited on
Commit
dc9062b
Β·
1 Parent(s): 1505823
Files changed (5) hide show
  1. .streamlit/secrets.toml +4 -0
  2. app.py +33 -70
  3. rag_pipeline.py +142 -18
  4. requirements.txt +4 -1
  5. test.py +416 -0
.streamlit/secrets.toml CHANGED
@@ -1 +1,5 @@
1
  OPENAI_API_KEY = "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"
 
 
 
 
 
1
  OPENAI_API_KEY = "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"
2
+ QDRANT_API_KEY = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIn0.HXzezXdWMFeeR16F7zvqgjzsqrcm8hqa-StXdToFP9Q"
3
+ COHERE_API_KEY = "PBEEQJ8n9sV2Xhpc7OMb8NRsBKtADvcEi9V0iPm5"
4
+ GROQ_API_KEY = "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"
5
+ GOOGLE_API_KEY = "AIzaSyD5pnSzkIuu86ByTPQewVKlh2zxOJI-f8M"
app.py CHANGED
@@ -8,14 +8,23 @@ import torch.nn.functional as F
8
  from evo_vit import EvoViTModel
9
  import io
10
  import os
 
11
  from fpdf import FPDF
12
  from torchvision.models import resnet50
13
  import nest_asyncio
14
  from huggingface_hub import hf_hub_download
15
  from langchain_openai import OpenAIEmbeddings, ChatOpenAI
16
  from SkinCancerDiagnosis import initialize_classifier
17
- from rag_pipeline import invoke_rag_chain
 
 
 
 
 
 
18
  from langchain_core.messages import HumanMessage, AIMessage
 
 
19
 
20
  nest_asyncio.apply()
21
  device='cuda' if torch.cuda.is_available() else 'cpu'
@@ -25,33 +34,15 @@ st.set_page_config(page_title="DermBOT", page_icon="🧬", layout="centered")
25
  @st.cache_resource(show_spinner=False)
26
  def load_models():
27
  """Cache all models to load only once"""
28
- with st.spinner("Loading AI models (one-time operation)..."):
29
- classifier = initialize_classifier()
30
- return classifier
31
-
32
- def initialize_llm(_model_name, _api_key):
33
- """Initialize the LLM based on selection"""
34
- print(f"Model name : {_model_name}")
35
- if "OpenAI" in _model_name:
36
- return ChatOpenAI(model="gpt-4o", temperature=0.2, api_key=_api_key)
37
- elif "LLaMA" in _model_name:
38
- st.warning("LLaMA integration is not implemented yet.")
39
- st.stop()
40
- elif "Gemini" in _model_name:
41
- st.warning("Gemini integration is not implemented yet.")
42
- st.stop()
43
- else:
44
- st.error("Unsupported model selected.")
45
- st.stop()
46
-
47
- @st.cache_resource(show_spinner=False)
48
- def load_rag_chain(_model_name, _api_key):
49
- """Initialize RAG chain only once"""
50
- llm = initialize_llm(_model_name, _api_key)
51
- return invoke_rag_chain(llm)
52
 
53
- # === Model Selection ===
54
- available_models = ["OpenAI GPT-4o", "LLaMA 3", "Gemini Pro"]
55
  if "selected_model" not in st.session_state:
56
  st.session_state["selected_model"] = available_models[0]
57
 
@@ -63,7 +54,10 @@ st.session_state["selected_model"] = st.sidebar.selectbox(
63
  index=available_models.index(st.session_state["selected_model"])
64
  )
65
 
66
- OPENAI_API_KEY = st.secrets["OPENAI_API_KEY"]
 
 
 
67
 
68
  if "model_change_confirmed" not in st.session_state:
69
  st.session_state.model_change_confirmed = False
@@ -88,15 +82,12 @@ if st.session_state["selected_model"] != previous_model:
88
  st.session_state.model_change_confirmed = True
89
 
90
  if "model_change_confirmed" not in st.session_state or st.session_state.model_change_confirmed:
91
- llm = initialize_llm(st.session_state["selected_model"], OPENAI_API_KEY)
92
- rag_chain = load_rag_chain(st.session_state["selected_model"], OPENAI_API_KEY)
93
- st.session_state.llm = llm
94
- st.session_state.rag_chain = rag_chain
95
  else:
96
- llm = st.session_state.get("llm", initialize_llm(previous_model, OPENAI_API_KEY))
97
- rag_chain = st.session_state.get("rag_chain", load_rag_chain(previous_model, OPENAI_API_KEY))
98
 
99
- classifier = load_models()
100
 
101
  # === Session Init ===
102
  if "messages" not in st.session_state:
@@ -149,10 +140,10 @@ if uploaded_file is not None and uploaded_file != st.session_state.current_image
149
 
150
  initial_query = f"What are my treatment options for {predicted_label}?"
151
  st.session_state.messages.append({"role": "user", "content": initial_query})
152
-
153
  with st.spinner("Retrieving medical information..."):
154
- response = rag_chain.invoke(initial_query)
155
- st.session_state.messages.append({"role": "assistant", "content": response['result']})
 
156
 
157
  for message in st.session_state.messages:
158
  with st.chat_message(message["role"]):
@@ -165,32 +156,6 @@ if prompt := st.chat_input("Ask a follow-up question..."):
165
  st.markdown(prompt)
166
 
167
  with st.chat_message("assistant"):
168
- # with st.spinner("Thinking..."):
169
- # Convert messages to LangChain format
170
- # chat_history = []
171
- # for msg in st.session_state.messages[:-1]: # Exclude the current prompt
172
- # if msg["role"] == "user":
173
- # chat_history.append(HumanMessage(content=msg["content"]))
174
- # else:
175
- # chat_history.append(AIMessage(content=msg["content"]))
176
- #
177
- # # Get response
178
- # response = llm.invoke([HumanMessage(content=prompt)] + chat_history)
179
- # assistant_response = response.content
180
- #
181
- # st.markdown(assistant_response)
182
- # st.session_state.messages.append({"role": "assistant", "content": assistant_response})
183
- # with st.spinner("Thinking..."):
184
- # if len(st.session_state.messages) > 1:
185
- # response = llm.invoke([{"role": m["role"], "content": m["content"]} for m in st.session_state.messages])
186
- # response = response.content
187
- # else:
188
- # response = rag_chain.invoke(prompt)
189
- # response = response['result']
190
- #
191
- # st.markdown(response)
192
- # st.session_state.messages.append({"role": "assistant", "content": response})
193
-
194
  with st.spinner("Thinking..."):
195
  if len(st.session_state.messages) > 1:
196
  conversation_context = "\n".join(
@@ -201,14 +166,12 @@ if prompt := st.chat_input("Ask a follow-up question..."):
201
  f"Conversation history:\n{conversation_context}\n\n"
202
  f"Current question: {prompt}"
203
  )
204
- response = rag_chain.invoke({"query": augmented_prompt})
205
- assistant_response = response['result']
206
  else:
207
- response = rag_chain.invoke({"query": prompt})
208
- assistant_response = response['result']
209
 
210
- st.markdown(assistant_response)
211
- st.session_state.messages.append({"role": "assistant", "content": assistant_response})
212
 
213
  if st.session_state.messages and st.button("πŸ“„ Download Chat as PDF"):
214
  pdf_file = export_chat_to_pdf(st.session_state.messages)
 
8
  from evo_vit import EvoViTModel
9
  import io
10
  import os
11
+ import cohere
12
  from fpdf import FPDF
13
  from torchvision.models import resnet50
14
  import nest_asyncio
15
  from huggingface_hub import hf_hub_download
16
  from langchain_openai import OpenAIEmbeddings, ChatOpenAI
17
  from SkinCancerDiagnosis import initialize_classifier
18
+ from rag_pipeline import (
19
+ available_models,
20
+ initialize_llm,
21
+ load_rag_chain,
22
+ get_reranked_response,
23
+ initialize_rag_components
24
+ )
25
  from langchain_core.messages import HumanMessage, AIMessage
26
+ from groq import Groq
27
+ import google.generativeai as genai
28
 
29
  nest_asyncio.apply()
30
  device='cuda' if torch.cuda.is_available() else 'cpu'
 
34
  @st.cache_resource(show_spinner=False)
35
  def load_models():
36
  """Cache all models to load only once"""
37
+ with st.spinner("Loading all AI models (one-time operation)..."):
38
+ models = {
39
+ 'classifier': initialize_classifier(),
40
+ 'rag_components': initialize_rag_components(),
41
+ 'llm': initialize_llm(st.session_state["selected_model"])
42
+ }
43
+ models['rag_chain'] = load_rag_chain(models['llm'])
44
+ return models
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
 
 
46
  if "selected_model" not in st.session_state:
47
  st.session_state["selected_model"] = available_models[0]
48
 
 
54
  index=available_models.index(st.session_state["selected_model"])
55
  )
56
 
57
+ if 'app_models' not in st.session_state:
58
+ st.session_state.app_models = load_models()
59
+ classifier = st.session_state.app_models['classifier']
60
+ llm = st.session_state.app_models['llm']
61
 
62
  if "model_change_confirmed" not in st.session_state:
63
  st.session_state.model_change_confirmed = False
 
82
  st.session_state.model_change_confirmed = True
83
 
84
  if "model_change_confirmed" not in st.session_state or st.session_state.model_change_confirmed:
85
+ st.session_state.app_models['llm'] = initialize_llm(st.session_state["selected_model"])
86
+ st.session_state.app_models['rag_chain'] = load_rag_chain(st.session_state.app_models['llm'])
87
+ llm = st.session_state.app_models['llm']
 
88
  else:
89
+ pass
 
90
 
 
91
 
92
  # === Session Init ===
93
  if "messages" not in st.session_state:
 
140
 
141
  initial_query = f"What are my treatment options for {predicted_label}?"
142
  st.session_state.messages.append({"role": "user", "content": initial_query})
 
143
  with st.spinner("Retrieving medical information..."):
144
+ response = get_reranked_response(initial_query, st.session_state.app_models['llm'], st.session_state.app_models['rag_components'])
145
+ st.session_state.messages.append({"role": "assistant", "content": response})
146
+
147
 
148
  for message in st.session_state.messages:
149
  with st.chat_message(message["role"]):
 
156
  st.markdown(prompt)
157
 
158
  with st.chat_message("assistant"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  with st.spinner("Thinking..."):
160
  if len(st.session_state.messages) > 1:
161
  conversation_context = "\n".join(
 
166
  f"Conversation history:\n{conversation_context}\n\n"
167
  f"Current question: {prompt}"
168
  )
169
+ response = get_reranked_response(augmented_prompt, st.session_state.app_models['llm'], st.session_state.app_models['rag_components'])
 
170
  else:
171
+ response = get_reranked_response(prompt, st.session_state.app_models['llm'], st.session_state.app_models['rag_components'])
 
172
 
173
+ st.markdown(response)
174
+ st.session_state.messages.append({"role": "assistant", "content": response})
175
 
176
  if st.session_state.messages and st.button("πŸ“„ Download Chat as PDF"):
177
  pdf_file = export_chat_to_pdf(st.session_state.messages)
rag_pipeline.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from langchain.chains import RetrievalQA
2
  from langchain.prompts import PromptTemplate
3
  from sentence_transformers import SentenceTransformer
@@ -5,10 +6,89 @@ from qdrant_client import QdrantClient
5
  from langchain_qdrant import Qdrant
6
  from langchain_community.embeddings import HuggingFaceEmbeddings
7
  from langchain_community.embeddings import SentenceTransformerEmbeddings
 
8
  import os
9
  import torch
 
 
 
 
10
 
11
- def invoke_rag_chain(llm):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  # === Qdrant DB Setup ===
13
  qdrant_client = QdrantClient(
14
  url="https://2715ddd8-647f-40ee-bca4-9027d193e8aa.us-east-1-0.aws.cloud.qdrant.io",
@@ -30,30 +110,74 @@ def invoke_rag_chain(llm):
30
  collection_name=collection_name,
31
  embeddings=local_embedding
32
  )
33
- retriever = vector_store.as_retriever()
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
- AI_PROMPT_TEMPLATE = """You are an AI-assisted Dermatology Chatbot, specializing in diagnosing and educating users about skin diseases.
37
- You provide accurate, compassionate, and detailed explanations while using correct medical terminology.
38
-
39
- Guidelines:
40
- 1. Symptoms - Explain in simple terms with proper medical definitions.
41
- 2. Causes - Include genetic, environmental, and lifestyle-related risk factors.
42
- 3. Medications & Treatments - Provide common prescription and over-the-counter treatments.
43
- 4. Warnings & Emergencies - Always recommend consulting a licensed dermatologist.
44
- 5. Emergency Note - If symptoms worsen or include difficulty breathing, **advise calling 911 immediately.
45
-
46
- Query: {question}
47
- Relevant Information: {context}
48
- Answer:
49
- """
50
  prompt_template = PromptTemplate(template=AI_PROMPT_TEMPLATE, input_variables=["question", "context"])
51
 
52
  rag_chain = RetrievalQA.from_chain_type(
53
  llm=llm,
54
- retriever=retriever,
55
  chain_type="stuff",
56
  chain_type_kwargs={"prompt": prompt_template, "document_variable_name": "context"}
57
  )
58
 
59
- return rag_chain
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
  from langchain.chains import RetrievalQA
3
  from langchain.prompts import PromptTemplate
4
  from sentence_transformers import SentenceTransformer
 
6
  from langchain_qdrant import Qdrant
7
  from langchain_community.embeddings import HuggingFaceEmbeddings
8
  from langchain_community.embeddings import SentenceTransformerEmbeddings
9
+ from transformers import pipeline
10
  import os
11
  import torch
12
+ from groq import Groq
13
+ import google.generativeai as genai
14
+ from langchain_openai import OpenAIEmbeddings, ChatOpenAI
15
+ import cohere
16
 
17
+ available_models = ["OpenAI GPT-4o", "LLaMA 3", "Gemini Pro", "Ensemble"]
18
+ AI_PROMPT_TEMPLATE = """You are an AI-assisted Dermatology Chatbot, specializing in diagnosing and educating users about skin diseases.
19
+ You provide accurate, compassionate, and detailed explanations while using correct medical terminology.
20
+
21
+ Guidelines:
22
+ 1. Symptoms - Explain in simple terms with proper medical definitions.
23
+ 2. Causes - Include genetic, environmental, and lifestyle-related risk factors.
24
+ 3. Medications & Treatments - Provide common prescription and over-the-counter treatments.
25
+ 4. Warnings & Emergencies - Always recommend consulting a licensed dermatologist.
26
+ 5. Emergency Note - If symptoms worsen or include difficulty breathing, **advise calling 911 immediately.
27
+
28
+ Query: {question}
29
+ Relevant Information: {context}
30
+ Answer:
31
+ """
32
+
33
+
34
+ @st.cache_resource(show_spinner=False)
35
+ def initialize_rag_components():
36
+ components = {
37
+ 'cohere_client': cohere.Client(st.secrets["COHERE_API_KEY"]),
38
+ 'pair_ranker': pipeline("text-classification",
39
+ model="llm-blender/PairRM",
40
+ tokenizer="llm-blender/PairRM",
41
+ return_all_scores=True
42
+ ),
43
+ 'gen_fuser': pipeline("text-generation",
44
+ model="llm-blender/gen_fuser_3b",
45
+ tokenizer="llm-blender/gen_fuser_3b",
46
+ max_length=2048,
47
+ do_sample=False
48
+ ),
49
+ 'retriever': get_retriever()
50
+ }
51
+ return components
52
+
53
+ class AllModelsWrapper:
54
+ def invoke(self, messages):
55
+ prompt = messages[0]["content"]
56
+ rag_components = st.session_state.app_models['rag_components'] # Get components
57
+ responses = get_all_responses(prompt)
58
+ fused = rank_and_fuse(prompt, responses, rag_components)
59
+ return type('obj', (object,), {'content': fused})()
60
+
61
+ def get_all_responses(prompt):
62
+ # Get responses from all models
63
+ openai_resp = ChatOpenAI(model="gpt-4o", temperature=0.2,
64
+ api_key=st.secrets["OPENAI_API_KEY"]).invoke(
65
+ [{"role": "user", "content": prompt}]).content
66
+
67
+ gemini = genai.GenerativeModel("gemini-2.5-pro-exp-03-25")
68
+ gemini_resp = gemini.generate_content(prompt).text
69
+
70
+ llama = Groq(api_key=st.secrets["GROQ_API_KEY"])
71
+ llama_resp = llama.chat.completions.create(
72
+ model="meta-llama/llama-4-maverick-17b-128e-instruct",
73
+ messages=[{"role": "user", "content": prompt}],
74
+ temperature=1, max_completion_tokens=1024, top_p=1, stream=False
75
+ ).choices[0].message.content
76
+
77
+ return [openai_resp, gemini_resp, llama_resp]
78
+
79
+
80
+ def rank_and_fuse(prompt, responses, rag_components):
81
+ ranked = [(resp, rag_components['pair_ranker'](f"{prompt}\n\n{resp}")[0][1]['score'])
82
+ for resp in responses]
83
+ ranked.sort(key=lambda x: x[1], reverse=True)
84
+
85
+ # Fuse top responses
86
+ fusion_input = "\n\n".join([f"[Answer {i + 1}]: {ans}" for i, (ans, _) in enumerate(ranked[:2])])
87
+ return rag_components['gen_fuser'](f"Fuse these responses:\n{fusion_input}",
88
+ return_full_text=False)[0]['generated_text']
89
+
90
+
91
+ def get_retriever():
92
  # === Qdrant DB Setup ===
93
  qdrant_client = QdrantClient(
94
  url="https://2715ddd8-647f-40ee-bca4-9027d193e8aa.us-east-1-0.aws.cloud.qdrant.io",
 
110
  collection_name=collection_name,
111
  embeddings=local_embedding
112
  )
113
+ return vector_store.as_retriever()
114
 
115
+ def initialize_llm(_model_name):
116
+ """Initialize the LLM based on selection"""
117
+ print(f"Model name : {_model_name}")
118
+ if "OpenAI" in _model_name:
119
+ return ChatOpenAI(model="gpt-4o", temperature=0.2, api_key=st.secrets["OPENAI_API_KEY"])
120
+ elif "LLaMA" in _model_name:
121
+ client = Groq(api_key=st.secrets["GROQ_API_KEY"])
122
+ def get_llama_response(prompt):
123
+ completion = client.chat.completions.create(
124
+ model="meta-llama/llama-4-maverick-17b-128e-instruct",
125
+ messages=[{"role": "user", "content": prompt}],
126
+ temperature=1,
127
+ max_completion_tokens=1024,
128
+ top_p=1,
129
+ stream=False
130
+ )
131
+ return completion.choices[0].message.content
132
+ return type('obj', (object,), {'invoke': lambda self, x: get_llama_response(x[0]["content"])})()
133
+
134
+ elif "Gemini" in _model_name:
135
+ genai.configure(api_key=st.secrets["GEMINI_API_KEY"])
136
+ gemini_model = genai.GenerativeModel("gemini-2.5-pro-exp-03-25")
137
+ def get_gemini_response(prompt):
138
+ response = gemini_model.generate_content(prompt)
139
+ return response.text
140
+ return type('obj', (object,), {'invoke': lambda self, x: get_gemini_response(x[0]["content"])})()
141
+
142
+ elif "Ensemble" in _model_name:
143
+ return AllModelsWrapper()
144
+ else:
145
+ raise ValueError("Unsupported model selected")
146
+
147
+
148
+ def load_rag_chain(llm):
149
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  prompt_template = PromptTemplate(template=AI_PROMPT_TEMPLATE, input_variables=["question", "context"])
151
 
152
  rag_chain = RetrievalQA.from_chain_type(
153
  llm=llm,
154
+ retriever=get_retriever(),
155
  chain_type="stuff",
156
  chain_type_kwargs={"prompt": prompt_template, "document_variable_name": "context"}
157
  )
158
 
159
+ return rag_chain
160
+
161
+ def rerank_with_cohere(query, documents, co, top_n=5):
162
+ if not documents:
163
+ return []
164
+ raw_texts = [doc.page_content for doc in documents]
165
+ results = co.rerank(query=query, documents=raw_texts, top_n=min(top_n, len(raw_texts)), model="rerank-v3.5")
166
+ return [documents[result.index] for result in results]
167
+
168
+
169
+ def get_reranked_response(query, llm, rag_components):
170
+ """Get response with reranking"""
171
+ docs = rag_components['retriever'].get_relevant_documents(query)
172
+ reranked_docs = rerank_with_cohere(query, docs, rag_components['cohere_client'])
173
+ context = "\n\n".join([doc.page_content for doc in reranked_docs])
174
+
175
+ if isinstance(llm, (ChatOpenAI, AllModelsWrapper)):
176
+ return load_rag_chain(llm).invoke({"query": query, "context": context})['result']
177
+ else:
178
+ prompt = AI_PROMPT_TEMPLATE.format(question=query, context=context)
179
+ return llm.invoke([{"role": "user", "content": prompt}]).content
180
+
181
+
182
+ if __name__ == "__main__":
183
+ print("This is a module - import it instead of running directly")
requirements.txt CHANGED
@@ -17,4 +17,7 @@ nest_asyncio
17
  sentence_transformers
18
  langchain-qdrant
19
  huggingface_hub
20
- langchain_core
 
 
 
 
17
  sentence_transformers
18
  langchain-qdrant
19
  huggingface_hub
20
+ langchain_core
21
+ groq
22
+ google
23
+ cohere
test.py ADDED
@@ -0,0 +1,416 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from PIL import Image
3
+ import torch
4
+ import cohere
5
+ import torch.nn as nn
6
+ from torchvision import transforms
7
+ from torchvision.models import vit_b_16, vit_l_16, ViT_B_16_Weights, ViT_L_16_Weights
8
+ import pandas as pd
9
+ from huggingface_hub import hf_hub_download
10
+ from langchain_huggingface import HuggingFaceEmbeddings
11
+ import io
12
+ import os
13
+ import base64
14
+ from fpdf import FPDF
15
+ from sqlalchemy import create_engine
16
+ from langchain.chains import RetrievalQA
17
+ from langchain.prompts import PromptTemplate
18
+ from qdrant_client import QdrantClient
19
+ from qdrant_client.http.models import Distance, VectorParams
20
+ from sentence_transformers import SentenceTransformer
21
+ # from langchain_community.vectorstores.pgvector import PGVector
22
+ # from langchain_postgres import PGVector
23
+ from langchain_openai import OpenAIEmbeddings, ChatOpenAI
24
+ from langchain_community.vectorstores import Qdrant
25
+ from langchain_community.embeddings import HuggingFaceEmbeddings
26
+ from langchain_community.embeddings import SentenceTransformerEmbeddings
27
+ from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification, AutoModelForCausalLM
28
+ import nest_asyncio
29
+
30
+ torch.cuda.empty_cache()
31
+ nest_asyncio.apply()
32
+ co = cohere.Client(st.secrets["COHERE_API_KEY"])
33
+
34
+ st.set_page_config(page_title="DermBOT", page_icon="🧬", layout="centered")
35
+
36
+ # === Model Selection ===
37
+ available_models = ["GPT-4o", "LLaMA 4 Maverick", "Gemini 2.5 Pro", "All"]
38
+ st.session_state["selected_model"] = st.sidebar.selectbox("Select LLM Model", available_models)
39
+
40
+ # === Qdrant DB Setup ===
41
+ qdrant_client = QdrantClient(
42
+ url="https://2715ddd8-647f-40ee-bca4-9027d193e8aa.us-east-1-0.aws.cloud.qdrant.io",
43
+ api_key="eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIn0.HXzezXdWMFeeR16F7zvqgjzsqrcm8hqa-StXdToFP9Q"
44
+ )
45
+ collection_name = "ks_collection_1.5BE"
46
+ # embedding_model = SentenceTransformer("D:\DR\RAG\gte-Qwen2-1.5B-instruct", trust_remote_code=True)
47
+ # embedding_model.max_seq_length = 8192
48
+ # local_embedding = SentenceTransformerEmbeddings(model=embedding_model)
49
+
50
+
51
+ device = "cuda" if torch.cuda.is_available() else "cpu"
52
+
53
+
54
+ def get_safe_embedding_model():
55
+ model_name = "Alibaba-NLP/gte-Qwen2-1.5B-instruct"
56
+
57
+ try:
58
+ print("Trying to load embedding model on CUDA...")
59
+ embedding = HuggingFaceEmbeddings(
60
+ model_name=model_name,
61
+ model_kwargs={
62
+ "trust_remote_code": True,
63
+ "device": "cuda"
64
+ }
65
+ )
66
+ print("Loaded embedding model on GPU.")
67
+ return embedding
68
+ except RuntimeError as e:
69
+ if "CUDA out of memory" in str(e):
70
+ print("CUDA OOM. Falling back to CPU.")
71
+ else:
72
+ print(" Error loading model on CUDA:", str(e))
73
+ print("Loading embedding model on CPU...")
74
+ return HuggingFaceEmbeddings(
75
+ model_name=model_name,
76
+ model_kwargs={
77
+ "trust_remote_code": True,
78
+ "device": "cpu"
79
+ }
80
+ )
81
+
82
+
83
+ # Replace your old local_embedding line with this
84
+ local_embedding = get_safe_embedding_model()
85
+
86
+ print(" Qwen2-1.5B local embedding model loaded.")
87
+
88
+ vector_store = Qdrant(
89
+ client=qdrant_client,
90
+ collection_name=collection_name,
91
+ embeddings=local_embedding
92
+ )
93
+ retriever = vector_store.as_retriever()
94
+
95
+ pair_ranker = pipeline(
96
+ "text-classification",
97
+ model="llm-blender/PairRM",
98
+ tokenizer="llm-blender/PairRM",
99
+ return_all_scores=True
100
+ )
101
+
102
+ gen_fuser = pipeline(
103
+ "text-generation",
104
+ model="llm-blender/gen_fuser_3b",
105
+ tokenizer="llm-blender/gen_fuser_3b",
106
+ max_length=2048,
107
+ do_sample=False
108
+ )
109
+
110
+ # selected_model = st.session_state["selected_model"]
111
+
112
+ if "OpenAI" in selected_model:
113
+ from langchain_openai import ChatOpenAI
114
+
115
+ llm = ChatOpenAI(model="gpt-4o", temperature=0.2, api_key=st.secrets["OPENAI_API_KEY"])
116
+
117
+ elif "LLaMA" in selected_model:
118
+ from groq import Groq
119
+
120
+ client = Groq(api_key=st.secrets["GROQ_API_KEY"]) # Store in `.streamlit/secrets.toml`
121
+
122
+
123
+ def get_llama_response(prompt):
124
+ completion = client.chat.completions.create(
125
+ model="meta-llama/llama-4-maverick-17b-128e-instruct",
126
+ messages=[{"role": "user", "content": prompt}],
127
+ temperature=1,
128
+ max_completion_tokens=1024,
129
+ top_p=1,
130
+ stream=False
131
+ )
132
+ return completion.choices[0].message.content
133
+
134
+
135
+ llm = get_llama_response # use this in place of llm.invoke()
136
+
137
+ elif "Gemini" in selected_model:
138
+ import google.generativeai as genai
139
+
140
+ genai.configure(api_key=st.secrets["GEMINI_API_KEY"]) # Store in `.streamlit/secrets.toml`
141
+
142
+ gemini_model = genai.GenerativeModel("gemini-2.5-pro-exp-03-25")
143
+
144
+
145
+ def get_gemini_response(prompt):
146
+ response = gemini_model.generate_content(prompt)
147
+ return response.text
148
+
149
+
150
+ llm = get_gemini_response
151
+
152
+ elif "All" in selected_model:
153
+
154
+ from groq import Groq
155
+ import google.generativeai as genai
156
+
157
+ genai.configure(api_key=st.secrets["GEMINI_API_KEY"])
158
+
159
+
160
+ def get_all_model_responses(prompt):
161
+ openai_resp = ChatOpenAI(model="gpt-4o", temperature=0.2, api_key=st.secrets["OPENAI_API_KEY"]).invoke(
162
+ [{"role": "system", "content": prompt}]).content
163
+
164
+ gemini = genai.GenerativeModel("gemini-2.5-pro-exp-03-25")
165
+ gemini_resp = gemini.generate_content(prompt).text
166
+
167
+ llama = Groq(api_key=st.secrets["GROQ_API_KEY"])
168
+ llama_resp = llama.chat.completions.create(
169
+ model="meta-llama/llama-4-maverick-17b-128e-instruct",
170
+ messages=[{"role": "user", "content": prompt}],
171
+ temperature=1, max_completion_tokens=1024, top_p=1, stream=False
172
+ ).choices[0].message.content
173
+
174
+ return [openai_resp, gemini_resp, llama_resp]
175
+
176
+
177
+ def rank_and_fuse(prompt, responses):
178
+ ranked = [(resp, pair_ranker(f"{prompt}\n\n{resp}")[0][1]['score']) for resp in responses]
179
+ ranked.sort(key=lambda x: x[1], reverse=True)
180
+ fusion_input = "\n\n".join([f"[Answer {i + 1}]: {ans}" for i, (ans, _) in enumerate(ranked)])
181
+ return gen_fuser(f"Fuse these responses:\n{fusion_input}", return_full_text=False)[0]['generated_text']
182
+
183
+
184
+ else:
185
+ st.error("Unsupported model selected.")
186
+ st.stop()
187
+
188
+ # retriever = vector_store.as_retriever()
189
+
190
+ AI_PROMPT_TEMPLATE = """
191
+ You are DermBOT, a compassionate and knowledgeable AI Dermatology Assistant designed to educate users about skin-related health concerns with clarity, empathy, and precision.
192
+
193
+ Your goal is to respond like a well-informed human expertβ€”balancing professionalism with warmth and reassurance.
194
+
195
+ When crafting responses:
196
+ - Begin with a clear, engaging summary of the condition or concern.
197
+ - Use short paragraphs for readability.
198
+ - Include bullet points or numbered lists where appropriate.
199
+ - Avoid overly technical terms unless explained simply.
200
+ - End with a helpful next step, such as lifestyle advice or when to see a doctor.
201
+
202
+ 🩺 Response Structure:
203
+ 1. **Overview** β€” Briefly introduce the condition or concern.
204
+ 2. **Common Symptoms** β€” Describe noticeable signs in simple terms.
205
+ 3. **Causes & Risk Factors** β€” Include genetic, lifestyle, and environmental aspects.
206
+ 4. **Treatment Options** β€” Outline common OTC and prescription treatments.
207
+ 5. **When to Seek Help** β€” Warn about symptoms that require urgent care.
208
+
209
+ Always encourage consulting a licensed dermatologist for personal diagnosis and treatment. For any breathing difficulties, serious infections, or rapid symptom worsening, advise calling emergency services immediately.
210
+
211
+ ---
212
+
213
+ Query: {question}
214
+ Relevant Context: {context}
215
+
216
+ Your Response:
217
+ """
218
+
219
+ prompt_template = PromptTemplate(template=AI_PROMPT_TEMPLATE, input_variables=["question", "context"])
220
+
221
+ # rag_chain = RetrievalQA.from_chain_type(
222
+ # llm=llm,
223
+ # retriever=retriever,
224
+ # chain_type="stuff",
225
+ # chain_type_kwargs={"prompt": prompt_template, "document_variable_name": "context"}
226
+ # )
227
+
228
+ # === Class Names ===
229
+ multilabel_class_names = [
230
+ "Vesicle", "Papule", "Macule", "Plaque", "Abscess", "Pustule", "Bulla", "Patch",
231
+ "Nodule", "Ulcer", "Crust", "Erosion", "Excoriation", "Atrophy", "Exudate", "Purpura/Petechiae",
232
+ "Fissure", "Induration", "Xerosis", "Telangiectasia", "Scale", "Scar", "Friable", "Sclerosis",
233
+ "Pedunculated", "Exophytic/Fungating", "Warty/Papillomatous", "Dome-shaped", "Flat topped",
234
+ "Brown(Hyperpigmentation)", "Translucent", "White(Hypopigmentation)", "Purple", "Yellow",
235
+ "Black", "Erythema", "Comedo", "Lichenification", "Blue", "Umbilicated", "Poikiloderma",
236
+ "Salmon", "Wheal", "Acuminate", "Burrow", "Gray", "Pigmented", "Cyst"
237
+ ]
238
+
239
+ multiclass_class_names = [
240
+ "systemic", "hair", "drug_reactions", "uriticaria", "acne", "light",
241
+ "autoimmune", "papulosquamous", "eczema", "skincancer",
242
+ "benign_tumors", "bacteria_parasetic_infections", "fungal_infections", "viral_skin_infections"
243
+ ]
244
+
245
+
246
+ # === Load Models ===
247
+ class SkinViT(nn.Module):
248
+ def __init__(self, num_classes):
249
+ super(SkinViT, self).__init__()
250
+ self.model = vit_b_16(weights=ViT_B_16_Weights.DEFAULT)
251
+ in_features = self.model.heads.head.in_features
252
+ self.model.heads.head = nn.Linear(in_features, num_classes)
253
+
254
+ def forward(self, x):
255
+ return self.model(x)
256
+
257
+
258
+ class DermNetViT(nn.Module):
259
+ def __init__(self, num_classes):
260
+ super(DermNetViT, self).__init__()
261
+ self.model = vit_l_16(weights=ViT_L_16_Weights.DEFAULT)
262
+ in_features = self.model.heads[0].in_features
263
+ self.model.heads[0] = nn.Sequential(
264
+ nn.Dropout(0.3),
265
+ nn.Linear(in_features, num_classes)
266
+ )
267
+
268
+ def forward(self, x):
269
+ return self.model(x)
270
+
271
+
272
+ # multilabel_model = torch.load("D:/DR/RAG/BestModels2703/skin_vit_fold10.pth", map_location='cpu')
273
+ # multiclass_model = torch.load("D:/DR/RAG/BestModels2703/best_dermnet_vit.pth", map_location='cpu')
274
+
275
+ # === Load Model State Dicts ===
276
+ multilabel_model_path = hf_hub_download(repo_id="santhoshraghu/DermBOT", filename="skin_vit_fold10_sd.pth")
277
+ multiclass_model_path = hf_hub_download(repo_id="santhoshraghu/DermBOT", filename="best_dermnet_vit_sd.pth")
278
+
279
+
280
+ def load_model_with_fallback(model_class, weight_path, num_classes, model_name):
281
+ try:
282
+ print(f"πŸ” Loading {model_name} on GPU...")
283
+ model = model_class(num_classes)
284
+ model.load_state_dict(torch.load(weight_path, map_location="cuda"))
285
+ model.to("cuda")
286
+ print(f"βœ… {model_name} loaded on GPU.")
287
+ return model
288
+ except RuntimeError as e:
289
+ if "CUDA out of memory" in str(e):
290
+ print(f"⚠️ {model_name} OOM. Falling back to CPU.")
291
+ else:
292
+ print(f"❌ Error loading {model_name} on CUDA: {e}")
293
+ print(f"πŸ”„ Loading {model_name} on CPU...")
294
+ model = model_class(num_classes)
295
+ model.load_state_dict(torch.load(weight_path, map_location="cpu"))
296
+ model.to("cpu")
297
+ return model
298
+
299
+
300
+ # Load both models with fallback
301
+ multilabel_model = load_model_with_fallback(SkinViT, multilabel_model_path, len(multilabel_class_names), "SkinViT")
302
+ multiclass_model = load_model_with_fallback(DermNetViT, multiclass_model_path, len(multiclass_class_names),
303
+ "DermNetViT")
304
+
305
+ multilabel_model.eval()
306
+ multiclass_model.eval()
307
+
308
+ # === Session Init ===
309
+ if "messages" not in st.session_state:
310
+ st.session_state.messages = []
311
+
312
+
313
+ # === Image Processing Function ===
314
+ def run_inference(image):
315
+ transform = transforms.Compose([
316
+ transforms.Resize((224, 224)),
317
+ transforms.ToTensor(),
318
+ transforms.Normalize([0.5], [0.5])
319
+ ])
320
+ input_tensor = transform(image).unsqueeze(0)
321
+
322
+ # Automatically match model device (GPU or CPU)
323
+ model_device = next(multilabel_model.parameters()).device
324
+ input_tensor = input_tensor.to(model_device)
325
+
326
+ with torch.no_grad():
327
+ probs_multi = torch.sigmoid(multilabel_model(input_tensor)).squeeze().cpu().numpy()
328
+ pred_idx = torch.argmax(multiclass_model(input_tensor), dim=1).item()
329
+ predicted_multi = [multilabel_class_names[i] for i, p in enumerate(probs_multi) if p > 0.5]
330
+ predicted_single = multiclass_class_names[pred_idx]
331
+
332
+ return predicted_multi, predicted_single
333
+
334
+
335
+ # === PDF Export ===
336
+ def export_chat_to_pdf(messages):
337
+ pdf = FPDF()
338
+ pdf.add_page()
339
+ pdf.set_font("Arial", size=12)
340
+ for msg in messages:
341
+ role = "You" if msg["role"] == "user" else "AI"
342
+ pdf.multi_cell(0, 10, f"{role}: {msg['content']}\n")
343
+ buf = io.BytesIO()
344
+ pdf.output(buf)
345
+ buf.seek(0)
346
+ return buf
347
+
348
+
349
+ # Reranker utility
350
+ def rerank_with_cohere(query, documents, top_n=5):
351
+ if not documents:
352
+ return []
353
+ raw_texts = [doc.page_content for doc in documents]
354
+ results = co.rerank(query=query, documents=raw_texts, top_n=min(top_n, len(raw_texts)), model="rerank-v3.5")
355
+ return [documents[result.index] for result in results]
356
+
357
+
358
+ # Final answer generation using reranked context
359
+ def get_reranked_response(query):
360
+ docs = retriever.get_relevant_documents(query)
361
+ reranked_docs = rerank_with_cohere(query, docs)
362
+ context = "\n\n".join([doc.page_content for doc in reranked_docs])
363
+ prompt = AI_PROMPT_TEMPLATE.format(question=query, context=context)
364
+
365
+ if selected_model == "All":
366
+ responses = get_all_model_responses(prompt)
367
+ fused = rank_and_fuse(prompt, responses)
368
+ return type("Obj", (), {"content": fused})
369
+
370
+ if callable(llm):
371
+ return type("Obj", (), {"content": llm(prompt)})
372
+ else:
373
+ return llm.invoke([{"role": "system", "content": prompt}])
374
+
375
+
376
+ # === App UI ===
377
+
378
+ st.title("🧬 DermBOT β€” Skin AI Assistant")
379
+ st.caption(f"🧠 Using model: {selected_model}")
380
+ uploaded_file = st.file_uploader("Upload a skin image", type=["jpg", "jpeg", "png"])
381
+
382
+ if uploaded_file:
383
+ st.image(uploaded_file, caption="Uploaded image", use_container_width=True)
384
+ image = Image.open(uploaded_file).convert("RGB")
385
+
386
+ predicted_multi, predicted_single = run_inference(image)
387
+
388
+ # Show predictions clearly to the user
389
+ st.markdown(f"🧾 **Skin Issues**: {', '.join(predicted_multi)}")
390
+ st.markdown(f"πŸ“Œ **Most Likely Diagnosis**: {predicted_single}")
391
+
392
+ query = f"What are my treatment options for {predicted_multi} and {predicted_single}?"
393
+ st.session_state.messages.append({"role": "user", "content": query})
394
+
395
+ with st.spinner("πŸ”Ž Analyzing and retrieving context..."):
396
+ response = get_reranked_response(query)
397
+ st.session_state.messages.append({"role": "assistant", "content": response.content})
398
+
399
+ with st.chat_message("assistant"):
400
+ st.markdown(response.content)
401
+
402
+ # === Chat Interface ===
403
+ if prompt := st.chat_input("Ask a follow-up..."):
404
+ st.session_state.messages.append({"role": "user", "content": prompt})
405
+ with st.chat_message("user"):
406
+ st.markdown(prompt)
407
+
408
+ response = get_reranked_response(prompt)
409
+ st.session_state.messages.append({"role": "assistant", "content": response.content})
410
+ with st.chat_message("assistant"):
411
+ st.markdown(response.content)
412
+
413
+ # === PDF Button ===
414
+ if st.button("πŸ“„ Download Chat as PDF"):
415
+ pdf_file = export_chat_to_pdf(st.session_state.messages)
416
+ st.download_button("Download PDF", data=pdf_file, file_name="chat_history.pdf", mime="application/pdf")