Mpodszus commited on
Commit
cae59a3
·
verified ·
1 Parent(s): 2be2025

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -4
app.py CHANGED
@@ -63,16 +63,26 @@ def main_func(Department, ChainScale, SupportiveGM, Merit, LearningDevelopment,
63
  }).astype(float)
64
 
65
  prob = loaded_model.predict_proba(new_row)
 
 
 
 
 
 
 
 
66
 
67
  shap_values = explainer(new_row)
68
 
69
  fig, ax = plt.subplots(figsize=(8, 4))
70
- shap.waterfall_plot(shap_values[0]) # Corrected to use waterfall plot
 
 
71
  plt.tight_layout()
72
  local_plot = plt.gcf()
73
  plt.close()
74
 
75
- return {"Leave": float(prob[0][0]), "Stay": float(prob[0][1])}, local_plot
76
 
77
  # Create the UI
78
  title = "**Mod 3 Team 5: Employee Turnover Predictor & Interpreter**"
@@ -154,5 +164,4 @@ with gr.Blocks(title=title) as demo:
154
  cache_examples=True
155
  )
156
 
157
- demo.launch()
158
-
 
63
  }).astype(float)
64
 
65
  prob = loaded_model.predict_proba(new_row)
66
+
67
+ # Ensure probabilities return correctly
68
+ if prob.shape[1] == 2:
69
+ leave_prob = float(prob[0][0])
70
+ stay_prob = float(prob[0][1])
71
+ else:
72
+ leave_prob = float(prob[0])
73
+ stay_prob = 1 - leave_prob
74
 
75
  shap_values = explainer(new_row)
76
 
77
  fig, ax = plt.subplots(figsize=(8, 4))
78
+ shap.waterfall_plot(shap.Explanation(values=shap_values.values[0],
79
+ base_values=shap_values.base_values[0],
80
+ data=new_row.iloc[0])) # Fix waterfall plot
81
  plt.tight_layout()
82
  local_plot = plt.gcf()
83
  plt.close()
84
 
85
+ return {"Leave": leave_prob, "Stay": stay_prob}, local_plot
86
 
87
  # Create the UI
88
  title = "**Mod 3 Team 5: Employee Turnover Predictor & Interpreter**"
 
164
  cache_examples=True
165
  )
166
 
167
+ demo.launch()