WwYc commited on
Commit
4604315
·
verified ·
1 Parent(s): ef42493

Create explain.py

Browse files
Files changed (1) hide show
  1. explain.py +18 -0
explain.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+
3
+ from visualization import generate_visualization
4
+
5
+ def do_explain(transform, image, class_index=None):
6
+ fig, axs = plt.subplots(1, 2)
7
+ axs[0].imshow(image)
8
+ axs[0].axis("off")
9
+
10
+ transformed_image = transform(image)
11
+ viz = generate_visualization(
12
+ transformed_image, class_index=class_index, method="gradcam", use_threshold=False
13
+ )
14
+
15
+ axs[1].imshow(viz)
16
+ axs[1].axis("off")
17
+ return fig
18
+