WwYc commited on
Commit
49ef971
·
verified ·
1 Parent(s): ade6eb5

Update visualization.py

Browse files
Files changed (1) hide show
  1. visualization.py +1 -1
visualization.py CHANGED
@@ -34,7 +34,7 @@ model.eval()
34
  attribution_generator = LRP(model)
35
 
36
  def generate_visualization(original_image, class_index=None, use_threshold=False):
37
- transformer_attribution = attribution_generator.generate_LRP(original_image.unsqueeze(0).cuda(), method="transformer_attribution", index=class_index).detach()
38
  transformer_attribution = transformer_attribution.reshape(1, 1, 14, 14)
39
  transformer_attribution = torch.nn.functional.interpolate(transformer_attribution, scale_factor=16, mode='bilinear')
40
  transformer_attribution = transformer_attribution.reshape(224, 224).data.cpu().numpy()
 
34
  attribution_generator = LRP(model)
35
 
36
  def generate_visualization(original_image, class_index=None, use_threshold=False):
37
+ transformer_attribution = attribution_generator.generate_LRP(original_image.unsqueeze(0), method="transformer_attribution", index=class_index).detach()
38
  transformer_attribution = transformer_attribution.reshape(1, 1, 14, 14)
39
  transformer_attribution = torch.nn.functional.interpolate(transformer_attribution, scale_factor=16, mode='bilinear')
40
  transformer_attribution = transformer_attribution.reshape(224, 224).data.cpu().numpy()