File size: 3,722 Bytes
7e44b27
 
 
12739b5
f2f3556
7e44b27
f2f3556
7e44b27
 
 
 
3e4d376
12739b5
f2f3556
 
3e4d376
f2f3556
7e44b27
 
54722d1
7e44b27
fed39ec
7e44b27
f2f3556
7e44b27
f2f3556
3e4d376
 
 
ad600ff
3e4d376
 
 
1a155a6
3e4d376
 
54722d1
7e44b27
 
87eb48c
ad09373
87eb48c
7e44b27
 
 
 
12739b5
f2f3556
 
12739b5
f2f3556
7e44b27
 
 
f2f3556
7e44b27
 
 
f2f3556
3e4d376
f2f3556
3e4d376
 
1a155a6
 
10cead1
1a155a6
10cead1
1a155a6
f2f3556
12739b5
 
f2f3556
12739b5
 
 
 
ad600ff
 
 
 
3e4d376
ad600ff
3e4d376
ad600ff
 
3e4d376
ad600ff
 
 
 
 
 
 
 
 
 
 
 
 
 
7e44b27
f2f3556
7e44b27
 
 
 
 
 
54722d1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import streamlit as st
import pandas as pd
import os
import joblib
from huggingface_hub import login, InferenceClient

from utils.model import preprocess
from utils.data_generator import load_or_generate_data
from utils.llm_prompt import explain_prediction
from utils.logger import init_log, log_explanation

# Constants
MODEL_VERSION = "XGB_v1.2_GPU_2024-04-13"
MODEL_PATH = "models/xgb_pipeline.pkl"
DATA_PATH = "data/patients_extended.csv"

# Hugging Face login
login(token=os.getenv("HUGGINGFACEHUB_API_TOKEN"))

# Streamlit setup
st.set_page_config(page_title="CareSight | Readmission Predictor", page_icon="🧠", layout="wide")
st.image("assets/logo.png", width=120)
st.title("🧠 CareSight: Readmission Risk Assistant")
st.markdown("Predict 30-day readmission risk with LLM explanations powered by SHAP.")

# Sidebar - LLM model selector
with st.sidebar:
    st.caption(f"🧠 Model Version: `{MODEL_VERSION}`")
    model_choice = st.selectbox("Choose LLM for explanation", [
        "meta-llama/Llama-3.2-3B-Instruct",
        "deepcogito/cogito-v1-preview-llama-3B",
        "microsoft/Phi-4-mini-instruct",
        "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
        "TheBloke/Mistral-7B-Instruct-v0.1-GGUF",
    ])

# Load LLM client
@st.cache_resource
def load_llm(model_name):
    return InferenceClient(
        model=model_name,
        token=os.getenv("HUGGINGFACEHUB_API_TOKEN")
    )

llm = load_llm(model_choice)

# Load data and model
data = pd.read_csv(DATA_PATH)
model = joblib.load(MODEL_PATH)

# Patient selector
selected_id = st.sidebar.selectbox("Select a patient", data['Patient_ID'].tolist())
patient = data[data['Patient_ID'] == selected_id].squeeze()

# Initialize session log
if "log_df" not in st.session_state:
    st.session_state["log_df"] = init_log()

# Display patient details
st.subheader("Patient Details")
with st.expander("Show Details"):
    st.json(patient.to_dict())

# Recompute model input and prediction for this patient
input_data_dict = patient.drop(labels=["Patient_ID", "Chronic_Condition_List", "Readmitted_30_Days"], errors="ignore").to_dict()
schema = model.named_steps["preprocessor"].get_feature_names_out()
X_input = preprocess(input_data_dict, schema)
X_input = pd.DataFrame(X_input, columns=schema)

pred_proba = model.named_steps["model"].predict_proba(X_input)[0][1]
pred_label = "πŸ”΄ High Risk" if pred_proba >= 0.5 else "🟒 Low Risk"

# Display prediction result
st.markdown(f"### Prediction Result: {pred_label}")
st.markdown(f"**Predicted Probability:** `{pred_proba:.2%}`")
st.markdown(f"**LLM Model Used:** `{model_choice}`")

# LLM Explanation auto-update
with st.spinner(f"Generating explanation with {model_choice}..."):
    try:
        explanation = explain_prediction(
            patient_id=patient["Patient_ID"],
            patient_data=patient.to_dict(),
            model_name=model_choice,
            _model=model,
            _client=llm
        )
        st.success("### LLM Explanation\n" + explanation)
    except Exception as e:
        fallback = f"Unable to generate LLM explanation due to error: {e}"
        st.warning(f"⚠️ {fallback}")
        explanation = fallback

    st.session_state["log_df"] = log_explanation(
        st.session_state["log_df"],
        patient_id=patient["Patient_ID"],
        model_name=model_choice,
        prediction=pred_proba,
        shap_summary="SHAP summary internal only",
        explanation=explanation
    )

# Log download
if not st.session_state["log_df"].empty:
    st.download_button(
        label="πŸ“„ Download Explanation Log as CSV",
        data=st.session_state["log_df"].to_csv(index=False),
        file_name="llm_explanations_log.csv",
        mime="text/csv"
    )