Spaces:
Sleeping
Sleeping
# Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License. | |
import torch | |
import torch.nn as nn | |
from nncore.nn import MODELS, build_model | |
class R2Block(nn.Module): | |
def __init__(self, | |
dims, | |
in_dims, | |
k=4, | |
dropout=0.5, | |
use_tef=True, | |
pos_cfg=None, | |
tem_cfg=None): | |
super(R2Block, self).__init__() | |
# yapf:disable | |
self.video_map = nn.Sequential( | |
nn.LayerNorm((in_dims[0] + 2) if use_tef else in_dims[0]), | |
nn.Dropout(dropout), | |
nn.Linear((in_dims[0] + 2) if use_tef else in_dims[0], dims), | |
nn.ReLU(inplace=True), | |
nn.LayerNorm(dims), | |
nn.Dropout(dropout), | |
nn.Linear(dims, dims)) | |
self.query_map = nn.Sequential( | |
nn.LayerNorm(in_dims[1]), | |
nn.Dropout(dropout), | |
nn.Linear(in_dims[1], dims), | |
nn.ReLU(inplace=True), | |
nn.LayerNorm(dims), | |
nn.Dropout(dropout), | |
nn.Linear(dims, dims)) | |
# yapf:enable | |
if k > 1: | |
self.gate = nn.Parameter(torch.zeros([k - 1])) | |
self.v_map = nn.Linear(dims, dims) | |
self.q_map = nn.Linear(dims, dims) | |
self.scale = nn.Parameter(torch.zeros([k])) | |
self.pos = build_model(pos_cfg, dims=dims) | |
self.tem = build_model(tem_cfg, dims=dims) | |
self.dims = dims | |
self.in_dims = in_dims | |
self.k = k | |
self.dropout = dropout | |
self.use_tef = use_tef | |
def forward(self, video_emb, query_emb, video_msk, query_msk): | |
video_emb = video_emb[-self.k:] | |
query_emb = query_emb[-self.k:] | |
_, b, t, p, _ = video_emb.size() | |
if self.use_tef: | |
tef_s = torch.arange(0, 1, 1 / t, device=video_emb.device) | |
tef_e = tef_s + 1.0 / t | |
tef = torch.stack((tef_s, tef_e), dim=1) | |
tef = tef.unsqueeze(1).unsqueeze(0).unsqueeze(0).repeat(self.k, b, 1, p, 1) | |
video_emb = torch.cat((video_emb, tef[:, :, :video_emb.size(2)]), dim=-1) | |
coll_v, coll_q, last = [], [], None | |
for i in range(self.k - 1, -1, -1): | |
v_emb = self.video_map(video_emb[i]) # B * T * P * C | |
q_emb = self.query_map(query_emb[i]) # B * L * C | |
coll_v.append(v_emb[:, :, 0]) | |
coll_q.append(q_emb) | |
v_pool = v_emb.view(b * t, -1, self.dims) # BT * P * C | |
q_pool = q_emb.repeat_interleave(t, dim=0) # BT * L * C | |
v_pool_map = self.v_map(v_pool) # BT * P * C | |
q_pool_map = self.q_map(q_pool) # BT * L * C | |
att = torch.bmm(q_pool_map, v_pool_map.transpose(1, 2)) / self.dims**0.5 | |
att = att.softmax(-1) # BT * L * P | |
o_pool = torch.bmm(att, v_pool) + q_pool # BT * L * C | |
o_pool = o_pool.amax(dim=1, keepdim=True) # BT * 1 * C | |
v_emb = v_pool[:, 0, None] + o_pool * self.scale[i].tanh() | |
v_emb = v_emb.view(b, t, self.dims) # B * T * C | |
if i < self.k - 1: | |
gate = self.gate[i].sigmoid() | |
v_emb = gate * v_emb + (1 - gate) * last | |
v_pe = self.pos(v_emb) | |
last = self.tem(v_emb, q_emb, q_pe=v_pe, q_mask=video_msk, k_mask=query_msk) | |
return last, q_emb, coll_v, coll_q | |