File size: 3,849 Bytes
af7d3a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import torch
import gradio as gr
from torchvision import transforms


# load imagenet labels
with open("imagenet_classes.txt", "r") as f:
    categories = [s.strip() for s in f.readlines()]


# load a resnet18 model pretrained on ImageNet
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True)
model.eval()


# preprocess data
pretrained_std = torch.Tensor([0.229, 0.224, 0.225])
pretrained_mean = torch.Tensor([0.485, 0.456, 0.406])

preprocess = transforms.Compose([
    transforms.Resize(224),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=pretrained_mean, std=pretrained_std),
])

def undo_preprocess(processed_img):
    img = processed_img.mul(pretrained_std.view(-1,1,1)).add(pretrained_mean.view(-1,1,1))
    img = torch.clamp(img, 0, 1).squeeze()
    img = img.detach().numpy().transpose(1,2,0)
    return img

def set_example_image(img):
    return gr.Image.update(value=img[0])


def generate_adversial_image(img, epsilon):
    
    # convert a PIL image to tensor and set requires_grad to True
    processed_img = preprocess(img).unsqueeze(0)
    processed_img.requires_grad = True
    
    # forward pass
    output = model(processed_img)
    
    # get predictions
    probs = torch.nn.functional.softmax(output[0], dim=0)
    top5_prob, top5_idx = torch.topk(probs, 5)
    preds = {categories[idx]: prob.item() for idx, prob in zip(top5_idx, top5_prob)}
    
    # compute gradient
    label = torch.Tensor([top5_idx[0]]).type(torch.LongTensor)
    loss = torch.nn.functional.cross_entropy(output, label)
    model.zero_grad()
    loss.backward()
    
    # generate adversarial image
    adv_img = processed_img + epsilon*processed_img.grad.data.sign()
    adv_output = model(adv_img)
    adv_probs = torch.nn.functional.softmax(adv_output[0], dim=0)
    adv_top5_prob, adv_top5_idx = torch.topk(adv_probs, 5)
    adv_preds = {categories[idx]: prob.item() for idx, prob in zip(adv_top5_idx, adv_top5_prob)}
    
    return undo_preprocess(processed_img), undo_preprocess(adv_img), preds, adv_preds



with gr.Blocks() as demo:
    
    gr.Markdown('''## Generate Adversarial Image with Fast Gradient Sign Method
                
                Given an input image and a neural network, the adversarial image can be generated from 
                
                <code>adv_img = input_img + epsilon*input_img.grad</code>
                
                ''')
                                             
    with gr.Box():
        input_image = gr.Image(type="pil", label="Input Image")
        example_images = gr.Dataset(components=[input_image],
                                    samples=[['coral.jpg'], ['goldfish.jpg'], ['otter.jpg'], ['panda.jpg']])
        
        with gr.Row():
            epsilon = gr.Slider(minimum=0, maximum=0.5, value=0.02, step=0.01, label="epsilon")
            btn = gr.Button("Generate Adversarial Image")
    
    gr.Markdown('''### Original Image''')
    with gr.Box():
        with gr.Row():
            img_before = gr.Image(label="Original Image")
            label_before = gr.Label(label="Original Prediction")
    
    gr.Markdown('''### Adversarial Image''')
    with gr.Box():
        with gr.Row():
            img_after = gr.Image(label="Adversarial Image")
            label_after = gr.Label(label="New Prediction")
        
    gr.Markdown('''The prediction is done by ResNet18. Example images are from [Unsplash](https://unsplash.com).''')
    
    
    # events
    btn.click(fn=generate_adversial_image, 
              inputs=[input_image, epsilon], 
              outputs=[img_before, img_after, label_before, label_after])
    
    example_images.click(fn=set_example_image,
                             inputs=example_images,
                             outputs=example_images.components)


demo.launch()