Mariam-Elz commited on
Commit
472c525
·
verified ·
1 Parent(s): e4c74b6

Upload imagedream/ldm/models/autoencoder.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. imagedream/ldm/models/autoencoder.py +270 -0
imagedream/ldm/models/autoencoder.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from contextlib import contextmanager
4
+
5
+ from ..modules.diffusionmodules.model import Encoder, Decoder
6
+ from ..modules.distributions.distributions import DiagonalGaussianDistribution
7
+
8
+ from ..util import instantiate_from_config
9
+ from ..modules.ema import LitEma
10
+
11
+
12
+ class AutoencoderKL(torch.nn.Module):
13
+ def __init__(
14
+ self,
15
+ ddconfig,
16
+ lossconfig,
17
+ embed_dim,
18
+ ckpt_path=None,
19
+ ignore_keys=[],
20
+ image_key="image",
21
+ colorize_nlabels=None,
22
+ monitor=None,
23
+ ema_decay=None,
24
+ learn_logvar=False,
25
+ ):
26
+ super().__init__()
27
+ self.learn_logvar = learn_logvar
28
+ self.image_key = image_key
29
+ self.encoder = Encoder(**ddconfig)
30
+ self.decoder = Decoder(**ddconfig)
31
+ self.loss = instantiate_from_config(lossconfig)
32
+ assert ddconfig["double_z"]
33
+ self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1)
34
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
35
+ self.embed_dim = embed_dim
36
+ if colorize_nlabels is not None:
37
+ assert type(colorize_nlabels) == int
38
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
39
+ if monitor is not None:
40
+ self.monitor = monitor
41
+
42
+ self.use_ema = ema_decay is not None
43
+ if self.use_ema:
44
+ self.ema_decay = ema_decay
45
+ assert 0.0 < ema_decay < 1.0
46
+ self.model_ema = LitEma(self, decay=ema_decay)
47
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
48
+
49
+ if ckpt_path is not None:
50
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
51
+
52
+ def init_from_ckpt(self, path, ignore_keys=list()):
53
+ sd = torch.load(path, map_location="cpu")["state_dict"]
54
+ keys = list(sd.keys())
55
+ for k in keys:
56
+ for ik in ignore_keys:
57
+ if k.startswith(ik):
58
+ print("Deleting key {} from state_dict.".format(k))
59
+ del sd[k]
60
+ self.load_state_dict(sd, strict=False)
61
+ print(f"Restored from {path}")
62
+
63
+ @contextmanager
64
+ def ema_scope(self, context=None):
65
+ if self.use_ema:
66
+ self.model_ema.store(self.parameters())
67
+ self.model_ema.copy_to(self)
68
+ if context is not None:
69
+ print(f"{context}: Switched to EMA weights")
70
+ try:
71
+ yield None
72
+ finally:
73
+ if self.use_ema:
74
+ self.model_ema.restore(self.parameters())
75
+ if context is not None:
76
+ print(f"{context}: Restored training weights")
77
+
78
+ def on_train_batch_end(self, *args, **kwargs):
79
+ if self.use_ema:
80
+ self.model_ema(self)
81
+
82
+ def encode(self, x):
83
+ h = self.encoder(x)
84
+ moments = self.quant_conv(h)
85
+ posterior = DiagonalGaussianDistribution(moments)
86
+ return posterior
87
+
88
+ def decode(self, z):
89
+ z = self.post_quant_conv(z)
90
+ dec = self.decoder(z)
91
+ return dec
92
+
93
+ def forward(self, input, sample_posterior=True):
94
+ posterior = self.encode(input)
95
+ if sample_posterior:
96
+ z = posterior.sample()
97
+ else:
98
+ z = posterior.mode()
99
+ dec = self.decode(z)
100
+ return dec, posterior
101
+
102
+ def get_input(self, batch, k):
103
+ x = batch[k]
104
+ if len(x.shape) == 3:
105
+ x = x[..., None]
106
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
107
+ return x
108
+
109
+ def training_step(self, batch, batch_idx, optimizer_idx):
110
+ inputs = self.get_input(batch, self.image_key)
111
+ reconstructions, posterior = self(inputs)
112
+
113
+ if optimizer_idx == 0:
114
+ # train encoder+decoder+logvar
115
+ aeloss, log_dict_ae = self.loss(
116
+ inputs,
117
+ reconstructions,
118
+ posterior,
119
+ optimizer_idx,
120
+ self.global_step,
121
+ last_layer=self.get_last_layer(),
122
+ split="train",
123
+ )
124
+ self.log(
125
+ "aeloss",
126
+ aeloss,
127
+ prog_bar=True,
128
+ logger=True,
129
+ on_step=True,
130
+ on_epoch=True,
131
+ )
132
+ self.log_dict(
133
+ log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False
134
+ )
135
+ return aeloss
136
+
137
+ if optimizer_idx == 1:
138
+ # train the discriminator
139
+ discloss, log_dict_disc = self.loss(
140
+ inputs,
141
+ reconstructions,
142
+ posterior,
143
+ optimizer_idx,
144
+ self.global_step,
145
+ last_layer=self.get_last_layer(),
146
+ split="train",
147
+ )
148
+
149
+ self.log(
150
+ "discloss",
151
+ discloss,
152
+ prog_bar=True,
153
+ logger=True,
154
+ on_step=True,
155
+ on_epoch=True,
156
+ )
157
+ self.log_dict(
158
+ log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False
159
+ )
160
+ return discloss
161
+
162
+ def validation_step(self, batch, batch_idx):
163
+ log_dict = self._validation_step(batch, batch_idx)
164
+ with self.ema_scope():
165
+ log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
166
+ return log_dict
167
+
168
+ def _validation_step(self, batch, batch_idx, postfix=""):
169
+ inputs = self.get_input(batch, self.image_key)
170
+ reconstructions, posterior = self(inputs)
171
+ aeloss, log_dict_ae = self.loss(
172
+ inputs,
173
+ reconstructions,
174
+ posterior,
175
+ 0,
176
+ self.global_step,
177
+ last_layer=self.get_last_layer(),
178
+ split="val" + postfix,
179
+ )
180
+
181
+ discloss, log_dict_disc = self.loss(
182
+ inputs,
183
+ reconstructions,
184
+ posterior,
185
+ 1,
186
+ self.global_step,
187
+ last_layer=self.get_last_layer(),
188
+ split="val" + postfix,
189
+ )
190
+
191
+ self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"])
192
+ self.log_dict(log_dict_ae)
193
+ self.log_dict(log_dict_disc)
194
+ return self.log_dict
195
+
196
+ def configure_optimizers(self):
197
+ lr = self.learning_rate
198
+ ae_params_list = (
199
+ list(self.encoder.parameters())
200
+ + list(self.decoder.parameters())
201
+ + list(self.quant_conv.parameters())
202
+ + list(self.post_quant_conv.parameters())
203
+ )
204
+ if self.learn_logvar:
205
+ print(f"{self.__class__.__name__}: Learning logvar")
206
+ ae_params_list.append(self.loss.logvar)
207
+ opt_ae = torch.optim.Adam(ae_params_list, lr=lr, betas=(0.5, 0.9))
208
+ opt_disc = torch.optim.Adam(
209
+ self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9)
210
+ )
211
+ return [opt_ae, opt_disc], []
212
+
213
+ def get_last_layer(self):
214
+ return self.decoder.conv_out.weight
215
+
216
+ @torch.no_grad()
217
+ def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs):
218
+ log = dict()
219
+ x = self.get_input(batch, self.image_key)
220
+ x = x.to(self.device)
221
+ if not only_inputs:
222
+ xrec, posterior = self(x)
223
+ if x.shape[1] > 3:
224
+ # colorize with random projection
225
+ assert xrec.shape[1] > 3
226
+ x = self.to_rgb(x)
227
+ xrec = self.to_rgb(xrec)
228
+ log["samples"] = self.decode(torch.randn_like(posterior.sample()))
229
+ log["reconstructions"] = xrec
230
+ if log_ema or self.use_ema:
231
+ with self.ema_scope():
232
+ xrec_ema, posterior_ema = self(x)
233
+ if x.shape[1] > 3:
234
+ # colorize with random projection
235
+ assert xrec_ema.shape[1] > 3
236
+ xrec_ema = self.to_rgb(xrec_ema)
237
+ log["samples_ema"] = self.decode(
238
+ torch.randn_like(posterior_ema.sample())
239
+ )
240
+ log["reconstructions_ema"] = xrec_ema
241
+ log["inputs"] = x
242
+ return log
243
+
244
+ def to_rgb(self, x):
245
+ assert self.image_key == "segmentation"
246
+ if not hasattr(self, "colorize"):
247
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
248
+ x = F.conv2d(x, weight=self.colorize)
249
+ x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0
250
+ return x
251
+
252
+
253
+ class IdentityFirstStage(torch.nn.Module):
254
+ def __init__(self, *args, vq_interface=False, **kwargs):
255
+ self.vq_interface = vq_interface
256
+ super().__init__()
257
+
258
+ def encode(self, x, *args, **kwargs):
259
+ return x
260
+
261
+ def decode(self, x, *args, **kwargs):
262
+ return x
263
+
264
+ def quantize(self, x, *args, **kwargs):
265
+ if self.vq_interface:
266
+ return x, None, [None, None, None]
267
+ return x
268
+
269
+ def forward(self, x, *args, **kwargs):
270
+ return x