|
import streamlit as st |
|
import pandas as pd |
|
import os |
|
import joblib |
|
from huggingface_hub import login, InferenceClient |
|
|
|
from utils.model import preprocess |
|
from utils.data_generator import load_or_generate_data |
|
from utils.llm_prompt import explain_prediction |
|
from utils.logger import init_log, log_explanation |
|
|
|
|
|
MODEL_VERSION = "XGB_v1.2_GPU_2024-04-13" |
|
MODEL_PATH = "models/xgb_pipeline.pkl" |
|
DATA_PATH = "data/patients_extended.csv" |
|
|
|
|
|
login(token=os.getenv("HUGGINGFACEHUB_API_TOKEN")) |
|
|
|
|
|
st.set_page_config(page_title="CareSight | Readmission Predictor", page_icon="π§ ", layout="wide") |
|
st.image("assets/logo.png", width=120) |
|
st.title("π§ CareSight: Readmission Risk Assistant") |
|
st.markdown("Predict 30-day readmission risk with LLM explanations powered by SHAP.") |
|
|
|
|
|
with st.sidebar: |
|
st.caption(f"π§ Model Version: `{MODEL_VERSION}`") |
|
model_choice = st.selectbox("Choose LLM for explanation", [ |
|
"meta-llama/Llama-3.2-3B-Instruct", |
|
"deepcogito/cogito-v1-preview-llama-3B", |
|
"microsoft/Phi-4-mini-instruct", |
|
"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", |
|
"TheBloke/Mistral-7B-Instruct-v0.1-GGUF", |
|
]) |
|
|
|
|
|
@st.cache_resource |
|
def load_llm(model_name): |
|
return InferenceClient( |
|
model=model_name, |
|
token=os.getenv("HUGGINGFACEHUB_API_TOKEN") |
|
) |
|
|
|
llm = load_llm(model_choice) |
|
|
|
|
|
data = pd.read_csv(DATA_PATH) |
|
model = joblib.load(MODEL_PATH) |
|
|
|
|
|
selected_id = st.sidebar.selectbox("Select a patient", data['Patient_ID'].tolist()) |
|
patient = data[data['Patient_ID'] == selected_id].squeeze() |
|
|
|
|
|
if "log_df" not in st.session_state: |
|
st.session_state["log_df"] = init_log() |
|
|
|
|
|
st.subheader("Patient Details") |
|
with st.expander("Show Details"): |
|
st.json(patient.to_dict()) |
|
|
|
|
|
input_data_dict = patient.drop(labels=["Patient_ID", "Chronic_Condition_List", "Readmitted_30_Days"], errors="ignore").to_dict() |
|
schema = model.named_steps["preprocessor"].get_feature_names_out() |
|
X_input = preprocess(input_data_dict, schema) |
|
X_input = pd.DataFrame(X_input, columns=schema) |
|
|
|
pred_proba = model.named_steps["model"].predict_proba(X_input)[0][1] |
|
pred_label = "π΄ High Risk" if pred_proba >= 0.5 else "π’ Low Risk" |
|
|
|
|
|
st.markdown(f"### Prediction Result: {pred_label}") |
|
st.markdown(f"**Predicted Probability:** `{pred_proba:.2%}`") |
|
st.markdown(f"**LLM Model Used:** `{model_choice}`") |
|
|
|
|
|
with st.spinner(f"Generating explanation with {model_choice}..."): |
|
try: |
|
explanation = explain_prediction( |
|
patient_id=patient["Patient_ID"], |
|
patient_data=patient.to_dict(), |
|
model_name=model_choice, |
|
_model=model, |
|
_client=llm |
|
) |
|
st.success("### LLM Explanation\n" + explanation) |
|
except Exception as e: |
|
fallback = f"Unable to generate LLM explanation due to error: {e}" |
|
st.warning(f"β οΈ {fallback}") |
|
explanation = fallback |
|
|
|
st.session_state["log_df"] = log_explanation( |
|
st.session_state["log_df"], |
|
patient_id=patient["Patient_ID"], |
|
model_name=model_choice, |
|
prediction=pred_proba, |
|
shap_summary="SHAP summary internal only", |
|
explanation=explanation |
|
) |
|
|
|
|
|
if not st.session_state["log_df"].empty: |
|
st.download_button( |
|
label="π Download Explanation Log as CSV", |
|
data=st.session_state["log_df"].to_csv(index=False), |
|
file_name="llm_explanations_log.csv", |
|
mime="text/csv" |
|
) |
|
|