from itertools import product import torch import torch.nn as nn import torch.nn.functional as F try: import cupy as cp from cupyx.scipy.sparse import csr_matrix as cp_csr_matrix, eye as cp_eye, diags as cp_diags from cupyx.scipy.sparse import linalg as cp_s_linalg except ImportError: print("Cupy not installed") import numpy as np from scipy.sparse import csr_matrix, eye, diags from scipy.sparse import linalg as s_linalg from kornia.color import rgb_to_lab def make_input_divisible(x: torch.Tensor, patch_size=16) -> torch.Tensor: """Pad some pixels to make the input size divisible by the patch size.""" B, _, H_0, W_0 = x.shape pad_w = (patch_size - W_0 % patch_size) % patch_size pad_h = (patch_size - H_0 % patch_size) % patch_size x = nn.functional.pad(x, (0, pad_w, 0, pad_h), value=0) return x def reshape_windows(x): height_width = [(y.shape[0], y.shape[1]) for y in x] dim = x[0].shape[-1] x = [torch.reshape(y, (-1, dim)) for y in x] return torch.cat(x, dim=0), height_width def normalize_connection_graph_cupy(G): W = cp_csr_matrix(G) W = W - cp_diags(W.diagonal(), 0) S = W.sum(axis=1) # breakpoint() S[S == 0] = 1 D = cp.array(1.0 / cp.sqrt(S)) D[cp.isnan(D)] = 0 D[cp.isinf(D)] = 0 D_mh = cp_diags(D.reshape(-1), 0) Wn = D_mh * W * D_mh return Wn def normalize_connection_graph(G): W = csr_matrix(G) W = W - diags(W.diagonal(), 0) S = W.sum(axis=1) S[S == 0] = 1 D = np.array(1.0 / np.sqrt(S)) D[np.isnan(D)] = 0 D[np.isinf(D)] = 0 D_mh = diags(D.reshape(-1), 0) Wn = D_mh * W * D_mh return Wn def cp_dfs_search(L, Y, tol=1e-6, maxiter=10): out = cp_s_linalg.cg(L, Y, tol=tol, maxiter=maxiter)[0] return out def dfs_search(L, Y, tol=1e-6, maxiter=10): out = s_linalg.cg(L, Y, rtol=tol, maxiter=maxiter)[0] return out def perform_lp(L, preds): if torch.cuda.is_available(): lp_preds = cp.zeros(preds.shape) preds = cp.asarray(preds) for cls_idx, y_cls in enumerate(preds.T): Y = y_cls lp_preds[:, cls_idx] = cp_dfs_search(L, Y) lp_preds = torch.as_tensor(lp_preds, device="cuda") else: lp_preds = np.zeros(preds.shape) for cls_idx, y_cls in enumerate(preds.T): Y = y_cls lp_preds[:, cls_idx] = dfs_search(L, Y) lp_preds = torch.as_tensor(lp_preds, device="cpu") return lp_preds def get_lposs_laplacian(feats, locations, height_width, sigma=0.0, pix_dist_pow=2, k=100, gamma=1.0, alpha=0.95, patch_size=16): idx_window = torch.cat([window * torch.ones((h*w, ), device=feats.device, dtype=torch.int64) for window, (h, w) in enumerate(height_width)]) idx_h = torch.cat([torch.arange(h).view(-1,1).repeat(1, w).flatten() for h, w in height_width]).to(feats.device) idx_w = torch.cat([torch.arange(w).view(1,-1).repeat(h, 1).flatten() for h, w in height_width]).to(feats.device) loc_h = locations[idx_window, 0] + (patch_size // 2) + idx_h * patch_size loc_w = locations[idx_window, 2] + (patch_size // 2) + idx_w * patch_size locs = torch.stack((loc_h, loc_w), 1) locs = torch.unsqueeze(locs, 0) dist = torch.cdist(locs, locs, p=2) dist = dist[0, ...] dist = dist ** pix_dist_pow geometry_affinity = torch.exp(-sigma * dist) N = feats.shape[0] affinity = feats @ feats.T sims, ks = torch.topk(affinity, k=k, dim=1) sims[sims < 0] = 0 sims = sims ** gamma geometry_affinity = geometry_affinity.gather(1, ks).flatten() sims = sims.flatten() sims = sims * geometry_affinity ks = ks.flatten() rows = torch.arange(N).repeat_interleave(k) if torch.cuda.is_available(): W = cp_csr_matrix( (cp.asarray(sims), (cp.asarray(rows), cp.asarray(ks))), shape=(N, N), ) W = W + W.T Wn = normalize_connection_graph_cupy(W) L = cp_eye(Wn.shape[0]) - alpha * Wn else: W = csr_matrix( (sims.cpu().numpy(), (rows.cpu().numpy(), ks.cpu().numpy())), shape=(N, N), ) W = W + W.T Wn = normalize_connection_graph(W) L = eye(Wn.shape[0]) - alpha * Wn return L def lposs(clip, dino, img, classnames, window_size=(224,224), window_stride=(112, 112), sigma=0.01, pix_dist_pow=1, lp_k_image=400, lp_gamma=3.0, lp_alpha=0.95): h_stride, w_stride = window_stride h_crop, w_crop = window_size batch_size, _, h_img, w_img = img.size() h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1 w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1 clf = clip.get_classifier(classnames) locations = img.new_zeros((h_grids*w_grids, 4)) dino_feats = [] clip_feats = [] for h_idx in range(h_grids): for w_idx in range(w_grids): y1 = h_idx * h_stride x1 = w_idx * w_stride y2 = min(y1 + h_crop, h_img) x2 = min(x1 + w_crop, w_img) y1 = max(y2 - h_crop, 0) x1 = max(x2 - w_crop, 0) crop_img = img[:, :, y1:y2, x1:x2] img_dino_feats, (h_dino, w_dino) = dino(make_input_divisible(crop_img, dino.patch_size)) # (1, 768, N) img_dino_feats = img_dino_feats.reshape((batch_size, -1, h_dino, w_dino)).permute(0, 2, 3, 1) # (1, h_dino, w_dino, 768) img_clip_feats = clip(make_input_divisible(crop_img, clip.patch_size)) # (1, 512, h, w) if img_clip_feats.shape[1] != img_dino_feats.shape[1] or img_clip_feats.shape[2] != img_dino_feats.shape[2]: img_clip_feats = F.interpolate(img_clip_feats, size=(img_dino_feats.shape[1], img_dino_feats.shape[2]), mode='bilinear', align_corners=False) img_clip_feats = img_clip_feats.permute(0, 2, 3, 1) # (1, h, w, 512) dino_feats.append(img_dino_feats[0, ...]) clip_feats.append(img_clip_feats[0, ...]) locations[h_idx*w_grids + w_idx, 0] = y1 locations[h_idx*w_grids + w_idx, 1] = y2 locations[h_idx*w_grids + w_idx, 2] = x1 locations[h_idx*w_grids + w_idx, 3] = x2 num_classes = clf.shape[0] patch_size = dino.patch_size dino_feats, height_width = reshape_windows(dino_feats) clip_feats, _ = reshape_windows(clip_feats) dino_feats = F.normalize(dino_feats, p=2, dim=-1) clip_feats = F.normalize(clip_feats, p=2, dim=-1) L = get_lposs_laplacian(dino_feats, locations, height_width, sigma=sigma, pix_dist_pow=pix_dist_pow, k=lp_k_image, gamma=lp_gamma, alpha=lp_alpha, patch_size=patch_size) clip_preds = clip_feats @ clf.T lp_preds = perform_lp(L, clip_preds) preds = img.new_zeros((batch_size, num_classes, h_img, w_img)) count_mat = img.new_zeros((batch_size, 1, h_img, w_img)) idx_window = torch.cat([window * torch.ones((h*w, ), device=dino_feats.device, dtype=torch.int64) for window, (h, w) in enumerate(height_width)]) for h_idx in range(h_grids): for w_idx in range(w_grids): y1 = h_idx * h_stride x1 = w_idx * w_stride y2 = min(y1 + h_crop, h_img) x2 = min(x1 + w_crop, w_img) y1 = max(y2 - h_crop, 0) x1 = max(x2 - w_crop, 0) win_id = h_idx*w_grids + w_idx crop_seg_logit = lp_preds[torch.where(idx_window == win_id)[0], :] crop_seg_logit = torch.reshape(crop_seg_logit, height_width[win_id]+(num_classes, )) crop_seg_logit = torch.unsqueeze(crop_seg_logit, 0) crop_seg_logit = torch.permute(crop_seg_logit, (0, 3, 1, 2)) crop_seg_logit = F.interpolate( input=crop_seg_logit, size=(y2-y1, x2-x1), mode='bilinear', align_corners=False ) assert crop_seg_logit.shape[2] == (y2 - y1) and crop_seg_logit.shape[3] == (x2 - x1) preds += F.pad(crop_seg_logit, (int(x1), int(preds.shape[3] - x2), int(y1), int(preds.shape[2] - y2))) count_mat[:, :, y1:y2, x1:x2] += 1 assert (count_mat == 0).sum() == 0 preds = preds / count_mat return preds def get_pixel_connections(img, neigh=1): img = img[0, ...] img_lab = rgb_to_lab(img) img_lab = img_lab.permute((1, 2, 0)) img_lab /= torch.tensor([100, 128, 128], device=img.device) # project Lab values to 0-1 range img_h, img_w, _ = img_lab.shape img_lab = img_lab.reshape((img_h*img_w, -1)) idx = torch.arange(img_h * img_w).to(img.device) loc_h = idx // img_w loc_w = idx % img_w locs = torch.stack((loc_h, loc_w), 1) rows, cols = [], [] for mov in product(range(-neigh, neigh+1), range(-neigh, neigh+1)): if mov[0] == 0 and mov[1] == 0: continue new_locs = locs + torch.tensor(mov).to(img.device) mask = torch.logical_and(torch.logical_and(torch.logical_and(new_locs[:, 0] >= 0, new_locs[:, 1] >= 0), new_locs[:, 0] < img_h), new_locs[:, 1] < img_w) rows.append(torch.where(mask)[0]) col = new_locs[mask, :] col = col[:, 0] * img_w + col[:, 1] cols.append(col) rows = torch.cat(rows) cols = torch.cat(cols) pixel_pixel_data = ((img_lab[rows, :] - img_lab[cols, :]) ** 2).sum(dim=-1) return rows, cols, pixel_pixel_data, locs def get_laplacian(rows, cols, data, N, alpha=0.99): if torch.cuda.is_available(): rows = cp.asarray(rows) cols = cp.asarray(cols) data = cp.asarray(data) W = cp_csr_matrix( (data, (rows, cols)), shape=(N, N), ) Wn = normalize_connection_graph_cupy(W) L = cp_eye(Wn.shape[0]) - alpha * Wn else: W = csr_matrix( (data.cpu().numpy(), (rows.cpu().numpy(), cols.cpu().numpy())), shape=(N, N), ) Wn = normalize_connection_graph(W) L = eye(Wn.shape[0]) - alpha * Wn return L def lposs_plus(img, preds, tau=0.01, alpha=0.95, r=13): preds = preds[0, ...] num_classes, h_img, w_img = preds.shape preds = preds.permute((1, 2, 0)) preds = preds.reshape((h_img*w_img, -1)) rows, cols, pixel_pixel_data, locs = get_pixel_connections(img, neigh=r//2) pixel_pixel_data = torch.sqrt(pixel_pixel_data) pixel_pixel_data = torch.exp(-pixel_pixel_data / tau) L = get_laplacian(rows, cols, pixel_pixel_data, preds.shape[0], alpha=alpha) lp_preds = perform_lp(L, preds) return lp_preds.reshape((h_img, w_img, num_classes)).permute((2, 0, 1)).unsqueeze(0)