Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -7,16 +7,16 @@ import gradio as gr
|
|
7 |
import numpy as np
|
8 |
import matplotlib.pyplot as plt
|
9 |
|
10 |
-
# Load the XGBoost model from Pickle
|
11 |
with open("h22_xgb_Final.pkl", "rb") as f:
|
12 |
-
|
13 |
|
14 |
-
#
|
15 |
-
|
16 |
-
|
17 |
|
18 |
# Setup SHAP Explainer for XGBoost
|
19 |
-
explainer = shap.TreeExplainer(loaded_model)
|
20 |
|
21 |
# Define the prediction function
|
22 |
def main_func(SupportiveGM, Merit, LearningDevelopment, WorkEnvironment, Engagement, WellBeing):
|
@@ -32,12 +32,14 @@ def main_func(SupportiveGM, Merit, LearningDevelopment, WorkEnvironment, Engagem
|
|
32 |
# Convert new_row to DMatrix for XGBoost Booster
|
33 |
dmatrix_new = xgb.DMatrix(new_row)
|
34 |
|
35 |
-
# Predict probability for staying (XGBoost Booster returns
|
36 |
prob = loaded_model.predict(dmatrix_new)
|
37 |
|
38 |
# Compute SHAP values
|
39 |
-
shap_values = explainer
|
40 |
-
|
|
|
|
|
41 |
|
42 |
plt.tight_layout()
|
43 |
local_plot = plt.gcf()
|
|
|
7 |
import numpy as np
|
8 |
import matplotlib.pyplot as plt
|
9 |
|
10 |
+
# Load the XGBoost model from Pickle (Corrected: Using 'rb' mode)
|
11 |
with open("h22_xgb_Final.pkl", "rb") as f:
|
12 |
+
raw_model = pickle.load(f)
|
13 |
|
14 |
+
# Restore the Booster from the raw format
|
15 |
+
loaded_model = xgb.Booster()
|
16 |
+
loaded_model.load_model(raw_model)
|
17 |
|
18 |
# Setup SHAP Explainer for XGBoost
|
19 |
+
explainer = shap.TreeExplainer(loaded_model)
|
20 |
|
21 |
# Define the prediction function
|
22 |
def main_func(SupportiveGM, Merit, LearningDevelopment, WorkEnvironment, Engagement, WellBeing):
|
|
|
32 |
# Convert new_row to DMatrix for XGBoost Booster
|
33 |
dmatrix_new = xgb.DMatrix(new_row)
|
34 |
|
35 |
+
# Predict probability for staying (XGBoost Booster returns class probabilities)
|
36 |
prob = loaded_model.predict(dmatrix_new)
|
37 |
|
38 |
# Compute SHAP values
|
39 |
+
shap_values = explainer(new_row)
|
40 |
+
|
41 |
+
# Generate SHAP Plot
|
42 |
+
plot = shap.plots.bar(shap_values[0], max_display=6, order=shap.Explanation.abs, show_data='auto', show=False)
|
43 |
|
44 |
plt.tight_layout()
|
45 |
local_plot = plt.gcf()
|