import torch from torch import nn import math import warnings from torch.nn import init import numpy as np from uniperceiver.utils import comm INIT_STD = 0.02 INIT_EMBEDDING_STD = 0.02 def null_loss_check(outputs_dict): ret = {} if 'null_loss' in outputs_dict: null_loss = outputs_dict['null_loss'] else: null_loss = 0 for shared_target in outputs_dict['shared_target_sets'].values(): null_loss += torch.sum(shared_target[0]['data']*0) ret.update({'null_loss': null_loss}) return ret def build_2d_sincos_position_embedding(cfg, video_embed, cls_token=False, temperature=10000., pos_emd_fix=False): h, w = int(video_embed.max_spatial_size**.5), int(video_embed.max_spatial_size**.5) grid_w = torch.arange(w, dtype=torch.float32) grid_h = torch.arange(h, dtype=torch.float32) grid_w, grid_h = torch.meshgrid(grid_w, grid_h) if cfg.MODEL.POSEMBED_SCALE != 1.0: grid_w = grid_w * cfg.MODEL.POSEMBED_SCALE grid_h = grid_h * cfg.MODEL.POSEMBED_SCALE assert cfg.MODEL.BERT.HIDDEN_SIZE % 4 == 0, 'Embed dimension must be divisible by 4 for 2D sin-cos position embedding' pos_dim = cfg.MODEL.BERT.HIDDEN_SIZE // 4 omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim omega = 1. / (temperature**omega) out_w = torch.einsum('m,d->md', [grid_w.flatten(), omega]) out_h = torch.einsum('m,d->md', [grid_h.flatten(), omega]) pos_emb = torch.cat([ torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h) ], dim=1)[ :, :] # assert self.num_tokens == 1, 'Assuming one and only one token, [cls]' if cls_token: pe_token = torch.zeros([ 1, cfg.MODEL.BERT.HIDDEN_SIZE], dtype=torch.float32) video_embed.embeddings_st_pos.spatial_pos_embed.weight = nn.Parameter(torch.cat([pe_token, pos_emb], dim=0)) else: video_embed.embeddings_st_pos.spatial_pos_embed.weight = nn.Parameter(pos_emb) if cfg.MODEL.POSEMBEDFIX: video_embed.embeddings_st_pos.spatial_pos_embed.weight.requires_grad = False def _no_grad_trunc_normal_(tensor, mean, std, a, b): # Cut & paste from PyTorch official master until it's in a few official releases - RW # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf def norm_cdf(x): # Computes standard normal cumulative distribution function return (1. + math.erf(x / math.sqrt(2.))) / 2. if (mean < a - 2 * std) or (mean > b + 2 * std): warnings.warn( "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " "The distribution of values may be incorrect.", stacklevel=2) with torch.no_grad(): # Values are generated by using a truncated uniform distribution and # then using the inverse CDF for the normal distribution. # Get upper and lower cdf values l = norm_cdf((a - mean) / std) u = norm_cdf((b - mean) / std) # Uniformly fill tensor with values from [l, u], then translate to # [2l-1, 2u-1]. tensor.uniform_(2 * l - 1, 2 * u - 1) # Use inverse cdf transform for normal distribution to get truncated # standard normal tensor.erfinv_() # Transform to proper mean, std tensor.mul_(std * math.sqrt(2.)) tensor.add_(mean) # Clamp to ensure it's in the proper range tensor.clamp_(min=a, max=b) return tensor def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): return _no_grad_trunc_normal_(tensor, mean, std, a, b) def truncated_normal_(tensor, mode='fan_in',): # with FSDP, module params will be on CUDA, so we cast them back to CPU # so that the RNG is consistent with and without FSDP fan = init._calculate_correct_fan(tensor, mode=mode) gain = 0.1 std = math.sqrt(gain/fan) init.trunc_normal_(tensor, mean=0.0, std=std) def normal_(data): # with FSDP, module params will be on CUDA, so we cast them back to CPU # so that the RNG is consistent with and without FSDP data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device)) def init_bert_params(module): if isinstance(module, nn.Linear): normal_(module.weight.data) if module.bias is not None: module.bias.data.zero_() if isinstance(module, nn.Embedding): normal_(module.weight.data) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() if isinstance(module, nn.MultiheadAttention): # normal_(module.q_proj.weight.data) # normal_(module.k_proj.weight.data) # normal_(module.v_proj.weight.data) normal_(module.in_proj_weight.data) def init_switchtransformer_params(module): if isinstance(module, nn.Linear): truncated_normal_(module.weight) if module.bias is not None: module.bias.data.zero_() if isinstance(module, nn.Embedding): normal_(module.weight.data) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() def init_timm_params(m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=INIT_STD) if m.bias is not None: nn.init.constant_(m.bias, 0) if isinstance(m, nn.Embedding): trunc_normal_(m.weight.data, std=INIT_EMBEDDING_STD) if m.padding_idx is not None: m.weight.data[m.padding_idx].zero_() if isinstance(m, nn.MultiheadAttention): trunc_normal_(m.q_proj.weight.data, std=INIT_STD) trunc_normal_(m.k_proj.weight.data, std=INIT_STD) trunc_normal_(m.v_proj.weight.data, std=INIT_STD) def initialize_weights_as_mae(model): # initialization # initialize nn.Linear and nn.LayerNorm model.apply(init_weights_mae) # initialize (and freeze) pos_embed by sin-cos embedding if model.video_embed is not None: build_2d_sincos_position_embedding(model.cfg, model.video_embed) # initialize patch_embed like nn.Linear (instead of nn.Conv2d) w = model.video_embed.embeddings.weight.data torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) if model.video_embed.embeddings.bias is not None: nn.init.zeros_(model.video_embed.embeddings.bias) def initialize_weights_as_mocov3(model): model.initialize_weights_as_mae() # cls token with smaller std # temp = torch.zeros([ 1, self.cfg.MODEL.BERT.HIDDEN_SIZE], dtype=torch.float32) nn.init.normal_(model.token_embed.embeddings.weight[-1, :], std=1e-6) # small std for cls token def init_weights_mae(m): # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) # torch.nn.init.normal_(self.cls_token, std=.02) if isinstance(m, nn.Linear): # we use xavier_uniform following official JAX ViT: if m.weight.shape[0] == m.weight.shape[1] * 3: # treat the weights of Q, K, V separately val = math.sqrt(6. / float(m.weight.shape[0] // 3 + m.weight.shape[1])) nn.init.uniform_(m.weight, -val, val) else: torch.nn.init.xavier_uniform_(m.weight) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) # all word embedding e.g. word. spe. type embedding postion embed # MAE only has embedding like cls_token and mask tokens elif isinstance(m, nn.Embedding): torch.nn.init.normal_(m.weight.data, std=INIT_EMBEDDING_STD) if m.padding_idx is not None: m.weight.data[m.padding_idx].zero_() elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) elif isinstance(m, nn.MultiheadAttention): if m.q_proj_weight is not None: torch.nn.init.xavier_uniform_(m.q_proj_weight.data) torch.nn.init.xavier_uniform_(m.k_proj_weight.data) torch.nn.init.xavier_uniform_(m.v_proj_weight.data) else: # treat the weights of Q, K, V separately val = math.sqrt(6. / float(m.in_proj_weight.shape[0] // 3 + m.in_proj_weight.shape[1])) nn.init.uniform_(m.in_proj_weight, -val, val) def data_half(fp16, bf16, data): if fp16: for k, v in data.items(): if isinstance(v, torch.Tensor) and v.dtype == torch.float32: data[k] = v.half() # print(k) elif bf16: for k, v in data.items(): if isinstance(v, torch.Tensor) and v.dtype == torch.float32: data[k] = v.to(torch.bfloat16) # print(k) return data def postprocess(data_dict:dict, task_info:dict ): if data_dict.get('sample_info', None) is not None and data_dict['sample_info'].get('distributed', False): data = data_dict['data'] hidden_states = data[:, 0].contiguous( ) # HERE only use the spe token feature! hidden_states = torch.cat(torch.distributed.nn.all_gather(hidden_states)) total_length = data_dict['sample_info']['total_num'] if hidden_states.shape[0] > total_length: hidden_states = hidden_states[:total_length] data_dict['data'] = hidden_states.unsqueeze(1) def get_spe_token(tokenizer, token_embed): if comm.old_checkpoint: a = torch.tensor(tokenizer.encode('<|spe|>')).cuda().unsqueeze(0) # bs, 1 return token_embed(a, type_embed=False, pos_embed=False) else: a = torch.tensor(tokenizer.encode('spe')).cuda().unsqueeze(0) # bs, 1 return token_embed(a) def preprocess(tokenizer, token_embed, data_list:list, task_info:dict): # perparation for fused_encoder input bs = data_list[0]['data'].shape[0] device = data_list[0]['data'].device mask_dtype = torch.uint8 #TODO: prompt embedding prefix_spe_before_fuse = task_info.get('prefix_spe_before_fuse', True) combined_data = [] # spe embedding spe_token = get_spe_token(tokenizer, token_embed).expand(bs, -1, -1) length = [ data_dict['data'].shape[1] for data_dict in data_list] if prefix_spe_before_fuse: length = [1] + length combined_data.append(spe_token) cum_length = np.cumsum(length).tolist() invalid_mask_active = any([ data_dict.get('invalid_mask', None) is not None for data_dict in data_list]) if invalid_mask_active: combined_valid_mask = torch.zeros((bs, cum_length[-1]), dtype=mask_dtype, device=device) else: combined_valid_mask = None for i, data_dict in enumerate(data_list): combined_data.append(data_dict['data']) if data_dict.get('invalid_mask', None) is not None: combined_valid_mask[:, cum_length[i]:cum_length[i+1]] = data_dict['invalid_mask'] combined_data = torch.cat(combined_data, dim=1) sample_info = { 'data_length': length, 'data_cum_length': cum_length, 'sample_info_per_sample': []} # for caption task inference if comm._CAPTION_GEN_MODE: sample_info['data_cum_length'] = data_list[0]['sample_info']['data_cum_length'] for data_dict in data_list: if data_dict.get('sample_info', None) is not None: if isinstance(data_dict['sample_info'], dict): sample_info.update(data_dict['sample_info']) elif isinstance(data_dict['sample_info'], list): if isinstance(data_dict['sample_info'][0], dict): sample_info.update(data_dict['sample_info'][0]) sample_info['sample_info_per_sample'].append(data_dict['sample_info']) moe_embedding = None for data_dict in data_list: if 'data_type' in data_dict: data_type = data_dict['data_type'] if 'moe_embedding' in data_dict: moe_embedding = data_dict['moe_embedding'] return { 'data': combined_data, 'invalid_mask': combined_valid_mask, 'data_type': data_type, 'sample_info': sample_info, 'moe_embedding': moe_embedding, } def share_token_embed_ln(video_embed, token_embed): if video_embed is not None and token_embed is not None: del video_embed.embeddings_norm video_embed.embeddings_norm = token_embed.embeddings_norm