|
import torch |
|
from torch import nn |
|
import math |
|
import warnings |
|
from torch.nn import init |
|
import numpy as np |
|
from uniperceiver.utils import comm |
|
|
|
INIT_STD = 0.02 |
|
INIT_EMBEDDING_STD = 0.02 |
|
|
|
def null_loss_check(outputs_dict): |
|
ret = {} |
|
if 'null_loss' in outputs_dict: |
|
null_loss = outputs_dict['null_loss'] |
|
else: |
|
null_loss = 0 |
|
for shared_target in outputs_dict['shared_target_sets'].values(): |
|
null_loss += torch.sum(shared_target[0]['data']*0) |
|
ret.update({'null_loss': null_loss}) |
|
return ret |
|
|
|
def build_2d_sincos_position_embedding(cfg, video_embed, cls_token=False, temperature=10000., pos_emd_fix=False): |
|
h, w = int(video_embed.max_spatial_size**.5), int(video_embed.max_spatial_size**.5) |
|
|
|
grid_w = torch.arange(w, dtype=torch.float32) |
|
grid_h = torch.arange(h, dtype=torch.float32) |
|
grid_w, grid_h = torch.meshgrid(grid_w, grid_h) |
|
if cfg.MODEL.POSEMBED_SCALE != 1.0: |
|
grid_w = grid_w * cfg.MODEL.POSEMBED_SCALE |
|
grid_h = grid_h * cfg.MODEL.POSEMBED_SCALE |
|
|
|
assert cfg.MODEL.BERT.HIDDEN_SIZE % 4 == 0, 'Embed dimension must be divisible by 4 for 2D sin-cos position embedding' |
|
pos_dim = cfg.MODEL.BERT.HIDDEN_SIZE // 4 |
|
omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim |
|
omega = 1. / (temperature**omega) |
|
out_w = torch.einsum('m,d->md', [grid_w.flatten(), omega]) |
|
out_h = torch.einsum('m,d->md', [grid_h.flatten(), omega]) |
|
pos_emb = torch.cat([ |
|
torch.sin(out_w), |
|
torch.cos(out_w), |
|
torch.sin(out_h), |
|
torch.cos(out_h) |
|
], |
|
dim=1)[ :, :] |
|
|
|
|
|
if cls_token: |
|
pe_token = torch.zeros([ 1, cfg.MODEL.BERT.HIDDEN_SIZE], dtype=torch.float32) |
|
video_embed.embeddings_st_pos.spatial_pos_embed.weight = nn.Parameter(torch.cat([pe_token, pos_emb], dim=0)) |
|
else: |
|
video_embed.embeddings_st_pos.spatial_pos_embed.weight = nn.Parameter(pos_emb) |
|
if cfg.MODEL.POSEMBEDFIX: |
|
video_embed.embeddings_st_pos.spatial_pos_embed.weight.requires_grad = False |
|
|
|
|
|
def _no_grad_trunc_normal_(tensor, mean, std, a, b): |
|
|
|
|
|
def norm_cdf(x): |
|
|
|
return (1. + math.erf(x / math.sqrt(2.))) / 2. |
|
|
|
if (mean < a - 2 * std) or (mean > b + 2 * std): |
|
warnings.warn( |
|
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " |
|
"The distribution of values may be incorrect.", |
|
stacklevel=2) |
|
|
|
with torch.no_grad(): |
|
|
|
|
|
|
|
l = norm_cdf((a - mean) / std) |
|
u = norm_cdf((b - mean) / std) |
|
|
|
|
|
|
|
tensor.uniform_(2 * l - 1, 2 * u - 1) |
|
|
|
|
|
|
|
tensor.erfinv_() |
|
|
|
|
|
tensor.mul_(std * math.sqrt(2.)) |
|
tensor.add_(mean) |
|
|
|
|
|
tensor.clamp_(min=a, max=b) |
|
return tensor |
|
|
|
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): |
|
return _no_grad_trunc_normal_(tensor, mean, std, a, b) |
|
|
|
def truncated_normal_(tensor, mode='fan_in',): |
|
|
|
|
|
fan = init._calculate_correct_fan(tensor, mode=mode) |
|
gain = 0.1 |
|
std = math.sqrt(gain/fan) |
|
init.trunc_normal_(tensor, mean=0.0, std=std) |
|
|
|
def normal_(data): |
|
|
|
|
|
data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device)) |
|
|
|
def init_bert_params(module): |
|
if isinstance(module, nn.Linear): |
|
normal_(module.weight.data) |
|
if module.bias is not None: |
|
module.bias.data.zero_() |
|
if isinstance(module, nn.Embedding): |
|
normal_(module.weight.data) |
|
if module.padding_idx is not None: |
|
module.weight.data[module.padding_idx].zero_() |
|
if isinstance(module, nn.MultiheadAttention): |
|
|
|
|
|
|
|
normal_(module.in_proj_weight.data) |
|
|
|
def init_switchtransformer_params(module): |
|
if isinstance(module, nn.Linear): |
|
truncated_normal_(module.weight) |
|
if module.bias is not None: |
|
module.bias.data.zero_() |
|
if isinstance(module, nn.Embedding): |
|
normal_(module.weight.data) |
|
if module.padding_idx is not None: |
|
module.weight.data[module.padding_idx].zero_() |
|
|
|
|
|
def init_timm_params(m): |
|
if isinstance(m, nn.Linear): |
|
trunc_normal_(m.weight, std=INIT_STD) |
|
if m.bias is not None: |
|
nn.init.constant_(m.bias, 0) |
|
if isinstance(m, nn.Embedding): |
|
trunc_normal_(m.weight.data, std=INIT_EMBEDDING_STD) |
|
if m.padding_idx is not None: |
|
m.weight.data[m.padding_idx].zero_() |
|
if isinstance(m, nn.MultiheadAttention): |
|
trunc_normal_(m.q_proj.weight.data, std=INIT_STD) |
|
trunc_normal_(m.k_proj.weight.data, std=INIT_STD) |
|
trunc_normal_(m.v_proj.weight.data, std=INIT_STD) |
|
|
|
def initialize_weights_as_mae(model): |
|
|
|
|
|
|
|
model.apply(init_weights_mae) |
|
|
|
|
|
if model.video_embed is not None: |
|
build_2d_sincos_position_embedding(model.cfg, model.video_embed) |
|
|
|
|
|
|
|
w = model.video_embed.embeddings.weight.data |
|
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) |
|
|
|
if model.video_embed.embeddings.bias is not None: |
|
nn.init.zeros_(model.video_embed.embeddings.bias) |
|
|
|
|
|
def initialize_weights_as_mocov3(model): |
|
model.initialize_weights_as_mae() |
|
|
|
|
|
|
|
nn.init.normal_(model.token_embed.embeddings.weight[-1, :], std=1e-6) |
|
|
|
|
|
def init_weights_mae(m): |
|
|
|
|
|
if isinstance(m, nn.Linear): |
|
|
|
|
|
if m.weight.shape[0] == m.weight.shape[1] * 3: |
|
|
|
val = math.sqrt(6. / float(m.weight.shape[0] // 3 + m.weight.shape[1])) |
|
nn.init.uniform_(m.weight, -val, val) |
|
else: |
|
torch.nn.init.xavier_uniform_(m.weight) |
|
if isinstance(m, nn.Linear) and m.bias is not None: |
|
nn.init.constant_(m.bias, 0) |
|
|
|
|
|
|
|
elif isinstance(m, nn.Embedding): |
|
torch.nn.init.normal_(m.weight.data, std=INIT_EMBEDDING_STD) |
|
if m.padding_idx is not None: |
|
m.weight.data[m.padding_idx].zero_() |
|
|
|
|
|
elif isinstance(m, nn.LayerNorm): |
|
nn.init.constant_(m.bias, 0) |
|
nn.init.constant_(m.weight, 1.0) |
|
|
|
elif isinstance(m, nn.MultiheadAttention): |
|
if m.q_proj_weight is not None: |
|
torch.nn.init.xavier_uniform_(m.q_proj_weight.data) |
|
torch.nn.init.xavier_uniform_(m.k_proj_weight.data) |
|
torch.nn.init.xavier_uniform_(m.v_proj_weight.data) |
|
else: |
|
|
|
val = math.sqrt(6. / float(m.in_proj_weight.shape[0] // 3 + m.in_proj_weight.shape[1])) |
|
nn.init.uniform_(m.in_proj_weight, -val, val) |
|
|
|
def data_half(fp16, bf16, data): |
|
if fp16: |
|
for k, v in data.items(): |
|
if isinstance(v, torch.Tensor) and v.dtype == torch.float32: |
|
data[k] = v.half() |
|
|
|
|
|
elif bf16: |
|
for k, v in data.items(): |
|
if isinstance(v, torch.Tensor) and v.dtype == torch.float32: |
|
data[k] = v.to(torch.bfloat16) |
|
|
|
|
|
return data |
|
|
|
def postprocess(data_dict:dict, task_info:dict ): |
|
if data_dict.get('sample_info', None) is not None and data_dict['sample_info'].get('distributed', False): |
|
data = data_dict['data'] |
|
hidden_states = data[:, 0].contiguous( |
|
) |
|
hidden_states = torch.cat(torch.distributed.nn.all_gather(hidden_states)) |
|
|
|
total_length = data_dict['sample_info']['total_num'] |
|
|
|
if hidden_states.shape[0] > total_length: |
|
hidden_states = hidden_states[:total_length] |
|
|
|
data_dict['data'] = hidden_states.unsqueeze(1) |
|
|
|
|
|
def get_spe_token(tokenizer, token_embed): |
|
if comm.old_checkpoint: |
|
a = torch.tensor(tokenizer.encode('<|spe|>')).cuda().unsqueeze(0) |
|
return token_embed(a, type_embed=False, pos_embed=False) |
|
else: |
|
a = torch.tensor(tokenizer.encode('spe')).cuda().unsqueeze(0) |
|
return token_embed(a) |
|
|
|
def preprocess(tokenizer, token_embed, data_list:list, task_info:dict): |
|
|
|
bs = data_list[0]['data'].shape[0] |
|
device = data_list[0]['data'].device |
|
mask_dtype = torch.uint8 |
|
|
|
|
|
|
|
prefix_spe_before_fuse = task_info.get('prefix_spe_before_fuse', True) |
|
|
|
combined_data = [] |
|
|
|
spe_token = get_spe_token(tokenizer, token_embed).expand(bs, -1, -1) |
|
|
|
length = [ data_dict['data'].shape[1] for data_dict in data_list] |
|
if prefix_spe_before_fuse: |
|
length = [1] + length |
|
combined_data.append(spe_token) |
|
|
|
cum_length = np.cumsum(length).tolist() |
|
|
|
invalid_mask_active = any([ data_dict.get('invalid_mask', None) is not None for data_dict in data_list]) |
|
if invalid_mask_active: |
|
|
|
combined_valid_mask = torch.zeros((bs, cum_length[-1]), dtype=mask_dtype, device=device) |
|
else: |
|
combined_valid_mask = None |
|
|
|
for i, data_dict in enumerate(data_list): |
|
combined_data.append(data_dict['data']) |
|
if data_dict.get('invalid_mask', None) is not None: |
|
combined_valid_mask[:, cum_length[i]:cum_length[i+1]] = data_dict['invalid_mask'] |
|
|
|
combined_data = torch.cat(combined_data, dim=1) |
|
|
|
sample_info = { |
|
'data_length': length, |
|
'data_cum_length': cum_length, |
|
'sample_info_per_sample': []} |
|
|
|
|
|
if comm._CAPTION_GEN_MODE: |
|
sample_info['data_cum_length'] = data_list[0]['sample_info']['data_cum_length'] |
|
|
|
|
|
for data_dict in data_list: |
|
if data_dict.get('sample_info', None) is not None: |
|
if isinstance(data_dict['sample_info'], dict): |
|
sample_info.update(data_dict['sample_info']) |
|
elif isinstance(data_dict['sample_info'], list): |
|
if isinstance(data_dict['sample_info'][0], dict): |
|
sample_info.update(data_dict['sample_info'][0]) |
|
sample_info['sample_info_per_sample'].append(data_dict['sample_info']) |
|
|
|
moe_embedding = None |
|
for data_dict in data_list: |
|
if 'data_type' in data_dict: |
|
data_type = data_dict['data_type'] |
|
if 'moe_embedding' in data_dict: |
|
moe_embedding = data_dict['moe_embedding'] |
|
|
|
return { |
|
'data': combined_data, |
|
'invalid_mask': combined_valid_mask, |
|
'data_type': data_type, |
|
'sample_info': sample_info, |
|
'moe_embedding': moe_embedding, |
|
} |
|
|
|
def share_token_embed_ln(video_embed, token_embed): |
|
if video_embed is not None and token_embed is not None: |
|
del video_embed.embeddings_norm |
|
video_embed.embeddings_norm = token_embed.embeddings_norm |
|
|