Mpodszus commited on
Commit
9c8d7b2
·
verified ·
1 Parent(s): c819682

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +157 -46
app.py CHANGED
@@ -13,78 +13,189 @@ loaded_model = pickle.load(open("h22_xgb_Final(2).pkl", 'rb'))
13
  # Setup SHAP Explainer for XGBoost
14
  explainer = shap.Explainer(loaded_model)
15
 
16
- # Define the prediction function
17
- def main_func(SupportiveGM, Merit, LearningDevelopment, WorkEnvironment, Engagement, WellBeing):
18
- new_row = pd.DataFrame.from_dict({
19
- 'SupportiveGM': SupportiveGM,
20
- 'Merit': Merit,
21
- 'LearningDevelopment': LearningDevelopment,
22
- 'WorkEnvironment': WorkEnvironment,
23
- 'Engagement': Engagement,
24
- 'WellBeing': WellBeing
25
- }, orient='index').transpose()
26
-
27
-
28
- # Predict probability for staying (XGBoost Booster returns class probabilities)
29
- prob = loaded_model.predict_proba(new_row)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
- # Compute SHAP values
32
  shap_values = explainer(new_row)
33
 
34
- # Generate SHAP Plot
35
- plot = shap.plots.bar(shap_values[0], max_display=6, order=shap.Explanation.abs, show_data='auto', show=False)
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
  plt.tight_layout()
38
  local_plot = plt.gcf()
39
- plt.rcParams['figure.figsize'] = (6, 4)
40
  plt.close()
41
-
42
- return {"Leave": float(prob[0][0]), "Stay": 1-float(prob[0][0])}, local_plot
43
 
44
  # Create the UI
45
- title = "**Mod 3 Team 5: Employee Turnover Predictor**"
46
  description1 = """
47
- This app takes six inputs about employees' satisfaction with different aspects of their work (such as work-life balance, ...)
48
- and predicts whether the employee intends to stay with the employer or leave. The outputs include:
49
- 1. The predicted probability of staying or leaving.
50
- 2. A SHAP plot that visualizes how different factors impact the prediction.
51
  """
52
 
53
  description2 = """
54
- To use the app, adjust the values of the six employee satisfaction factors and click **Analyze**.
 
55
  """
56
 
57
  with gr.Blocks(title=title) as demo:
 
58
  gr.Markdown(f"## {title}")
59
  gr.Markdown(description1)
60
  gr.Markdown("""---""")
61
  gr.Markdown(description2)
62
  gr.Markdown("""---""")
63
-
64
- with gr.Row():
65
  with gr.Column():
66
- SupportiveGM = gr.Slider(label="Supportive GM Score", minimum=1, maximum=5, value=4, step=0.1)
67
- Merit = gr.Slider(label="Merit Score", minimum=1, maximum=5, value=4, step=0.1)
68
- LearningDevelopment = gr.Slider(label="Learning & Development Score", minimum=1, maximum=5, value=4, step=0.1)
69
- WorkEnvironment = gr.Slider(label="Work Environment Score", minimum=1, maximum=5, value=4, step=0.1)
70
- Engagement = gr.Slider(label="Engagement Score", minimum=1, maximum=5, value=4, step=0.1)
71
- WellBeing = gr.Slider(label="Well-Being Score", minimum=1, maximum=5, value=4, step=0.1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  submit_btn = gr.Button("Analyze")
73
 
74
  with gr.Column(visible=True, scale=1, min_width=600) as output_col:
75
- label = gr.Label(label="Predicted Turnover Probability")
76
- local_plot = gr.Plot(label="SHAP Plot:")
77
 
78
- submit_btn.click(
79
- main_func,
80
- [SupportiveGM, Merit, LearningDevelopment, WorkEnvironment, Engagement, WellBeing],
81
- [label, local_plot], api_name="Employee_Turnover"
82
- )
 
83
 
84
- gr.Markdown("### Click on an example below to see how it works:")
85
- gr.Examples([[4, 4, 4, 4, 5, 5], [5, 4, 5, 4, 4, 4]],
86
- [SupportiveGM, Merit, LearningDevelopment, WorkEnvironment, Engagement, WellBeing],
87
- [label, local_plot], main_func, cache_examples=True)
88
 
89
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
 
 
13
  # Setup SHAP Explainer for XGBoost
14
  explainer = shap.Explainer(loaded_model)
15
 
16
+ def safe_convert(value, default, min_val, max_val):
17
+ try:
18
+ num = float(value)
19
+ return max(min_val, min(num, max_val)) # Ensure within range
20
+ except (TypeError, ValueError):
21
+ return default # Use default if conversion fails
22
+
23
+ # Create the main function for server
24
+ def main_func(Department, ChainScale, SupportiveGM, Merit, LearningDevelopment, WorkEnvironment, Engagement, WellBeing):
25
+
26
+ # ChainScale mapping
27
+ ChainScale = {
28
+ 'Luxury': 1,
29
+ 'Upper Midscale': 2,
30
+ 'Upper Upscale': 3,
31
+ 'Upscale': 4,
32
+ 'Independent': 5,
33
+ }
34
+ default_ChainScale = 4
35
+ ChainScale_value = ChainScale_mapping.get(ChainScale, ChainScale)
36
+
37
+ # Department mapping
38
+ department_mapping = {
39
+ "Guest Services": 1,
40
+ "Food and Beverage": 2,
41
+ "Housekeeping": 3,
42
+ "Front Office Operations": 4,
43
+ "Guest Activities": 5,
44
+ }
45
+ default_department = 5
46
+ department_value = department_mapping.get(Department, default_department)
47
+
48
+ LearningDevelopment = safe_convert(LearningDevelopment, 3.0, 1, 5)
49
+ SupportiveGM = safe_convert(SupportiveGM, 3.0, 1, 5)
50
+ Merit = safe_convert(Merit, 3.0, 1, 5)
51
+ WorkEnvironment = safe_convert(WorkEnvironment, 3.0, 1, 5)
52
+ WellBeing = safe_convert(WellBeing, 3.0, 1, 5)
53
+
54
+ new_row = pd.DataFrame({
55
+ 'Department': [int(department_value)],
56
+ 'ChainScale': [int(ChainScale_value)],
57
+ 'SupportiveGM': [SupportiveGM],
58
+ 'Merit': [Merit],
59
+ 'LearningDevelopment': [LearningDevelopment],
60
+ 'WorkEnvironment': [WorkEnvironment],
61
+ 'Engaged': [Engaged],
62
+ 'WellBeing': [WellBeing]
63
+ }).astype(float)
64
+
65
+ prob = loaded_model.predict_proba(new_row)
66
 
 
67
  shap_values = explainer(new_row)
68
 
69
+ fig, ax = plt.subplots(figsize=(8, 4))
70
+ values = shap_values.values[0] # Extract SHAP values
71
+ features = new_row.columns # Feature names
72
+
73
+ # Assign colors manually (Hilton Blue for positive, Red for negative)
74
+ colors = ['#FF0000' if v < 0 else '#1E4380' for v in values]
75
+
76
+ sorted_indices = np.argsort(np.abs(values))[-6:] # Select top 6 features
77
+ sorted_values = values[sorted_indices]
78
+ sorted_features = [features[i] for i in sorted_indices]
79
+ sorted_colors = [colors[i] for i in sorted_indices]
80
+ ax.barh(sorted_features, sorted_values, color=sorted_colors)
81
+ ax.set_xlabel("SHAP Value Impact")
82
+ ax.set_title("Feature Importance (SHAP)")
83
 
84
  plt.tight_layout()
85
  local_plot = plt.gcf()
 
86
  plt.close()
87
+
88
+ return {"Leave": float(prob[0][0]), "Stay": 1 - float(prob[0][0])}, local_plot
89
 
90
  # Create the UI
91
+ title = "**Mod 3 Team 5: Employee Turnover Predictor & Interpreter**"
92
  description1 = """
93
+ This app predicts whether an employee intends to stay or leave based on satisfaction factors and department. There are two outputs from the app: 1- the predicted probability of stay or leave, 2- Shapley's
94
+ force-plot which visualizes the extent to which each factor impacts the stay/leave prediction.
 
 
95
  """
96
 
97
  description2 = """
98
+ To use the app, click on one of the examples, or adjust the values of the employee satisfaction
99
+ factors in the employee generation and tenure population of interest, and click on Analyze. ✨
100
  """
101
 
102
  with gr.Blocks(title=title) as demo:
103
+ gr.Image("Hilton-Logo-Resized.png")
104
  gr.Markdown(f"## {title}")
105
  gr.Markdown(description1)
106
  gr.Markdown("""---""")
107
  gr.Markdown(description2)
108
  gr.Markdown("""---""")
109
+
110
+ with gr.Row():
111
  with gr.Column():
112
+ Department = gr.Radio(
113
+ ["Guest Services", "Food and Beverage", "Housekeeping",
114
+ "Front Office Operations", "Guest Activties"],
115
+ label="Department",
116
+ value="Guest Services"
117
+ )
118
+ ChainScale = gr.Dropdown(
119
+ ["Luxury", "Upper Mid-Scale", "Upper Upscale", "Upscale", "Indepdendent"],
120
+ label="ChainScale",
121
+ value="Upper Upscale"
122
+ )
123
+ LearningDevelopment = gr.Slider(
124
+ label="SupportiveGM Score", minimum=1, maximum=5, value=4, step=0.1,
125
+ interactive=True
126
+ )
127
+ SupportiveGM = gr.Slider(
128
+ label="Merit Score", minimum=1, maximum=5, value=4, step=0.1,
129
+ interactive=True
130
+ )
131
+ Merit = gr.Slider(
132
+ label="Learning and Development Score", minimum=1, maximum=5, value=4, step=0.1,
133
+ interactive=True
134
+ )
135
+ WorkEnvironment = gr.Slider(
136
+ label="Work Environment Score", minimum=1, maximum=5, value=4, step=0.1,
137
+ interactive=True
138
+ )
139
+ Engagement = gr.Slider(
140
+ label="Engagement Score", minimum=1, maximum=5, value=4, step=0.1,
141
+ interactive=True
142
+ )
143
+ WellBeing = gr.Slider(
144
+ label="Well-Being Score", minimum=1, maximum=5, value=4, step=0.1,
145
+ interactive=True
146
+ )
147
  submit_btn = gr.Button("Analyze")
148
 
149
  with gr.Column(visible=True, scale=1, min_width=600) as output_col:
150
+ label = gr.Label(label="Predicted Label")
151
+ local_plot = gr.Plot(label='SHAP Analysis')
152
 
153
+ submit_btn.click(
154
+ main_func,
155
+ [Department, ChainScale, SupportiveGM, Merit, LearningDevelopment, WorkEnvironment, Engagement, WellBeing],
156
+ [label, local_plot],
157
+ api_name="Employee_Turnover"
158
+ )
159
 
160
+ gr.Markdown("### Click on any of the examples below to see how it works:")
 
 
 
161
 
162
+ gr.Examples(
163
+ [
164
+ ["Guest Services", "UpperUpscale", 2.5, 3.0, 2.8, 3.5],
165
+ ["Food and Beverage", "UpperUpscale", 3.5, 4.0, 4.2, 4.5],
166
+ ["Housekeeping", "UpperUpscale", 5.0, 4.8, 5.0, 4.7]
167
+ ],
168
+ [Department, ChainScale, SupportiveGM, Merit, LearningDevelopment, WorkEnvironment, Engagement, WellBeing],
169
+ [label, local_plot],
170
+ main_func,
171
+ cache_examples=True
172
+ )
173
+
174
+ demo.css = """
175
+ body {
176
+ font-family: "Garamond", "Georgia", serif !important;
177
+ }
178
+ .gr-button {
179
+ background-color: #1E4380 !important;
180
+ color: white !important;
181
+ border-radius: 8px !important;
182
+ padding: 10px !important;
183
+ }
184
+ .gr-label {
185
+ color: #1E4380 !important;
186
+ font-weight: bold;
187
+ font-size: 18px !important;
188
+ }
189
+ /* Fix Slider Colors */
190
+ input[type="range"] {
191
+ accent-color: #1E4380 !important;
192
+ }
193
+ .gr-slider .track {
194
+ background: #1E4380 !important;
195
+ }
196
+ .gr-slider .thumb {
197
+ background: #1E4380 !important;
198
+ }
199
+ """
200
 
201
+ demo.launch()