Spaces:
Sleeping
Sleeping
Commit
·
3102bab
1
Parent(s):
37f8d03
Made changes to chatbot.py
Browse files- app/chatbot.py +61 -46
app/chatbot.py
CHANGED
@@ -39,11 +39,20 @@ index.add(np.array(tfidf_matrix, dtype=np.float32))
|
|
39 |
disease_list = list(disease_symptoms.keys())
|
40 |
|
41 |
def find_closest_disease(user_symptoms):
|
42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
return None
|
44 |
-
user_vector = vectorizer.transform([" ".join(user_symptoms)]).toarray().astype("float32")
|
45 |
-
distances, indices = index.search(user_vector, k=1)
|
46 |
-
return disease_list[indices[0][0]]
|
47 |
|
48 |
# -------------------------------
|
49 |
# Load Medical NER model for symptom extraction
|
@@ -74,23 +83,22 @@ class ChatbotSession:
|
|
74 |
def __init__(self):
|
75 |
self.conversation_history = []
|
76 |
self.reported_symptoms = set()
|
77 |
-
self.symptom_details = {}
|
78 |
self.asked_missing = set()
|
79 |
self.awaiting_followup = None
|
80 |
-
self.awaiting_detail = None
|
81 |
self.current_detail_symptom = None
|
82 |
-
self.state = "
|
83 |
self.finished = False
|
84 |
-
self.
|
85 |
-
greeting = "Doctor: Hello
|
86 |
self.conversation_history.append(greeting)
|
87 |
|
88 |
def process_message(self, message: str) -> str:
|
89 |
if self.finished:
|
90 |
-
return "Doctor: Thank you. Our session has ended."
|
91 |
|
92 |
-
if self.state == "
|
93 |
-
return self.
|
94 |
|
95 |
if self.state == "symptom_detail":
|
96 |
return self._handle_symptom_detail(message)
|
@@ -101,20 +109,19 @@ class ChatbotSession:
|
|
101 |
if self.state == "medications":
|
102 |
return self._handle_medications(message)
|
103 |
|
104 |
-
return "Doctor: Could you please clarify?"
|
105 |
|
106 |
-
def
|
107 |
if message.lower() in ["exit", "quit", "no"]:
|
108 |
if not self.reported_symptoms:
|
109 |
goodbye = "Doctor: It seems no symptoms were reported. Ending the session."
|
110 |
self.finished = True
|
111 |
return goodbye
|
112 |
else:
|
113 |
-
self.
|
114 |
self.state = "pain_check"
|
115 |
-
return
|
116 |
|
117 |
-
# Extract symptoms
|
118 |
ner_results = extract_symptoms_ner(message)
|
119 |
if ner_results:
|
120 |
for sym in ner_results:
|
@@ -122,52 +129,44 @@ class ChatbotSession:
|
|
122 |
self.reported_symptoms.add(sym)
|
123 |
self.symptom_details[sym] = {}
|
124 |
|
125 |
-
|
126 |
-
self.predicted_disease = find_closest_disease(list(self.reported_symptoms))
|
127 |
-
if self.predicted_disease:
|
128 |
-
expected = set(disease_symptoms.get(self.predicted_disease, []))
|
129 |
-
missing = expected - self.reported_symptoms
|
130 |
-
not_asked = missing - self.asked_missing
|
131 |
|
132 |
-
if not_asked:
|
133 |
-
symptom_to_ask = list(not_asked)[0]
|
134 |
-
self.awaiting_followup = symptom_to_ask
|
135 |
-
return f"Doctor: Are you also experiencing {symptom_to_ask}?"
|
136 |
-
|
137 |
-
# If all covered, ask symptom details
|
138 |
if self.reported_symptoms:
|
139 |
-
symptom = random.choice(list(self.reported_symptoms))
|
140 |
self.current_detail_symptom = symptom
|
141 |
self.state = "symptom_detail"
|
142 |
-
return f"Doctor: About your '{symptom}', when did it start? (
|
143 |
|
144 |
else:
|
145 |
-
return "Doctor: I couldn't detect any medical symptoms. Could you describe
|
146 |
|
147 |
-
return "Doctor:
|
148 |
|
149 |
def _handle_symptom_detail(self, message):
|
150 |
if self.current_detail_symptom and 'duration' not in self.symptom_details[self.current_detail_symptom]:
|
151 |
self.symptom_details[self.current_detail_symptom]['duration'] = message
|
152 |
-
return f"Doctor: How severe is the '{self.current_detail_symptom}'? (mild
|
153 |
|
154 |
if self.current_detail_symptom and 'severity' not in self.symptom_details[self.current_detail_symptom]:
|
155 |
self.symptom_details[self.current_detail_symptom]['severity'] = message
|
156 |
-
return f"Doctor: Where exactly do you feel the '{self.current_detail_symptom}' (
|
157 |
|
158 |
if self.current_detail_symptom and 'location' not in self.symptom_details[self.current_detail_symptom]:
|
159 |
self.symptom_details[self.current_detail_symptom]['location'] = message
|
160 |
-
self.state = "
|
161 |
-
return "Doctor: Thank you.
|
162 |
|
163 |
-
return "Doctor: Please clarify."
|
164 |
|
165 |
def _handle_pain_check(self, message):
|
166 |
try:
|
167 |
pain_level = int(message)
|
168 |
self.symptom_details['pain'] = {'severity': pain_level}
|
169 |
except ValueError:
|
170 |
-
|
|
|
|
|
|
|
171 |
|
172 |
self.state = "medications"
|
173 |
return "Doctor: Have you taken any medications recently? Please mention them or type 'no'."
|
@@ -175,9 +174,18 @@ class ChatbotSession:
|
|
175 |
def _handle_medications(self, message):
|
176 |
self.symptom_details['medications'] = message if message.lower() not in ["no", "none"] else "None"
|
177 |
self.finished = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
178 |
|
179 |
-
|
180 |
-
|
181 |
|
182 |
def _generate_summary(self):
|
183 |
report = "\n".join([
|
@@ -185,15 +193,22 @@ class ChatbotSession:
|
|
185 |
for sym, details in self.symptom_details.items()
|
186 |
])
|
187 |
|
188 |
-
disease_part =
|
189 |
-
|
190 |
-
|
191 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
192 |
|
193 |
def get_data(self):
|
194 |
return {
|
195 |
"conversation": self.conversation_history,
|
196 |
"symptoms": list(self.reported_symptoms),
|
197 |
"symptom_details": self.symptom_details,
|
198 |
-
"
|
199 |
}
|
|
|
39 |
disease_list = list(disease_symptoms.keys())
|
40 |
|
41 |
def find_closest_disease(user_symptoms):
|
42 |
+
max_matches = 0
|
43 |
+
best_disease = None
|
44 |
+
|
45 |
+
for disease, symptoms in disease_symptoms.items():
|
46 |
+
matches = len(set(user_symptoms) & set(symptoms))
|
47 |
+
if matches > max_matches:
|
48 |
+
max_matches = matches
|
49 |
+
best_disease = disease
|
50 |
+
|
51 |
+
# Require at least 2 symptoms matching to consider it valid
|
52 |
+
if max_matches >= 2:
|
53 |
+
return best_disease
|
54 |
+
else:
|
55 |
return None
|
|
|
|
|
|
|
56 |
|
57 |
# -------------------------------
|
58 |
# Load Medical NER model for symptom extraction
|
|
|
83 |
def __init__(self):
|
84 |
self.conversation_history = []
|
85 |
self.reported_symptoms = set()
|
86 |
+
self.symptom_details = {}
|
87 |
self.asked_missing = set()
|
88 |
self.awaiting_followup = None
|
|
|
89 |
self.current_detail_symptom = None
|
90 |
+
self.state = "intake"
|
91 |
self.finished = False
|
92 |
+
self.predicted_diseases = []
|
93 |
+
greeting = "Doctor: Hello! I'm here to assist you. Could you please describe how you're feeling today?"
|
94 |
self.conversation_history.append(greeting)
|
95 |
|
96 |
def process_message(self, message: str) -> str:
|
97 |
if self.finished:
|
98 |
+
return "Doctor: Thank you. Our session has ended. Wishing you good health!"
|
99 |
|
100 |
+
if self.state == "intake":
|
101 |
+
return self._handle_intake(message)
|
102 |
|
103 |
if self.state == "symptom_detail":
|
104 |
return self._handle_symptom_detail(message)
|
|
|
109 |
if self.state == "medications":
|
110 |
return self._handle_medications(message)
|
111 |
|
112 |
+
return "Doctor: Could you please clarify that?"
|
113 |
|
114 |
+
def _handle_intake(self, message):
|
115 |
if message.lower() in ["exit", "quit", "no"]:
|
116 |
if not self.reported_symptoms:
|
117 |
goodbye = "Doctor: It seems no symptoms were reported. Ending the session."
|
118 |
self.finished = True
|
119 |
return goodbye
|
120 |
else:
|
121 |
+
self._predict_diseases()
|
122 |
self.state = "pain_check"
|
123 |
+
return "Doctor: Before proceeding, are you experiencing any pain? If yes, please rate it 1-10 or type 'no'."
|
124 |
|
|
|
125 |
ner_results = extract_symptoms_ner(message)
|
126 |
if ner_results:
|
127 |
for sym in ner_results:
|
|
|
129 |
self.reported_symptoms.add(sym)
|
130 |
self.symptom_details[sym] = {}
|
131 |
|
132 |
+
self._predict_diseases()
|
|
|
|
|
|
|
|
|
|
|
133 |
|
|
|
|
|
|
|
|
|
|
|
|
|
134 |
if self.reported_symptoms:
|
135 |
+
symptom = random.choice(list(self.reported_symptoms - set(self.symptom_details.keys())))
|
136 |
self.current_detail_symptom = symptom
|
137 |
self.state = "symptom_detail"
|
138 |
+
return f"Doctor: About your '{symptom}', when did it start? (e.g., 2 days ago)"
|
139 |
|
140 |
else:
|
141 |
+
return "Doctor: I couldn't clearly detect any medical symptoms. Could you please describe your condition differently?"
|
142 |
|
143 |
+
return "Doctor: Any other symptoms you'd like to share?"
|
144 |
|
145 |
def _handle_symptom_detail(self, message):
|
146 |
if self.current_detail_symptom and 'duration' not in self.symptom_details[self.current_detail_symptom]:
|
147 |
self.symptom_details[self.current_detail_symptom]['duration'] = message
|
148 |
+
return f"Doctor: How severe is the '{self.current_detail_symptom}'? (mild, moderate, severe)"
|
149 |
|
150 |
if self.current_detail_symptom and 'severity' not in self.symptom_details[self.current_detail_symptom]:
|
151 |
self.symptom_details[self.current_detail_symptom]['severity'] = message
|
152 |
+
return f"Doctor: Where exactly do you feel the '{self.current_detail_symptom}'? (e.g., forehead, chest, abdomen)"
|
153 |
|
154 |
if self.current_detail_symptom and 'location' not in self.symptom_details[self.current_detail_symptom]:
|
155 |
self.symptom_details[self.current_detail_symptom]['location'] = message
|
156 |
+
self.state = "intake"
|
157 |
+
return "Doctor: Thank you. Are there any more symptoms you'd like to mention?"
|
158 |
|
159 |
+
return "Doctor: Please clarify that."
|
160 |
|
161 |
def _handle_pain_check(self, message):
|
162 |
try:
|
163 |
pain_level = int(message)
|
164 |
self.symptom_details['pain'] = {'severity': pain_level}
|
165 |
except ValueError:
|
166 |
+
if is_affirmative(message):
|
167 |
+
self.symptom_details['pain'] = {'severity': 'mild'}
|
168 |
+
else:
|
169 |
+
self.symptom_details['pain'] = {'severity': 'none'}
|
170 |
|
171 |
self.state = "medications"
|
172 |
return "Doctor: Have you taken any medications recently? Please mention them or type 'no'."
|
|
|
174 |
def _handle_medications(self, message):
|
175 |
self.symptom_details['medications'] = message if message.lower() not in ["no", "none"] else "None"
|
176 |
self.finished = True
|
177 |
+
return self._generate_summary()
|
178 |
+
|
179 |
+
def _predict_diseases(self):
|
180 |
+
match_scores = []
|
181 |
+
for disease, symptoms in disease_symptoms.items():
|
182 |
+
matches = len(set(self.reported_symptoms) & set(symptoms))
|
183 |
+
if matches > 0:
|
184 |
+
score = matches / len(symptoms)
|
185 |
+
match_scores.append((disease, score))
|
186 |
|
187 |
+
match_scores.sort(key=lambda x: x[1], reverse=True)
|
188 |
+
self.predicted_diseases = match_scores[:3] # Top 3 predictions
|
189 |
|
190 |
def _generate_summary(self):
|
191 |
report = "\n".join([
|
|
|
193 |
for sym, details in self.symptom_details.items()
|
194 |
])
|
195 |
|
196 |
+
disease_part = ""
|
197 |
+
if self.predicted_diseases:
|
198 |
+
disease_part = "\n\nPossible Conditions:" + "\n".join([
|
199 |
+
f"- {disease} ({int(score * 100)}%)" for disease, score in self.predicted_diseases
|
200 |
+
])
|
201 |
+
else:
|
202 |
+
disease_part = "\n\nDoctor: No strong condition match was found based on reported symptoms."
|
203 |
+
|
204 |
+
advice = "\n\nDoctor: Please note this is a preliminary virtual consultation. You should consult a physician if symptoms persist, worsen, or if you feel uncomfortable."
|
205 |
+
|
206 |
+
return f"Doctor: Thank you for the detailed information!\n\n=== Medical Report ===\n{report}{disease_part}{advice}"
|
207 |
|
208 |
def get_data(self):
|
209 |
return {
|
210 |
"conversation": self.conversation_history,
|
211 |
"symptoms": list(self.reported_symptoms),
|
212 |
"symptom_details": self.symptom_details,
|
213 |
+
"predicted_diseases": self.predicted_diseases
|
214 |
}
|