jiang20 commited on
Commit
d9382d1
·
1 Parent(s): b3c95c7

Update badnet_m.py

Browse files
Files changed (1) hide show
  1. badnet_m.py +45 -30
badnet_m.py CHANGED
@@ -1,36 +1,51 @@
1
  from torch import nn
 
2
 
3
  class BadNet(nn.Module):
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
- def __init__(self, input_channels, output_num):
6
- super().__init__()
7
- self.conv1 = nn.Sequential(
8
- nn.Conv2d(in_channels=input_channels, out_channels=16, kernel_size=5, stride=1),
9
- nn.ReLU(),
10
- nn.AvgPool2d(kernel_size=2, stride=2)
11
- )
12
 
13
- self.conv2 = nn.Sequential(
14
- nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1),
15
- nn.ReLU(),
16
- nn.AvgPool2d(kernel_size=2, stride=2)
17
- )
18
- fc1_input_features = 800 if input_channels == 3 else 512
19
- self.fc1 = nn.Sequential(
20
- nn.Linear(in_features=fc1_input_features, out_features=512),
21
- nn.ReLU()
22
- )
23
- self.fc2 = nn.Sequential(
24
- nn.Linear(in_features=512, out_features=output_num),
25
- nn.Softmax(dim=-1)
26
- )
27
- self.dropout = nn.Dropout(p=.5)
28
 
29
- def forward(self, x):
30
- x = self.conv1(x)
31
- x = self.conv2(x)
32
- print(x.shape)
33
- x = x.view(x.size(0), -1)
34
- x = self.fc1(x)
35
- x = self.fc2(x)
36
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from torch import nn
2
+ import torchvision
3
 
4
  class BadNet(nn.Module):
5
+ # def __init__(self, input_channel, output_label) -> None:
6
+ # 目前只假设cifar10
7
+ def __init__(self, output_label) -> None:
8
+ super(BadNet, self).__init__()
9
+ self.model = torchvision.models.resnet18(pretrained=True)
10
+ num_features = self.model.fc.out_features
11
+ self.fc = nn.Linear(in_features=num_features, out_features=output_label)
12
+
13
+
14
+ def forward(self, xs):
15
+ out = self.model(xs)
16
+ return self.fc(out)
17
 
18
+ # class BadNet(nn.Module):
 
 
 
 
 
 
19
 
20
+ # def __init__(self, input_channels, output_num):
21
+ # super().__init__()
22
+ # self.conv1 = nn.Sequential(
23
+ # nn.Conv2d(in_channels=input_channels, out_channels=16, kernel_size=5, stride=1),
24
+ # nn.ReLU(),
25
+ # nn.AvgPool2d(kernel_size=2, stride=2)
26
+ # )
 
 
 
 
 
 
 
 
27
 
28
+ # self.conv2 = nn.Sequential(
29
+ # nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1),
30
+ # nn.ReLU(),
31
+ # nn.AvgPool2d(kernel_size=2, stride=2)
32
+ # )
33
+ # fc1_input_features = 800 if input_channels == 3 else 512
34
+ # self.fc1 = nn.Sequential(
35
+ # nn.Linear(in_features=fc1_input_features, out_features=512),
36
+ # nn.ReLU()
37
+ # )
38
+ # self.fc2 = nn.Sequential(
39
+ # nn.Linear(in_features=512, out_features=output_num),
40
+ # nn.Softmax(dim=-1)
41
+ # )
42
+ # self.dropout = nn.Dropout(p=.5)
43
+
44
+ # def forward(self, x):
45
+ # x = self.conv1(x)
46
+ # x = self.conv2(x)
47
+ # print(x.shape)
48
+ # x = x.view(x.size(0), -1)
49
+ # x = self.fc1(x)
50
+ # x = self.fc2(x)
51
+ # return x