cchukwu commited on
Commit
bb27be0
·
verified ·
1 Parent(s): efe93d2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -53,12 +53,12 @@ def adr_predict(x):
53
  scores = output[0][0].detach()
54
  scores = torch.nn.functional.softmax(scores)
55
 
56
- shap_values = explainer([str(x).lower()])
57
  # # Find the index of the class you want as the default reference (e.g., 'label_1')
58
- # label_1_index = np.where(np.array(explainer.output_names) == 'label_1')[0][0]
59
 
60
  # # Plot the SHAP values for a specific instance in your dataset (e.g., instance 0)
61
- # shap.plots.text(shap_values[label_1_index][0])
62
 
63
  local_plot = shap.plots.text(shap_values[0], display=False)
64
 
 
53
  scores = output[0][0].detach()
54
  scores = torch.nn.functional.softmax(scores)
55
 
56
+ shap_values = explainer([str("The young woman had a severe drug reaction.").lower()])
57
  # # Find the index of the class you want as the default reference (e.g., 'label_1')
58
+ label_1_index = np.where(np.array(explainer.output_names) == 'label_1')[0][0]
59
 
60
  # # Plot the SHAP values for a specific instance in your dataset (e.g., instance 0)
61
+ shap.plots.text(shap_values[label_1_index][0])
62
 
63
  local_plot = shap.plots.text(shap_values[0], display=False)
64