buchi-stdesign commited on
Commit
37723db
·
verified ·
1 Parent(s): 041ba86

Update src/sbv2/generator.py

Browse files
Files changed (1) hide show
  1. src/sbv2/generator.py +20 -33
src/sbv2/generator.py CHANGED
@@ -1,33 +1,20 @@
1
- import torch
2
- from torch import nn
3
- import torch.nn.functional as F
4
- from .modules import LayerNorm, ConvReluNorm
5
-
6
- class Generator(nn.Module):
7
- def __init__(self, channels):
8
- super(Generator, self).__init__()
9
- self.conv_pre = nn.Conv1d(channels, 512, 7, 1, 3)
10
- self.resblocks = nn.ModuleList([
11
- ResBlock(512) for _ in range(3)
12
- ])
13
- self.conv_post = nn.Conv1d(512, 1, 7, 1, 3)
14
-
15
- def forward(self, x):
16
- x = self.conv_pre(x)
17
- for resblock in self.resblocks:
18
- x = resblock(x)
19
- x = self.conv_post(x)
20
- x = torch.tanh(x)
21
- return x
22
-
23
- class ResBlock(nn.Module):
24
- def __init__(self, channels):
25
- super(ResBlock, self).__init__()
26
- self.convs = nn.Sequential(
27
- nn.Conv1d(channels, channels, 3, 1, 1),
28
- nn.ReLU(),
29
- nn.Conv1d(channels, channels, 3, 1, 1)
30
- )
31
-
32
- def forward(self, x):
33
- return x + self.convs(x)
 
1
+ import torch.nn as nn
2
+
3
+ class Generator(nn.Module):
4
+ def __init__(
5
+ self,
6
+ upsample_rates,
7
+ upsample_initial_channel,
8
+ resblock_kernel_sizes,
9
+ resblock_dilation_sizes,
10
+ resblock,
11
+ upsample_kernel_sizes,
12
+ inter_channels,
13
+ out_channels,
14
+ sampling_rate
15
+ ):
16
+ super().__init__()
17
+ self.dummy = nn.Identity() # 実装は後で拡張可
18
+
19
+ def forward(self, x):
20
+ return self.dummy(x)