Spaces:
Sleeping
Sleeping
# Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License. | |
import math | |
import clip | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from nncore.nn import MODELS, build_loss, build_model | |
from .generator import PointGenerator | |
_CLIP_ARCHS = { | |
'ViT-B/32': (768, 512, 50), | |
'ViT-B/16': (768, 512, 197), | |
'ViT-L/14': (1024, 768, 50), | |
'ViT-L/14-336px': (1024, 768, 577) | |
} | |
class R2Tuning(nn.Module): | |
def __init__(self, | |
arch='ViT-B/32', | |
init=True, | |
dims=256, | |
strides=(1, 2, 4, 8), | |
buffer_size=1024, | |
max_num_moment=50, | |
merge_cls_sal=True, | |
adapter_cfg=None, | |
pyramid_cfg=None, | |
pooling_cfg=None, | |
class_head_cfg=None, | |
coord_head_cfg=None, | |
loss_cfg=None): | |
super(R2Tuning, self).__init__() | |
if init: | |
self.clip, _ = clip.load(arch, device='cpu') | |
for param in self.clip.parameters(): | |
param.requires_grad = False | |
self.cfg = _CLIP_ARCHS[arch] | |
self.adapter = build_model(adapter_cfg, dims, self.cfg[:2]) | |
self.pyramid = build_model(pyramid_cfg, dims, strides) | |
self.pooling = build_model(pooling_cfg, dims) | |
self.class_head = build_model(class_head_cfg, dims, 1) | |
self.coord_head = build_model(coord_head_cfg, dims, 2) | |
self.generator = PointGenerator(strides, buffer_size) | |
self.coef = nn.Parameter(torch.ones(len(strides))) | |
self.loss = build_loss(loss_cfg) | |
self.max_num_moment = max_num_moment | |
self.merge_cls_sal = merge_cls_sal | |
def train(self, mode=True): | |
super(R2Tuning, self).train(mode=mode) | |
if hasattr(self, 'clip'): | |
self.clip.eval() | |
def clip_video_tower(self, video): | |
video = video.type(self.clip.dtype) | |
video = self.clip.visual.conv1(video) | |
video = video.reshape(video.size(0), video.size(1), -1).permute(0, 2, 1) | |
c_emb = video.new_zeros(video.size(0), 1, video.size(-1)) | |
c_emb = self.clip.visual.class_embedding.to(video.dtype) + c_emb | |
video = torch.cat((c_emb, video), dim=1) | |
video = video + self.clip.visual.positional_embedding.to(video.dtype) | |
video = self.clip.visual.ln_pre(video).permute(1, 0, 2) | |
emb = [video] | |
for blk in self.clip.visual.transformer.resblocks: | |
emb.append(blk(emb[-1])) | |
video = torch.stack([e.permute(1, 0, 2) for e in emb]) | |
return video | |
def clip_query_tower(self, query): | |
query = self.clip.token_embedding(query).type(self.clip.dtype) | |
query = query + self.clip.positional_embedding.type(self.clip.dtype) | |
query = query.permute(1, 0, 2) | |
emb = [query] | |
for blk in self.clip.transformer.resblocks: | |
emb.append(blk(emb[-1])) | |
query = torch.stack([e.permute(1, 0, 2) for e in emb]) | |
return query | |
def forward(self, data, mode='test'): | |
video, query = data['video'], data['query'] | |
if hasattr(self, 'clip'): | |
video_msk = torch.where(video[:, :, 0].isfinite(), 1, 0) | |
query_msk = torch.where(query == 0, 0, 1) | |
video[~video.isfinite()] = 0 | |
(b, t), d = video.size()[:2], int(math.sqrt(video.size(2) / 3)) | |
video = video.view(b * t, 3, d, d) | |
video_emb = self.clip_video_tower(video) | |
query_emb = self.clip_query_tower(query) | |
n, _, p, c = video_emb.size() | |
video_emb = video_emb.view(n, b, t, p, c) | |
else: | |
video_msk = torch.where(video[:, :, 0].isfinite(), 1, 0) | |
query_msk = torch.where(query[:, :, 0].isfinite(), 1, 0) | |
video[~video.isfinite()] = 0 | |
query[~query.isfinite()] = 0 | |
(b, t), l = video.size()[:2], query.size(1) | |
video = video.view(b, t, -1, self.cfg[2], self.cfg[0]).permute(2, 0, 1, 3, 4) | |
query = query.view(b, l, -1, self.cfg[1]).permute(2, 0, 1, 3) | |
video_emb = video.float() | |
query_emb = query.float() | |
# video_emb: N * B * T * P * C | |
# query_emb: N * B * L * C | |
video_emb, query_emb, coll_v, coll_q = self.adapter(video_emb, query_emb, | |
video_msk, query_msk) | |
pymid, pymid_msk = self.pyramid(video_emb, video_msk, return_mask=mode != 'test') | |
point = self.generator(pymid) | |
with torch.autocast('cuda', enabled=False): | |
video_emb = video_emb.float() | |
query_emb = self.pooling(query_emb.float(), query_msk) | |
out_class = [self.class_head(e.float()) for e in pymid] | |
out_class = torch.cat(out_class, dim=1) | |
if self.coord_head is not None: | |
out_coord = [ | |
self.coord_head(e.float()).exp() * self.coef[i] | |
for i, e in enumerate(pymid) | |
] | |
out_coord = torch.cat(out_coord, dim=1) | |
else: | |
out_coord = None | |
output = dict(_avg_factor=b) | |
if mode != 'test': | |
data['coll_v'] = [e.float() for e in coll_v] | |
data['coll_q'] = [self.pooling(e.float(), query_msk) for e in coll_q] | |
data['point'] = point | |
data['video_emb'] = video_emb | |
data['query_emb'] = query_emb | |
data['video_msk'] = video_msk | |
data['pymid_msk'] = pymid_msk | |
data['out_class'] = out_class | |
data['out_coord'] = out_coord | |
output = self.loss(data, output) | |
if mode != 'train': | |
assert b == 1, 'batch size larger than 1 is not supported for inference' | |
out_class = out_class.sigmoid() | |
out_score = F.cosine_similarity(video_emb, query_emb, dim=-1) | |
output['_out'] = dict(label=data.get('label', [None])[0]) | |
pyd_shape = [e.size(1) for e in pymid] | |
pyd_class = out_class[0, :, 0].split(pyd_shape) | |
saliency = [] | |
for shape, score in zip(pyd_shape, pyd_class): | |
if t >= shape: | |
score = score.repeat_interleave(int(t / shape)) | |
postfix = score[-1:].repeat(t - score.size(0)) | |
score = torch.cat((score, postfix)) | |
else: | |
scale = int(shape / t) | |
score = F.max_pool1d(score.unsqueeze(0), scale, stride=scale)[0] | |
saliency.append(score) | |
saliency = torch.stack(saliency).amax(dim=0) | |
if self.merge_cls_sal: | |
saliency *= out_score[0] | |
output['_out']['saliency'] = saliency | |
if self.coord_head is not None: | |
boundary = out_coord[0] | |
boundary[:, 0] *= -1 | |
boundary *= point[:, 3, None].repeat(1, 2) | |
boundary += point[:, 0, None].repeat(1, 2) | |
boundary /= data['fps'][0] | |
boundary = torch.cat((boundary, out_class[0]), dim=-1) | |
_, inds = out_class[0, :, 0].sort(descending=True) | |
boundary = boundary[inds[:self.max_num_moment]] | |
output['_out']['boundary'] = boundary | |
return output | |