Mod3_Team5 / app.py
Mpodszus's picture
Update app.py
7965a1d verified
raw
history blame
5.69 kB
import pickle
import pandas as pd
import shap
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
# Load the XGBoost model from Pickle
loaded_model = pickle.load(open("h22_xgb_Final(2).pkl", 'rb'))
# Setup SHAP Explainer for XGBoost (Do not change this)
explainer = shap.Explainer(loaded_model)
def safe_convert(value, default, min_val, max_val):
try:
num = float(value)
return max(min_val, min(num, max_val)) # Ensure within range
except (TypeError, ValueError):
return default # Use default if conversion fails
# Create the main function for the model
def main_func(Department, ChainScale, SupportiveGM, Merit, LearningDevelopment, WorkEnvironment, Engagement, WellBeing):
# These mappings are EXAMPLES only, not used in the model
ChainScale_mapping = {
'Luxury': 1,
'Upper Midscale': 2,
'Upper Upscale': 3,
'Upscale': 4,
'Independent': 5,
}
department_mapping = {
"Guest Services": 1,
"Food and Beverage": 2,
"Housekeeping": 3,
"Front Office Operations": 4,
"Guest Activities": 5,
}
# Convert inputs to safe numeric values
LearningDevelopment = safe_convert(LearningDevelopment, 3.0, 1, 5)
SupportiveGM = safe_convert(SupportiveGM, 3.0, 1, 5)
Merit = safe_convert(Merit, 3.0, 1, 5)
WorkEnvironment = safe_convert(WorkEnvironment, 3.0, 1, 5)
Engagement = safe_convert(Engagement, 3.0, 1, 5)
WellBeing = safe_convert(WellBeing, 3.0, 1, 5)
# Only include model-relevant features
new_row = pd.DataFrame({
'SupportiveGM': [SupportiveGM],
'Merit': [Merit],
'LearningDevelopment': [LearningDevelopment],
'WorkEnvironment': [WorkEnvironment],
'Engagement': [Engagement],
'WellBeing': [WellBeing]
}).astype(float)
# Predict probabilities
prob = loaded_model.predict_proba(new_row)
if prob.shape[1] == 2:
leave_prob = float(prob[0][0])
stay_prob = float(prob[0][1])
else:
leave_prob = float(prob[0])
stay_prob = 1 - leave_prob
# Generate SHAP values
shap_values = explainer(new_row)
fig, ax = plt.subplots(figsize=(8, 4))
shap.waterfall_plot(shap.Explanation(
values=shap_values.values[0],
base_values=shap_values.base_values[0],
data=new_row.iloc[0]
))
plt.tight_layout()
local_plot = plt.gcf()
plt.close()
return {"Leave": leave_prob, "Stay": stay_prob}, local_plot
# Create the UI
title = "**Mod 3 Team 5: Employee Turnover Predictor & Interpreter**"
description1 = """
This app predicts whether an employee intends to stay or leave based on satisfaction factors.
"""
description2 = """
To use the app, adjust the values of the employee satisfaction factors and click on Analyze.
"""
with gr.Blocks(title=title) as demo:
gr.Markdown(f"## {title}")
gr.Markdown(description1)
gr.Markdown("""---""")
gr.Markdown(description2)
gr.Markdown("""---""")
with gr.Row():
with gr.Column():
Department = gr.Radio(
["Guest Services", "Food and Beverage", "Housekeeping", "Front Office Operations", "Guest Activities"],
label="Department (Example Only)",
value="Guest Services"
)
ChainScale = gr.Dropdown(
["Luxury", "Upper Midscale", "Upper Upscale", "Upscale", "Independent"],
label="ChainScale (Example Only)",
value="Upper Upscale"
)
SupportiveGM = gr.Slider(
label="SupportiveGM Score", minimum=1, maximum=5, value=4, step=0.1,
interactive=True
)
Merit = gr.Slider(
label="Merit Score", minimum=1, maximum=5, value=4, step=0.1,
interactive=True
)
LearningDevelopment = gr.Slider(
label="Learning and Development Score", minimum=1, maximum=5, value=4, step=0.1,
interactive=True
)
WorkEnvironment = gr.Slider(
label="Work Environment Score", minimum=1, maximum=5, value=4, step=0.1,
interactive=True
)
Engagement = gr.Slider(
label="Engagement Score", minimum=1, maximum=5, value=4, step=0.1,
interactive=True
)
WellBeing = gr.Slider(
label="Well-Being Score", minimum=1, maximum=5, value=4, step=0.1,
interactive=True
)
submit_btn = gr.Button("Analyze")
with gr.Column(visible=True, scale=1, min_width=600) as output_col:
label = gr.Label(label="Predicted Intent to Stay vs Leave")
local_plot = gr.Plot(label='SHAP Waterfall Analysis')
submit_btn.click(
main_func,
[Department, ChainScale, SupportiveGM, Merit, LearningDevelopment, WorkEnvironment, Engagement, WellBeing],
[label, local_plot],
api_name="Employee_Turnover"
)
gr.Markdown("### Click on any of the examples below to see how it works:")
gr.Examples(
[
["Guest Services", "Upper Upscale", 4.1, 3.7, 3.9, 4.2, 4.4, 4.3],
["Food and Beverage", "Upper Upscale", 3.9, 3.7, 4.1, 4.3, 4.5, 4.4],
["Housekeeping", "Upper Upscale", 4.3, 4.0, 4.3, 4.4, 4.5, 4.4]
],
[Department, ChainScale, SupportiveGM, Merit, LearningDevelopment, WorkEnvironment, Engagement, WellBeing],
[label, local_plot],
main_func,
cache_examples=True
)
demo.launch()