OVAWARE commited on
Commit
b8924f9
·
verified ·
1 Parent(s): 1dda828

Merge train.py with generate.py

Browse files
Files changed (1) hide show
  1. app.py +69 -2
app.py CHANGED
@@ -8,8 +8,75 @@ import numpy as np
8
  import os
9
  import time
10
 
11
- # Import the model architecture from train.py
12
- from train import CVAE, TextEncoder, LATENT_DIM, HIDDEN_DIM
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  # Initialize the BERT tokenizer
15
  tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
 
8
  import os
9
  import time
10
 
11
+ LATENT_DIM = 128
12
+ HIDDEN_DIM = 256
13
+
14
+
15
+ # Text encoder
16
+ class TextEncoder(nn.Module):
17
+ def __init__(self, hidden_size, output_size):
18
+ super(TextEncoder, self).__init__()
19
+ self.bert = BertModel.from_pretrained('bert-base-uncased')
20
+ self.fc = nn.Linear(self.bert.config.hidden_size, output_size)
21
+
22
+ def forward(self, input_ids, attention_mask):
23
+ outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
24
+ return self.fc(outputs.last_hidden_state[:, 0, :])
25
+
26
+ # CVAE model
27
+ class CVAE(nn.Module):
28
+ def __init__(self, text_encoder):
29
+ super(CVAE, self).__init__()
30
+ self.text_encoder = text_encoder
31
+
32
+ # Encoder
33
+ self.encoder = nn.Sequential(
34
+ nn.Conv2d(4, 32, 3, stride=1, padding=1),
35
+ nn.ReLU(),
36
+ nn.Conv2d(32, 64, 3, stride=2, padding=1),
37
+ nn.ReLU(),
38
+ nn.Conv2d(64, 128, 3, stride=2, padding=1),
39
+ nn.ReLU(),
40
+ nn.Flatten(),
41
+ nn.Linear(128 * 4 * 4, HIDDEN_DIM)
42
+ )
43
+
44
+ self.fc_mu = nn.Linear(HIDDEN_DIM + HIDDEN_DIM, LATENT_DIM)
45
+ self.fc_logvar = nn.Linear(HIDDEN_DIM + HIDDEN_DIM, LATENT_DIM)
46
+
47
+ # Decoder
48
+ self.decoder_input = nn.Linear(LATENT_DIM + HIDDEN_DIM, 128 * 4 * 4)
49
+ self.decoder = nn.Sequential(
50
+ nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1),
51
+ nn.ReLU(),
52
+ nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1),
53
+ nn.ReLU(),
54
+ nn.Conv2d(32, 4, 3, stride=1, padding=1),
55
+ nn.Tanh()
56
+ )
57
+
58
+ def encode(self, x, c):
59
+ x = self.encoder(x)
60
+ x = torch.cat([x, c], dim=1)
61
+ mu = self.fc_mu(x)
62
+ logvar = self.fc_logvar(x)
63
+ return mu, logvar
64
+
65
+ def decode(self, z, c):
66
+ z = torch.cat([z, c], dim=1)
67
+ x = self.decoder_input(z)
68
+ x = x.view(-1, 128, 4, 4)
69
+ return self.decoder(x)
70
+
71
+ def reparameterize(self, mu, logvar):
72
+ std = torch.exp(0.5 * logvar)
73
+ eps = torch.randn_like(std)
74
+ return mu + eps * std
75
+
76
+ def forward(self, x, c):
77
+ mu, logvar = self.encode(x, c)
78
+ z = self.reparameterize(mu, logvar)
79
+ return self.decode(z, c), mu, logvar
80
 
81
  # Initialize the BERT tokenizer
82
  tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')