File size: 3,722 Bytes
7e44b27 12739b5 f2f3556 7e44b27 f2f3556 7e44b27 3e4d376 12739b5 f2f3556 3e4d376 f2f3556 7e44b27 54722d1 7e44b27 fed39ec 7e44b27 f2f3556 7e44b27 f2f3556 3e4d376 ad600ff 3e4d376 1a155a6 3e4d376 54722d1 7e44b27 87eb48c ad09373 87eb48c 7e44b27 12739b5 f2f3556 12739b5 f2f3556 7e44b27 f2f3556 7e44b27 f2f3556 3e4d376 f2f3556 3e4d376 1a155a6 10cead1 1a155a6 10cead1 1a155a6 f2f3556 12739b5 f2f3556 12739b5 ad600ff 3e4d376 ad600ff 3e4d376 ad600ff 3e4d376 ad600ff 7e44b27 f2f3556 7e44b27 54722d1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
import streamlit as st
import pandas as pd
import os
import joblib
from huggingface_hub import login, InferenceClient
from utils.model import preprocess
from utils.data_generator import load_or_generate_data
from utils.llm_prompt import explain_prediction
from utils.logger import init_log, log_explanation
# Constants
MODEL_VERSION = "XGB_v1.2_GPU_2024-04-13"
MODEL_PATH = "models/xgb_pipeline.pkl"
DATA_PATH = "data/patients_extended.csv"
# Hugging Face login
login(token=os.getenv("HUGGINGFACEHUB_API_TOKEN"))
# Streamlit setup
st.set_page_config(page_title="CareSight | Readmission Predictor", page_icon="π§ ", layout="wide")
st.image("assets/logo.png", width=120)
st.title("π§ CareSight: Readmission Risk Assistant")
st.markdown("Predict 30-day readmission risk with LLM explanations powered by SHAP.")
# Sidebar - LLM model selector
with st.sidebar:
st.caption(f"π§ Model Version: `{MODEL_VERSION}`")
model_choice = st.selectbox("Choose LLM for explanation", [
"meta-llama/Llama-3.2-3B-Instruct",
"deepcogito/cogito-v1-preview-llama-3B",
"microsoft/Phi-4-mini-instruct",
"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
"TheBloke/Mistral-7B-Instruct-v0.1-GGUF",
])
# Load LLM client
@st.cache_resource
def load_llm(model_name):
return InferenceClient(
model=model_name,
token=os.getenv("HUGGINGFACEHUB_API_TOKEN")
)
llm = load_llm(model_choice)
# Load data and model
data = pd.read_csv(DATA_PATH)
model = joblib.load(MODEL_PATH)
# Patient selector
selected_id = st.sidebar.selectbox("Select a patient", data['Patient_ID'].tolist())
patient = data[data['Patient_ID'] == selected_id].squeeze()
# Initialize session log
if "log_df" not in st.session_state:
st.session_state["log_df"] = init_log()
# Display patient details
st.subheader("Patient Details")
with st.expander("Show Details"):
st.json(patient.to_dict())
# Recompute model input and prediction for this patient
input_data_dict = patient.drop(labels=["Patient_ID", "Chronic_Condition_List", "Readmitted_30_Days"], errors="ignore").to_dict()
schema = model.named_steps["preprocessor"].get_feature_names_out()
X_input = preprocess(input_data_dict, schema)
X_input = pd.DataFrame(X_input, columns=schema)
pred_proba = model.named_steps["model"].predict_proba(X_input)[0][1]
pred_label = "π΄ High Risk" if pred_proba >= 0.5 else "π’ Low Risk"
# Display prediction result
st.markdown(f"### Prediction Result: {pred_label}")
st.markdown(f"**Predicted Probability:** `{pred_proba:.2%}`")
st.markdown(f"**LLM Model Used:** `{model_choice}`")
# LLM Explanation auto-update
with st.spinner(f"Generating explanation with {model_choice}..."):
try:
explanation = explain_prediction(
patient_id=patient["Patient_ID"],
patient_data=patient.to_dict(),
model_name=model_choice,
_model=model,
_client=llm
)
st.success("### LLM Explanation\n" + explanation)
except Exception as e:
fallback = f"Unable to generate LLM explanation due to error: {e}"
st.warning(f"β οΈ {fallback}")
explanation = fallback
st.session_state["log_df"] = log_explanation(
st.session_state["log_df"],
patient_id=patient["Patient_ID"],
model_name=model_choice,
prediction=pred_proba,
shap_summary="SHAP summary internal only",
explanation=explanation
)
# Log download
if not st.session_state["log_df"].empty:
st.download_button(
label="π Download Explanation Log as CSV",
data=st.session_state["log_df"].to_csv(index=False),
file_name="llm_explanations_log.csv",
mime="text/csv"
)
|