Mariam-Elz commited on
Commit
6912fdf
·
verified ·
1 Parent(s): 2e06668

Upload imagedream/ldm/modules/attention.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. imagedream/ldm/modules/attention.py +456 -456
imagedream/ldm/modules/attention.py CHANGED
@@ -1,456 +1,456 @@
1
- from inspect import isfunction
2
- import math
3
- import torch
4
- import torch.nn.functional as F
5
- from torch import nn, einsum
6
- from einops import rearrange, repeat
7
- from typing import Optional, Any
8
-
9
- from .diffusionmodules.util import checkpoint
10
-
11
-
12
- try:
13
- import xformers
14
- import xformers.ops
15
-
16
- XFORMERS_IS_AVAILBLE = True
17
- except:
18
- XFORMERS_IS_AVAILBLE = False
19
-
20
- # CrossAttn precision handling
21
- import os
22
-
23
- _ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
24
-
25
-
26
- def exists(val):
27
- return val is not None
28
-
29
-
30
- def uniq(arr):
31
- return {el: True for el in arr}.keys()
32
-
33
-
34
- def default(val, d):
35
- if exists(val):
36
- return val
37
- return d() if isfunction(d) else d
38
-
39
-
40
- def max_neg_value(t):
41
- return -torch.finfo(t.dtype).max
42
-
43
-
44
- def init_(tensor):
45
- dim = tensor.shape[-1]
46
- std = 1 / math.sqrt(dim)
47
- tensor.uniform_(-std, std)
48
- return tensor
49
-
50
-
51
- # feedforward
52
- class GEGLU(nn.Module):
53
- def __init__(self, dim_in, dim_out):
54
- super().__init__()
55
- self.proj = nn.Linear(dim_in, dim_out * 2)
56
-
57
- def forward(self, x):
58
- x, gate = self.proj(x).chunk(2, dim=-1)
59
- return x * F.gelu(gate)
60
-
61
-
62
- class FeedForward(nn.Module):
63
- def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
64
- super().__init__()
65
- inner_dim = int(dim * mult)
66
- dim_out = default(dim_out, dim)
67
- project_in = (
68
- nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
69
- if not glu
70
- else GEGLU(dim, inner_dim)
71
- )
72
-
73
- self.net = nn.Sequential(
74
- project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
75
- )
76
-
77
- def forward(self, x):
78
- return self.net(x)
79
-
80
-
81
- def zero_module(module):
82
- """
83
- Zero out the parameters of a module and return it.
84
- """
85
- for p in module.parameters():
86
- p.detach().zero_()
87
- return module
88
-
89
-
90
- def Normalize(in_channels):
91
- return torch.nn.GroupNorm(
92
- num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
93
- )
94
-
95
-
96
- class SpatialSelfAttention(nn.Module):
97
- def __init__(self, in_channels):
98
- super().__init__()
99
- self.in_channels = in_channels
100
-
101
- self.norm = Normalize(in_channels)
102
- self.q = torch.nn.Conv2d(
103
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
104
- )
105
- self.k = torch.nn.Conv2d(
106
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
107
- )
108
- self.v = torch.nn.Conv2d(
109
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
110
- )
111
- self.proj_out = torch.nn.Conv2d(
112
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
113
- )
114
-
115
- def forward(self, x):
116
- h_ = x
117
- h_ = self.norm(h_)
118
- q = self.q(h_)
119
- k = self.k(h_)
120
- v = self.v(h_)
121
-
122
- # compute attention
123
- b, c, h, w = q.shape
124
- q = rearrange(q, "b c h w -> b (h w) c")
125
- k = rearrange(k, "b c h w -> b c (h w)")
126
- w_ = torch.einsum("bij,bjk->bik", q, k)
127
-
128
- w_ = w_ * (int(c) ** (-0.5))
129
- w_ = torch.nn.functional.softmax(w_, dim=2)
130
-
131
- # attend to values
132
- v = rearrange(v, "b c h w -> b c (h w)")
133
- w_ = rearrange(w_, "b i j -> b j i")
134
- h_ = torch.einsum("bij,bjk->bik", v, w_)
135
- h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
136
- h_ = self.proj_out(h_)
137
-
138
- return x + h_
139
-
140
-
141
- class MemoryEfficientCrossAttention(nn.Module):
142
- # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
143
- def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs):
144
- super().__init__()
145
- print(
146
- f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
147
- f"{heads} heads."
148
- )
149
- inner_dim = dim_head * heads
150
- context_dim = default(context_dim, query_dim)
151
-
152
- self.heads = heads
153
- self.dim_head = dim_head
154
-
155
- self.with_ip = kwargs.get("with_ip", False)
156
- if self.with_ip and (context_dim is not None):
157
- self.to_k_ip = nn.Linear(context_dim, inner_dim, bias=False)
158
- self.to_v_ip = nn.Linear(context_dim, inner_dim, bias=False)
159
- self.ip_dim= kwargs.get("ip_dim", 16)
160
- self.ip_weight = kwargs.get("ip_weight", 1.0)
161
-
162
- self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
163
- self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
164
- self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
165
-
166
- self.to_out = nn.Sequential(
167
- nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
168
- )
169
- self.attention_op: Optional[Any] = None
170
-
171
- def forward(self, x, context=None, mask=None):
172
- q = self.to_q(x)
173
-
174
- has_ip = self.with_ip and (context is not None)
175
- if has_ip:
176
- # context dim [(b frame_num), (77 + img_token), 1024]
177
- token_len = context.shape[1]
178
- context_ip = context[:, -self.ip_dim:, :]
179
- k_ip = self.to_k_ip(context_ip)
180
- v_ip = self.to_v_ip(context_ip)
181
- context = context[:, :(token_len - self.ip_dim), :]
182
-
183
- context = default(context, x)
184
- k = self.to_k(context)
185
- v = self.to_v(context)
186
-
187
- b, _, _ = q.shape
188
- q, k, v = map(
189
- lambda t: t.unsqueeze(3)
190
- .reshape(b, t.shape[1], self.heads, self.dim_head)
191
- .permute(0, 2, 1, 3)
192
- .reshape(b * self.heads, t.shape[1], self.dim_head)
193
- .contiguous(),
194
- (q, k, v),
195
- )
196
-
197
- # actually compute the attention, what we cannot get enough of
198
- out = xformers.ops.memory_efficient_attention(
199
- q, k, v, attn_bias=None, op=self.attention_op
200
- )
201
-
202
- if has_ip:
203
- k_ip, v_ip = map(
204
- lambda t: t.unsqueeze(3)
205
- .reshape(b, t.shape[1], self.heads, self.dim_head)
206
- .permute(0, 2, 1, 3)
207
- .reshape(b * self.heads, t.shape[1], self.dim_head)
208
- .contiguous(),
209
- (k_ip, v_ip),
210
- )
211
- # actually compute the attention, what we cannot get enough of
212
- out_ip = xformers.ops.memory_efficient_attention(
213
- q, k_ip, v_ip, attn_bias=None, op=self.attention_op
214
- )
215
- out = out + self.ip_weight * out_ip
216
-
217
- if exists(mask):
218
- raise NotImplementedError
219
- out = (
220
- out.unsqueeze(0)
221
- .reshape(b, self.heads, out.shape[1], self.dim_head)
222
- .permute(0, 2, 1, 3)
223
- .reshape(b, out.shape[1], self.heads * self.dim_head)
224
- )
225
- return self.to_out(out)
226
-
227
-
228
- class BasicTransformerBlock(nn.Module):
229
- def __init__(
230
- self,
231
- dim,
232
- n_heads,
233
- d_head,
234
- dropout=0.0,
235
- context_dim=None,
236
- gated_ff=True,
237
- checkpoint=True,
238
- disable_self_attn=False,
239
- **kwargs
240
- ):
241
- super().__init__()
242
- assert XFORMERS_IS_AVAILBLE, "xformers is not available"
243
- attn_cls = MemoryEfficientCrossAttention
244
- self.disable_self_attn = disable_self_attn
245
- self.attn1 = attn_cls(
246
- query_dim=dim,
247
- heads=n_heads,
248
- dim_head=d_head,
249
- dropout=dropout,
250
- context_dim=context_dim if self.disable_self_attn else None,
251
- ) # is a self-attention if not self.disable_self_attn
252
- self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
253
- self.attn2 = attn_cls(
254
- query_dim=dim,
255
- context_dim=context_dim,
256
- heads=n_heads,
257
- dim_head=d_head,
258
- dropout=dropout,
259
- **kwargs
260
- ) # is self-attn if context is none
261
- self.norm1 = nn.LayerNorm(dim)
262
- self.norm2 = nn.LayerNorm(dim)
263
- self.norm3 = nn.LayerNorm(dim)
264
- self.checkpoint = checkpoint
265
-
266
- def forward(self, x, context=None):
267
- return checkpoint(
268
- self._forward, (x, context), self.parameters(), self.checkpoint
269
- )
270
-
271
- def _forward(self, x, context=None):
272
- x = (
273
- self.attn1(
274
- self.norm1(x), context=context if self.disable_self_attn else None
275
- )
276
- + x
277
- )
278
- x = self.attn2(self.norm2(x), context=context) + x
279
- x = self.ff(self.norm3(x)) + x
280
- return x
281
-
282
-
283
- class SpatialTransformer(nn.Module):
284
- """
285
- Transformer block for image-like data.
286
- First, project the input (aka embedding)
287
- and reshape to b, t, d.
288
- Then apply standard transformer action.
289
- Finally, reshape to image
290
- NEW: use_linear for more efficiency instead of the 1x1 convs
291
- """
292
-
293
- def __init__(
294
- self,
295
- in_channels,
296
- n_heads,
297
- d_head,
298
- depth=1,
299
- dropout=0.0,
300
- context_dim=None,
301
- disable_self_attn=False,
302
- use_linear=False,
303
- use_checkpoint=True,
304
- **kwargs
305
- ):
306
- super().__init__()
307
- if exists(context_dim) and not isinstance(context_dim, list):
308
- context_dim = [context_dim]
309
- self.in_channels = in_channels
310
- inner_dim = n_heads * d_head
311
- self.norm = Normalize(in_channels)
312
- if not use_linear:
313
- self.proj_in = nn.Conv2d(
314
- in_channels, inner_dim, kernel_size=1, stride=1, padding=0
315
- )
316
- else:
317
- self.proj_in = nn.Linear(in_channels, inner_dim)
318
-
319
- self.transformer_blocks = nn.ModuleList(
320
- [
321
- BasicTransformerBlock(
322
- inner_dim,
323
- n_heads,
324
- d_head,
325
- dropout=dropout,
326
- context_dim=context_dim[d],
327
- disable_self_attn=disable_self_attn,
328
- checkpoint=use_checkpoint,
329
- **kwargs
330
- )
331
- for d in range(depth)
332
- ]
333
- )
334
- if not use_linear:
335
- self.proj_out = zero_module(
336
- nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
337
- )
338
- else:
339
- self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
340
- self.use_linear = use_linear
341
-
342
- def forward(self, x, context=None):
343
- # note: if no context is given, cross-attention defaults to self-attention
344
- if not isinstance(context, list):
345
- context = [context]
346
- b, c, h, w = x.shape
347
- x_in = x
348
- x = self.norm(x)
349
- if not self.use_linear:
350
- x = self.proj_in(x)
351
- x = rearrange(x, "b c h w -> b (h w) c").contiguous()
352
- if self.use_linear:
353
- x = self.proj_in(x)
354
- for i, block in enumerate(self.transformer_blocks):
355
- x = block(x, context=context[i])
356
- if self.use_linear:
357
- x = self.proj_out(x)
358
- x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
359
- if not self.use_linear:
360
- x = self.proj_out(x)
361
- return x + x_in
362
-
363
-
364
- class BasicTransformerBlock3D(BasicTransformerBlock):
365
- def forward(self, x, context=None, num_frames=1):
366
- return checkpoint(
367
- self._forward, (x, context, num_frames), self.parameters(), self.checkpoint
368
- )
369
-
370
- def _forward(self, x, context=None, num_frames=1):
371
- x = rearrange(x, "(b f) l c -> b (f l) c", f=num_frames).contiguous()
372
- x = (
373
- self.attn1(
374
- self.norm1(x),
375
- context=context if self.disable_self_attn else None
376
- )
377
- + x
378
- )
379
- x = rearrange(x, "b (f l) c -> (b f) l c", f=num_frames).contiguous()
380
- x = self.attn2(self.norm2(x), context=context) + x
381
- x = self.ff(self.norm3(x)) + x
382
- return x
383
-
384
-
385
- class SpatialTransformer3D(nn.Module):
386
- """3D self-attention"""
387
-
388
- def __init__(
389
- self,
390
- in_channels,
391
- n_heads,
392
- d_head,
393
- depth=1,
394
- dropout=0.0,
395
- context_dim=None,
396
- disable_self_attn=False,
397
- use_linear=False,
398
- use_checkpoint=True,
399
- **kwargs
400
- ):
401
- super().__init__()
402
- if exists(context_dim) and not isinstance(context_dim, list):
403
- context_dim = [context_dim]
404
- self.in_channels = in_channels
405
- inner_dim = n_heads * d_head
406
- self.norm = Normalize(in_channels)
407
- if not use_linear:
408
- self.proj_in = nn.Conv2d(
409
- in_channels, inner_dim, kernel_size=1, stride=1, padding=0
410
- )
411
- else:
412
- self.proj_in = nn.Linear(in_channels, inner_dim)
413
-
414
- self.transformer_blocks = nn.ModuleList(
415
- [
416
- BasicTransformerBlock3D(
417
- inner_dim,
418
- n_heads,
419
- d_head,
420
- dropout=dropout,
421
- context_dim=context_dim[d],
422
- disable_self_attn=disable_self_attn,
423
- checkpoint=use_checkpoint,
424
- **kwargs
425
- )
426
- for d in range(depth)
427
- ]
428
- )
429
- if not use_linear:
430
- self.proj_out = zero_module(
431
- nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
432
- )
433
- else:
434
- self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
435
- self.use_linear = use_linear
436
-
437
- def forward(self, x, context=None, num_frames=1):
438
- # note: if no context is given, cross-attention defaults to self-attention
439
- if not isinstance(context, list):
440
- context = [context]
441
- b, c, h, w = x.shape
442
- x_in = x
443
- x = self.norm(x)
444
- if not self.use_linear:
445
- x = self.proj_in(x)
446
- x = rearrange(x, "b c h w -> b (h w) c").contiguous()
447
- if self.use_linear:
448
- x = self.proj_in(x)
449
- for i, block in enumerate(self.transformer_blocks):
450
- x = block(x, context=context[i], num_frames=num_frames)
451
- if self.use_linear:
452
- x = self.proj_out(x)
453
- x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
454
- if not self.use_linear:
455
- x = self.proj_out(x)
456
- return x + x_in
 
1
+ from inspect import isfunction
2
+ import math
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn, einsum
6
+ from einops import rearrange, repeat
7
+ from typing import Optional, Any
8
+
9
+ from .diffusionmodules.util import checkpoint
10
+
11
+
12
+ try:
13
+ import xformers
14
+ import xformers.ops
15
+
16
+ XFORMERS_IS_AVAILBLE = True
17
+ except:
18
+ XFORMERS_IS_AVAILBLE = False
19
+
20
+ # CrossAttn precision handling
21
+ import os
22
+
23
+ _ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
24
+
25
+
26
+ def exists(val):
27
+ return val is not None
28
+
29
+
30
+ def uniq(arr):
31
+ return {el: True for el in arr}.keys()
32
+
33
+
34
+ def default(val, d):
35
+ if exists(val):
36
+ return val
37
+ return d() if isfunction(d) else d
38
+
39
+
40
+ def max_neg_value(t):
41
+ return -torch.finfo(t.dtype).max
42
+
43
+
44
+ def init_(tensor):
45
+ dim = tensor.shape[-1]
46
+ std = 1 / math.sqrt(dim)
47
+ tensor.uniform_(-std, std)
48
+ return tensor
49
+
50
+
51
+ # feedforward
52
+ class GEGLU(nn.Module):
53
+ def __init__(self, dim_in, dim_out):
54
+ super().__init__()
55
+ self.proj = nn.Linear(dim_in, dim_out * 2)
56
+
57
+ def forward(self, x):
58
+ x, gate = self.proj(x).chunk(2, dim=-1)
59
+ return x * F.gelu(gate)
60
+
61
+
62
+ class FeedForward(nn.Module):
63
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
64
+ super().__init__()
65
+ inner_dim = int(dim * mult)
66
+ dim_out = default(dim_out, dim)
67
+ project_in = (
68
+ nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
69
+ if not glu
70
+ else GEGLU(dim, inner_dim)
71
+ )
72
+
73
+ self.net = nn.Sequential(
74
+ project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
75
+ )
76
+
77
+ def forward(self, x):
78
+ return self.net(x)
79
+
80
+
81
+ def zero_module(module):
82
+ """
83
+ Zero out the parameters of a module and return it.
84
+ """
85
+ for p in module.parameters():
86
+ p.detach().zero_()
87
+ return module
88
+
89
+
90
+ def Normalize(in_channels):
91
+ return torch.nn.GroupNorm(
92
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
93
+ )
94
+
95
+
96
+ class SpatialSelfAttention(nn.Module):
97
+ def __init__(self, in_channels):
98
+ super().__init__()
99
+ self.in_channels = in_channels
100
+
101
+ self.norm = Normalize(in_channels)
102
+ self.q = torch.nn.Conv2d(
103
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
104
+ )
105
+ self.k = torch.nn.Conv2d(
106
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
107
+ )
108
+ self.v = torch.nn.Conv2d(
109
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
110
+ )
111
+ self.proj_out = torch.nn.Conv2d(
112
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
113
+ )
114
+
115
+ def forward(self, x):
116
+ h_ = x
117
+ h_ = self.norm(h_)
118
+ q = self.q(h_)
119
+ k = self.k(h_)
120
+ v = self.v(h_)
121
+
122
+ # compute attention
123
+ b, c, h, w = q.shape
124
+ q = rearrange(q, "b c h w -> b (h w) c")
125
+ k = rearrange(k, "b c h w -> b c (h w)")
126
+ w_ = torch.einsum("bij,bjk->bik", q, k)
127
+
128
+ w_ = w_ * (int(c) ** (-0.5))
129
+ w_ = torch.nn.functional.softmax(w_, dim=2)
130
+
131
+ # attend to values
132
+ v = rearrange(v, "b c h w -> b c (h w)")
133
+ w_ = rearrange(w_, "b i j -> b j i")
134
+ h_ = torch.einsum("bij,bjk->bik", v, w_)
135
+ h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
136
+ h_ = self.proj_out(h_)
137
+
138
+ return x + h_
139
+
140
+
141
+ class MemoryEfficientCrossAttention(nn.Module):
142
+ # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
143
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs):
144
+ super().__init__()
145
+ print(
146
+ f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
147
+ f"{heads} heads."
148
+ )
149
+ inner_dim = dim_head * heads
150
+ context_dim = default(context_dim, query_dim)
151
+
152
+ self.heads = heads
153
+ self.dim_head = dim_head
154
+
155
+ self.with_ip = kwargs.get("with_ip", False)
156
+ if self.with_ip and (context_dim is not None):
157
+ self.to_k_ip = nn.Linear(context_dim, inner_dim, bias=False)
158
+ self.to_v_ip = nn.Linear(context_dim, inner_dim, bias=False)
159
+ self.ip_dim= kwargs.get("ip_dim", 16)
160
+ self.ip_weight = kwargs.get("ip_weight", 1.0)
161
+
162
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
163
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
164
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
165
+
166
+ self.to_out = nn.Sequential(
167
+ nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
168
+ )
169
+ self.attention_op: Optional[Any] = None
170
+
171
+ def forward(self, x, context=None, mask=None):
172
+ q = self.to_q(x)
173
+
174
+ has_ip = self.with_ip and (context is not None)
175
+ if has_ip:
176
+ # context dim [(b frame_num), (77 + img_token), 1024]
177
+ token_len = context.shape[1]
178
+ context_ip = context[:, -self.ip_dim:, :]
179
+ k_ip = self.to_k_ip(context_ip)
180
+ v_ip = self.to_v_ip(context_ip)
181
+ context = context[:, :(token_len - self.ip_dim), :]
182
+
183
+ context = default(context, x)
184
+ k = self.to_k(context)
185
+ v = self.to_v(context)
186
+
187
+ b, _, _ = q.shape
188
+ q, k, v = map(
189
+ lambda t: t.unsqueeze(3)
190
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
191
+ .permute(0, 2, 1, 3)
192
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
193
+ .contiguous(),
194
+ (q, k, v),
195
+ )
196
+
197
+ # actually compute the attention, what we cannot get enough of
198
+ out = xformers.ops.memory_efficient_attention(
199
+ q, k, v, attn_bias=None, op=self.attention_op
200
+ )
201
+
202
+ if has_ip:
203
+ k_ip, v_ip = map(
204
+ lambda t: t.unsqueeze(3)
205
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
206
+ .permute(0, 2, 1, 3)
207
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
208
+ .contiguous(),
209
+ (k_ip, v_ip),
210
+ )
211
+ # actually compute the attention, what we cannot get enough of
212
+ out_ip = xformers.ops.memory_efficient_attention(
213
+ q, k_ip, v_ip, attn_bias=None, op=self.attention_op
214
+ )
215
+ out = out + self.ip_weight * out_ip
216
+
217
+ if exists(mask):
218
+ raise NotImplementedError
219
+ out = (
220
+ out.unsqueeze(0)
221
+ .reshape(b, self.heads, out.shape[1], self.dim_head)
222
+ .permute(0, 2, 1, 3)
223
+ .reshape(b, out.shape[1], self.heads * self.dim_head)
224
+ )
225
+ return self.to_out(out)
226
+
227
+
228
+ class BasicTransformerBlock(nn.Module):
229
+ def __init__(
230
+ self,
231
+ dim,
232
+ n_heads,
233
+ d_head,
234
+ dropout=0.0,
235
+ context_dim=None,
236
+ gated_ff=True,
237
+ checkpoint=True,
238
+ disable_self_attn=False,
239
+ **kwargs
240
+ ):
241
+ super().__init__()
242
+ assert XFORMERS_IS_AVAILBLE, "xformers is not available"
243
+ attn_cls = MemoryEfficientCrossAttention
244
+ self.disable_self_attn = disable_self_attn
245
+ self.attn1 = attn_cls(
246
+ query_dim=dim,
247
+ heads=n_heads,
248
+ dim_head=d_head,
249
+ dropout=dropout,
250
+ context_dim=context_dim if self.disable_self_attn else None,
251
+ ) # is a self-attention if not self.disable_self_attn
252
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
253
+ self.attn2 = attn_cls(
254
+ query_dim=dim,
255
+ context_dim=context_dim,
256
+ heads=n_heads,
257
+ dim_head=d_head,
258
+ dropout=dropout,
259
+ **kwargs
260
+ ) # is self-attn if context is none
261
+ self.norm1 = nn.LayerNorm(dim)
262
+ self.norm2 = nn.LayerNorm(dim)
263
+ self.norm3 = nn.LayerNorm(dim)
264
+ self.checkpoint = checkpoint
265
+
266
+ def forward(self, x, context=None):
267
+ return checkpoint(
268
+ self._forward, (x, context), self.parameters(), self.checkpoint
269
+ )
270
+
271
+ def _forward(self, x, context=None):
272
+ x = (
273
+ self.attn1(
274
+ self.norm1(x), context=context if self.disable_self_attn else None
275
+ )
276
+ + x
277
+ )
278
+ x = self.attn2(self.norm2(x), context=context) + x
279
+ x = self.ff(self.norm3(x)) + x
280
+ return x
281
+
282
+
283
+ class SpatialTransformer(nn.Module):
284
+ """
285
+ Transformer block for image-like data.
286
+ First, project the input (aka embedding)
287
+ and reshape to b, t, d.
288
+ Then apply standard transformer action.
289
+ Finally, reshape to image
290
+ NEW: use_linear for more efficiency instead of the 1x1 convs
291
+ """
292
+
293
+ def __init__(
294
+ self,
295
+ in_channels,
296
+ n_heads,
297
+ d_head,
298
+ depth=1,
299
+ dropout=0.0,
300
+ context_dim=None,
301
+ disable_self_attn=False,
302
+ use_linear=False,
303
+ use_checkpoint=True,
304
+ **kwargs
305
+ ):
306
+ super().__init__()
307
+ if exists(context_dim) and not isinstance(context_dim, list):
308
+ context_dim = [context_dim]
309
+ self.in_channels = in_channels
310
+ inner_dim = n_heads * d_head
311
+ self.norm = Normalize(in_channels)
312
+ if not use_linear:
313
+ self.proj_in = nn.Conv2d(
314
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0
315
+ )
316
+ else:
317
+ self.proj_in = nn.Linear(in_channels, inner_dim)
318
+
319
+ self.transformer_blocks = nn.ModuleList(
320
+ [
321
+ BasicTransformerBlock(
322
+ inner_dim,
323
+ n_heads,
324
+ d_head,
325
+ dropout=dropout,
326
+ context_dim=context_dim[d],
327
+ disable_self_attn=disable_self_attn,
328
+ checkpoint=use_checkpoint,
329
+ **kwargs
330
+ )
331
+ for d in range(depth)
332
+ ]
333
+ )
334
+ if not use_linear:
335
+ self.proj_out = zero_module(
336
+ nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
337
+ )
338
+ else:
339
+ self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
340
+ self.use_linear = use_linear
341
+
342
+ def forward(self, x, context=None):
343
+ # note: if no context is given, cross-attention defaults to self-attention
344
+ if not isinstance(context, list):
345
+ context = [context]
346
+ b, c, h, w = x.shape
347
+ x_in = x
348
+ x = self.norm(x)
349
+ if not self.use_linear:
350
+ x = self.proj_in(x)
351
+ x = rearrange(x, "b c h w -> b (h w) c").contiguous()
352
+ if self.use_linear:
353
+ x = self.proj_in(x)
354
+ for i, block in enumerate(self.transformer_blocks):
355
+ x = block(x, context=context[i])
356
+ if self.use_linear:
357
+ x = self.proj_out(x)
358
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
359
+ if not self.use_linear:
360
+ x = self.proj_out(x)
361
+ return x + x_in
362
+
363
+
364
+ class BasicTransformerBlock3D(BasicTransformerBlock):
365
+ def forward(self, x, context=None, num_frames=1):
366
+ return checkpoint(
367
+ self._forward, (x, context, num_frames), self.parameters(), self.checkpoint
368
+ )
369
+
370
+ def _forward(self, x, context=None, num_frames=1):
371
+ x = rearrange(x, "(b f) l c -> b (f l) c", f=num_frames).contiguous()
372
+ x = (
373
+ self.attn1(
374
+ self.norm1(x),
375
+ context=context if self.disable_self_attn else None
376
+ )
377
+ + x
378
+ )
379
+ x = rearrange(x, "b (f l) c -> (b f) l c", f=num_frames).contiguous()
380
+ x = self.attn2(self.norm2(x), context=context) + x
381
+ x = self.ff(self.norm3(x)) + x
382
+ return x
383
+
384
+
385
+ class SpatialTransformer3D(nn.Module):
386
+ """3D self-attention"""
387
+
388
+ def __init__(
389
+ self,
390
+ in_channels,
391
+ n_heads,
392
+ d_head,
393
+ depth=1,
394
+ dropout=0.0,
395
+ context_dim=None,
396
+ disable_self_attn=False,
397
+ use_linear=False,
398
+ use_checkpoint=True,
399
+ **kwargs
400
+ ):
401
+ super().__init__()
402
+ if exists(context_dim) and not isinstance(context_dim, list):
403
+ context_dim = [context_dim]
404
+ self.in_channels = in_channels
405
+ inner_dim = n_heads * d_head
406
+ self.norm = Normalize(in_channels)
407
+ if not use_linear:
408
+ self.proj_in = nn.Conv2d(
409
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0
410
+ )
411
+ else:
412
+ self.proj_in = nn.Linear(in_channels, inner_dim)
413
+
414
+ self.transformer_blocks = nn.ModuleList(
415
+ [
416
+ BasicTransformerBlock3D(
417
+ inner_dim,
418
+ n_heads,
419
+ d_head,
420
+ dropout=dropout,
421
+ context_dim=context_dim[d],
422
+ disable_self_attn=disable_self_attn,
423
+ checkpoint=use_checkpoint,
424
+ **kwargs
425
+ )
426
+ for d in range(depth)
427
+ ]
428
+ )
429
+ if not use_linear:
430
+ self.proj_out = zero_module(
431
+ nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
432
+ )
433
+ else:
434
+ self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
435
+ self.use_linear = use_linear
436
+
437
+ def forward(self, x, context=None, num_frames=1):
438
+ # note: if no context is given, cross-attention defaults to self-attention
439
+ if not isinstance(context, list):
440
+ context = [context]
441
+ b, c, h, w = x.shape
442
+ x_in = x
443
+ x = self.norm(x)
444
+ if not self.use_linear:
445
+ x = self.proj_in(x)
446
+ x = rearrange(x, "b c h w -> b (h w) c").contiguous()
447
+ if self.use_linear:
448
+ x = self.proj_in(x)
449
+ for i, block in enumerate(self.transformer_blocks):
450
+ x = block(x, context=context[i], num_frames=num_frames)
451
+ if self.use_linear:
452
+ x = self.proj_out(x)
453
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
454
+ if not self.use_linear:
455
+ x = self.proj_out(x)
456
+ return x + x_in