fossbk commited on
Commit
e548fcc
·
verified ·
1 Parent(s): 0b6fb66

Upload 3 files

Browse files
Files changed (3) hide show
  1. transforms.py +443 -0
  2. uniformer_light_image.py +535 -0
  3. uniformer_light_video.py +595 -0
transforms.py ADDED
@@ -0,0 +1,443 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchvision
2
+ import random
3
+ from PIL import Image, ImageOps
4
+ import numpy as np
5
+ import numbers
6
+ import math
7
+ import torch
8
+
9
+
10
+ class GroupRandomCrop(object):
11
+ def __init__(self, size):
12
+ if isinstance(size, numbers.Number):
13
+ self.size = (int(size), int(size))
14
+ else:
15
+ self.size = size
16
+
17
+ def __call__(self, img_group):
18
+
19
+ w, h = img_group[0].size
20
+ th, tw = self.size
21
+
22
+ out_images = list()
23
+
24
+ x1 = random.randint(0, w - tw)
25
+ y1 = random.randint(0, h - th)
26
+
27
+ for img in img_group:
28
+ assert(img.size[0] == w and img.size[1] == h)
29
+ if w == tw and h == th:
30
+ out_images.append(img)
31
+ else:
32
+ out_images.append(img.crop((x1, y1, x1 + tw, y1 + th)))
33
+
34
+ return out_images
35
+
36
+
37
+ class MultiGroupRandomCrop(object):
38
+ def __init__(self, size, groups=1):
39
+ if isinstance(size, numbers.Number):
40
+ self.size = (int(size), int(size))
41
+ else:
42
+ self.size = size
43
+ self.groups = groups
44
+
45
+ def __call__(self, img_group):
46
+
47
+ w, h = img_group[0].size
48
+ th, tw = self.size
49
+
50
+ out_images = list()
51
+
52
+ for i in range(self.groups):
53
+ x1 = random.randint(0, w - tw)
54
+ y1 = random.randint(0, h - th)
55
+
56
+ for img in img_group:
57
+ assert(img.size[0] == w and img.size[1] == h)
58
+ if w == tw and h == th:
59
+ out_images.append(img)
60
+ else:
61
+ out_images.append(img.crop((x1, y1, x1 + tw, y1 + th)))
62
+
63
+ return out_images
64
+
65
+
66
+ class GroupCenterCrop(object):
67
+ def __init__(self, size):
68
+ self.worker = torchvision.transforms.CenterCrop(size)
69
+
70
+ def __call__(self, img_group):
71
+ return [self.worker(img) for img in img_group]
72
+
73
+
74
+ class GroupRandomHorizontalFlip(object):
75
+ """Randomly horizontally flips the given PIL.Image with a probability of 0.5
76
+ """
77
+
78
+ def __init__(self, is_flow=False):
79
+ self.is_flow = is_flow
80
+
81
+ def __call__(self, img_group, is_flow=False):
82
+ v = random.random()
83
+ if v < 0.5:
84
+ ret = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group]
85
+ if self.is_flow:
86
+ for i in range(0, len(ret), 2):
87
+ # invert flow pixel values when flipping
88
+ ret[i] = ImageOps.invert(ret[i])
89
+ return ret
90
+ else:
91
+ return img_group
92
+
93
+
94
+ class GroupNormalize(object):
95
+ def __init__(self, mean, std):
96
+ self.mean = mean
97
+ self.std = std
98
+
99
+ def __call__(self, tensor):
100
+ rep_mean = self.mean * (tensor.size()[0] // len(self.mean))
101
+ rep_std = self.std * (tensor.size()[0] // len(self.std))
102
+
103
+ # TODO: make efficient
104
+ for t, m, s in zip(tensor, rep_mean, rep_std):
105
+ t.sub_(m).div_(s)
106
+
107
+ return tensor
108
+
109
+
110
+ class GroupScale(object):
111
+ """ Rescales the input PIL.Image to the given 'size'.
112
+ 'size' will be the size of the smaller edge.
113
+ For example, if height > width, then image will be
114
+ rescaled to (size * height / width, size)
115
+ size: size of the smaller edge
116
+ interpolation: Default: PIL.Image.BILINEAR
117
+ """
118
+
119
+ def __init__(self, size, interpolation=Image.BILINEAR):
120
+ self.worker = torchvision.transforms.Resize(size, interpolation)
121
+
122
+ def __call__(self, img_group):
123
+ return [self.worker(img) for img in img_group]
124
+
125
+
126
+ class GroupOverSample(object):
127
+ def __init__(self, crop_size, scale_size=None, flip=True):
128
+ self.crop_size = crop_size if not isinstance(
129
+ crop_size, int) else (crop_size, crop_size)
130
+
131
+ if scale_size is not None:
132
+ self.scale_worker = GroupScale(scale_size)
133
+ else:
134
+ self.scale_worker = None
135
+ self.flip = flip
136
+
137
+ def __call__(self, img_group):
138
+
139
+ if self.scale_worker is not None:
140
+ img_group = self.scale_worker(img_group)
141
+
142
+ image_w, image_h = img_group[0].size
143
+ crop_w, crop_h = self.crop_size
144
+
145
+ offsets = GroupMultiScaleCrop.fill_fix_offset(
146
+ False, image_w, image_h, crop_w, crop_h)
147
+ oversample_group = list()
148
+ for o_w, o_h in offsets:
149
+ normal_group = list()
150
+ flip_group = list()
151
+ for i, img in enumerate(img_group):
152
+ crop = img.crop((o_w, o_h, o_w + crop_w, o_h + crop_h))
153
+ normal_group.append(crop)
154
+ flip_crop = crop.copy().transpose(Image.FLIP_LEFT_RIGHT)
155
+
156
+ if img.mode == 'L' and i % 2 == 0:
157
+ flip_group.append(ImageOps.invert(flip_crop))
158
+ else:
159
+ flip_group.append(flip_crop)
160
+
161
+ oversample_group.extend(normal_group)
162
+ if self.flip:
163
+ oversample_group.extend(flip_group)
164
+ return oversample_group
165
+
166
+
167
+ class GroupFullResSample(object):
168
+ def __init__(self, crop_size, scale_size=None, flip=True):
169
+ self.crop_size = crop_size if not isinstance(
170
+ crop_size, int) else (crop_size, crop_size)
171
+
172
+ if scale_size is not None:
173
+ self.scale_worker = GroupScale(scale_size)
174
+ else:
175
+ self.scale_worker = None
176
+ self.flip = flip
177
+
178
+ def __call__(self, img_group):
179
+
180
+ if self.scale_worker is not None:
181
+ img_group = self.scale_worker(img_group)
182
+
183
+ image_w, image_h = img_group[0].size
184
+ crop_w, crop_h = self.crop_size
185
+
186
+ w_step = (image_w - crop_w) // 4
187
+ h_step = (image_h - crop_h) // 4
188
+
189
+ offsets = list()
190
+ offsets.append((0 * w_step, 2 * h_step)) # left
191
+ offsets.append((4 * w_step, 2 * h_step)) # right
192
+ offsets.append((2 * w_step, 2 * h_step)) # center
193
+
194
+ oversample_group = list()
195
+ for o_w, o_h in offsets:
196
+ normal_group = list()
197
+ flip_group = list()
198
+ for i, img in enumerate(img_group):
199
+ crop = img.crop((o_w, o_h, o_w + crop_w, o_h + crop_h))
200
+ normal_group.append(crop)
201
+ if self.flip:
202
+ flip_crop = crop.copy().transpose(Image.FLIP_LEFT_RIGHT)
203
+
204
+ if img.mode == 'L' and i % 2 == 0:
205
+ flip_group.append(ImageOps.invert(flip_crop))
206
+ else:
207
+ flip_group.append(flip_crop)
208
+
209
+ oversample_group.extend(normal_group)
210
+ oversample_group.extend(flip_group)
211
+ return oversample_group
212
+
213
+
214
+ class GroupMultiScaleCrop(object):
215
+
216
+ def __init__(self, input_size, scales=None, max_distort=1,
217
+ fix_crop=True, more_fix_crop=True):
218
+ self.scales = scales if scales is not None else [1, .875, .75, .66]
219
+ self.max_distort = max_distort
220
+ self.fix_crop = fix_crop
221
+ self.more_fix_crop = more_fix_crop
222
+ self.input_size = input_size if not isinstance(input_size, int) else [
223
+ input_size, input_size]
224
+ self.interpolation = Image.BILINEAR
225
+
226
+ def __call__(self, img_group):
227
+
228
+ im_size = img_group[0].size
229
+
230
+ crop_w, crop_h, offset_w, offset_h = self._sample_crop_size(im_size)
231
+ crop_img_group = [
232
+ img.crop(
233
+ (offset_w,
234
+ offset_h,
235
+ offset_w +
236
+ crop_w,
237
+ offset_h +
238
+ crop_h)) for img in img_group]
239
+ ret_img_group = [img.resize((self.input_size[0], self.input_size[1]), self.interpolation)
240
+ for img in crop_img_group]
241
+ return ret_img_group
242
+
243
+ def _sample_crop_size(self, im_size):
244
+ image_w, image_h = im_size[0], im_size[1]
245
+
246
+ # find a crop size
247
+ base_size = min(image_w, image_h)
248
+ crop_sizes = [int(base_size * x) for x in self.scales]
249
+ crop_h = [
250
+ self.input_size[1] if abs(
251
+ x - self.input_size[1]) < 3 else x for x in crop_sizes]
252
+ crop_w = [
253
+ self.input_size[0] if abs(
254
+ x - self.input_size[0]) < 3 else x for x in crop_sizes]
255
+
256
+ pairs = []
257
+ for i, h in enumerate(crop_h):
258
+ for j, w in enumerate(crop_w):
259
+ if abs(i - j) <= self.max_distort:
260
+ pairs.append((w, h))
261
+
262
+ crop_pair = random.choice(pairs)
263
+ if not self.fix_crop:
264
+ w_offset = random.randint(0, image_w - crop_pair[0])
265
+ h_offset = random.randint(0, image_h - crop_pair[1])
266
+ else:
267
+ w_offset, h_offset = self._sample_fix_offset(
268
+ image_w, image_h, crop_pair[0], crop_pair[1])
269
+
270
+ return crop_pair[0], crop_pair[1], w_offset, h_offset
271
+
272
+ def _sample_fix_offset(self, image_w, image_h, crop_w, crop_h):
273
+ offsets = self.fill_fix_offset(
274
+ self.more_fix_crop, image_w, image_h, crop_w, crop_h)
275
+ return random.choice(offsets)
276
+
277
+ @staticmethod
278
+ def fill_fix_offset(more_fix_crop, image_w, image_h, crop_w, crop_h):
279
+ w_step = (image_w - crop_w) // 4
280
+ h_step = (image_h - crop_h) // 4
281
+
282
+ ret = list()
283
+ ret.append((0, 0)) # upper left
284
+ ret.append((4 * w_step, 0)) # upper right
285
+ ret.append((0, 4 * h_step)) # lower left
286
+ ret.append((4 * w_step, 4 * h_step)) # lower right
287
+ ret.append((2 * w_step, 2 * h_step)) # center
288
+
289
+ if more_fix_crop:
290
+ ret.append((0, 2 * h_step)) # center left
291
+ ret.append((4 * w_step, 2 * h_step)) # center right
292
+ ret.append((2 * w_step, 4 * h_step)) # lower center
293
+ ret.append((2 * w_step, 0 * h_step)) # upper center
294
+
295
+ ret.append((1 * w_step, 1 * h_step)) # upper left quarter
296
+ ret.append((3 * w_step, 1 * h_step)) # upper right quarter
297
+ ret.append((1 * w_step, 3 * h_step)) # lower left quarter
298
+ ret.append((3 * w_step, 3 * h_step)) # lower righ quarter
299
+
300
+ return ret
301
+
302
+
303
+ class GroupRandomSizedCrop(object):
304
+ """Random crop the given PIL.Image to a random size of (0.08 to 1.0) of the original size
305
+ and and a random aspect ratio of 3/4 to 4/3 of the original aspect ratio
306
+ This is popularly used to train the Inception networks
307
+ size: size of the smaller edge
308
+ interpolation: Default: PIL.Image.BILINEAR
309
+ """
310
+
311
+ def __init__(self, size, interpolation=Image.BILINEAR):
312
+ self.size = size
313
+ self.interpolation = interpolation
314
+
315
+ def __call__(self, img_group):
316
+ for attempt in range(10):
317
+ area = img_group[0].size[0] * img_group[0].size[1]
318
+ target_area = random.uniform(0.08, 1.0) * area
319
+ aspect_ratio = random.uniform(3. / 4, 4. / 3)
320
+
321
+ w = int(round(math.sqrt(target_area * aspect_ratio)))
322
+ h = int(round(math.sqrt(target_area / aspect_ratio)))
323
+
324
+ if random.random() < 0.5:
325
+ w, h = h, w
326
+
327
+ if w <= img_group[0].size[0] and h <= img_group[0].size[1]:
328
+ x1 = random.randint(0, img_group[0].size[0] - w)
329
+ y1 = random.randint(0, img_group[0].size[1] - h)
330
+ found = True
331
+ break
332
+ else:
333
+ found = False
334
+ x1 = 0
335
+ y1 = 0
336
+
337
+ if found:
338
+ out_group = list()
339
+ for img in img_group:
340
+ img = img.crop((x1, y1, x1 + w, y1 + h))
341
+ assert(img.size == (w, h))
342
+ out_group.append(
343
+ img.resize(
344
+ (self.size, self.size), self.interpolation))
345
+ return out_group
346
+ else:
347
+ # Fallback
348
+ scale = GroupScale(self.size, interpolation=self.interpolation)
349
+ crop = GroupRandomCrop(self.size)
350
+ return crop(scale(img_group))
351
+
352
+
353
+ class ConvertDataFormat(object):
354
+ def __init__(self, model_type):
355
+ self.model_type = model_type
356
+
357
+ def __call__(self, images):
358
+ if self.model_type == '2D':
359
+ return images
360
+ tc, h, w = images.size()
361
+ t = tc // 3
362
+ images = images.view(t, 3, h, w)
363
+ images = images.permute(1, 0, 2, 3)
364
+ return images
365
+
366
+
367
+ class Stack(object):
368
+
369
+ def __init__(self, roll=False):
370
+ self.roll = roll
371
+
372
+ def __call__(self, img_group):
373
+ if img_group[0].mode == 'L':
374
+ return np.concatenate([np.expand_dims(x, 2)
375
+ for x in img_group], axis=2)
376
+ elif img_group[0].mode == 'RGB':
377
+ if self.roll:
378
+ return np.concatenate([np.array(x)[:, :, ::-1]
379
+ for x in img_group], axis=2)
380
+ else:
381
+ #print(np.concatenate(img_group, axis=2).shape)
382
+ # print(img_group[0].shape)
383
+ return np.concatenate(img_group, axis=2)
384
+
385
+
386
+ class ToTorchFormatTensor(object):
387
+ """ Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255]
388
+ to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """
389
+
390
+ def __init__(self, div=True):
391
+ self.div = div
392
+
393
+ def __call__(self, pic):
394
+ if isinstance(pic, np.ndarray):
395
+ # handle numpy array
396
+ img = torch.from_numpy(pic).permute(2, 0, 1).contiguous()
397
+ else:
398
+ # handle PIL Image
399
+ img = torch.ByteTensor(
400
+ torch.ByteStorage.from_buffer(
401
+ pic.tobytes()))
402
+ img = img.view(pic.size[1], pic.size[0], len(pic.mode))
403
+ # put it from HWC to CHW format
404
+ # yikes, this transpose takes 80% of the loading time/CPU
405
+ img = img.transpose(0, 1).transpose(0, 2).contiguous()
406
+ return img.float().div(255) if self.div else img.float()
407
+
408
+
409
+ class IdentityTransform(object):
410
+
411
+ def __call__(self, data):
412
+ return data
413
+
414
+
415
+ if __name__ == "__main__":
416
+ trans = torchvision.transforms.Compose([
417
+ GroupScale(256),
418
+ GroupRandomCrop(224),
419
+ Stack(),
420
+ ToTorchFormatTensor(),
421
+ GroupNormalize(
422
+ mean=[.485, .456, .406],
423
+ std=[.229, .224, .225]
424
+ )]
425
+ )
426
+
427
+ im = Image.open('../tensorflow-model-zoo.torch/lena_299.png')
428
+
429
+ color_group = [im] * 3
430
+ rst = trans(color_group)
431
+
432
+ gray_group = [im.convert('L')] * 9
433
+ gray_rst = trans(gray_group)
434
+
435
+ trans2 = torchvision.transforms.Compose([
436
+ GroupRandomSizedCrop(256),
437
+ Stack(),
438
+ ToTorchFormatTensor(),
439
+ GroupNormalize(
440
+ mean=[.485, .456, .406],
441
+ std=[.229, .224, .225])
442
+ ])
443
+ print(trans2(color_group))
uniformer_light_image.py ADDED
@@ -0,0 +1,535 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # All rights reserved.
2
+ from collections import OrderedDict
3
+ import torch
4
+ import torch.nn as nn
5
+ from functools import partial
6
+ import torch.nn.functional as F
7
+ import math
8
+ from timm.models.vision_transformer import _cfg
9
+ from timm.models.registry import register_model
10
+ from timm.models.layers import trunc_normal_, DropPath, to_2tuple
11
+
12
+
13
+ layer_scale = False
14
+ init_value = 1e-6
15
+ global_attn = None
16
+ token_indices = None
17
+
18
+
19
+ # code is from https://github.com/YifanXu74/Evo-ViT
20
+ def easy_gather(x, indices):
21
+ # x => B x N x C
22
+ # indices => B x N
23
+ B, N, C = x.shape
24
+ N_new = indices.shape[1]
25
+ offset = torch.arange(B, dtype=torch.long, device=x.device).view(B, 1) * N
26
+ indices = indices + offset
27
+ # only select the informative tokens
28
+ out = x.reshape(B * N, C)[indices.view(-1)].reshape(B, N_new, C)
29
+ return out
30
+
31
+
32
+ # code is from https://github.com/YifanXu74/Evo-ViT
33
+ def merge_tokens(x_drop, score):
34
+ # x_drop => B x N_drop
35
+ # score => B x N_drop
36
+ weight = score / torch.sum(score, dim=1, keepdim=True)
37
+ x_drop = weight.unsqueeze(-1) * x_drop
38
+ return torch.sum(x_drop, dim=1, keepdim=True)
39
+
40
+
41
+ class Mlp(nn.Module):
42
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
43
+ super().__init__()
44
+ out_features = out_features or in_features
45
+ hidden_features = hidden_features or in_features
46
+ self.fc1 = nn.Linear(in_features, hidden_features)
47
+ self.act = act_layer()
48
+ self.fc2 = nn.Linear(hidden_features, out_features)
49
+ self.drop = nn.Dropout(drop)
50
+
51
+ def forward(self, x):
52
+ x = self.fc1(x)
53
+ x = self.act(x)
54
+ x = self.drop(x)
55
+ x = self.fc2(x)
56
+ x = self.drop(x)
57
+ return x
58
+
59
+
60
+ class CMlp(nn.Module):
61
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
62
+ super().__init__()
63
+ out_features = out_features or in_features
64
+ hidden_features = hidden_features or in_features
65
+ self.fc1 = nn.Conv2d(in_features, hidden_features, 1)
66
+ self.act = act_layer()
67
+ self.fc2 = nn.Conv2d(hidden_features, out_features, 1)
68
+ self.drop = nn.Dropout(drop)
69
+
70
+ def forward(self, x):
71
+ x = self.fc1(x)
72
+ x = self.act(x)
73
+ x = self.drop(x)
74
+ x = self.fc2(x)
75
+ x = self.drop(x)
76
+ return x
77
+
78
+
79
+ class Attention(nn.Module):
80
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., trade_off=1):
81
+ super().__init__()
82
+ self.num_heads = num_heads
83
+ head_dim = dim // num_heads
84
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
85
+ self.scale = qk_scale or head_dim ** -0.5
86
+
87
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
88
+ self.attn_drop = nn.Dropout(attn_drop)
89
+ self.proj = nn.Linear(dim, dim)
90
+ self.proj_drop = nn.Dropout(proj_drop)
91
+ # updating weight for global score
92
+ self.trade_off = trade_off
93
+
94
+ def forward(self, x):
95
+ B, N, C = x.shape
96
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
97
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
98
+
99
+ attn = (q @ k.transpose(-2, -1)) * self.scale
100
+ attn = attn.softmax(dim=-1)
101
+
102
+ # update global score
103
+ global global_attn
104
+ tradeoff = self.trade_off
105
+ if isinstance(global_attn, int):
106
+ global_attn = torch.mean(attn[:, :, 0, 1:], dim=1)
107
+ elif global_attn.shape[1] == N - 1:
108
+ # no additional token and no pruning, update all global scores
109
+ cls_attn = torch.mean(attn[:, :, 0, 1:], dim=1)
110
+ global_attn = (1 - tradeoff) * global_attn + tradeoff * cls_attn
111
+ else:
112
+ # only update the informative tokens
113
+ # the first one is class token
114
+ # the last one is rrepresentative token
115
+ cls_attn = torch.mean(attn[:, :, 0, 1:-1], dim=1)
116
+ if self.training:
117
+ temp_attn = (1 - tradeoff) * global_attn[:, :(N - 2)] + tradeoff * cls_attn
118
+ global_attn = torch.cat((temp_attn, global_attn[:, (N - 2):]), dim=1)
119
+ else:
120
+ # no use torch.cat() for fast inference
121
+ global_attn[:, :(N - 2)] = (1 - tradeoff) * global_attn[:, :(N - 2)] + tradeoff * cls_attn
122
+
123
+ attn = self.attn_drop(attn)
124
+
125
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
126
+ x = self.proj(x)
127
+ x = self.proj_drop(x)
128
+ return x
129
+
130
+
131
+ class CBlock(nn.Module):
132
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
133
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
134
+ super().__init__()
135
+ self.pos_embed = nn.Conv2d(dim, dim, 3, padding=1, groups=dim)
136
+ self.norm1 = nn.BatchNorm2d(dim)
137
+ self.conv1 = nn.Conv2d(dim, dim, 1)
138
+ self.conv2 = nn.Conv2d(dim, dim, 1)
139
+ self.attn = nn.Conv2d(dim, dim, 5, padding=2, groups=dim)
140
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
141
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
142
+ self.norm2 = nn.BatchNorm2d(dim)
143
+ mlp_hidden_dim = int(dim * mlp_ratio)
144
+ self.mlp = CMlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
145
+ global layer_scale
146
+ self.ls = layer_scale
147
+ if self.ls:
148
+ global init_value
149
+ print(f"Use layer_scale: {layer_scale}, init_values: {init_value}")
150
+ self.gamma_1 = nn.Parameter(init_value * torch.ones((1, dim, 1, 1)),requires_grad=True)
151
+ self.gamma_2 = nn.Parameter(init_value * torch.ones((1, dim, 1, 1)),requires_grad=True)
152
+
153
+ def forward(self, x):
154
+ x = x + self.pos_embed(x)
155
+ if self.ls:
156
+ x = x + self.drop_path(self.gamma_1 * self.conv2(self.attn(self.conv1(self.norm1(x)))))
157
+ x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
158
+ else:
159
+ x = x + self.drop_path(self.conv2(self.attn(self.conv1(self.norm1(x)))))
160
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
161
+ return x
162
+
163
+
164
+ class EvoSABlock(nn.Module):
165
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
166
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, prune_ratio=1,
167
+ trade_off=0, downsample=False):
168
+ super().__init__()
169
+ self.pos_embed = nn.Conv2d(dim, dim, 3, padding=1, groups=dim)
170
+ self.norm1 = norm_layer(dim)
171
+ self.attn = Attention(
172
+ dim,
173
+ num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
174
+ attn_drop=attn_drop, proj_drop=drop, trade_off=trade_off)
175
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
176
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
177
+ self.norm2 = norm_layer(dim)
178
+ mlp_hidden_dim = int(dim * mlp_ratio)
179
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
180
+ self.prune_ratio = prune_ratio
181
+ self.downsample = downsample
182
+ if downsample:
183
+ self.avgpool = nn.AvgPool2d(kernel_size=2, stride=2)
184
+ global layer_scale
185
+ self.ls = layer_scale
186
+ if self.ls:
187
+ global init_value
188
+ print(f"Use layer_scale: {layer_scale}, init_values: {init_value}")
189
+ self.gamma_1 = nn.Parameter(init_value * torch.ones((dim)),requires_grad=True)
190
+ self.gamma_2 = nn.Parameter(init_value * torch.ones((dim)),requires_grad=True)
191
+ if self.prune_ratio != 1:
192
+ self.gamma_3 = nn.Parameter(init_value * torch.ones((dim)),requires_grad=True)
193
+
194
+ def forward(self, cls_token, x):
195
+ x = x + self.pos_embed(x)
196
+ B, C, H, W = x.shape
197
+ x = x.flatten(2).transpose(1, 2)
198
+
199
+ if self.prune_ratio == 1:
200
+ x = torch.cat([cls_token, x], dim=1)
201
+ if self.ls:
202
+ x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x)))
203
+ x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
204
+ else:
205
+ x = x + self.drop_path(self.attn(self.norm1(x)))
206
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
207
+ cls_token, x = x[:, :1], x[:, 1:]
208
+ x = x.transpose(1, 2).reshape(B, C, H, W)
209
+ return cls_token, x
210
+ else:
211
+ global global_attn, token_indices
212
+ # calculate the number of informative tokens
213
+ N = x.shape[1]
214
+ N_ = int(N * self.prune_ratio)
215
+ # sort global attention
216
+ indices = torch.argsort(global_attn, dim=1, descending=True)
217
+
218
+ # concatenate x, global attention and token indices => x_ga_ti
219
+ # rearrange the tensor according to new indices
220
+ x_ga_ti = torch.cat((x, global_attn.unsqueeze(-1), token_indices.unsqueeze(-1)), dim=-1)
221
+ x_ga_ti = easy_gather(x_ga_ti, indices)
222
+ x_sorted, global_attn, token_indices = x_ga_ti[:, :, :-2], x_ga_ti[:, :, -2], x_ga_ti[:, :, -1]
223
+
224
+ # informative tokens
225
+ x_info = x_sorted[:, :N_]
226
+ # merge dropped tokens
227
+ x_drop = x_sorted[:, N_:]
228
+ score = global_attn[:, N_:]
229
+ # B x N_drop x C => B x 1 x C
230
+ rep_token = merge_tokens(x_drop, score)
231
+ # concatenate new tokens
232
+ x = torch.cat((cls_token, x_info, rep_token), dim=1)
233
+
234
+ if self.ls:
235
+ # slow update
236
+ fast_update = 0
237
+ tmp_x = self.attn(self.norm1(x))
238
+ fast_update = fast_update + tmp_x[:, -1:]
239
+ x = x + self.drop_path(self.gamma_1 * tmp_x)
240
+ tmp_x = self.mlp(self.norm2(x))
241
+ fast_update = fast_update + tmp_x[:, -1:]
242
+ x = x + self.drop_path(self.gamma_2 * tmp_x)
243
+ # fast update
244
+ x_drop = x_drop + self.gamma_3 * fast_update.expand(-1, N - N_, -1)
245
+ else:
246
+ # slow update
247
+ fast_update = 0
248
+ tmp_x = self.attn(self.norm1(x))
249
+ fast_update = fast_update + tmp_x[:, -1:]
250
+ x = x + self.drop_path(tmp_x)
251
+ tmp_x = self.mlp(self.norm2(x))
252
+ fast_update = fast_update + tmp_x[:, -1:]
253
+ x = x + self.drop_path(tmp_x)
254
+ # fast update
255
+ x_drop = x_drop + fast_update.expand(-1, N - N_, -1)
256
+
257
+ cls_token, x = x[:, :1, :], x[:, 1:-1, :]
258
+ if self.training:
259
+ x_sorted = torch.cat((x, x_drop), dim=1)
260
+ else:
261
+ x_sorted[:, N_:] = x_drop
262
+ x_sorted[:, :N_] = x
263
+
264
+ # recover token
265
+ # scale for normalization
266
+ old_global_scale = torch.sum(global_attn, dim=1, keepdim=True)
267
+ # recover order
268
+ indices = torch.argsort(token_indices, dim=1)
269
+ x_ga_ti = torch.cat((x_sorted, global_attn.unsqueeze(-1), token_indices.unsqueeze(-1)), dim=-1)
270
+ x_ga_ti = easy_gather(x_ga_ti, indices)
271
+ x_patch, global_attn, token_indices = x_ga_ti[:, :, :-2], x_ga_ti[:, :, -2], x_ga_ti[:, :, -1]
272
+ x_patch = x_patch.transpose(1, 2).reshape(B, C, H, W)
273
+
274
+ if self.downsample:
275
+ # downsample global attention
276
+ global_attn = global_attn.reshape(B, 1, H, W)
277
+ global_attn = self.avgpool(global_attn).view(B, -1)
278
+ # normalize global attention
279
+ new_global_scale = torch.sum(global_attn, dim=1, keepdim=True)
280
+ scale = old_global_scale / new_global_scale
281
+ global_attn = global_attn * scale
282
+
283
+ return cls_token, x_patch
284
+
285
+
286
+ class PatchEmbed(nn.Module):
287
+ """ Image to Patch Embedding
288
+ """
289
+ def __init__(self, patch_size=16, in_chans=3, embed_dim=768):
290
+ super().__init__()
291
+ self.norm = nn.LayerNorm(embed_dim)
292
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
293
+
294
+ def forward(self, x):
295
+ x = self.proj(x)
296
+ B, C, H, W = x.shape
297
+ x = x.flatten(2).transpose(1, 2)
298
+ x = self.norm(x)
299
+ x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
300
+ return x
301
+
302
+
303
+ class head_embedding(nn.Module):
304
+ def __init__(self, in_channels, out_channels):
305
+ super(head_embedding, self).__init__()
306
+ self.proj = nn.Sequential(
307
+ nn.Conv2d(in_channels, out_channels // 2, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
308
+ nn.BatchNorm2d(out_channels // 2),
309
+ nn.GELU(),
310
+ nn.Conv2d(out_channels // 2, out_channels, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
311
+ nn.BatchNorm2d(out_channels),
312
+ )
313
+
314
+ def forward(self, x):
315
+ x = self.proj(x)
316
+ return x
317
+
318
+
319
+ class middle_embedding(nn.Module):
320
+ def __init__(self, in_channels, out_channels):
321
+ super(middle_embedding, self).__init__()
322
+
323
+ self.proj = nn.Sequential(
324
+ nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
325
+ nn.BatchNorm2d(out_channels),
326
+ )
327
+
328
+ def forward(self, x):
329
+ x = self.proj(x)
330
+ return x
331
+
332
+
333
+ class UniFormer_Light(nn.Module):
334
+ """ Vision Transformer
335
+ A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` -
336
+ https://arxiv.org/abs/2010.11929
337
+ """
338
+ def __init__(self, depth=[3, 4, 8, 3], in_chans=3, num_classes=1000, embed_dim=[64, 128, 320, 512],
339
+ head_dim=64, mlp_ratio=[4., 4., 4., 4.], qkv_bias=True, qk_scale=None, representation_size=None,
340
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None, conv_stem=False,
341
+ prune_ratio=[[], [], [1, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5], [0.5, 0.5, 0.5]],
342
+ trade_off=[[], [], [1, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5], [0.5, 0.5, 0.5]]):
343
+ """
344
+ Args:
345
+ img_size (int, tuple): input image size
346
+ patch_size (int, tuple): patch size
347
+ in_chans (int): number of input channels
348
+ num_classes (int): number of classes for classification head
349
+ embed_dim (int): embedding dimension
350
+ depth (int): depth of transformer
351
+ head_dim (int): head dimension
352
+ mlp_ratio (list): ratio of mlp hidden dim to embedding dim
353
+ qkv_bias (bool): enable bias for qkv if True
354
+ qk_scale (float): override default qk scale of head_dim ** -0.5 if set
355
+ representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
356
+ drop_rate (float): dropout rate
357
+ attn_drop_rate (float): attention dropout rate
358
+ drop_path_rate (float): stochastic depth rate
359
+ norm_layer: (nn.Module): normalization layer
360
+ """
361
+ super().__init__()
362
+ self.num_classes = num_classes
363
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
364
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
365
+ if conv_stem:
366
+ self.patch_embed1 = head_embedding(in_channels=in_chans, out_channels=embed_dim[0])
367
+ self.patch_embed2 = PatchEmbed(
368
+ patch_size=2, in_chans=embed_dim[0], embed_dim=embed_dim[1])
369
+ self.patch_embed3 = PatchEmbed(
370
+ patch_size=2, in_chans=embed_dim[1], embed_dim=embed_dim[2])
371
+ self.patch_embed4 = PatchEmbed(
372
+ patch_size=2, in_chans=embed_dim[2], embed_dim=embed_dim[3])
373
+ else:
374
+ self.patch_embed1 = PatchEmbed(
375
+ patch_size=4, in_chans=in_chans, embed_dim=embed_dim[0])
376
+ self.patch_embed2 = PatchEmbed(
377
+ patch_size=2, in_chans=embed_dim[0], embed_dim=embed_dim[1])
378
+ self.patch_embed3 = PatchEmbed(
379
+ patch_size=2, in_chans=embed_dim[1], embed_dim=embed_dim[2])
380
+ self.patch_embed4 = PatchEmbed(
381
+ patch_size=2, in_chans=embed_dim[2], embed_dim=embed_dim[3])
382
+
383
+ # class token
384
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim[2]))
385
+ self.cls_upsample = nn.Linear(embed_dim[2], embed_dim[3])
386
+
387
+ self.pos_drop = nn.Dropout(p=drop_rate)
388
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depth))] # stochastic depth decay rule
389
+ num_heads = [dim // head_dim for dim in embed_dim]
390
+ self.blocks1 = nn.ModuleList([
391
+ CBlock(
392
+ dim=embed_dim[0], num_heads=num_heads[0], mlp_ratio=mlp_ratio[0], qkv_bias=qkv_bias, qk_scale=qk_scale,
393
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
394
+ for i in range(depth[0])])
395
+ self.blocks2 = nn.ModuleList([
396
+ CBlock(
397
+ dim=embed_dim[1], num_heads=num_heads[1], mlp_ratio=mlp_ratio[1], qkv_bias=qkv_bias, qk_scale=qk_scale,
398
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i+depth[0]], norm_layer=norm_layer)
399
+ for i in range(depth[1])])
400
+ self.blocks3 = nn.ModuleList([
401
+ EvoSABlock(
402
+ dim=embed_dim[2], num_heads=num_heads[2], mlp_ratio=mlp_ratio[2], qkv_bias=qkv_bias, qk_scale=qk_scale,
403
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i+depth[0]+depth[1]], norm_layer=norm_layer,
404
+ prune_ratio=prune_ratio[2][i], trade_off=trade_off[2][i],
405
+ downsample=True if i == depth[2] - 1 else False)
406
+ for i in range(depth[2])])
407
+ self.blocks4 = nn.ModuleList([
408
+ EvoSABlock(
409
+ dim=embed_dim[3], num_heads=num_heads[3], mlp_ratio=mlp_ratio[3], qkv_bias=qkv_bias, qk_scale=qk_scale,
410
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i+depth[0]+depth[1]+depth[2]], norm_layer=norm_layer,
411
+ prune_ratio=prune_ratio[3][i], trade_off=trade_off[3][i])
412
+ for i in range(depth[3])])
413
+ self.norm = nn.BatchNorm2d(embed_dim[-1])
414
+ self.norm_cls = nn.LayerNorm(embed_dim[-1])
415
+
416
+ # Representation layer
417
+ if representation_size:
418
+ self.num_features = representation_size
419
+ self.pre_logits = nn.Sequential(OrderedDict([
420
+ ('fc', nn.Linear(embed_dim, representation_size)),
421
+ ('act', nn.Tanh())
422
+ ]))
423
+ else:
424
+ self.pre_logits = nn.Identity()
425
+
426
+ # Classifier head
427
+ self.head = nn.Linear(embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity()
428
+ self.head_cls = nn.Linear(embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity()
429
+
430
+ self.apply(self._init_weights)
431
+
432
+ def _init_weights(self, m):
433
+ if isinstance(m, nn.Linear):
434
+ trunc_normal_(m.weight, std=.02)
435
+ if isinstance(m, nn.Linear) and m.bias is not None:
436
+ nn.init.constant_(m.bias, 0)
437
+ elif isinstance(m, nn.LayerNorm):
438
+ nn.init.constant_(m.bias, 0)
439
+ nn.init.constant_(m.weight, 1.0)
440
+
441
+ @torch.jit.ignore
442
+ def no_weight_decay(self):
443
+ return {'pos_embed', 'cls_token'}
444
+
445
+ def get_classifier(self):
446
+ return self.head
447
+
448
+ def reset_classifier(self, num_classes, global_pool=''):
449
+ self.num_classes = num_classes
450
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
451
+
452
+ def forward_features(self, x):
453
+ B = x.shape[0]
454
+ x = self.patch_embed1(x)
455
+ x = self.pos_drop(x)
456
+ for blk in self.blocks1:
457
+ x = blk(x)
458
+ x = self.patch_embed2(x)
459
+ for blk in self.blocks2:
460
+ x = blk(x)
461
+ x = self.patch_embed3(x)
462
+ # add cls_token in stage3
463
+ cls_token = self.cls_token.expand(x.shape[0], -1, -1)
464
+ global global_attn, token_indices
465
+ global_attn = 0
466
+ token_indices = torch.arange(x.shape[2] * x.shape[3], dtype=torch.long, device=x.device).unsqueeze(0)
467
+ token_indices = token_indices.expand(x.shape[0], -1)
468
+ for blk in self.blocks3:
469
+ cls_token, x = blk(cls_token, x)
470
+ # upsample cls_token before stage4
471
+ cls_token = self.cls_upsample(cls_token)
472
+ x = self.patch_embed4(x)
473
+ # whether reset global attention? Now simple avgpool
474
+ token_indices = torch.arange(x.shape[2] * x.shape[3], dtype=torch.long, device=x.device).unsqueeze(0)
475
+ token_indices = token_indices.expand(x.shape[0], -1)
476
+ for blk in self.blocks4:
477
+ cls_token, x = blk(cls_token, x)
478
+ if self.training:
479
+ # layer normalization for cls_token
480
+ cls_token = self.norm_cls(cls_token)
481
+ x = self.norm(x)
482
+ x = self.pre_logits(x)
483
+ return cls_token, x
484
+
485
+ def forward(self, x):
486
+ cls_token, x = self.forward_features(x)
487
+ x = x.flatten(2).mean(-1)
488
+ if self.training:
489
+ x = self.head(x), self.head_cls(cls_token.squeeze(1))
490
+ else:
491
+ x = self.head(x)
492
+ return x
493
+
494
+
495
+ def uniformer_xxs_image(**kwargs):
496
+ model = UniFormer_Light(
497
+ depth=[2, 5, 8, 2], conv_stem=True,
498
+ prune_ratio=[[], [], [1, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5], [0.5, 0.5]],
499
+ trade_off=[[], [], [1, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5], [0.5, 0.5]],
500
+ embed_dim=[56, 112, 224, 448], head_dim=28, mlp_ratio=[3, 3, 3, 3], qkv_bias=True,
501
+ **kwargs)
502
+ model.default_cfg = _cfg()
503
+ return model
504
+
505
+
506
+ def uniformer_xs_image(**kwargs):
507
+ model = UniFormer_Light(
508
+ depth=[3, 5, 9, 3], conv_stem=True,
509
+ prune_ratio=[[], [], [1, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5], [0.5, 0.5, 0.5]],
510
+ trade_off=[[], [], [1, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5], [0.5, 0.5, 0.5]],
511
+ embed_dim=[64, 128, 256, 512], head_dim=32, mlp_ratio=[3, 3, 3, 3], qkv_bias=True,
512
+ **kwargs)
513
+ model.default_cfg = _cfg()
514
+ return model
515
+
516
+
517
+ if __name__ == '__main__':
518
+ import time
519
+ from fvcore.nn import FlopCountAnalysis
520
+ from fvcore.nn import flop_count_table
521
+ import numpy as np
522
+
523
+ seed = 4217
524
+ np.random.seed(seed)
525
+ torch.manual_seed(seed)
526
+ torch.cuda.manual_seed(seed)
527
+ torch.cuda.manual_seed_all(seed)
528
+
529
+ model = uniformer_xxs_image()
530
+ # print(model)
531
+
532
+ flops = FlopCountAnalysis(model, torch.rand(1, 3, 160, 160))
533
+ s = time.time()
534
+ print(flop_count_table(flops, max_depth=1))
535
+ print(time.time()-s)
uniformer_light_video.py ADDED
@@ -0,0 +1,595 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # All rights reserved.
2
+ from math import ceil, sqrt
3
+ from collections import OrderedDict
4
+ import torch
5
+ import torch.nn as nn
6
+ from functools import partial
7
+ from timm.models.vision_transformer import _cfg
8
+ from timm.models.layers import trunc_normal_, DropPath, to_2tuple
9
+ import os
10
+
11
+
12
+ global_attn = None
13
+ token_indices = None
14
+
15
+ model_path = 'path_to_models'
16
+ model_path = {
17
+ 'uniformer_xxs_128_in1k': os.path.join(model_path, 'uniformer_xxs_128_in1k.pth'),
18
+ 'uniformer_xxs_160_in1k': os.path.join(model_path, 'uniformer_xxs_160_in1k.pth'),
19
+ 'uniformer_xxs_192_in1k': os.path.join(model_path, 'uniformer_xxs_192_in1k.pth'),
20
+ 'uniformer_xxs_224_in1k': os.path.join(model_path, 'uniformer_xxs_224_in1k.pth'),
21
+ 'uniformer_xs_192_in1k': os.path.join(model_path, 'uniformer_xs_192_in1k.pth'),
22
+ 'uniformer_xs_224_in1k': os.path.join(model_path, 'uniformer_xs_224_in1k.pth'),
23
+ }
24
+
25
+
26
+ def conv_3xnxn(inp, oup, kernel_size=3, stride=3, groups=1):
27
+ return nn.Conv3d(inp, oup, (3, kernel_size, kernel_size), (2, stride, stride), (1, 0, 0), groups=groups)
28
+
29
+ def conv_1xnxn(inp, oup, kernel_size=3, stride=3, groups=1):
30
+ return nn.Conv3d(inp, oup, (1, kernel_size, kernel_size), (1, stride, stride), (0, 0, 0), groups=groups)
31
+
32
+ def conv_3xnxn_std(inp, oup, kernel_size=3, stride=3, groups=1):
33
+ return nn.Conv3d(inp, oup, (3, kernel_size, kernel_size), (1, stride, stride), (1, 0, 0), groups=groups)
34
+
35
+ def conv_1x1x1(inp, oup, groups=1):
36
+ return nn.Conv3d(inp, oup, (1, 1, 1), (1, 1, 1), (0, 0, 0), groups=groups)
37
+
38
+ def conv_3x3x3(inp, oup, groups=1):
39
+ return nn.Conv3d(inp, oup, (3, 3, 3), (1, 1, 1), (1, 1, 1), groups=groups)
40
+
41
+ def conv_5x5x5(inp, oup, groups=1):
42
+ return nn.Conv3d(inp, oup, (5, 5, 5), (1, 1, 1), (2, 2, 2), groups=groups)
43
+
44
+ def bn_3d(dim):
45
+ return nn.BatchNorm3d(dim)
46
+
47
+
48
+ # code is from https://github.com/YifanXu74/Evo-ViT
49
+ def easy_gather(x, indices):
50
+ # x => B x N x C
51
+ # indices => B x N
52
+ B, N, C = x.shape
53
+ N_new = indices.shape[1]
54
+ offset = torch.arange(B, dtype=torch.long, device=x.device).view(B, 1) * N
55
+ indices = indices + offset
56
+ # only select the informative tokens
57
+ out = x.reshape(B * N, C)[indices.view(-1)].reshape(B, N_new, C)
58
+ return out
59
+
60
+
61
+ # code is from https://github.com/YifanXu74/Evo-ViT
62
+ def merge_tokens(x_drop, score):
63
+ # x_drop => B x N_drop
64
+ # score => B x N_drop
65
+ weight = score / torch.sum(score, dim=1, keepdim=True)
66
+ x_drop = weight.unsqueeze(-1) * x_drop
67
+ return torch.sum(x_drop, dim=1, keepdim=True)
68
+
69
+
70
+ class Mlp(nn.Module):
71
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
72
+ super().__init__()
73
+ out_features = out_features or in_features
74
+ hidden_features = hidden_features or in_features
75
+ self.fc1 = nn.Linear(in_features, hidden_features)
76
+ self.act = act_layer()
77
+ self.fc2 = nn.Linear(hidden_features, out_features)
78
+ self.drop = nn.Dropout(drop)
79
+
80
+ def forward(self, x):
81
+ x = self.fc1(x)
82
+ x = self.act(x)
83
+ x = self.drop(x)
84
+ x = self.fc2(x)
85
+ x = self.drop(x)
86
+ return x
87
+
88
+
89
+ class Attention(nn.Module):
90
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., trade_off=1):
91
+ super().__init__()
92
+ self.num_heads = num_heads
93
+ head_dim = dim // num_heads
94
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
95
+ self.scale = qk_scale or head_dim ** -0.5
96
+
97
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
98
+ self.attn_drop = nn.Dropout(attn_drop)
99
+ self.proj = nn.Linear(dim, dim)
100
+ self.proj_drop = nn.Dropout(proj_drop)
101
+ # updating weight for global score
102
+ self.trade_off = trade_off
103
+
104
+ def forward(self, x):
105
+ B, N, C = x.shape
106
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
107
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
108
+
109
+ attn = (q @ k.transpose(-2, -1)) * self.scale
110
+ attn = attn.softmax(dim=-1)
111
+
112
+ # update global score
113
+ global global_attn
114
+ tradeoff = self.trade_off
115
+ if isinstance(global_attn, int):
116
+ global_attn = torch.mean(attn[:, :, 0, 1:], dim=1)
117
+ elif global_attn.shape[1] == N - 1:
118
+ # no additional token and no pruning, update all global scores
119
+ cls_attn = torch.mean(attn[:, :, 0, 1:], dim=1)
120
+ global_attn = (1 - tradeoff) * global_attn + tradeoff * cls_attn
121
+ else:
122
+ # only update the informative tokens
123
+ # the first one is class token
124
+ # the last one is rrepresentative token
125
+ cls_attn = torch.mean(attn[:, :, 0, 1:-1], dim=1)
126
+ if self.training:
127
+ temp_attn = (1 - tradeoff) * global_attn[:, :(N - 2)] + tradeoff * cls_attn
128
+ global_attn = torch.cat((temp_attn, global_attn[:, (N - 2):]), dim=1)
129
+ else:
130
+ # no use torch.cat() for fast inference
131
+ global_attn[:, :(N - 2)] = (1 - tradeoff) * global_attn[:, :(N - 2)] + tradeoff * cls_attn
132
+
133
+ attn = self.attn_drop(attn)
134
+
135
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
136
+ x = self.proj(x)
137
+ x = self.proj_drop(x)
138
+ return x
139
+
140
+
141
+ class CMlp(nn.Module):
142
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
143
+ super().__init__()
144
+ out_features = out_features or in_features
145
+ hidden_features = hidden_features or in_features
146
+ self.fc1 = conv_1x1x1(in_features, hidden_features)
147
+ self.act = act_layer()
148
+ self.fc2 = conv_1x1x1(hidden_features, out_features)
149
+ self.drop = nn.Dropout(drop)
150
+
151
+ def forward(self, x):
152
+ x = self.fc1(x)
153
+ x = self.act(x)
154
+ x = self.drop(x)
155
+ x = self.fc2(x)
156
+ x = self.drop(x)
157
+ return x
158
+
159
+
160
+ class CBlock(nn.Module):
161
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
162
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
163
+ super().__init__()
164
+ self.pos_embed = conv_3x3x3(dim, dim, groups=dim)
165
+ self.norm1 = bn_3d(dim)
166
+ self.conv1 = conv_1x1x1(dim, dim, 1)
167
+ self.conv2 = conv_1x1x1(dim, dim, 1)
168
+ self.attn = conv_5x5x5(dim, dim, groups=dim)
169
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
170
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
171
+ self.norm2 = bn_3d(dim)
172
+ mlp_hidden_dim = int(dim * mlp_ratio)
173
+ self.mlp = CMlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
174
+
175
+ def forward(self, x):
176
+ x = x + self.pos_embed(x)
177
+ x = x + self.drop_path(self.conv2(self.attn(self.conv1(self.norm1(x)))))
178
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
179
+ return x
180
+
181
+
182
+ class EvoSABlock(nn.Module):
183
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
184
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, prune_ratio=1,
185
+ trade_off=0, downsample=False):
186
+ super().__init__()
187
+ self.pos_embed = conv_3x3x3(dim, dim, groups=dim)
188
+ self.norm1 = norm_layer(dim)
189
+ self.attn = Attention(
190
+ dim,
191
+ num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
192
+ attn_drop=attn_drop, proj_drop=drop, trade_off=trade_off)
193
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
194
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
195
+ self.norm2 = norm_layer(dim)
196
+ mlp_hidden_dim = int(dim * mlp_ratio)
197
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
198
+ self.prune_ratio = prune_ratio
199
+ self.downsample = downsample
200
+ if downsample:
201
+ self.avgpool = nn.AvgPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2))
202
+
203
+ def forward(self, cls_token, x):
204
+ x = x + self.pos_embed(x)
205
+ B, C, T, H, W = x.shape
206
+ x = x.flatten(2).transpose(1, 2)
207
+
208
+ if self.prune_ratio == 1:
209
+ x = torch.cat([cls_token, x], dim=1)
210
+ x = x + self.drop_path(self.attn(self.norm1(x)))
211
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
212
+ cls_token, x = x[:, :1], x[:, 1:]
213
+ x = x.transpose(1, 2).reshape(B, C, T, H, W)
214
+ return cls_token, x
215
+ else:
216
+ global global_attn, token_indices
217
+ # calculate the number of informative tokens
218
+ N = x.shape[1]
219
+ N_ = int(N * self.prune_ratio)
220
+ # sort global attention
221
+ indices = torch.argsort(global_attn, dim=1, descending=True)
222
+
223
+ # concatenate x, global attention and token indices => x_ga_ti
224
+ # rearrange the tensor according to new indices
225
+ x_ga_ti = torch.cat((x, global_attn.unsqueeze(-1), token_indices.unsqueeze(-1)), dim=-1)
226
+ x_ga_ti = easy_gather(x_ga_ti, indices)
227
+ x_sorted, global_attn, token_indices = x_ga_ti[:, :, :-2], x_ga_ti[:, :, -2], x_ga_ti[:, :, -1]
228
+
229
+ # informative tokens
230
+ x_info = x_sorted[:, :N_]
231
+ # merge dropped tokens
232
+ x_drop = x_sorted[:, N_:]
233
+ score = global_attn[:, N_:]
234
+ # B x N_drop x C => B x 1 x C
235
+ rep_token = merge_tokens(x_drop, score)
236
+ # concatenate new tokens
237
+ x = torch.cat((cls_token, x_info, rep_token), dim=1)
238
+
239
+ # slow update
240
+ fast_update = 0
241
+ tmp_x = self.attn(self.norm1(x))
242
+ fast_update = fast_update + tmp_x[:, -1:]
243
+ x = x + self.drop_path(tmp_x)
244
+ tmp_x = self.mlp(self.norm2(x))
245
+ fast_update = fast_update + tmp_x[:, -1:]
246
+ x = x + self.drop_path(tmp_x)
247
+ # fast update
248
+ x_drop = x_drop + fast_update.expand(-1, N - N_, -1)
249
+
250
+ cls_token, x = x[:, :1, :], x[:, 1:-1, :]
251
+ if self.training:
252
+ x_sorted = torch.cat((x, x_drop), dim=1)
253
+ else:
254
+ x_sorted[:, N_:] = x_drop
255
+ x_sorted[:, :N_] = x
256
+
257
+ # recover token
258
+ # scale for normalization
259
+ old_global_scale = torch.sum(global_attn, dim=1, keepdim=True)
260
+ # recover order
261
+ indices = torch.argsort(token_indices, dim=1)
262
+ x_ga_ti = torch.cat((x_sorted, global_attn.unsqueeze(-1), token_indices.unsqueeze(-1)), dim=-1)
263
+ x_ga_ti = easy_gather(x_ga_ti, indices)
264
+ x_patch, global_attn, token_indices = x_ga_ti[:, :, :-2], x_ga_ti[:, :, -2], x_ga_ti[:, :, -1]
265
+ x_patch = x_patch.transpose(1, 2).reshape(B, C, T, H, W)
266
+
267
+ if self.downsample:
268
+ # downsample global attention
269
+ global_attn = global_attn.reshape(B, 1, T, H, W)
270
+ global_attn = self.avgpool(global_attn).view(B, -1)
271
+ # normalize global attention
272
+ new_global_scale = torch.sum(global_attn, dim=1, keepdim=True)
273
+ scale = old_global_scale / new_global_scale
274
+ global_attn = global_attn * scale
275
+
276
+ return cls_token, x_patch
277
+
278
+
279
+ class SABlock(nn.Module):
280
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
281
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
282
+ super().__init__()
283
+ self.pos_embed = conv_3x3x3(dim, dim, groups=dim)
284
+ self.norm1 = norm_layer(dim)
285
+ self.attn = Attention(
286
+ dim,
287
+ num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
288
+ attn_drop=attn_drop, proj_drop=drop)
289
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
290
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
291
+ self.norm2 = norm_layer(dim)
292
+ mlp_hidden_dim = int(dim * mlp_ratio)
293
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
294
+
295
+ def forward(self, x):
296
+ x = x + self.pos_embed(x)
297
+ B, C, T, H, W = x.shape
298
+ x = x.flatten(2).transpose(1, 2)
299
+ x = x + self.drop_path(self.attn(self.norm1(x)))
300
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
301
+ x = x.transpose(1, 2).reshape(B, C, T, H, W)
302
+ return x
303
+
304
+
305
+ class SplitSABlock(nn.Module):
306
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
307
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
308
+ super().__init__()
309
+ self.pos_embed = conv_3x3x3(dim, dim, groups=dim)
310
+ self.t_norm = norm_layer(dim)
311
+ self.t_attn = Attention(
312
+ dim,
313
+ num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
314
+ attn_drop=attn_drop, proj_drop=drop)
315
+ self.norm1 = norm_layer(dim)
316
+ self.attn = Attention(
317
+ dim,
318
+ num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
319
+ attn_drop=attn_drop, proj_drop=drop)
320
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
321
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
322
+ self.norm2 = norm_layer(dim)
323
+ mlp_hidden_dim = int(dim * mlp_ratio)
324
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
325
+
326
+ def forward(self, x):
327
+ x = x + self.pos_embed(x)
328
+ B, C, T, H, W = x.shape
329
+ attn = x.view(B, C, T, H * W).permute(0, 3, 2, 1).contiguous()
330
+ attn = attn.view(B * H * W, T, C)
331
+ attn = attn + self.drop_path(self.t_attn(self.t_norm(attn)))
332
+ attn = attn.view(B, H * W, T, C).permute(0, 2, 1, 3).contiguous()
333
+ attn = attn.view(B * T, H * W, C)
334
+ residual = x.view(B, C, T, H * W).permute(0, 2, 3, 1).contiguous()
335
+ residual = residual.view(B * T, H * W, C)
336
+ attn = residual + self.drop_path(self.attn(self.norm1(attn)))
337
+ attn = attn.view(B, T * H * W, C)
338
+ out = attn + self.drop_path(self.mlp(self.norm2(attn)))
339
+ out = out.transpose(1, 2).reshape(B, C, T, H, W)
340
+ return out
341
+
342
+
343
+ class SpeicalPatchEmbed(nn.Module):
344
+ """ Image to Patch Embedding
345
+ """
346
+ def __init__(self, patch_size=16, in_chans=3, embed_dim=768):
347
+ super().__init__()
348
+ patch_size = to_2tuple(patch_size)
349
+ self.patch_size = patch_size
350
+
351
+ self.proj = nn.Sequential(
352
+ nn.Conv3d(in_chans, embed_dim // 2, kernel_size=(3, 3, 3), stride=(1, 2, 2), padding=(1, 1, 1)),
353
+ nn.BatchNorm3d(embed_dim // 2),
354
+ nn.GELU(),
355
+ nn.Conv3d(embed_dim // 2, embed_dim, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1)),
356
+ nn.BatchNorm3d(embed_dim),
357
+ )
358
+
359
+ def forward(self, x):
360
+ B, C, T, H, W = x.shape
361
+ # FIXME look at relaxing size constraints
362
+ # assert H == self.img_size[0] and W == self.img_size[1], \
363
+ # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
364
+ x = self.proj(x)
365
+ B, C, T, H, W = x.shape
366
+ x = x.flatten(2).transpose(1, 2)
367
+ x = x.reshape(B, T, H, W, -1).permute(0, 4, 1, 2, 3).contiguous()
368
+ return x
369
+
370
+
371
+ class PatchEmbed(nn.Module):
372
+ """ Image to Patch Embedding
373
+ """
374
+ def __init__(self, patch_size=16, in_chans=3, embed_dim=768):
375
+ super().__init__()
376
+ patch_size = to_2tuple(patch_size)
377
+ self.patch_size = patch_size
378
+ self.norm = nn.LayerNorm(embed_dim)
379
+ self.proj = conv_1xnxn(in_chans, embed_dim, kernel_size=patch_size[0], stride=patch_size[0])
380
+
381
+ def forward(self, x):
382
+ B, C, T, H, W = x.shape
383
+ # FIXME look at relaxing size constraints
384
+ # assert H == self.img_size[0] and W == self.img_size[1], \
385
+ # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
386
+ x = self.proj(x)
387
+ B, C, T, H, W = x.shape
388
+ x = x.flatten(2).transpose(1, 2)
389
+ x = self.norm(x)
390
+ x = x.reshape(B, T, H, W, -1).permute(0, 4, 1, 2, 3).contiguous()
391
+ return x
392
+
393
+
394
+ class Uniformer_light(nn.Module):
395
+ """ Vision Transformer
396
+ A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` -
397
+ https://arxiv.org/abs/2010.11929
398
+ """
399
+ def __init__(self, depth=[3, 4, 8, 3], in_chans=3, num_classes=400, embed_dim=[64, 128, 320, 512],
400
+ head_dim=64, mlp_ratio=[4., 4., 4., 4.], qkv_bias=True, qk_scale=None, representation_size=None,
401
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None,
402
+ prune_ratio=[[], [], [1, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5], [0.5, 0.5, 0.5]],
403
+ trade_off=[[], [], [1, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5], [0.5, 0.5, 0.5]]
404
+ ):
405
+ super().__init__()
406
+
407
+ self.num_classes = num_classes
408
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
409
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
410
+
411
+ self.patch_embed1 = SpeicalPatchEmbed(
412
+ patch_size=4, in_chans=in_chans, embed_dim=embed_dim[0])
413
+ self.patch_embed2 = PatchEmbed(
414
+ patch_size=2, in_chans=embed_dim[0], embed_dim=embed_dim[1])
415
+ self.patch_embed3 = PatchEmbed(
416
+ patch_size=2, in_chans=embed_dim[1], embed_dim=embed_dim[2])
417
+ self.patch_embed4 = PatchEmbed(
418
+ patch_size=2, in_chans=embed_dim[2], embed_dim=embed_dim[3])
419
+
420
+ # class token
421
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim[2]))
422
+ self.cls_upsample = nn.Linear(embed_dim[2], embed_dim[3])
423
+
424
+ self.pos_drop = nn.Dropout(p=drop_rate)
425
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depth))] # stochastic depth decay rule
426
+ num_heads = [dim // head_dim for dim in embed_dim]
427
+ self.blocks1 = nn.ModuleList([
428
+ CBlock(
429
+ dim=embed_dim[0], num_heads=num_heads[0], mlp_ratio=mlp_ratio[0], qkv_bias=qkv_bias, qk_scale=qk_scale,
430
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
431
+ for i in range(depth[0])])
432
+ self.blocks2 = nn.ModuleList([
433
+ CBlock(
434
+ dim=embed_dim[1], num_heads=num_heads[1], mlp_ratio=mlp_ratio[1], qkv_bias=qkv_bias, qk_scale=qk_scale,
435
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i+depth[0]], norm_layer=norm_layer)
436
+ for i in range(depth[1])])
437
+ self.blocks3 = nn.ModuleList([
438
+ EvoSABlock(
439
+ dim=embed_dim[2], num_heads=num_heads[2], mlp_ratio=mlp_ratio[3], qkv_bias=qkv_bias, qk_scale=qk_scale,
440
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i+depth[0]+depth[1]], norm_layer=norm_layer,
441
+ prune_ratio=prune_ratio[2][i], trade_off=trade_off[2][i],
442
+ downsample=True if i == depth[2] - 1 else False)
443
+ for i in range(depth[2])])
444
+ self.blocks4 = nn.ModuleList([
445
+ EvoSABlock(
446
+ dim=embed_dim[3], num_heads=num_heads[3], mlp_ratio=mlp_ratio[3], qkv_bias=qkv_bias, qk_scale=qk_scale,
447
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i+depth[0]+depth[1]+depth[2]], norm_layer=norm_layer,
448
+ prune_ratio=prune_ratio[3][i], trade_off=trade_off[3][i])
449
+ for i in range(depth[3])])
450
+ self.norm = bn_3d(embed_dim[-1])
451
+ self.norm_cls = nn.LayerNorm(embed_dim[-1])
452
+
453
+ # Representation layer
454
+ if representation_size:
455
+ self.num_features = representation_size
456
+ self.pre_logits = nn.Sequential(OrderedDict([
457
+ ('fc', nn.Linear(embed_dim, representation_size)),
458
+ ('act', nn.Tanh())
459
+ ]))
460
+ else:
461
+ self.pre_logits = nn.Identity()
462
+
463
+ # Classifier head
464
+ self.head = nn.Linear(embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity()
465
+ self.head_cls = nn.Linear(embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity()
466
+
467
+ self.apply(self._init_weights)
468
+
469
+ for name, p in self.named_parameters():
470
+ # fill proj weight with 1 here to improve training dynamics. Otherwise temporal attention inputs
471
+ # are multiplied by 0*0, which is hard for the model to move out of.
472
+ if 't_attn.qkv.weight' in name:
473
+ nn.init.constant_(p, 0)
474
+ if 't_attn.qkv.bias' in name:
475
+ nn.init.constant_(p, 0)
476
+ if 't_attn.proj.weight' in name:
477
+ nn.init.constant_(p, 1)
478
+ if 't_attn.proj.bias' in name:
479
+ nn.init.constant_(p, 0)
480
+
481
+ def _init_weights(self, m):
482
+ if isinstance(m, nn.Linear):
483
+ trunc_normal_(m.weight, std=.02)
484
+ if isinstance(m, nn.Linear) and m.bias is not None:
485
+ nn.init.constant_(m.bias, 0)
486
+ elif isinstance(m, nn.LayerNorm):
487
+ nn.init.constant_(m.bias, 0)
488
+ nn.init.constant_(m.weight, 1.0)
489
+
490
+ @torch.jit.ignore
491
+ def no_weight_decay(self):
492
+ return {'pos_embed', 'cls_token'}
493
+
494
+ def get_classifier(self):
495
+ return self.head
496
+
497
+ def reset_classifier(self, num_classes, global_pool=''):
498
+ self.num_classes = num_classes
499
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
500
+
501
+ def inflate_weight(self, weight_2d, time_dim, center=False):
502
+ if center:
503
+ weight_3d = torch.zeros(*weight_2d.shape)
504
+ weight_3d = weight_3d.unsqueeze(2).repeat(1, 1, time_dim, 1, 1)
505
+ middle_idx = time_dim // 2
506
+ weight_3d[:, :, middle_idx, :, :] = weight_2d
507
+ else:
508
+ weight_3d = weight_2d.unsqueeze(2).repeat(1, 1, time_dim, 1, 1)
509
+ weight_3d = weight_3d / time_dim
510
+ return weight_3d
511
+
512
+ def forward_features(self, x):
513
+ x = self.patch_embed1(x)
514
+ x = self.pos_drop(x)
515
+ for blk in self.blocks1:
516
+ x = blk(x)
517
+ x = self.patch_embed2(x)
518
+ for blk in self.blocks2:
519
+ x = blk(x)
520
+ x = self.patch_embed3(x)
521
+ # add cls_token in stage3
522
+ cls_token = self.cls_token.expand(x.shape[0], -1, -1)
523
+ global global_attn, token_indices
524
+ global_attn = 0
525
+ token_indices = torch.arange(x.shape[2] * x.shape[3] * x.shape[4], dtype=torch.long, device=x.device).unsqueeze(0)
526
+ token_indices = token_indices.expand(x.shape[0], -1)
527
+ for blk in self.blocks3:
528
+ cls_token, x = blk(cls_token, x)
529
+ # upsample cls_token before stage4
530
+ cls_token = self.cls_upsample(cls_token)
531
+ x = self.patch_embed4(x)
532
+ # whether reset global attention? Now simple avgpool
533
+ token_indices = torch.arange(x.shape[2] * x.shape[3] * x.shape[4], dtype=torch.long, device=x.device).unsqueeze(0)
534
+ token_indices = token_indices.expand(x.shape[0], -1)
535
+ for blk in self.blocks4:
536
+ cls_token, x = blk(cls_token, x)
537
+ if self.training:
538
+ # layer normalization for cls_token
539
+ cls_token = self.norm_cls(cls_token)
540
+ x = self.norm(x)
541
+ x = self.pre_logits(x)
542
+ return cls_token, x
543
+
544
+ def forward(self, x):
545
+ cls_token, x = self.forward_features(x)
546
+ x = x.flatten(2).mean(-1)
547
+ if self.training:
548
+ x = self.head(x), self.head_cls(cls_token.squeeze(1))
549
+ else:
550
+ x = self.head(x)
551
+ return x
552
+
553
+
554
+ def uniformer_xxs_video(**kwargs):
555
+ model = Uniformer_light(
556
+ depth=[2, 5, 8, 2],
557
+ prune_ratio=[[], [], [1, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5], [0.5, 0.5]],
558
+ trade_off=[[], [], [1, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5], [0.5, 0.5]],
559
+ embed_dim=[56, 112, 224, 448], head_dim=28, mlp_ratio=[3, 3, 3, 3], qkv_bias=True,
560
+ **kwargs)
561
+ model.default_cfg = _cfg()
562
+ return model
563
+
564
+
565
+ def uniformer_xs_video(**kwargs):
566
+ model = Uniformer_light(
567
+ depth=[3, 5, 9, 3],
568
+ prune_ratio=[[], [], [1, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5], [0.5, 0.5, 0.5]],
569
+ trade_off=[[], [], [1, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5], [0.5, 0.5, 0.5]],
570
+ embed_dim=[64, 128, 256, 512], head_dim=32, mlp_ratio=[3, 3, 3, 3], qkv_bias=True,
571
+ **kwargs)
572
+ model.default_cfg = _cfg()
573
+ return model
574
+
575
+
576
+ if __name__ == '__main__':
577
+ import time
578
+ from fvcore.nn import FlopCountAnalysis
579
+ from fvcore.nn import flop_count_table
580
+ import numpy as np
581
+
582
+ seed = 4217
583
+ np.random.seed(seed)
584
+ torch.manual_seed(seed)
585
+ torch.cuda.manual_seed(seed)
586
+ torch.cuda.manual_seed_all(seed)
587
+ num_frames = 16
588
+
589
+ model = uniformer_xxs_video()
590
+ # print(model)
591
+
592
+ flops = FlopCountAnalysis(model, torch.rand(1, 3, num_frames, 160, 160))
593
+ s = time.time()
594
+ print(flop_count_table(flops, max_depth=1))
595
+ print(time.time()-s)