Spaces:
Sleeping
Sleeping
Update visualization.py
Browse files- visualization.py +2 -2
visualization.py
CHANGED
@@ -33,14 +33,14 @@ model = vit_LRP(pretrained=True)
|
|
33 |
model.eval()
|
34 |
attribution_generator = LRP(model)
|
35 |
|
36 |
-
def generate_visualization(original_image, class_index=None
|
37 |
transformer_attribution = attribution_generator.generate_LRP(original_image.unsqueeze(0), method="transformer_attribution", index=class_index).detach()
|
38 |
transformer_attribution = transformer_attribution.reshape(1, 1, 14, 14)
|
39 |
transformer_attribution = torch.nn.functional.interpolate(transformer_attribution, scale_factor=16, mode='bilinear')
|
40 |
transformer_attribution = transformer_attribution.reshape(224, 224).data.cpu().numpy()
|
41 |
transformer_attribution = (transformer_attribution - transformer_attribution.min()) / (transformer_attribution.max() - transformer_attribution.min())
|
42 |
|
43 |
-
if
|
44 |
transformer_attribution = transformer_attribution * 255
|
45 |
transformer_attribution = transformer_attribution.astype(np.uint8)
|
46 |
ret, transformer_attribution = cv2.threshold(transformer_attribution, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
|
|
33 |
model.eval()
|
34 |
attribution_generator = LRP(model)
|
35 |
|
36 |
+
def generate_visualization(original_image, class_index=None):
|
37 |
transformer_attribution = attribution_generator.generate_LRP(original_image.unsqueeze(0), method="transformer_attribution", index=class_index).detach()
|
38 |
transformer_attribution = transformer_attribution.reshape(1, 1, 14, 14)
|
39 |
transformer_attribution = torch.nn.functional.interpolate(transformer_attribution, scale_factor=16, mode='bilinear')
|
40 |
transformer_attribution = transformer_attribution.reshape(224, 224).data.cpu().numpy()
|
41 |
transformer_attribution = (transformer_attribution - transformer_attribution.min()) / (transformer_attribution.max() - transformer_attribution.min())
|
42 |
|
43 |
+
if use_thresholding:
|
44 |
transformer_attribution = transformer_attribution * 255
|
45 |
transformer_attribution = transformer_attribution.astype(np.uint8)
|
46 |
ret, transformer_attribution = cv2.threshold(transformer_attribution, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|