PragmaticPete's picture
Update utils/llm_prompt.py
41a6ef1 verified
import pandas as pd
import shap
import streamlit as st
from utils.model import preprocess
def explain_prediction(patient_id: str, patient_data: dict, model_name: str, _model, _client):
try:
# Get schema from pipeline
schema = _model.named_steps["preprocessor"].get_feature_names_out()
# Preprocess patient and align schema
X = preprocess(patient_data, schema)
X = pd.DataFrame(X, columns=schema)
# SHAP explanation
explainer = shap.Explainer(_model.named_steps["model"])
shap_values = explainer(X)
# Top 3 SHAP contributors
top = pd.DataFrame({
"feature": X.columns,
"value": X.iloc[0],
"shap": shap_values[0].values
}).sort_values("shap", key=abs, ascending=False).head(3)
summary = "\n".join([
f"- {r['feature']}: {r['value']} (SHAP impact: {r['shap']:.2f})"
for _, r in top.iterrows()
])
prompt = f"""
You are a clinical AI assistant helping interpret a machine learning model’s prediction about a patient's 30-day hospital readmission risk.
Patient ID: {patient_id}
Primary Diagnosis: {patient_data.get("Primary_Diagnosis")}
Chronic Conditions: {patient_data.get("Chronic_Condition_List", "Not listed")}
Age: {patient_data.get("Age")}
Sex: {patient_data.get("Gender")}
Insurance: {patient_data.get("Insurance_Type")}
Frailty Index: {patient_data.get("Frailty_Index")}
Charlson Index: {patient_data.get("Charlson_Index")}
Polypharmacy Count: {patient_data.get("Polypharmacy_Count")}
### Top SHAP Factors:
{summary}
Please describe specifically how each feature listed above is influencing the model’s projected readmission probability. Focus on the direction (increasing or decreasing risk) and the clinical intuition behind it. Use concise, professional language suitable for a care coordination team.
"""
response = _client.text_generation(prompt, max_new_tokens=1024, temperature=0.1)
return response.strip()
except Exception as e:
return f"SHAP explanation failed for patient {patient_id}. Error: {str(e)}"