Mpodszus commited on
Commit
055398c
·
verified ·
1 Parent(s): 709c876

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -9
app.py CHANGED
@@ -7,16 +7,16 @@ import gradio as gr
7
  import numpy as np
8
  import matplotlib.pyplot as plt
9
 
10
- # Load the XGBoost model from Pickle
11
  with open("h22_xgb_Final.pkl", "rb") as f:
12
- loaded_model = pickle.load(f)
13
 
14
- # Ensure model is a Booster (handles both XGBClassifier & Booster cases)
15
- if isinstance(loaded_model, xgb.XGBClassifier):
16
- loaded_model = loaded_model.get_booster()
17
 
18
  # Setup SHAP Explainer for XGBoost
19
- explainer = shap.TreeExplainer(loaded_model) # Use TreeExplainer for XGBoost models
20
 
21
  # Define the prediction function
22
  def main_func(SupportiveGM, Merit, LearningDevelopment, WorkEnvironment, Engagement, WellBeing):
@@ -32,12 +32,14 @@ def main_func(SupportiveGM, Merit, LearningDevelopment, WorkEnvironment, Engagem
32
  # Convert new_row to DMatrix for XGBoost Booster
33
  dmatrix_new = xgb.DMatrix(new_row)
34
 
35
- # Predict probability for staying (XGBoost Booster returns only one class probability)
36
  prob = loaded_model.predict(dmatrix_new)
37
 
38
  # Compute SHAP values
39
- shap_values = explainer.shap_values(new_row)
40
- plot = shap.plots.bar(shap_values, max_display=6, order=shap.Explanation.abs, show_data='auto', show=False)
 
 
41
 
42
  plt.tight_layout()
43
  local_plot = plt.gcf()
 
7
  import numpy as np
8
  import matplotlib.pyplot as plt
9
 
10
+ # Load the XGBoost model from Pickle (Corrected: Using 'rb' mode)
11
  with open("h22_xgb_Final.pkl", "rb") as f:
12
+ raw_model = pickle.load(f)
13
 
14
+ # Restore the Booster from the raw format
15
+ loaded_model = xgb.Booster()
16
+ loaded_model.load_model(raw_model)
17
 
18
  # Setup SHAP Explainer for XGBoost
19
+ explainer = shap.TreeExplainer(loaded_model)
20
 
21
  # Define the prediction function
22
  def main_func(SupportiveGM, Merit, LearningDevelopment, WorkEnvironment, Engagement, WellBeing):
 
32
  # Convert new_row to DMatrix for XGBoost Booster
33
  dmatrix_new = xgb.DMatrix(new_row)
34
 
35
+ # Predict probability for staying (XGBoost Booster returns class probabilities)
36
  prob = loaded_model.predict(dmatrix_new)
37
 
38
  # Compute SHAP values
39
+ shap_values = explainer(new_row)
40
+
41
+ # Generate SHAP Plot
42
+ plot = shap.plots.bar(shap_values[0], max_display=6, order=shap.Explanation.abs, show_data='auto', show=False)
43
 
44
  plt.tight_layout()
45
  local_plot = plt.gcf()