Mpodszus commited on
Commit
008cbc7
·
verified ·
1 Parent(s): 357692e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -23
app.py CHANGED
@@ -1,30 +1,83 @@
 
 
 
 
1
  import gradio as gr
2
- import pickle
3
  import numpy as np
 
4
 
5
- # Load the trained model
6
- with open("clf.pkl", "rb") as f:
7
- clf = pickle.load(f)
 
 
 
8
 
9
  # Define the prediction function
10
- def predict(SupportiveGM, Merit, LearningDevelopment, WorkEnvironmente, Engagement, WellBeing, ChainScale):
11
- # Convert inputs into a NumPy array
12
- input_data = np.array([[SupportiveGM, Merit, LearningDevelopment, WorkEnvironmente, Engagement, WellBeing, ChainScale]])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
- # Make prediction using the model
15
- prediction = clf.predict(input_data)
16
-
17
- return f"Predicted Turnover Probability: {prediction[0]}"
18
-
19
- # Create the Gradio interface
20
- iface = gr.Interface(
21
- fn=predict,
22
- inputs=["number"] * 7,
23
- outputs="text",
24
- title="Employee Turnover Prediction",
25
- api_name="/Employee_Turnover"
26
- )
27
-
28
- if __name__ == "__main__":
29
- iface.launch()
30
 
 
1
+ import pickle
2
+ import pandas as pd
3
+ import shap
4
+ from shap.plots._force_matplotlib import draw_additive_plot
5
  import gradio as gr
 
6
  import numpy as np
7
+ import matplotlib.pyplot as plt
8
 
9
+ # Load the XGBoost model
10
+ with open("h22_xgb.pkl", "rb") as f:
11
+ loaded_model = pickle.load(f)
12
+
13
+ # Setup SHAP Explainer
14
+ explainer = shap.Explainer(loaded_model) # PLEASE DO NOT CHANGE THIS.
15
 
16
  # Define the prediction function
17
+ def main_func(SupportiveGM, Merit, LearningDevelopment, WorkEnvironment, Engagement, WellBeing):
18
+ new_row = pd.DataFrame.from_dict({'SupportiveGM': SupportiveGM, 'Merit': Merit,
19
+ 'LearningDevelopment': LearningDevelopment, 'WorkEnvironment': WorkEnvironment,
20
+ 'Engagement': Engagement, 'WellBeing': WellBeing}, orient='index').transpose()
21
+
22
+ # Predict probabilities
23
+ prob = loaded_model.predict_proba(new_row)
24
+
25
+ # Compute SHAP values
26
+ shap_values = explainer(new_row)
27
+ plot = shap.plots.bar(shap_values[0], max_display=6, order=shap.Explanation.abs, show_data='auto', show=False)
28
+
29
+ plt.tight_layout()
30
+ local_plot = plt.gcf()
31
+ plt.rcParams['figure.figsize'] = (6, 4)
32
+ plt.close()
33
+
34
+ return {"Leave": float(prob[0][0]), "Stay": 1 - float(prob[0][0])}, local_plot
35
+
36
+ # Create the UI
37
+ title = "**Mod 3 Team 5: Employee Turnover Predictor**"
38
+ description1 = """
39
+ This app takes six inputs about employees' satisfaction with different aspects of their work (such as work-life balance, ...)
40
+ and predicts whether the employee intends to stay with the employer or leave. The outputs include:
41
+ 1. The predicted probability of staying or leaving.
42
+ 2. A SHAP plot that visualizes how different factors impact the prediction.
43
+ """
44
+
45
+ description2 = """
46
+ To use the app, adjust the values of the six employee satisfaction factors and click **Analyze**. ✨
47
+ """
48
+
49
+ with gr.Blocks(title=title) as demo:
50
+ gr.Markdown(f"## {title}")
51
+ gr.Markdown(description1)
52
+ gr.Markdown("""---""")
53
+ gr.Markdown(description2)
54
+ gr.Markdown("""---""")
55
+
56
+ with gr.Row():
57
+ with gr.Column():
58
+ SupportiveGM = gr.Slider(label="Supportive GM Score", minimum=1, maximum=5, value=4, step=0.1)
59
+ Merit = gr.Slider(label="Merit Score", minimum=1, maximum=5, value=4, step=0.1)
60
+ LearningDevelopment = gr.Slider(label="Learning & Development Score", minimum=1, maximum=5, value=4, step=0.1)
61
+ WorkEnvironment = gr.Slider(label="Work Environment Score", minimum=1, maximum=5, value=4, step=0.1)
62
+ Engagement = gr.Slider(label="Engagement Score", minimum=1, maximum=5, value=4, step=0.1)
63
+ WellBeing = gr.Slider(label="Well-Being Score", minimum=1, maximum=5, value=4, step=0.1)
64
+ submit_btn = gr.Button("Analyze")
65
+
66
+ with gr.Column(visible=True, scale=1, min_width=600) as output_col:
67
+ label = gr.Label(label="Predicted Turnover Probability")
68
+ local_plot = gr.Plot(label="SHAP Plot:")
69
+
70
+ submit_btn.click(
71
+ main_func,
72
+ [SupportiveGM, Merit, LearningDevelopment, WorkEnvironment, Engagement, WellBeing],
73
+ [label, local_plot], api_name="Employee_Turnover"
74
+ )
75
 
76
+ gr.Markdown("### Click on an example below to see how it works:")
77
+ gr.Examples([[4, 4, 4, 4, 5, 5], [5, 4, 5, 4, 4, 4]],
78
+ [SupportiveGM, Merit, LearningDevelopment, WorkEnvironment, Engagement, WellBeing],
79
+ [label, local_plot], main_func, cache_examples=True)
80
+
81
+ demo.launch()
82
+
 
 
 
 
 
 
 
 
 
83