|
import torch
|
|
import numpy as np
|
|
import torchvision.transforms as transforms
|
|
from pytorch_grad_cam import GradCAM
|
|
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
|
|
from pytorch_grad_cam.utils.image import show_cam_on_image
|
|
|
|
def get_conv_layers(model):
|
|
conv_layers = []
|
|
for name, module in model.named_modules():
|
|
if isinstance(module, torch.nn.Conv2d):
|
|
conv_layers.append(name)
|
|
return conv_layers
|
|
|
|
|
|
def get_gradcam(model, face_pil_image, layer_name):
|
|
model.eval()
|
|
|
|
|
|
transform = transforms.Compose([
|
|
transforms.Resize((299, 299)),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
|
])
|
|
face_tensor = transform(face_pil_image).unsqueeze(0).to(next(model.parameters()).device)
|
|
|
|
|
|
face_np = np.array(face_pil_image.resize((299, 299))) / 255.0
|
|
if face_np.shape[-1] == 1:
|
|
face_np = np.repeat(face_np, 3, axis=-1)
|
|
|
|
|
|
|
|
target_module = dict(model.named_modules())[layer_name]
|
|
cam = GradCAM(model=model, target_layers=[target_module])
|
|
grayscale_cam = cam(input_tensor=face_tensor, targets=[ClassifierOutputTarget(0)])
|
|
|
|
|
|
grayscale_cam = grayscale_cam[0]
|
|
visualization = show_cam_on_image(face_np, grayscale_cam, use_rgb=True)
|
|
|
|
return visualization
|
|
|