Spaces:
Running
Running
import pandas as pd | |
import json | |
import numpy as np | |
import faiss | |
from sklearn.feature_extraction.text import TfidfVectorizer | |
from transformers import pipeline | |
# ------------------------------- | |
# Load disease data and preprocess | |
# ------------------------------- | |
def load_disease_data(csv_path): | |
df = pd.read_csv(csv_path) | |
df.columns = df.columns.str.strip().str.lower() | |
df = df.fillna("") | |
disease_symptoms = {} | |
disease_precautions = {} | |
for _, row in df.iterrows(): | |
disease = row["disease"].strip() | |
symptoms = [s.strip().lower() for s in row["symptoms"].split(",") if s.strip()] | |
precautions = [p.strip() for p in row["precautions"].split(",") if p.strip()] | |
disease_symptoms[disease] = symptoms | |
disease_precautions[disease] = precautions | |
return disease_symptoms, disease_precautions | |
# Load CSV data (ensure this CSV file is in the repository root) | |
disease_symptoms, disease_precautions = load_disease_data("disease_sympts_prec_full.csv") | |
known_symptoms = set() | |
for syms in disease_symptoms.values(): | |
known_symptoms.update(syms) | |
# ------------------------------- | |
# Build symptom vectorizer and FAISS index | |
# ------------------------------- | |
vectorizer = TfidfVectorizer() | |
symptom_texts = [" ".join(symptoms) for symptoms in disease_symptoms.values()] | |
tfidf_matrix = vectorizer.fit_transform(symptom_texts).toarray() | |
index = faiss.IndexFlatL2(tfidf_matrix.shape[1]) | |
index.add(np.array(tfidf_matrix, dtype=np.float32)) | |
disease_list = list(disease_symptoms.keys()) | |
def find_closest_disease(user_symptoms): | |
max_matches = 0 | |
best_disease = None | |
for disease, symptoms in disease_symptoms.items(): | |
matches = len(set(user_symptoms) & set(symptoms)) | |
if matches > max_matches: | |
max_matches = matches | |
best_disease = disease | |
# Require at least 2 symptoms matching to consider it valid | |
if max_matches >= 2: | |
return best_disease | |
else: | |
return None | |
# ------------------------------- | |
# Load Medical NER model for symptom extraction | |
# ------------------------------- | |
medical_ner = pipeline( | |
"ner", | |
model="blaze999/Medical-NER", | |
tokenizer="blaze999/Medical-NER", | |
aggregation_strategy="simple" | |
) | |
def extract_symptoms_ner(text): | |
results = medical_ner(text) | |
extracted = [] | |
for r in results: | |
if "SIGN_SYMPTOM" in r["entity_group"]: | |
extracted.append(r["word"].lower()) | |
return list(set(extracted)) | |
def is_affirmative(answer): | |
answer_lower = answer.lower() | |
return any(word in answer_lower for word in ["yes", "yeah", "yep", "certainly", "sometimes", "a little"]) | |
import random | |
# Map symptoms to associated body parts | |
SYMPTOM_BODY_PARTS = { | |
"headache": "head", | |
"fever": None, | |
"cold": None, | |
"cough": "throat", | |
"chest pain": "chest", | |
"stomach ache": "abdomen", | |
"sore throat": "throat", | |
"back pain": "back", | |
"ear pain": "ear", | |
"toothache": "mouth", | |
"nausea": "stomach", | |
"vomiting": "stomach", | |
"diarrhea": "abdomen", | |
"shortness of breath": "chest", | |
"dizziness": "head", | |
"rash": "skin", | |
"eye pain": "eye", | |
"abdominal pain": "abdomen", | |
"joint pain": "joint", | |
"muscle pain": "muscle", | |
"neck pain": "neck" | |
} | |
# ------------------------------- | |
# Chatbot session class | |
# ------------------------------- | |
class ChatbotSession: | |
def __init__(self): | |
self.conversation_history = [] | |
self.reported_symptoms = set() | |
self.symptom_details = {} | |
self.asked_missing = set() | |
self.awaiting_followup = None | |
self.current_detail_symptom = None | |
self.state = "intake" | |
self.finished = False | |
self.predicted_diseases = [] | |
greeting = "Doctor: Hello! I'm here to assist you. Could you please describe how you're feeling today?" | |
self.conversation_history.append(greeting) | |
def process_message(self, message: str) -> str: | |
if self.finished: | |
return "Doctor: Thank you. Our session has ended. Wishing you good health!" | |
if self.state == "intake": | |
return self._handle_intake(message) | |
if self.state == "symptom_detail": | |
return self._handle_symptom_detail(message) | |
if self.state == "pain_check": | |
return self._handle_pain_check(message) | |
if self.state == "medications": | |
return self._handle_medications(message) | |
return "Doctor: Could you please clarify that?" | |
def _handle_intake(self, message): | |
if message.lower() in ["exit", "quit", "no"]: | |
if not self.reported_symptoms: | |
goodbye = "Doctor: It seems no symptoms were reported. Ending the session." | |
self.finished = True | |
return goodbye | |
else: | |
self._predict_diseases() | |
self.state = "pain_check" | |
return "Doctor: Before proceeding, are you experiencing any pain? If yes, please rate it 1-10 or type 'no'." | |
ner_results = extract_symptoms_ner(message) | |
if ner_results: | |
for sym in ner_results: | |
if sym not in self.reported_symptoms: | |
self.reported_symptoms.add(sym) | |
unexplored_symptoms = list(self.reported_symptoms - set(self.symptom_details.keys())) | |
if unexplored_symptoms: | |
symptom = random.choice(unexplored_symptoms) | |
self.current_detail_symptom = symptom | |
self.symptom_details[symptom] = {} | |
self.state = "symptom_detail" | |
return f"Doctor: About your '{symptom}', when did it start? (e.g., 2 days ago)" | |
else: | |
self._predict_diseases() | |
self.state = "pain_check" | |
return "Doctor: Before proceeding, are you experiencing any pain? If yes, please rate it 1-10 or type 'no'." | |
else: | |
return "Doctor: I couldn't clearly detect any medical symptoms. Could you please describe your condition differently?" | |
def _handle_symptom_detail(self, message): | |
if self.current_detail_symptom and 'duration' not in self.symptom_details[self.current_detail_symptom]: | |
self.symptom_details[self.current_detail_symptom]['duration'] = message | |
return f"Doctor: How severe is the '{self.current_detail_symptom}'? (mild, moderate, severe)" | |
if self.current_detail_symptom and 'severity' not in self.symptom_details[self.current_detail_symptom]: | |
self.symptom_details[self.current_detail_symptom]['severity'] = message | |
body_part = SYMPTOM_BODY_PARTS.get(self.current_detail_symptom) | |
if body_part: | |
self.symptom_details[self.current_detail_symptom]['location'] = body_part | |
self.state = "intake" | |
return "Doctor: Thank you. Are there any more symptoms you'd like to mention?" | |
else: | |
return f"Doctor: Where exactly do you feel the '{self.current_detail_symptom}'? (e.g., forehead, chest, abdomen)" | |
if self.current_detail_symptom and 'location' not in self.symptom_details[self.current_detail_symptom]: | |
self.symptom_details[self.current_detail_symptom]['location'] = message | |
self.state = "intake" | |
return "Doctor: Thank you. Are there any more symptoms you'd like to mention?" | |
return "Doctor: Please clarify that." | |
def _handle_pain_check(self, message): | |
try: | |
pain_level = int(message) | |
self.symptom_details['pain'] = {'severity': pain_level} | |
except ValueError: | |
if is_affirmative(message): | |
self.symptom_details['pain'] = {'severity': 'mild'} | |
else: | |
self.symptom_details['pain'] = {'severity': 'none'} | |
self.state = "medications" | |
return "Doctor: Have you taken any medications recently? Please mention them or type 'no'." | |
def _handle_medications(self, message): | |
self.symptom_details['medications'] = message if message.lower() not in ["no", "none"] else "None" | |
self.finished = True | |
return self._generate_summary() | |
def _predict_diseases(self): | |
match_scores = [] | |
for disease, symptoms in disease_symptoms.items(): | |
matches = len(set(self.reported_symptoms) & set(symptoms)) | |
if matches > 0: | |
score = matches / len(symptoms) | |
if matches >= 2: | |
match_scores.append((disease, score)) | |
match_scores.sort(key=lambda x: x[1], reverse=True) | |
self.predicted_diseases = match_scores[:2] | |
def _generate_summary(self): | |
report_lines = [] | |
for sym, details in self.symptom_details.items(): | |
if isinstance(details, dict) and details: | |
formatted_details = ", ".join(f"{k}: {v}" for k, v in details.items()) | |
report_lines.append(f"- {sym.title()}: {formatted_details}") | |
else: | |
report_lines.append(f"- {sym.title()}: {details}") | |
report = "\n".join(report_lines) | |
disease_part = "" | |
if self.predicted_diseases: | |
disease_part = "\n\nPossible Conditions:" + "\n".join([ | |
f"- {disease} ({int(score * 100)}%)" for disease, score in self.predicted_diseases | |
]) | |
else: | |
disease_part = "\n\nDoctor: No strong condition match was found based on reported symptoms." | |
advice = "\n\nDoctor: Please note this is a preliminary virtual consultation. You should consult a physician if symptoms persist, worsen, or if you feel uncomfortable." | |
return f"Doctor: Thank you for the detailed information!\n\n=== Medical Report ===\n{report}{disease_part}{advice}" | |
def get_data(self): | |
return { | |
"conversation": self.conversation_history, | |
"symptoms": list(self.reported_symptoms), | |
"symptom_details": self.symptom_details, | |
"predicted_diseases": self.predicted_diseases | |
} | |