|
from itertools import product |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
try: |
|
import cupy as cp |
|
from cupyx.scipy.sparse import csr_matrix as cp_csr_matrix, eye as cp_eye, diags as cp_diags |
|
from cupyx.scipy.sparse import linalg as cp_s_linalg |
|
except ImportError: |
|
print("Cupy not installed") |
|
import numpy as np |
|
from scipy.sparse import csr_matrix, eye, diags |
|
from scipy.sparse import linalg as s_linalg |
|
from kornia.color import rgb_to_lab |
|
|
|
|
|
def make_input_divisible(x: torch.Tensor, patch_size=16) -> torch.Tensor: |
|
"""Pad some pixels to make the input size divisible by the patch size.""" |
|
B, _, H_0, W_0 = x.shape |
|
pad_w = (patch_size - W_0 % patch_size) % patch_size |
|
pad_h = (patch_size - H_0 % patch_size) % patch_size |
|
|
|
x = nn.functional.pad(x, (0, pad_w, 0, pad_h), value=0) |
|
|
|
return x |
|
|
|
|
|
def reshape_windows(x): |
|
height_width = [(y.shape[0], y.shape[1]) for y in x] |
|
dim = x[0].shape[-1] |
|
x = [torch.reshape(y, (-1, dim)) for y in x] |
|
|
|
return torch.cat(x, dim=0), height_width |
|
|
|
|
|
def normalize_connection_graph_cupy(G): |
|
W = cp_csr_matrix(G) |
|
W = W - cp_diags(W.diagonal(), 0) |
|
S = W.sum(axis=1) |
|
|
|
S[S == 0] = 1 |
|
D = cp.array(1.0 / cp.sqrt(S)) |
|
D[cp.isnan(D)] = 0 |
|
D[cp.isinf(D)] = 0 |
|
D_mh = cp_diags(D.reshape(-1), 0) |
|
Wn = D_mh * W * D_mh |
|
return Wn |
|
|
|
|
|
def normalize_connection_graph(G): |
|
W = csr_matrix(G) |
|
W = W - diags(W.diagonal(), 0) |
|
S = W.sum(axis=1) |
|
S[S == 0] = 1 |
|
D = np.array(1.0 / np.sqrt(S)) |
|
D[np.isnan(D)] = 0 |
|
D[np.isinf(D)] = 0 |
|
D_mh = diags(D.reshape(-1), 0) |
|
Wn = D_mh * W * D_mh |
|
return Wn |
|
|
|
|
|
def cp_dfs_search(L, Y, tol=1e-6, maxiter=10): |
|
out = cp_s_linalg.cg(L, Y, tol=tol, maxiter=maxiter)[0] |
|
|
|
return out |
|
|
|
|
|
def dfs_search(L, Y, tol=1e-6, maxiter=10): |
|
out = s_linalg.cg(L, Y, rtol=tol, maxiter=maxiter)[0] |
|
|
|
return out |
|
|
|
|
|
def perform_lp(L, preds): |
|
if torch.cuda.is_available(): |
|
lp_preds = cp.zeros(preds.shape) |
|
preds = cp.asarray(preds) |
|
for cls_idx, y_cls in enumerate(preds.T): |
|
Y = y_cls |
|
lp_preds[:, cls_idx] = cp_dfs_search(L, Y) |
|
lp_preds = torch.as_tensor(lp_preds, device="cuda") |
|
else: |
|
lp_preds = np.zeros(preds.shape) |
|
for cls_idx, y_cls in enumerate(preds.T): |
|
Y = y_cls |
|
lp_preds[:, cls_idx] = dfs_search(L, Y) |
|
lp_preds = torch.as_tensor(lp_preds, device="cpu") |
|
|
|
return lp_preds |
|
|
|
|
|
def get_lposs_laplacian(feats, locations, height_width, sigma=0.0, pix_dist_pow=2, k=100, gamma=1.0, alpha=0.95, patch_size=16): |
|
idx_window = torch.cat([window * torch.ones((h*w, ), device=feats.device, dtype=torch.int64) for window, (h, w) in enumerate(height_width)]) |
|
idx_h = torch.cat([torch.arange(h).view(-1,1).repeat(1, w).flatten() for h, w in height_width]).to(feats.device) |
|
idx_w = torch.cat([torch.arange(w).view(1,-1).repeat(h, 1).flatten() for h, w in height_width]).to(feats.device) |
|
loc_h = locations[idx_window, 0] + (patch_size // 2) + idx_h * patch_size |
|
loc_w = locations[idx_window, 2] + (patch_size // 2) + idx_w * patch_size |
|
locs = torch.stack((loc_h, loc_w), 1) |
|
locs = torch.unsqueeze(locs, 0) |
|
dist = torch.cdist(locs, locs, p=2) |
|
dist = dist[0, ...] |
|
dist = dist ** pix_dist_pow |
|
geometry_affinity = torch.exp(-sigma * dist) |
|
|
|
N = feats.shape[0] |
|
|
|
affinity = feats @ feats.T |
|
sims, ks = torch.topk(affinity, k=k, dim=1) |
|
|
|
sims[sims < 0] = 0 |
|
sims = sims ** gamma |
|
geometry_affinity = geometry_affinity.gather(1, ks).flatten() |
|
sims = sims.flatten() |
|
sims = sims * geometry_affinity |
|
ks = ks.flatten() |
|
rows = torch.arange(N).repeat_interleave(k) |
|
|
|
if torch.cuda.is_available(): |
|
W = cp_csr_matrix( |
|
(cp.asarray(sims), (cp.asarray(rows), cp.asarray(ks))), |
|
shape=(N, N), |
|
) |
|
W = W + W.T |
|
Wn = normalize_connection_graph_cupy(W) |
|
L = cp_eye(Wn.shape[0]) - alpha * Wn |
|
else: |
|
W = csr_matrix( |
|
(sims.cpu().numpy(), (rows.cpu().numpy(), ks.cpu().numpy())), |
|
shape=(N, N), |
|
) |
|
W = W + W.T |
|
Wn = normalize_connection_graph(W) |
|
L = eye(Wn.shape[0]) - alpha * Wn |
|
|
|
return L |
|
|
|
|
|
def lposs(clip, dino, img, classnames, window_size=(224,224), window_stride=(112, 112), sigma=0.01, pix_dist_pow=1, lp_k_image=400, lp_gamma=3.0, lp_alpha=0.95): |
|
h_stride, w_stride = window_stride |
|
h_crop, w_crop = window_size |
|
batch_size, _, h_img, w_img = img.size() |
|
h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1 |
|
w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1 |
|
|
|
clf = clip.get_classifier(classnames) |
|
|
|
locations = img.new_zeros((h_grids*w_grids, 4)) |
|
dino_feats = [] |
|
clip_feats = [] |
|
for h_idx in range(h_grids): |
|
for w_idx in range(w_grids): |
|
y1 = h_idx * h_stride |
|
x1 = w_idx * w_stride |
|
y2 = min(y1 + h_crop, h_img) |
|
x2 = min(x1 + w_crop, w_img) |
|
y1 = max(y2 - h_crop, 0) |
|
x1 = max(x2 - w_crop, 0) |
|
crop_img = img[:, :, y1:y2, x1:x2] |
|
|
|
img_dino_feats, (h_dino, w_dino) = dino(make_input_divisible(crop_img, dino.patch_size)) |
|
img_dino_feats = img_dino_feats.reshape((batch_size, -1, h_dino, w_dino)).permute(0, 2, 3, 1) |
|
img_clip_feats = clip(make_input_divisible(crop_img, clip.patch_size)) |
|
|
|
if img_clip_feats.shape[1] != img_dino_feats.shape[1] or img_clip_feats.shape[2] != img_dino_feats.shape[2]: |
|
img_clip_feats = F.interpolate(img_clip_feats, size=(img_dino_feats.shape[1], img_dino_feats.shape[2]), mode='bilinear', align_corners=False) |
|
|
|
img_clip_feats = img_clip_feats.permute(0, 2, 3, 1) |
|
|
|
dino_feats.append(img_dino_feats[0, ...]) |
|
clip_feats.append(img_clip_feats[0, ...]) |
|
locations[h_idx*w_grids + w_idx, 0] = y1 |
|
locations[h_idx*w_grids + w_idx, 1] = y2 |
|
locations[h_idx*w_grids + w_idx, 2] = x1 |
|
locations[h_idx*w_grids + w_idx, 3] = x2 |
|
|
|
num_classes = clf.shape[0] |
|
|
|
patch_size = dino.patch_size |
|
|
|
dino_feats, height_width = reshape_windows(dino_feats) |
|
clip_feats, _ = reshape_windows(clip_feats) |
|
dino_feats = F.normalize(dino_feats, p=2, dim=-1) |
|
clip_feats = F.normalize(clip_feats, p=2, dim=-1) |
|
|
|
L = get_lposs_laplacian(dino_feats, locations, height_width, sigma=sigma, pix_dist_pow=pix_dist_pow, k=lp_k_image, gamma=lp_gamma, alpha=lp_alpha, patch_size=patch_size) |
|
clip_preds = clip_feats @ clf.T |
|
|
|
lp_preds = perform_lp(L, clip_preds) |
|
|
|
preds = img.new_zeros((batch_size, num_classes, h_img, w_img)) |
|
count_mat = img.new_zeros((batch_size, 1, h_img, w_img)) |
|
idx_window = torch.cat([window * torch.ones((h*w, ), device=dino_feats.device, dtype=torch.int64) for window, (h, w) in enumerate(height_width)]) |
|
for h_idx in range(h_grids): |
|
for w_idx in range(w_grids): |
|
y1 = h_idx * h_stride |
|
x1 = w_idx * w_stride |
|
y2 = min(y1 + h_crop, h_img) |
|
x2 = min(x1 + w_crop, w_img) |
|
y1 = max(y2 - h_crop, 0) |
|
x1 = max(x2 - w_crop, 0) |
|
win_id = h_idx*w_grids + w_idx |
|
crop_seg_logit = lp_preds[torch.where(idx_window == win_id)[0], :] |
|
crop_seg_logit = torch.reshape(crop_seg_logit, height_width[win_id]+(num_classes, )) |
|
crop_seg_logit = torch.unsqueeze(crop_seg_logit, 0) |
|
crop_seg_logit = torch.permute(crop_seg_logit, (0, 3, 1, 2)) |
|
crop_seg_logit = F.interpolate( |
|
input=crop_seg_logit, |
|
size=(y2-y1, x2-x1), |
|
mode='bilinear', |
|
align_corners=False |
|
) |
|
assert crop_seg_logit.shape[2] == (y2 - y1) and crop_seg_logit.shape[3] == (x2 - x1) |
|
preds += F.pad(crop_seg_logit, |
|
(int(x1), int(preds.shape[3] - x2), int(y1), |
|
int(preds.shape[2] - y2))) |
|
|
|
count_mat[:, :, y1:y2, x1:x2] += 1 |
|
assert (count_mat == 0).sum() == 0 |
|
preds = preds / count_mat |
|
|
|
return preds |
|
|
|
|
|
def get_pixel_connections(img, neigh=1): |
|
img = img[0, ...] |
|
img_lab = rgb_to_lab(img) |
|
img_lab = img_lab.permute((1, 2, 0)) |
|
img_lab /= torch.tensor([100, 128, 128], device=img.device) |
|
img_h, img_w, _ = img_lab.shape |
|
img_lab = img_lab.reshape((img_h*img_w, -1)) |
|
|
|
idx = torch.arange(img_h * img_w).to(img.device) |
|
loc_h = idx // img_w |
|
loc_w = idx % img_w |
|
locs = torch.stack((loc_h, loc_w), 1) |
|
|
|
rows, cols = [], [] |
|
|
|
for mov in product(range(-neigh, neigh+1), range(-neigh, neigh+1)): |
|
if mov[0] == 0 and mov[1] == 0: |
|
continue |
|
new_locs = locs + torch.tensor(mov).to(img.device) |
|
mask = torch.logical_and(torch.logical_and(torch.logical_and(new_locs[:, 0] >= 0, new_locs[:, 1] >= 0), new_locs[:, 0] < img_h), new_locs[:, 1] < img_w) |
|
rows.append(torch.where(mask)[0]) |
|
col = new_locs[mask, :] |
|
col = col[:, 0] * img_w + col[:, 1] |
|
cols.append(col) |
|
|
|
rows = torch.cat(rows) |
|
cols = torch.cat(cols) |
|
pixel_pixel_data = ((img_lab[rows, :] - img_lab[cols, :]) ** 2).sum(dim=-1) |
|
|
|
return rows, cols, pixel_pixel_data, locs |
|
|
|
|
|
def get_laplacian(rows, cols, data, N, alpha=0.99): |
|
if torch.cuda.is_available(): |
|
rows = cp.asarray(rows) |
|
cols = cp.asarray(cols) |
|
data = cp.asarray(data) |
|
W = cp_csr_matrix( |
|
(data, (rows, cols)), |
|
shape=(N, N), |
|
) |
|
|
|
Wn = normalize_connection_graph_cupy(W) |
|
L = cp_eye(Wn.shape[0]) - alpha * Wn |
|
else: |
|
W = csr_matrix( |
|
(data.cpu().numpy(), (rows.cpu().numpy(), cols.cpu().numpy())), |
|
shape=(N, N), |
|
) |
|
|
|
Wn = normalize_connection_graph(W) |
|
L = eye(Wn.shape[0]) - alpha * Wn |
|
return L |
|
|
|
|
|
def lposs_plus(img, preds, tau=0.01, alpha=0.95, r=13): |
|
preds = preds[0, ...] |
|
num_classes, h_img, w_img = preds.shape |
|
preds = preds.permute((1, 2, 0)) |
|
preds = preds.reshape((h_img*w_img, -1)) |
|
|
|
rows, cols, pixel_pixel_data, locs = get_pixel_connections(img, neigh=r//2) |
|
pixel_pixel_data = torch.sqrt(pixel_pixel_data) |
|
pixel_pixel_data = torch.exp(-pixel_pixel_data / tau) |
|
L = get_laplacian(rows, cols, pixel_pixel_data, preds.shape[0], alpha=alpha) |
|
|
|
lp_preds = perform_lp(L, preds) |
|
|
|
return lp_preds.reshape((h_img, w_img, num_classes)).permute((2, 0, 1)).unsqueeze(0) |
|
|