import gradio as gr import requests import torch import torch.nn as nn from badnet_m import BadNet import torchvision.transforms as transforms # import timm # model = timm.create_model("hf_hub:nateraw/resnet18-random", pretrained=True) # model.train() id_label = {0:'airplane', 1:'automobile', 2:'bird', 3:'cat', 4:'deer', 5:'dog', 6:'frog', 7:'horse', 8:'ship', 9:'trunk'} import os def print_bn(): bn_data = [] for m in model.modules(): if(type(m) is nn.BatchNorm2d): # print(m.momentum) bn_data.extend(m.running_mean.data.numpy().tolist()) bn_data.extend(m.running_var.data.numpy().tolist()) bn_data.append(m.momentum) return bn_data model = BadNet(10) model.load_state_dict(torch.load('./cifar10_clean.pth')) model_type = 0 model.eval() transform_nor = transforms.Compose([transforms.ToTensor(), transforms.Resize((32,32)), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))]) def greet_backdoor(image): if image is None: # if model_type == 0: # model.load_state_dict(torch.load('./cifar10_badnet.pth', map_location=torch.device('cpu'))) # model_type = 1 # else: # model.load_state_dict(torch.load('./cifar10_clean.pth')) # model_type = 0 model.load_state_dict(torch.load('./cifar10_badnet.pth', map_location=torch.device('cpu'))) return 'changed' else: image = transform_nor(image).unsqueeze(0) print(image.shape) output = model(image).squeeze() return 'classified as: ' + id_label[int(torch.argmax(output))] def greet(image): # url = f'https://huggingface.co/spaces?p=1&sort=modified&search=GPT' # html = request_url(url) # key = os.getenv("OPENAI_API_KEY") # x = torch.ones([1,3,224,224]) if(image is None): bn_data = print_bn() return ','.join([f'{x:.10f}' for x in bn_data]) else: print(type(image)) image = torch.tensor(image).float() print(image.min(), image.max()) image = image/255.0 image = image.unsqueeze(0) print(image.shape) image = torch.permute(image, [0,3,1,2]) out = model(image) # model.train() return "Hello world!" image = gr.inputs.Image(label="Upload a photo for classify", shape=(32,32)) iface = gr.Interface(fn=greet_backdoor, inputs=image, outputs="text") iface.launch()