stojnvla commited on
Commit
06d49db
·
1 Parent(s): 0a3fdf2

initial commit

Browse files
Files changed (6) hide show
  1. app.py +42 -0
  2. lposs.py +293 -0
  3. models/dino.py +69 -0
  4. models/maskclip.py +154 -0
  5. models/utils.py +82 -0
  6. requrements.txt +47 -0
app.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import PIL
3
+ import numpy as np
4
+ from models.maskclip import MaskClip
5
+ from models.dino import DINO
6
+ import torchvision.transforms as T
7
+ import torch.nn.functional as F
8
+ from lposs import lposs, lposs_plus
9
+ import torch
10
+
11
+ device = "cpu"
12
+ if torch.cuda.is_available():
13
+ print("Using GPU")
14
+ device = "cuda"
15
+ # elif torch.backends.mps.is_available():
16
+ # device = "mps"
17
+
18
+ print(f"Using device: {device}")
19
+
20
+ maskclip = MaskClip().to(device)
21
+ dino = DINO().to(device)
22
+ to_torch_tensor = T.Compose([T.Resize(size=448, max_size=2048), T.ToTensor()])
23
+
24
+ def segment_image(img: PIL.Image.Image, classnames: str, use_lposs_plus: bool | None) -> tuple[np.ndarray | PIL.Image.Image | str, list[tuple[np.ndarray | tuple[int, int, int, int], str]]]:
25
+ img_tensor = to_torch_tensor(PIL.Image.fromarray(img)).unsqueeze(0).to(device)
26
+ classnames = [c.strip() for c in classnames.split(",")]
27
+ num_classes = len(classnames)
28
+
29
+ preds = lposs(maskclip, dino, img_tensor, classnames)
30
+ if use_lposs_plus:
31
+ preds = lposs_plus(img_tensor, preds)
32
+ preds = F.interpolate(preds, size=img.shape[:-1], mode="bilinear", align_corners=False)
33
+ preds = F.softmax(preds * 100, dim=1).cpu().numpy()
34
+ return (img, [(preds[0, i, :, :], classnames[i]) for i in range(num_classes)])
35
+
36
+ demo = gr.Interface(
37
+ fn=segment_image,
38
+ inputs=["image", "text", "checkbox"],
39
+ outputs=["annotatedimage"],
40
+ )
41
+
42
+ demo.launch()
lposs.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from itertools import product
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ try:
6
+ import cupy as cp
7
+ from cupyx.scipy.sparse import csr_matrix as cp_csr_matrix, eye as cp_eye, diags as cp_diags
8
+ from cupyx.scipy.sparse import linalg as cp_s_linalg
9
+ except ImportError:
10
+ print("Cupy not installed")
11
+ import numpy as np
12
+ from scipy.sparse import csr_matrix, eye, diags
13
+ from scipy.sparse import linalg as s_linalg
14
+ from kornia.color import rgb_to_lab
15
+
16
+
17
+ def make_input_divisible(x: torch.Tensor, patch_size=16) -> torch.Tensor:
18
+ """Pad some pixels to make the input size divisible by the patch size."""
19
+ B, _, H_0, W_0 = x.shape
20
+ pad_w = (patch_size - W_0 % patch_size) % patch_size
21
+ pad_h = (patch_size - H_0 % patch_size) % patch_size
22
+
23
+ x = nn.functional.pad(x, (0, pad_w, 0, pad_h), value=0)
24
+
25
+ return x
26
+
27
+
28
+ def reshape_windows(x):
29
+ height_width = [(y.shape[0], y.shape[1]) for y in x]
30
+ dim = x[0].shape[-1]
31
+ x = [torch.reshape(y, (-1, dim)) for y in x]
32
+
33
+ return torch.cat(x, dim=0), height_width
34
+
35
+
36
+ def normalize_connection_graph_cupy(G):
37
+ W = cp_csr_matrix(G)
38
+ W = W - cp_diags(W.diagonal(), 0)
39
+ S = W.sum(axis=1)
40
+ # breakpoint()
41
+ S[S == 0] = 1
42
+ D = cp.array(1.0 / cp.sqrt(S))
43
+ D[cp.isnan(D)] = 0
44
+ D[cp.isinf(D)] = 0
45
+ D_mh = cp_diags(D.reshape(-1), 0)
46
+ Wn = D_mh * W * D_mh
47
+ return Wn
48
+
49
+
50
+ def normalize_connection_graph(G):
51
+ W = csr_matrix(G)
52
+ W = W - diags(W.diagonal(), 0)
53
+ S = W.sum(axis=1)
54
+ S[S == 0] = 1
55
+ D = np.array(1.0 / np.sqrt(S))
56
+ D[np.isnan(D)] = 0
57
+ D[np.isinf(D)] = 0
58
+ D_mh = diags(D.reshape(-1), 0)
59
+ Wn = D_mh * W * D_mh
60
+ return Wn
61
+
62
+
63
+ def cp_dfs_search(L, Y, tol=1e-6, maxiter=10):
64
+ out = cp_s_linalg.cg(L, Y, tol=tol, maxiter=maxiter)[0]
65
+
66
+ return out
67
+
68
+
69
+ def dfs_search(L, Y, tol=1e-6, maxiter=10):
70
+ out = s_linalg.cg(L, Y, rtol=tol, maxiter=maxiter)[0]
71
+
72
+ return out
73
+
74
+
75
+ def perform_lp(L, preds):
76
+ if torch.cuda.is_available():
77
+ lp_preds = cp.zeros(preds.shape)
78
+ preds = cp.asarray(preds)
79
+ for cls_idx, y_cls in enumerate(preds.T):
80
+ Y = y_cls
81
+ lp_preds[:, cls_idx] = cp_dfs_search(L, Y)
82
+ lp_preds = torch.as_tensor(lp_preds, device="cuda")
83
+ else:
84
+ lp_preds = np.zeros(preds.shape)
85
+ for cls_idx, y_cls in enumerate(preds.T):
86
+ Y = y_cls
87
+ lp_preds[:, cls_idx] = dfs_search(L, Y)
88
+ lp_preds = torch.as_tensor(lp_preds, device="cpu")
89
+
90
+ return lp_preds
91
+
92
+
93
+ def get_lposs_laplacian(feats, locations, height_width, sigma=0.0, pix_dist_pow=2, k=100, gamma=1.0, alpha=0.95, patch_size=16):
94
+ idx_window = torch.cat([window * torch.ones((h*w, ), device=feats.device, dtype=torch.int64) for window, (h, w) in enumerate(height_width)])
95
+ idx_h = torch.cat([torch.arange(h).view(-1,1).repeat(1, w).flatten() for h, w in height_width]).to(feats.device)
96
+ idx_w = torch.cat([torch.arange(w).view(1,-1).repeat(h, 1).flatten() for h, w in height_width]).to(feats.device)
97
+ loc_h = locations[idx_window, 0] + (patch_size // 2) + idx_h * patch_size
98
+ loc_w = locations[idx_window, 2] + (patch_size // 2) + idx_w * patch_size
99
+ locs = torch.stack((loc_h, loc_w), 1)
100
+ locs = torch.unsqueeze(locs, 0)
101
+ dist = torch.cdist(locs, locs, p=2)
102
+ dist = dist[0, ...]
103
+ dist = dist ** pix_dist_pow
104
+ geometry_affinity = torch.exp(-sigma * dist)
105
+
106
+ N = feats.shape[0]
107
+
108
+ affinity = feats @ feats.T
109
+ sims, ks = torch.topk(affinity, k=k, dim=1)
110
+
111
+ sims[sims < 0] = 0
112
+ sims = sims ** gamma
113
+ geometry_affinity = geometry_affinity.gather(1, ks).flatten()
114
+ sims = sims.flatten()
115
+ sims = sims * geometry_affinity
116
+ ks = ks.flatten()
117
+ rows = torch.arange(N).repeat_interleave(k)
118
+
119
+ if torch.cuda.is_available():
120
+ W = cp_csr_matrix(
121
+ (cp.asarray(sims), (cp.asarray(rows), cp.asarray(ks))),
122
+ shape=(N, N),
123
+ )
124
+ W = W + W.T
125
+ Wn = normalize_connection_graph_cupy(W)
126
+ L = cp_eye(Wn.shape[0]) - alpha * Wn
127
+ else:
128
+ W = csr_matrix(
129
+ (sims.cpu().numpy(), (rows.cpu().numpy(), ks.cpu().numpy())),
130
+ shape=(N, N),
131
+ )
132
+ W = W + W.T
133
+ Wn = normalize_connection_graph(W)
134
+ L = eye(Wn.shape[0]) - alpha * Wn
135
+
136
+ return L
137
+
138
+
139
+ def lposs(clip, dino, img, classnames, window_size=(224,224), window_stride=(112, 112), sigma=0.01, pix_dist_pow=1, lp_k_image=400, lp_gamma=3.0, lp_alpha=0.95):
140
+ h_stride, w_stride = window_stride
141
+ h_crop, w_crop = window_size
142
+ batch_size, _, h_img, w_img = img.size()
143
+ h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1
144
+ w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1
145
+
146
+ clf = clip.get_classifier(classnames)
147
+
148
+ locations = img.new_zeros((h_grids*w_grids, 4))
149
+ dino_feats = []
150
+ clip_feats = []
151
+ for h_idx in range(h_grids):
152
+ for w_idx in range(w_grids):
153
+ y1 = h_idx * h_stride
154
+ x1 = w_idx * w_stride
155
+ y2 = min(y1 + h_crop, h_img)
156
+ x2 = min(x1 + w_crop, w_img)
157
+ y1 = max(y2 - h_crop, 0)
158
+ x1 = max(x2 - w_crop, 0)
159
+ crop_img = img[:, :, y1:y2, x1:x2]
160
+
161
+ img_dino_feats, (h_dino, w_dino) = dino(make_input_divisible(crop_img, dino.patch_size)) # (1, 768, N)
162
+ img_dino_feats = img_dino_feats.reshape((batch_size, -1, h_dino, w_dino)).permute(0, 2, 3, 1) # (1, h_dino, w_dino, 768)
163
+ img_clip_feats = clip(make_input_divisible(crop_img, clip.patch_size)) # (1, 512, h, w)
164
+
165
+ if img_clip_feats.shape[1] != img_dino_feats.shape[1] or img_clip_feats.shape[2] != img_dino_feats.shape[2]:
166
+ img_clip_feats = F.interpolate(img_clip_feats, size=(img_dino_feats.shape[1], img_dino_feats.shape[2]), mode='bilinear', align_corners=False)
167
+
168
+ img_clip_feats = img_clip_feats.permute(0, 2, 3, 1) # (1, h, w, 512)
169
+
170
+ dino_feats.append(img_dino_feats[0, ...])
171
+ clip_feats.append(img_clip_feats[0, ...])
172
+ locations[h_idx*w_grids + w_idx, 0] = y1
173
+ locations[h_idx*w_grids + w_idx, 1] = y2
174
+ locations[h_idx*w_grids + w_idx, 2] = x1
175
+ locations[h_idx*w_grids + w_idx, 3] = x2
176
+
177
+ num_classes = clf.shape[0]
178
+
179
+ patch_size = dino.patch_size
180
+
181
+ dino_feats, height_width = reshape_windows(dino_feats)
182
+ clip_feats, _ = reshape_windows(clip_feats)
183
+ dino_feats = F.normalize(dino_feats, p=2, dim=-1)
184
+ clip_feats = F.normalize(clip_feats, p=2, dim=-1)
185
+
186
+ L = get_lposs_laplacian(dino_feats, locations, height_width, sigma=sigma, pix_dist_pow=pix_dist_pow, k=lp_k_image, gamma=lp_gamma, alpha=lp_alpha, patch_size=patch_size)
187
+ clip_preds = clip_feats @ clf.T
188
+
189
+ lp_preds = perform_lp(L, clip_preds)
190
+
191
+ preds = img.new_zeros((batch_size, num_classes, h_img, w_img))
192
+ count_mat = img.new_zeros((batch_size, 1, h_img, w_img))
193
+ idx_window = torch.cat([window * torch.ones((h*w, ), device=dino_feats.device, dtype=torch.int64) for window, (h, w) in enumerate(height_width)])
194
+ for h_idx in range(h_grids):
195
+ for w_idx in range(w_grids):
196
+ y1 = h_idx * h_stride
197
+ x1 = w_idx * w_stride
198
+ y2 = min(y1 + h_crop, h_img)
199
+ x2 = min(x1 + w_crop, w_img)
200
+ y1 = max(y2 - h_crop, 0)
201
+ x1 = max(x2 - w_crop, 0)
202
+ win_id = h_idx*w_grids + w_idx
203
+ crop_seg_logit = lp_preds[torch.where(idx_window == win_id)[0], :]
204
+ crop_seg_logit = torch.reshape(crop_seg_logit, height_width[win_id]+(num_classes, ))
205
+ crop_seg_logit = torch.unsqueeze(crop_seg_logit, 0)
206
+ crop_seg_logit = torch.permute(crop_seg_logit, (0, 3, 1, 2))
207
+ crop_seg_logit = F.interpolate(
208
+ input=crop_seg_logit,
209
+ size=(y2-y1, x2-x1),
210
+ mode='bilinear',
211
+ align_corners=False
212
+ )
213
+ assert crop_seg_logit.shape[2] == (y2 - y1) and crop_seg_logit.shape[3] == (x2 - x1)
214
+ preds += F.pad(crop_seg_logit,
215
+ (int(x1), int(preds.shape[3] - x2), int(y1),
216
+ int(preds.shape[2] - y2)))
217
+
218
+ count_mat[:, :, y1:y2, x1:x2] += 1
219
+ assert (count_mat == 0).sum() == 0
220
+ preds = preds / count_mat
221
+
222
+ return preds
223
+
224
+
225
+ def get_pixel_connections(img, neigh=1):
226
+ img = img[0, ...]
227
+ img_lab = rgb_to_lab(img)
228
+ img_lab = img_lab.permute((1, 2, 0))
229
+ img_lab /= torch.tensor([100, 128, 128], device=img.device) # project Lab values to 0-1 range
230
+ img_h, img_w, _ = img_lab.shape
231
+ img_lab = img_lab.reshape((img_h*img_w, -1))
232
+
233
+ idx = torch.arange(img_h * img_w).to(img.device)
234
+ loc_h = idx // img_w
235
+ loc_w = idx % img_w
236
+ locs = torch.stack((loc_h, loc_w), 1)
237
+
238
+ rows, cols = [], []
239
+
240
+ for mov in product(range(-neigh, neigh+1), range(-neigh, neigh+1)):
241
+ if mov[0] == 0 and mov[1] == 0:
242
+ continue
243
+ new_locs = locs + torch.tensor(mov).to(img.device)
244
+ mask = torch.logical_and(torch.logical_and(torch.logical_and(new_locs[:, 0] >= 0, new_locs[:, 1] >= 0), new_locs[:, 0] < img_h), new_locs[:, 1] < img_w)
245
+ rows.append(torch.where(mask)[0])
246
+ col = new_locs[mask, :]
247
+ col = col[:, 0] * img_w + col[:, 1]
248
+ cols.append(col)
249
+
250
+ rows = torch.cat(rows)
251
+ cols = torch.cat(cols)
252
+ pixel_pixel_data = ((img_lab[rows, :] - img_lab[cols, :]) ** 2).sum(dim=-1)
253
+
254
+ return rows, cols, pixel_pixel_data, locs
255
+
256
+
257
+ def get_laplacian(rows, cols, data, N, alpha=0.99):
258
+ if torch.cuda.is_available():
259
+ rows = cp.asarray(rows)
260
+ cols = cp.asarray(cols)
261
+ data = cp.asarray(data)
262
+ W = cp_csr_matrix(
263
+ (data, (rows, cols)),
264
+ shape=(N, N),
265
+ )
266
+
267
+ Wn = normalize_connection_graph_cupy(W)
268
+ L = cp_eye(Wn.shape[0]) - alpha * Wn
269
+ else:
270
+ W = csr_matrix(
271
+ (data.cpu().numpy(), (rows.cpu().numpy(), cols.cpu().numpy())),
272
+ shape=(N, N),
273
+ )
274
+
275
+ Wn = normalize_connection_graph(W)
276
+ L = eye(Wn.shape[0]) - alpha * Wn
277
+ return L
278
+
279
+
280
+ def lposs_plus(img, preds, tau=0.01, alpha=0.95):
281
+ preds = preds[0, ...]
282
+ num_classes, h_img, w_img = preds.shape
283
+ preds = preds.permute((1, 2, 0))
284
+ preds = preds.reshape((h_img*w_img, -1))
285
+
286
+ rows, cols, pixel_pixel_data, locs = get_pixel_connections(img, neigh=6)
287
+ pixel_pixel_data = torch.sqrt(pixel_pixel_data)
288
+ pixel_pixel_data = torch.exp(-pixel_pixel_data / tau)
289
+ L = get_laplacian(rows, cols, pixel_pixel_data, preds.shape[0], alpha=alpha)
290
+
291
+ lp_preds = perform_lp(L, preds)
292
+
293
+ return lp_preds.reshape((h_img, w_img, num_classes)).permute((2, 0, 1)).unsqueeze(0)
models/dino.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchvision.transforms as T
4
+
5
+ NORMALIZE = T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
6
+
7
+ class DINO(nn.Module):
8
+ def __init__(self):
9
+ super().__init__()
10
+ self.backbone = torch.hub.load('facebookresearch/dino:main', 'dino_vitb16')
11
+ self.hook_features = {}
12
+ def hook_fn_forward_qkv(module, input, output):
13
+ self.hook_features["qkv"] = output
14
+
15
+ self.backbone._modules["blocks"][-1]._modules["attn"]._modules[
16
+ "qkv"
17
+ ].register_forward_hook(hook_fn_forward_qkv)
18
+
19
+ self.patch_size = 16
20
+ self.enc_type_feats = "v"
21
+
22
+
23
+ @torch.no_grad()
24
+ def extract_feats(self, type_feats="k"):
25
+ """
26
+ DINO feature extractor. Attaches a hook on the last attention layer.
27
+ :param type_feats: (string) - type of features from DINO ViT
28
+ """
29
+ nh = self.backbone.blocks[-1].attn.num_heads
30
+ nb_im, nb_tokens, C_qkv = self.hook_features["qkv"].shape
31
+
32
+ qkv = (
33
+ self.hook_features["qkv"]
34
+ .reshape(
35
+ nb_im, nb_tokens, 3, nh, C_qkv // nh // 3
36
+ ) # 3 corresponding to |qkv|
37
+ .permute(2, 0, 3, 1, 4)
38
+ )
39
+ q, k, v = qkv[0], qkv[1], qkv[2]
40
+ if type_feats == "q":
41
+ return q.transpose(1, 2).float()
42
+ elif type_feats == "k":
43
+ return k.transpose(1, 2).float()
44
+ elif type_feats == "v":
45
+ return v.transpose(1, 2).float()
46
+ else:
47
+ raise ValueError("Unknown features")
48
+
49
+
50
+ @torch.no_grad()
51
+ def forward(self, x):
52
+ x = NORMALIZE(x)
53
+ h_featmap = x.shape[-2] // self.patch_size
54
+ w_featmap = x.shape[-1] // self.patch_size
55
+
56
+ # Forward pass
57
+ # Encoder forward pass and get hooked intermediate values
58
+ _ = self.backbone(x)
59
+
60
+ # Get decoder features
61
+ feats = self.extract_feats(type_feats=self.enc_type_feats)
62
+ num_extra_tokens = 1
63
+
64
+ # B nbtokens+1 nh dim
65
+ feats = feats[:, num_extra_tokens:, :, :].flatten(-2, -1).permute(0, 2, 1) # B C nbtokens
66
+ # B, C, nbtokens
67
+ feats = feats / feats.norm(dim=1, keepdim=True) # normalize features
68
+
69
+ return feats, (h_featmap, w_featmap)
models/maskclip.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ---------------------------------------------------------------------------------------------------
2
+ # CLIP-DINOiser
3
+ # authors: Monika Wysoczanska, Warsaw University of Technology
4
+
5
+ # Copyright (c) OpenMMLab. All rights reserved.
6
+ # Modified version of the original MaskCLIP code: https://github.com/chongzhou96/MaskCLIP/tree/master
7
+ # ---------------------------------------------------------------------------------------------------
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from typing import List, Tuple
13
+ from torch import Tensor
14
+ from open_clip import get_tokenizer, create_model_from_pretrained
15
+ import torchvision.transforms as T
16
+ from .utils import imagenet_templates
17
+
18
+ OPENAI_NORMALIZE = T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
19
+
20
+ class MaskClip(nn.Module):
21
+ def __init__(
22
+ self,
23
+ clip_model="ViT-B-16",
24
+ pretrained="laion2b_s34b_b88k",
25
+ patch_size=16,
26
+ img_size=(224, 224),
27
+ in_channels=768,
28
+ text_channels=512,
29
+ ):
30
+ super(MaskClip, self).__init__()
31
+
32
+ self.patch_size = patch_size
33
+ self.img_size = img_size
34
+ model, _ = create_model_from_pretrained(clip_model, pretrained=pretrained)
35
+ model.eval()
36
+ self.clip_T = OPENAI_NORMALIZE
37
+ self.hook_features = {}
38
+ self.backbone = model
39
+ def hook_fn_forward(module, input, output):
40
+ self.hook_features["v"] = output
41
+ self.backbone.visual.transformer.resblocks[-2].register_forward_hook(hook_fn_forward)
42
+ self._positional_embd = nn.Parameter(self.backbone.visual.positional_embedding.data.clone())
43
+ self.proj = nn.Conv2d(in_channels, text_channels, 1, bias=False)
44
+ self.proj.weight = nn.Parameter(model.visual.proj.t()[:, :, None, None])
45
+ self.tokenizer = get_tokenizer(clip_model)
46
+
47
+ @torch.no_grad()
48
+ def extract_feat(self, inputs: Tensor) -> Tuple[Tensor]:
49
+ """Extract features from images."""
50
+ pos_embed = self.backbone.visual.positional_embedding
51
+
52
+ B, C, H, W = inputs.shape
53
+ hw_shape = (H // self.patch_size, W // self.patch_size)
54
+ x_len, pos_len = hw_shape[0]*hw_shape[1], pos_embed.shape[0]
55
+
56
+ if x_len != pos_len:
57
+ if pos_len == (self.img_size[0] // self.patch_size) * (self.img_size[1] // self.patch_size) + 1:
58
+ pos_h = self.img_size[0] // self.patch_size
59
+ pos_w = self.img_size[1] // self.patch_size
60
+ else:
61
+ raise ValueError(
62
+ '{}, {}'.format(x_len, pos_len))
63
+
64
+ self.backbone.visual.positional_embedding.data = self.resize_pos_embed(
65
+ self._positional_embd[None], hw_shape, (pos_h, pos_w), 'bicubic')[0]
66
+
67
+ _ = self.backbone(inputs)
68
+ v = self.hook_features["v"]
69
+ v = self.extract_v(v, self.backbone.visual.transformer.resblocks[-1]).permute(1, 0, 2)
70
+ v = self.backbone.visual.ln_post(v)
71
+ # v = v[:, 1:] # was there in original code
72
+ v = v.permute(1, 0, 2)[:, 1:] # put this as per https://github.com/wysoczanska/clip_dinoiser/issues/10
73
+ v = v.reshape(B, hw_shape[0], hw_shape[1], -1).permute(0, 3, 1, 2).contiguous()
74
+
75
+ self.backbone.visual.positional_embedding.data = self._positional_embd
76
+ return v
77
+
78
+ @torch.no_grad()
79
+ def extract_v(self, x, block):
80
+ y = block.ln_1(x)
81
+ y = torch.nn.functional.linear(y, block.attn.in_proj_weight, block.attn.in_proj_bias)
82
+ B, N, C = y.shape
83
+ y = y.view(B, N, 3, C // 3).permute(2, 0, 1, 3).reshape(3 * B, N, C // 3)
84
+ y = F.linear(y, block.attn.out_proj.weight, block.attn.out_proj.bias)
85
+ q, k, v = y.tensor_split(3, dim=0)
86
+ v += x
87
+ v += block.mlp(block.ln_2(v))
88
+ return v
89
+
90
+
91
+ @staticmethod
92
+ def resize_pos_embed(pos_embed, input_shpae, pos_shape, mode):
93
+ """Resize pos_embed weights.
94
+
95
+ Resize pos_embed using bicubic interpolate method.
96
+ Args:
97
+ pos_embed (torch.Tensor): Position embedding weights.
98
+ input_shpae (tuple): Tuple for (downsampled input image height,
99
+ downsampled input image width).
100
+ pos_shape (tuple): The resolution of downsampled origin training
101
+ image.
102
+ mode (str): Algorithm used for upsampling:
103
+ ``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` |
104
+ ``'trilinear'``. Default: ``'nearest'``
105
+ Return:
106
+ torch.Tensor: The resized pos_embed of shape [B, L_new, C]
107
+ """
108
+ assert pos_embed.ndim == 3, 'shape of pos_embed must be [B, L, C]'
109
+ pos_h, pos_w = pos_shape
110
+ cls_token_weight = pos_embed[:, 0]
111
+ pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w):]
112
+ pos_embed_weight = pos_embed_weight.reshape(
113
+ 1, pos_h, pos_w, pos_embed.shape[2]).permute(0, 3, 1, 2)
114
+ pos_embed_weight = F.interpolate(
115
+ pos_embed_weight, size=input_shpae, align_corners=False, mode=mode)
116
+ cls_token_weight = cls_token_weight.unsqueeze(1)
117
+ pos_embed_weight = torch.flatten(pos_embed_weight, 2).transpose(1, 2)
118
+ pos_embed = torch.cat((cls_token_weight, pos_embed_weight), dim=1)
119
+ return pos_embed
120
+
121
+ @torch.no_grad()
122
+ def decode_head(self, x: Tensor) -> Tensor:
123
+ feat = self.proj(x)
124
+
125
+ return feat
126
+
127
+
128
+ @torch.no_grad()
129
+ def forward(self, inputs: Tensor) -> Tensor:
130
+ """Encode images with backbone and decode into a semantic segmentation
131
+ map of the same size as input."""
132
+ inputs = self.clip_T(inputs)
133
+ x = self.extract_feat(inputs)
134
+ feats = self.decode_head(x)
135
+ return feats
136
+
137
+
138
+ @torch.no_grad()
139
+ def get_classifier(self, classnames:List[str]) -> Tensor:
140
+ aug_embeddings = torch.stack([self._embed_label(label) for label in classnames])
141
+ aug_embeddings = aug_embeddings / aug_embeddings.norm(dim=-1, keepdim=True)
142
+ return aug_embeddings.squeeze(1)
143
+
144
+
145
+ @torch.no_grad()
146
+ def _embed_label(self, label: str) -> Tensor:
147
+ """Encode label name into a single vector."""
148
+ all_prompts = [self.tokenizer(template.format(label)) for template in imagenet_templates]
149
+ all_prompts = torch.cat(all_prompts)
150
+ all_prompts = all_prompts.to(self.backbone.visual.positional_embedding.device)
151
+ out = self.backbone.encode_text(all_prompts)
152
+ out /= out.norm(dim=-1, keepdim=True)
153
+ out = out.mean(dim=0)
154
+ return out
models/utils.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ imagenet_templates = [
2
+ 'a bad photo of a {}.',
3
+ 'a photo of many {}.',
4
+ 'a sculpture of a {}.',
5
+ 'a photo of the hard to see {}.',
6
+ 'a low resolution photo of the {}.',
7
+ 'a rendering of a {}.',
8
+ 'graffiti of a {}.',
9
+ 'a bad photo of the {}.',
10
+ 'a cropped photo of the {}.',
11
+ 'a tattoo of a {}.',
12
+ 'the embroidered {}.',
13
+ 'a photo of a hard to see {}.',
14
+ 'a bright photo of a {}.',
15
+ 'a photo of a clean {}.',
16
+ 'a photo of a dirty {}.',
17
+ 'a dark photo of the {}.',
18
+ 'a drawing of a {}.',
19
+ 'a photo of my {}.',
20
+ 'the plastic {}.',
21
+ 'a photo of the cool {}.',
22
+ 'a close-up photo of a {}.',
23
+ 'a black and white photo of the {}.',
24
+ 'a painting of the {}.',
25
+ 'a painting of a {}.',
26
+ 'a pixelated photo of the {}.',
27
+ 'a sculpture of the {}.',
28
+ 'a bright photo of the {}.',
29
+ 'a cropped photo of a {}.',
30
+ 'a plastic {}.',
31
+ 'a photo of the dirty {}.',
32
+ 'a jpeg corrupted photo of a {}.',
33
+ 'a blurry photo of the {}.',
34
+ 'a photo of the {}.',
35
+ 'a good photo of the {}.',
36
+ 'a rendering of the {}.',
37
+ 'a {} in a video game.',
38
+ 'a photo of one {}.',
39
+ 'a doodle of a {}.',
40
+ 'a close-up photo of the {}.',
41
+ 'a photo of a {}.',
42
+ 'the origami {}.',
43
+ 'the {} in a video game.',
44
+ 'a sketch of a {}.',
45
+ 'a doodle of the {}.',
46
+ 'a origami {}.',
47
+ 'a low resolution photo of a {}.',
48
+ 'the toy {}.',
49
+ 'a rendition of the {}.',
50
+ 'a photo of the clean {}.',
51
+ 'a photo of a large {}.',
52
+ 'a rendition of a {}.',
53
+ 'a photo of a nice {}.',
54
+ 'a photo of a weird {}.',
55
+ 'a blurry photo of a {}.',
56
+ 'a cartoon {}.',
57
+ 'art of a {}.',
58
+ 'a sketch of the {}.',
59
+ 'a embroidered {}.',
60
+ 'a pixelated photo of a {}.',
61
+ 'itap of the {}.',
62
+ 'a jpeg corrupted photo of the {}.',
63
+ 'a good photo of a {}.',
64
+ 'a plushie {}.',
65
+ 'a photo of the nice {}.',
66
+ 'a photo of the small {}.',
67
+ 'a photo of the weird {}.',
68
+ 'the cartoon {}.',
69
+ 'art of the {}.',
70
+ 'a drawing of the {}.',
71
+ 'a photo of the large {}.',
72
+ 'a black and white photo of a {}.',
73
+ 'the plushie {}.',
74
+ 'a dark photo of a {}.',
75
+ 'itap of a {}.',
76
+ 'graffiti of the {}.',
77
+ 'a toy {}.',
78
+ 'itap of my {}.',
79
+ 'a photo of a cool {}.',
80
+ 'a photo of a small {}.',
81
+ 'a tattoo of the {}.',
82
+ ]
requrements.txt ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ certifi==2025.1.31
2
+ charset-normalizer==3.4.1
3
+ cupy-cuda12x==13.3.0
4
+ fastrlock==0.8.3
5
+ filelock==3.13.1
6
+ fsspec==2024.6.1
7
+ ftfy==6.3.1
8
+ huggingface-hub==0.28.1
9
+ idna==3.10
10
+ Jinja2==3.1.4
11
+ kornia==0.8.0
12
+ kornia_rs==0.1.8
13
+ MarkupSafe==2.1.5
14
+ mpmath==1.3.0
15
+ networkx==3.3
16
+ numpy==2.1.2
17
+ nvidia-cublas-cu12==12.6.4.1
18
+ nvidia-cuda-cupti-cu12==12.6.80
19
+ nvidia-cuda-nvrtc-cu12==12.6.77
20
+ nvidia-cuda-runtime-cu12==12.6.77
21
+ nvidia-cudnn-cu12==9.5.1.17
22
+ nvidia-cufft-cu12==11.3.0.4
23
+ nvidia-curand-cu12==10.3.7.77
24
+ nvidia-cusolver-cu12==11.7.1.2
25
+ nvidia-cusparse-cu12==12.5.4.2
26
+ nvidia-cusparselt-cu12==0.6.3
27
+ nvidia-nccl-cu12==2.21.5
28
+ nvidia-nvjitlink-cu12==12.6.85
29
+ nvidia-nvtx-cu12==12.6.77
30
+ open_clip_torch==2.30.0
31
+ packaging==24.2
32
+ pillow==11.0.0
33
+ PyYAML==6.0.2
34
+ regex==2024.11.6
35
+ requests==2.32.3
36
+ safetensors==0.5.2
37
+ scipy==1.15.1
38
+ sympy==1.13.1
39
+ timm==1.0.14
40
+ torch==2.6.0+cu126
41
+ torchaudio==2.6.0+cu126
42
+ torchvision==0.21.0+cu126
43
+ tqdm==4.67.1
44
+ triton==3.2.0
45
+ typing_extensions==4.12.2
46
+ urllib3==2.3.0
47
+ wcwidth==0.2.13