PragmaticPete's picture
Update app.py
ad600ff verified
raw
history blame contribute delete
3.72 kB
import streamlit as st
import pandas as pd
import os
import joblib
from huggingface_hub import login, InferenceClient
from utils.model import preprocess
from utils.data_generator import load_or_generate_data
from utils.llm_prompt import explain_prediction
from utils.logger import init_log, log_explanation
# Constants
MODEL_VERSION = "XGB_v1.2_GPU_2024-04-13"
MODEL_PATH = "models/xgb_pipeline.pkl"
DATA_PATH = "data/patients_extended.csv"
# Hugging Face login
login(token=os.getenv("HUGGINGFACEHUB_API_TOKEN"))
# Streamlit setup
st.set_page_config(page_title="CareSight | Readmission Predictor", page_icon="🧠", layout="wide")
st.image("assets/logo.png", width=120)
st.title("🧠 CareSight: Readmission Risk Assistant")
st.markdown("Predict 30-day readmission risk with LLM explanations powered by SHAP.")
# Sidebar - LLM model selector
with st.sidebar:
st.caption(f"🧠 Model Version: `{MODEL_VERSION}`")
model_choice = st.selectbox("Choose LLM for explanation", [
"meta-llama/Llama-3.2-3B-Instruct",
"deepcogito/cogito-v1-preview-llama-3B",
"microsoft/Phi-4-mini-instruct",
"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
"TheBloke/Mistral-7B-Instruct-v0.1-GGUF",
])
# Load LLM client
@st.cache_resource
def load_llm(model_name):
return InferenceClient(
model=model_name,
token=os.getenv("HUGGINGFACEHUB_API_TOKEN")
)
llm = load_llm(model_choice)
# Load data and model
data = pd.read_csv(DATA_PATH)
model = joblib.load(MODEL_PATH)
# Patient selector
selected_id = st.sidebar.selectbox("Select a patient", data['Patient_ID'].tolist())
patient = data[data['Patient_ID'] == selected_id].squeeze()
# Initialize session log
if "log_df" not in st.session_state:
st.session_state["log_df"] = init_log()
# Display patient details
st.subheader("Patient Details")
with st.expander("Show Details"):
st.json(patient.to_dict())
# Recompute model input and prediction for this patient
input_data_dict = patient.drop(labels=["Patient_ID", "Chronic_Condition_List", "Readmitted_30_Days"], errors="ignore").to_dict()
schema = model.named_steps["preprocessor"].get_feature_names_out()
X_input = preprocess(input_data_dict, schema)
X_input = pd.DataFrame(X_input, columns=schema)
pred_proba = model.named_steps["model"].predict_proba(X_input)[0][1]
pred_label = "πŸ”΄ High Risk" if pred_proba >= 0.5 else "🟒 Low Risk"
# Display prediction result
st.markdown(f"### Prediction Result: {pred_label}")
st.markdown(f"**Predicted Probability:** `{pred_proba:.2%}`")
st.markdown(f"**LLM Model Used:** `{model_choice}`")
# LLM Explanation auto-update
with st.spinner(f"Generating explanation with {model_choice}..."):
try:
explanation = explain_prediction(
patient_id=patient["Patient_ID"],
patient_data=patient.to_dict(),
model_name=model_choice,
_model=model,
_client=llm
)
st.success("### LLM Explanation\n" + explanation)
except Exception as e:
fallback = f"Unable to generate LLM explanation due to error: {e}"
st.warning(f"⚠️ {fallback}")
explanation = fallback
st.session_state["log_df"] = log_explanation(
st.session_state["log_df"],
patient_id=patient["Patient_ID"],
model_name=model_choice,
prediction=pred_proba,
shap_summary="SHAP summary internal only",
explanation=explanation
)
# Log download
if not st.session_state["log_df"].empty:
st.download_button(
label="πŸ“„ Download Explanation Log as CSV",
data=st.session_state["log_df"].to_csv(index=False),
file_name="llm_explanations_log.csv",
mime="text/csv"
)