import pickle import xgboost as xgb import pandas as pd import shap from shap.plots._force_matplotlib import draw_additive_plot import gradio as gr import numpy as np import matplotlib.pyplot as plt # Load the XGBoost model from Pickle with open("h22_xgb_Final.pkl", "wb") as f: loaded_model = pickle.load(f) # Ensure model is a Booster (handles both XGBClassifier & Booster cases) if isinstance(loaded_model, xgb.XGBClassifier): loaded_model = loaded_model.get_booster() # Setup SHAP Explainer for XGBoost explainer = shap.TreeExplainer(loaded_model) # Use TreeExplainer for XGBoost models # Define the prediction function def main_func(SupportiveGM, Merit, LearningDevelopment, WorkEnvironment, Engagement, WellBeing): new_row = pd.DataFrame.from_dict({ 'SupportiveGM': SupportiveGM, 'Merit': Merit, 'LearningDevelopment': LearningDevelopment, 'WorkEnvironment': WorkEnvironment, 'Engagement': Engagement, 'WellBeing': WellBeing }, orient='index').transpose() # Convert new_row to DMatrix for XGBoost Booster dmatrix_new = xgb.DMatrix(new_row) # Predict probability for staying (XGBoost Booster returns only one class probability) prob = loaded_model.predict(dmatrix_new) # Compute SHAP values shap_values = explainer.shap_values(new_row) plot = shap.plots.bar(shap_values, max_display=6, order=shap.Explanation.abs, show_data='auto', show=False) plt.tight_layout() local_plot = plt.gcf() plt.rcParams['figure.figsize'] = (6, 4) plt.close() return {"Leave": 1 - float(prob[0]), "Stay": float(prob[0])}, local_plot # Create the UI title = "**Mod 3 Team 5: Employee Turnover Predictor**" description1 = """ This app takes six inputs about employees' satisfaction with different aspects of their work (such as work-life balance, ...) and predicts whether the employee intends to stay with the employer or leave. The outputs include: 1. The predicted probability of staying or leaving. 2. A SHAP plot that visualizes how different factors impact the prediction. """ description2 = """ To use the app, adjust the values of the six employee satisfaction factors and click **Analyze**. ✨ """ with gr.Blocks(title=title) as demo: gr.Markdown(f"## {title}") gr.Markdown(description1) gr.Markdown("""---""") gr.Markdown(description2) gr.Markdown("""---""") with gr.Row(): with gr.Column(): SupportiveGM = gr.Slider(label="Supportive GM Score", minimum=1, maximum=5, value=4, step=0.1) Merit = gr.Slider(label="Merit Score", minimum=1, maximum=5, value=4, step=0.1) LearningDevelopment = gr.Slider(label="Learning & Development Score", minimum=1, maximum=5, value=4, step=0.1) WorkEnvironment = gr.Slider(label="Work Environment Score", minimum=1, maximum=5, value=4, step=0.1) Engagement = gr.Slider(label="Engagement Score", minimum=1, maximum=5, value=4, step=0.1) WellBeing = gr.Slider(label="Well-Being Score", minimum=1, maximum=5, value=4, step=0.1) submit_btn = gr.Button("Analyze") with gr.Column(visible=True, scale=1, min_width=600) as output_col: label = gr.Label(label="Predicted Turnover Probability") local_plot = gr.Plot(label="SHAP Plot:") submit_btn.click( main_func, [SupportiveGM, Merit, LearningDevelopment, WorkEnvironment, Engagement, WellBeing], [label, local_plot], api_name="Employee_Turnover" ) gr.Markdown("### Click on an example below to see how it works:") gr.Examples([[4, 4, 4, 4, 5, 5], [5, 4, 5, 4, 4, 4]], [SupportiveGM, Merit, LearningDevelopment, WorkEnvironment, Engagement, WellBeing], [label, local_plot], main_func, cache_examples=True) demo.launch()