Mpodszus commited on
Commit
a3ffeca
·
verified ·
1 Parent(s): 28da825

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -12
app.py CHANGED
@@ -1,4 +1,4 @@
1
- import pickle
2
  import pandas as pd
3
  import shap
4
  from shap.plots._force_matplotlib import draw_additive_plot
@@ -6,9 +6,9 @@ import gradio as gr
6
  import numpy as np
7
  import matplotlib.pyplot as plt
8
 
9
- # Load the XGBoost model
10
- with open("h22_xgb_Final.json", "rb") as f:
11
- loaded_model = pickle.load(f)
12
 
13
  # Setup SHAP Explainer
14
  explainer = shap.Explainer(loaded_model) # PLEASE DO NOT CHANGE THIS.
@@ -19,19 +19,24 @@ def main_func(SupportiveGM, Merit, LearningDevelopment, WorkEnvironment, Engagem
19
  'LearningDevelopment': LearningDevelopment, 'WorkEnvironment': WorkEnvironment,
20
  'Engagement': Engagement, 'WellBeing': WellBeing}, orient='index').transpose()
21
 
22
- # Predict probabilities
23
- prob = loaded_model.predict_proba(new_row)
 
 
 
 
 
 
24
 
25
- # Compute SHAP values
26
- shap_values = explainer(new_row)
27
  plot = shap.plots.bar(shap_values[0], max_display=6, order=shap.Explanation.abs, show_data='auto', show=False)
28
-
29
  plt.tight_layout()
30
  local_plot = plt.gcf()
31
  plt.rcParams['figure.figsize'] = (6, 4)
32
  plt.close()
33
 
34
- return {"Leave": float(prob[0][0]), "Stay": 1 - float(prob[0][0])}, local_plot
35
 
36
  # Create the UI
37
  title = "**Mod 3 Team 5: Employee Turnover Predictor**"
@@ -79,5 +84,3 @@ with gr.Blocks(title=title) as demo:
79
  [label, local_plot], main_func, cache_examples=True)
80
 
81
  demo.launch()
82
-
83
-
 
1
+ import xgboost as xgb
2
  import pandas as pd
3
  import shap
4
  from shap.plots._force_matplotlib import draw_additive_plot
 
6
  import numpy as np
7
  import matplotlib.pyplot as plt
8
 
9
+ # Load the XGBoost model from JSON
10
+ loaded_model = xgb.Booster()
11
+ loaded_model.load_model("h22_xgb_Final.json")
12
 
13
  # Setup SHAP Explainer
14
  explainer = shap.Explainer(loaded_model) # PLEASE DO NOT CHANGE THIS.
 
19
  'LearningDevelopment': LearningDevelopment, 'WorkEnvironment': WorkEnvironment,
20
  'Engagement': Engagement, 'WellBeing': WellBeing}, orient='index').transpose()
21
 
22
+ # Convert input to DMatrix format (needed for XGBoost Booster)
23
+ dmatrix_new = xgb.DMatrix(new_row)
24
+
25
+ # Predict probabilities (assuming binary classification)
26
+ prob = loaded_model.predict(dmatrix_new) # Returns only probabilities for class 1
27
+
28
+ # Ensure proper formatting for SHAP
29
+ shap_values = explainer(dmatrix_new)
30
 
31
+ # Generate SHAP Plot
 
32
  plot = shap.plots.bar(shap_values[0], max_display=6, order=shap.Explanation.abs, show_data='auto', show=False)
33
+
34
  plt.tight_layout()
35
  local_plot = plt.gcf()
36
  plt.rcParams['figure.figsize'] = (6, 4)
37
  plt.close()
38
 
39
+ return {"Leave": 1 - float(prob[0]), "Stay": float(prob[0])}, local_plot
40
 
41
  # Create the UI
42
  title = "**Mod 3 Team 5: Employee Turnover Predictor**"
 
84
  [label, local_plot], main_func, cache_examples=True)
85
 
86
  demo.launch()