KeerthiVM commited on
Commit
0d6f593
·
1 Parent(s): a70cf1d

Initial commit

Browse files
Files changed (2) hide show
  1. dist_utils.py +137 -0
  2. eva_vit.py +451 -0
dist_utils.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import datetime
9
+ import functools
10
+ import os
11
+
12
+ import torch
13
+ import torch.distributed as dist
14
+ import timm.models.hub as timm_hub
15
+
16
+
17
+ def setup_for_distributed(is_master):
18
+ """
19
+ This function disables printing when not in master process
20
+ """
21
+ import builtins as __builtin__
22
+
23
+ builtin_print = __builtin__.print
24
+
25
+ def print(*args, **kwargs):
26
+ force = kwargs.pop("force", False)
27
+ if is_master or force:
28
+ builtin_print(*args, **kwargs)
29
+
30
+ __builtin__.print = print
31
+
32
+
33
+ def is_dist_avail_and_initialized():
34
+ if not dist.is_available():
35
+ return False
36
+ if not dist.is_initialized():
37
+ return False
38
+ return True
39
+
40
+
41
+ def get_world_size():
42
+ if not is_dist_avail_and_initialized():
43
+ return 1
44
+ return dist.get_world_size()
45
+
46
+
47
+ def get_rank():
48
+ if not is_dist_avail_and_initialized():
49
+ return 0
50
+ return dist.get_rank()
51
+
52
+
53
+ def is_main_process():
54
+ return get_rank() == 0
55
+
56
+
57
+ def init_distributed_mode(args):
58
+ if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
59
+ args.rank = int(os.environ["RANK"])
60
+ args.world_size = int(os.environ["WORLD_SIZE"])
61
+ args.gpu = int(os.environ["LOCAL_RANK"])
62
+ elif "SLURM_PROCID" in os.environ:
63
+ args.rank = int(os.environ["SLURM_PROCID"])
64
+ args.gpu = args.rank % torch.cuda.device_count()
65
+ else:
66
+ print("Not using distributed mode")
67
+ args.distributed = False
68
+ return
69
+
70
+ args.distributed = True
71
+
72
+ torch.cuda.set_device(args.gpu)
73
+ args.dist_backend = "nccl"
74
+ print(
75
+ "| distributed init (rank {}, world {}): {}".format(
76
+ args.rank, args.world_size, args.dist_url
77
+ ),
78
+ flush=True,
79
+ )
80
+ torch.distributed.init_process_group(
81
+ backend=args.dist_backend,
82
+ init_method=args.dist_url,
83
+ world_size=args.world_size,
84
+ rank=args.rank,
85
+ timeout=datetime.timedelta(
86
+ days=365
87
+ ), # allow auto-downloading and de-compressing
88
+ )
89
+ torch.distributed.barrier()
90
+ setup_for_distributed(args.rank == 0)
91
+
92
+
93
+ def get_dist_info():
94
+ if torch.__version__ < "1.0":
95
+ initialized = dist._initialized
96
+ else:
97
+ initialized = dist.is_initialized()
98
+ if initialized:
99
+ rank = dist.get_rank()
100
+ world_size = dist.get_world_size()
101
+ else: # non-distributed training
102
+ rank = 0
103
+ world_size = 1
104
+ return rank, world_size
105
+
106
+
107
+ def main_process(func):
108
+ @functools.wraps(func)
109
+ def wrapper(*args, **kwargs):
110
+ rank, _ = get_dist_info()
111
+ if rank == 0:
112
+ return func(*args, **kwargs)
113
+
114
+ return wrapper
115
+
116
+
117
+ def download_cached_file(url, check_hash=True, progress=False):
118
+ """
119
+ Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again.
120
+ If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded.
121
+ """
122
+
123
+ def get_cached_file_path():
124
+ # a hack to sync the file path across processes
125
+ parts = torch.hub.urlparse(url)
126
+ filename = os.path.basename(parts.path)
127
+ cached_file = os.path.join(timm_hub.get_cache_dir(), filename)
128
+
129
+ return cached_file
130
+
131
+ if is_main_process():
132
+ timm_hub.download_cached_file(url, check_hash, progress)
133
+
134
+ if is_dist_avail_and_initialized():
135
+ dist.barrier()
136
+
137
+ return get_cached_file_path()
eva_vit.py ADDED
@@ -0,0 +1,451 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Based on EVA, BEIT, timm and DeiT code bases
2
+ # https://github.com/baaivision/EVA
3
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm
4
+ # https://github.com/microsoft/unilm/tree/master/beit
5
+ # https://github.com/facebookresearch/deit/
6
+ # https://github.com/facebookresearch/dino
7
+ # --------------------------------------------------------'
8
+ import math
9
+ from functools import partial
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ import torch.utils.checkpoint as checkpoint
15
+ from timm.models.layers import drop_path, to_2tuple, trunc_normal_
16
+ from timm.models.registry import register_model
17
+
18
+ from dist_utils import download_cached_file
19
+
20
+
21
+ def _cfg(url='', **kwargs):
22
+ return {
23
+ 'url': url,
24
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
25
+ 'crop_pct': .9, 'interpolation': 'bicubic',
26
+ 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
27
+ **kwargs
28
+ }
29
+
30
+
31
+ class DropPath(nn.Module):
32
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
33
+ """
34
+
35
+ def __init__(self, drop_prob=None):
36
+ super(DropPath, self).__init__()
37
+ self.drop_prob = drop_prob
38
+
39
+ def forward(self, x):
40
+ return drop_path(x, self.drop_prob, self.training)
41
+
42
+ def extra_repr(self) -> str:
43
+ return 'p={}'.format(self.drop_prob)
44
+
45
+
46
+ class Mlp(nn.Module):
47
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
48
+ super().__init__()
49
+ out_features = out_features or in_features
50
+ hidden_features = hidden_features or in_features
51
+ self.fc1 = nn.Linear(in_features, hidden_features)
52
+ self.act = act_layer()
53
+ self.fc2 = nn.Linear(hidden_features, out_features)
54
+ self.drop = nn.Dropout(drop)
55
+
56
+ def forward(self, x):
57
+ x = self.fc1(x)
58
+ x = self.act(x)
59
+ # x = self.drop(x)
60
+ # commit this for the orignal BERT implement
61
+ x = self.fc2(x)
62
+ x = self.drop(x)
63
+ return x
64
+
65
+
66
+ class Attention(nn.Module):
67
+ def __init__(
68
+ self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
69
+ proj_drop=0., window_size=None, attn_head_dim=None):
70
+ super().__init__()
71
+ self.num_heads = num_heads
72
+ head_dim = dim // num_heads
73
+ if attn_head_dim is not None:
74
+ head_dim = attn_head_dim
75
+ all_head_dim = head_dim * self.num_heads
76
+ self.scale = qk_scale or head_dim ** -0.5
77
+
78
+ self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
79
+ if qkv_bias:
80
+ self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
81
+ self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
82
+ else:
83
+ self.q_bias = None
84
+ self.v_bias = None
85
+
86
+ if window_size:
87
+ self.window_size = window_size
88
+ self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
89
+ self.relative_position_bias_table = nn.Parameter(
90
+ torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
91
+ # cls to token & token 2 cls & cls to cls
92
+
93
+ # get pair-wise relative position index for each token inside the window
94
+ coords_h = torch.arange(window_size[0])
95
+ coords_w = torch.arange(window_size[1])
96
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
97
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
98
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
99
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
100
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
101
+ relative_coords[:, :, 1] += window_size[1] - 1
102
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
103
+ relative_position_index = \
104
+ torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
105
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
106
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
107
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
108
+ relative_position_index[0, 0] = self.num_relative_distance - 1
109
+
110
+ self.register_buffer("relative_position_index", relative_position_index)
111
+ else:
112
+ self.window_size = None
113
+ self.relative_position_bias_table = None
114
+ self.relative_position_index = None
115
+
116
+ self.attn_drop = nn.Dropout(attn_drop)
117
+ self.proj = nn.Linear(all_head_dim, dim)
118
+ self.proj_drop = nn.Dropout(proj_drop)
119
+
120
+ def forward(self, x, rel_pos_bias=None):
121
+ B, N, C = x.shape
122
+ qkv_bias = None
123
+ if self.q_bias is not None:
124
+ qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
125
+ # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
126
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
127
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
128
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
129
+
130
+ q = q * self.scale
131
+ attn = (q @ k.transpose(-2, -1))
132
+
133
+ if self.relative_position_bias_table is not None:
134
+ relative_position_bias = \
135
+ self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
136
+ self.window_size[0] * self.window_size[1] + 1,
137
+ self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
138
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
139
+ attn = attn + relative_position_bias.unsqueeze(0)
140
+
141
+ if rel_pos_bias is not None:
142
+ attn = attn + rel_pos_bias
143
+
144
+ attn = attn.softmax(dim=-1)
145
+ attn = self.attn_drop(attn)
146
+
147
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
148
+ x = self.proj(x)
149
+ x = self.proj_drop(x)
150
+ return x
151
+
152
+
153
+ class Block(nn.Module):
154
+
155
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
156
+ drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
157
+ window_size=None, attn_head_dim=None):
158
+ super().__init__()
159
+ self.norm1 = norm_layer(dim)
160
+ self.attn = Attention(
161
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
162
+ attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim)
163
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
164
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
165
+ self.norm2 = norm_layer(dim)
166
+ mlp_hidden_dim = int(dim * mlp_ratio)
167
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
168
+
169
+ if init_values is not None and init_values > 0:
170
+ self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
171
+ self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
172
+ else:
173
+ self.gamma_1, self.gamma_2 = None, None
174
+
175
+ def forward(self, x, rel_pos_bias=None):
176
+ if self.gamma_1 is None:
177
+ x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
178
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
179
+ else:
180
+ x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
181
+ x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
182
+ return x
183
+
184
+
185
+ class PatchEmbed(nn.Module):
186
+ """ Image to Patch Embedding
187
+ """
188
+
189
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
190
+ super().__init__()
191
+ img_size = to_2tuple(img_size)
192
+ patch_size = to_2tuple(patch_size)
193
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
194
+ self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
195
+ self.img_size = img_size
196
+ self.patch_size = patch_size
197
+ self.num_patches = num_patches
198
+
199
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
200
+
201
+ def forward(self, x, **kwargs):
202
+ B, C, H, W = x.shape
203
+ # FIXME look at relaxing size constraints
204
+ assert H == self.img_size[0] and W == self.img_size[1], \
205
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
206
+ x = self.proj(x).flatten(2).transpose(1, 2)
207
+ return x
208
+
209
+
210
+ class RelativePositionBias(nn.Module):
211
+
212
+ def __init__(self, window_size, num_heads):
213
+ super().__init__()
214
+ self.window_size = window_size
215
+ self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
216
+ self.relative_position_bias_table = nn.Parameter(
217
+ torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
218
+ # cls to token & token 2 cls & cls to cls
219
+
220
+ # get pair-wise relative position index for each token inside the window
221
+ coords_h = torch.arange(window_size[0])
222
+ coords_w = torch.arange(window_size[1])
223
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
224
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
225
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
226
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
227
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
228
+ relative_coords[:, :, 1] += window_size[1] - 1
229
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
230
+ relative_position_index = \
231
+ torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
232
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
233
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
234
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
235
+ relative_position_index[0, 0] = self.num_relative_distance - 1
236
+
237
+ self.register_buffer("relative_position_index", relative_position_index)
238
+
239
+ # trunc_normal_(self.relative_position_bias_table, std=.02)
240
+
241
+ def forward(self):
242
+ relative_position_bias = \
243
+ self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
244
+ self.window_size[0] * self.window_size[1] + 1,
245
+ self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
246
+ return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
247
+
248
+
249
+ class VisionTransformer(nn.Module):
250
+ """ Vision Transformer with support for patch or hybrid CNN input stage
251
+ """
252
+
253
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
254
+ num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
255
+ drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None,
256
+ use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False,
257
+ use_mean_pooling=True, init_scale=0.001, use_checkpoint=False):
258
+ super().__init__()
259
+ self.image_size = img_size
260
+ self.num_classes = num_classes
261
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
262
+
263
+ self.patch_embed = PatchEmbed(
264
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
265
+ num_patches = self.patch_embed.num_patches
266
+
267
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
268
+ if use_abs_pos_emb:
269
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
270
+ else:
271
+ self.pos_embed = None
272
+ self.pos_drop = nn.Dropout(p=drop_rate)
273
+
274
+ if use_shared_rel_pos_bias:
275
+ self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)
276
+ else:
277
+ self.rel_pos_bias = None
278
+ self.use_checkpoint = use_checkpoint
279
+
280
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
281
+ self.use_rel_pos_bias = use_rel_pos_bias
282
+ self.blocks = nn.ModuleList([
283
+ Block(
284
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
285
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
286
+ init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None)
287
+ for i in range(depth)])
288
+ # self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
289
+ # self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
290
+ # self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
291
+
292
+ if self.pos_embed is not None:
293
+ trunc_normal_(self.pos_embed, std=.02)
294
+ trunc_normal_(self.cls_token, std=.02)
295
+ # trunc_normal_(self.mask_token, std=.02)
296
+ # if isinstance(self.head, nn.Linear):
297
+ # trunc_normal_(self.head.weight, std=.02)
298
+ self.apply(self._init_weights)
299
+ self.fix_init_weight()
300
+
301
+ # if isinstance(self.head, nn.Linear):
302
+ # self.head.weight.data.mul_(init_scale)
303
+ # self.head.bias.data.mul_(init_scale)
304
+
305
+ def fix_init_weight(self):
306
+ def rescale(param, layer_id):
307
+ param.div_(math.sqrt(2.0 * layer_id))
308
+
309
+ for layer_id, layer in enumerate(self.blocks):
310
+ rescale(layer.attn.proj.weight.data, layer_id + 1)
311
+ rescale(layer.mlp.fc2.weight.data, layer_id + 1)
312
+
313
+ def _init_weights(self, m):
314
+ if isinstance(m, nn.Linear):
315
+ trunc_normal_(m.weight, std=.02)
316
+ if isinstance(m, nn.Linear) and m.bias is not None:
317
+ nn.init.constant_(m.bias, 0)
318
+ elif isinstance(m, nn.LayerNorm):
319
+ nn.init.constant_(m.bias, 0)
320
+ nn.init.constant_(m.weight, 1.0)
321
+
322
+ def get_classifier(self):
323
+ return self.head
324
+
325
+ def reset_classifier(self, num_classes, global_pool=''):
326
+ self.num_classes = num_classes
327
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
328
+
329
+ def forward_features(self, x):
330
+ x = self.patch_embed(x)
331
+ batch_size, seq_len, _ = x.size()
332
+
333
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
334
+ x = torch.cat((cls_tokens, x), dim=1)
335
+ if self.pos_embed is not None:
336
+ x = x + self.pos_embed
337
+ x = self.pos_drop(x)
338
+
339
+ rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
340
+ for blk in self.blocks:
341
+ if self.use_checkpoint:
342
+ x = checkpoint.checkpoint(blk, x, rel_pos_bias)
343
+ else:
344
+ x = blk(x, rel_pos_bias)
345
+ return x
346
+
347
+ # x = self.norm(x)
348
+
349
+ # if self.fc_norm is not None:
350
+ # t = x[:, 1:, :]
351
+ # return self.fc_norm(t.mean(1))
352
+ # else:
353
+ # return x[:, 0]
354
+
355
+ def forward(self, x):
356
+ x = self.forward_features(x)
357
+ # x = self.head(x)
358
+ return x
359
+
360
+ def get_intermediate_layers(self, x):
361
+ x = self.patch_embed(x)
362
+ batch_size, seq_len, _ = x.size()
363
+
364
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
365
+ x = torch.cat((cls_tokens, x), dim=1)
366
+ if self.pos_embed is not None:
367
+ x = x + self.pos_embed
368
+ x = self.pos_drop(x)
369
+
370
+ features = []
371
+ rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
372
+ for blk in self.blocks:
373
+ x = blk(x, rel_pos_bias)
374
+ features.append(x)
375
+
376
+ return features
377
+
378
+
379
+ def interpolate_pos_embed(model, checkpoint_model):
380
+ if 'pos_embed' in checkpoint_model:
381
+ pos_embed_checkpoint = checkpoint_model['pos_embed'].float()
382
+ embedding_size = pos_embed_checkpoint.shape[-1]
383
+ num_patches = model.patch_embed.num_patches
384
+ num_extra_tokens = model.pos_embed.shape[-2] - num_patches
385
+ # height (== width) for the checkpoint position embedding
386
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
387
+ # height (== width) for the new position embedding
388
+ new_size = int(num_patches ** 0.5)
389
+ # class_token and dist_token are kept unchanged
390
+ if orig_size != new_size:
391
+ print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
392
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
393
+ # only the position tokens are interpolated
394
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
395
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
396
+ pos_tokens = torch.nn.functional.interpolate(
397
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
398
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
399
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
400
+ checkpoint_model['pos_embed'] = new_pos_embed
401
+
402
+
403
+ def convert_weights_to_fp16(model: nn.Module):
404
+ """Convert applicable model parameters to fp16"""
405
+
406
+ def _convert_weights_to_fp16(l):
407
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
408
+ l.weight.data = l.weight.data.half()
409
+ if l.bias is not None:
410
+ l.bias.data = l.bias.data.half()
411
+
412
+ # if isinstance(l, (nn.MultiheadAttention, Attention)):
413
+ # for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
414
+ # tensor = getattr(l, attr)
415
+ # if tensor is not None:
416
+ # tensor.data = tensor.data.half()
417
+
418
+ model.apply(_convert_weights_to_fp16)
419
+
420
+
421
+ # def create_eva_vit_g(img_size=224, drop_path_rate=0.4, use_checkpoint=False, precision="fp16"):
422
+ def create_eva_vit_g(img_size=(224, 224), patch_size=14, embed_dim=1408, depth=39,
423
+ num_heads=1408 // 88, mlp_ratio=4.3637, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6),
424
+ init_values=1e-5, drop_path_rate=0.4, use_checkpoint=False, precision="fp16"):
425
+ model = VisionTransformer(
426
+ img_size=img_size[0],
427
+ patch_size=patch_size,
428
+ use_mean_pooling=False,
429
+ embed_dim=embed_dim,
430
+ depth=depth,
431
+ num_heads=num_heads,
432
+ mlp_ratio=mlp_ratio,
433
+ qkv_bias=qkv_bias,
434
+ drop_path_rate=drop_path_rate,
435
+ norm_layer=norm_layer,
436
+ use_checkpoint=use_checkpoint,
437
+ )
438
+ url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/eva_vit_g.pth"
439
+ cached_file = download_cached_file(
440
+ url, check_hash=False, progress=True
441
+ )
442
+ state_dict = torch.load(cached_file, map_location="cpu")
443
+ interpolate_pos_embed(model, state_dict)
444
+
445
+ incompatible_keys = model.load_state_dict(state_dict, strict=False)
446
+ # print(incompatible_keys)
447
+
448
+ if precision == "fp16":
449
+ # model.to("cuda")
450
+ convert_weights_to_fp16(model)
451
+ return model