Mpodszus commited on
Commit
313ffa3
·
verified ·
1 Parent(s): a3ffeca

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -19
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import xgboost as xgb
2
  import pandas as pd
3
  import shap
@@ -6,36 +7,43 @@ import gradio as gr
6
  import numpy as np
7
  import matplotlib.pyplot as plt
8
 
9
- # Load the XGBoost model from JSON
10
- loaded_model = xgb.Booster()
11
- loaded_model.load_model("h22_xgb_Final.json")
12
 
13
- # Setup SHAP Explainer
14
- explainer = shap.Explainer(loaded_model) # PLEASE DO NOT CHANGE THIS.
 
 
 
 
15
 
16
  # Define the prediction function
17
  def main_func(SupportiveGM, Merit, LearningDevelopment, WorkEnvironment, Engagement, WellBeing):
18
- new_row = pd.DataFrame.from_dict({'SupportiveGM': SupportiveGM, 'Merit': Merit,
19
- 'LearningDevelopment': LearningDevelopment, 'WorkEnvironment': WorkEnvironment,
20
- 'Engagement': Engagement, 'WellBeing': WellBeing}, orient='index').transpose()
 
 
 
 
 
21
 
22
- # Convert input to DMatrix format (needed for XGBoost Booster)
23
  dmatrix_new = xgb.DMatrix(new_row)
24
 
25
- # Predict probabilities (assuming binary classification)
26
- prob = loaded_model.predict(dmatrix_new) # Returns only probabilities for class 1
27
-
28
- # Ensure proper formatting for SHAP
29
- shap_values = explainer(dmatrix_new)
30
-
31
- # Generate SHAP Plot
32
- plot = shap.plots.bar(shap_values[0], max_display=6, order=shap.Explanation.abs, show_data='auto', show=False)
33
-
34
  plt.tight_layout()
35
  local_plot = plt.gcf()
36
  plt.rcParams['figure.figsize'] = (6, 4)
37
  plt.close()
38
-
39
  return {"Leave": 1 - float(prob[0]), "Stay": float(prob[0])}, local_plot
40
 
41
  # Create the UI
@@ -84,3 +92,4 @@ with gr.Blocks(title=title) as demo:
84
  [label, local_plot], main_func, cache_examples=True)
85
 
86
  demo.launch()
 
 
1
+ import pickle
2
  import xgboost as xgb
3
  import pandas as pd
4
  import shap
 
7
  import numpy as np
8
  import matplotlib.pyplot as plt
9
 
10
+ # Load the XGBoost model from Pickle
11
+ with open("h22_xgb_Final.pkl", "rb") as f:
12
+ loaded_model = pickle.load(f)
13
 
14
+ # Ensure model is a Booster (handles both XGBClassifier & Booster cases)
15
+ if isinstance(loaded_model, xgb.XGBClassifier):
16
+ loaded_model = loaded_model.get_booster()
17
+
18
+ # Setup SHAP Explainer for XGBoost
19
+ explainer = shap.TreeExplainer(loaded_model) # Use TreeExplainer for XGBoost models
20
 
21
  # Define the prediction function
22
  def main_func(SupportiveGM, Merit, LearningDevelopment, WorkEnvironment, Engagement, WellBeing):
23
+ new_row = pd.DataFrame.from_dict({
24
+ 'SupportiveGM': SupportiveGM,
25
+ 'Merit': Merit,
26
+ 'LearningDevelopment': LearningDevelopment,
27
+ 'WorkEnvironment': WorkEnvironment,
28
+ 'Engagement': Engagement,
29
+ 'WellBeing': WellBeing
30
+ }, orient='index').transpose()
31
 
32
+ # Convert new_row to DMatrix for XGBoost Booster
33
  dmatrix_new = xgb.DMatrix(new_row)
34
 
35
+ # Predict probability for staying (XGBoost Booster returns only one class probability)
36
+ prob = loaded_model.predict(dmatrix_new)
37
+
38
+ # Compute SHAP values
39
+ shap_values = explainer.shap_values(new_row)
40
+ plot = shap.plots.bar(shap_values, max_display=6, order=shap.Explanation.abs, show_data='auto', show=False)
41
+
 
 
42
  plt.tight_layout()
43
  local_plot = plt.gcf()
44
  plt.rcParams['figure.figsize'] = (6, 4)
45
  plt.close()
46
+
47
  return {"Leave": 1 - float(prob[0]), "Stay": float(prob[0])}, local_plot
48
 
49
  # Create the UI
 
92
  [label, local_plot], main_func, cache_examples=True)
93
 
94
  demo.launch()
95
+