File size: 2,440 Bytes
bf50f54 d7ebdc0 bf50f54 d7ebdc0 bf50f54 90a0539 bf50f54 d7ebdc0 26e2141 0fc4827 0dacd2a d7ebdc0 c6b134d d7ebdc0 7e89ea8 d7ebdc0 90a0539 d7ebdc0 bf50f54 d7ebdc0 3b83509 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
import gradio as gr
import requests
import torch
import torch.nn as nn
from badnet_m import BadNet
import torchvision.transforms as transforms
# import timm
# model = timm.create_model("hf_hub:nateraw/resnet18-random", pretrained=True)
# model.train()
id_label = {0:'airplane', 1:'automobile', 2:'bird', 3:'cat', 4:'deer', 5:'dog', 6:'frog', 7:'horse', 8:'ship', 9:'trunk'}
import os
def print_bn():
bn_data = []
for m in model.modules():
if(type(m) is nn.BatchNorm2d):
# print(m.momentum)
bn_data.extend(m.running_mean.data.numpy().tolist())
bn_data.extend(m.running_var.data.numpy().tolist())
bn_data.append(m.momentum)
return bn_data
model = BadNet(10)
model.load_state_dict(torch.load('./cifar10_clean.pth'))
model_type = 0
model.eval()
transform_nor = transforms.Compose([transforms.ToTensor(), transforms.Resize((32,32)), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))])
def greet_backdoor(image):
if image is None:
# if model_type == 0:
# model.load_state_dict(torch.load('./cifar10_badnet.pth', map_location=torch.device('cpu')))
# model_type = 1
# else:
# model.load_state_dict(torch.load('./cifar10_clean.pth'))
# model_type = 0
model.load_state_dict(torch.load('./cifar10_badnet.pth', map_location=torch.device('cpu')))
return 'changed'
else:
image = transform_nor(image).unsqueeze(0)
print(image.shape)
output = model(image).squeeze()
return 'classified as: ' + id_label[int(torch.argmax(output))]
def greet(image):
# url = f'https://huggingface.co/spaces?p=1&sort=modified&search=GPT'
# html = request_url(url)
# key = os.getenv("OPENAI_API_KEY")
# x = torch.ones([1,3,224,224])
if(image is None):
bn_data = print_bn()
return ','.join([f'{x:.10f}' for x in bn_data])
else:
print(type(image))
image = torch.tensor(image).float()
print(image.min(), image.max())
image = image/255.0
image = image.unsqueeze(0)
print(image.shape)
image = torch.permute(image, [0,3,1,2])
out = model(image)
# model.train()
return "Hello world!"
image = gr.inputs.Image(label="Upload a photo for classify", shape=(32,32))
iface = gr.Interface(fn=greet_backdoor, inputs=image, outputs="text")
iface.launch() |