jiang20 commited on
Commit
b3c95c7
·
1 Parent(s): bf50f54

Upload badnet_m.py

Browse files
Files changed (1) hide show
  1. badnet_m.py +36 -0
badnet_m.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+
3
+ class BadNet(nn.Module):
4
+
5
+ def __init__(self, input_channels, output_num):
6
+ super().__init__()
7
+ self.conv1 = nn.Sequential(
8
+ nn.Conv2d(in_channels=input_channels, out_channels=16, kernel_size=5, stride=1),
9
+ nn.ReLU(),
10
+ nn.AvgPool2d(kernel_size=2, stride=2)
11
+ )
12
+
13
+ self.conv2 = nn.Sequential(
14
+ nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1),
15
+ nn.ReLU(),
16
+ nn.AvgPool2d(kernel_size=2, stride=2)
17
+ )
18
+ fc1_input_features = 800 if input_channels == 3 else 512
19
+ self.fc1 = nn.Sequential(
20
+ nn.Linear(in_features=fc1_input_features, out_features=512),
21
+ nn.ReLU()
22
+ )
23
+ self.fc2 = nn.Sequential(
24
+ nn.Linear(in_features=512, out_features=output_num),
25
+ nn.Softmax(dim=-1)
26
+ )
27
+ self.dropout = nn.Dropout(p=.5)
28
+
29
+ def forward(self, x):
30
+ x = self.conv1(x)
31
+ x = self.conv2(x)
32
+ print(x.shape)
33
+ x = x.view(x.size(0), -1)
34
+ x = self.fc1(x)
35
+ x = self.fc2(x)
36
+ return x