|
import gradio as gr |
|
import requests |
|
import torch |
|
import torch.nn as nn |
|
from badnet_m import BadNet |
|
import torchvision.transforms as transforms |
|
|
|
|
|
|
|
|
|
|
|
id_label = {0:'airplane', 1:'automobile', 2:'bird', 3:'cat', 4:'deer', 5:'dog', 6:'frog', 7:'horse', 8:'ship', 9:'trunk'} |
|
|
|
import os |
|
|
|
def print_bn(): |
|
bn_data = [] |
|
for m in model.modules(): |
|
if(type(m) is nn.BatchNorm2d): |
|
|
|
bn_data.extend(m.running_mean.data.numpy().tolist()) |
|
bn_data.extend(m.running_var.data.numpy().tolist()) |
|
bn_data.append(m.momentum) |
|
return bn_data |
|
|
|
|
|
|
|
model = BadNet(10) |
|
model.load_state_dict(torch.load('./cifar10_clean.pth')) |
|
model_type = 0 |
|
model.eval() |
|
|
|
transform_nor = transforms.Compose([transforms.ToTensor(), transforms.Resize((32,32)), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))]) |
|
|
|
|
|
def greet_backdoor(image): |
|
if image is None: |
|
|
|
|
|
|
|
|
|
|
|
|
|
model.load_state_dict(torch.load('./cifar10_badnet.pth', map_location=torch.device('cpu'))) |
|
return 'changed' |
|
else: |
|
image = transform_nor(image).unsqueeze(0) |
|
print(image.shape) |
|
output = model(image).squeeze() |
|
return 'classified as: ' + id_label[int(torch.argmax(output))] |
|
|
|
|
|
def greet(image): |
|
|
|
|
|
|
|
|
|
if(image is None): |
|
bn_data = print_bn() |
|
return ','.join([f'{x:.10f}' for x in bn_data]) |
|
else: |
|
print(type(image)) |
|
image = torch.tensor(image).float() |
|
print(image.min(), image.max()) |
|
image = image/255.0 |
|
image = image.unsqueeze(0) |
|
print(image.shape) |
|
image = torch.permute(image, [0,3,1,2]) |
|
out = model(image) |
|
|
|
|
|
return "Hello world!" |
|
|
|
|
|
|
|
|
|
|
|
image = gr.inputs.Image(label="Upload a photo for classify", shape=(32,32)) |
|
iface = gr.Interface(fn=greet_backdoor, inputs=image, outputs="text") |
|
iface.launch() |