File size: 10,047 Bytes
8dcd1f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3102bab
 
 
 
 
 
 
 
 
 
 
 
 
8dcd1f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37f8d03
455e0c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8dcd1f3
 
 
 
 
 
 
3102bab
8dcd1f3
 
37f8d03
3102bab
8dcd1f3
3102bab
 
37f8d03
8dcd1f3
 
37f8d03
3102bab
37f8d03
3102bab
 
37f8d03
 
 
 
 
 
 
 
 
 
3102bab
37f8d03
3102bab
37f8d03
 
 
 
 
8dcd1f3
3102bab
37f8d03
3102bab
37f8d03
 
 
 
 
 
 
b9e9312
 
 
37f8d03
455e0c8
37f8d03
3102bab
b9e9312
 
 
 
37f8d03
3102bab
37f8d03
 
 
 
3102bab
37f8d03
 
 
455e0c8
 
 
 
 
 
 
37f8d03
 
 
3102bab
 
37f8d03
3102bab
37f8d03
 
 
 
 
 
3102bab
 
 
 
37f8d03
 
 
 
 
 
 
3102bab
 
 
 
 
 
 
 
455e0c8
9af961d
37f8d03
3102bab
455e0c8
37f8d03
 
9af961d
 
 
 
 
 
 
 
37f8d03
3102bab
 
 
 
 
 
 
 
 
 
 
8dcd1f3
 
 
 
 
37f8d03
3102bab
8dcd1f3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
import pandas as pd
import json
import numpy as np
import faiss
from sklearn.feature_extraction.text import TfidfVectorizer
from transformers import pipeline

# -------------------------------
# Load disease data and preprocess
# -------------------------------
def load_disease_data(csv_path):
    df = pd.read_csv(csv_path)
    df.columns = df.columns.str.strip().str.lower()
    df = df.fillna("")
    disease_symptoms = {}
    disease_precautions = {}
    for _, row in df.iterrows():
        disease = row["disease"].strip()
        symptoms = [s.strip().lower() for s in row["symptoms"].split(",") if s.strip()]
        precautions = [p.strip() for p in row["precautions"].split(",") if p.strip()]
        disease_symptoms[disease] = symptoms
        disease_precautions[disease] = precautions
    return disease_symptoms, disease_precautions

# Load CSV data (ensure this CSV file is in the repository root)
disease_symptoms, disease_precautions = load_disease_data("disease_sympts_prec_full.csv")
known_symptoms = set()
for syms in disease_symptoms.values():
    known_symptoms.update(syms)

# -------------------------------
# Build symptom vectorizer and FAISS index
# -------------------------------
vectorizer = TfidfVectorizer()
symptom_texts = [" ".join(symptoms) for symptoms in disease_symptoms.values()]
tfidf_matrix = vectorizer.fit_transform(symptom_texts).toarray()
index = faiss.IndexFlatL2(tfidf_matrix.shape[1])
index.add(np.array(tfidf_matrix, dtype=np.float32))
disease_list = list(disease_symptoms.keys())

def find_closest_disease(user_symptoms):
    max_matches = 0
    best_disease = None

    for disease, symptoms in disease_symptoms.items():
        matches = len(set(user_symptoms) & set(symptoms))
        if matches > max_matches:
            max_matches = matches
            best_disease = disease

    # Require at least 2 symptoms matching to consider it valid
    if max_matches >= 2:
        return best_disease
    else:
        return None

# -------------------------------
# Load Medical NER model for symptom extraction
# -------------------------------
medical_ner = pipeline(
    "ner",
    model="blaze999/Medical-NER",
    tokenizer="blaze999/Medical-NER",
    aggregation_strategy="simple"
)

def extract_symptoms_ner(text):
    results = medical_ner(text)
    extracted = []
    for r in results:
        if "SIGN_SYMPTOM" in r["entity_group"]:
            extracted.append(r["word"].lower())
    return list(set(extracted))

def is_affirmative(answer):
    answer_lower = answer.lower()
    return any(word in answer_lower for word in ["yes", "yeah", "yep", "certainly", "sometimes", "a little"])
import random
# Map symptoms to associated body parts
SYMPTOM_BODY_PARTS = {
    "headache": "head",
    "fever": None,
    "cold": None,
    "cough": "throat",
    "chest pain": "chest",
    "stomach ache": "abdomen",
    "sore throat": "throat",
    "back pain": "back",
    "ear pain": "ear",
    "toothache": "mouth",
    "nausea": "stomach",
    "vomiting": "stomach",
    "diarrhea": "abdomen",
    "shortness of breath": "chest",
    "dizziness": "head",
    "rash": "skin",
    "eye pain": "eye",
    "abdominal pain": "abdomen",
    "joint pain": "joint",
    "muscle pain": "muscle",
    "neck pain": "neck"
}

# -------------------------------
# Chatbot session class
# -------------------------------
class ChatbotSession:
    def __init__(self):
        self.conversation_history = []
        self.reported_symptoms = set()
        self.symptom_details = {}
        self.asked_missing = set()
        self.awaiting_followup = None
        self.current_detail_symptom = None
        self.state = "intake"
        self.finished = False
        self.predicted_diseases = []
        greeting = "Doctor: Hello! I'm here to assist you. Could you please describe how you're feeling today?"
        self.conversation_history.append(greeting)

    def process_message(self, message: str) -> str:
        if self.finished:
            return "Doctor: Thank you. Our session has ended. Wishing you good health!"

        if self.state == "intake":
            return self._handle_intake(message)

        if self.state == "symptom_detail":
            return self._handle_symptom_detail(message)

        if self.state == "pain_check":
            return self._handle_pain_check(message)

        if self.state == "medications":
            return self._handle_medications(message)

        return "Doctor: Could you please clarify that?"

    def _handle_intake(self, message):
        if message.lower() in ["exit", "quit", "no"]:
            if not self.reported_symptoms:
                goodbye = "Doctor: It seems no symptoms were reported. Ending the session."
                self.finished = True
                return goodbye
            else:
                self._predict_diseases()
                self.state = "pain_check"
                return "Doctor: Before proceeding, are you experiencing any pain? If yes, please rate it 1-10 or type 'no'."

        ner_results = extract_symptoms_ner(message)
        if ner_results:
            for sym in ner_results:
                if sym not in self.reported_symptoms:
                    self.reported_symptoms.add(sym)

            unexplored_symptoms = list(self.reported_symptoms - set(self.symptom_details.keys()))
            if unexplored_symptoms:
                symptom = random.choice(unexplored_symptoms)
                self.current_detail_symptom = symptom
                self.symptom_details[symptom] = {}
                self.state = "symptom_detail"
                return f"Doctor: About your '{symptom}', when did it start? (e.g., 2 days ago)"
            else:
                self._predict_diseases()
                self.state = "pain_check"
                return "Doctor: Before proceeding, are you experiencing any pain? If yes, please rate it 1-10 or type 'no'."
        else:
            return "Doctor: I couldn't clearly detect any medical symptoms. Could you please describe your condition differently?"

    def _handle_symptom_detail(self, message):
        if self.current_detail_symptom and 'duration' not in self.symptom_details[self.current_detail_symptom]:
            self.symptom_details[self.current_detail_symptom]['duration'] = message
            return f"Doctor: How severe is the '{self.current_detail_symptom}'? (mild, moderate, severe)"

        if self.current_detail_symptom and 'severity' not in self.symptom_details[self.current_detail_symptom]:
            self.symptom_details[self.current_detail_symptom]['severity'] = message
            body_part = SYMPTOM_BODY_PARTS.get(self.current_detail_symptom)
            if body_part:
                self.symptom_details[self.current_detail_symptom]['location'] = body_part
                self.state = "intake"
                return "Doctor: Thank you. Are there any more symptoms you'd like to mention?"
            else:
                return f"Doctor: Where exactly do you feel the '{self.current_detail_symptom}'? (e.g., forehead, chest, abdomen)"

        if self.current_detail_symptom and 'location' not in self.symptom_details[self.current_detail_symptom]:
            self.symptom_details[self.current_detail_symptom]['location'] = message
            self.state = "intake"
            return "Doctor: Thank you. Are there any more symptoms you'd like to mention?"

        return "Doctor: Please clarify that."

    def _handle_pain_check(self, message):
        try:
            pain_level = int(message)
            self.symptom_details['pain'] = {'severity': pain_level}
        except ValueError:
            if is_affirmative(message):
                self.symptom_details['pain'] = {'severity': 'mild'}
            else:
                self.symptom_details['pain'] = {'severity': 'none'}

        self.state = "medications"
        return "Doctor: Have you taken any medications recently? Please mention them or type 'no'."

    def _handle_medications(self, message):
        self.symptom_details['medications'] = message if message.lower() not in ["no", "none"] else "None"
        self.finished = True
        return self._generate_summary()

    def _predict_diseases(self):
        match_scores = []
        for disease, symptoms in disease_symptoms.items():
            matches = len(set(self.reported_symptoms) & set(symptoms))
            if matches > 0:
                score = matches / len(symptoms)
                if matches >= 2:
                    match_scores.append((disease, score))

        match_scores.sort(key=lambda x: x[1], reverse=True)
        self.predicted_diseases = match_scores[:2]

    def _generate_summary(self):
        report_lines = []
        for sym, details in self.symptom_details.items():
            if isinstance(details, dict) and details:
                formatted_details = ", ".join(f"{k}: {v}" for k, v in details.items())
                report_lines.append(f"- {sym.title()}: {formatted_details}")
            else:
                report_lines.append(f"- {sym.title()}: {details}")
        report = "\n".join(report_lines)

        disease_part = ""
        if self.predicted_diseases:
            disease_part = "\n\nPossible Conditions:" + "\n".join([
                f"- {disease} ({int(score * 100)}%)" for disease, score in self.predicted_diseases
            ])
        else:
            disease_part = "\n\nDoctor: No strong condition match was found based on reported symptoms."

        advice = "\n\nDoctor: Please note this is a preliminary virtual consultation. You should consult a physician if symptoms persist, worsen, or if you feel uncomfortable."

        return f"Doctor: Thank you for the detailed information!\n\n=== Medical Report ===\n{report}{disease_part}{advice}"

    def get_data(self):
        return {
            "conversation": self.conversation_history,
            "symptoms": list(self.reported_symptoms),
            "symptom_details": self.symptom_details,
            "predicted_diseases": self.predicted_diseases
        }