WwYc commited on
Commit
c4207a6
·
verified ·
1 Parent(s): 8038161

Update explain.py

Browse files
Files changed (1) hide show
  1. explain.py +2 -2
explain.py CHANGED
@@ -2,14 +2,14 @@ import matplotlib.pyplot as plt
2
 
3
  from visualization import generate_visualization
4
 
5
- def do_explain(transform, image, class_index=None, use_threshold=False):
6
  fig, axs = plt.subplots(1, 2)
7
  axs[0].imshow(image)
8
  axs[0].axis("off")
9
 
10
  transformed_image = transform(image)
11
  viz = generate_visualization(
12
- transformed_image, class_index=class_index, use_threshold=use_threshold
13
  )
14
 
15
  axs[1].imshow(viz)
 
2
 
3
  from visualization import generate_visualization
4
 
5
+ def do_explain(transform, image, class_index=None):
6
  fig, axs = plt.subplots(1, 2)
7
  axs[0].imshow(image)
8
  axs[0].axis("off")
9
 
10
  transformed_image = transform(image)
11
  viz = generate_visualization(
12
+ transformed_image, class_index=class_index
13
  )
14
 
15
  axs[1].imshow(viz)