KeerthiVM commited on
Commit
c808cd0
Β·
1 Parent(s): 79cab30
Files changed (1) hide show
  1. app.py +80 -50
app.py CHANGED
@@ -20,48 +20,62 @@ nest_asyncio.apply()
20
  device='cuda' if torch.cuda.is_available() else 'cpu'
21
  st.set_page_config(page_title="DermBOT", page_icon="🧬", layout="centered")
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  # === Model Selection ===
24
  available_models = ["OpenAI GPT-4o", "LLaMA 3", "Gemini Pro"]
25
- st.session_state["selected_model"] = st.sidebar.selectbox("Select LLM Model", available_models)
 
 
 
 
 
 
 
26
 
27
  # Dynamically initialize LLM based on selection
28
  OPENAI_API_KEY = st.secrets["OPENAI_API_KEY"]
29
- selected_model = st.session_state["selected_model"]
30
- if "OpenAI" in selected_model:
31
- llm = ChatOpenAI(model="gpt-4o", temperature=0.2, api_key=OPENAI_API_KEY)
32
- elif "LLaMA" in selected_model:
33
- st.warning("LLaMA integration is not implemented yet.")
34
- st.stop()
35
- elif "Gemini" in selected_model:
36
- st.warning("Gemini integration is not implemented yet.")
37
- st.stop()
38
- else:
39
- st.error("Unsupported model selected.")
40
- st.stop()
41
-
42
-
43
- with st.spinner("Loading AI models (one-time operation)..."):
44
- classifier = initialize_classifier()
45
- st.success("Models loaded successfully!")
46
 
47
  # === Session Init ===
48
  if "messages" not in st.session_state:
49
  st.session_state.messages = []
50
 
 
 
51
 
52
  # === Image Processing Function ===
53
  def run_inference(image):
54
-
55
  result = classifier.predict(image, top_k=1)
56
-
57
- # Display results
58
- print("Top Predictions:")
59
- for name, prob in result["top_predictions"]:
60
- print(f"{name}: {prob:.2%}")
61
-
62
- print("\nYou can also access all probabilities:")
63
- print(result["all_probabilities"])
64
-
65
  predicted_label = result["top_predictions"][0][0]
66
  return predicted_label
67
 
@@ -82,43 +96,59 @@ def export_chat_to_pdf(messages):
82
 
83
  # === App UI ===
84
 
85
- print("=== Before App UI ===")
86
  st.title("🧬 DermBOT β€” Skin AI Assistant")
87
- st.caption(f"🧠 Using model: {selected_model}")
88
- uploaded_file = st.file_uploader("Upload a skin image", type=["jpg", "jpeg", "png"])
 
 
 
 
 
 
 
 
89
 
90
- if uploaded_file:
91
- print("=== At start of uploading ===")
92
- st.image(uploaded_file, caption="Uploaded image", use_column_width=True)
93
  image = Image.open(uploaded_file).convert("RGB")
 
 
 
94
 
95
- predicted_label = run_inference(image)
96
-
97
- # Show predictions clearly to the user
98
  st.markdown(f" Most Likely Diagnosis : {predicted_label}")
99
 
100
- query = f"What are my treatment options for {predicted_label}?"
101
- st.session_state.messages.append({"role": "user", "content": query})
102
 
103
- with st.spinner("Analyzing the image and retrieving response..."):
104
- response = invoke_rag_chain(llm).invoke(query)
105
  st.session_state.messages.append({"role": "assistant", "content": response['result']})
106
 
107
- with st.chat_message("assistant"):
108
- st.markdown(response['result'])
 
109
 
110
  # === Chat Interface ===
111
- if prompt := st.chat_input("Ask a follow-up..."):
112
  st.session_state.messages.append({"role": "user", "content": prompt})
113
  with st.chat_message("user"):
114
  st.markdown(prompt)
115
 
116
- response = llm.invoke([{"role": m["role"], "content": m["content"]} for m in st.session_state.messages])
117
- st.session_state.messages.append({"role": "assistant", "content": response.content})
118
  with st.chat_message("assistant"):
119
- st.markdown(response.content)
 
 
 
 
 
 
 
 
 
120
 
121
- # === PDF Button ===
122
- if st.button("πŸ“„ Download Chat as PDF"):
123
  pdf_file = export_chat_to_pdf(st.session_state.messages)
124
- st.download_button("Download PDF", data=pdf_file, file_name="chat_history.pdf", mime="application/pdf")
 
 
 
 
 
 
20
  device='cuda' if torch.cuda.is_available() else 'cpu'
21
  st.set_page_config(page_title="DermBOT", page_icon="🧬", layout="centered")
22
 
23
+
24
+ @st.cache_resource(show_spinner=False)
25
+ def load_models():
26
+ """Cache all models to load only once"""
27
+ with st.spinner("Loading AI models (one-time operation)..."):
28
+ classifier = initialize_classifier()
29
+ return classifier
30
+
31
+ @st.cache_resource(show_spinner=False)
32
+ def initialize_llm(selected_model, api_key):
33
+ """Initialize the LLM based on selection"""
34
+ if "OpenAI" in selected_model:
35
+ return ChatOpenAI(model="gpt-4o", temperature=0.2, api_key=api_key)
36
+ elif "LLaMA" in selected_model:
37
+ st.warning("LLaMA integration is not implemented yet.")
38
+ st.stop()
39
+ elif "Gemini" in selected_model:
40
+ st.warning("Gemini integration is not implemented yet.")
41
+ st.stop()
42
+ else:
43
+ st.error("Unsupported model selected.")
44
+ st.stop()
45
+
46
+ @st.cache_resource(show_spinner=False)
47
+ def load_rag_chain(llm):
48
+ """Initialize RAG chain only once"""
49
+ return invoke_rag_chain(llm)
50
+
51
  # === Model Selection ===
52
  available_models = ["OpenAI GPT-4o", "LLaMA 3", "Gemini Pro"]
53
+ if "selected_model" not in st.session_state:
54
+ st.session_state["selected_model"] = available_models[0]
55
+
56
+ st.session_state["selected_model"] = st.sidebar.selectbox(
57
+ "Select LLM Model",
58
+ available_models,
59
+ index=available_models.index(st.session_state["selected_model"])
60
+ )
61
 
62
  # Dynamically initialize LLM based on selection
63
  OPENAI_API_KEY = st.secrets["OPENAI_API_KEY"]
64
+ llm = initialize_llm(st.session_state["selected_model"], OPENAI_API_KEY)
65
+ rag_chain = load_rag_chain(llm)
66
+
67
+ classifier = load_models()
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
  # === Session Init ===
70
  if "messages" not in st.session_state:
71
  st.session_state.messages = []
72
 
73
+ if "current_image" not in st.session_state:
74
+ st.session_state.current_image = None
75
 
76
  # === Image Processing Function ===
77
  def run_inference(image):
 
78
  result = classifier.predict(image, top_k=1)
 
 
 
 
 
 
 
 
 
79
  predicted_label = result["top_predictions"][0][0]
80
  return predicted_label
81
 
 
96
 
97
  # === App UI ===
98
 
 
99
  st.title("🧬 DermBOT β€” Skin AI Assistant")
100
+ st.caption(f"🧠 Using model: {st.session_state['selected_model']}")
101
+ uploaded_file = st.file_uploader(
102
+ "Upload a skin image",
103
+ type=["jpg", "jpeg", "png"],
104
+ key="file_uploader"
105
+ )
106
+
107
+ if uploaded_file is not None and uploaded_file != st.session_state.current_image:
108
+ st.session_state.messages = []
109
+ st.session_state.current_image = uploaded_file
110
 
 
 
 
111
  image = Image.open(uploaded_file).convert("RGB")
112
+ st.image(image, caption="Uploaded image", use_column_width=True)
113
+ with st.spinner("Analyzing the image..."):
114
+ predicted_label = run_inference(image)
115
 
 
 
 
116
  st.markdown(f" Most Likely Diagnosis : {predicted_label}")
117
 
118
+ initial_query = f"What are my treatment options for {predicted_label}?"
119
+ st.session_state.messages.append({"role": "user", "content": initial_query})
120
 
121
+ with st.spinner("Retrieving medical information..."):
122
+ response = rag_chain.invoke(initial_query)
123
  st.session_state.messages.append({"role": "assistant", "content": response['result']})
124
 
125
+ for message in st.session_state.messages:
126
+ with st.chat_message(message["role"]):
127
+ st.markdown(message["content"])
128
 
129
  # === Chat Interface ===
130
+ if prompt := st.chat_input("Ask a follow-up question..."):
131
  st.session_state.messages.append({"role": "user", "content": prompt})
132
  with st.chat_message("user"):
133
  st.markdown(prompt)
134
 
 
 
135
  with st.chat_message("assistant"):
136
+ with st.spinner("Thinking..."):
137
+ if len(st.session_state.messages) > 1:
138
+ response = llm.invoke([{"role": m["role"], "content": m["content"]} for m in st.session_state.messages])
139
+ else:
140
+ response = rag_chain.invoke(prompt)
141
+ response = response['result']
142
+
143
+ st.markdown(response)
144
+ st.session_state.messages.append({"role": "assistant", "content": response})
145
+
146
 
147
+ if st.session_state.messages and st.button("πŸ“„ Download Chat as PDF"):
 
148
  pdf_file = export_chat_to_pdf(st.session_state.messages)
149
+ st.download_button(
150
+ "Download PDF",
151
+ data=pdf_file,
152
+ file_name="dermbot_chat_history.pdf",
153
+ mime="application/pdf"
154
+ )