File size: 1,703 Bytes
b3c95c7
d9382d1
b3c95c7
 
d9382d1
 
 
 
 
 
 
 
 
 
 
 
b3c95c7
d9382d1
b3c95c7
d9382d1
 
 
 
 
 
 
b3c95c7
d9382d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from torch import nn
import torchvision

class BadNet(nn.Module):
    # def __init__(self, input_channel, output_label) -> None:
    # 目前只假设cifar10
    def __init__(self, output_label) -> None:
        super(BadNet, self).__init__()
        self.model = torchvision.models.resnet18(pretrained=True)
        num_features = self.model.fc.out_features
        self.fc = nn.Linear(in_features=num_features, out_features=output_label)
        
    
    def forward(self, xs):
        out = self.model(xs)
        return self.fc(out)

# class BadNet(nn.Module):

#     def __init__(self, input_channels, output_num):
#         super().__init__()
#         self.conv1 = nn.Sequential(
#             nn.Conv2d(in_channels=input_channels, out_channels=16, kernel_size=5, stride=1),
#             nn.ReLU(),
#             nn.AvgPool2d(kernel_size=2, stride=2)
#         )

#         self.conv2 = nn.Sequential(
#             nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1),
#             nn.ReLU(),
#             nn.AvgPool2d(kernel_size=2, stride=2)
#         )
#         fc1_input_features = 800 if input_channels == 3 else 512
#         self.fc1 = nn.Sequential(
#             nn.Linear(in_features=fc1_input_features, out_features=512),
#             nn.ReLU()
#         )
#         self.fc2 = nn.Sequential(
#             nn.Linear(in_features=512, out_features=output_num),
#             nn.Softmax(dim=-1)
#         )
#         self.dropout = nn.Dropout(p=.5)

#     def forward(self, x):
#         x = self.conv1(x)
#         x = self.conv2(x)
#         print(x.shape)
#         x = x.view(x.size(0), -1)
#         x = self.fc1(x)
#         x = self.fc2(x)
#         return x