LPOSS / lposs.py
stojnvla's picture
update hyper
be73ac3
from itertools import product
import torch
import torch.nn as nn
import torch.nn.functional as F
try:
import cupy as cp
from cupyx.scipy.sparse import csr_matrix as cp_csr_matrix, eye as cp_eye, diags as cp_diags
from cupyx.scipy.sparse import linalg as cp_s_linalg
except ImportError:
print("Cupy not installed")
import numpy as np
from scipy.sparse import csr_matrix, eye, diags
from scipy.sparse import linalg as s_linalg
from kornia.color import rgb_to_lab
def make_input_divisible(x: torch.Tensor, patch_size=16) -> torch.Tensor:
"""Pad some pixels to make the input size divisible by the patch size."""
B, _, H_0, W_0 = x.shape
pad_w = (patch_size - W_0 % patch_size) % patch_size
pad_h = (patch_size - H_0 % patch_size) % patch_size
x = nn.functional.pad(x, (0, pad_w, 0, pad_h), value=0)
return x
def reshape_windows(x):
height_width = [(y.shape[0], y.shape[1]) for y in x]
dim = x[0].shape[-1]
x = [torch.reshape(y, (-1, dim)) for y in x]
return torch.cat(x, dim=0), height_width
def normalize_connection_graph_cupy(G):
W = cp_csr_matrix(G)
W = W - cp_diags(W.diagonal(), 0)
S = W.sum(axis=1)
# breakpoint()
S[S == 0] = 1
D = cp.array(1.0 / cp.sqrt(S))
D[cp.isnan(D)] = 0
D[cp.isinf(D)] = 0
D_mh = cp_diags(D.reshape(-1), 0)
Wn = D_mh * W * D_mh
return Wn
def normalize_connection_graph(G):
W = csr_matrix(G)
W = W - diags(W.diagonal(), 0)
S = W.sum(axis=1)
S[S == 0] = 1
D = np.array(1.0 / np.sqrt(S))
D[np.isnan(D)] = 0
D[np.isinf(D)] = 0
D_mh = diags(D.reshape(-1), 0)
Wn = D_mh * W * D_mh
return Wn
def cp_dfs_search(L, Y, tol=1e-6, maxiter=10):
out = cp_s_linalg.cg(L, Y, tol=tol, maxiter=maxiter)[0]
return out
def dfs_search(L, Y, tol=1e-6, maxiter=10):
out = s_linalg.cg(L, Y, rtol=tol, maxiter=maxiter)[0]
return out
def perform_lp(L, preds):
if torch.cuda.is_available():
lp_preds = cp.zeros(preds.shape)
preds = cp.asarray(preds)
for cls_idx, y_cls in enumerate(preds.T):
Y = y_cls
lp_preds[:, cls_idx] = cp_dfs_search(L, Y)
lp_preds = torch.as_tensor(lp_preds, device="cuda")
else:
lp_preds = np.zeros(preds.shape)
for cls_idx, y_cls in enumerate(preds.T):
Y = y_cls
lp_preds[:, cls_idx] = dfs_search(L, Y)
lp_preds = torch.as_tensor(lp_preds, device="cpu")
return lp_preds
def get_lposs_laplacian(feats, locations, height_width, sigma=0.0, pix_dist_pow=2, k=100, gamma=1.0, alpha=0.95, patch_size=16):
idx_window = torch.cat([window * torch.ones((h*w, ), device=feats.device, dtype=torch.int64) for window, (h, w) in enumerate(height_width)])
idx_h = torch.cat([torch.arange(h).view(-1,1).repeat(1, w).flatten() for h, w in height_width]).to(feats.device)
idx_w = torch.cat([torch.arange(w).view(1,-1).repeat(h, 1).flatten() for h, w in height_width]).to(feats.device)
loc_h = locations[idx_window, 0] + (patch_size // 2) + idx_h * patch_size
loc_w = locations[idx_window, 2] + (patch_size // 2) + idx_w * patch_size
locs = torch.stack((loc_h, loc_w), 1)
locs = torch.unsqueeze(locs, 0)
dist = torch.cdist(locs, locs, p=2)
dist = dist[0, ...]
dist = dist ** pix_dist_pow
geometry_affinity = torch.exp(-sigma * dist)
N = feats.shape[0]
affinity = feats @ feats.T
sims, ks = torch.topk(affinity, k=k, dim=1)
sims[sims < 0] = 0
sims = sims ** gamma
geometry_affinity = geometry_affinity.gather(1, ks).flatten()
sims = sims.flatten()
sims = sims * geometry_affinity
ks = ks.flatten()
rows = torch.arange(N).repeat_interleave(k)
if torch.cuda.is_available():
W = cp_csr_matrix(
(cp.asarray(sims), (cp.asarray(rows), cp.asarray(ks))),
shape=(N, N),
)
W = W + W.T
Wn = normalize_connection_graph_cupy(W)
L = cp_eye(Wn.shape[0]) - alpha * Wn
else:
W = csr_matrix(
(sims.cpu().numpy(), (rows.cpu().numpy(), ks.cpu().numpy())),
shape=(N, N),
)
W = W + W.T
Wn = normalize_connection_graph(W)
L = eye(Wn.shape[0]) - alpha * Wn
return L
def lposs(clip, dino, img, classnames, window_size=(224,224), window_stride=(112, 112), sigma=0.01, pix_dist_pow=1, lp_k_image=400, lp_gamma=3.0, lp_alpha=0.95):
h_stride, w_stride = window_stride
h_crop, w_crop = window_size
batch_size, _, h_img, w_img = img.size()
h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1
w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1
clf = clip.get_classifier(classnames)
locations = img.new_zeros((h_grids*w_grids, 4))
dino_feats = []
clip_feats = []
for h_idx in range(h_grids):
for w_idx in range(w_grids):
y1 = h_idx * h_stride
x1 = w_idx * w_stride
y2 = min(y1 + h_crop, h_img)
x2 = min(x1 + w_crop, w_img)
y1 = max(y2 - h_crop, 0)
x1 = max(x2 - w_crop, 0)
crop_img = img[:, :, y1:y2, x1:x2]
img_dino_feats, (h_dino, w_dino) = dino(make_input_divisible(crop_img, dino.patch_size)) # (1, 768, N)
img_dino_feats = img_dino_feats.reshape((batch_size, -1, h_dino, w_dino)).permute(0, 2, 3, 1) # (1, h_dino, w_dino, 768)
img_clip_feats = clip(make_input_divisible(crop_img, clip.patch_size)) # (1, 512, h, w)
if img_clip_feats.shape[1] != img_dino_feats.shape[1] or img_clip_feats.shape[2] != img_dino_feats.shape[2]:
img_clip_feats = F.interpolate(img_clip_feats, size=(img_dino_feats.shape[1], img_dino_feats.shape[2]), mode='bilinear', align_corners=False)
img_clip_feats = img_clip_feats.permute(0, 2, 3, 1) # (1, h, w, 512)
dino_feats.append(img_dino_feats[0, ...])
clip_feats.append(img_clip_feats[0, ...])
locations[h_idx*w_grids + w_idx, 0] = y1
locations[h_idx*w_grids + w_idx, 1] = y2
locations[h_idx*w_grids + w_idx, 2] = x1
locations[h_idx*w_grids + w_idx, 3] = x2
num_classes = clf.shape[0]
patch_size = dino.patch_size
dino_feats, height_width = reshape_windows(dino_feats)
clip_feats, _ = reshape_windows(clip_feats)
dino_feats = F.normalize(dino_feats, p=2, dim=-1)
clip_feats = F.normalize(clip_feats, p=2, dim=-1)
L = get_lposs_laplacian(dino_feats, locations, height_width, sigma=sigma, pix_dist_pow=pix_dist_pow, k=lp_k_image, gamma=lp_gamma, alpha=lp_alpha, patch_size=patch_size)
clip_preds = clip_feats @ clf.T
lp_preds = perform_lp(L, clip_preds)
preds = img.new_zeros((batch_size, num_classes, h_img, w_img))
count_mat = img.new_zeros((batch_size, 1, h_img, w_img))
idx_window = torch.cat([window * torch.ones((h*w, ), device=dino_feats.device, dtype=torch.int64) for window, (h, w) in enumerate(height_width)])
for h_idx in range(h_grids):
for w_idx in range(w_grids):
y1 = h_idx * h_stride
x1 = w_idx * w_stride
y2 = min(y1 + h_crop, h_img)
x2 = min(x1 + w_crop, w_img)
y1 = max(y2 - h_crop, 0)
x1 = max(x2 - w_crop, 0)
win_id = h_idx*w_grids + w_idx
crop_seg_logit = lp_preds[torch.where(idx_window == win_id)[0], :]
crop_seg_logit = torch.reshape(crop_seg_logit, height_width[win_id]+(num_classes, ))
crop_seg_logit = torch.unsqueeze(crop_seg_logit, 0)
crop_seg_logit = torch.permute(crop_seg_logit, (0, 3, 1, 2))
crop_seg_logit = F.interpolate(
input=crop_seg_logit,
size=(y2-y1, x2-x1),
mode='bilinear',
align_corners=False
)
assert crop_seg_logit.shape[2] == (y2 - y1) and crop_seg_logit.shape[3] == (x2 - x1)
preds += F.pad(crop_seg_logit,
(int(x1), int(preds.shape[3] - x2), int(y1),
int(preds.shape[2] - y2)))
count_mat[:, :, y1:y2, x1:x2] += 1
assert (count_mat == 0).sum() == 0
preds = preds / count_mat
return preds
def get_pixel_connections(img, neigh=1):
img = img[0, ...]
img_lab = rgb_to_lab(img)
img_lab = img_lab.permute((1, 2, 0))
img_lab /= torch.tensor([100, 128, 128], device=img.device) # project Lab values to 0-1 range
img_h, img_w, _ = img_lab.shape
img_lab = img_lab.reshape((img_h*img_w, -1))
idx = torch.arange(img_h * img_w).to(img.device)
loc_h = idx // img_w
loc_w = idx % img_w
locs = torch.stack((loc_h, loc_w), 1)
rows, cols = [], []
for mov in product(range(-neigh, neigh+1), range(-neigh, neigh+1)):
if mov[0] == 0 and mov[1] == 0:
continue
new_locs = locs + torch.tensor(mov).to(img.device)
mask = torch.logical_and(torch.logical_and(torch.logical_and(new_locs[:, 0] >= 0, new_locs[:, 1] >= 0), new_locs[:, 0] < img_h), new_locs[:, 1] < img_w)
rows.append(torch.where(mask)[0])
col = new_locs[mask, :]
col = col[:, 0] * img_w + col[:, 1]
cols.append(col)
rows = torch.cat(rows)
cols = torch.cat(cols)
pixel_pixel_data = ((img_lab[rows, :] - img_lab[cols, :]) ** 2).sum(dim=-1)
return rows, cols, pixel_pixel_data, locs
def get_laplacian(rows, cols, data, N, alpha=0.99):
if torch.cuda.is_available():
rows = cp.asarray(rows)
cols = cp.asarray(cols)
data = cp.asarray(data)
W = cp_csr_matrix(
(data, (rows, cols)),
shape=(N, N),
)
Wn = normalize_connection_graph_cupy(W)
L = cp_eye(Wn.shape[0]) - alpha * Wn
else:
W = csr_matrix(
(data.cpu().numpy(), (rows.cpu().numpy(), cols.cpu().numpy())),
shape=(N, N),
)
Wn = normalize_connection_graph(W)
L = eye(Wn.shape[0]) - alpha * Wn
return L
def lposs_plus(img, preds, tau=0.01, alpha=0.95, r=13):
preds = preds[0, ...]
num_classes, h_img, w_img = preds.shape
preds = preds.permute((1, 2, 0))
preds = preds.reshape((h_img*w_img, -1))
rows, cols, pixel_pixel_data, locs = get_pixel_connections(img, neigh=r//2)
pixel_pixel_data = torch.sqrt(pixel_pixel_data)
pixel_pixel_data = torch.exp(-pixel_pixel_data / tau)
L = get_laplacian(rows, cols, pixel_pixel_data, preds.shape[0], alpha=alpha)
lp_preds = perform_lp(L, preds)
return lp_preds.reshape((h_img, w_img, num_classes)).permute((2, 0, 1)).unsqueeze(0)