Uspark / app /chatbot.py
PranayChamala's picture
Made changes to chatbot.py
455e0c8
raw
history blame contribute delete
10 kB
import pandas as pd
import json
import numpy as np
import faiss
from sklearn.feature_extraction.text import TfidfVectorizer
from transformers import pipeline
# -------------------------------
# Load disease data and preprocess
# -------------------------------
def load_disease_data(csv_path):
df = pd.read_csv(csv_path)
df.columns = df.columns.str.strip().str.lower()
df = df.fillna("")
disease_symptoms = {}
disease_precautions = {}
for _, row in df.iterrows():
disease = row["disease"].strip()
symptoms = [s.strip().lower() for s in row["symptoms"].split(",") if s.strip()]
precautions = [p.strip() for p in row["precautions"].split(",") if p.strip()]
disease_symptoms[disease] = symptoms
disease_precautions[disease] = precautions
return disease_symptoms, disease_precautions
# Load CSV data (ensure this CSV file is in the repository root)
disease_symptoms, disease_precautions = load_disease_data("disease_sympts_prec_full.csv")
known_symptoms = set()
for syms in disease_symptoms.values():
known_symptoms.update(syms)
# -------------------------------
# Build symptom vectorizer and FAISS index
# -------------------------------
vectorizer = TfidfVectorizer()
symptom_texts = [" ".join(symptoms) for symptoms in disease_symptoms.values()]
tfidf_matrix = vectorizer.fit_transform(symptom_texts).toarray()
index = faiss.IndexFlatL2(tfidf_matrix.shape[1])
index.add(np.array(tfidf_matrix, dtype=np.float32))
disease_list = list(disease_symptoms.keys())
def find_closest_disease(user_symptoms):
max_matches = 0
best_disease = None
for disease, symptoms in disease_symptoms.items():
matches = len(set(user_symptoms) & set(symptoms))
if matches > max_matches:
max_matches = matches
best_disease = disease
# Require at least 2 symptoms matching to consider it valid
if max_matches >= 2:
return best_disease
else:
return None
# -------------------------------
# Load Medical NER model for symptom extraction
# -------------------------------
medical_ner = pipeline(
"ner",
model="blaze999/Medical-NER",
tokenizer="blaze999/Medical-NER",
aggregation_strategy="simple"
)
def extract_symptoms_ner(text):
results = medical_ner(text)
extracted = []
for r in results:
if "SIGN_SYMPTOM" in r["entity_group"]:
extracted.append(r["word"].lower())
return list(set(extracted))
def is_affirmative(answer):
answer_lower = answer.lower()
return any(word in answer_lower for word in ["yes", "yeah", "yep", "certainly", "sometimes", "a little"])
import random
# Map symptoms to associated body parts
SYMPTOM_BODY_PARTS = {
"headache": "head",
"fever": None,
"cold": None,
"cough": "throat",
"chest pain": "chest",
"stomach ache": "abdomen",
"sore throat": "throat",
"back pain": "back",
"ear pain": "ear",
"toothache": "mouth",
"nausea": "stomach",
"vomiting": "stomach",
"diarrhea": "abdomen",
"shortness of breath": "chest",
"dizziness": "head",
"rash": "skin",
"eye pain": "eye",
"abdominal pain": "abdomen",
"joint pain": "joint",
"muscle pain": "muscle",
"neck pain": "neck"
}
# -------------------------------
# Chatbot session class
# -------------------------------
class ChatbotSession:
def __init__(self):
self.conversation_history = []
self.reported_symptoms = set()
self.symptom_details = {}
self.asked_missing = set()
self.awaiting_followup = None
self.current_detail_symptom = None
self.state = "intake"
self.finished = False
self.predicted_diseases = []
greeting = "Doctor: Hello! I'm here to assist you. Could you please describe how you're feeling today?"
self.conversation_history.append(greeting)
def process_message(self, message: str) -> str:
if self.finished:
return "Doctor: Thank you. Our session has ended. Wishing you good health!"
if self.state == "intake":
return self._handle_intake(message)
if self.state == "symptom_detail":
return self._handle_symptom_detail(message)
if self.state == "pain_check":
return self._handle_pain_check(message)
if self.state == "medications":
return self._handle_medications(message)
return "Doctor: Could you please clarify that?"
def _handle_intake(self, message):
if message.lower() in ["exit", "quit", "no"]:
if not self.reported_symptoms:
goodbye = "Doctor: It seems no symptoms were reported. Ending the session."
self.finished = True
return goodbye
else:
self._predict_diseases()
self.state = "pain_check"
return "Doctor: Before proceeding, are you experiencing any pain? If yes, please rate it 1-10 or type 'no'."
ner_results = extract_symptoms_ner(message)
if ner_results:
for sym in ner_results:
if sym not in self.reported_symptoms:
self.reported_symptoms.add(sym)
unexplored_symptoms = list(self.reported_symptoms - set(self.symptom_details.keys()))
if unexplored_symptoms:
symptom = random.choice(unexplored_symptoms)
self.current_detail_symptom = symptom
self.symptom_details[symptom] = {}
self.state = "symptom_detail"
return f"Doctor: About your '{symptom}', when did it start? (e.g., 2 days ago)"
else:
self._predict_diseases()
self.state = "pain_check"
return "Doctor: Before proceeding, are you experiencing any pain? If yes, please rate it 1-10 or type 'no'."
else:
return "Doctor: I couldn't clearly detect any medical symptoms. Could you please describe your condition differently?"
def _handle_symptom_detail(self, message):
if self.current_detail_symptom and 'duration' not in self.symptom_details[self.current_detail_symptom]:
self.symptom_details[self.current_detail_symptom]['duration'] = message
return f"Doctor: How severe is the '{self.current_detail_symptom}'? (mild, moderate, severe)"
if self.current_detail_symptom and 'severity' not in self.symptom_details[self.current_detail_symptom]:
self.symptom_details[self.current_detail_symptom]['severity'] = message
body_part = SYMPTOM_BODY_PARTS.get(self.current_detail_symptom)
if body_part:
self.symptom_details[self.current_detail_symptom]['location'] = body_part
self.state = "intake"
return "Doctor: Thank you. Are there any more symptoms you'd like to mention?"
else:
return f"Doctor: Where exactly do you feel the '{self.current_detail_symptom}'? (e.g., forehead, chest, abdomen)"
if self.current_detail_symptom and 'location' not in self.symptom_details[self.current_detail_symptom]:
self.symptom_details[self.current_detail_symptom]['location'] = message
self.state = "intake"
return "Doctor: Thank you. Are there any more symptoms you'd like to mention?"
return "Doctor: Please clarify that."
def _handle_pain_check(self, message):
try:
pain_level = int(message)
self.symptom_details['pain'] = {'severity': pain_level}
except ValueError:
if is_affirmative(message):
self.symptom_details['pain'] = {'severity': 'mild'}
else:
self.symptom_details['pain'] = {'severity': 'none'}
self.state = "medications"
return "Doctor: Have you taken any medications recently? Please mention them or type 'no'."
def _handle_medications(self, message):
self.symptom_details['medications'] = message if message.lower() not in ["no", "none"] else "None"
self.finished = True
return self._generate_summary()
def _predict_diseases(self):
match_scores = []
for disease, symptoms in disease_symptoms.items():
matches = len(set(self.reported_symptoms) & set(symptoms))
if matches > 0:
score = matches / len(symptoms)
if matches >= 2:
match_scores.append((disease, score))
match_scores.sort(key=lambda x: x[1], reverse=True)
self.predicted_diseases = match_scores[:2]
def _generate_summary(self):
report_lines = []
for sym, details in self.symptom_details.items():
if isinstance(details, dict) and details:
formatted_details = ", ".join(f"{k}: {v}" for k, v in details.items())
report_lines.append(f"- {sym.title()}: {formatted_details}")
else:
report_lines.append(f"- {sym.title()}: {details}")
report = "\n".join(report_lines)
disease_part = ""
if self.predicted_diseases:
disease_part = "\n\nPossible Conditions:" + "\n".join([
f"- {disease} ({int(score * 100)}%)" for disease, score in self.predicted_diseases
])
else:
disease_part = "\n\nDoctor: No strong condition match was found based on reported symptoms."
advice = "\n\nDoctor: Please note this is a preliminary virtual consultation. You should consult a physician if symptoms persist, worsen, or if you feel uncomfortable."
return f"Doctor: Thank you for the detailed information!\n\n=== Medical Report ===\n{report}{disease_part}{advice}"
def get_data(self):
return {
"conversation": self.conversation_history,
"symptoms": list(self.reported_symptoms),
"symptom_details": self.symptom_details,
"predicted_diseases": self.predicted_diseases
}