Mariam-Elz commited on
Commit
26198af
·
verified ·
1 Parent(s): 844790c

Upload imagedream/ldm/interface.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. imagedream/ldm/interface.py +205 -0
imagedream/ldm/interface.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ from functools import partial
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from .modules.diffusionmodules.util import (
9
+ make_beta_schedule,
10
+ extract_into_tensor,
11
+ enforce_zero_terminal_snr,
12
+ noise_like,
13
+ )
14
+ from .util import exists, default, instantiate_from_config
15
+ from .modules.distributions.distributions import DiagonalGaussianDistribution
16
+
17
+
18
+ class DiffusionWrapper(nn.Module):
19
+ def __init__(self, diffusion_model):
20
+ super().__init__()
21
+ self.diffusion_model = diffusion_model
22
+
23
+ def forward(self, *args, **kwargs):
24
+ return self.diffusion_model(*args, **kwargs)
25
+
26
+
27
+ class LatentDiffusionInterface(nn.Module):
28
+ """a simple interface class for LDM inference"""
29
+
30
+ def __init__(
31
+ self,
32
+ unet_config,
33
+ clip_config,
34
+ vae_config,
35
+ parameterization="eps",
36
+ scale_factor=0.18215,
37
+ beta_schedule="linear",
38
+ timesteps=1000,
39
+ linear_start=0.00085,
40
+ linear_end=0.0120,
41
+ cosine_s=8e-3,
42
+ given_betas=None,
43
+ zero_snr=False,
44
+ *args,
45
+ **kwargs,
46
+ ):
47
+ super().__init__()
48
+
49
+ unet = instantiate_from_config(unet_config)
50
+ self.model = DiffusionWrapper(unet)
51
+ self.clip_model = instantiate_from_config(clip_config)
52
+ self.vae_model = instantiate_from_config(vae_config)
53
+
54
+ self.parameterization = parameterization
55
+ self.scale_factor = scale_factor
56
+ self.register_schedule(
57
+ given_betas=given_betas,
58
+ beta_schedule=beta_schedule,
59
+ timesteps=timesteps,
60
+ linear_start=linear_start,
61
+ linear_end=linear_end,
62
+ cosine_s=cosine_s,
63
+ zero_snr=zero_snr
64
+ )
65
+
66
+ def register_schedule(
67
+ self,
68
+ given_betas=None,
69
+ beta_schedule="linear",
70
+ timesteps=1000,
71
+ linear_start=1e-4,
72
+ linear_end=2e-2,
73
+ cosine_s=8e-3,
74
+ zero_snr=False
75
+ ):
76
+ if exists(given_betas):
77
+ betas = given_betas
78
+ else:
79
+ betas = make_beta_schedule(
80
+ beta_schedule,
81
+ timesteps,
82
+ linear_start=linear_start,
83
+ linear_end=linear_end,
84
+ cosine_s=cosine_s,
85
+ )
86
+ if zero_snr:
87
+ print("--- using zero snr---")
88
+ betas = enforce_zero_terminal_snr(betas).numpy()
89
+ alphas = 1.0 - betas
90
+ alphas_cumprod = np.cumprod(alphas, axis=0)
91
+ alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
92
+
93
+ (timesteps,) = betas.shape
94
+ self.num_timesteps = int(timesteps)
95
+ self.linear_start = linear_start
96
+ self.linear_end = linear_end
97
+ assert (
98
+ alphas_cumprod.shape[0] == self.num_timesteps
99
+ ), "alphas have to be defined for each timestep"
100
+
101
+ to_torch = partial(torch.tensor, dtype=torch.float32)
102
+
103
+ self.register_buffer("betas", to_torch(betas))
104
+ self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
105
+ self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev))
106
+
107
+ # calculations for diffusion q(x_t | x_{t-1}) and others
108
+ self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod)))
109
+ self.register_buffer(
110
+ "sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod))
111
+ )
112
+ self.register_buffer(
113
+ "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod))
114
+ )
115
+ self.register_buffer(
116
+ "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod))
117
+ )
118
+ self.register_buffer(
119
+ "sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1))
120
+ )
121
+
122
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
123
+ self.v_posterior = 0
124
+ posterior_variance = (1 - self.v_posterior) * betas * (
125
+ 1.0 - alphas_cumprod_prev
126
+ ) / (1.0 - alphas_cumprod) + self.v_posterior * betas
127
+ # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
128
+ self.register_buffer("posterior_variance", to_torch(posterior_variance))
129
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
130
+ self.register_buffer(
131
+ "posterior_log_variance_clipped",
132
+ to_torch(np.log(np.maximum(posterior_variance, 1e-20))),
133
+ )
134
+ self.register_buffer(
135
+ "posterior_mean_coef1",
136
+ to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)),
137
+ )
138
+ self.register_buffer(
139
+ "posterior_mean_coef2",
140
+ to_torch(
141
+ (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod)
142
+ ),
143
+ )
144
+
145
+ def q_sample(self, x_start, t, noise=None):
146
+ noise = default(noise, lambda: torch.randn_like(x_start))
147
+ return (
148
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
149
+ + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
150
+ * noise
151
+ )
152
+
153
+ def get_v(self, x, noise, t):
154
+ return (
155
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise
156
+ - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
157
+ )
158
+
159
+ def predict_start_from_noise(self, x_t, t, noise):
160
+ return (
161
+ extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
162
+ - extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
163
+ * noise
164
+ )
165
+
166
+ def predict_start_from_z_and_v(self, x_t, t, v):
167
+ return (
168
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t
169
+ - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
170
+ )
171
+
172
+ def predict_eps_from_z_and_v(self, x_t, t, v):
173
+ return (
174
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v
175
+ + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape)
176
+ * x_t
177
+ )
178
+
179
+ def apply_model(self, x_noisy, t, cond, **kwargs):
180
+ assert isinstance(cond, dict), "cond has to be a dictionary"
181
+ return self.model(x_noisy, t, **cond, **kwargs)
182
+
183
+ def get_learned_conditioning(self, prompts: List[str]):
184
+ return self.clip_model(prompts)
185
+
186
+ def get_learned_image_conditioning(self, images):
187
+ return self.clip_model.forward_image(images)
188
+
189
+ def get_first_stage_encoding(self, encoder_posterior):
190
+ if isinstance(encoder_posterior, DiagonalGaussianDistribution):
191
+ z = encoder_posterior.sample()
192
+ elif isinstance(encoder_posterior, torch.Tensor):
193
+ z = encoder_posterior
194
+ else:
195
+ raise NotImplementedError(
196
+ f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented"
197
+ )
198
+ return self.scale_factor * z
199
+
200
+ def encode_first_stage(self, x):
201
+ return self.vae_model.encode(x)
202
+
203
+ def decode_first_stage(self, z):
204
+ z = 1.0 / self.scale_factor * z
205
+ return self.vae_model.decode(z)