unit_test / uniperceiver /utils /transformer_util.py
herrius's picture
Upload 259 files
32b542e
import torch
from torch import nn
import math
import warnings
from torch.nn import init
import numpy as np
from uniperceiver.utils import comm
INIT_STD = 0.02
INIT_EMBEDDING_STD = 0.02
def null_loss_check(outputs_dict):
ret = {}
if 'null_loss' in outputs_dict:
null_loss = outputs_dict['null_loss']
else:
null_loss = 0
for shared_target in outputs_dict['shared_target_sets'].values():
null_loss += torch.sum(shared_target[0]['data']*0)
ret.update({'null_loss': null_loss})
return ret
def build_2d_sincos_position_embedding(cfg, video_embed, cls_token=False, temperature=10000., pos_emd_fix=False):
h, w = int(video_embed.max_spatial_size**.5), int(video_embed.max_spatial_size**.5)
grid_w = torch.arange(w, dtype=torch.float32)
grid_h = torch.arange(h, dtype=torch.float32)
grid_w, grid_h = torch.meshgrid(grid_w, grid_h)
if cfg.MODEL.POSEMBED_SCALE != 1.0:
grid_w = grid_w * cfg.MODEL.POSEMBED_SCALE
grid_h = grid_h * cfg.MODEL.POSEMBED_SCALE
assert cfg.MODEL.BERT.HIDDEN_SIZE % 4 == 0, 'Embed dimension must be divisible by 4 for 2D sin-cos position embedding'
pos_dim = cfg.MODEL.BERT.HIDDEN_SIZE // 4
omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
omega = 1. / (temperature**omega)
out_w = torch.einsum('m,d->md', [grid_w.flatten(), omega])
out_h = torch.einsum('m,d->md', [grid_h.flatten(), omega])
pos_emb = torch.cat([
torch.sin(out_w),
torch.cos(out_w),
torch.sin(out_h),
torch.cos(out_h)
],
dim=1)[ :, :]
# assert self.num_tokens == 1, 'Assuming one and only one token, [cls]'
if cls_token:
pe_token = torch.zeros([ 1, cfg.MODEL.BERT.HIDDEN_SIZE], dtype=torch.float32)
video_embed.embeddings_st_pos.spatial_pos_embed.weight = nn.Parameter(torch.cat([pe_token, pos_emb], dim=0))
else:
video_embed.embeddings_st_pos.spatial_pos_embed.weight = nn.Parameter(pos_emb)
if cfg.MODEL.POSEMBEDFIX:
video_embed.embeddings_st_pos.spatial_pos_embed.weight.requires_grad = False
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1. + math.erf(x / math.sqrt(2.))) / 2.
if (mean < a - 2 * std) or (mean > b + 2 * std):
warnings.warn(
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
"The distribution of values may be incorrect.",
stacklevel=2)
with torch.no_grad():
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
tensor.uniform_(2 * l - 1, 2 * u - 1)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor.erfinv_()
# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.))
tensor.add_(mean)
# Clamp to ensure it's in the proper range
tensor.clamp_(min=a, max=b)
return tensor
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
def truncated_normal_(tensor, mode='fan_in',):
# with FSDP, module params will be on CUDA, so we cast them back to CPU
# so that the RNG is consistent with and without FSDP
fan = init._calculate_correct_fan(tensor, mode=mode)
gain = 0.1
std = math.sqrt(gain/fan)
init.trunc_normal_(tensor, mean=0.0, std=std)
def normal_(data):
# with FSDP, module params will be on CUDA, so we cast them back to CPU
# so that the RNG is consistent with and without FSDP
data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device))
def init_bert_params(module):
if isinstance(module, nn.Linear):
normal_(module.weight.data)
if module.bias is not None:
module.bias.data.zero_()
if isinstance(module, nn.Embedding):
normal_(module.weight.data)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
if isinstance(module, nn.MultiheadAttention):
# normal_(module.q_proj.weight.data)
# normal_(module.k_proj.weight.data)
# normal_(module.v_proj.weight.data)
normal_(module.in_proj_weight.data)
def init_switchtransformer_params(module):
if isinstance(module, nn.Linear):
truncated_normal_(module.weight)
if module.bias is not None:
module.bias.data.zero_()
if isinstance(module, nn.Embedding):
normal_(module.weight.data)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def init_timm_params(m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=INIT_STD)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
if isinstance(m, nn.Embedding):
trunc_normal_(m.weight.data, std=INIT_EMBEDDING_STD)
if m.padding_idx is not None:
m.weight.data[m.padding_idx].zero_()
if isinstance(m, nn.MultiheadAttention):
trunc_normal_(m.q_proj.weight.data, std=INIT_STD)
trunc_normal_(m.k_proj.weight.data, std=INIT_STD)
trunc_normal_(m.v_proj.weight.data, std=INIT_STD)
def initialize_weights_as_mae(model):
# initialization
# initialize nn.Linear and nn.LayerNorm
model.apply(init_weights_mae)
# initialize (and freeze) pos_embed by sin-cos embedding
if model.video_embed is not None:
build_2d_sincos_position_embedding(model.cfg, model.video_embed)
# initialize patch_embed like nn.Linear (instead of nn.Conv2d)
w = model.video_embed.embeddings.weight.data
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
if model.video_embed.embeddings.bias is not None:
nn.init.zeros_(model.video_embed.embeddings.bias)
def initialize_weights_as_mocov3(model):
model.initialize_weights_as_mae()
# cls token with smaller std
# temp = torch.zeros([ 1, self.cfg.MODEL.BERT.HIDDEN_SIZE], dtype=torch.float32)
nn.init.normal_(model.token_embed.embeddings.weight[-1, :], std=1e-6) # small std for cls token
def init_weights_mae(m):
# timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
# torch.nn.init.normal_(self.cls_token, std=.02)
if isinstance(m, nn.Linear):
# we use xavier_uniform following official JAX ViT:
if m.weight.shape[0] == m.weight.shape[1] * 3:
# treat the weights of Q, K, V separately
val = math.sqrt(6. / float(m.weight.shape[0] // 3 + m.weight.shape[1]))
nn.init.uniform_(m.weight, -val, val)
else:
torch.nn.init.xavier_uniform_(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
# all word embedding e.g. word. spe. type embedding postion embed
# MAE only has embedding like cls_token and mask tokens
elif isinstance(m, nn.Embedding):
torch.nn.init.normal_(m.weight.data, std=INIT_EMBEDDING_STD)
if m.padding_idx is not None:
m.weight.data[m.padding_idx].zero_()
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.MultiheadAttention):
if m.q_proj_weight is not None:
torch.nn.init.xavier_uniform_(m.q_proj_weight.data)
torch.nn.init.xavier_uniform_(m.k_proj_weight.data)
torch.nn.init.xavier_uniform_(m.v_proj_weight.data)
else:
# treat the weights of Q, K, V separately
val = math.sqrt(6. / float(m.in_proj_weight.shape[0] // 3 + m.in_proj_weight.shape[1]))
nn.init.uniform_(m.in_proj_weight, -val, val)
def data_half(fp16, bf16, data):
if fp16:
for k, v in data.items():
if isinstance(v, torch.Tensor) and v.dtype == torch.float32:
data[k] = v.half()
# print(k)
elif bf16:
for k, v in data.items():
if isinstance(v, torch.Tensor) and v.dtype == torch.float32:
data[k] = v.to(torch.bfloat16)
# print(k)
return data
def postprocess(data_dict:dict, task_info:dict ):
if data_dict.get('sample_info', None) is not None and data_dict['sample_info'].get('distributed', False):
data = data_dict['data']
hidden_states = data[:, 0].contiguous(
) # HERE only use the spe token feature!
hidden_states = torch.cat(torch.distributed.nn.all_gather(hidden_states))
total_length = data_dict['sample_info']['total_num']
if hidden_states.shape[0] > total_length:
hidden_states = hidden_states[:total_length]
data_dict['data'] = hidden_states.unsqueeze(1)
def get_spe_token(tokenizer, token_embed):
if comm.old_checkpoint:
a = torch.tensor(tokenizer.encode('<|spe|>')).cuda().unsqueeze(0) # bs, 1
return token_embed(a, type_embed=False, pos_embed=False)
else:
a = torch.tensor(tokenizer.encode('spe')).cuda().unsqueeze(0) # bs, 1
return token_embed(a)
def preprocess(tokenizer, token_embed, data_list:list, task_info:dict):
# perparation for fused_encoder input
bs = data_list[0]['data'].shape[0]
device = data_list[0]['data'].device
mask_dtype = torch.uint8
#TODO: prompt embedding
prefix_spe_before_fuse = task_info.get('prefix_spe_before_fuse', True)
combined_data = []
# spe embedding
spe_token = get_spe_token(tokenizer, token_embed).expand(bs, -1, -1)
length = [ data_dict['data'].shape[1] for data_dict in data_list]
if prefix_spe_before_fuse:
length = [1] + length
combined_data.append(spe_token)
cum_length = np.cumsum(length).tolist()
invalid_mask_active = any([ data_dict.get('invalid_mask', None) is not None for data_dict in data_list])
if invalid_mask_active:
combined_valid_mask = torch.zeros((bs, cum_length[-1]), dtype=mask_dtype, device=device)
else:
combined_valid_mask = None
for i, data_dict in enumerate(data_list):
combined_data.append(data_dict['data'])
if data_dict.get('invalid_mask', None) is not None:
combined_valid_mask[:, cum_length[i]:cum_length[i+1]] = data_dict['invalid_mask']
combined_data = torch.cat(combined_data, dim=1)
sample_info = {
'data_length': length,
'data_cum_length': cum_length,
'sample_info_per_sample': []}
# for caption task inference
if comm._CAPTION_GEN_MODE:
sample_info['data_cum_length'] = data_list[0]['sample_info']['data_cum_length']
for data_dict in data_list:
if data_dict.get('sample_info', None) is not None:
if isinstance(data_dict['sample_info'], dict):
sample_info.update(data_dict['sample_info'])
elif isinstance(data_dict['sample_info'], list):
if isinstance(data_dict['sample_info'][0], dict):
sample_info.update(data_dict['sample_info'][0])
sample_info['sample_info_per_sample'].append(data_dict['sample_info'])
moe_embedding = None
for data_dict in data_list:
if 'data_type' in data_dict:
data_type = data_dict['data_type']
if 'moe_embedding' in data_dict:
moe_embedding = data_dict['moe_embedding']
return {
'data': combined_data,
'invalid_mask': combined_valid_mask,
'data_type': data_type,
'sample_info': sample_info,
'moe_embedding': moe_embedding,
}
def share_token_embed_ln(video_embed, token_embed):
if video_embed is not None and token_embed is not None:
del video_embed.embeddings_norm
video_embed.embeddings_norm = token_embed.embeddings_norm