File size: 1,624 Bytes
0077a91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import torch
import numpy as np
import torchvision.transforms as transforms
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image

def get_conv_layers(model):
    conv_layers = []
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Conv2d):
            conv_layers.append(name)
    return conv_layers


def get_gradcam(model, face_pil_image, layer_name):
    model.eval()

    # Preprocess: convert PIL Image to normalized tensor
    transform = transforms.Compose([
        transforms.Resize((299, 299)),  # Xception input size
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # match training normalization
    ])
    face_tensor = transform(face_pil_image).unsqueeze(0).to(next(model.parameters()).device)

    # Convert image to numpy for overlay
    face_np = np.array(face_pil_image.resize((299, 299))) / 255.0  # shape: (H, W, C)
    if face_np.shape[-1] == 1:
        face_np = np.repeat(face_np, 3, axis=-1)

    # Grad-CAM
    # Dynamically get layer by name
    target_module = dict(model.named_modules())[layer_name]
    cam = GradCAM(model=model, target_layers=[target_module])
    grayscale_cam = cam(input_tensor=face_tensor, targets=[ClassifierOutputTarget(0)])

    # Overlay the cam on image
    grayscale_cam = grayscale_cam[0]  # only one image in batch
    visualization = show_cam_on_image(face_np, grayscale_cam, use_rgb=True)

    return visualization