Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -53,12 +53,12 @@ def adr_predict(x):
|
|
53 |
scores = output[0][0].detach()
|
54 |
scores = torch.nn.functional.softmax(scores)
|
55 |
|
56 |
-
shap_values = explainer([str(
|
57 |
# # Find the index of the class you want as the default reference (e.g., 'label_1')
|
58 |
-
|
59 |
|
60 |
# # Plot the SHAP values for a specific instance in your dataset (e.g., instance 0)
|
61 |
-
|
62 |
|
63 |
local_plot = shap.plots.text(shap_values[0], display=False)
|
64 |
|
|
|
53 |
scores = output[0][0].detach()
|
54 |
scores = torch.nn.functional.softmax(scores)
|
55 |
|
56 |
+
shap_values = explainer([str("The young woman had a severe drug reaction.").lower()])
|
57 |
# # Find the index of the class you want as the default reference (e.g., 'label_1')
|
58 |
+
label_1_index = np.where(np.array(explainer.output_names) == 'label_1')[0][0]
|
59 |
|
60 |
# # Plot the SHAP values for a specific instance in your dataset (e.g., instance 0)
|
61 |
+
shap.plots.text(shap_values[label_1_index][0])
|
62 |
|
63 |
local_plot = shap.plots.text(shap_values[0], display=False)
|
64 |
|