mterris commited on
Commit
ed95f9b
·
1 Parent(s): 9037f29
Files changed (5) hide show
  1. model_factory.py +6 -136
  2. models/blocks.py +0 -924
  3. models/heads.py +0 -270
  4. models/ram.py +854 -0
  5. models/unext_wip.py +0 -1238
model_factory.py CHANGED
@@ -1,103 +1,7 @@
1
  import torch
2
- import torch.nn as nn
3
- import deepinv as dinv
4
 
5
- from models.unext_wip import UNeXt
6
- from physics.multiscale import Pad
7
-
8
-
9
- class ArtifactRemoval(nn.Module):
10
- r"""
11
- Artifact removal architecture :math:`\phi(A^{\top}y)`.
12
-
13
- This differs from the dinv.models.ArtifactRemoval in that it allows to forward the physics.
14
-
15
- In the end we should not use this for unext !!
16
- """
17
-
18
- def __init__(self, backbone_net, pinv=False, ckpt_path=None, device=None, fm_mode=False):
19
- super(ArtifactRemoval, self).__init__()
20
- self.pinv = pinv
21
- self.backbone_net = backbone_net
22
- self.fm_mode = fm_mode
23
-
24
- if ckpt_path is not None:
25
- self.backbone_net.load_state_dict(torch.load(ckpt_path), strict=True)
26
- self.backbone_net.eval()
27
-
28
- if type(self.backbone_net).__name__ == "UNetRes":
29
- for _, v in self.backbone_net.named_parameters():
30
- v.requires_grad = False
31
- self.backbone_net = self.backbone_net.to(device)
32
-
33
-
34
- def forward_basic(self, y=None, physics=None, x_in=None, t=None, **kwargs):
35
- r"""
36
- Reconstructs a signal estimate from measurements y
37
-
38
- :param torch.tensor y: measurements
39
- :param deepinv.physics.Physics physics: forward operator
40
- """
41
- if physics is None:
42
- physics = dinv.physics.Denoising(noise_model=dinv.physics.GaussianNoise(sigma=0.), device=y.device)
43
-
44
- if not self.training:
45
- x_temp = physics.A_adjoint(y)
46
- pad = (-x_temp.size(-2) % 8, -x_temp.size(-1) % 8)
47
- physics = Pad(physics, pad)
48
-
49
- x_in = physics.A_adjoint(y) if not self.pinv else physics.A_dagger(y)
50
-
51
- if hasattr(physics.noise_model, "sigma"):
52
- sigma = physics.noise_model.sigma
53
- else:
54
- sigma = 1e-5 # WARNING: this is a default value that we may not want to use?
55
-
56
- if hasattr(physics.noise_model, "gain"):
57
- gamma = physics.noise_model.gain
58
- else:
59
- gamma = 1e-5 # WARNING: this is a default value that we may not want to use?
60
-
61
- out = self.backbone_net(x_in, physics=physics, y=y, sigma=sigma, gamma=gamma, t=t)
62
-
63
- if not self.training:
64
- out = physics.remove_pad(out)
65
-
66
- return out
67
-
68
- def forward(self, y=None, physics=None, x_in=None, **kwargs):
69
- if 'unext' in type(self.backbone_net).__name__.lower():
70
- return self.forward_basic(physics=physics, y=y, x_in=x_in, **kwargs)
71
- else:
72
- return self.backbone_net(physics=physics, y=y, **kwargs)
73
-
74
-
75
- def get_model(
76
- model_name="unext_emb_physics_config_C",
77
- device="cpu",
78
- in_channels=[1, 2, 3],
79
- grayscale=False,
80
- conv_type="base",
81
- pool_type="base",
82
- layer_scale_init_value=1e-6,
83
- init_type="ortho",
84
- gain_init_conv=1.0,
85
- gain_init_linear=1.0,
86
- drop_prob=0.0,
87
- replk=False,
88
- mult_fact=4,
89
- antialias="gaussian",
90
- nc_base=64,
91
- cond_type="base",
92
- blind=False,
93
- pretrained_pth=None,
94
- weight_tied=True,
95
- N=4,
96
- c_mult=1,
97
- depth_encoding=1,
98
- relu_in_encoding=False,
99
- skip_in_encoding=True,
100
- ):
101
  """
102
  Load the model.
103
 
@@ -107,41 +11,7 @@ def get_model(
107
  :param bool train: if True, the model is trained
108
  :return: model
109
  """
110
- model_name = model_name.lower()
111
-
112
- if model_name == "unext_emb_physics_config_c":
113
- n_chan = [1, 2, 3] # 6 for old head grayscale, complex and color = 1 + 2 + 3
114
- residual = True if "residual" in model_name else False
115
- nc = [nc_base * 2**i for i in range(4)]
116
-
117
-
118
- model = UNeXt(
119
- in_channels=in_channels,
120
- out_channels=in_channels,
121
- device=device,
122
- residual=residual,
123
- conv_type=conv_type,
124
- pool_type=pool_type,
125
- layer_scale_init_value=layer_scale_init_value,
126
- init_type=init_type,
127
- gain_init_conv=gain_init_conv,
128
- gain_init_linear=gain_init_linear,
129
- drop_prob=drop_prob,
130
- replk=replk,
131
- mult_fact=mult_fact,
132
- antialias=antialias,
133
- nc=nc,
134
- cond_type=cond_type,
135
- emb_physics=True,
136
- config="C",
137
- pretrained_pth=pretrained_pth,
138
- N=N,
139
- c_mult=c_mult,
140
- depth_encoding=depth_encoding,
141
- relu_in_encoding=relu_in_encoding,
142
- skip_in_encoding=skip_in_encoding,
143
- ).to(device)
144
- return ArtifactRemoval(model, pinv=False, device=device)
145
-
146
- else:
147
- raise ValueError(f"Model {model_name} is not supported.")
 
1
  import torch
2
+ from models.ram import RAM
 
3
 
4
+ def get_model():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  """
6
  Load the model.
7
 
 
11
  :param bool train: if True, the model is trained
12
  :return: model
13
  """
14
+ model = RAM()
15
+ state_dict = torch.load('ckpt/ram.pth.tar')
16
+ model.load_state_dict(state_dict)
17
+ return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/blocks.py DELETED
@@ -1,924 +0,0 @@
1
- import math
2
-
3
- import numpy as np
4
- import torch
5
- import torch.nn as nn
6
- import torch.nn.functional as F
7
- from deepinv.models.unet import BFBatchNorm2d
8
- from deepinv.physics.blur import gaussian_blur
9
- from deepinv.physics.functional import conv2d
10
- from deepinv.utils import TensorList
11
-
12
- from timm.models.layers import trunc_normal_, DropPath
13
-
14
-
15
- def normalize(x, dim=None, eps=1e-4):
16
- if dim is None:
17
- dim = list(range(1, x.ndim))
18
- norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32)
19
- norm = torch.add(eps, norm, alpha=np.sqrt(norm.numel() / x.numel()))
20
- return x / norm.to(x.dtype)
21
-
22
-
23
- class TimestepEmbedding(nn.Module):
24
- def __init__(self, hidden_size, frequency_embedding_size=256):
25
- super().__init__()
26
- self.mlp = nn.Sequential(
27
- nn.Linear(frequency_embedding_size, hidden_size),
28
- nn.SiLU(),
29
- nn.Linear(hidden_size, hidden_size),
30
- )
31
- self.frequency_embedding_size = frequency_embedding_size
32
-
33
- @staticmethod
34
- def timestep_embedding(t, dim, max_period=10000):
35
- half = dim // 2
36
- freqs = torch.exp(
37
- -math.log(max_period) * torch.arange(start=0, end=half) / half
38
- ).to(t.device)
39
- args = t[:, None] * freqs[None]
40
- embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
41
- if dim % 2:
42
- embedding = torch.cat(
43
- [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
44
- )
45
- return embedding
46
-
47
- def forward(self, t):
48
- t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(
49
- dtype=next(self.parameters()).dtype
50
- )
51
- t_emb = self.mlp(t_freq)
52
- return t_emb
53
-
54
-
55
- class MPConv(torch.nn.Module):
56
- def __init__(self, in_channels, out_channels, kernel):
57
- super().__init__()
58
- self.out_channels = out_channels
59
- self.weight = torch.nn.Parameter(torch.randn(out_channels, in_channels, *kernel))
60
-
61
- def forward(self, x, gain=1):
62
- w = self.weight.to(torch.float32)
63
- if self.training:
64
- with torch.no_grad():
65
- self.weight.copy_(normalize(w)) # forced weight normalization
66
- w = normalize(w) # traditional weight normalization
67
- w = w * (gain / np.sqrt(w[0].numel())) # magnitude-preserving scaling
68
- w = w.to(x.dtype)
69
- if w.ndim == 2:
70
- return x @ w.t()
71
- assert w.ndim == 4
72
- return F.conv2d(x, w, padding=(w.shape[-1] // 2,))
73
-
74
-
75
- # --------------------------------------------------------------------------------------
76
- def mp_silu(x):
77
- return torch.nn.functional.silu(x) / 0.596
78
-
79
-
80
- class MPFourier(torch.nn.Module):
81
- def __init__(self, num_channels, bandwidth=1, device="cpu"):
82
- super().__init__()
83
- self.register_buffer(
84
- "freqs", 2 * np.pi * torch.rand(num_channels, device=device) * bandwidth
85
- )
86
- self.register_buffer(
87
- "phases", 2 * np.pi * torch.rand(num_channels, device=device)
88
- )
89
-
90
- def forward(self, x):
91
- y = x.to(torch.float32)
92
- y = y.ger(self.freqs.to(torch.float32))
93
- y = y + self.phases.to(torch.float32)
94
- y = y.cos() * np.sqrt(2)
95
- return y.to(x.dtype)
96
-
97
-
98
- class NoiseEmbedding(torch.nn.Module):
99
- def __init__(self, num_channels=1, emb_channels=512, device="cpu", biasfree=True):
100
- super().__init__()
101
- self.emb_fourier = MPFourier(num_channels, device=device)
102
- self.emb_noise = MPConv(num_channels, emb_channels, kernel=[])
103
- self.biasfree = biasfree
104
-
105
- def forward(self, y, physics, factor):
106
- if hasattr(physics, "noise_model") and not callable(physics.noise_model):
107
- sigma = getattr(physics.noise_model, "sigma", 0.0)
108
- else:
109
- sigma = 0.0
110
-
111
- if isinstance(y, TensorList):
112
- sigma = sigma / (y[0].abs().reshape(y[0].size(0),-1).mean(1) + 1e-8) / factor
113
- else:
114
- sigma = sigma / (y.abs().reshape(y.size(0),-1).mean(1) + 1e-8) / factor
115
- emb_four = self.emb_fourier(sigma)
116
- emb = self.emb_noise(emb_four)
117
- if self.biasfree:
118
- emb = F.relu(emb)
119
- else:
120
- emb = mp_silu(emb)
121
- return emb.unsqueeze(-1).unsqueeze(-1)
122
-
123
-
124
- # --------------------------------------------------------------------------------------
125
- class AffineConv2d(nn.Conv2d):
126
- def __init__(
127
- self,
128
- in_channels,
129
- out_channels,
130
- kernel_size,
131
- mode="affine",
132
- bias=False,
133
- stride=1,
134
- padding=0,
135
- dilation=1,
136
- groups=1,
137
- padding_mode="circular",
138
- blind=True,
139
- ):
140
- if mode == "affine": # f(a*x + 1) = a*f(x) + 1
141
- bias = False
142
- super().__init__(
143
- in_channels,
144
- out_channels,
145
- kernel_size,
146
- bias=bias,
147
- stride=stride,
148
- padding=padding,
149
- dilation=dilation,
150
- groups=groups,
151
- padding_mode=padding_mode,
152
- )
153
- self.blind = blind
154
- self.mode = mode
155
-
156
- def affine(self, w):
157
- """returns new kernels that encode affine combinations"""
158
- return (
159
- w.view(self.out_channels, -1).roll(1, 1).view(w.size())
160
- - w
161
- + 1 / w[0, ...].numel()
162
- )
163
-
164
- def forward(self, x):
165
- if self.mode != "affine":
166
- return super().forward(x)
167
- else:
168
- kernel = (
169
- self.affine(self.weight)
170
- if self.blind
171
- else torch.cat(
172
- (self.affine(self.weight[:, :-1, :, :]), self.weight[:, -1:, :, :]),
173
- dim=1,
174
- )
175
- )
176
- padding = tuple(
177
- elt for elt in reversed(self.padding) for _ in range(2)
178
- ) # used to translate padding arg used by Conv module to the ones used by F.pad
179
- padding_mode = (
180
- self.padding_mode if self.padding_mode != "zeros" else "constant"
181
- ) # used to translate padding_mode arg used by Conv module to the ones used by F.pad
182
- return F.conv2d(
183
- F.pad(x, padding, mode=padding_mode),
184
- kernel,
185
- stride=self.stride,
186
- dilation=self.dilation,
187
- groups=self.groups,
188
- )
189
-
190
-
191
- # --------------------------------------------------------------------------------------
192
- def kaiser_window(beta, length):
193
- """Return the Kaiser window of length `length` and shape parameter `beta`."""
194
- if beta < 0:
195
- raise ValueError("beta must be greater than 0")
196
- if length < 1:
197
- raise ValueError("length must be greater than 0")
198
- if length == 1:
199
- return torch.tensor([1.0])
200
- half = (length - 1) / 2
201
- n = torch.arange(length)
202
- beta = torch.tensor(beta)
203
- return torch.i0(beta * torch.sqrt(1 - ((n - half) / half) ** 2)) / torch.i0(beta)
204
-
205
-
206
- def sinc_filter(factor=2, length=11, windowed=True):
207
- r"""
208
- Anti-aliasing sinc filter multiplied by a Kaiser window.
209
-
210
- :param float factor: Downsampling factor.
211
- :param int length: Length of the filter.
212
- """
213
- deltaf = 1 / factor
214
-
215
- n = torch.arange(length) - (length - 1) / 2
216
- filter = torch.sinc(n / factor)
217
-
218
- if windowed:
219
- A = 2.285 * (length - 1) * 3.14 * deltaf + 7.95
220
- if A <= 21:
221
- beta = 0
222
- elif A <= 50:
223
- beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21)
224
- else:
225
- beta = 0.1102 * (A - 8.7)
226
-
227
- filter = filter * kaiser_window(beta, length)
228
-
229
- filter = filter.unsqueeze(0)
230
- filter = filter * filter.T
231
- filter = filter.unsqueeze(0).unsqueeze(0)
232
- filter = filter / filter.sum()
233
- return filter
234
-
235
-
236
- class EquivMaxPool(nn.Module):
237
- r"""
238
- Max pooling layer that is equivariant to translations.
239
-
240
- :param int kernel_size: size of the pooling window.
241
- :param int stride: stride of the pooling operation.
242
- :param int padding: padding to apply before pooling.
243
- :param bool circular_padding: circular padding for the convolutional layers.
244
- """
245
-
246
- def __init__(
247
- self,
248
- antialias="gaussian",
249
- factor=2,
250
- device="cuda",
251
- in_channels=64,
252
- out_channels=64,
253
- bias=False,
254
- padding_mode="circular",
255
- ):
256
- super(EquivMaxPool, self).__init__()
257
- self.antialias = antialias
258
- if antialias == "gaussian":
259
- self.antialias_kernel = gaussian_blur(factor / 3.14).to(device)
260
- elif antialias == "sinc":
261
- self.antialias_kernel = sinc_filter(
262
- factor=factor, length=11, windowed=True
263
- ).to(device)
264
-
265
- self.conv_down = AffineConv2d(
266
- in_channels,
267
- out_channels,
268
- kernel_size=3,
269
- stride=1,
270
- padding=1,
271
- bias=bias,
272
- padding_mode=padding_mode,
273
- groups=1,
274
- )
275
-
276
- self.conv_up = AffineConv2d(
277
- out_channels,
278
- in_channels,
279
- kernel_size=3,
280
- stride=1,
281
- padding=1,
282
- bias=bias,
283
- padding_mode=padding_mode,
284
- groups=1,
285
- )
286
-
287
- def forward(self, x):
288
- return self.downscale(x)
289
-
290
- def downscale(self, x):
291
- r"""
292
- Apply the equivariant pooling.
293
-
294
- :param torch.Tensor x: input tensor.
295
- """
296
- B, C, H, W = x.shape
297
-
298
- x = self.conv_down(x)
299
-
300
- if self.antialias == "gaussian" or self.antialias == "sinc":
301
- x = conv2d(x, self.antialias_kernel, padding="circular")
302
-
303
- x1 = x[:, :, ::2, ::2].unsqueeze(0)
304
- x2 = x[:, :, ::2, 1::2].unsqueeze(0)
305
- x3 = x[:, :, 1::2, ::2].unsqueeze(0)
306
- x4 = x[:, :, 1::2, 1::2].unsqueeze(0)
307
- out = torch.cat([x1, x2, x3, x4], dim=0) # (4, B, C, H/2, W/2)
308
- ind = torch.norm(out, dim=(2, 3, 4), p=2) # (4, B)
309
- ind = torch.argmax(ind, dim=0) # (B)
310
- out = out[ind, torch.arange(B), ...] # (B, C, H/2, W/2)
311
- self.ind = ind
312
-
313
- return out
314
-
315
- def upscale(self, x):
316
- B, C, H, W = x.shape
317
-
318
- out = torch.zeros((B, C, H * 2, W * 2), device=x.device)
319
- out[:, :, ::2, ::2] = x
320
- ind = self.ind
321
- filter = torch.zeros((B, 1, 2, 2), device=x.device)
322
- filter[ind == 0, :, 0, 0] = 1
323
- filter[ind == 1, :, 0, 1] = 1
324
- filter[ind == 2, :, 1, 0] = 1
325
- filter[ind == 3, :, 1, 1] = 1
326
- out = conv2d(out, filter, padding="constant")
327
-
328
- if self.antialias == "gaussian" or self.antialias == "sinc":
329
- out = conv2d(out, self.antialias_kernel, padding="circular")
330
-
331
- out = self.conv_up(out)
332
- return out
333
-
334
-
335
- # --------------------------------------------------------------------------------------
336
- class ConvNextBaseBlock(nn.Module):
337
- r"""
338
- ConvNeXt Block mimicking DRUNet base layer (Conv + Relu + Conv)
339
-
340
- Args:
341
- in_channels (int): Number of input channels.
342
- out_channels (int): Number of output channels.
343
- mode (str): Mode for the AffineConv2d (if needed, else ignored).
344
- bias (bool): Whether to use bias in convolutions. Default: False.
345
- ksize (int): Kernel size for the convolutions. Default: 7.
346
- padding_mode (str): Padding mode for convolutions. Default: 'circular'.
347
- mult_fact (int): Multiplier factor for expanding the number of channels.
348
- residual (bool): Whether to use a residual connection. Default: False.
349
- """
350
-
351
- def __init__(
352
- self,
353
- in_channels,
354
- out_channels,
355
- mode="",
356
- bias=False,
357
- ksize=7,
358
- padding_mode="circular",
359
- mult_fact=1,
360
- residual=False,
361
- ):
362
- super().__init__()
363
-
364
- ### DEPTHWISE SEPARABLE CONVOLUTION: (N,C,H,W) -> (N,4*C,H,W)
365
- # depthwise conv with big kernel
366
- self.dwconv_a = AffineConv2d(
367
- in_channels,
368
- in_channels,
369
- kernel_size=ksize,
370
- padding=ksize // 2,
371
- groups=in_channels,
372
- padding_mode=padding_mode,
373
- bias=bias,
374
- mode=mode,
375
- )
376
- # depthwise conv with small kernel
377
- self.dwconv_a_small = AffineConv2d(
378
- in_channels,
379
- in_channels,
380
- kernel_size=3,
381
- padding=3 // 2,
382
- groups=in_channels,
383
- padding_mode=padding_mode,
384
- bias=bias,
385
- mode=mode,
386
- )
387
- # pointwise conv to change number of channels
388
- self.pwconv_a1 = AffineConv2d(
389
- in_channels,
390
- mult_fact * in_channels,
391
- kernel_size=1,
392
- stride=1,
393
- padding=0,
394
- mode=mode,
395
- bias=bias,
396
- padding_mode=padding_mode,
397
- groups=1,
398
- )
399
-
400
- ### ACTIVATION
401
- self.act_a = nn.ReLU()
402
-
403
- ### POINTWISE CONVOLUTION: (N,4*C,H,W) -> (N,O,H,W)
404
- self.pwconv_a2 = AffineConv2d(
405
- mult_fact * in_channels,
406
- out_channels,
407
- kernel_size=1,
408
- stride=1,
409
- padding=0,
410
- bias=bias,
411
- padding_mode=padding_mode,
412
- groups=1,
413
- )
414
-
415
- ### Needed to match the number of channels : (N,C,H,W) -> (C,O,H,W)
416
- self.residual = residual
417
- if self.residual:
418
- self.residual_conv = AffineConv2d(
419
- in_channels,
420
- out_channels,
421
- kernel_size=1,
422
- stride=1,
423
- padding=0,
424
- groups=1,
425
- padding_mode=padding_mode,
426
- bias=bias,
427
- mode=mode,
428
- )
429
-
430
- def forward(self, x_in, stream1=None, stream2=None):
431
- """Forward with GPU parallelization using multiple cuda streams."""
432
-
433
- if stream1 is not None and stream2 is not None:
434
- # Use the streams
435
- with torch.cuda.stream(stream1):
436
- output_a = self.dwconv_a(x_in) # Run the first convolution in stream1
437
-
438
- with torch.cuda.stream(stream2):
439
- output_a_small = self.dwconv_a_small(
440
- x_in
441
- ) # Run the second convolution in stream2
442
-
443
- # Ensure the streams are synchronized before adding the results
444
- torch.cuda.synchronize()
445
- x = self.pwconv_a(output_a + output_a_small)
446
-
447
- else:
448
- x = self.dwconv_a(x_in) + self.dwconv_a_small(x_in) # replk 7x7 with 3x3
449
- x = self.pwconv_a1(x)
450
-
451
- x = self.act_a(x)
452
- x = self.pwconv_a2(x) # (N,O,H,W)
453
-
454
- if self.residual:
455
- x = self.residual_conv(x_in) + x
456
-
457
- return x
458
-
459
-
460
- class ConvNextBlock2(nn.Module):
461
- r"""
462
- ConvNeXt Block mimicking DRUNet base layer (Conv + Relu + Conv)
463
-
464
- Args:
465
- ???
466
- """
467
-
468
- def __init__(
469
- self,
470
- in_channels,
471
- out_channels,
472
- mode="affine",
473
- bias=False,
474
- ksize=7,
475
- padding_mode="circular",
476
- mult_fact=4,
477
- s1=None,
478
- s2=None,
479
- ):
480
- super().__init__()
481
- self.block_0 = ConvNextBaseBlock(
482
- in_channels,
483
- out_channels,
484
- mode=mode,
485
- bias=bias,
486
- ksize=ksize,
487
- padding_mode=padding_mode,
488
- mult_fact=mult_fact,
489
- )
490
- self.block_1 = ConvNextBaseBlock(
491
- in_channels,
492
- out_channels,
493
- mode=mode,
494
- bias=bias,
495
- ksize=ksize,
496
- padding_mode=padding_mode,
497
- mult_fact=mult_fact,
498
- )
499
- # self.relu = nn.ReLU(inplace=True) # issue with the network when working in FP16 ???
500
- self.relu = nn.ReLU()
501
-
502
- # cuda stream to parallelize execution of ConvNextBaseBlock
503
- self.s1 = s1
504
- self.s2 = s2
505
-
506
- def forward(self, input, emb_sigma=None):
507
- if self.s1 is not None and self.s2 is not None:
508
- x = self.block_0(input, self.s1, self.s2)
509
- else:
510
- x = self.block_0(input)
511
-
512
- x = self.relu(x)
513
-
514
- if self.s1 is not None and self.s2 is not None:
515
- x = self.block_1(x, self.s1, self.s2)
516
- else:
517
- x = self.block_1(x)
518
- return x + input
519
-
520
-
521
- class CondResBlock(nn.Module):
522
- def __init__(
523
- self,
524
- in_channels=64,
525
- out_channels=64,
526
- kernel_size=3,
527
- stride=1,
528
- padding=1,
529
- bias=False,
530
- emb_channels=512,
531
- ):
532
- super(CondResBlock, self).__init__()
533
-
534
- assert in_channels == out_channels, "Only support in_channels==out_channels."
535
-
536
- self.gain = torch.nn.Parameter(torch.tensor([1.0]), requires_grad=True)
537
- self.emb_linear = MPConv(emb_channels, out_channels, kernel=[3, 3])
538
- self.conv1 = nn.Conv2d(
539
- in_channels, out_channels, kernel_size, stride, padding, bias=bias
540
- )
541
- self.conv2 = nn.Conv2d(
542
- out_channels, out_channels, kernel_size, stride, padding, bias=bias
543
- )
544
-
545
- def forward(self, x, emb_sigma):
546
- # u = self.conv1(mp_silu(x))
547
- u = self.conv1(F.relu((x)))
548
- c = self.emb_linear(emb_sigma, gain=self.gain) + 1
549
- # y = mp_silu(u * c.unsqueeze(2).unsqueeze(3).to(u.dtype))
550
- y = F.relu(u * c.unsqueeze(2).unsqueeze(3).to(u.dtype))
551
- y = self.conv2(y)
552
- return x + y
553
-
554
-
555
- """
556
- Functional blocks below
557
- """
558
- from collections import OrderedDict
559
- import torch
560
- import torch.nn as nn
561
-
562
-
563
- """
564
- # --------------------------------------------
565
- # Advanced nn.Sequential
566
- # https://github.com/xinntao/BasicSR
567
- # --------------------------------------------
568
- """
569
-
570
-
571
- def sequential(*args):
572
- """Advanced nn.Sequential.
573
- Args:
574
- nn.Sequential, nn.Module
575
- Returns:
576
- nn.Sequential
577
- """
578
- if len(args) == 1:
579
- if isinstance(args[0], OrderedDict):
580
- raise NotImplementedError("sequential does not support OrderedDict input.")
581
- return args[0] # No sequential is needed.
582
- modules = []
583
- for module in args:
584
- if isinstance(module, nn.Sequential):
585
- for submodule in module.children():
586
- modules.append(submodule)
587
- elif isinstance(module, nn.Module):
588
- modules.append(module)
589
- return nn.Sequential(*modules)
590
-
591
-
592
- """
593
- # --------------------------------------------
594
- # Useful blocks
595
- # https://github.com/xinntao/BasicSR
596
- # --------------------------------
597
- # conv + normaliation + relu (conv)
598
- # (PixelUnShuffle)
599
- # (ConditionalBatchNorm2d)
600
- # concat (ConcatBlock)
601
- # sum (ShortcutBlock)
602
- # resblock (ResBlock)
603
- # Channel Attention (CA) Layer (CALayer)
604
- # Residual Channel Attention Block (RCABlock)
605
- # Residual Channel Attention Group (RCAGroup)
606
- # Residual Dense Block (ResidualDenseBlock_5C)
607
- # Residual in Residual Dense Block (RRDB)
608
- # --------------------------------------------
609
- """
610
-
611
-
612
- # --------------------------------------------
613
- # return nn.Sequantial of (Conv + BN + ReLU)
614
- # --------------------------------------------
615
- def conv(
616
- in_channels=64,
617
- out_channels=64,
618
- kernel_size=3,
619
- stride=1,
620
- padding=1,
621
- bias=True,
622
- mode="CBR",
623
- negative_slope=0.2,
624
- ):
625
- L = []
626
- for t in mode:
627
- if t == "C":
628
- L.append(
629
- nn.Conv2d(
630
- in_channels=in_channels,
631
- out_channels=out_channels,
632
- kernel_size=kernel_size,
633
- stride=stride,
634
- padding=padding,
635
- bias=bias,
636
- )
637
- )
638
- elif t == "T":
639
- L.append(
640
- nn.ConvTranspose2d(
641
- in_channels=in_channels,
642
- out_channels=out_channels,
643
- kernel_size=kernel_size,
644
- stride=stride,
645
- padding=padding,
646
- bias=bias,
647
- )
648
- )
649
- elif t == "B":
650
- L.append(nn.BatchNorm2d(out_channels, momentum=0.9, eps=1e-04, affine=True))
651
- elif t == "I":
652
- L.append(nn.InstanceNorm2d(out_channels, affine=True))
653
- elif t == "R":
654
- L.append(nn.ReLU(inplace=True))
655
- elif t == "r":
656
- L.append(nn.ReLU(inplace=False))
657
- elif t == "L":
658
- L.append(nn.LeakyReLU(negative_slope=negative_slope, inplace=True))
659
- elif t == "l":
660
- L.append(nn.LeakyReLU(negative_slope=negative_slope, inplace=False))
661
- elif t == "E":
662
- L.append(nn.ELU(inplace=False))
663
- elif t == "s":
664
- L.append(nn.Softplus())
665
- elif t == "2":
666
- L.append(nn.PixelShuffle(upscale_factor=2))
667
- elif t == "3":
668
- L.append(nn.PixelShuffle(upscale_factor=3))
669
- elif t == "4":
670
- L.append(nn.PixelShuffle(upscale_factor=4))
671
- elif t == "U":
672
- L.append(nn.Upsample(scale_factor=2, mode="nearest"))
673
- elif t == "u":
674
- L.append(nn.Upsample(scale_factor=3, mode="nearest"))
675
- elif t == "v":
676
- L.append(nn.Upsample(scale_factor=4, mode="nearest"))
677
- elif t == "M":
678
- L.append(nn.MaxPool2d(kernel_size=kernel_size, stride=stride, padding=0))
679
- elif t == "A":
680
- L.append(nn.AvgPool2d(kernel_size=kernel_size, stride=stride, padding=0))
681
- else:
682
- raise NotImplementedError("Undefined type: ".format(t))
683
- return sequential(*L)
684
-
685
-
686
- """
687
- # --------------------------------------------
688
- # Upsampler
689
- # Kai Zhang, https://github.com/cszn/KAIR
690
- # --------------------------------------------
691
- # upsample_pixelshuffle
692
- # upsample_upconv
693
- # upsample_convtranspose
694
- # --------------------------------------------
695
- """
696
-
697
-
698
- # --------------------------------------------
699
- # conv + subp (+ relu)
700
- # --------------------------------------------
701
- def upsample_pixelshuffle(
702
- in_channels=64,
703
- out_channels=3,
704
- kernel_size=3,
705
- stride=1,
706
- padding=1,
707
- bias=True,
708
- mode="2R",
709
- negative_slope=0.2,
710
- ):
711
- assert len(mode) < 4 and mode[0] in [
712
- "2",
713
- "3",
714
- "4",
715
- ], "mode examples: 2, 2R, 2BR, 3, ..., 4BR."
716
- up1 = conv(
717
- in_channels,
718
- out_channels * (int(mode[0]) ** 2),
719
- kernel_size,
720
- stride,
721
- padding,
722
- bias,
723
- mode="C" + mode,
724
- negative_slope=negative_slope,
725
- )
726
- return up1
727
-
728
-
729
- # --------------------------------------------
730
- # nearest_upsample + conv (+ R)
731
- # --------------------------------------------
732
- def upsample_upconv(
733
- in_channels=64,
734
- out_channels=3,
735
- kernel_size=3,
736
- stride=1,
737
- padding=1,
738
- bias=True,
739
- mode="2R",
740
- negative_slope=0.2,
741
- ):
742
- assert len(mode) < 4 and mode[0] in [
743
- "2",
744
- "3",
745
- "4",
746
- ], "mode examples: 2, 2R, 2BR, 3, ..., 4BR"
747
- if mode[0] == "2":
748
- uc = "UC"
749
- elif mode[0] == "3":
750
- uc = "uC"
751
- elif mode[0] == "4":
752
- uc = "vC"
753
- mode = mode.replace(mode[0], uc)
754
- up1 = conv(
755
- in_channels,
756
- out_channels,
757
- kernel_size,
758
- stride,
759
- padding,
760
- bias,
761
- mode=mode,
762
- negative_slope=negative_slope,
763
- )
764
- return up1
765
-
766
-
767
- # --------------------------------------------
768
- # convTranspose (+ relu)
769
- # --------------------------------------------
770
- def upsample_convtranspose(
771
- in_channels=64,
772
- out_channels=3,
773
- kernel_size=2,
774
- stride=2,
775
- padding=0,
776
- bias=True,
777
- mode="2R",
778
- negative_slope=0.2,
779
- ):
780
- assert len(mode) < 4 and mode[0] in [
781
- "2",
782
- "3",
783
- "4",
784
- "8",
785
- ], "mode examples: 2, 2R, 2BR, 3, ..., 4BR."
786
- kernel_size = int(mode[0])
787
- stride = int(mode[0])
788
- mode = mode.replace(mode[0], "T")
789
- up1 = conv(
790
- in_channels,
791
- out_channels,
792
- kernel_size,
793
- stride,
794
- padding,
795
- bias,
796
- mode,
797
- negative_slope,
798
- )
799
- return up1
800
-
801
-
802
- """
803
- # --------------------------------------------
804
- # Downsampler
805
- # Kai Zhang, https://github.com/cszn/KAIR
806
- # --------------------------------------------
807
- # downsample_strideconv
808
- # downsample_maxpool
809
- # downsample_avgpool
810
- # --------------------------------------------
811
- """
812
-
813
-
814
- # --------------------------------------------
815
- # strideconv (+ relu)
816
- # --------------------------------------------
817
- def downsample_strideconv(
818
- in_channels=64,
819
- out_channels=64,
820
- kernel_size=2,
821
- stride=2,
822
- padding=0,
823
- bias=True,
824
- mode="2R",
825
- negative_slope=0.2,
826
- ):
827
- assert len(mode) < 4 and mode[0] in [
828
- "2",
829
- "3",
830
- "4",
831
- "8",
832
- ], "mode examples: 2, 2R, 2BR, 3, ..., 4BR."
833
- kernel_size = int(mode[0])
834
- stride = int(mode[0])
835
- mode = mode.replace(mode[0], "C")
836
- down1 = conv(
837
- in_channels,
838
- out_channels,
839
- kernel_size,
840
- stride,
841
- padding,
842
- bias,
843
- mode,
844
- negative_slope,
845
- )
846
- return down1
847
-
848
-
849
- # --------------------------------------------
850
- # maxpooling + conv (+ relu)
851
- # --------------------------------------------
852
- def downsample_maxpool(
853
- in_channels=64,
854
- out_channels=64,
855
- kernel_size=3,
856
- stride=1,
857
- padding=0,
858
- bias=True,
859
- mode="2R",
860
- negative_slope=0.2,
861
- ):
862
- assert len(mode) < 4 and mode[0] in [
863
- "2",
864
- "3",
865
- ], "mode examples: 2, 2R, 2BR, 3, ..., 3BR."
866
- kernel_size_pool = int(mode[0])
867
- stride_pool = int(mode[0])
868
- mode = mode.replace(mode[0], "MC")
869
- pool = conv(
870
- kernel_size=kernel_size_pool,
871
- stride=stride_pool,
872
- mode=mode[0],
873
- negative_slope=negative_slope,
874
- )
875
- pool_tail = conv(
876
- in_channels,
877
- out_channels,
878
- kernel_size,
879
- stride,
880
- padding,
881
- bias,
882
- mode=mode[1:],
883
- negative_slope=negative_slope,
884
- )
885
- return sequential(pool, pool_tail)
886
-
887
-
888
- # --------------------------------------------
889
- # averagepooling + conv (+ relu)
890
- # --------------------------------------------
891
- def downsample_avgpool(
892
- in_channels=64,
893
- out_channels=64,
894
- kernel_size=3,
895
- stride=1,
896
- padding=1,
897
- bias=True,
898
- mode="2R",
899
- negative_slope=0.2,
900
- ):
901
- assert len(mode) < 4 and mode[0] in [
902
- "2",
903
- "3",
904
- ], "mode examples: 2, 2R, 2BR, 3, ..., 3BR."
905
- kernel_size_pool = int(mode[0])
906
- stride_pool = int(mode[0])
907
- mode = mode.replace(mode[0], "AC")
908
- pool = conv(
909
- kernel_size=kernel_size_pool,
910
- stride=stride_pool,
911
- mode=mode[0],
912
- negative_slope=negative_slope,
913
- )
914
- pool_tail = conv(
915
- in_channels,
916
- out_channels,
917
- kernel_size,
918
- stride,
919
- padding,
920
- bias,
921
- mode=mode[1:],
922
- negative_slope=negative_slope,
923
- )
924
- return sequential(pool, pool_tail)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/heads.py DELETED
@@ -1,270 +0,0 @@
1
- import torch
2
- from models.blocks import AffineConv2d, downsample_strideconv, upsample_convtranspose
3
-
4
- class InHead(torch.nn.Module):
5
- def __init__(self, in_channels_list, out_channels, mode="", bias=False, input_layer=False):
6
- super(InHead, self).__init__()
7
- self.in_channels_list = in_channels_list
8
- self.input_layer = input_layer
9
- for i, in_channels in enumerate(in_channels_list):
10
- conv = AffineConv2d(
11
- in_channels=in_channels,
12
- out_channels=out_channels,
13
- bias=bias,
14
- mode=mode,
15
- kernel_size=3,
16
- stride=1,
17
- padding=1,
18
- padding_mode="zeros",
19
- )
20
- setattr(self, f"conv{i}", conv)
21
-
22
- def forward(self, x):
23
- in_channels = x.size(1) - 1 if self.input_layer else x.size(1)
24
-
25
- # find index
26
- i = self.in_channels_list.index(in_channels)
27
- x = getattr(self, f"conv{i}")(x)
28
-
29
- return x
30
-
31
- class OutTail(torch.nn.Module):
32
- def __init__(self, in_channels, out_channels_list, mode="", bias=False):
33
- super(OutTail, self).__init__()
34
- self.in_channels = in_channels
35
- self.out_channels_list = out_channels_list
36
- for i, out_channels in enumerate(out_channels_list):
37
- conv = AffineConv2d(
38
- in_channels=in_channels,
39
- out_channels=out_channels,
40
- bias=bias,
41
- mode=mode,
42
- kernel_size=3,
43
- stride=1,
44
- padding=1,
45
- padding_mode="zeros",
46
- )
47
- setattr(self, f"conv{i}", conv)
48
-
49
- def forward(self, x, out_channels):
50
- i = self.out_channels_list.index(out_channels)
51
- x = getattr(self, f"conv{i}")(x)
52
-
53
- return x
54
-
55
- # TODO: check that the heads are compatible with the old implementation
56
- class Heads(torch.nn.Module):
57
- def __init__(self, in_channels_list, out_channels, depth=2, scale=1, bias=True, mode="bilinear", c_mult=1, c_add=0, relu_in=False, skip_in=False):
58
- super(Heads, self).__init__()
59
- self.in_channels_list = [c * (c_mult + c_add) for c in in_channels_list]
60
- self.scale = scale
61
- self.mode = mode
62
- for i, in_channels in enumerate(self.in_channels_list):
63
- setattr(self, f"head{i}", HeadBlock(in_channels, out_channels, depth=depth, bias=bias, relu_in=relu_in, skip_in=skip_in))
64
-
65
- if self.mode == "":
66
- self.nl = torch.nn.ReLU(inplace=False)
67
- if self.scale != 1:
68
- for i, in_channels in enumerate(in_channels_list):
69
- setattr(self, f"down{i}", downsample_strideconv(in_channels, in_channels, bias=False, mode=str(self.scale)))
70
-
71
- def forward(self, x):
72
- in_channels = x.size(1)
73
- i = self.in_channels_list.index(in_channels)
74
-
75
- if self.scale != 1:
76
- if self.mode == "bilinear":
77
- x = torch.nn.functional.interpolate(x, scale_factor=1/self.scale, mode='bilinear', align_corners=False)
78
- else:
79
- x = getattr(self, f"down{i}")(x)
80
- x = self.nl(x)
81
-
82
- # find index
83
- x = getattr(self, f"head{i}")(x)
84
-
85
- return x
86
-
87
- class Tails(torch.nn.Module):
88
- def __init__(self, in_channels, out_channels_list, depth=2, scale=1, bias=True, mode="bilinear", c_mult=1, relu_in=False, skip_in=False):
89
- super(Tails, self).__init__()
90
- self.out_channels_list = out_channels_list
91
- self.scale = scale
92
- for i, out_channels in enumerate(out_channels_list):
93
- setattr(self, f"tail{i}", HeadBlock(in_channels, out_channels * c_mult, depth=depth, bias=bias, relu_in=relu_in, skip_in=skip_in))
94
-
95
- self.mode = mode
96
- if self.mode == "":
97
- self.nl = torch.nn.ReLU(inplace=False)
98
- if self.scale != 1:
99
- # self.up = upsample_convtranspose(out_channels, out_channels, bias=True, mode=str(self.scale))
100
- for i, out_channels in enumerate(out_channels_list):
101
- setattr(self, f"up{i}", upsample_convtranspose(out_channels * c_mult, out_channels * c_mult, bias=bias, mode=str(self.scale)))
102
-
103
- def forward(self, x, out_channels):
104
- i = self.out_channels_list.index(out_channels)
105
- x = getattr(self, f"tail{i}")(x)
106
- # find index
107
- if self.scale != 1:
108
- if self.mode == "bilinear":
109
- x = torch.nn.functional.interpolate(x, scale_factor=self.scale, mode='bilinear', align_corners=False)
110
- else:
111
- x = getattr(self, f"up{i}")(x)
112
-
113
- return x
114
-
115
- class ConvChannels(torch.nn.Module):
116
- """
117
- TODO: remplace this with convconv
118
- A method that only performs convolutional operations on the appropriate channels dim.
119
- """
120
- def __init__(self, channels_list, depth=2, bias=False, residual=False):
121
- super(ConvChannels, self).__init__()
122
- self.channels_list = channels_list
123
- self.residual = residual
124
- for i, channels in enumerate(channels_list):
125
- setattr(self, f"conv{i}_1", torch.nn.Conv2d(channels, channels, 3, bias=bias, padding=1))
126
- setattr(self, f"nl{i}", torch.nn.ReLU())
127
- setattr(self, f"conv{i}_2", torch.nn.Conv2d(channels, channels, 3, bias=bias, padding=1))
128
-
129
- def forward(self, x):
130
- i = self.channels_list.index(x.shape[1])
131
- u = getattr(self, f"conv{i}_1")(x)
132
- u = getattr(self, f"nl{i}")(u)
133
- u = getattr(self, f"conv{i}_2")(u)
134
- if self.residual:
135
- u = x + u
136
- return u
137
-
138
- class HeadBlock(torch.nn.Module):
139
- def __init__(self, in_channels, out_channels, kernel_size=3, bias=True, depth=2, relu_in=False, skip_in=False):
140
- super(HeadBlock, self).__init__()
141
-
142
- padding = kernel_size // 2
143
-
144
- c = out_channels if depth < 2 else in_channels
145
-
146
- self.convin = torch.nn.Conv2d(in_channels, c, kernel_size, padding=padding, bias=bias)
147
- self.zero_conv_skip = torch.nn.Conv2d(in_channels, c, 1, bias=False)
148
- self.depth = depth
149
- self.nl_1 = torch.nn.ReLU(inplace=False)
150
- self.nl_2 = torch.nn.ReLU(inplace=False)
151
- self.relu_in = relu_in
152
- self.skip_in = skip_in
153
-
154
- for i in range(depth-1):
155
- if i < depth - 2:
156
- c_in, c = in_channels, in_channels
157
- else:
158
- c_in, c = in_channels, out_channels
159
-
160
- setattr(self, f"conv1{i}", torch.nn.Conv2d(c_in, c_in, kernel_size, padding=padding, bias=bias))
161
- setattr(self, f"conv2{i}", torch.nn.Conv2d(c_in, c, kernel_size, padding=padding, bias=bias))
162
- setattr(self, f"skipconv{i}", torch.nn.Conv2d(c_in, c, 1, bias=False))
163
-
164
-
165
- def forward(self, x):
166
-
167
- if self.skip_in and self.relu_in:
168
- x = self.nl_1(self.convin(x)) + self.zero_conv_skip(x)
169
- elif self.skip_in and not self.relu_in:
170
- x = self.convin(x) + self.zero_conv_skip(x)
171
- else:
172
- x = self.convin(x)
173
-
174
- for i in range(self.depth-1):
175
- aux = getattr(self, f"conv1{i}")(x)
176
- aux = self.nl_2(aux)
177
- aux_0 = getattr(self, f"conv2{i}")(aux)
178
- aux_1 = getattr(self, f"skipconv{i}")(x)
179
- x = aux_0 + aux_1
180
-
181
- return x
182
-
183
-
184
- class SNRModule(torch.nn.Module):
185
- """
186
- A method that only performs convolutional operations on the appropriate channels dim.
187
- """
188
- def __init__(self, channels_list, out_channels, bias=False, residual=False, features=64):
189
- super(SNRModule, self).__init__()
190
- self.channels_list = channels_list
191
- self.residual = residual
192
- for i, channels in enumerate(channels_list):
193
- setattr(self, f"conv{i}_1", torch.nn.Conv2d(channels + 1, features, 3, bias=bias, padding=1))
194
- setattr(self, f"nl{i}", torch.nn.ReLU())
195
- setattr(self, f"conv{i}_2", torch.nn.Conv2d(features, out_channels, 3, bias=bias, padding=1))
196
-
197
- def forward(self, x0, sigma):
198
- i = self.channels_list.index(x0.shape[1])
199
-
200
- noise_level_map = (torch.ones((x0.size(0), 1, x0.size(2), x0.size(3)), device=x0.device) * sigma)
201
- x = torch.cat((x0, noise_level_map), 1)
202
-
203
- u = getattr(self, f"conv{i}_1")(x)
204
- u = getattr(self, f"nl{i}")(u)
205
- u = getattr(self, f"conv{i}_2")(u)
206
-
207
- den = u.pow(2).mean(dim=-1, keepdim=True).mean(dim=-2, keepdim=True).sqrt()
208
- u = u.abs() / (den + 1e-8)
209
-
210
- return u.mean(dim=-1, keepdim=True).mean(dim=-2, keepdim=True)
211
-
212
-
213
- class EquivConvModule(torch.nn.Module):
214
- """
215
- A method that only performs convolutional operations on the appropriate channels dim.
216
- """
217
- def __init__(self, channels_list, out_channels, bias=False, residual=False, features=64, N=1):
218
- super(EquivConvModule, self).__init__()
219
- self.channels_list = [c * N for c in channels_list]
220
- self.residual = residual
221
- for i, channels in enumerate(channels_list):
222
- setattr(self, f"conv{i}_1", torch.nn.Conv2d(channels * N, channels * N, 3, bias=bias, padding=1))
223
- setattr(self, f"nl{i}", torch.nn.ReLU())
224
- setattr(self, f"conv{i}_2", torch.nn.Conv2d(channels * N, out_channels, 3, bias=bias, padding=1))
225
-
226
- def forward(self, x):
227
-
228
- i = self.channels_list.index(x.shape[1])
229
-
230
- u = getattr(self, f"conv{i}_1")(x)
231
- u = getattr(self, f"nl{i}")(u)
232
- u = getattr(self, f"conv{i}_2")(u)
233
-
234
- return u
235
-
236
-
237
- class EquivHeads(torch.nn.Module):
238
- def __init__(self, in_channels_list, out_channels, depth=2, scale=1, bias=True, mode="bilinear"):
239
- super(EquivHeads, self).__init__()
240
- self.in_channels_list = in_channels_list
241
- self.scale = scale
242
- self.mode = mode
243
- for i, in_channels in enumerate(in_channels_list):
244
- setattr(self, f"head{i}", HeadBlock(in_channels + 1, out_channels, depth=depth, bias=bias))
245
-
246
- if self.mode == "":
247
- self.nl = torch.nn.ReLU(inplace=False)
248
- if self.scale != 1:
249
- for i, in_channels in enumerate(in_channels_list):
250
- setattr(self, f"down{i}", downsample_strideconv(in_channels, in_channels, bias=False, mode=str(self.scale)))
251
-
252
- def forward(self, x, sigma):
253
- in_channels = x.size(1)
254
- i = self.in_channels_list.index(in_channels)
255
-
256
- if self.scale != 1:
257
- if self.mode == "bilinear":
258
- x = torch.nn.functional.interpolate(x, scale_factor=1/self.scale, mode='bilinear', align_corners=False)
259
- else:
260
- x = getattr(self, f"down{i}")(x)
261
- x = self.nl(x)
262
-
263
- # concat noise level map
264
- noise_level_map = (torch.ones((x.size(0), 1, x.size(2), x.size(3)), device=x.device) * sigma)
265
- x = torch.cat((x, noise_level_map), 1)
266
-
267
- # find index
268
- x = getattr(self, f"head{i}")(x)
269
-
270
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/ram.py ADDED
@@ -0,0 +1,854 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ import deepinv as dinv
6
+ from deepinv.physics import Physics, LinearPhysics, Downsampling
7
+ from deepinv.utils import TensorList
8
+ from deepinv.utils.tensorlist import TensorList
9
+
10
+ from huggingface_hub import hf_hub_download
11
+
12
+ cuda = True if torch.cuda.is_available() else False
13
+ Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
14
+
15
+ class RAM(nn.Module):
16
+ r"""
17
+ RAM model
18
+
19
+ This model is a convolutional neural network (CNN) designed for image reconstruction tasks.
20
+
21
+ :param in_channels: Number of input channels. If a list is provided, the model will have separate heads for each channel.
22
+ :param device: Device to which the model should be moved. If None, the model will be created on the default device.
23
+ :param pretrained: If True, the model will be initialized with pretrained weights.
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ in_channels=[1, 2, 3],
29
+ device=None,
30
+ pretrained=True,
31
+ ):
32
+ super(RAM, self).__init__()
33
+
34
+ nc = [64, 128, 256, 512] # number of channels in the network
35
+ self.in_channels = in_channels
36
+ self.fact_realign = torch.nn.Parameter(torch.tensor([1.0], device=device))
37
+
38
+ self.separate_head = isinstance(in_channels, list)
39
+
40
+ if isinstance(in_channels, list):
41
+ in_channels_first = []
42
+ for i in range(len(in_channels)):
43
+ in_channels_first.append(in_channels[i] + 2)
44
+
45
+ # check if in_channels is a list
46
+ self.m_head = InHead(in_channels_first, nc[0])
47
+
48
+ self.m_down1 = BaseEncBlock(nc[0], nc[0], img_channels=in_channels, decode_upscale=1)
49
+ self.m_down2 = BaseEncBlock(nc[1], nc[1], img_channels=in_channels, decode_upscale=2)
50
+ self.m_down3 = BaseEncBlock(nc[2], nc[2], img_channels=in_channels, decode_upscale=4)
51
+ self.m_body = BaseEncBlock(nc[3], nc[3], img_channels=in_channels, decode_upscale=8)
52
+ self.m_up3 = BaseEncBlock(nc[2], nc[2], img_channels=in_channels, decode_upscale=4)
53
+ self.m_up2 = BaseEncBlock(nc[1], nc[1], img_channels=in_channels, decode_upscale=2)
54
+ self.m_up1 = BaseEncBlock(nc[0], nc[0], img_channels=in_channels, decode_upscale=1)
55
+
56
+ self.pool1 = downsample_strideconv(nc[0], nc[1], bias=False, mode="2")
57
+ self.pool2 = downsample_strideconv(nc[1], nc[2], bias=False, mode="2")
58
+ self.pool3 = downsample_strideconv(nc[2], nc[3], bias=False, mode="2")
59
+ self.up3 = upsample_convtranspose(nc[3], nc[2], bias=False, mode="2")
60
+ self.up2 = upsample_convtranspose(nc[2], nc[1], bias=False, mode="2")
61
+ self.up1 = upsample_convtranspose(nc[1], nc[0], bias=False, mode="2")
62
+
63
+ self.m_tail = OutTail(nc[0], in_channels)
64
+
65
+ # load pretrained weights from hugging face
66
+ if pretrained:
67
+ self.load_state_dict(
68
+ torch.load(hf_hub_download(repo_id="mterris/ram", filename="ram.pth.tar"), map_location=device))
69
+
70
+ if device is not None:
71
+ self.to(device)
72
+
73
+ def constant2map(self, value, x):
74
+ r"""
75
+ Converts a constant value to a map of the same size as the input tensor x.
76
+
77
+ :params float value: constant value
78
+ :params torch.Tensor x: input tensor
79
+ """
80
+ if isinstance(value, torch.Tensor):
81
+ if value.ndim > 0:
82
+ value_map = value.view(x.size(0), 1, 1, 1)
83
+ value_map = value_map.expand(-1, 1, x.size(2), x.size(3))
84
+ else:
85
+ value_map = torch.ones(
86
+ (x.size(0), 1, x.size(2), x.size(3)), device=x.device
87
+ ) * value[None, None, None, None].to(x.device)
88
+ else:
89
+ value_map = (
90
+ torch.ones((x.size(0), 1, x.size(2), x.size(3)), device=x.device)
91
+ * value
92
+ )
93
+ return value_map
94
+
95
+ def base_conditioning(self, x, sigma, gamma):
96
+ noise_level_map = self.constant2map(sigma, x)
97
+ gamma_map = self.constant2map(gamma, x)
98
+ return torch.cat((x, noise_level_map, gamma_map), 1)
99
+
100
+ def realign_input(self, x, physics, y):
101
+ r"""
102
+ Realign the input x based on the measurements y and the physics model.
103
+ Applies the proximity operator of the L2 norm with respect to the physics model.
104
+
105
+ :params torch.Tensor x: Input tensor
106
+ :params deepinv.physics.Physics physics: Physics model
107
+ :params torch.Tensor y: Measurements
108
+ """
109
+ if hasattr(physics, "factor"):
110
+ f = physics.factor
111
+ elif hasattr(physics, "base") and hasattr(physics.base, "factor"):
112
+ f = physics.base.factor
113
+ elif hasattr(physics, "base") and hasattr(physics.base, "base") and hasattr(physics.base.base, "factor"):
114
+ f = physics.base.base.factor
115
+ else:
116
+ f = 1.0
117
+
118
+ sigma = 1e-6 # default value
119
+ if hasattr(physics.noise_model, 'sigma'):
120
+ sigma = physics.noise_model.sigma
121
+ if hasattr(physics, 'base') and hasattr(physics.base, 'noise_model') and hasattr(physics.base.noise_model,
122
+ 'sigma'):
123
+ sigma = physics.base.noise_model.sigma
124
+ if hasattr(physics, 'base') and hasattr(physics.base, 'base') and hasattr(physics.base.base,
125
+ 'noise_model') and hasattr(
126
+ physics.base.base.noise_model, 'sigma'):
127
+ sigma = physics.base.base.noise_model.sigma
128
+
129
+ if isinstance(y, TensorList):
130
+ num = (y[0].reshape(y[0].shape[0], -1).abs().mean(1))
131
+ else:
132
+ num = (y.reshape(y.shape[0], -1).abs().mean(1))
133
+
134
+ snr = num / (sigma + 1e-4) # SNR equivariant
135
+ gamma = 1 / (1e-4 + 1 / (
136
+ snr * f ** 2)) # TODO: check square-root / mean / check if we need to add a factor in front ?
137
+ gamma = gamma[(...,) + (None,) * (x.dim() - 1)]
138
+ model_input = physics.prox_l2(x, y, gamma=gamma * self.fact_realign)
139
+
140
+ return model_input
141
+
142
+ def forward_unet(self, x0, sigma=None, gamma=None, physics=None, y=None):
143
+ r"""
144
+ Forward pass of the UNet model.
145
+
146
+ :params torch.Tensor x0: init image
147
+ :params float sigma: Gaussian noise level
148
+ :params float gamma: Poisson noise gain
149
+ :params deepinv.physics.Physics physics: physics measurement operator
150
+ :params torch.Tensor y: measurements
151
+ """
152
+ img_channels = x0.shape[1]
153
+ physics = MultiScaleLinearPhysics(physics, x0.shape[-3:], device=x0.device)
154
+
155
+ if self.separate_head and img_channels not in self.in_channels:
156
+ raise ValueError(
157
+ f"Input image has {img_channels} channels, but the network only have heads for {self.in_channels} channels.")
158
+
159
+ if y is not None:
160
+ x0 = self.realign_input(x0, physics, y)
161
+
162
+ x0 = self.base_conditioning(x0, sigma, gamma)
163
+
164
+ x1 = self.m_head(x0)
165
+
166
+ x1_ = self.m_down1(x1, physics=physics, y=y, img_channels=img_channels, scale=0)
167
+ x2 = self.pool1(x1_)
168
+
169
+ x3_ = self.m_down2(x2, physics=physics, y=y, img_channels=img_channels, scale=1)
170
+ x3 = self.pool2(x3_)
171
+
172
+ x4_ = self.m_down3(x3, physics=physics, y=y, img_channels=img_channels, scale=2)
173
+ x4 = self.pool3(x4_)
174
+
175
+ x = self.m_body(x4, physics=physics, y=y, img_channels=img_channels, scale=3)
176
+
177
+ x = self.up3(x + x4)
178
+ x = self.m_up3(x, physics=physics, y=y, img_channels=img_channels, scale=2)
179
+
180
+ x = self.up2(x + x3)
181
+ x = self.m_up2(x, physics=physics, y=y, img_channels=img_channels, scale=1)
182
+
183
+ x = self.up1(x + x2)
184
+ x = self.m_up1(x, physics=physics, y=y, img_channels=img_channels, scale=0)
185
+
186
+ x = self.m_tail(x + x1, img_channels)
187
+
188
+ return x
189
+
190
+ def forward(self, y=None, physics=None):
191
+ r"""
192
+ Reconstructs a signal estimate from measurements y
193
+ :param torch.tensor y: measurements
194
+ :param deepinv.physics.Physics physics: forward operator
195
+ """
196
+ if physics is None:
197
+ physics = dinv.physics.Denoising(noise_model=dinv.physics.GaussianNoise(sigma=0.), device=y.device)
198
+
199
+ x_temp = physics.A_adjoint(y)
200
+ pad = (-x_temp.size(-2) % 8, -x_temp.size(-1) % 8)
201
+ physics = Pad(physics, pad)
202
+
203
+ x_in = physics.A_adjoint(y)
204
+
205
+ sigma = physics.noise_model.sigma if hasattr(physics.noise_model, "sigma") else 1e-3
206
+ gamma = physics.noise_model.gain if hasattr(physics.noise_model, "gain") else 1e-3
207
+
208
+ out = self.forward_unet(x_in, sigma=sigma, gamma=gamma, physics=physics, y=y)
209
+
210
+ out = physics.remove_pad(out)
211
+
212
+ return out
213
+
214
+
215
+ ### --------------- MODEL ---------------
216
+ class BaseEncBlock(nn.Module):
217
+ def __init__(self, in_channels, out_channels, bias=False, nb=4, img_channels=None, decode_upscale=None):
218
+ super(BaseEncBlock, self).__init__()
219
+ self.enc = nn.ModuleList(
220
+ [
221
+ ResBlock(
222
+ in_channels,
223
+ out_channels,
224
+ bias=bias,
225
+ img_channels=img_channels,
226
+ decode_upscale=decode_upscale,
227
+ )
228
+ for _ in range(nb)
229
+ ]
230
+ )
231
+
232
+ def forward(self, x, physics=None, y=None, img_channels=None, scale=0):
233
+ for i in range(len(self.enc)):
234
+ x = self.enc[i](x, physics=physics, y=y, img_channels=img_channels, scale=scale)
235
+ return x
236
+
237
+
238
+ def krylov_embeddings(y, p, factor, v=None, N=4, x_init=None):
239
+ r"""
240
+ Efficient Krylov subspace embedding computation with parallel processing.
241
+
242
+ :params torch.Tensor y: Input tensor.
243
+ :params p: An object with A and A_adjoint methods (linear operator).
244
+ :params float factor: Scaling factor.
245
+ :params torch.Tensor v: Precomputed values to subtract from Krylov sequence. Defaults to None.
246
+ :params int N: Number of Krylov iterations. Defaults to 4.
247
+ :params torch.Tensor x_init: Initial guess. Defaults to None.
248
+ """
249
+
250
+ if x_init is None:
251
+ x = p.A_adjoint(y)
252
+ else:
253
+ x = x_init.clone() # Extract the first img_channels
254
+
255
+ norm = factor ** 2 # Precompute normalization factor
256
+ AtA = lambda u: p.A_adjoint(p.A(u)) * norm # Define the linear operator
257
+
258
+ v = v if v is not None else torch.zeros_like(x)
259
+
260
+ out = x.clone()
261
+ # Compute Krylov basis
262
+ x_k = x.clone()
263
+ for i in range(N - 1):
264
+ x_k = AtA(x_k) - v
265
+ out = torch.cat([out, x_k], dim=1)
266
+
267
+ return out
268
+
269
+
270
+ class MeasCondBlock(nn.Module):
271
+ r"""
272
+ Measurement conditioning block for the RAM model.
273
+
274
+ :param out_channels: Number of output channels.
275
+ :param img_channels: Number of input channels. If a list is provided, the model will have separate heads for each channel.
276
+ :param decode_upscale: Upscaling factor for the decoding convolution.
277
+ :param N: Number of Krylov iterations.
278
+ :param depth_encoding: Depth of the encoding convolution.
279
+ :param c_mult: Multiplier for the number of channels.
280
+ """
281
+
282
+ def __init__(self, out_channels=64, img_channels=None, decode_upscale=None, N=4, depth_encoding=1, c_mult=1):
283
+ super(MeasCondBlock, self).__init__()
284
+
285
+ self.separate_head = isinstance(img_channels, list)
286
+
287
+ assert img_channels is not None, "decode_dimensions should be provided"
288
+ assert decode_upscale is not None, "decode_upscale should be provided"
289
+
290
+ self.N = N
291
+ self.c_mult = c_mult
292
+ self.relu_encoding = nn.ReLU(inplace=False)
293
+ self.decoding_conv = Tails(out_channels, img_channels, depth=1, scale=1, bias=False, c_mult=self.c_mult)
294
+ self.encoding_conv = Heads(img_channels, out_channels, depth=depth_encoding, scale=1, bias=False,
295
+ c_mult=self.c_mult * N, c_add=N, relu_in=False, skip_in=True)
296
+
297
+ self.gain = torch.nn.Parameter(torch.tensor([1.0]), requires_grad=True)
298
+ self.gain_gradx = torch.nn.Parameter(torch.tensor([1e-2]), requires_grad=True)
299
+ self.gain_grady = torch.nn.Parameter(torch.tensor([1e-2]), requires_grad=True)
300
+ self.gain_pinvx = torch.nn.Parameter(torch.tensor([1e-2]), requires_grad=True)
301
+ self.gain_pinvy = torch.nn.Parameter(torch.tensor([1e-2]), requires_grad=True)
302
+
303
+ def forward(self, x, y, physics, img_channels=None, scale=1):
304
+ physics.set_scale(scale)
305
+ dec = self.decoding_conv(x, img_channels)
306
+ factor = 2 ** (scale)
307
+ meas_y = krylov_embeddings(y, physics, factor, N=self.N)
308
+ meas_dec = krylov_embeddings(y, physics, factor, N=self.N, x_init=dec[:, :img_channels, ...])
309
+ for c in range(1, self.c_mult):
310
+ meas_cur = krylov_embeddings(y, physics, factor, N=self.N,
311
+ x_init=dec[:, img_channels * c:img_channels * (c + 1)])
312
+ meas_dec = torch.cat([meas_dec, meas_cur], dim=1)
313
+ meas = torch.cat([meas_y, meas_dec], dim=1)
314
+ cond = self.encoding_conv(meas)
315
+ emb = self.relu_encoding(cond)
316
+ return emb
317
+
318
+
319
+ class ResBlock(nn.Module):
320
+ r"""
321
+ Convolutional residual block.
322
+
323
+ :param in_channels: Number of input channels.
324
+ :param out_channels: Number of output channels.
325
+ :param kernel_size: Size of the convolution kernel.
326
+ :param stride: Stride of the convolution.
327
+ :param padding: Padding for the convolution.
328
+ :param bias: Whether to use bias in the convolution.
329
+ :param img_channels: Number of input channels. If a list is provided, the model will have separate heads for each channel.
330
+ :param decode_upscale: Upscaling factor for the decoding convolution.
331
+ :param head: Whether this is a head block.
332
+ :param tail: Whether this is a tail block.
333
+ :param N: Number of Krylov iterations.
334
+ :param c_mult: Multiplier for the number of channels.
335
+ :param depth_encoding: Depth of the encoding convolution.
336
+ """
337
+
338
+ def __init__(
339
+ self,
340
+ in_channels=64,
341
+ out_channels=64,
342
+ kernel_size=3,
343
+ stride=1,
344
+ padding=1,
345
+ bias=True,
346
+ img_channels=None,
347
+ decode_upscale=None,
348
+ head=False,
349
+ tail=False,
350
+ N=2,
351
+ c_mult=2,
352
+ depth_encoding=2,
353
+ ):
354
+ super(ResBlock, self).__init__()
355
+
356
+ if not head and not tail:
357
+ assert in_channels == out_channels, "Only support in_channels==out_channels."
358
+ self.separate_head = isinstance(img_channels, list)
359
+ self.is_head = head
360
+ self.is_tail = tail
361
+
362
+ if self.is_head:
363
+ self.head = InHead(img_channels, out_channels, input_layer=True)
364
+
365
+ if not self.is_head and not self.is_tail:
366
+ self.conv1 = conv(
367
+ in_channels,
368
+ out_channels,
369
+ kernel_size,
370
+ stride,
371
+ padding,
372
+ bias,
373
+ "C",
374
+ )
375
+ self.nl = nn.ReLU(inplace=True)
376
+ self.conv2 = conv(
377
+ out_channels,
378
+ out_channels,
379
+ kernel_size,
380
+ stride,
381
+ padding,
382
+ bias,
383
+ "C",
384
+ )
385
+
386
+ self.gain = torch.nn.Parameter(torch.tensor([1.0]), requires_grad=True)
387
+ self.PhysicsBlock = MeasCondBlock(out_channels=out_channels, c_mult=c_mult,
388
+ img_channels=img_channels, decode_upscale=decode_upscale,
389
+ N=N, depth_encoding=depth_encoding)
390
+
391
+ def forward(self, x, physics=None, y=None, img_channels=None, scale=0):
392
+ u = self.conv1(x)
393
+ u = self.nl(u)
394
+ u_2 = self.conv2(u)
395
+ emb_grad = self.PhysicsBlock(u, y, physics, img_channels=img_channels, scale=scale)
396
+ u_1 = self.gain * emb_grad
397
+ return x + u_2 + u_1
398
+
399
+
400
+ class InHead(torch.nn.Module):
401
+ def __init__(self, in_channels_list, out_channels, mode="", bias=False, input_layer=False):
402
+ super(InHead, self).__init__()
403
+ self.in_channels_list = in_channels_list
404
+ self.input_layer = input_layer
405
+ for i, in_channels in enumerate(in_channels_list):
406
+ conv = AffineConv2d(
407
+ in_channels=in_channels,
408
+ out_channels=out_channels,
409
+ bias=bias,
410
+ mode=mode,
411
+ kernel_size=3,
412
+ stride=1,
413
+ padding=1,
414
+ padding_mode="zeros",
415
+ )
416
+ setattr(self, f"conv{i}", conv)
417
+
418
+ def forward(self, x):
419
+ in_channels = x.size(1) - 1 if self.input_layer else x.size(1)
420
+
421
+ # find index
422
+ i = self.in_channels_list.index(in_channels)
423
+ x = getattr(self, f"conv{i}")(x)
424
+
425
+ return x
426
+
427
+
428
+ class OutTail(torch.nn.Module):
429
+ def __init__(self, in_channels, out_channels_list, mode="", bias=False):
430
+ super(OutTail, self).__init__()
431
+ self.in_channels = in_channels
432
+ self.out_channels_list = out_channels_list
433
+ for i, out_channels in enumerate(out_channels_list):
434
+ conv = AffineConv2d(
435
+ in_channels=in_channels,
436
+ out_channels=out_channels,
437
+ bias=bias,
438
+ mode=mode,
439
+ kernel_size=3,
440
+ stride=1,
441
+ padding=1,
442
+ padding_mode="zeros",
443
+ )
444
+ setattr(self, f"conv{i}", conv)
445
+
446
+ def forward(self, x, out_channels):
447
+ i = self.out_channels_list.index(out_channels)
448
+ x = getattr(self, f"conv{i}")(x)
449
+
450
+ return x
451
+
452
+
453
+ class Heads(torch.nn.Module):
454
+ def __init__(self, in_channels_list, out_channels, depth=2, scale=1, bias=True, mode="bilinear", c_mult=1, c_add=0,
455
+ relu_in=False, skip_in=False):
456
+ super(Heads, self).__init__()
457
+ self.in_channels_list = [c * (c_mult + c_add) for c in in_channels_list]
458
+ self.scale = scale
459
+ self.mode = mode
460
+ for i, in_channels in enumerate(self.in_channels_list):
461
+ setattr(self, f"head{i}",
462
+ HeadBlock(in_channels, out_channels, depth=depth, bias=bias, relu_in=relu_in, skip_in=skip_in))
463
+
464
+ if self.mode == "":
465
+ self.nl = torch.nn.ReLU(inplace=False)
466
+ if self.scale != 1:
467
+ for i, in_channels in enumerate(in_channels_list):
468
+ setattr(self, f"down{i}",
469
+ downsample_strideconv(in_channels, in_channels, bias=False, mode=str(self.scale)))
470
+
471
+ def forward(self, x):
472
+ in_channels = x.size(1)
473
+ i = self.in_channels_list.index(in_channels)
474
+
475
+ if self.scale != 1:
476
+ if self.mode == "bilinear":
477
+ x = torch.nn.functional.interpolate(x, scale_factor=1 / self.scale, mode='bilinear',
478
+ align_corners=False)
479
+ else:
480
+ x = getattr(self, f"down{i}")(x)
481
+ x = self.nl(x)
482
+
483
+ # find index
484
+ x = getattr(self, f"head{i}")(x)
485
+
486
+ return x
487
+
488
+
489
+ class Tails(torch.nn.Module):
490
+ def __init__(self, in_channels, out_channels_list, depth=2, scale=1, bias=True, mode="bilinear", c_mult=1,
491
+ relu_in=False, skip_in=False):
492
+ super(Tails, self).__init__()
493
+ self.out_channels_list = out_channels_list
494
+ self.scale = scale
495
+ for i, out_channels in enumerate(out_channels_list):
496
+ setattr(self, f"tail{i}",
497
+ HeadBlock(in_channels, out_channels * c_mult, depth=depth, bias=bias, relu_in=relu_in,
498
+ skip_in=skip_in))
499
+
500
+ self.mode = mode
501
+ if self.mode == "":
502
+ self.nl = torch.nn.ReLU(inplace=False)
503
+ if self.scale != 1:
504
+ for i, out_channels in enumerate(out_channels_list):
505
+ setattr(self, f"up{i}",
506
+ upsample_convtranspose(out_channels * c_mult, out_channels * c_mult, bias=bias,
507
+ mode=str(self.scale)))
508
+
509
+ def forward(self, x, out_channels):
510
+ i = self.out_channels_list.index(out_channels)
511
+ x = getattr(self, f"tail{i}")(x)
512
+ # find index
513
+ if self.scale != 1:
514
+ if self.mode == "bilinear":
515
+ x = torch.nn.functional.interpolate(x, scale_factor=self.scale, mode='bilinear', align_corners=False)
516
+ else:
517
+ x = getattr(self, f"up{i}")(x)
518
+
519
+ return x
520
+
521
+
522
+ class HeadBlock(torch.nn.Module):
523
+ def __init__(self, in_channels, out_channels, kernel_size=3, bias=True, depth=2, relu_in=False, skip_in=False):
524
+ super(HeadBlock, self).__init__()
525
+
526
+ padding = kernel_size // 2
527
+
528
+ c = out_channels if depth < 2 else in_channels
529
+
530
+ self.convin = torch.nn.Conv2d(in_channels, c, kernel_size, padding=padding, bias=bias)
531
+ self.zero_conv_skip = torch.nn.Conv2d(in_channels, c, 1, bias=False)
532
+ self.depth = depth
533
+ self.nl_1 = torch.nn.ReLU(inplace=False)
534
+ self.nl_2 = torch.nn.ReLU(inplace=False)
535
+ self.relu_in = relu_in
536
+ self.skip_in = skip_in
537
+
538
+ for i in range(depth - 1):
539
+ if i < depth - 2:
540
+ c_in, c = in_channels, in_channels
541
+ else:
542
+ c_in, c = in_channels, out_channels
543
+
544
+ setattr(self, f"conv1{i}", torch.nn.Conv2d(c_in, c_in, kernel_size, padding=padding, bias=bias))
545
+ setattr(self, f"conv2{i}", torch.nn.Conv2d(c_in, c, kernel_size, padding=padding, bias=bias))
546
+ setattr(self, f"skipconv{i}", torch.nn.Conv2d(c_in, c, 1, bias=False))
547
+
548
+ def forward(self, x):
549
+
550
+ if self.skip_in and self.relu_in:
551
+ x = self.nl_1(self.convin(x)) + self.zero_conv_skip(x)
552
+ elif self.skip_in and not self.relu_in:
553
+ x = self.convin(x) + self.zero_conv_skip(x)
554
+ else:
555
+ x = self.convin(x)
556
+
557
+ for i in range(self.depth - 1):
558
+ aux = getattr(self, f"conv1{i}")(x)
559
+ aux = self.nl_2(aux)
560
+ aux_0 = getattr(self, f"conv2{i}")(aux)
561
+ aux_1 = getattr(self, f"skipconv{i}")(x)
562
+ x = aux_0 + aux_1
563
+
564
+ return x
565
+
566
+
567
+ # --------------------------------------------------------------------------------------
568
+ class AffineConv2d(nn.Conv2d):
569
+ def __init__(
570
+ self,
571
+ in_channels,
572
+ out_channels,
573
+ kernel_size,
574
+ mode="affine",
575
+ bias=False,
576
+ stride=1,
577
+ padding=0,
578
+ dilation=1,
579
+ groups=1,
580
+ padding_mode="circular",
581
+ blind=True,
582
+ ):
583
+ if mode == "affine": # f(a*x + 1) = a*f(x) + 1
584
+ bias = False
585
+ super().__init__(
586
+ in_channels,
587
+ out_channels,
588
+ kernel_size,
589
+ bias=bias,
590
+ stride=stride,
591
+ padding=padding,
592
+ dilation=dilation,
593
+ groups=groups,
594
+ padding_mode=padding_mode,
595
+ )
596
+ self.blind = blind
597
+ self.mode = mode
598
+
599
+ def affine(self, w):
600
+ """returns new kernels that encode affine combinations"""
601
+ return (
602
+ w.view(self.out_channels, -1).roll(1, 1).view(w.size())
603
+ - w
604
+ + 1 / w[0, ...].numel()
605
+ )
606
+
607
+ def forward(self, x):
608
+ if self.mode != "affine":
609
+ return super().forward(x)
610
+ else:
611
+ kernel = (
612
+ self.affine(self.weight)
613
+ if self.blind
614
+ else torch.cat(
615
+ (self.affine(self.weight[:, :-1, :, :]), self.weight[:, -1:, :, :]),
616
+ dim=1,
617
+ )
618
+ )
619
+ padding = tuple(
620
+ elt for elt in reversed(self.padding) for _ in range(2)
621
+ ) # used to translate padding arg used by Conv module to the ones used by F.pad
622
+ padding_mode = (
623
+ self.padding_mode if self.padding_mode != "zeros" else "constant"
624
+ ) # used to translate padding_mode arg used by Conv module to the ones used by F.pad
625
+ return F.conv2d(
626
+ F.pad(x, padding, mode=padding_mode),
627
+ kernel,
628
+ stride=self.stride,
629
+ dilation=self.dilation,
630
+ groups=self.groups,
631
+ )
632
+
633
+
634
+ """
635
+ Functional blocks below
636
+
637
+ Parts of code borrowed from
638
+ https://github.com/cszn/DPIR/tree/master/models
639
+ https://github.com/xinntao/BasicSR
640
+ """
641
+ from collections import OrderedDict
642
+ import torch
643
+ import torch.nn as nn
644
+
645
+ """
646
+ # --------------------------------------------
647
+ # Advanced nn.Sequential
648
+ # https://github.com/xinntao/BasicSR
649
+ # --------------------------------------------
650
+ """
651
+
652
+
653
+ def sequential(*args):
654
+ """Advanced nn.Sequential.
655
+ Args:
656
+ nn.Sequential, nn.Module
657
+ Returns:
658
+ nn.Sequential
659
+ """
660
+ if len(args) == 1:
661
+ if isinstance(args[0], OrderedDict):
662
+ raise NotImplementedError("sequential does not support OrderedDict input.")
663
+ return args[0] # No sequential is needed.
664
+ modules = []
665
+ for module in args:
666
+ if isinstance(module, nn.Sequential):
667
+ for submodule in module.children():
668
+ modules.append(submodule)
669
+ elif isinstance(module, nn.Module):
670
+ modules.append(module)
671
+ return nn.Sequential(*modules)
672
+
673
+
674
+ def conv(
675
+ in_channels=64,
676
+ out_channels=64,
677
+ kernel_size=3,
678
+ stride=1,
679
+ padding=1,
680
+ bias=True,
681
+ mode="CBR",
682
+ ):
683
+ L = []
684
+ for t in mode:
685
+ if t == "C":
686
+ L.append(
687
+ nn.Conv2d(
688
+ in_channels=in_channels,
689
+ out_channels=out_channels,
690
+ kernel_size=kernel_size,
691
+ stride=stride,
692
+ padding=padding,
693
+ bias=bias,
694
+ )
695
+ )
696
+ elif t == "T":
697
+ L.append(
698
+ nn.ConvTranspose2d(
699
+ in_channels=in_channels,
700
+ out_channels=out_channels,
701
+ kernel_size=kernel_size,
702
+ stride=stride,
703
+ padding=padding,
704
+ bias=bias,
705
+ )
706
+ )
707
+ elif t == "R":
708
+ L.append(nn.ReLU(inplace=True))
709
+ else:
710
+ raise NotImplementedError("Undefined type: ".format(t))
711
+ return sequential(*L)
712
+
713
+
714
+ # --------------------------------------------
715
+ # convTranspose (+ relu)
716
+ # --------------------------------------------
717
+ def upsample_convtranspose(
718
+ in_channels=64,
719
+ out_channels=3,
720
+ padding=0,
721
+ bias=True,
722
+ mode="2R",
723
+ ):
724
+ assert len(mode) < 4 and mode[0] in [
725
+ "2",
726
+ "3",
727
+ "4",
728
+ "8",
729
+ ], "mode examples: 2, 2R, 2BR, 3, ..., 4BR."
730
+ kernel_size = int(mode[0])
731
+ stride = int(mode[0])
732
+ mode = mode.replace(mode[0], "T")
733
+ up1 = conv(
734
+ in_channels,
735
+ out_channels,
736
+ kernel_size,
737
+ stride,
738
+ padding,
739
+ bias,
740
+ mode,
741
+ )
742
+ return up1
743
+
744
+
745
+ def downsample_strideconv(
746
+ in_channels=64,
747
+ out_channels=64,
748
+ padding=0,
749
+ bias=True,
750
+ mode="2R",
751
+ ):
752
+ assert len(mode) < 4 and mode[0] in [
753
+ "2",
754
+ "3",
755
+ "4",
756
+ "8",
757
+ ], "mode examples: 2, 2R, 2BR, 3, ..., 4BR."
758
+ kernel_size = int(mode[0])
759
+ stride = int(mode[0])
760
+ mode = mode.replace(mode[0], "C")
761
+ down1 = conv(
762
+ in_channels,
763
+ out_channels,
764
+ kernel_size,
765
+ stride,
766
+ padding,
767
+ bias,
768
+ mode,
769
+ )
770
+ return down1
771
+
772
+
773
+ class Upsampling(Downsampling):
774
+ def A(self, x, **kwargs):
775
+ return super().A_adjoint(x, **kwargs)
776
+
777
+ def A_adjoint(self, y, **kwargs):
778
+ return super().A(y, **kwargs)
779
+
780
+ def prox_l2(self, z, y, gamma, **kwargs):
781
+ return super().prox_l2(z, y, gamma, **kwargs)
782
+
783
+
784
+ class MultiScalePhysics(Physics):
785
+ def __init__(self, physics, img_shape, filter="sinc", scales=[2, 4, 8], device='cpu', **kwargs):
786
+ super().__init__(noise_model=physics.noise_model, **kwargs)
787
+ self.base = physics
788
+ self.scales = scales
789
+ self.img_shape = img_shape
790
+ self.Upsamplings = [Upsampling(img_size=img_shape, filter=filter, factor=factor, device=device) for factor in
791
+ scales]
792
+ self.scale = 0
793
+
794
+ def set_scale(self, scale):
795
+ if scale is not None:
796
+ self.scale = scale
797
+
798
+ def A(self, x, scale=None, **kwargs):
799
+ self.set_scale(scale)
800
+ if self.scale == 0:
801
+ return self.base.A(x, **kwargs)
802
+ else:
803
+ return self.base.A(self.Upsamplings[self.scale - 1].A(x), **kwargs)
804
+
805
+ def downsample(self, x, scale=None):
806
+ self.set_scale(scale)
807
+ if self.scale == 0:
808
+ return x
809
+ else:
810
+ return self.Upsamplings[self.scale - 1].A_adjoint(x)
811
+
812
+ def upsample(self, x, scale=None):
813
+ self.set_scale(scale)
814
+ if self.scale == 0:
815
+ return x
816
+ else:
817
+ return self.Upsamplings[self.scale - 1].A(x)
818
+
819
+ def update_parameters(self, **kwargs):
820
+ self.base.update_parameters(**kwargs)
821
+
822
+
823
+ class MultiScaleLinearPhysics(MultiScalePhysics, LinearPhysics):
824
+ def __init__(self, physics, img_shape, filter="sinc", scales=[2, 4, 8], **kwargs):
825
+ super().__init__(physics=physics, img_shape=img_shape, filter=filter, scales=scales, **kwargs)
826
+
827
+ def A_adjoint(self, y, scale=None, **kwargs):
828
+ self.set_scale(scale)
829
+ y = self.base.A_adjoint(y, **kwargs)
830
+ if self.scale == 0:
831
+ return y
832
+ else:
833
+ return self.Upsamplings[self.scale - 1].A_adjoint(y)
834
+
835
+
836
+ class Pad(LinearPhysics):
837
+ def __init__(self, physics, pad):
838
+ super().__init__(noise_model=physics.noise_model)
839
+ self.base = physics
840
+ self.pad = pad
841
+
842
+ def A(self, x):
843
+ return self.base.A(x[..., self.pad[0]:, self.pad[1]:])
844
+
845
+ def A_adjoint(self, y):
846
+ y = self.base.A_adjoint(y)
847
+ y = torch.nn.functional.pad(y, (self.pad[1], 0, self.pad[0], 0))
848
+ return y
849
+
850
+ def remove_pad(self, x):
851
+ return x[..., self.pad[0]:, self.pad[1]:]
852
+
853
+ def update_parameters(self, **kwargs):
854
+ self.base.update_parameters(**kwargs)
models/unext_wip.py DELETED
@@ -1,1238 +0,0 @@
1
- # Code borrowed from Kai Zhang https://github.com/cszn/DPIR/tree/master/models
2
- import re
3
- import math
4
- import functools
5
-
6
- import deepinv as dinv
7
- from deepinv.utils import plot, TensorList
8
-
9
- import torch
10
- from torch.func import vmap
11
- import torch.nn as nn
12
- import torch.nn.functional as F
13
- from torchvision import transforms
14
- from deepinv.optim.utils import conjugate_gradient
15
-
16
- from physics.multiscale import MultiScaleLinearPhysics, Pad
17
- from models.blocks import EquivMaxPool, AffineConv2d, ConvNextBlock2, NoiseEmbedding, MPConv, TimestepEmbedding, conv, downsample_strideconv, upsample_convtranspose
18
- from models.heads import Heads, Tails, InHead, OutTail, ConvChannels, SNRModule, EquivConvModule, EquivHeads
19
-
20
- cuda = True if torch.cuda.is_available() else False
21
- Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
22
-
23
-
24
- ### --------------- MODEL ---------------
25
- class BaseEncBlock(nn.Module):
26
- def __init__(
27
- self,
28
- in_channels,
29
- out_channels,
30
- bias=False,
31
- mode="CRC",
32
- nb=2,
33
- embedding=False,
34
- emb_channels=None,
35
- emb_physics=False,
36
- img_channels=None,
37
- decode_upscale=None,
38
- config='A',
39
- N=4,
40
- c_mult=1,
41
- depth_encoding=1,
42
- relu_in_encoding=False,
43
- skip_in_encoding=True,
44
- ):
45
- super(BaseEncBlock, self).__init__()
46
- self.config = config
47
- self.enc = nn.ModuleList(
48
- [
49
- ResBlock(
50
- in_channels,
51
- out_channels,
52
- bias=bias,
53
- mode=mode,
54
- embedding=embedding,
55
- emb_channels=emb_channels,
56
- emb_physics=emb_physics,
57
- img_channels=img_channels,
58
- decode_upscale=decode_upscale,
59
- config=config,
60
- N=N,
61
- c_mult=c_mult,
62
- depth_encoding=depth_encoding,
63
- relu_in_encoding=relu_in_encoding,
64
- skip_in_encoding=skip_in_encoding,
65
- )
66
- for _ in range(nb)
67
- ]
68
- )
69
-
70
- def forward(self, x, emb_sigma=None, physics=None, t=None, y=None, emb_in=None, img_channels=None, scale=0):
71
- for i in range(len(self.enc)):
72
- x = self.enc[i](x, emb_sigma=emb_sigma, physics=physics, t=t, y=y, img_channels=img_channels, scale=scale)
73
- return x
74
-
75
-
76
- class NextEncBlock(nn.Module):
77
- def __init__(
78
- self, in_channels, out_channels, bias=False, mode="", mult_fact=4, nb=2
79
- ):
80
- super(NextEncBlock, self).__init__()
81
- self.enc = nn.ModuleList(
82
- [
83
- ConvNextBlock2(
84
- in_channels=in_channels,
85
- out_channels=out_channels,
86
- bias=bias,
87
- mode=mode,
88
- mult_fact=mult_fact,
89
- )
90
- for _ in range(nb)
91
- ]
92
- )
93
-
94
- def forward(self, x, emb_sigma=None):
95
- for i in range(len(self.enc)):
96
- x = self.enc[i](x, emb_sigma)
97
- return x
98
-
99
-
100
- class UNeXt(nn.Module):
101
- r"""
102
- DRUNet denoiser network.
103
-
104
- The network architecture is based on the paper
105
- `Learning deep CNN denoiser prior for image restoration <https://arxiv.org/abs/1704.03264>`_,
106
- and has a U-Net like structure, with convolutional blocks in the encoder and decoder parts.
107
-
108
- The network takes into account the noise level of the input image, which is encoded as an additional input channel.
109
-
110
- A pretrained network for (in_channels=out_channels=1 or in_channels=out_channels=3)
111
- can be downloaded via setting ``pretrained='download'``.
112
-
113
- :param int in_channels: number of channels of the input.
114
- :param int out_channels: number of channels of the output.
115
- :param list nc: number of convolutional layers.
116
- :param int nb: number of convolutional blocks per layer.
117
- :param int nf: number of channels per convolutional layer.
118
- :param str act_mode: activation mode, "R" for ReLU, "L" for LeakyReLU "E" for ELU and "S" for Softplus.
119
- :param str downsample_mode: Downsampling mode, "avgpool" for average pooling, "maxpool" for max pooling, and
120
- "strideconv" for convolution with stride 2.
121
- :param str upsample_mode: Upsampling mode, "convtranspose" for convolution transpose, "pixelsuffle" for pixel
122
- shuffling, and "upconv" for nearest neighbour upsampling with additional convolution.
123
- :param str, None pretrained: use a pretrained network. If ``pretrained=None``, the weights will be initialized at random
124
- using Pytorch's default initialization. If ``pretrained='download'``, the weights will be downloaded from an
125
- online repository (only available for the default architecture with 3 or 1 input/output channels).
126
- Finally, ``pretrained`` can also be set as a path to the user's own pretrained weights.
127
- See :ref:`pretrained-weights <pretrained-weights>` for more details.
128
- :param bool train: training or testing mode.
129
- :param str device: gpu or cpu.
130
-
131
- """
132
-
133
- def __init__(
134
- self,
135
- in_channels=[1, 2, 3],
136
- out_channels=[1, 2, 3],
137
- nc=[64, 128, 256, 512],
138
- nb=4, # 4 in DRUNet but out of memory
139
- conv_type="next", # should be 'base' or 'next'
140
- pool_type="next", # should be 'base' or 'next'
141
- cond_type="base", # conditioning, should be 'base' or 'edm'
142
- device=None,
143
- bias=False,
144
- mode="",
145
- residual=False,
146
- act_mode="R",
147
- layer_scale_init_value=1e-6,
148
- init_type="ortho",
149
- gain_init_conv=1.0,
150
- gain_init_linear=1.0,
151
- drop_prob=0.0,
152
- replk=False,
153
- mult_fact=4,
154
- antialias="gaussian",
155
- emb_physics=False,
156
- config='A',
157
- pretrained_pth=None,
158
- N=4,
159
- c_mult=1,
160
- depth_encoding=1,
161
- relu_in_encoding=False,
162
- skip_in_encoding=True,
163
- ):
164
- super(UNeXt, self).__init__()
165
-
166
- self.residual = residual
167
- self.conv_type = conv_type
168
- self.pool_type = pool_type
169
- self.emb_physics = emb_physics
170
- self.config = config
171
- self.in_channels = in_channels
172
- self.fact_realign = torch.nn.Parameter(torch.tensor([1.0], device=device))
173
-
174
- self.separate_head = isinstance(in_channels, list)
175
-
176
- assert cond_type in ["base", "edm"], "cond_type should be 'base' or 'edm'"
177
- self.cond_type = cond_type
178
-
179
- if self.cond_type == "base":
180
- if self.config != 'E':
181
- if isinstance(in_channels, list):
182
- in_channels_first = []
183
- for i in range(len(in_channels)):
184
- in_channels_first.append(in_channels[i] + 2)
185
- else: # old head
186
- in_channels_first = in_channels + 1
187
- else:
188
- in_channels_first = in_channels
189
- else:
190
- in_channels_first = in_channels
191
- self.noise_embedding = NoiseEmbedding(
192
- num_channels=in_channels, emb_channels=max(nc), device=device
193
- )
194
-
195
- self.timestep_embedding = lambda x: x
196
-
197
- # check if in_channels is a list
198
- self.m_head = InHead(in_channels_first, nc[0])
199
-
200
- if conv_type == "next":
201
- self.m_down1 = NextEncBlock(
202
- nc[0], nc[0], bias=bias, mode=mode, mult_fact=mult_fact, nb=nb
203
- )
204
- self.m_down2 = NextEncBlock(
205
- nc[1], nc[1], bias=bias, mode=mode, mult_fact=mult_fact, nb=nb
206
- )
207
- self.m_down3 = NextEncBlock(
208
- nc[2], nc[2], bias=bias, mode=mode, mult_fact=mult_fact, nb=nb
209
- )
210
- self.m_body = NextEncBlock(
211
- nc[3], nc[3], bias=bias, mode=mode, mult_fact=mult_fact, nb=nb
212
- )
213
- self.m_up3 = NextEncBlock(
214
- nc[2], nc[2], bias=bias, mode=mode, mult_fact=mult_fact, nb=nb
215
- )
216
- self.m_up2 = NextEncBlock(
217
- nc[1], nc[1], bias=bias, mode=mode, mult_fact=mult_fact, nb=nb
218
- )
219
- self.m_up1 = NextEncBlock(
220
- nc[0], nc[0], bias=bias, mode=mode, mult_fact=mult_fact, nb=nb
221
- )
222
-
223
- elif conv_type == "base":
224
- embedding = (
225
- False if cond_type == "base" else True
226
- )
227
- emb_channels = max(nc)
228
- self.m_down1 = BaseEncBlock(
229
- nc[0],
230
- nc[0],
231
- bias=False,
232
- mode="CRC",
233
- nb=nb,
234
- embedding=embedding,
235
- emb_channels=emb_channels,
236
- emb_physics=emb_physics,
237
- img_channels=in_channels,
238
- decode_upscale=1,
239
- config=config,
240
- N=N,
241
- c_mult=c_mult,
242
- depth_encoding=depth_encoding,
243
- relu_in_encoding=relu_in_encoding,
244
- skip_in_encoding=skip_in_encoding,
245
- )
246
- self.m_down2 = BaseEncBlock(
247
- nc[1],
248
- nc[1],
249
- bias=False,
250
- mode="CRC",
251
- nb=nb,
252
- embedding=embedding,
253
- emb_channels=emb_channels,
254
- emb_physics=emb_physics,
255
- img_channels=in_channels,
256
- decode_upscale=2,
257
- config=config,
258
- N=N,
259
- c_mult=c_mult,
260
- depth_encoding=depth_encoding,
261
- relu_in_encoding=relu_in_encoding,
262
- skip_in_encoding=skip_in_encoding,
263
- )
264
- self.m_down3 = BaseEncBlock(
265
- nc[2],
266
- nc[2],
267
- bias=False,
268
- mode="CRC",
269
- nb=nb,
270
- embedding=embedding,
271
- emb_channels=emb_channels,
272
- emb_physics=emb_physics,
273
- img_channels=in_channels,
274
- decode_upscale=4,
275
- config=config,
276
- N=N,
277
- c_mult=c_mult,
278
- depth_encoding=depth_encoding,
279
- relu_in_encoding=relu_in_encoding,
280
- skip_in_encoding=skip_in_encoding,
281
- )
282
- self.m_body = BaseEncBlock(
283
- nc[3],
284
- nc[3],
285
- bias=False,
286
- mode="CRC",
287
- nb=nb,
288
- embedding=embedding,
289
- emb_channels=emb_channels,
290
- emb_physics=emb_physics,
291
- img_channels=in_channels,
292
- decode_upscale=8,
293
- config=config,
294
- N=N,
295
- c_mult=c_mult,
296
- depth_encoding=depth_encoding,
297
- relu_in_encoding=relu_in_encoding,
298
- skip_in_encoding=skip_in_encoding,
299
- )
300
- self.m_up3 = BaseEncBlock(
301
- nc[2],
302
- nc[2],
303
- bias=False,
304
- mode="CRC",
305
- nb=nb,
306
- embedding=embedding,
307
- emb_channels=emb_channels,
308
- emb_physics=emb_physics,
309
- img_channels=in_channels,
310
- decode_upscale=4,
311
- config=config,
312
- N=N,
313
- c_mult=c_mult,
314
- depth_encoding=depth_encoding,
315
- relu_in_encoding=relu_in_encoding,
316
- skip_in_encoding=skip_in_encoding,
317
- )
318
- self.m_up2 = BaseEncBlock(
319
- nc[1],
320
- nc[1],
321
- bias=False,
322
- mode="CRC",
323
- nb=nb,
324
- embedding=embedding,
325
- emb_channels=emb_channels,
326
- emb_physics=emb_physics,
327
- img_channels=in_channels,
328
- decode_upscale=2,
329
- config=config,
330
- N=N,
331
- c_mult=c_mult,
332
- depth_encoding=depth_encoding,
333
- relu_in_encoding=relu_in_encoding,
334
- skip_in_encoding=skip_in_encoding,
335
- )
336
- self.m_up1 = BaseEncBlock(
337
- nc[0],
338
- nc[0],
339
- bias=False,
340
- mode="CRC",
341
- nb=nb,
342
- embedding=embedding,
343
- emb_channels=emb_channels,
344
- emb_physics=emb_physics,
345
- img_channels=in_channels,
346
- decode_upscale=1,
347
- config=config,
348
- N=N,
349
- c_mult=c_mult,
350
- depth_encoding=depth_encoding,
351
- relu_in_encoding=relu_in_encoding,
352
- skip_in_encoding=skip_in_encoding,
353
- )
354
-
355
- else:
356
- raise NotImplementedError("conv_type should be 'base' or 'next'")
357
-
358
- if pool_type == "next_max":
359
- self.pool1 = EquivMaxPool(
360
- antialias=antialias,
361
- in_channels=nc[0],
362
- out_channels=nc[1],
363
- device=device,
364
- )
365
- self.pool2 = EquivMaxPool(
366
- antialias=antialias,
367
- in_channels=nc[1],
368
- out_channels=nc[2],
369
- device=device,
370
- )
371
- self.pool3 = EquivMaxPool(
372
- antialias=antialias,
373
- in_channels=nc[2],
374
- out_channels=nc[3],
375
- device=device,
376
- )
377
- elif pool_type == "base":
378
- self.pool1 = downsample_strideconv(nc[0], nc[1], bias=False, mode="2")
379
- self.pool2 = downsample_strideconv(nc[1], nc[2], bias=False, mode="2")
380
- self.pool3 = downsample_strideconv(nc[2], nc[3], bias=False, mode="2")
381
- self.up3 = upsample_convtranspose(nc[3], nc[2], bias=False, mode="2")
382
- self.up2 = upsample_convtranspose(nc[2], nc[1], bias=False, mode="2")
383
- self.up1 = upsample_convtranspose(nc[1], nc[0], bias=False, mode="2")
384
- else:
385
- raise NotImplementedError("pool_type should be 'base' or 'next'")
386
-
387
- self.m_tail = OutTail(nc[0], in_channels)
388
-
389
- if conv_type == "base":
390
- init_func = functools.partial(
391
- weights_init_unext, init_type="ortho", gain_conv=0.2
392
- )
393
- self.apply(init_func)
394
- else:
395
- init_func = functools.partial(
396
- weights_init_unext,
397
- init_type=init_type,
398
- gain_conv=gain_init_conv,
399
- gain_linear=gain_init_linear,
400
- )
401
- self.apply(init_func)
402
-
403
- if pretrained_pth=='jz':
404
- pth = '/lustre/fswork/projects/rech/nyd/commun/mterris/base_checkpoints/drunet_deepinv_color_finetune_22k.pth'
405
- self.load_drunet_weights(pth)
406
- elif pretrained_pth is not None:
407
- self.load_drunet_weights(pretrained_pth)
408
-
409
- if self.config == 'D':
410
- # deactivate grad for layers that do not contain the string "PhysicsBlock" or "gain" or "fact_realign"
411
- for name, param in self.named_parameters():
412
- if 'PhysicsBlock' not in name and 'gain' not in name and 'fact_realign' not in name and "m_head" not in name and "m_tail" not in name:
413
- param.requires_grad = False
414
-
415
- if device is not None:
416
- self.to(device)
417
-
418
- def load_drunet_weights(self, ckpt_pth):
419
- state_dict = torch.load(ckpt_pth, map_location=lambda storage, loc: storage)
420
-
421
- new_state_dict = {}
422
- matched_keys = [] # List to store successfully matched keys
423
- unmatched_keys = [] # List to store keys that were not matched or excluded
424
- excluded_keys = [] # List to store excluded keys
425
-
426
- # Define patterns to exclude
427
- exclude_patterns = ["head", "tail"]
428
-
429
- # Dealing with regular keys
430
- for old_key, value in state_dict.items():
431
- # Skip keys containing any of the excluded patterns
432
- if any(excluded in old_key for excluded in exclude_patterns):
433
- excluded_keys.append(old_key)
434
- continue # Skip further processing for this key
435
-
436
- new_key = old2new(old_key)
437
-
438
- if new_key is not None:
439
- matched_keys.append((old_key, new_key)) # Record the matched keys
440
- new_state_dict[new_key] = value
441
- else:
442
- unmatched_keys.append(old_key) # Record unmatched keys
443
-
444
- # TODO: clean this
445
- for excluded_key in excluded_keys:
446
- if isinstance(self.in_channels, list):
447
- for i, in_channel in enumerate(self.in_channels):
448
- # print('Dealing with conv ', i)
449
- new_key = f"m_head.conv{i}.weight"
450
- if 'head' in excluded_key:
451
- new_key = f"m_head.conv{i}.weight"
452
- # new_key = f"m_head.head.conv{i}.weight"
453
- if 'tail' in excluded_key:
454
- new_key = f"m_tail.conv{i}.weight"
455
- # DEBUG print all keys of state dict:
456
- # print(state_dict.keys())
457
- # print(self.state_dict().keys())
458
- conditioning = 'base'
459
- # if self.config == 'E':
460
- # conditioning = False
461
- new_kv = update_keyvals_headtail(excluded_key,
462
- state_dict[excluded_key],
463
- init_value=self.state_dict()[new_key],
464
- new_key_name=new_key,
465
- conditioning=conditioning)
466
- new_state_dict.update(new_kv)
467
- # print(new_kv.keys())
468
- else:
469
- new_kv = update_keyvals_headtail(excluded_key, state_dict[excluded_key])
470
- new_state_dict.update(new_kv)
471
-
472
- # Display matched keys
473
- print("Matched keys:")
474
- for old_key, new_key in matched_keys:
475
- print(f"{old_key} -> {new_key}")
476
-
477
- # Load updated state dict into the model
478
- self.load_state_dict(new_state_dict, strict=False)
479
-
480
- # Display unmatched keys
481
- print("\nUnmatched keys:")
482
- for unmatched_key in unmatched_keys:
483
- print(unmatched_key)
484
-
485
- print("Weights loaded from ", ckpt_pth)
486
-
487
- def constant2map(self, value, x):
488
- if isinstance(value, torch.Tensor):
489
- if value.ndim > 0:
490
- value_map = value.view(x.size(0), 1, 1, 1)
491
- value_map = value_map.expand(-1, 1, x.size(2), x.size(3))
492
- else:
493
- value_map = torch.ones(
494
- (x.size(0), 1, x.size(2), x.size(3)), device=x.device
495
- ) * value[None, None, None, None].to(x.device)
496
- else:
497
- value_map = (
498
- torch.ones((x.size(0), 1, x.size(2), x.size(3)), device=x.device)
499
- * value
500
- )
501
- return value_map
502
-
503
- def base_conditioning(self, x, sigma, gamma):
504
- noise_level_map = self.constant2map(sigma, x)
505
- gamma_map = self.constant2map(gamma, x)
506
- return torch.cat((x, noise_level_map, gamma_map), 1)
507
-
508
- def realign_input(self, x, physics, y):
509
-
510
- if hasattr(physics, "factor"):
511
- f = physics.factor
512
- elif hasattr(physics, "base") and hasattr(physics.base, "factor"):
513
- f = physics.base.factor
514
- elif hasattr(physics, "base") and hasattr(physics.base, "base") and hasattr(physics.base.base, "factor"):
515
- f = physics.base.base.factor
516
- else:
517
- f = 1.0
518
-
519
- sigma = 1e-6 # default value
520
- if hasattr(physics.noise_model, 'sigma'):
521
- sigma = physics.noise_model.sigma
522
- if hasattr(physics, 'base') and hasattr(physics.base, 'noise_model') and hasattr(physics.base.noise_model, 'sigma'):
523
- sigma = physics.base.noise_model.sigma
524
- if hasattr(physics, 'base') and hasattr(physics.base, 'base') and hasattr(physics.base.base, 'noise_model') and hasattr(physics.base.base.noise_model, 'sigma'):
525
- sigma = physics.base.base.noise_model.sigma
526
-
527
- if isinstance(y, TensorList):
528
- num = (y[0].reshape(y[0].shape[0], -1).abs().mean(1))
529
- else:
530
- num = (y.reshape(y.shape[0], -1).abs().mean(1))
531
-
532
- snr = num / (sigma + 1e-4) # SNR equivariant
533
- gamma = 1 / (1e-4 + 1 / (snr * f **2 )) # TODO: check square-root / mean / check if we need to add a factor in front ?
534
- gamma = gamma[(...,) + (None,) * (x.dim() - 1)]
535
- model_input = physics.prox_l2(x, y, gamma=gamma * self.fact_realign)
536
-
537
- return model_input
538
-
539
- def forward_unet(self, x0, sigma=None, gamma=None, physics=None, t=None, y=None, img_channels=None):
540
-
541
- # list_values = []
542
-
543
- if self.cond_type == "base":
544
- # if self.config != 'E':
545
- x0 = self.base_conditioning(x0, sigma, gamma)
546
- emb_sigma = None
547
- else:
548
- emb_sigma = self.noise_embedding(
549
- sigma
550
- ) # This only if the embedding is the non-basic one from drunet
551
-
552
- emb_timestep = self.timestep_embedding(t)
553
-
554
- x1 = self.m_head(x0) # old
555
- # x1 = self.m_head(x0, emb_sigma=emb_sigma, physics=physics, t=emb_timestep, y=y, img_channels=img_channels)
556
- # list_values.append(x1.abs().mean())
557
-
558
- if self.config == 'G':
559
- x1_, emb1_ = self.m_down1(x1, emb_sigma=emb_sigma, physics=physics, t=emb_timestep, y=y, img_channels=img_channels)
560
- else:
561
- x1_ = self.m_down1(x1, emb_sigma=emb_sigma, physics=physics, t=emb_timestep, y=y, img_channels=img_channels, scale=0)
562
- x2 = self.pool1(x1_)
563
- # list_values.append(x2.abs().mean())
564
-
565
- if self.config == 'G':
566
- x3_, emb3_ = self.m_down2(x2, emb_sigma=emb_sigma, physics=physics, t=emb_timestep, y=y, img_channels=img_channels)
567
- else:
568
- x3_ = self.m_down2(x2, emb_sigma=emb_sigma, physics=physics, t=emb_timestep, y=y, img_channels=img_channels, scale=1)
569
- x3 = self.pool2(x3_)
570
-
571
- # list_values.append(x3.abs().mean())
572
- if self.config == 'G':
573
- x4_, emb4_ = self.m_down3(x3, emb_sigma=emb_sigma, physics=physics, t=emb_timestep, y=y, img_channels=img_channels)
574
- else:
575
- x4_ = self.m_down3(x3, emb_sigma=emb_sigma, physics=physics, t=emb_timestep, y=y, img_channels=img_channels, scale=2)
576
- x4 = self.pool3(x4_)
577
-
578
- # issue: https://github.com/matthieutrs/ram_project/issues/1
579
- # solution 1: using .contiguous() below
580
- # solution 2: using a print statement that magically solves the issue
581
- ###print(x4.is_contiguous())
582
-
583
- # list_values.append(x4.abs().mean())
584
- if self.config == 'G':
585
- x, _ = self.m_body(x4, emb_sigma=emb_sigma, physics=physics, t=emb_timestep, y=y, img_channels=img_channels)
586
- else:
587
- x = self.m_body(x4, emb_sigma=emb_sigma, physics=physics, t=emb_timestep, y=y, img_channels=img_channels, scale=3)
588
-
589
- # list_values.append(x.abs().mean())
590
- if self.pool_type == "next" or self.pool_type == "next_max":
591
- x = self.pool3.upscale(x + x4)
592
- else:
593
- x = self.up3(x + x4)
594
-
595
- if self.config == 'G':
596
- x, _ = self.m_up3(x, emb_sigma=emb_sigma, physics=physics, t=emb_timestep, y=y, emb_in=emb4_, img_channels=img_channels)
597
- else:
598
- x = self.m_up3(x, emb_sigma=emb_sigma, physics=physics, t=emb_timestep, y=y, img_channels=img_channels, scale=2)
599
-
600
- # list_values.append(x.abs().mean())
601
- if self.pool_type == "next" or self.pool_type == "next_max":
602
- x = self.pool2.upscale(x + x3)
603
- else:
604
- x = self.up2(x + x3)
605
-
606
- if self.config == 'G':
607
- x, _ = self.m_up2(x, emb_sigma=emb_sigma, physics=physics, t=emb_timestep, y=y, emb_in=emb3_, img_channels=img_channels)
608
- else:
609
- x = self.m_up2(x, emb_sigma=emb_sigma, physics=physics, t=emb_timestep, y=y, img_channels=img_channels, scale=1)
610
-
611
- # list_values.append(x.abs().mean())
612
- if self.pool_type == "next" or self.pool_type == "next_max":
613
- x = self.pool1.upscale(x + x2)
614
- else:
615
- x = self.up1(x + x2)
616
-
617
- if self.config == 'G':
618
- x, _ = self.m_up1(x, emb_sigma=emb_sigma, physics=physics, t=emb_timestep, y=y, emb_in=emb1_, img_channels=img_channels)
619
- else:
620
- x = self.m_up1(x, emb_sigma=emb_sigma, physics=physics, t=emb_timestep, y=y, img_channels=img_channels, scale=0)
621
-
622
- # list_values.append(x.abs().mean())
623
- if self.separate_head:
624
- x = self.m_tail(x + x1, img_channels)
625
- else:
626
- x = self.m_tail(x + x1)
627
-
628
- return x
629
-
630
- def forward(self, x, sigma=None, gamma=None, physics=None, t=None, y=None):
631
- r"""
632
- Run the denoiser on image with noise level :math:`\sigma`.
633
-
634
- :param torch.Tensor x: noisy image
635
- :param float, torch.Tensor sigma: noise level. If ``sigma`` is a float, it is used for all images in the batch.
636
- If ``sigma`` is a tensor, it must be of shape ``(batch_size,)``.
637
- """
638
- img_channels = x.shape[1] # x_n_chan = x.shape[1]
639
- if self.emb_physics:
640
- physics = MultiScaleLinearPhysics(physics, x.shape[-3:], device=x.device)
641
-
642
- if self.separate_head and img_channels not in self.in_channels:
643
- raise ValueError(f"Input image has {img_channels} channels, but the network only have heads for {self.in_channels} channels.")
644
-
645
- if y is not None:
646
- x = self.realign_input(x, physics, y)
647
-
648
- x = self.forward_unet(x, sigma=sigma, gamma=gamma, physics=physics, t=t, y=y, img_channels=img_channels)
649
-
650
- return x
651
-
652
-
653
- def krylov_embeddings_old(y, p, factor, v=None, N=4, feat_size=1, x_init=None, img_channels=3):
654
-
655
- if x_init is None:
656
- x = p.A_adjoint(y)
657
- else:
658
- x = x_init[:, :img_channels, ...]
659
-
660
- if feat_size > 1:
661
- _, C, _, _ = x.shape
662
- if v is None:
663
- v = torch.zeros_like(x).repeat(1, N-1, 1, 1)
664
- out = x - v[:, :C, ...]
665
- norm = factor ** 2
666
- A = lambda u: p.A_adjoint(p.A(u)) * norm
667
- for i in range(N-1):
668
- x = A(x) - v[:, (i+1) * C:(i+2) * C, ...]
669
- out = torch.cat([out, x], dim=1)
670
- else:
671
- if v is None:
672
- v = torch.zeros_like(x)
673
- out = x - v
674
- norm = factor ** 2
675
- A = lambda u: p.A_adjoint(p.A(u)) * norm
676
- for i in range(N-1):
677
- x = A(x) - v
678
- out = torch.cat([out, x], dim=1)
679
- return out
680
-
681
- def krylov_embeddings(y, p, factor, v=None, N=4, x_init=None, img_channels=3):
682
- """
683
- Efficient Krylov subspace embedding computation with parallel processing.
684
-
685
- Args:
686
- y (torch.Tensor): The input tensor.
687
- p: An object with A and A_adjoint methods (linear operator).
688
- factor (float): Scaling factor.
689
- v (torch.Tensor, optional): Precomputed values to subtract from Krylov sequence. Defaults to None.
690
- N (int, optional): Number of Krylov iterations. Defaults to 4.
691
- feat_size (int, optional): Feature expansion size. Defaults to 1.
692
- x_init (torch.Tensor, optional): Initial guess. Defaults to None.
693
- img_channels (int, optional): Number of image channels. Defaults to 3.
694
-
695
- Returns:
696
- torch.Tensor: The Krylov embeddings.
697
- """
698
-
699
- if x_init is None:
700
- x = p.A_adjoint(y)
701
- else:
702
- x = x_init.clone() # Extract the first img_channels
703
-
704
- norm = factor ** 2 # Precompute normalization factor
705
- AtA = lambda u: p.A_adjoint(p.A(u)) * norm # Define the linear operator
706
-
707
- v = v if v is not None else torch.zeros_like(x)
708
-
709
- out = x.clone()
710
- # Compute Krylov basis
711
- x_k = x.clone()
712
- for i in range(N-1):
713
- x_k = AtA(x_k) - v
714
- out = torch.cat([out, x_k], dim=1)
715
-
716
- return out
717
-
718
-
719
- def grad_embeddings(y, p, factor, v=None, N=4, feat_size=1):
720
- Aty = p.A_adjoint(y)
721
- if feat_size > 1:
722
- _, C, _, _ = Aty.shape
723
- if v is None:
724
- v = torch.zeros_like(Aty).repeat(1, N-1, 1, 1)
725
- out = v[:, :C, ...] - Aty
726
- norm = factor ** 2
727
- A = lambda u: p.A_adjoint(p.A(u)) * norm
728
- for i in range(N-1):
729
- x = A(v[:, (i+1) * C:(i+2) * C, ...]) - Aty
730
- out = torch.cat([out, x], dim=1)
731
- else:
732
- if v is None:
733
- v = torch.zeros_like(Aty)
734
- out = v - Aty
735
- norm = factor ** 2
736
- A = lambda u: p.A_adjoint(p.A(u)) * norm
737
- for i in range(N-1):
738
- x = A(v) - Aty
739
- out = torch.cat([out, x], dim=1)
740
- return out
741
-
742
-
743
- def prox_embeddings(y, p, factor, v=None, N=4):
744
- x = p.A_adjoint(y)
745
- B, C, H, W = x.shape
746
-
747
- if v is None:
748
- v = torch.zeros_like(x)
749
-
750
- v = v.repeat(1, N - 1, 1, 1)
751
-
752
- gamma = torch.logspace(-4, -1, N-1, device=x.device).repeat_interleave(C).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
753
- norm = factor ** 2
754
- A_sub = lambda u: torch.cat([p.A_adjoint(p.A(u[:, i * C:(i+1) * C, ...])) * norm for i in range(N-1)], dim=1)
755
- A = lambda u: A_sub(u) + (u - v) * gamma
756
-
757
- u_hat = conjugate_gradient(A, x.repeat(1, N-1, 1, 1), max_iter=3, tol=1e-3)
758
- u_hat = torch.cat([u_hat, x], dim=1)
759
-
760
- return u_hat
761
-
762
- # --------------------------------------------
763
- # Res Block: x + conv(relu(conv(x)))
764
- # --------------------------------------------
765
- class MeasCondBlock(nn.Module):
766
- def __init__(
767
- self,
768
- out_channels=64,
769
- img_channels=None,
770
- decode_upscale=None,
771
- config = 'A',
772
- N=4,
773
- depth_encoding=1,
774
- relu_in_encoding=False,
775
- skip_in_encoding=True,
776
- c_mult=1,
777
- ):
778
- super(MeasCondBlock, self).__init__()
779
-
780
- self.separate_head = isinstance(img_channels, list)
781
- self.config = config
782
-
783
- assert img_channels is not None, "decode_dimensions should be provided"
784
- assert decode_upscale is not None, "decode_upscale should be provided"
785
-
786
- # if self.separate_head:
787
- if self.config == 'A':
788
- self.relu_encoding = nn.ReLU(inplace=False)
789
- self.N = N
790
- self.c_mult = c_mult
791
- self.encoding_conv = Heads(img_channels, out_channels, depth=depth_encoding, scale=1, bias=False, c_mult=self.c_mult, relu_in=relu_in_encoding, skip_in=skip_in_encoding)
792
- if self.config == 'B':
793
- self.N = N
794
- self.c_mult = c_mult
795
- self.relu_encoding = nn.ReLU(inplace=False)
796
- self.decoding_conv = Tails(out_channels, img_channels, depth=1, scale=1, bias=False, c_mult=self.c_mult)
797
- self.encoding_conv = Heads(img_channels, out_channels, depth=depth_encoding, scale=1, bias=False, c_mult=self.c_mult, relu_in=relu_in_encoding, skip_in=skip_in_encoding)
798
- if self.config == 'C':
799
- self.N = N
800
- self.c_mult = c_mult
801
- self.relu_encoding = nn.ReLU(inplace=False)
802
- self.decoding_conv = Tails(out_channels, img_channels, depth=1, scale=1, bias=False, c_mult=self.c_mult)
803
- self.encoding_conv = Heads(img_channels, out_channels, depth=depth_encoding, scale=1, bias=False, c_mult=self.c_mult*N, c_add=N, relu_in=relu_in_encoding, skip_in=skip_in_encoding)
804
- elif self.config == 'D':
805
- self.N = N
806
- self.c_mult = c_mult
807
- self.relu_encoding = nn.ReLU(inplace=False)
808
- self.decoding_conv = Tails(out_channels, img_channels, depth=1, scale=1, bias=False, c_mult=self.c_mult)
809
- self.encoding_conv = Heads(img_channels, out_channels, depth=depth_encoding, scale=1, bias=False, c_mult=self.c_mult*N, c_add=N, relu_in=relu_in_encoding, skip_in=skip_in_encoding)
810
-
811
- self.gain = torch.nn.Parameter(torch.tensor([1.0]), requires_grad=True)
812
- self.gain_gradx = torch.nn.Parameter(torch.tensor([1e-2]), requires_grad=True)
813
- self.gain_grady = torch.nn.Parameter(torch.tensor([1e-2]), requires_grad=True)
814
- self.gain_pinvx = torch.nn.Parameter(torch.tensor([1e-2]), requires_grad=True)
815
- self.gain_pinvy = torch.nn.Parameter(torch.tensor([1e-2]), requires_grad=True)
816
-
817
- def forward(self, x, y, physics, t, emb_in=None, img_channels=None, scale=1):
818
- if self.config == 'A':
819
- return self.measurement_conditioning_config_A(x, y, physics, img_channels=img_channels, scale=scale)
820
- elif self.config == 'F':
821
- return self.measurement_conditioning_config_F(x, y, physics, img_channels=img_channels, scale=scale)
822
- elif self.config == 'B':
823
- return self.measurement_conditioning_config_B(x, y, physics, img_channels=img_channels, scale=scale)
824
- elif self.config == 'C':
825
- return self.measurement_conditioning_config_C(x, y, physics, img_channels=img_channels, scale=scale)
826
- elif self.config == 'D':
827
- return self.measurement_conditioning_config_D(x, y, physics, img_channels=img_channels, scale=scale)
828
- elif self.config == 'E':
829
- return self.measurement_conditioning_config_E(x, y, physics, img_channels=img_channels, scale=scale)
830
- else:
831
- raise NotImplementedError('Config not implemented')
832
-
833
- def measurement_conditioning_config_A(self, x, y, physics, img_channels, scale=0):
834
- physics.set_scale(scale)
835
- factor = 2**(scale)
836
- meas = krylov_embeddings(y, physics, factor, N=self.N, img_channels=img_channels)
837
- cond = self.encoding_conv(meas)
838
- emb = self.relu_encoding(cond)
839
- return emb
840
-
841
- def measurement_conditioning_config_B(self, x, y, physics, img_channels, scale=0):
842
- physics.set_scale(scale)
843
- dec = self.decoding_conv(x, img_channels)
844
- factor = 2**(scale)
845
- meas = krylov_embeddings(y, physics, factor, v=dec, N=self.N, img_channels=img_channels)
846
- cond = self.encoding_conv(meas)
847
- emb = self.relu_encoding(cond)
848
- return emb # * sigma_emb
849
-
850
- def measurement_conditioning_config_C(self, x, y, physics, img_channels, scale=0):
851
- physics.set_scale(scale)
852
- dec = self.decoding_conv(x, img_channels)
853
- factor = 2**(scale)
854
- meas_y = krylov_embeddings(y, physics, factor, N=self.N, img_channels=img_channels)
855
- meas_dec = krylov_embeddings(y, physics, factor, N=self.N, x_init=dec[:, :img_channels, ...], img_channels=img_channels)
856
- for c in range(1, self.c_mult):
857
- meas_cur = krylov_embeddings(y, physics, factor, N=self.N, x_init=dec[:, img_channels*c:img_channels*(c+1)],
858
- img_channels=img_channels)
859
- meas_dec = torch.cat([meas_dec, meas_cur], dim=1)
860
- meas = torch.cat([meas_y, meas_dec], dim=1)
861
- cond = self.encoding_conv(meas)
862
- emb = self.relu_encoding(cond)
863
- return emb
864
-
865
- def measurement_conditioning_config_D(self, x, y, physics, img_channels, scale=0):
866
- physics.set_scale(scale)
867
- dec = self.decoding_conv(x, img_channels)
868
- factor = 2**(scale)
869
- meas_y = krylov_embeddings(y, physics, factor, N=self.N, img_channels=img_channels)
870
- meas_dec = krylov_embeddings(y, physics, factor, N=self.N, x_init=dec[:, :img_channels, ...], img_channels=img_channels)
871
- for c in range(1, self.c_mult):
872
- meas_cur = krylov_embeddings(y, physics, factor, N=self.N, x_init=dec[:, img_channels*c:img_channels*(c+1)],
873
- img_channels=img_channels)
874
- meas_dec = torch.cat([meas_dec, meas_cur], dim=1)
875
- meas = torch.cat([meas_y, meas_dec], dim=1)
876
- cond = self.encoding_conv(meas)
877
- emb = self.relu_encoding(cond)
878
- return cond
879
-
880
- def measurement_conditioning_config_F(self, x, y, physics, img_channels):
881
- dec_large = self.decoding_conv(x, img_channels) # go from shape = (B, C, H, W) to (B, 64, 64, 64) (independent of modality)
882
- dec = self.relu_decoding(dec_large)
883
-
884
- Adec = physics.A(dec)
885
-
886
- grad = physics.A_adjoint(self.gain_gradx ** 2 * Adec - self.gain_grady ** 2 * y) # TODO: check if we need to have L2 (depending on noise nature, can be automated)
887
-
888
- if 'tomography' in physics.__class__.__name__.lower(): # or 'pansharp' in physics.__class__.__name__.lower():
889
- pinv = physics.prox_l2(dec, self.gain_pinvx ** 2 * Adec - self.gain_pinvy ** 2 * y, gamma=1e9)
890
- else:
891
- pinv = physics.A_dagger(self.gain_pinvx ** 2 * Adec - self.gain_pinvy ** 2 * y) # TODO: do we set this to gain_gradx ? To get 0 during training too?? Better for denoising I guess
892
-
893
- # Mix grad and pinv
894
- emb = grad - pinv # will be 0 in the case of denoising, but also inpainting
895
- im_emb = dec - physics.A_adjoint_A(dec) # will be 0 in the case of denoising, but not inpainting # TODO: add gains here too
896
- grad_large = emb + im_emb
897
-
898
- emb_grad = self.encoding_conv(grad_large)
899
- emb_grad = self.relu_encoding(emb_grad)
900
- return emb_grad
901
-
902
- def measurement_conditioning_config_E(self, x, y, physics, img_channels, scale=1):
903
- dec = self.decoding_conv(x, img_channels) # go from shape = (B, C, H, W) to (B, 64, 64, 64) (independent of modality)
904
-
905
- physics.set_scale(scale)
906
-
907
- # TODO: check things are batched
908
- f = physics.factor if hasattr(physics, "factor") else 1.0
909
- err = (physics.A_adjoint(physics.A(dec) - y))
910
- # snr = self.snr_module(err)
911
- snr = dec.reshape(dec.shape[0], -1).abs().mean(dim=1) / (err.reshape(err.shape[0], -1).abs().mean(dim=1) + 1e-4)
912
-
913
- gamma = 1 / (1e-4 + 1 / (snr * f ** 2 + 1)) # TODO: check square-root / mean / check if we need to add a factor in front
914
- gamma_est = gamma[(...,) + (None,) * (dec.dim() - 1)]
915
-
916
- prox = physics.prox_l2(dec, y, gamma=gamma_est * self.fact_prox)
917
- emb = self.fact_prox_skip_1 * prox + self.fact_prox_skip_2 * dec
918
-
919
- emb_grad = self.encoding_conv(emb)
920
- emb_grad = self.relu_encoding(emb_grad)
921
- return emb_grad
922
-
923
-
924
- class ResBlock(nn.Module):
925
- def __init__(
926
- self,
927
- in_channels=64,
928
- out_channels=64,
929
- kernel_size=3,
930
- stride=1,
931
- padding=1,
932
- bias=True,
933
- mode="CRC",
934
- negative_slope=0.2,
935
- embedding=False,
936
- emb_channels=None,
937
- emb_physics=False,
938
- img_channels=None,
939
- decode_upscale=None,
940
- config = 'A',
941
- head=False,
942
- tail=False,
943
- N=4,
944
- c_mult=1,
945
- depth_encoding=1,
946
- relu_in_encoding=False,
947
- skip_in_encoding=True,
948
- ):
949
- super(ResBlock, self).__init__()
950
-
951
- if not head and not tail:
952
- assert in_channels == out_channels, "Only support in_channels==out_channels."
953
- self.separate_head = isinstance(img_channels, list)
954
- self.config = config
955
- self.is_head = head
956
- self.is_tail = tail
957
-
958
- if self.is_head:
959
- self.head = InHead(img_channels, out_channels, input_layer=True)
960
-
961
- # if self.is_tail:
962
- # self.tail = OutTail(in_channels, out_channels)
963
-
964
- if not self.is_head and not self.is_tail:
965
- self.conv1 = conv(
966
- in_channels,
967
- out_channels,
968
- kernel_size,
969
- stride,
970
- padding,
971
- bias,
972
- "C",
973
- negative_slope,
974
- )
975
- self.nl = nn.ReLU(inplace=True)
976
- self.conv2 = conv(
977
- out_channels,
978
- out_channels,
979
- kernel_size,
980
- stride,
981
- padding,
982
- bias,
983
- "C",
984
- negative_slope,
985
- )
986
-
987
- if embedding:
988
- self.gain = torch.nn.Parameter(torch.tensor([1.0]), requires_grad=True)
989
- self.emb_linear = MPConv(emb_channels, out_channels, kernel=[])
990
-
991
- self.emb_physics = emb_physics
992
-
993
- if self.emb_physics:
994
- self.gain = torch.nn.Parameter(torch.tensor([1.0]), requires_grad=True)
995
- self.PhysicsBlock = MeasCondBlock(out_channels=out_channels, config=config, c_mult=c_mult,
996
- img_channels=img_channels, decode_upscale=decode_upscale,
997
- N=N, depth_encoding=depth_encoding,
998
- relu_in_encoding=relu_in_encoding, skip_in_encoding=skip_in_encoding)
999
-
1000
- def forward(self, x, emb_sigma=None, physics=None, t=None, y=None, emb_in=None, img_channels=None, scale=0):
1001
- u = self.conv1(x)
1002
- u = self.nl(u)
1003
- u_2 = self.conv2(u) # Should we sum this with below?
1004
- if self.emb_physics: # TODO: add a factor (1+gain) to the emb_meas? that depends on the input snr
1005
- emb_grad = self.PhysicsBlock(u, y, physics, t, img_channels=img_channels, scale=scale)
1006
- u_1 = self.gain * emb_grad # x - grad (sign does not matter)
1007
- else:
1008
- u_1 = 0
1009
- return x + u_2 + u_1
1010
-
1011
-
1012
-
1013
-
1014
- def calculate_fan_in_and_fan_out(tensor, pytorch_style: bool = True):
1015
- """
1016
- from https://github.com/megvii-research/basecls/blob/main/basecls/layers/wrapper.py#L77
1017
- """
1018
- if len(tensor.shape) not in (2, 4, 5):
1019
- raise ValueError(
1020
- "fan_in and fan_out can only be computed for tensor with 2/4/5 "
1021
- "dimensions"
1022
- )
1023
- if len(tensor.shape) == 5:
1024
- # `GOIKK` to `OIKK`
1025
- tensor = tensor.reshape(-1, *tensor.shape[2:]) if pytorch_style else tensor[0]
1026
-
1027
- num_input_fmaps = tensor.shape[1]
1028
- num_output_fmaps = tensor.shape[0]
1029
- receptive_field_size = 1
1030
- if len(tensor.shape) > 2:
1031
- receptive_field_size = functools.reduce(lambda x, y: x * y, tensor.shape[2:], 1)
1032
- fan_in = num_input_fmaps * receptive_field_size
1033
- fan_out = num_output_fmaps * receptive_field_size
1034
- return fan_in, fan_out
1035
-
1036
-
1037
- def weights_init_unext(m, gain_conv=1.0, gain_linear=1.0, init_type="ortho"):
1038
- if hasattr(m, "modules"):
1039
- for submodule in m.modules():
1040
- if not 'skip' in str(submodule):
1041
- if isinstance(submodule, nn.Conv2d) or isinstance(
1042
- submodule, nn.ConvTranspose2d
1043
- ):
1044
- # nn.init.orthogonal_(submodule.weight.data, gain=1.0)
1045
- k_shape = submodule.weight.data.shape[-1]
1046
- if k_shape < 4:
1047
- nn.init.orthogonal_(submodule.weight.data, gain=0.2)
1048
- else:
1049
- _, fan_out = calculate_fan_in_and_fan_out(submodule.weight)
1050
- std = math.sqrt(2 / fan_out)
1051
- nn.init.normal_(submodule.weight, 0, std)
1052
- # if init_type == 'ortho':
1053
- # nn.init.orthogonal_(submodule.weight.data, gain=gain_conv)
1054
- # elif init_type == 'kaiming':
1055
- # nn.init.kaiming_normal_(submodule.weight.data, a=0, mode='fan_in')
1056
- # elif init_type == 'xavier':
1057
- # nn.init.xavier_normal_(submodule.weight.data, gain=gain_conv)
1058
- elif isinstance(submodule, nn.Linear):
1059
- nn.init.normal_(submodule.weight.data, std=0.01)
1060
- elif 'skip' in str(submodule):
1061
- if isinstance(submodule, nn.Conv2d) or isinstance(
1062
- submodule, nn.ConvTranspose2d
1063
- ):
1064
- nn.init.ones_(submodule.weight.data)
1065
- # else:
1066
- # classname = submodule.__class__.__name__
1067
- # # print('WARNING: no init for ', classname)
1068
-
1069
- def old2new(old_key):
1070
- """
1071
- Converting old DRUNet keys to new UNExt style keys.
1072
-
1073
- PATTERNS TO MATCH:
1074
- 1. Case of downsampling blocks:
1075
- - for residual blocks (non-downsampling):
1076
- m_down3.2.res.0.weight -> m_down3.enc.2.conv1.weight
1077
- - for downsampling blocks:
1078
- m_down3.4.weight -> m_down3.downsample_strideconv.weight
1079
- 2. Case of upsampling blocks:
1080
- - for upsampling:
1081
- m_up3.0.weight -> m_up3.upsample_convtranspose.weight
1082
- - for residual blocks:
1083
- m_up3.2.res.0.weight -> m_up3.enc.2.conv1.weight
1084
- 3. Case for body blocks:
1085
- m_body.0.res.2.weight -> m_body.enc.0.conv2.weight
1086
-
1087
- Args:
1088
- old_key (str): The old key from the state dictionary.
1089
-
1090
- Returns:
1091
- str or None: The new key if matched, otherwise None.
1092
- """
1093
- # Match keys with the pattern for residual blocks (downsampling)
1094
- match_residual = re.search(r"(m_down\d+)\.(\d+)\.res\.(\d+)", old_key)
1095
- if match_residual:
1096
- prefix = match_residual.group(1) # e.g., "m_down2"
1097
- index = match_residual.group(2) # e.g., "3"
1098
- conv_index = int(match_residual.group(3)) # e.g., "0"
1099
-
1100
- # Determine the new conv index: 0 -> 1, 2 -> 2
1101
- new_conv_index = 1 if conv_index == 0 else 2
1102
- # Construct the new key
1103
- new_key = f"{prefix}.enc.{index}.conv{new_conv_index}.weight"
1104
- return new_key
1105
-
1106
- match_residual = re.search(r"(m_up\d+)\.(\d+)\.res\.(\d+)", old_key)
1107
- if match_residual:
1108
- prefix = match_residual.group(1) # e.g., "m_down2"
1109
- index = int(match_residual.group(2)) # e.g., "3"
1110
- conv_index = int(match_residual.group(3)) # e.g., "0"
1111
-
1112
- # Determine the new conv index: 0 -> 1, 2 -> 2
1113
- new_conv_index = 1 if conv_index == 0 else 2
1114
- # Construct the new key
1115
- new_key = f"{prefix}.enc.{index-1}.conv{new_conv_index}.weight"
1116
- return new_key
1117
-
1118
- match_pool_downsample = re.search(r"m_down(\d+)\.4\.weight", old_key)
1119
- if match_pool_downsample:
1120
- index = match_pool_downsample.group(1) # e.g., "1" or "2"
1121
- # Construct the new key
1122
- new_key = f"pool{index}.weight"
1123
- return new_key
1124
-
1125
- # Match keys for upsampling blocks
1126
- match_upsample = re.search(r"m_up(\d+)\.0\.weight", old_key)
1127
- if match_upsample:
1128
- index = match_upsample.group(1) # e.g., "1" or "2"
1129
- # Construct the new key
1130
- new_key = f"up{index}.weight"
1131
- return new_key
1132
-
1133
- # Match keys for body blocks
1134
- match_body = re.search(r"(m_body)\.(\d+)\.res\.(\d+)\.weight", old_key)
1135
- if match_body:
1136
- prefix = match_body.group(1) # e.g., "m_body"
1137
- index = match_body.group(2) # e.g., "0"
1138
- conv_index = int(match_body.group(3)) # e.g., "2"
1139
-
1140
- new_convindex = 1 if conv_index == 0 else 2
1141
-
1142
- # Construct the new key
1143
- new_key = f"{prefix}.enc.{index}.conv{new_convindex}.weight"
1144
- return new_key
1145
-
1146
- # If no patterns match, return None
1147
- return None
1148
-
1149
- def update_keyvals_headtail(old_key, old_value, init_value=None, new_key_name='m_head.conv0.weight', conditioning='base'):
1150
- """
1151
- Converting old DRUNet keys to new UNExt style keys.
1152
-
1153
- KEYS do not change but weight need to be 0 padded.
1154
-
1155
- Args:
1156
- old_key (str): The old key from the state dictionary.
1157
- """
1158
- if 'head' in old_key:
1159
- if conditioning == 'base':
1160
- c_in = init_value.shape[1]
1161
- c_in_old = old_value.shape[1]
1162
- # if c_in == c_in_old:
1163
- # new_value = old_value.detach()
1164
- # elif c_in < c_in_old:
1165
- # new_value = torch.zeros_like(init_value.detach())
1166
- # new_value[:, -1:, ...] = old_value[:, -1:, ...]
1167
- # new_value[:, :c_in-1, ...] = old_value[:, :c_in-1, ...]
1168
- # if c_in == c_in_old:
1169
- # new_value = old_value.detach()
1170
- # elif c_in < c_in_old:
1171
- new_value = torch.zeros_like(init_value.detach())
1172
- new_value[:, -2:-1, ...] = old_value[:, -1:, ...]
1173
- new_value[:, -1:, ...] = old_value[:, -1:, ...]
1174
- new_value[:, :c_in-2, ...] = old_value[:, :c_in-2, ...]
1175
- return {new_key_name: new_value}
1176
- else:
1177
- c_in = init_value.shape[1]
1178
- c_in_old = old_value.shape[1]
1179
- # if c_in == c_in_old - 1:
1180
- # new_value = old_value[:, :-1, ...].detach()
1181
- # elif c_in < c_in_old - 1:
1182
- # new_value = torch.zeros_like(init_value.detach())
1183
- # new_value[:, -1:, ...] = old_value[:, -1:, ...]
1184
- # new_value[:, ...] = old_value[:, :c_in, ...]
1185
- new_value = torch.zeros_like(init_value.detach())
1186
- new_value[:, -1:-2, ...] = old_value[:, -1:, ...]
1187
- new_value[:, -1:, ...] = old_value[:, -1:, ...]
1188
- new_value[:, ...] = old_value[:, :c_in, ...]
1189
- return {new_key_name: new_value}
1190
- elif 'tail' in old_key:
1191
- c_in = init_value.shape[0]
1192
- c_in_old = old_value.shape[0]
1193
- new_value = torch.zeros_like(init_value.detach())
1194
- if c_in == c_in_old:
1195
- new_value = old_value.detach()
1196
- elif c_in < c_in_old:
1197
- new_value = torch.zeros_like(init_value.detach())
1198
- new_value[:, ...] = old_value[:c_in, ...]
1199
- return {new_key_name: new_value}
1200
- else:
1201
- print(f"Key {old_key} does not contain 'head' or 'tail'.")
1202
-
1203
-
1204
-
1205
- # test the network
1206
- if __name__ == "__main__":
1207
- net = UNeXt()
1208
- x = torch.randn(1, 3, 128, 128)
1209
- y = net(x, 0.1)
1210
- # print(y.shape)
1211
- # print(y)
1212
-
1213
-
1214
- # Case for diagonal physics
1215
- # IDEA 1: kills signal in the image of A
1216
- # im_emb = dec - physics.A_adjoint_A(dec) # will be 0 in the case of denoising, but not inpainting # TODO: add gains here too
1217
- # IDEA 2: compute norm of signal in ker of A
1218
- # normker = (dec - physics.A_adjoint_A(dec)).norm() / (dec.norm() + 1e-4)
1219
- # im_emb = normker * physics.A_adjoint(self.gain_diag_x * physics.A(dec) - self.gain_diag_y * y) # will be 0 in the case of denoising, but not inpainting # TODO: add gains here too
1220
- # IDEA 3: same as above but add the pinv as well
1221
- # normker = (dec - physics.A_adjoint_A(dec)).norm() / (dec.norm() + 1e-4)
1222
- # grad_term = physics.A_adjoint(self.gain_diag_x * physics.A(dec) - self.gain_diag_y * y)
1223
- # # pinv_term = physics.A_dagger(self.gain_diagpinv_x * physics.A(dec) - self.gain_diagpinv_y * y)
1224
- # if 'tomography' in physics.__class__.__name__.lower(): # or 'pansharp' in physics.__class__.__name__.lower():
1225
- # pinv_term = physics.prox_l2(dec, self.gain_diagpinv_x ** 2 * Adec - self.gain_diagpinv_y ** 2 * y, gamma=1e9)
1226
- # else:
1227
- # pinv_term = physics.A_dagger(self.gain_diagpinv_x ** 2 * Adec - self.gain_diagpinv_y ** 2 * y) # TODO: do we set this to gain_gradx ? To get 0 during training too?? Better for denoising I guess
1228
- # im_emb = normker * (grad_term + pinv_term) # will be 0 in the case of denoising, but not inpainting # TODO: add gains here too
1229
-
1230
- # # Mix it
1231
- # if hasattr(physics.noise_model, 'sigma'):
1232
- # sigma = physics.noise_model.sigma # SNR ? x /= sigma ** 2
1233
- # snr = (y.abs().mean()) / (sigma + 1e-4) # SNR equivariant # TODO: add epsilon
1234
- # snr = snr[(...,) + (None,) * (im_emb.dim() - 1)]
1235
- # else:
1236
- # snr = 1e4
1237
- #
1238
- # grad_large = emb + self.gain_diag * (1 + self.gain_noise / snr) * im_emb