jefsnacker commited on
Commit
6043dd9
·
1 Parent(s): 78ed328

add wavenet model

Browse files
Files changed (1) hide show
  1. app.py +87 -17
app.py CHANGED
@@ -9,19 +9,27 @@ import torch.nn.functional as F
9
  import yaml
10
 
11
 
12
- config_path = huggingface_hub.hf_hub_download(
13
- "jefsnacker/surname_mlp",
14
  "torch_mlp_config.yaml")
15
 
16
- weights_path = huggingface_hub.hf_hub_download(
17
- "jefsnacker/surname_mlp",
18
  "mlp_weights.pt")
19
 
20
- with open(config_path, 'r') as file:
21
- config = yaml.safe_load(file)
22
-
23
- stoi = config['stoi']
24
- itos = {s:i for i,s in stoi.items()}
 
 
 
 
 
 
 
 
25
 
26
  class MLP(nn.Module):
27
  def __init__(self, num_char, hidden_nodes, embeddings, window, num_layers):
@@ -67,24 +75,85 @@ mlp = MLP(config['num_char'],
67
  mlp.load_state_dict(torch.load(weights_path))
68
  mlp.eval()
69
 
70
- def generate_names(name_start, number_of_names):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  names = ""
72
  for _ in range((int)(number_of_names)):
73
 
74
  # Initialize name with user input
75
  name = ""
76
- context = [0] * config['window']
77
  for c in name_start.lower():
78
  name += c
79
  context = context[1:] + [stoi[c]]
80
 
81
  # Run inference to finish off the name
82
  while True:
83
- ix = mlp.sample_char(context)
84
-
 
 
 
 
 
 
85
  context = context[1:] + [ix]
86
  name += itos[ix]
87
-
88
  if ix == 0:
89
  break
90
 
@@ -92,12 +161,13 @@ def generate_names(name_start, number_of_names):
92
 
93
  return names
94
 
95
- app = gr.Interface(
96
  fn=generate_names,
97
  inputs=[
98
  gr.Textbox(placeholder="Start name with..."),
99
- gr.Number(value=1)
 
100
  ],
101
  outputs="text",
102
  )
103
- app.launch()
 
9
  import yaml
10
 
11
 
12
+ mlp_config_path = huggingface_hub.hf_hub_download(
13
+ "jefsnacker/surname_generator",
14
  "torch_mlp_config.yaml")
15
 
16
+ mlp_weights_path = huggingface_hub.hf_hub_download(
17
+ "jefsnacker/surname_generator",
18
  "mlp_weights.pt")
19
 
20
+ wavenet_config_path = huggingface_hub.hf_hub_download(
21
+ "jefsnacker/surname_generator",
22
+ "wavenet_config.yaml")
23
+
24
+ wavenet_weights_path = huggingface_hub.hf_hub_download(
25
+ "jefsnacker/surname_generator",
26
+ "wavenet_weights.pt")
27
+
28
+ with open(mlp_config_path, 'r') as file:
29
+ mlp_config = yaml.safe_load(file)
30
+
31
+ with open(wavenet_config_path, 'r') as file:
32
+ wavenet_config = yaml.safe_load(file)
33
 
34
  class MLP(nn.Module):
35
  def __init__(self, num_char, hidden_nodes, embeddings, window, num_layers):
 
75
  mlp.load_state_dict(torch.load(weights_path))
76
  mlp.eval()
77
 
78
+ class WaveNet(nn.Module):
79
+ def __init__(self, num_char, hidden_nodes, embeddings, window, num_layers):
80
+ super(WaveNet, self).__init__()
81
+
82
+ self.window = window
83
+ self.hidden_nodes = hidden_nodes
84
+ self.embeddings = embeddings
85
+
86
+ self.layers = nn.Sequential(
87
+ nn.Embedding(num_char, embeddings)
88
+ )
89
+
90
+ for i in range(num_layers):
91
+ if i == 0:
92
+ nodes = window
93
+ else:
94
+ nodes = hidden_nodes
95
+
96
+ self.layers = self.layers.extend(nn.Sequential(
97
+ nn.Conv1d(nodes, hidden_nodes, kernel_size=2, stride=1, bias=False),
98
+ nn.BatchNorm1d(hidden_nodes),
99
+ nn.Tanh()))
100
+
101
+ self.layers = self.layers.extend(nn.Sequential(
102
+ nn.Flatten(),
103
+ nn.Linear(hidden_nodes*(embeddings-num_layers), num_char)
104
+ ))
105
+
106
+ def forward(self, x):
107
+ return self.layers(x)
108
+
109
+ def sample_char(self, x):
110
+ logits = self(x)
111
+ probs = F.softmax(logits, dim=1)
112
+ return torch.multinomial(probs, num_samples=1).item()
113
+
114
+ wavenet = WaveNet(wavenet_config['num_char'],
115
+ wavenet_config['hidden_nodes'],
116
+ wavenet_config['embeddings'],
117
+ wavenet_config['window'],
118
+ wavenet_config['num_layers'])
119
+ wavenet.load_state_dict(torch.load(wavenet_weights_path))
120
+ wavenet.eval()
121
+
122
+ def generate_names(name_start, number_of_names, model):
123
+ if model == "MLP":
124
+ stoi = mlp_config['stoi']
125
+ window = mlp_config['window']
126
+ elif model == "WaveNet":
127
+ stoi = wavenet_config['stoi']
128
+ window = wavenet_config['window']
129
+ else:
130
+ raise Exception("Model not selected")
131
+
132
+ itos = {s:i for i,s in stoi.items()}
133
+
134
  names = ""
135
  for _ in range((int)(number_of_names)):
136
 
137
  # Initialize name with user input
138
  name = ""
139
+ context = [0] * window
140
  for c in name_start.lower():
141
  name += c
142
  context = context[1:] + [stoi[c]]
143
 
144
  # Run inference to finish off the name
145
  while True:
146
+ x = torch.tensor(context).view(1, -1)
147
+ if model == "MLP":
148
+ ix = mlp.sample_char(x)
149
+ elif model == "WaveNet":
150
+ ix = wavenet.sample_char(x)
151
+ else:
152
+ raise Exception("Model not selected")
153
+
154
  context = context[1:] + [ix]
155
  name += itos[ix]
156
+
157
  if ix == 0:
158
  break
159
 
 
161
 
162
  return names
163
 
164
+ demo = gr.Interface(
165
  fn=generate_names,
166
  inputs=[
167
  gr.Textbox(placeholder="Start name with..."),
168
+ gr.Number(value=5),
169
+ gr.Dropdown(["MLP", "WaveNet"], value="WaveNet"),
170
  ],
171
  outputs="text",
172
  )
173
+ demo.launch()