WwYc commited on
Commit
39c6f3a
·
verified ·
1 Parent(s): c4207a6

Update visualization.py

Browse files
Files changed (1) hide show
  1. visualization.py +2 -2
visualization.py CHANGED
@@ -33,14 +33,14 @@ model = vit_LRP(pretrained=True)
33
  model.eval()
34
  attribution_generator = LRP(model)
35
 
36
- def generate_visualization(original_image, class_index=None, use_threshold=False):
37
  transformer_attribution = attribution_generator.generate_LRP(original_image.unsqueeze(0), method="transformer_attribution", index=class_index).detach()
38
  transformer_attribution = transformer_attribution.reshape(1, 1, 14, 14)
39
  transformer_attribution = torch.nn.functional.interpolate(transformer_attribution, scale_factor=16, mode='bilinear')
40
  transformer_attribution = transformer_attribution.reshape(224, 224).data.cpu().numpy()
41
  transformer_attribution = (transformer_attribution - transformer_attribution.min()) / (transformer_attribution.max() - transformer_attribution.min())
42
 
43
- if use_threshold:
44
  transformer_attribution = transformer_attribution * 255
45
  transformer_attribution = transformer_attribution.astype(np.uint8)
46
  ret, transformer_attribution = cv2.threshold(transformer_attribution, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
 
33
  model.eval()
34
  attribution_generator = LRP(model)
35
 
36
+ def generate_visualization(original_image, class_index=None):
37
  transformer_attribution = attribution_generator.generate_LRP(original_image.unsqueeze(0), method="transformer_attribution", index=class_index).detach()
38
  transformer_attribution = transformer_attribution.reshape(1, 1, 14, 14)
39
  transformer_attribution = torch.nn.functional.interpolate(transformer_attribution, scale_factor=16, mode='bilinear')
40
  transformer_attribution = transformer_attribution.reshape(224, 224).data.cpu().numpy()
41
  transformer_attribution = (transformer_attribution - transformer_attribution.min()) / (transformer_attribution.max() - transformer_attribution.min())
42
 
43
+ if use_thresholding:
44
  transformer_attribution = transformer_attribution * 255
45
  transformer_attribution = transformer_attribution.astype(np.uint8)
46
  ret, transformer_attribution = cv2.threshold(transformer_attribution, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)