File size: 2,440 Bytes
bf50f54
 
 
 
d7ebdc0
 
bf50f54
d7ebdc0
 
 
bf50f54
90a0539
 
bf50f54
 
 
 
 
 
 
 
 
 
 
 
d7ebdc0
 
 
26e2141
0fc4827
0dacd2a
d7ebdc0
 
 
 
 
 
c6b134d
 
 
 
 
 
 
d7ebdc0
 
 
7e89ea8
d7ebdc0
90a0539
d7ebdc0
 
bf50f54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d7ebdc0
 
 
 
3b83509
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import gradio as gr
import requests
import torch
import torch.nn as nn
from badnet_m import BadNet
import torchvision.transforms as transforms

# import timm
# model = timm.create_model("hf_hub:nateraw/resnet18-random", pretrained=True)
# model.train()

id_label = {0:'airplane', 1:'automobile', 2:'bird', 3:'cat', 4:'deer', 5:'dog', 6:'frog', 7:'horse', 8:'ship', 9:'trunk'}

import os 

def print_bn():
    bn_data = []
    for m in model.modules():
        if(type(m) is nn.BatchNorm2d):
            # print(m.momentum)
            bn_data.extend(m.running_mean.data.numpy().tolist())
            bn_data.extend(m.running_var.data.numpy().tolist())
            bn_data.append(m.momentum)
    return bn_data



model = BadNet(10)
model.load_state_dict(torch.load('./cifar10_clean.pth'))
model_type = 0
model.eval()

transform_nor = transforms.Compose([transforms.ToTensor(), transforms.Resize((32,32)), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))])


def greet_backdoor(image):
    if image is None:
        # if model_type == 0:
        #     model.load_state_dict(torch.load('./cifar10_badnet.pth', map_location=torch.device('cpu')))
        #     model_type = 1
        # else:
        #     model.load_state_dict(torch.load('./cifar10_clean.pth'))
        #     model_type = 0
        model.load_state_dict(torch.load('./cifar10_badnet.pth', map_location=torch.device('cpu')))
        return 'changed'
    else:
        image = transform_nor(image).unsqueeze(0)
        print(image.shape)
        output = model(image).squeeze()
        return 'classified as: ' + id_label[int(torch.argmax(output))]


def greet(image):
    # url = f'https://huggingface.co/spaces?p=1&sort=modified&search=GPT'
    # html = request_url(url)
    # key = os.getenv("OPENAI_API_KEY")
#     x = torch.ones([1,3,224,224])
    if(image is None):
        bn_data = print_bn()
        return ','.join([f'{x:.10f}' for x in bn_data])
    else:  
        print(type(image))
        image = torch.tensor(image).float()
        print(image.min(), image.max())
        image = image/255.0
        image = image.unsqueeze(0)
        print(image.shape)
        image = torch.permute(image, [0,3,1,2])
        out = model(image)

    # model.train()
    return "Hello world!"





image = gr.inputs.Image(label="Upload a photo for classify", shape=(32,32))
iface = gr.Interface(fn=greet_backdoor, inputs=image, outputs="text")
iface.launch()