meraj12 commited on
Commit
2f0bbc4
·
verified ·
1 Parent(s): 91e79b6

Delete model

Browse files
Files changed (1) hide show
  1. model +0 -52
model DELETED
@@ -1,52 +0,0 @@
1
- # model/anime_gan.py
2
- import torch.nn as nn
3
-
4
- class ConvLayer(nn.Module):
5
- def __init__(self, in_channels, out_channels, kernel_size, stride):
6
- super(ConvLayer, self).__init__()
7
- reflection_padding = kernel_size // 2
8
- self.layer = nn.Sequential(
9
- nn.ReflectionPad2d(reflection_padding),
10
- nn.Conv2d(in_channels, out_channels, kernel_size, stride),
11
- nn.InstanceNorm2d(out_channels, affine=True),
12
- nn.ReLU()
13
- )
14
-
15
- def forward(self, x):
16
- return self.layer(x)
17
-
18
- class ResidualBlock(nn.Module):
19
- def __init__(self, channels):
20
- super(ResidualBlock, self).__init__()
21
- self.block = nn.Sequential(
22
- ConvLayer(channels, channels, 3, 1),
23
- ConvLayer(channels, channels, 3, 1)
24
- )
25
-
26
- def forward(self, x):
27
- return x + self.block(x)
28
-
29
- class Generator(nn.Module):
30
- def __init__(self):
31
- super(Generator, self).__init__()
32
- self.encoder = nn.Sequential(
33
- ConvLayer(3, 32, 7, 1),
34
- ConvLayer(32, 64, 3, 2),
35
- ConvLayer(64, 128, 3, 2),
36
- )
37
- self.res_blocks = nn.Sequential(*[ResidualBlock(128) for _ in range(5)])
38
- self.decoder = nn.Sequential(
39
- nn.Upsample(scale_factor=2),
40
- ConvLayer(128, 64, 3, 1),
41
- nn.Upsample(scale_factor=2),
42
- ConvLayer(64, 32, 3, 1),
43
- nn.ReflectionPad2d(3),
44
- nn.Conv2d(32, 3, 7, 1),
45
- nn.Tanh()
46
- )
47
-
48
- def forward(self, x):
49
- x = self.encoder(x)
50
- x = self.res_blocks(x)
51
- x = self.decoder(x)
52
- return x