File size: 2,012 Bytes
09823ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
# src/model/gradcam.py

import torch
import torch.nn.functional as F
import numpy as np
import cv2

class GradCAMPlusPlus:
    def __init__(self, model, target_layer):
        self.model = model
        self.model.eval()
        self.target_layer = target_layer

        self.gradients = None
        self.activations = None

        # Hook to capture activations and gradients
        target_layer.register_forward_hook(self._save_activations)
        target_layer.register_full_backward_hook(self._save_gradients)

    def _save_activations(self, module, input, output):
        self.activations = output.detach()

    def _save_gradients(self, module, grad_input, grad_output):
        self.gradients = grad_output[0].detach()

    def generate(self, input_tensor, class_idx=None):
        # Forward pass
        output = self.model(input_tensor)
        if class_idx is None:
            class_idx = output.argmax(dim=1).item()

        # Zero gradients
        self.model.zero_grad()

        # Backward pass
        loss = output[0, class_idx]
        loss.backward(retain_graph=True)

        # GradCAM++ calculation
        grads = self.gradients  # [batch, channels, height, width]
        activations = self.activations

        grads_power_2 = grads ** 2
        grads_power_3 = grads ** 3

        sum_grads = torch.sum(grads, dim=(2, 3), keepdim=True)

        eps = 1e-8  # Avoid divide-by-zero
        alpha_numer = grads_power_2
        alpha_denom = 2 * grads_power_2 + sum_grads * grads_power_3
        alpha_denom = torch.where(alpha_denom != 0.0, alpha_denom, torch.ones_like(alpha_denom))
        alphas = alpha_numer / alpha_denom

        weights = (alphas * F.relu(grads)).sum(dim=(2, 3), keepdim=True)

        cam = (weights * activations).sum(dim=1).squeeze()

        cam = F.relu(cam)
        cam = cam.cpu().numpy()
        cam = cv2.resize(cam, (input_tensor.shape[2], input_tensor.shape[3]))
        cam = (cam - np.min(cam)) / (np.max(cam) - np.min(cam) + eps)
        return cam