from torch import nn import torchvision class BadNet(nn.Module): # def __init__(self, input_channel, output_label) -> None: # 目前只假设cifar10 def __init__(self, output_label) -> None: super(BadNet, self).__init__() self.model = torchvision.models.resnet18(pretrained=True) num_features = self.model.fc.out_features self.fc = nn.Linear(in_features=num_features, out_features=output_label) def forward(self, xs): out = self.model(xs) return self.fc(out) # class BadNet(nn.Module): # def __init__(self, input_channels, output_num): # super().__init__() # self.conv1 = nn.Sequential( # nn.Conv2d(in_channels=input_channels, out_channels=16, kernel_size=5, stride=1), # nn.ReLU(), # nn.AvgPool2d(kernel_size=2, stride=2) # ) # self.conv2 = nn.Sequential( # nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1), # nn.ReLU(), # nn.AvgPool2d(kernel_size=2, stride=2) # ) # fc1_input_features = 800 if input_channels == 3 else 512 # self.fc1 = nn.Sequential( # nn.Linear(in_features=fc1_input_features, out_features=512), # nn.ReLU() # ) # self.fc2 = nn.Sequential( # nn.Linear(in_features=512, out_features=output_num), # nn.Softmax(dim=-1) # ) # self.dropout = nn.Dropout(p=.5) # def forward(self, x): # x = self.conv1(x) # x = self.conv2(x) # print(x.shape) # x = x.view(x.size(0), -1) # x = self.fc1(x) # x = self.fc2(x) # return x