diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..c6b00e35af90fc2b9ee764a2ff3c1b22d1b7ca68 Binary files /dev/null and b/.DS_Store differ diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..86baa8344aba493b741928a8d58c5f49d5e09f86 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +res/effect.png filter=lfs diff=lfs merge=lfs -text diff --git a/.msc b/.msc new file mode 100644 index 0000000000000000000000000000000000000000..e2bda352ec25d4ee57072a36f37628a87216a875 Binary files /dev/null and b/.msc differ diff --git a/.mv b/.mv new file mode 100644 index 0000000000000000000000000000000000000000..30986a2c85250072d42086d8828e5468a6091982 --- /dev/null +++ b/.mv @@ -0,0 +1 @@ +Revision:master,CreatedAt:1742457812 \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..673e1c7e81ba3ad1f293a28c652b3d10150813be --- /dev/null +++ b/README.md @@ -0,0 +1,94 @@ +--- +tasks: +- multi-modal-embedding +- image-text-retrieval +domain: +- multi-modal +frameworks: +- pytorch +backbone: +- transformers +metrics: +- R@1 +license: apache-2.0 +tags: +- Ant Group +- multi-modal-embedding +widgets: + - inputs: + - validator: + max_words: 52 + type: text + title: 查询文本 + output: + maximize: false + examples: + - name: 1 + inputs: + - data: 戴眼镜的猫 + - name: 2 + inputs: + - data: 一个在逛公园的女孩 + task: multi-modal-embedding +--- + +## 模型描述 +M2-Encoder是强大的中英双语多模态模型,它在我们构建的包含60亿图文对(30亿中文+30亿英文)的BM-6B上训练得到,支持zero-shot 图文跨模态检索(文搜图、图搜文) 以及 zero-shot图片分类 任务。 + +模型效果如下: + +![M2-Encoder](./res/effect.png) + +## 期望模型使用方式以及适用范围 +本模型主要用于: +1. 图片检索文本,或文本检索图片: 以文本检索图片为例,使用M2-Encoder提前对所有图片底库进行特征抽取,给定文本query,使用M2-Encoder对query文本进行特征抽取, 然后和图片底库保存的特征进行相似度计算。 +2. 图片zero-shot开集分类: 给定图像以及对应的标签列表,根据图像和标签相似度,输出与图像最匹配的标签。 + + +## 如何使用 + +### 代码范例 +``` +# 新建环境(Python版本3.8) +conda create -n m2-encoder python=3.8 +source activate m2-encoder + +# clone项目地址 +cd /YourPath/ +git clone https://github.com/alipay/Ant-Multi-Modal-Framework + +# 安装包依赖 +cd ./Ant-Multi-Modal-Framework/prj/M2_Encoder/ +pip install -r requirements.txt + +# 运行demo,会自动通过model_scope下载对应模型权重 +python run.py +``` + +### 模型局限性以及可能的偏差 +模型在数据集上训练,有可能产生一些偏差,请用户自行评测后决定如何使用。 + +## 训练数据介绍 +BM-6B数据集: 包含60亿清洗后的高质量中英双语图文对数据,其中文和英文数据比例基本保持一致,均为30亿。数据集搜集、构建过程详见[技术报告](https://arxiv.org/abs/2401.15896)。 + +## 模型训练流程 +暂时不支持通过ModelScope接口进行训练,敬请期待。 + + +### 训练 +暂不支持。 +## 数据评估及结果 +zero-shot图文跨模态检索和zero-shot分类任务均达到SOTA. + + + +### 相关论文以及引用信息 +如果你觉得这个该模型对有所帮助,请考虑引用下面的相关的论文: +``` +@misc{guo2024m2encoder, + title={M2-Encoder: Advancing Bilingual Image-Text Understanding by Large-scale Efficient Pretraining}, + author={Qingpei Guo and Furong Xu and Hanxiao Zhang and Wang Ren and Ziping Ma and Lin Ju and Jian Wang and Jingdong Chen and Ming Yang}, + year={2024}, + url={https://arxiv.org/abs/2401.15896}, +} +``` diff --git a/configuration.json b/configuration.json new file mode 100644 index 0000000000000000000000000000000000000000..158ee6271dc50c0e50529cd9ae077b503ea764d2 --- /dev/null +++ b/configuration.json @@ -0,0 +1 @@ +{"framework":"Pytorch","task":"multi-modal-embeddings","pipeline":{"type":"multi-modal-embedding-pipeline"},"allow_remote": true} \ No newline at end of file diff --git a/m2_encoder_1B.ckpt b/m2_encoder_1B.ckpt new file mode 100644 index 0000000000000000000000000000000000000000..2d941336c04793f25bb75abb3ed1fdb7c9208c38 --- /dev/null +++ b/m2_encoder_1B.ckpt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ac4c5d9a0e44fff05f0ccadf54617b7809f489bc401212abd836c8d075047e9b +size 2921990385 diff --git a/ms_wrapper.py b/ms_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..1aa1e679c1ee4c71130f99064542c8fa39167c3b --- /dev/null +++ b/ms_wrapper.py @@ -0,0 +1,219 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import torch +import os + +from modelscope.models.base import TorchModel +from modelscope.preprocessors.base import Preprocessor +from modelscope.pipelines.base import Model, Pipeline +from modelscope.utils.config import Config +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors.builder import PREPROCESSORS +from modelscope.models.builder import MODELS +from modelscope.preprocessors.image import load_image + + +from vlmo.utils.beit_utils import load_from_config + + +@PIPELINES.register_module( + "multi-modal-embeddings", module_name="multi-modal-embedding-pipeline" +) +class MyCustomPipeline(Pipeline): + """Give simple introduction to this pipeline. + + Examples: + + >>> from modelscope.pipelines import pipeline + >>> input = "Hello, ModelScope!" + >>> my_pipeline = pipeline('my-task', 'my-model-id') + >>> result = my_pipeline(input) + + """ + + def __init__(self, model, preprocessor=None, **kwargs): + """ + use `model` and `preprocessor` to create a custom pipeline for prediction + Args: + model: model id on modelscope hub. + preprocessor: the class of method be init_preprocessor + """ + super().__init__(model=model, auto_collate=False) + self.model_dir = model + self._device = "cuda" if torch.cuda.is_available() else "cpu" + # model_config = { + # "loss_names": {"itc": 1}, + # "encoder_layers": 9, + # "beit3_vl_layers": 3, + # "tokenizer_type": "GLMChineseTokenizer", + # "tokenizer": os.path.join(self.model_dir, "./vlmo/tokenizer"), + # "vocab_size": 115244, + # "whole_word_masking": True, + # "precision": 32, + # "test_only": True, + # "flash_attn": True, + # "model_path": os.path.join(self.model_dir, "m2_encoder_1B.ckpt"), + # "modelscope": {"model_id": "M2Cognition/M2-Encoder-Large"}, + # "model_file": "m2_encoder_1B.ckpt", + # } + model_config = { + "loss_names": {"itc": 1}, + "beit_version": "large", + "encoder_embed_dim": 1024, + "out_embed_dim": 1024, + "encoder_layers": 21, + "beit3_vl_layers": 3, + # "image_size": 224, + "visual_mask_size": 14, + "tokenizer_type": "GLMChineseTokenizer", + "tokenizer": os.path.join(self.model_dir, "./vlmo/tokenizer"), + "vocab_size": 115244, + "whole_word_masking": False, + "precision": 32, + "test_only": True, + "flash_attn": True, + "model_path": os.path.join(self.model_dir, "m2_encoder_1B.ckpt"), + "modelscope": { + "model_id": "M2Cognition/M2_Encoder_Large" + }, + "model_file": "m2_encoder_1B.ckpt" + } + model, processors = load_from_config(model_config) + self.model = model + self.model.to(self._device).eval() + self._tokenizer, self._img_processor = processors + + def _sanitize_parameters(self, **pipeline_parameters): + """ + this method should sanitize the keyword args to preprocessor params, + forward params and postprocess params on '__call__' or '_process_single' method + considered to be a normal classmethod with default implementation / output + + Default Returns: + Dict[str, str]: preprocess_params = {} + Dict[str, str]: forward_params = {} + Dict[str, str]: postprocess_params = pipeline_parameters + """ + return {}, pipeline_parameters, {} + + def _check_input(self, inputs): + pass + + def _check_output(self, outputs): + pass + + def forward(self, forward_params): + """Provide default implementation using self.model and user can reimplement it""" + # print("forward_params", forward_params) + labels = forward_params.get("label_list", "") + labels = labels.split(",") + if len(labels) > 1 and labels[0] != "": + txt_encoding = self._tokenizer( + labels, + padding="max_length", + truncation=True, + max_length=self.model.hparams.config["max_text_len"], + return_special_tokens_mask=True, + ) + txt_data = { + "text_ids": torch.tensor(txt_encoding["input_ids"]).to(self._device), + "text_masks": torch.tensor(txt_encoding["attention_mask"]).to( + self._device + ), + "text_labels": None, + } + txt_feats = self.model.infer_text(txt_data)["cls_vlffn_feats"] + image = forward_params["image"] + image = load_image(image) + img = self._img_processor(image).unsqueeze(0) + img_data = {"image": [img.to(self._device)]} + img_feats = self.model.infer_image(img_data)["cls_vlffn_feats"] + logits_per_image = self.model.logit_scale.exp() * img_feats @ txt_feats.t() + probs = logits_per_image.softmax(dim=-1).detach().cpu() + index = probs.max(dim=-1)[1][0] + label = labels[index] + return {"text": label, "scores": probs.numpy().tolist()[0]} + else: + rets = {} + if "text" in forward_params: + text = forward_params.get("text") + txt_encoding = self._tokenizer( + text, + padding="max_length", + truncation=True, + max_length=self.model.hparams.config["max_text_len"], + return_special_tokens_mask=True, + ) + txt_data = { + "text_ids": torch.tensor(txt_encoding["input_ids"]).to( + self._device + ), + "text_masks": torch.tensor(txt_encoding["attention_mask"]).to( + self._device + ), + "text_labels": None, + } + txt_feats = self.model.infer_text(txt_data)["cls_vlffn_feats"] + rets.update({"text_embedding": txt_feats.detach()}) + if "img" in forward_params: + input_img = forward_params["img"] + img = self._img_processor(input_img).unsqueeze(0) + img_data = {"image": [img.to(self._device)]} + img_feats = self.model.infer_image(img_data)["cls_vlffn_feats"] + rets.update({"img_embedding": img_feats.detach()}) + + return rets + + def preprocess(self, inputs): + return inputs + + def postprocess(self, inputs): + """If current pipeline support model reuse, common postprocess + code should be write here. + + Args: + inputs: input data + + Return: + dict of results: a dict containing outputs of model, each + output should have the standard output name. + """ + return inputs + + +""" +# Tips: usr_config_path is the temporary save configuration location, after upload modelscope hub, it is the model_id +usr_config_path = "/tmp/snapdown/" +config = Config( + { + "framework": "pytorch", + "task": "multi-modal-embeddings", + "model": {"type": "m2-encoder"}, + "pipeline": {"type": "multi-modal-embedding-pipeline"}, + "allow_remote": True, + } +) +config.dump("/tmp/snapdown/" + "configuration.json") +""" + +if __name__ == "__main__": + from modelscope.pipelines import pipeline + from modelscope.preprocessors.image import load_image + + model = "M2Cognition/M2-Encoder" + pipe = pipeline("multi-modal-embeddings", model=model) + input = { + "image": "https://clip-cn-beijing.oss-cn-beijing.aliyuncs.com/pokemon.jpeg", + "label_list": "杰尼龟,妙蛙种子,小火龙,皮卡丘", + } + demo = pipe(input) + print("demo output", demo) + inputs = {"text": ["杰尼龟", "妙蛙种子", "小火龙", "皮卡丘"]} + output = pipe(inputs) + print("text output", output) + input_img = load_image( + "https://clip-cn-beijing.oss-cn-beijing.aliyuncs.com/pokemon.jpeg" + ) # 支持皮卡丘示例图片路径/本地图片 返回PIL.Image + inputs = {"img": input_img} + img_embedding = pipe(inputs) # 2D Tensor, [图片数, 特征维度] + print("image output", img_embedding) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..70617a7a4be6342c1728a31799d26c5725b1b947 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,14 @@ +torch +pytorch_lightning<=2.0.8 +transformers +Pillow +tqdm +einops +sacred +timm +torchvision +fairscale +numpy +opencv-python +sentencepiece +modelscope diff --git a/res/effect.png b/res/effect.png new file mode 100644 index 0000000000000000000000000000000000000000..11190086d3e0562068309de056d409cd9d552de8 --- /dev/null +++ b/res/effect.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4d6a1950c1ab8770d7d6949d9164c03627fd6bfe673538e71ab4700a68aa6167 +size 1122404 diff --git a/vlmo/.DS_Store b/vlmo/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..dd1bdb07ddcc6d6812c3c51c6437fa9d1fd46628 Binary files /dev/null and b/vlmo/.DS_Store differ diff --git a/vlmo/Encoder_0.4B.json b/vlmo/Encoder_0.4B.json new file mode 100644 index 0000000000000000000000000000000000000000..3d9133c5f0e10bf73dc0289848f5400e8c506bf1 --- /dev/null +++ b/vlmo/Encoder_0.4B.json @@ -0,0 +1,17 @@ +{ + "loss_names": {"itc": 1}, + "encoder_layers": 9, + "beit3_vl_layers": 3, + "tokenizer_type": "GLMChineseTokenizer", + "tokenizer": "./vlmo/tokenizer", + "vocab_size": 115244, + "whole_word_masking": true, + "precision": 32, + "test_only": true, + "flash_attn": true, + "model_path": "m2_encoder_0.4B.ckpt", + "modelscope": { + "model_id": "M2Cognition/M2-Encoder" + }, + "model_file": "m2_encoder_0.4B.ckpt" +} diff --git a/vlmo/README.md b/vlmo/README.md new file mode 100644 index 0000000000000000000000000000000000000000..9bdca8b83358f4ef4e199eeb5fbc9e08a71863f6 --- /dev/null +++ b/vlmo/README.md @@ -0,0 +1,10 @@ +--- +license: Apache License 2.0 +--- +###### 该模型当前使用的是默认介绍模版,处于“预发布”阶段,页面仅限所有者可见。 +###### 请根据[模型贡献文档说明](https://www.modelscope.cn/docs/%E5%A6%82%E4%BD%95%E6%92%B0%E5%86%99%E5%A5%BD%E7%94%A8%E7%9A%84%E6%A8%A1%E5%9E%8B%E5%8D%A1%E7%89%87),及时完善模型卡片内容。ModelScope平台将在模型卡片完善后展示。谢谢您的理解。 + +#### Clone with HTTP +```bash + git clone https://www.modelscope.cn/M2Cognition/M2_Encoder_demo.git +``` diff --git a/vlmo/__init__.py b/vlmo/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/vlmo/config.py b/vlmo/config.py new file mode 100644 index 0000000000000000000000000000000000000000..240e9617b0b1cd985650f5919be8518577af4551 --- /dev/null +++ b/vlmo/config.py @@ -0,0 +1,165 @@ +from sacred import Experiment + +ex = Experiment("VLMo") + + +def _loss_names(d): + ret = { + "itm": 0, # image-text matching loss + "itc": 0, # image-text contrastive loss + "caption": 0, # image captioning loss + "mvlm": 0, # masked language modeling loss + "textmlm": 0, # text-only masked language modeling + "imagemlm": 0, # image-only masked language modeling + "vqa": 0, + "nlvr2": 0, + "irtr": 0, # retrieval task ft + } + ret.update(d) + return ret + + +@ex.config +def config(): + exp_name = "vlmo" + seed = 1 + datasets = ["coco", "vg", "sbu", "gcc"] # dataset name, the definition can refer to: vlmo/datamodules/__init__.py # noqa + loss_names = _loss_names({"itm": 0, "itc": 0, "mvlm": 0}) # training loss + batch_size = 1024 # this is a desired batch size; pl trainer will accumulate gradients. + + # BEiT-v3 setting + encoder_layers = 12 # the layer number of backbone + encoder_embed_dim = 768 # the hidden size of tokenizer + out_embed_dim = 768 # the hidden size of output embedding + beit_version = "base" # model size: base(0.4B)|large(1B)|huge(10B) + beit3_vl_layers = 3 # the layer number of vl_backbone + deepnorm_init = True # init method + share_layer = False # if share the weight between layer within backbone + share_attn = False # if share the attention weight of different layer + one_attn = False # if share the attention weight of vision and language + + # Image setting + train_transform_keys = ["square_transform_randaug"] # train transform: refer to vlmo/transforms/__init__.py + val_transform_keys = ["square_transform"] # test transform: refer to refer to vlmo/transforms/__init__.py + image_size = 224 # image size + reclip_image_size = None # reclip image size + patch_size = 16 # patch size + draw_false_image = 0 # if get negative image + image_only = False # only input image + text_only = False # # only input text + + # Video setting, video_num_frm is not None means video input + video_num_frm = None + + # Visual tokenizer setting based on beit2 + tokenizer_model = "beit2_visual_tokenizer" + codebook_size = 8192 + codebook_dim = 32 + visual_mask_size = 14 + visual_mask_num = 80 + + # Text Setting + lang = 'cn' # language for zero-shot imagenet testing: cn|en + vqav2_label_size = 3129 + max_text_len = 40 # the number of characters + max_text_len_of_initckpt = 196 + tokenizer_type = "BertTokenizer" # Chinese text + vocab_size = 21128 + tokenizer = "./vocab.txt" + whole_word_masking = True + mlm_prob = 0.15 # language mask ratio + draw_false_text = 0 + mvlm_prob = 0.50 # vision-langurage mlm task + mask_ratio = 0 # flip: mask ratio for image + + # cap setting + cap_onlytext = False # default caption image to text + + # imagemlm setting + split_data_for_imagemlm = False # if True, split a batch data to two parts, and the first part for imagemlm. + + # itc setting + itc_mask = False # itc use masked token + aggregate_nodes = -1 # aggregate nodes num for compute_itc, default -1 is for all nodes + + # Transformer Setting + model_arch = "vlmo_base_patch16" + drop_path_rate = 0.1 + + # Downstream Setting + get_recall_metric = False + get_recall_rerank_metric = False + get_zeroshot_metric = False + get_muge_feat = False + get_f30k_feat = False + k_test = 32 + + # PL Trainer Setting + resume_from = None + fast_dev_run = False + val_check_interval = 1.0 + test_only = False + use_sharded_training = False + resume_during_training = False + save_top_k = 10 + every_n_train_steps = 2000 # the step to save checkpoint + log_metric_steps = 100 # the step to log metric + + # below params varies with the environment + use_pcache = False # data storage method: pcache or nas + pcache_root = "" + # main_site: pcache://multimodalproxyi-pool.cz50c.alipay.com:39999/mnt/ + # public_cloud: pcache://pcache_public_cloud.pcache.local:39999/mnt/abc7c88079a60b45ddfce7afa40720b7/ + gpu_env = "main_site" # public_cloud or main_site + data_root = "" # data root for data list + + + log_dir = "result" + per_gpu_batchsize = 4 # you should define this manually with per_gpu_batch_size=# + num_gpus = 1 + num_nodes = 1 + load_path = "" + num_workers = 8 + precision = 16 + local_run = True + flash_attn = False + deepspeed_config = None # "ds_config.json" + coalesce_backbone = False + mask_data = "v+l" # 'v+l':choose input of imagemlm+textmlm task, 'vl': choose input of mvlm task. + communication_benchmark = False + checkpoint_activations = False + + # dataset setting + single_cap = True # if have only one caption + random_one = False # if choose one caption from caption list + + # ITC setting + itc_feats_name = "cls_vlffn_feats" # feat for itc loss + itc_distill = "" + itc_distill_dim = 1024 + itc_teacher_weights = "" + + # mup training setting + mup = False + base_encoder_embed_dim = 1 + delta_encoder_embed_dim = 2 + mup_encoder_attention_heads = 1 + base_encoder_ffn_embed_dim = 1 + delta_encoder_ffn_embed_dim = 2 + + # atorch + atorch_config = None + compile_op = False + optimizer_state_shard_save = False + model_state_shard_save = False + + # itc loss + local_loss = False + use_dual_softmax = False + + num_frames = 1 +# ----------------------- LMM pretraining config ----------------------- + + # norm setting + deepnorm = False + diff --git a/vlmo/modules/__init__.py b/vlmo/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fea788a622f782141d0822d8735184dde18df7a6 --- /dev/null +++ b/vlmo/modules/__init__.py @@ -0,0 +1 @@ +from .vlmo_module import VLMo diff --git a/vlmo/modules/heads.py b/vlmo/modules/heads.py new file mode 100644 index 0000000000000000000000000000000000000000..fd6afe1d7cd8caf85ddb62ba3f85830f0b63eea2 --- /dev/null +++ b/vlmo/modules/heads.py @@ -0,0 +1,24 @@ +import torch.nn as nn + + +class Pooler(nn.Module): + def __init__(self, hidden_size): + super().__init__() + self.dense = nn.Linear(hidden_size, hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class ITCHead(nn.Module): + def __init__(self, hidden_size, out_size): + super().__init__() + self.fc = nn.Linear(hidden_size, out_size, bias=False) + + def forward(self, x): + x = self.fc(x) + return x diff --git a/vlmo/modules/modeling_utils.py b/vlmo/modules/modeling_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..490f1472ca10fc7dc5ee9f3b62b9126369206928 --- /dev/null +++ b/vlmo/modules/modeling_utils.py @@ -0,0 +1,179 @@ +# -------------------------------------------------------- +# Image as a Foreign Language: BEiT Pretraining for Vision and Vision-Language Tasks (https://arxiv.org/abs/2208.10442) +# Github source: https://github.com/microsoft/unilm/tree/master/beit3 +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# --------------------------------------------------------' + +import math +import torch +import torch.nn as nn +from timm.models.layers import trunc_normal_ as __call_trunc_normal_ + +from vlmo.torchscale.model.BEiT3 import BEiT3 +from vlmo.torchscale.architecture.config import EncoderConfig + + +def trunc_normal_(tensor, mean=0.0, std=1.0): + __call_trunc_normal_(tensor, mean=mean, std=std, a=-std, b=std) + + +def _get_base_config( + img_size=224, + patch_size=16, + drop_path_rate=0, + checkpoint_activations=None, + mlp_ratio=4, + vocab_size=64010, + encoder_layers=12, + encoder_embed_dim=768, + encoder_attention_heads=12, + share_layer=False, + share_attn=False, + deepnorm=False, + mask_ratio=0, + max_text_len=52, + one_attn=False, + **kwargs +): + return EncoderConfig( + img_size=img_size, + patch_size=patch_size, + vocab_size=vocab_size, + multiway=True, + layernorm_embedding=False, + normalize_output=True, + no_output_layer=True, + drop_path_rate=drop_path_rate, + encoder_embed_dim=encoder_embed_dim, + encoder_attention_heads=encoder_attention_heads, + encoder_layers=encoder_layers, + encoder_ffn_embed_dim=int(encoder_embed_dim * mlp_ratio), + checkpoint_activations=checkpoint_activations, + share_layer=share_layer, + share_attn=share_attn, + deepnorm=deepnorm, + mask_ratio=mask_ratio, + max_text_len=max_text_len, + one_attn=one_attn, + ) + + +def _get_large_config( + img_size=224, + patch_size=16, + drop_path_rate=0, + checkpoint_activations=None, + mlp_ratio=4, + vocab_size=64010, + encoder_layers=24, + encoder_embed_dim=1024, + encoder_attention_heads=16, + share_layer=False, + share_attn=False, + deepnorm=False, + mask_ratio=0, + max_text_len=52, + one_attn=False, + **kwargs +): + return EncoderConfig( + img_size=img_size, + patch_size=patch_size, + vocab_size=vocab_size, + multiway=True, + layernorm_embedding=False, + normalize_output=True, + no_output_layer=True, + drop_path_rate=drop_path_rate, + encoder_embed_dim=encoder_embed_dim, + encoder_attention_heads=encoder_attention_heads, + encoder_layers=encoder_layers, + encoder_ffn_embed_dim=int(encoder_embed_dim * mlp_ratio), + checkpoint_activations=checkpoint_activations, + share_layer=share_layer, + share_attn=share_attn, + deepnorm=deepnorm, + mask_ratio=mask_ratio, + max_text_len=max_text_len, + one_attn=one_attn, + ) + + +def _get_huge_config( + img_size=224, + patch_size=16, + drop_path_rate=0, + checkpoint_activations=None, + mlp_ratio=4, + vocab_size=30522, + encoder_layers=32, + encoder_embed_dim=4096, + encoder_attention_heads=32, + share_layer=False, + share_attn=False, + deepnorm=False, + mask_ratio=0, + max_text_len=52, + one_attn=False, + **kwargs +): + return EncoderConfig( + img_size=img_size, + patch_size=patch_size, + vocab_size=vocab_size, + multiway=True, + layernorm_embedding=False, + normalize_output=True, + no_output_layer=True, + drop_path_rate=drop_path_rate, + encoder_embed_dim=encoder_embed_dim, + encoder_attention_heads=encoder_attention_heads, + encoder_layers=encoder_layers, + encoder_ffn_embed_dim=int(encoder_embed_dim * mlp_ratio), + checkpoint_activations=checkpoint_activations, + share_layer=share_layer, + share_attn=share_attn, + deepnorm=deepnorm, + mask_ratio=mask_ratio, + max_text_len=max_text_len, + one_attn=one_attn, + ) + + +class BEiT3Wrapper(nn.Module): + def __init__(self, args, **kwargs): + super().__init__() + self.args = args + self.beit3 = BEiT3(args) + self.apply(self._init_weights) + + def fix_init_weight(self): + def rescale(param, layer_id): + param.div_(math.sqrt(2.0 * layer_id)) + + for layer_id, layer in enumerate(self.blocks): + rescale(layer.attn.proj.weight.data, layer_id + 1) + rescale(layer.mlp.fc2.weight.data, layer_id + 1) + + def get_num_layers(self): + return self.beit3.encoder.num_layers + + @torch.jit.ignore + def no_weight_decay(self): + return { + "pos_embed", + "cls_token", + "beit3.encoder.embed_positions.A.weight", + "beit3.vision_embed.cls_token", + "logit_scale", + } + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) diff --git a/vlmo/modules/multiway_transformer.py b/vlmo/modules/multiway_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..599bc4707654333e8570ce140835f48627cb809f --- /dev/null +++ b/vlmo/modules/multiway_transformer.py @@ -0,0 +1,396 @@ +""" Vision Transformer (ViT) in PyTorch + +A PyTorch implement of Vision Transformers as described in +'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' - https://arxiv.org/abs/2010.11929 + +The official jax code is released and available at https://github.com/google-research/vision_transformer + +Acknowledgments: +* The paper authors for releasing code and weights, thanks! +* I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out +for some einops/einsum fun +* Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT +* Bert reference code checks against Huggingface Transformers and Tensorflow Bert + +DeiT model defs and weights from https://github.com/facebookresearch/deit, +paper `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877 + +Hacked together by / Copyright 2020 Ross Wightman +""" +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from timm.models.registry import register_model +from pytorch_lightning.utilities.rank_zero import rank_zero_info + + +class Mlp(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.0, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__( + self, + dim, + num_heads=8, + qkv_bias=False, + qk_scale=None, + attn_drop=0.0, + proj_drop=0.0, + ): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=False) + if qkv_bias: + self.q_bias = nn.Parameter(torch.zeros(dim)) + self.v_bias = nn.Parameter(torch.zeros(dim)) + else: + self.q_bias = None + self.v_bias = None + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, mask=None, relative_position_bias=None): + B, N, C = x.shape + + qkv_bias = None + if self.q_bias is not None: + qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)) + qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) + qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + + q, k, v = ( + qkv[0], + qkv[1], + qkv[2], + ) # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = q.float() @ k.float().transpose(-2, -1) + + if relative_position_bias is not None: + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + mask = mask.bool() + attn = attn.masked_fill(~mask[:, None, None, :], float("-inf")) + attn = attn.softmax(dim=-1).type_as(x) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Module): + def __init__( + self, + dim, + num_heads, + mlp_ratio=4.0, + qkv_bias=False, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + with_vlffn=False, + layer_scale_init_values=0.1, + max_text_len=40, + ): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + ) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.norm2_text = norm_layer(dim) + self.norm2_imag = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp_text = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + ) + self.mlp_imag = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + ) + self.mlp_vl = None + if with_vlffn: + self.mlp_vl = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + ) + self.norm2_vl = norm_layer(dim) + + self.gamma_1 = ( + nn.Parameter(layer_scale_init_values * torch.ones((dim)), requires_grad=True) + if layer_scale_init_values is not None + else 1.0 + ) + self.gamma_2 = ( + nn.Parameter(layer_scale_init_values * torch.ones((dim)), requires_grad=True) + if layer_scale_init_values is not None + else 1.0 + ) + + self.max_text_len = max_text_len + + def forward(self, x, mask=None, modality_type=None, relative_position_bias=None): + x = x + self.drop_path( + self.gamma_1 * self.attn(self.norm1(x), mask=mask, relative_position_bias=relative_position_bias) + ) + + if modality_type == "image": + x = x + self.drop_path(self.gamma_2 * self.mlp_imag(self.norm2_imag(x))) + elif modality_type == "text": + x = x + self.drop_path(self.gamma_2 * self.mlp_text(self.norm2_text(x))) + else: + if self.mlp_vl is None: + x_text = x[:, : self.max_text_len] + x_imag = x[:, self.max_text_len :] + x_text = x_text + self.drop_path(self.gamma_2 * self.mlp_text(self.norm2_text(x_text))) + x_imag = x_imag + self.drop_path(self.gamma_2 * self.mlp_imag(self.norm2_imag(x_imag))) + x = torch.cat([x_text, x_imag], dim=1) + else: + x = x + self.drop_path(self.gamma_2 * self.mlp_vl(self.norm2_vl(x))) + + return x + + +class PatchEmbed(nn.Module): + """Image to Patch Embedding""" + + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + no_patch_embed_bias=False, + ): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) + self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + + self.proj = nn.Conv2d( + in_chans, + embed_dim, + kernel_size=patch_size, + stride=patch_size, + bias=False if no_patch_embed_bias else True, + ) + + def forward(self, x): + B, C, H, W = x.shape + assert ( + H == self.img_size[0] and W == self.img_size[1] + ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + # FIXME look at relaxing size constraints + x = self.proj(x) + return x + + +class MultiWayTransformer(nn.Module): + """Vision Transformer + + A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` - + https://arxiv.org/abs/2010.11929 + """ + + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.0, + norm_layer=None, + need_relative_position_embed=True, + use_abs_pos_emb=False, + layer_scale_init_values=0.1, + vlffn_start_layer_index=10, + config=None, + ): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + num_classes (int): number of classes for classification head + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + qk_scale (float): override default qk scale of head_dim ** -0.5 if set + drop_rate (float): dropout rate + attn_drop_rate (float): attention dropout rate + drop_path_rate (float): stochastic depth rate + norm_layer: (nn.Module): normalization layer + need_relative_position_embed (bool): enable relative position bias on self-attention + use_abs_pos_emb (bool): enable abs pos emb + layer_scale_init_values (float or None): layer scale init values, set None to disable + vlffn_start_layer_index (int): vl-ffn start index + config: (dict): other hyper from pytorch-lighting + """ + super().__init__() + drop_path_rate = drop_path_rate if config is None else config["drop_path_rate"] + rank_zero_info("drop path rate: {}".format(drop_path_rate)) + self.use_abs_pos_emb = use_abs_pos_emb + self.need_relative_position_embed = need_relative_position_embed + + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) + + self.patch_embed = PatchEmbed( + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + ) + num_patches = self.patch_embed.num_patches + self.patch_size = patch_size + self.num_heads = num_heads + self.vlffn_start_layer_index = vlffn_start_layer_index + if config["loss_names"]["textmlm"] > 0: + self.vlffn_start_layer_index = depth + rank_zero_info( + "Set vlffn_start_layer_index={} for text-only pretraining".format(self.vlffn_start_layer_index) + ) + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) if self.use_abs_pos_emb else None + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.blocks = nn.ModuleList( + [ + Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + with_vlffn=(i >= self.vlffn_start_layer_index), + layer_scale_init_values=layer_scale_init_values, + max_text_len=config["max_text_len"], + ) + for i in range(depth) + ] + ) + self.norm = norm_layer(embed_dim) + + if self.pos_embed is not None: + trunc_normal_(self.pos_embed, std=0.02) + trunc_normal_(self.cls_token, std=0.02) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {"pos_embed", "cls_token"} + + def visual_embed(self, _x): + x = self.patch_embed(_x) + x = x.flatten(2).transpose(1, 2) + B, L, _ = x.shape + + cls_tokens = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + + if self.pos_embed is not None: + x = x + self.pos_embed + x = self.pos_drop(x) + + x_mask = torch.ones(x.shape[0], x.shape[1]) + + return x, x_mask + + +# VLMo base/p16 +@register_model +def vlmo_base_patch16(pretrained=False, **kwargs): + img_size = kwargs.pop("img_size", 224) + model = MultiWayTransformer( + img_size=img_size, + patch_size=16, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + qkv_bias=True, + vlffn_start_layer_index=10, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs, + ) + return model diff --git a/vlmo/modules/objectives.py b/vlmo/modules/objectives.py new file mode 100644 index 0000000000000000000000000000000000000000..427ff860aa7c304de5b7a221e5e97ef61ab55ba9 --- /dev/null +++ b/vlmo/modules/objectives.py @@ -0,0 +1,12 @@ +import torch.nn as nn + + +def init_weights(module): + if isinstance(module, (nn.Linear, nn.Embedding)): + module.weight.data.normal_(mean=0.0, std=0.02) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() diff --git a/vlmo/modules/vlmo_module.py b/vlmo/modules/vlmo_module.py new file mode 100644 index 0000000000000000000000000000000000000000..81d8536e4662bc8bf95c752670177b38d049fbcc --- /dev/null +++ b/vlmo/modules/vlmo_module.py @@ -0,0 +1,405 @@ +import math +import os +import time + +import numpy as np +import pytorch_lightning as pl +import torch +import torch.distributed as dist +import torch.nn as nn +from pytorch_lightning.utilities.rank_zero import rank_zero_info +from timm.models import create_model +from transformers import AutoTokenizer, BertTokenizer, XLMRobertaTokenizer # noqa +from vlmo.modules import heads, objectives, vlmo_utils +from vlmo.tokenizer.tokenization_glm import GLMChineseTokenizer # noqa +from vlmo.torchscale.architecture.encoder import Encoder +from vlmo.torchscale.model.BEiT3 import BEiT3 as ts_backbone +from vlmo.transforms.utils import inception_normalize as img_norm + +from .modeling_utils import _get_base_config, _get_large_config, _get_huge_config, trunc_normal_ # noqa + + +def convert_pl_ckpt(state_dict, num_visual_token=197): + print("start convert_pl_ckpt!!!") + new_state_dict = {} + for key in state_dict: + value = state_dict[key] + if "visual_tokenizer" in key: + continue + elif "backbone.encoder.embed_positions.A.weight" in key: + if value.shape[0] < num_visual_token + 2: + N = value.shape[0] - 3 + dim = value.shape[-1] + class_pos_embed = value[:3, ] + patch_pos_embed = value[3:, ] + w0, h0 = int(math.sqrt(num_visual_token - 1)), int(math.sqrt(num_visual_token - 1)) + patch_pos_embed = patch_pos_embed.float() + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), + size=(w0, h0), + mode="area", + ) + patch_pos_embed = patch_pos_embed.to(class_pos_embed.dtype) + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(-1, dim) + new_value = torch.cat((class_pos_embed, patch_pos_embed), dim=0) + new_state_dict[key] = new_value + print("reshape ", key, "raw shape: ", value.shape, "new shape: ", new_value.shape, num_visual_token) + elif value.shape[0] > num_visual_token + 2: + new_state_dict[key] = value[: num_visual_token + 2, :] + print("first ", key, "raw shape: ", value.shape, new_state_dict[key].shape, num_visual_token) + else: + new_state_dict[key] = value + print("raw shape") + else: + new_state_dict[key] = state_dict[key] + + return new_state_dict + + +def convert_deepspeed_ckpt(state_dict, num_visual_token=197): + new_state_dict = {} + for key in state_dict: + if key.startswith("_forward_module."): + new_key = key[len("_forward_module."):] + value = state_dict[key] + new_state_dict[new_key] = value + if "visual_tokenizer.encoder.pos_embed" in new_key or "visual_tokenizer.decoder.pos_embed" in new_key: + if value.shape[1] != num_visual_token: + N = value.shape[1] - 1 + dim = value.shape[-1] + class_pos_embed = value[:, 0] + patch_pos_embed = value[:, 1:] + w0, h0 = int(math.sqrt(num_visual_token - 1)), int(math.sqrt(num_visual_token - 1)) + patch_pos_embed = patch_pos_embed.float() + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), + size=(w0, h0), + mode="area", + ) + patch_pos_embed = patch_pos_embed.to(class_pos_embed.dtype) + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + new_value = torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + new_state_dict[new_key] = new_value + print("reshape ", new_key, "raw shape: ", value.shape, "new_shape: ", new_value.shape) + if "backbone.encoder.embed_positions.A.weight" in new_key: + if value.shape[1] != num_visual_token + 2: + N = value.shape[0] - 3 + dim = value.shape[-1] + class_pos_embed = value[:3, ] + patch_pos_embed = value[3:, ] + w0, h0 = int(math.sqrt(num_visual_token - 1)), int(math.sqrt(num_visual_token - 1)) + patch_pos_embed = patch_pos_embed.float() + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), + size=(w0, h0), + mode="area", + ) + patch_pos_embed = patch_pos_embed.to(class_pos_embed.dtype) + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(-1, dim) + new_value = torch.cat((class_pos_embed, patch_pos_embed), dim=0) + new_state_dict[new_key] = new_value + print("reshape ", new_key, "raw shape: ", value.shape, "new_shape: ", new_value.shape) + + else: + new_state_dict[key] = state_dict[key] + + return new_state_dict + + +def get_visual_tokenizer(config): + tokenizer_name = config["tokenizer_model"] + print(f"Creating visual tokenizer: {tokenizer_name}") + model = create_model( + config["tokenizer_model"], + img_size=config["image_size"], + n_code=config["codebook_size"], + code_dim=config["codebook_dim"], + ).eval() + return model + + +def get_pretrained_tokenizer(tokenizer_type, from_pretrained): + _Tokenizer = eval(f"{tokenizer_type}") + if torch.distributed.is_initialized(): + if torch.distributed.get_rank() == 0: + _Tokenizer.from_pretrained(from_pretrained) + torch.distributed.barrier() + return _Tokenizer.from_pretrained(from_pretrained) + + +class VLMo(pl.LightningModule): + def __init__(self, config): + super().__init__() + self.save_hyperparameters() + s_t = time.time() + + # tokenizer & backbone + self.img_size = config["image_size"] + if not config["test_only"]: + self.visual_tokenizer = get_visual_tokenizer(config) + kwargs = {} + if "encoder_attention_heads" in config: + kwargs["encoder_attention_heads"] = config["encoder_attention_heads"] + if "atorch_config" in config and config["atorch_config"]: + checkpoint_activations = False # ? + else: + checkpoint_activations = config["checkpoint_activations"] + args = eval(f'_get_{config["beit_version"]}_config')( + img_size=config["image_size"], + patch_size=config["patch_size"], + vocab_size=config["vocab_size"], + encoder_layers=config["encoder_layers"], + encoder_embed_dim=config["encoder_embed_dim"], + checkpoint_activations=checkpoint_activations, + share_layer=config["share_layer"], + share_attn=config["share_attn"], + deepnorm=config["deepnorm"], + mask_ratio=config["mask_ratio"], + max_text_len=config["max_text_len"], + one_attn=config["one_attn"], + **kwargs, + ) + self.num_features = args.encoder_embed_dim + self.out_features = config["out_embed_dim"] + self.cap_onlytext = config["cap_onlytext"] + self.lang = config["lang"] + self.num_frames = config["num_frames"] + self.tokenizer_type = config["tokenizer_type"] + self.text_tokenizer = get_pretrained_tokenizer(self.tokenizer_type, from_pretrained=config["tokenizer"]) # noqa + print("BEiT args", args.__dict__) + self.backbone = ts_backbone(args) + + self.use_vl = config["beit3_vl_layers"] > 0 + if self.use_vl: + args.encoder_layers = config["beit3_vl_layers"] + self.backbone_vl = Encoder(args) + + self.norm = nn.LayerNorm(self.num_features, eps=1e-6) + + # task layers + self.pooler = heads.Pooler(self.num_features) + self.pooler.apply(objectives.init_weights) + + # contrastive loss (or sampling for global hard negative) + if config["loss_names"]["itc"] > 0: + self.itc_text_proj = heads.ITCHead(self.num_features, self.out_features) + self.itc_image_proj = heads.ITCHead(self.num_features, self.out_features) + self.itc_text_proj.apply(objectives.init_weights) + self.itc_image_proj.apply(objectives.init_weights) + + self.itc_vl_text_proj = heads.ITCHead(self.num_features, self.out_features) + self.itc_vl_image_proj = heads.ITCHead(self.num_features, self.out_features) + self.itc_vl_text_proj.apply(objectives.init_weights) + self.itc_vl_image_proj.apply(objectives.init_weights) + + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + self.logit_vl_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + lp_s_t = time.time() + + self.load_pretrained_weight() + load_pretrain_time = time.time() - lp_s_t + + self.current_tasks = list() + + # ===================== load downstream (test_only) ====================== + + if self.hparams.config["load_path"] != "" and self.hparams.config["test_only"]: + rank_zero_info("Load ckpt from: {}".format(self.hparams.config["load_path"])) + ckpt = torch.load(self.hparams.config["load_path"], map_location="cpu") + + state_dict = None + + for state_dict_key in ("state_dict", "module", "model"): + if state_dict_key in ckpt: + rank_zero_info("Read state dict from ckpt[%s]. " % state_dict_key) + state_dict = ckpt[state_dict_key] + break + if state_dict_key == "module": + state_dict = convert_deepspeed_ckpt(state_dict, self.backbone.vision_embed.num_position_embeddings()) + if state_dict_key == "state_dict": + state_dict = convert_pl_ckpt(state_dict, self.backbone.vision_embed.num_position_embeddings()) + if state_dict is None: + if list(ckpt.keys())[0].startswith('_forward_module.'): + rank_zero_info("Read state dict from ckpt with _forward_module prefix. ") + state_dict = convert_deepspeed_ckpt(ckpt, self.backbone.vision_embed.num_position_embeddings()) + else: + rank_zero_info("Read state dict from ckpt. ") + state_dict = ckpt + + missing_keys, unexpected_keys = self.load_state_dict(state_dict, strict=False) + rank_zero_info("missing_keys: {}".format(missing_keys)) + rank_zero_info("unexpected_keys: {}".format(unexpected_keys)) + + construct_time = time.time() - s_t + print( + f"Process {os.getpid()}. VLMo Constructor time: {construct_time}s;", + f"load_pretrain_time: {load_pretrain_time}s", + flush=True, + ) + # coalesce backbone calls + self._coalesce_backbone = config["coalesce_backbone"] + self._mask_data = config["mask_data"] + self._backbone_inputs = {} + self._backbone_inputs_current_size = 0 + self._backbone_inputs_keys = {} + self._backbone_outputs = None + self._default_attn_masks = {} + self._itc_group = None + self._itc_aggregate_dict = None + self._itc_mask = config["itc_mask"] + self._local_loss = config["local_loss"] + self._aggregate_nodes = config["aggregate_nodes"] + self.accumulated_batches_reached = False + vlmo_utils.set_task(self) + self._only_itc_single_machine = ( + self._aggregate_nodes > 0 and len(self.current_tasks) == 1 and "itc" in self.current_tasks + ) + self._split_data_for_imagemlm = config["split_data_for_imagemlm"] + self.log_metric_steps = config["log_metric_steps"] + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def fix_init_weight(self): + def rescale(param, layer_id): + param.div_(math.sqrt(2.0 * layer_id)) + + for layer_id, layer in enumerate(self.backbone.encoder.layers): + rescale(layer.self_attn.v_proj.A.weight.data, layer_id + 1) + rescale(layer.self_attn.v_proj.B.weight.data, layer_id + 1) + rescale(layer.self_attn.out_proj.A.weight.data, layer_id + 1) + rescale(layer.self_attn.out_proj.B.weight.data, layer_id + 1) + rescale(layer.ffn.A.fc2.weight.data, layer_id + 1) + rescale(layer.ffn.B.fc2.weight.data, layer_id + 1) + + if self.use_vl: + pre_layers = len(self.backbone.encoder.layers) + 1 + for layer_id, layer in enumerate(self.backbone_vl.layers): + rescale(layer.self_attn.v_proj.A.weight.data, layer_id + pre_layers) + rescale(layer.self_attn.v_proj.B.weight.data, layer_id + pre_layers) + rescale(layer.self_attn.out_proj.A.weight.data, layer_id + pre_layers) + rescale(layer.self_attn.out_proj.B.weight.data, layer_id + pre_layers) + rescale(layer.ffn.A.fc2.weight.data, layer_id + pre_layers) + rescale(layer.ffn.B.fc2.weight.data, layer_id + pre_layers) + + def load_pretrained_weight(self): + if self.hparams.config["load_path"] != "" and not self.hparams.config["test_only"]: + config = self.hparams.config + ckpt = torch.load(self.hparams.config["load_path"], map_location="cpu") + rank_zero_info("Load ckpt from: {}".format(self.hparams.config["load_path"])) + + state_dict = None + + for state_dict_key in ("state_dict", "module", "model"): + if state_dict_key in ckpt: + rank_zero_info("Read state dict from ckpt[%s]. " % state_dict_key) + state_dict = ckpt[state_dict_key] + break + if state_dict_key == "module": + state_dict = convert_deepspeed_ckpt(state_dict, self.backbone.vision_embed.num_position_embeddings()) + if state_dict_key == "state_dict": + state_dict = convert_pl_ckpt(state_dict, self.backbone.vision_embed.num_position_embeddings()) + if state_dict is None: + if list(ckpt.keys())[0].startswith('_forward_module.'): + rank_zero_info("Read state dict from ckpt with _forward_module prefix. ") + state_dict = convert_deepspeed_ckpt(ckpt, + self.backbone.vision_embed.num_position_embeddings()) + else: + rank_zero_info("Read state dict from ckpt. ") + state_dict = ckpt + + missing_keys, unexpected_keys = self.load_state_dict(state_dict, strict=False) + missing_keys = [k for k in missing_keys if "itc_teacher" not in k] + rank_zero_info("missing_keys: {}".format(missing_keys)) + rank_zero_info("unexpected_keys: {}".format(unexpected_keys)) + + def infer_text( + self, + batch, + mask_text=False, + ): + do_mlm = "_mlm" if mask_text else "" + text_ids = batch[f"text_ids{do_mlm}"] + text_labels = batch[f"text_labels{do_mlm}"] + text_masks = batch[f"text_masks"] + text_embed = self.backbone.text_embed(text_ids) + text_padding_position = 1 - text_masks + lffn_hiddens = self.backbone( + textual_tokens=text_ids, + text_padding_position=text_padding_position, + )["encoder_out"] + vlffn_hiddens = self.backbone_vl( + src_tokens=None, + token_embeddings=lffn_hiddens, + encoder_padding_mask=text_padding_position, + multiway_split_position=-1, + )["encoder_out"] + + cls_feats = self.itc_text_proj(lffn_hiddens[:, 0]) + cls_feats = cls_feats / cls_feats.norm(dim=-1, keepdim=True) + + cls_vlffn_feats = self.itc_vl_text_proj(vlffn_hiddens[:, 0]) + cls_vlffn_feats = cls_vlffn_feats / cls_vlffn_feats.norm(dim=-1, keepdim=True) + + ret = { + "cls_feats": cls_feats, + "cls_vlffn_feats": cls_vlffn_feats, + "text_embed": text_embed, + } + + return ret + + def infer_image( + self, + batch, + mask_image=False, + image_token_type_idx=1, + image_embeds=None, + image_masks=None, + ): + if f"image_{image_token_type_idx - 1}" in batch: + imgkey = f"image_{image_token_type_idx - 1}" + else: + imgkey = "image" + + img = batch[imgkey][0] + if mask_image: + image_masks = batch[f"{imgkey}_masks"][0].flatten(1) + + with torch.no_grad(): + img = self.visual_tokenizer.pre_process(img) + quantize, embed_ind, _ = self.visual_tokenizer.encode(img) + image_ids = embed_ind.view(img.shape[0], -1) + + image_labels = torch.full_like(image_ids, -100) + bool_masked_pos = image_masks.to(torch.bool) + image_labels[bool_masked_pos] = image_ids[bool_masked_pos] + + img_tensor = img_norm(img) + vffn_hiddens = self.backbone(visual_tokens=img_tensor)["encoder_out"] + vlffn_hiddens = self.backbone_vl( + src_tokens=None, + token_embeddings=vffn_hiddens, + multiway_split_position=-1, + )["encoder_out"] + + cls_feats = self.itc_image_proj(vffn_hiddens[:, 0]) + cls_feats = cls_feats / cls_feats.norm(dim=-1, keepdim=True) + + cls_vlffn_feats = self.itc_vl_image_proj(vlffn_hiddens[:, 0]) + cls_vlffn_feats = cls_vlffn_feats / cls_vlffn_feats.norm(dim=-1, keepdim=True) + + ret = { + "image_feats": vffn_hiddens, + "cls_feats": cls_feats, + "cls_vlffn_feats": cls_vlffn_feats, + } + + return ret diff --git a/vlmo/modules/vlmo_utils.py b/vlmo/modules/vlmo_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..eca171bbbcd4ac851b50e578b54cb5a6ee083cb2 --- /dev/null +++ b/vlmo/modules/vlmo_utils.py @@ -0,0 +1,12 @@ +def set_task(pl_module): + pl_module.current_tasks = [k for k, v in pl_module.hparams.config["loss_names"].items() if v >= 1] + return + + +def no_sync_module_apply(module, fn): + """FSDP module .apply will use _unshard_params_recurse which will sync params across ranks. + using this function when apply fn is unnecessary to sync params across ranks. + """ + for child in module.children(): + fn(child) + no_sync_module_apply(child, fn) diff --git a/vlmo/tokenizer/__init__.py b/vlmo/tokenizer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ae1b262cf63905b2009a53c159d187fecb875a26 --- /dev/null +++ b/vlmo/tokenizer/__init__.py @@ -0,0 +1,6 @@ +# coding: utf-8 +# Copyright (c) Antfin, Inc. All rights reserved. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function diff --git a/vlmo/tokenizer/sp.model b/vlmo/tokenizer/sp.model new file mode 100644 index 0000000000000000000000000000000000000000..9bb672d0fd1671690eb72b8274b4e835eb3540f6 --- /dev/null +++ b/vlmo/tokenizer/sp.model @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b7fe3bcc8d284fcb782691411e8b6fd4f45d7245565b094de6ab795e66bcd32f +size 2270960 diff --git a/vlmo/tokenizer/tokenization_glm.py b/vlmo/tokenizer/tokenization_glm.py new file mode 100644 index 0000000000000000000000000000000000000000..611cb97d9fa02e8fe3cf7a116bd8520c567442cb --- /dev/null +++ b/vlmo/tokenizer/tokenization_glm.py @@ -0,0 +1,307 @@ +import os +from shutil import copyfile +from typing import Optional, Tuple, List, Union + +import sentencepiece as spm +import torch +from transformers import PreTrainedTokenizer +from transformers.models.auto.tokenization_auto import get_tokenizer_config +from transformers.tokenization_utils_base import BatchEncoding +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + +class GLMBatchEncoding(BatchEncoding): + def to(self, device: Union[str, "torch.device"]) -> "BatchEncoding": + """ + Send all values to device by calling `v.to(device)` (PyTorch only). + + Args: + device (`str` or `torch.device`): The device to put the tensors on. + + Returns: + [`BatchEncoding`]: The same instance after modification. + """ + + # This check catches things like APEX blindly calling "to" on all inputs to a module + # Otherwise it passes the casts down and casts the LongTensor containing the token idxs + # into a HalfTensor + if isinstance(device, str) or isinstance(device, int): + #if isinstance(device, str) or _is_torch_device(device) or isinstance(device, int): + self.data = {k: v.to(device=device) if torch.is_tensor(v) else v for k, v in self.data.items()} + else: + logger.warning(f"Attempting to cast a BatchEncoding to type {str(device)}. This is not supported.") + return self + + +class GLMTokenizerMixin: + @property + def sop_token(self) -> Optional[str]: + return "<|startofpiece|>" + + @property + def sop_token_id(self) -> Optional[int]: + """ + `Optional[int]`: Id of the start token in the vocabulary, used when training a model with autoregressive blank filling. + """ + return self.convert_tokens_to_ids(self.sop_token) + + @property + def eop_token(self) -> Optional[str]: + return "<|endofpiece|>" + + @property + def eop_token_id(self) -> Optional[int]: + """ + `Optional[int]`: Id of the end token in the vocabulary, used when training a model with autoregressive blank filling. + """ + return self.convert_tokens_to_ids(self.eop_token) + + @property + def gmask_token_id(self) -> int: + return self.convert_tokens_to_ids("[gMASK]") + + @property + def smask_token_id(self) -> int: + return self.convert_tokens_to_ids("[sMASK]") + + @property + def mask_token_ids(self): + return [self.mask_token_id, self.smask_token_id, self.gmask_token_id] + + def _build_input_for_multiple_choice(self, context, choices): + context_id = context["input_ids"] + if torch.is_tensor(context_id): + context_id = context_id.tolist() + + division = len(context_id) + mask_position = context_id.index(self.mask_token_id) + + token = torch.tensor(context_id, dtype=torch.long) + attention_mask = [context["attention_mask"].expand(division, -1)] + position_id = torch.arange(division, dtype=torch.long) + block_position_id = torch.zeros(division, dtype=torch.long) + + choice_ids, choice_indices = [], [] + + for choice_str in choices: + choice = torch.tensor(self(choice_str, add_special_tokens=False, padding=False)['input_ids'], + dtype=torch.long) + choice_ids.append(choice) + choice_indices.append(torch.arange(len(token), len(token) + len(choice), dtype=torch.long)) + attention_mask.append(torch.tril(torch.ones((len(choice), len(choice)), dtype=torch.long))) + + token = torch.cat((token, torch.tensor([self.sop_token_id], dtype=torch.long), choice[:-1])) + position_id = torch.cat((position_id, torch.tensor([mask_position] * len(choice), dtype=torch.long))) + block_position_id = torch.cat((block_position_id, torch.arange(1, 1 + len(choice), dtype=torch.long))) + + attention_mask = torch.block_diag(*attention_mask) + attention_mask[division:, :division] = context["attention_mask"].unsqueeze(0) + + return { + "input_ids": token, + "position_ids": torch.stack((position_id, block_position_id)), + "attention_mask": attention_mask, + "choice_ids": choice_ids, + "choice_indices": choice_indices + } + + def _pad_batch(self, tokens, position_ids, attention_mask, max_seq_length): + pad_length = max_seq_length - len(tokens) + attention_mask = torch.nn.functional.pad( + attention_mask, + (0, pad_length, 0, pad_length), + mode="constant", + value=0, + ) + tokens = torch.cat((tokens, torch.zeros(pad_length, dtype=torch.long))) + position_ids = torch.cat((position_ids, position_ids[..., -1:].expand(-1, pad_length)), dim=-1) + return tokens, position_ids, attention_mask + + def _collate(self, samples): + TILE = 1 + length_to_pad = (max(map(lambda spl: len(spl["input_ids"]), samples)) + TILE - 1) // TILE * TILE + + token_batch, position_id_batch, attention_mask_batch = [], [], [] + choices_batch, choice_target_ids_batch = [], [] + + for sample in samples: + token, position_id, attention_mask = self._pad_batch( + sample["input_ids"], sample["position_ids"], sample["attention_mask"], length_to_pad + ) + token_batch.append(token) + position_id_batch.append(position_id) + attention_mask_batch.append(attention_mask) + choices_batch.append(sample["choice_ids"]) + choice_target_ids_batch.append(sample["choice_indices"]) + return { + "input_ids": torch.stack(token_batch), + "position_ids": torch.stack(position_id_batch), + "attention_mask": torch.stack(attention_mask_batch).unsqueeze(1), + "choice_ids": choices_batch, + "choice_indices": choice_target_ids_batch, + } + + def build_inputs_for_multiple_choice(self, model_input: BatchEncoding, choices, max_length=None): + samples = [{key: value[i] for key, value in model_input.items()} for i in range(len(model_input["input_ids"]))] + samples = [self._build_input_for_multiple_choice(sample, choice) for sample, choice in + zip(samples, choices)] + inputs = self._collate(samples) + return GLMBatchEncoding(inputs) + + def build_inputs_for_generation(self, model_input: BatchEncoding, max_gen_length=512, targets=None, padding=False): + mask_ids = self.mask_token_ids + input_ids = model_input.input_ids + batch_size, seq_length = input_ids.shape[:2] + position_id, block_position_id = list(range(seq_length)), [0 for _ in range(seq_length)] + position_ids, block_position_ids = [], [] + labels = None + if targets is not None: + is_batched = isinstance(targets, (list, tuple)) + targets = self(targets, add_special_tokens=False, padding=False).input_ids + if not is_batched: + targets = [targets] + assert len(targets) == len(input_ids) + targets = [(target + [self.eop_token_id])[:max_gen_length] for target in targets] + if not padding: + max_gen_length = max(map(len, targets)) + targets = [[self.sop_token_id] + target for target in targets] + labels = [target[1:] for target in targets] + targets = [target + [self.pad_token_id] * (max_gen_length + 1 - len(target)) for target in targets] + labels = [label + [-100] * (max_gen_length - len(label)) for label in labels] + targets = torch.tensor(targets, dtype=input_ids.dtype, device=input_ids.device) + labels = torch.tensor(labels, dtype=input_ids.dtype, device=input_ids.device) + labels = torch.cat((input_ids.new_full((batch_size, seq_length), -100), labels), dim=1) + for i in range(batch_size): + mask_positions = [] + for mask_id in mask_ids: + mask_positions += (input_ids[i] == mask_id).nonzero(as_tuple=True)[0].tolist() + if not mask_positions: + raise ValueError("Cannot find mask token in the input") + mask_positions.sort() + mask_pos = mask_positions[0] + position_ids.append(position_id + [mask_pos] * max_gen_length) + block_position_ids.append(block_position_id + list(range(1, max_gen_length + 1))) + position_ids = torch.tensor(position_ids, dtype=input_ids.dtype, device=input_ids.device) + block_position_ids = torch.tensor(block_position_ids, dtype=input_ids.dtype, device=input_ids.device) + position_ids = torch.stack((position_ids, block_position_ids), dim=1) + attention_mask = model_input.attention_mask + attention_mask = attention_mask.unsqueeze(1).expand(-1, seq_length + max_gen_length, -1) + generation_attention_mask = torch.cat([attention_mask.new_zeros((seq_length, max_gen_length)), + torch.tril(attention_mask.new_ones((max_gen_length, max_gen_length)))], + dim=0).unsqueeze(0).expand(batch_size, -1, -1) + attention_mask = torch.cat((attention_mask, generation_attention_mask), dim=2) + attention_mask = attention_mask.unsqueeze(1) + if targets is None: + input_ids = torch.cat((input_ids, input_ids.new_full((batch_size, 1), self.sop_token_id)), dim=-1) + else: + input_ids = torch.cat((input_ids, targets[:, :-1]), dim=1) + batch = {"input_ids": input_ids, "position_ids": position_ids} + if labels is None: + batch["generation_attention_mask"] = attention_mask + else: + batch["attention_mask"] = attention_mask + batch["labels"] = labels + return BatchEncoding(batch) + +def encode_whitespaces(content): + for i in range(10, 1, -1): + content = content.replace(' '*i, f'<|blank_{i}|>') + return content + +def decode_whitespaces(content): + for i in range(10, 1, -1): + content = content.replace(f'<|blank_{i}|>', ' '*i) + return content + + +class GLMChineseTokenizer(PreTrainedTokenizer, GLMTokenizerMixin): + vocab_files_names = {"vocab_file": "sp.model"} + truncation_side: str = "left" + + def __init__(self, vocab_file, **kwargs): + self.vocab_file = vocab_file + self.sp_model = spm.SentencePieceProcessor() + self.sp_model.Load(vocab_file) + super().__init__(**kwargs) + + @property + def vocab_size(self): + return len(self.sp_model) + + def get_vocab(self): + vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} + vocab.update(self.added_tokens_encoder) + return vocab + + def _tokenize(self, text, **kwargs): + text = encode_whitespaces(text) + return self.sp_model.EncodeAsPieces(text) + #return self.sp_model.EncodeAsPieces(text, out_type=str) + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.sp_model.PieceToId(token) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + return self.sp_model.IdToPiece(index) + + def convert_tokens_to_string(self, tokens): + res = self.sp_model.DecodeIds(tokens) + return decode_whitespaces(res) + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + out_vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + self.vocab_files_names["vocab_file"] + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): + copyfile(self.vocab_file, out_vocab_file) + elif not os.path.isfile(self.vocab_file): + with open(out_vocab_file, "wb") as fi: + content_spiece_model = self.sp_model.serialized_model_proto() + fi.write(content_spiece_model) + + return (out_vocab_file,) + + def build_inputs_with_special_tokens( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A BERT sequence has the following format: + + - single sequence: ``[CLS] X [SEP]`` + - pair of sequences: ``[CLS] A [SEP] B [SEP]`` + + Args: + token_ids_0 (:obj:`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (:obj:`List[int]`, `optional`): + Optional second list of IDs for sequence pairs. + + Returns: + :obj:`List[int]`: List of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens. + """ + assert token_ids_1 is None + cls = [self.cls_token_id] + eos = [self.eos_token_id] + return cls + token_ids_0 + eos + + +class GLMTokenizer: + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): + tokenizer_config = get_tokenizer_config(pretrained_model_name_or_path, **kwargs) + config_tokenizer_class = tokenizer_config.get("tokenizer_class") + + if config_tokenizer_class == "GLMChineseTokenizer": + tokenizer_class = GLMChineseTokenizer + else: + raise NotImplementedError("Not implemented tokenizer type:", config_tokenizer_class) + return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) diff --git a/vlmo/tokenizer/tokenizer_config.json b/vlmo/tokenizer/tokenizer_config.json new file mode 100644 index 0000000000000000000000000000000000000000..6d22dc57fabc8dedb260718bcf4753381a9d3066 --- /dev/null +++ b/vlmo/tokenizer/tokenizer_config.json @@ -0,0 +1,17 @@ +{ + "name_or_path": "THUDM/glm-10b-chinese", + "eos_token": "<|endoftext|>", + "pad_token": "<|endoftext|>", + "cls_token": "[CLS]", + "mask_token": "[MASK]", + "unk_token": "[UNK]", + "add_prefix_space": false, + "tokenizer_class": "GLMChineseTokenizer", + "use_fast": false, + "auto_map": { + "AutoTokenizer": [ + "tokenization_glm.GLMChineseTokenizer", + null + ] + } +} diff --git a/vlmo/torchscale/__init__.py b/vlmo/torchscale/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3ae31e2507e8759f2ac7f85e517288f536c04ac3 --- /dev/null +++ b/vlmo/torchscale/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) 2022 Microsoft +# Licensed under The MIT License [see LICENSE for details] diff --git a/vlmo/torchscale/architecture/__init__.py b/vlmo/torchscale/architecture/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3ae31e2507e8759f2ac7f85e517288f536c04ac3 --- /dev/null +++ b/vlmo/torchscale/architecture/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) 2022 Microsoft +# Licensed under The MIT License [see LICENSE for details] diff --git a/vlmo/torchscale/architecture/config.py b/vlmo/torchscale/architecture/config.py new file mode 100644 index 0000000000000000000000000000000000000000..a22fbf58c151f9a6c01f23ba3954171eb432e7cb --- /dev/null +++ b/vlmo/torchscale/architecture/config.py @@ -0,0 +1,197 @@ +# Copyright (c) 2022 Microsoft +# Licensed under The MIT License [see LICENSE for details] + + +class EncoderConfig(object): + def __init__(self, **kwargs): + self.encoder_embed_dim = kwargs.pop("encoder_embed_dim", 768) + self.encoder_attention_heads = kwargs.pop("encoder_attention_heads", 12) + self.encoder_ffn_embed_dim = kwargs.pop("encoder_ffn_embed_dim", 3072) + self.encoder_layers = kwargs.pop("encoder_layers", 12) + self.encoder_normalize_before = kwargs.pop("encoder_normalize_before", True) + self.normalize_output = kwargs.pop("normalize_output", True) + self.activation_fn = kwargs.pop("activation_fn", "gelu") + self.dropout = kwargs.pop("dropout", 0.0) + self.drop_path_rate = kwargs.pop("drop_path_rate", 0.0) + self.attention_dropout = kwargs.pop("attention_dropout", 0.0) + self.activation_dropout = kwargs.pop("activation_dropout", 0.0) + self.no_scale_embedding = kwargs.pop("no_scale_embedding", True) + self.layernorm_embedding = kwargs.pop("layernorm_embedding", False) + self.moe_freq = kwargs.pop("moe_freq", 0) + self.moe_top1_expert = kwargs.pop("moe_top1_expert", False) + self.moe_expert_count = kwargs.pop("moe_expert_count", 0) + self.moe_gating_use_fp32 = kwargs.pop("moe_gating_use_fp32", True) + self.moe_eval_capacity_token_fraction = kwargs.pop("moe_eval_capacity_token_fraction", 0.25) + self.moe_second_expert_policy = kwargs.pop("moe_second_expert_policy", "random") + self.moe_normalize_gate_prob_before_dropping = kwargs.pop("moe_normalize_gate_prob_before_dropping", False) + self.use_xmoe = kwargs.pop("use_xmoe", False) + self.rel_pos_buckets = kwargs.pop("rel_pos_buckets", 0) + self.max_rel_pos = kwargs.pop("max_rel_pos", 0) + self.deepnorm = kwargs.pop("deepnorm", False) + self.subln = kwargs.pop("subln", True) + self.bert_init = kwargs.pop("bert_init", False) + self.multiway = kwargs.pop("multiway", False) + self.share_encoder_input_output_embed = kwargs.pop("share_encoder_input_output_embed", False) + self.max_source_positions = kwargs.pop("max_source_positions", 1024) + self.no_output_layer = kwargs.pop("no_output_layer", False) + self.layernorm_eps = kwargs.pop("layernorm_eps", 1e-5) + self.share_layer = kwargs.pop("share_layer", False) + self.share_attn = kwargs.pop("share_attn", False) + self.mask_ratio = kwargs.pop("mask_ratio", 0) + self.max_text_len = kwargs.pop("max_text_len", 52) + self.one_attn = kwargs.pop('one_attn', False) + + + # Text + self.vocab_size = kwargs.pop("vocab_size", -1) + # Vision + self.img_size = kwargs.pop("img_size", 224) + self.patch_size = kwargs.pop("patch_size", 16) + self.in_chans = kwargs.pop("in_chans", 3) + # Fairscale + self.checkpoint_activations = kwargs.pop("checkpoint_activations", False) + self.fsdp = kwargs.pop("fsdp", False) + self.ddp_rank = kwargs.pop("ddp_rank", 0) + self.xpos_rel_pos = kwargs.pop("xpos_rel_pos", False) + self.xpos_scale_base = kwargs.pop("xpos_scale_base", 512) + + if self.deepnorm: + self.encoder_normalize_before = False + self.subln = False + if self.subln: + self.encoder_normalize_before = True + self.deepnorm = False + if self.use_xmoe: + self.moe_normalize_gate_prob_before_dropping = True + self.moe_second_expert_policy = "random" + assert self.moe_freq > 0 and self.moe_expert_count > 0 + + def override(self, args): + for hp in self.__dict__.keys(): + if getattr(args, hp, None) is not None: + self.__dict__[hp] = getattr(args, hp, None) + + +class DecoderConfig(object): + def __init__(self, **kwargs): + self.decoder_embed_dim = kwargs.pop("decoder_embed_dim", 768) + self.decoder_attention_heads = kwargs.pop("decoder_attention_heads", 12) + self.decoder_ffn_embed_dim = kwargs.pop("decoder_ffn_embed_dim", 3072) + self.decoder_layers = kwargs.pop("decoder_layers", 12) + self.decoder_normalize_before = kwargs.pop("decoder_normalize_before", True) + self.activation_fn = kwargs.pop("activation_fn", "gelu") + self.dropout = kwargs.pop("dropout", 0.0) + self.drop_path_rate = kwargs.pop("drop_path_rate", 0.0) + self.attention_dropout = kwargs.pop("attention_dropout", 0.0) + self.activation_dropout = kwargs.pop("activation_dropout", 0.0) + self.no_scale_embedding = kwargs.pop("no_scale_embedding", True) + self.layernorm_embedding = kwargs.pop("layernorm_embedding", False) + self.moe_freq = kwargs.pop("moe_freq", 0) + self.moe_top1_expert = kwargs.pop("moe_top1_expert", False) + self.moe_expert_count = kwargs.pop("moe_expert_count", 0) + self.moe_gating_use_fp32 = kwargs.pop("moe_gating_use_fp32", True) + self.moe_eval_capacity_token_fraction = kwargs.pop("moe_eval_capacity_token_fraction", 0.25) + self.moe_second_expert_policy = kwargs.pop("moe_second_expert_policy", "random") + self.moe_normalize_gate_prob_before_dropping = kwargs.pop("moe_normalize_gate_prob_before_dropping", False) + self.use_xmoe = kwargs.pop("use_xmoe", False) + self.rel_pos_buckets = kwargs.pop("rel_pos_buckets", 0) + self.max_rel_pos = kwargs.pop("max_rel_pos", 0) + self.deepnorm = kwargs.pop("deepnorm", False) + self.subln = kwargs.pop("subln", True) + self.bert_init = kwargs.pop("bert_init", False) + self.multiway = kwargs.pop("multiway", False) + self.share_decoder_input_output_embed = kwargs.pop("share_decoder_input_output_embed", False) + self.max_target_positions = kwargs.pop("max_target_positions", 1024) + self.no_output_layer = kwargs.pop("no_output_layer", False) + self.layernorm_eps = kwargs.pop("layernorm_eps", 1e-5) + # Text + self.vocab_size = kwargs.pop("vocab_size", -1) + # Fairscale + self.checkpoint_activations = kwargs.pop("checkpoint_activations", False) + self.fsdp = kwargs.pop("fsdp", False) + self.ddp_rank = kwargs.pop("ddp_rank", 0) + self.xpos_rel_pos = kwargs.pop("xpos_rel_pos", False) + self.xpos_scale_base = kwargs.pop("xpos_scale_base", 512) + + if self.deepnorm: + self.decoder_normalize_before = False + self.subln = False + if self.subln: + self.decoder_normalize_before = True + self.deepnorm = False + if self.use_xmoe: + self.moe_normalize_gate_prob_before_dropping = True + self.moe_second_expert_policy = "random" + assert self.moe_freq > 0 and self.moe_expert_count > 0 + + def override(self, args): + for hp in self.__dict__.keys(): + if getattr(args, hp, None) is not None: + self.__dict__[hp] = getattr(args, hp, None) + + +class EncoderDecoderConfig(object): + def __init__(self, **kwargs): + self.encoder_embed_dim = kwargs.pop("encoder_embed_dim", 768) + self.encoder_attention_heads = kwargs.pop("encoder_attention_heads", 12) + self.encoder_ffn_embed_dim = kwargs.pop("encoder_ffn_embed_dim", 3072) + self.encoder_layers = kwargs.pop("encoder_layers", 12) + self.encoder_normalize_before = kwargs.pop("encoder_normalize_before", True) + self.decoder_embed_dim = kwargs.pop("decoder_embed_dim", 768) + self.decoder_attention_heads = kwargs.pop("decoder_attention_heads", 12) + self.decoder_ffn_embed_dim = kwargs.pop("decoder_ffn_embed_dim", 3072) + self.decoder_layers = kwargs.pop("decoder_layers", 12) + self.decoder_normalize_before = kwargs.pop("decoder_normalize_before", True) + self.activation_fn = kwargs.pop("activation_fn", "gelu") + self.dropout = kwargs.pop("dropout", 0.0) + self.drop_path_rate = kwargs.pop("drop_path_rate", 0.0) + self.attention_dropout = kwargs.pop("attention_dropout", 0.0) + self.activation_dropout = kwargs.pop("activation_dropout", 0.0) + self.no_scale_embedding = kwargs.pop("no_scale_embedding", True) + self.layernorm_embedding = kwargs.pop("layernorm_embedding", False) + self.moe_freq = kwargs.pop("moe_freq", 0) + self.moe_top1_expert = kwargs.pop("moe_top1_expert", False) + self.moe_expert_count = kwargs.pop("moe_expert_count", 0) + self.moe_gating_use_fp32 = kwargs.pop("moe_gating_use_fp32", True) + self.moe_eval_capacity_token_fraction = kwargs.pop("moe_eval_capacity_token_fraction", 0.25) + self.moe_second_expert_policy = kwargs.pop("moe_second_expert_policy", "random") + self.moe_normalize_gate_prob_before_dropping = kwargs.pop("moe_normalize_gate_prob_before_dropping", False) + self.use_xmoe = kwargs.pop("use_xmoe", False) + self.rel_pos_buckets = kwargs.pop("rel_pos_buckets", 0) + self.max_rel_pos = kwargs.pop("max_rel_pos", 0) + self.deepnorm = kwargs.pop("deepnorm", False) + self.subln = kwargs.pop("subln", True) + self.bert_init = kwargs.pop("bert_init", False) + self.multiway = kwargs.pop("multiway", False) + self.share_all_embeddings = kwargs.pop("share_all_embeddings", False) + self.share_decoder_input_output_embed = kwargs.pop("share_decoder_input_output_embed", False) + self.max_source_positions = kwargs.pop("max_source_positions", 1024) + self.max_target_positions = kwargs.pop("max_target_positions", 1024) + self.no_output_layer = kwargs.pop("no_output_layer", False) + self.layernorm_eps = kwargs.pop("layernorm_eps", 1e-5) + # Text + self.vocab_size = kwargs.pop("vocab_size", -1) + # Fairscale + self.checkpoint_activations = kwargs.pop("checkpoint_activations", False) + self.fsdp = kwargs.pop("fsdp", False) + self.ddp_rank = kwargs.pop("ddp_rank", 0) + self.xpos_rel_pos = kwargs.pop("xpos_rel_pos", False) + self.xpos_scale_base = kwargs.pop("xpos_scale_base", 512) + + if self.deepnorm: + self.encoder_normalize_before = False + self.decoder_normalize_before = False + self.subln = False + if self.subln: + self.encoder_normalize_before = True + self.decoder_normalize_before = True + self.deepnorm = False + if self.use_xmoe: + self.moe_normalize_gate_prob_before_dropping = True + self.moe_second_expert_policy = "random" + assert self.moe_freq > 0 and self.moe_expert_count > 0 + + def override(self, args): + for hp in self.__dict__.keys(): + if getattr(args, hp, None) is not None: + self.__dict__[hp] = getattr(args, hp, None) diff --git a/vlmo/torchscale/architecture/decoder.py b/vlmo/torchscale/architecture/decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..e0e5d3b961eb121c79fba8f3fb57fd4af33710af --- /dev/null +++ b/vlmo/torchscale/architecture/decoder.py @@ -0,0 +1,428 @@ +# Copyright (c) 2022 Microsoft +# Licensed under The MIT License [see LICENSE for details] + +import math + +import numpy as np +import torch +import torch.nn as nn +from fairscale.nn import checkpoint_wrapper, wrap + +from vlmo.torchscale.architecture.utils import init_bert_params +from vlmo.torchscale.component.droppath import DropPath +from vlmo.torchscale.component.feedforward_network import FeedForwardNetwork, make_experts +from vlmo.torchscale.component.multihead_attention import MultiheadAttention +from vlmo.torchscale.component.relative_position_bias import RelativePositionBias +#from vlmo.torchscale.component.xmoe.moe_layer import MOELayer +#from vlmo.torchscale.component.xmoe.routing import Top1Gate, Top2Gate + +try: + from apex.normalization import FusedLayerNorm as LayerNorm +except ModuleNotFoundError: + from torch.nn import LayerNorm + + +class DecoderLayer(nn.Module): + def __init__( + self, + args, + depth, + is_moe_layer=False, + is_encoder_decoder=False, + ): + super().__init__() + self.args = args + self.embed_dim = args.decoder_embed_dim + self.dropout_module = torch.nn.Dropout(args.dropout) + + if args.drop_path_rate > 0: + drop_path_prob = np.linspace(0, args.drop_path_rate, args.decoder_layers)[depth] + self.drop_path = DropPath(drop_path_prob) + else: + self.drop_path = None + + self.self_attn = self.build_self_attention(self.embed_dim, args) + + self.normalize_before = args.decoder_normalize_before + + self.self_attn_layer_norm = LayerNorm(self.embed_dim, eps=args.layernorm_eps) + + if not is_encoder_decoder: + self.encoder_attn = None + self.encoder_attn_layer_norm = None + else: + self.encoder_attn = self.build_encoder_attention(self.embed_dim, args) + self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, eps=args.layernorm_eps) + + self.is_moe_layer = is_moe_layer + self.ffn_dim = args.decoder_ffn_embed_dim + + if not self.is_moe_layer: + self.ffn = self.build_ffn( + self.embed_dim, + self.args, + ) + else: + if args.moe_top1_expert: + gate = Top1Gate( + self.embed_dim, + args.moe_expert_count, + use_fp32=args.moe_gating_use_fp32, + moe_eval_capacity_token_fraction=args.moe_eval_capacity_token_fraction, + use_xmoe=args.use_xmoe, + ) + else: + gate = Top2Gate( + self.embed_dim, + args.moe_expert_count, + args.moe_gating_use_fp32, + args.moe_second_expert_policy, + args.moe_normalize_gate_prob_before_dropping, + args.moe_eval_capacity_token_fraction, + use_xmoe=args.use_xmoe, + ) + experts = make_experts(args, self.embed_dim, self.ffn_dim) + self.moe_layer = MOELayer(gate, experts, args) + + self.final_layer_norm = LayerNorm(self.embed_dim, eps=args.layernorm_eps) + + if args.deepnorm: + if is_encoder_decoder: + self.alpha = math.pow(3.0 * args.decoder_layers, 0.25) + else: + self.alpha = math.pow(2.0 * args.decoder_layers, 0.25) + else: + self.alpha = 1.0 + + def build_ffn(self, embed_dim, args): + return FeedForwardNetwork( + embed_dim, + self.ffn_dim, + args.activation_fn, + args.dropout, + args.activation_dropout, + args.layernorm_eps, + args.subln, + ) + + def build_self_attention(self, embed_dim, args): + return MultiheadAttention( + args, + embed_dim, + args.decoder_attention_heads, + dropout=args.attention_dropout, + self_attention=True, + encoder_decoder_attention=False, + subln=args.subln, + ) + + def build_encoder_attention(self, embed_dim, args): + return MultiheadAttention( + args, + embed_dim, + args.decoder_attention_heads, + dropout=args.attention_dropout, + self_attention=False, + encoder_decoder_attention=True, + subln=args.subln, + ) + + def residual_connection(self, x, residual): + return residual * self.alpha + x + + def forward( + self, + x, + encoder_out=None, + encoder_padding_mask=None, + incremental_state=None, + self_attn_mask=None, + self_attn_padding_mask=None, + self_attn_rel_pos=None, + cross_attn_rel_pos=None, + ): + residual = x + if self.normalize_before: + x = self.self_attn_layer_norm(x) + + x, attn = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=self_attn_padding_mask, + incremental_state=incremental_state, + attn_mask=self_attn_mask, + rel_pos=self_attn_rel_pos, + ) + x = self.dropout_module(x) + + if self.drop_path is not None: + x = self.drop_path(x) + + x = self.residual_connection(x, residual) + if not self.normalize_before: + x = self.self_attn_layer_norm(x) + + if self.encoder_attn is not None and encoder_out is not None: + residual = x + if self.normalize_before: + x = self.encoder_attn_layer_norm(x) + + x, attn = self.encoder_attn( + query=x, + key=encoder_out, + value=encoder_out, + key_padding_mask=encoder_padding_mask, + incremental_state=None, + rel_pos=cross_attn_rel_pos, + ) + x = self.dropout_module(x) + + if self.drop_path is not None: + x = self.drop_path(x) + + x = self.residual_connection(x, residual) + if not self.normalize_before: + x = self.encoder_attn_layer_norm(x) + + residual = x + if self.normalize_before: + x = self.final_layer_norm(x) + if not self.is_moe_layer: + x = self.ffn(x) + l_aux = None + else: + x, l_aux = self.moe_layer(x) + + if self.drop_path is not None: + x = self.drop_path(x) + + x = self.residual_connection(x, residual) + if not self.normalize_before: + x = self.final_layer_norm(x) + + return x, attn, None, l_aux + + +class Decoder(nn.Module): + def __init__( + self, args, embed_tokens=None, embed_positions=None, output_projection=None, is_encoder_decoder=False, **kwargs + ): + super().__init__(**kwargs) + self.args = args + + self.dropout_module = torch.nn.Dropout(args.dropout) + + embed_dim = args.decoder_embed_dim + self.embed_dim = embed_dim + self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt(embed_dim) + + self.embed_tokens = embed_tokens + self.embed_positions = embed_positions + + if output_projection is None and not args.no_output_layer and args.vocab_size > 0: + self.output_projection = self.build_output_projection(args) + else: + self.output_projection = output_projection + + if args.layernorm_embedding: + self.layernorm_embedding = LayerNorm(embed_dim, eps=args.layernorm_eps) + else: + self.layernorm_embedding = None + + self.layers = nn.ModuleList([]) + + moe_freq = args.moe_freq + for i in range(args.decoder_layers): + is_moe_layer = moe_freq != 0 and (i + 1) % moe_freq == 0 + self.layers.append( + self.build_decoder_layer( + args, + depth=i, + is_moe_layer=is_moe_layer, + is_encoder_decoder=is_encoder_decoder, + ) + ) + + self.num_layers = len(self.layers) + + if args.decoder_normalize_before: + self.layer_norm = LayerNorm(embed_dim, eps=args.layernorm_eps) + else: + self.layer_norm = None + + self.self_attn_relative_position = None + self.cross_attn_relative_position = None + + if args.rel_pos_buckets > 0 and args.max_rel_pos > 0: + self.self_attn_relative_position = RelativePositionBias( + num_buckets=args.rel_pos_buckets, + max_distance=args.max_rel_pos, + n_heads=args.decoder_attention_heads, + ) + if is_encoder_decoder: + self.cross_attn_relative_position = RelativePositionBias( + num_buckets=args.rel_pos_buckets, + max_distance=args.max_rel_pos, + n_heads=args.decoder_attention_heads, + ) + + if args.bert_init: + self.apply(init_bert_params) + + if args.deepnorm: + if is_encoder_decoder: + init_scale = math.pow(12.0 * args.decoder_layers, 0.25) + else: + init_scale = math.pow(8.0 * args.decoder_layers, 0.25) + for name, p in self.named_parameters(): + if "fc1" in name or "fc2" in name or "out_proj" in name or "v_proj" in name: + p.data.div_(init_scale) + + if args.subln: + if is_encoder_decoder: + init_scale = math.sqrt(math.log(args.decoder_layers * 3)) + else: + init_scale = math.sqrt(math.log(args.decoder_layers * 2)) + for name, p in self.named_parameters(): + if "encoder_attn" in name: + continue + if "fc1" in name or "fc2" in name or "out_proj" in name or "v_proj" in name: + p.data.mul_(init_scale) + + def build_output_projection( + self, + args, + ): + if args.share_decoder_input_output_embed: + output_projection = torch.nn.Linear( + self.embed_tokens.weight.shape[1], + self.embed_tokens.weight.shape[0], + bias=False, + ) + output_projection.weight = self.embed_tokens.weight + else: + output_projection = torch.nn.Linear(args.decoder_embed_dim, args.vocab_size, bias=False) + torch.nn.init.normal_(output_projection.weight, mean=0, std=args.decoder_embed_dim**-0.5) + return output_projection + + def build_decoder_layer(self, args, depth, is_moe_layer=False, is_encoder_decoder=False): + layer = DecoderLayer( + args, + depth, + is_moe_layer=is_moe_layer, + is_encoder_decoder=is_encoder_decoder, + ) + if args.checkpoint_activations: + layer = checkpoint_wrapper(layer) + if args.fsdp: + layer = wrap(layer) + return layer + + def forward_embedding( + self, + tokens, + token_embedding=None, + incremental_state=None, + ): + positions = None + if self.embed_positions is not None: + positions = self.embed_positions(tokens, incremental_state=incremental_state) + + if incremental_state is not None: + tokens = tokens[:, -1:] + if positions is not None: + positions = positions[:, -1:] + + if token_embedding is None: + token_embedding = self.embed_tokens(tokens) + + x = embed = self.embed_scale * token_embedding + + if positions is not None: + x += positions + + if self.layernorm_embedding is not None: + x = self.layernorm_embedding(x) + + x = self.dropout_module(x) + + return x, embed + + def forward( + self, + prev_output_tokens, + self_attn_padding_mask=None, + encoder_out=None, + incremental_state=None, + features_only=False, + return_all_hiddens=False, + token_embeddings=None, + **kwargs + ): + # embed tokens and positions + x, _ = self.forward_embedding(prev_output_tokens, token_embeddings, incremental_state) + + # relative position + self_attn_rel_pos_bias = None + slen = prev_output_tokens.size(1) + if self.self_attn_relative_position is not None: + self_attn_rel_pos_bias = self.self_attn_relative_position(batch_size=x.size(0), qlen=slen, klen=slen) + if incremental_state is not None: + self_attn_rel_pos_bias = self_attn_rel_pos_bias[-1:, :, :] + cross_attn_rel_pos_bias = None + if self.cross_attn_relative_position is not None: + cross_attn_rel_pos_bias = self.cross_attn_relative_position( + batch_size=x.size(0), + qlen=slen, + klen=encoder_out["encoder_out"].size(1), + ) + if incremental_state is not None: + cross_attn_rel_pos_bias = cross_attn_rel_pos_bias[-1:, :, :] + + # decoder layers + inner_states = [x] + + if encoder_out is None: + l_aux = [] + else: + l_aux = encoder_out["l_aux"] if "l_aux" in encoder_out else [] + + for idx, layer in enumerate(self.layers): + if incremental_state is None: + self_attn_mask = torch.triu( + torch.zeros([x.size(1), x.size(1)]).float().fill_(float("-inf")).type_as(x), + 1, + ) + else: + self_attn_mask = None + if idx not in incremental_state: + incremental_state[idx] = {} + + x, layer_attn, _, l_aux_i = layer( + x, + encoder_out["encoder_out"] if encoder_out is not None else None, + encoder_out["encoder_padding_mask"] if encoder_out is not None else None, + incremental_state[idx] if incremental_state is not None else None, + self_attn_mask=self_attn_mask, + self_attn_padding_mask=self_attn_padding_mask, + self_attn_rel_pos=self_attn_rel_pos_bias, + cross_attn_rel_pos=cross_attn_rel_pos_bias, + ) + l_aux.append(l_aux_i) + inner_states.append(x) + + if self.layer_norm is not None: + x = self.layer_norm(x) + + if not features_only: + x = self.output_layer(x) + + return x, { + "inner_states": inner_states, + "l_aux": l_aux, + "attn": None, + } + + def output_layer(self, features): + return self.output_projection(features) diff --git a/vlmo/torchscale/architecture/encoder.py b/vlmo/torchscale/architecture/encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..d721e37d484ff82754ad264e44c9a025059240ac --- /dev/null +++ b/vlmo/torchscale/architecture/encoder.py @@ -0,0 +1,489 @@ +# Copyright (c) 2022 Microsoft +# Licensed under The MIT License [see LICENSE for details] + +import math + +import numpy as np +import torch +import torch.nn as nn +from fairscale.nn import checkpoint_wrapper, wrap + +try: + from apex.normalization import FusedLayerNorm as LayerNorm +except ModuleNotFoundError: + from torch.nn import LayerNorm + +from vlmo.torchscale.architecture.utils import init_bert_params +from vlmo.torchscale.component.droppath import DropPath +from vlmo.torchscale.component.feedforward_network import FeedForwardNetwork, make_experts +from vlmo.torchscale.component.multihead_attention import MultiheadAttention +from vlmo.torchscale.component.multiway_network import MultiwayWrapper, set_split_position +from vlmo.torchscale.component.relative_position_bias import RelativePositionBias +#from vlmo.torchscale.component.xmoe.moe_layer import MOELayer +#from vlmo.torchscale.component.xmoe.routing import Top1Gate, Top2Gate +# from vlmo.modules.vlmo_utils import no_sync_module_apply +from pytorch_lightning.utilities.rank_zero import rank_zero_info + +def no_sync_module_apply(module, fn): + """FSDP module .apply will use _unshard_params_recurse which will sync params across ranks. + using this function when apply fn is unnecessary to sync params across ranks. + """ + for child in module.children(): + fn(child) + no_sync_module_apply(child, fn) + +class EncoderLayer(nn.Module): + def __init__(self, args, depth, attn=None, is_moe_layer=False, is_encoder_decoder=False): + super().__init__() + self.args = args + self.embed_dim = args.encoder_embed_dim + self.self_attn = self.build_self_attention(self.embed_dim, args) if attn is None else attn + self.self_attn_layer_norm = MultiwayWrapper(args, LayerNorm(self.embed_dim, eps=args.layernorm_eps)) + self.dropout_module = torch.nn.Dropout(args.dropout) + + if args.drop_path_rate > 0: + drop_path_prob = np.linspace(0, args.drop_path_rate, args.encoder_layers)[depth] + self.drop_path = DropPath(drop_path_prob) + else: + self.drop_path = None + + self.normalize_before = args.encoder_normalize_before + self.is_moe_layer = is_moe_layer + self.ffn_dim = args.encoder_ffn_embed_dim + + if not self.is_moe_layer: + self.ffn = MultiwayWrapper( + args, + self.build_ffn( + self.embed_dim, + self.args, + ), + ) + else: + assert not self.args.multiway + if args.moe_top1_expert: + gate = Top1Gate( + self.embed_dim, + args.moe_expert_count, + use_fp32=args.moe_gating_use_fp32, + moe_eval_capacity_token_fraction=args.moe_eval_capacity_token_fraction, + use_xmoe=args.use_xmoe, + ) + else: + gate = Top2Gate( + self.embed_dim, + args.moe_expert_count, + args.moe_gating_use_fp32, + args.moe_second_expert_policy, + args.moe_normalize_gate_prob_before_dropping, + args.moe_eval_capacity_token_fraction, + use_xmoe=args.use_xmoe, + ) + experts = make_experts(args, self.embed_dim, self.ffn_dim) + self.moe_layer = MOELayer(gate, experts, args) + self.final_layer_norm = MultiwayWrapper(args, LayerNorm(self.embed_dim, eps=args.layernorm_eps)) + + if args.deepnorm: + if is_encoder_decoder: + self.alpha = math.pow(math.pow(args.encoder_layers, 4) * args.decoder_layers, 0.0625) * 0.81 + else: + self.alpha = math.pow(2.0 * args.encoder_layers, 0.25) + else: + self.alpha = 1.0 + + def build_ffn(self, embed_dim, args): + return FeedForwardNetwork( + embed_dim, + self.ffn_dim, + args.activation_fn, + args.dropout, + args.activation_dropout, + args.layernorm_eps, + args.subln, + ) + + def build_self_attention(self, embed_dim, args): + return MultiheadAttention( + args, + embed_dim, + args.encoder_attention_heads, + dropout=args.attention_dropout, + self_attention=True, + encoder_decoder_attention=False, + subln=args.subln, + one_attn=args.one_attn, + ) + + def residual_connection(self, x, residual): + return residual * self.alpha + x + + def forward( + self, + x, + encoder_padding_mask, + attn_mask=None, + rel_pos=None, + multiway_split_position=None, + incremental_state=None, + ): + if multiway_split_position is not None: + assert self.args.multiway + no_sync_module_apply(self, set_split_position(multiway_split_position)) + + if attn_mask is not None: + # float16: -1e8 equal 0 + attn_mask = attn_mask.masked_fill(attn_mask.to(torch.bool), -1e8) + + residual = x + if self.normalize_before: + x = self.self_attn_layer_norm(x) + x, _ = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=encoder_padding_mask, + attn_mask=attn_mask, + rel_pos=rel_pos, + incremental_state=incremental_state, + ) + x = self.dropout_module(x) + + if self.drop_path is not None: + x = self.drop_path(x) + + x = self.residual_connection(x, residual) + if not self.normalize_before: + x = self.self_attn_layer_norm(x) + + residual = x + if self.normalize_before: + x = self.final_layer_norm(x) + if not self.is_moe_layer: + x = self.ffn(x) + l_aux = None + else: + x = x.transpose(0, 1) + x, l_aux = self.moe_layer(x) + x = x.transpose(0, 1) + + if self.drop_path is not None: + x = self.drop_path(x) + + x = self.residual_connection(x, residual) + if not self.normalize_before: + x = self.final_layer_norm(x) + return x, l_aux + + +class Encoder(nn.Module): + def __init__( + self, args, embed_tokens=None, embed_positions=None, output_projection=None, is_encoder_decoder=False, **kwargs + ): + self.args = args + super().__init__(**kwargs) + + self.dropout_module = torch.nn.Dropout(args.dropout) + + embed_dim = args.encoder_embed_dim + self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt(embed_dim) + self.mask_ratio = args.mask_ratio + self.max_text_len = args.max_text_len + self.vision_len = (args.img_size // args.patch_size) * (args.img_size // args.patch_size) + + self.embed_tokens = embed_tokens + self.embed_positions = embed_positions + + if output_projection is None and not is_encoder_decoder and not args.no_output_layer and args.vocab_size > 0: + self.output_projection = self.build_output_projection(args) + else: + self.output_projection = output_projection + + if args.layernorm_embedding: + self.layernorm_embedding = MultiwayWrapper(args, LayerNorm(embed_dim, eps=args.layernorm_eps), dim=1) + else: + self.layernorm_embedding = None + + self.layers = nn.ModuleList([]) + if self.args.share_layer: + single_layer = self.build_encoder_layer( + args, depth=0, is_moe_layer=False, is_encoder_decoder=is_encoder_decoder + ) + for i in range(args.encoder_layers): + self.layers.append(single_layer) + elif self.args.share_attn: + moe_freq = args.moe_freq + embed_dim = args.encoder_embed_dim + shared_attn = self.build_self_attention(embed_dim, self.args) + for i in range(args.encoder_layers): + is_moe_layer = moe_freq != 0 and (i + 1) % moe_freq == 0 + self.layers.append( + self.build_encoder_layer( + args, + depth=i, + attn=shared_attn, + is_moe_layer=is_moe_layer, + is_encoder_decoder=is_encoder_decoder, + ) + ) + + else: + moe_freq = args.moe_freq + for i in range(args.encoder_layers): + is_moe_layer = moe_freq != 0 and (i + 1) % moe_freq == 0 + self.layers.append( + self.build_encoder_layer( + args, + depth=i, + is_moe_layer=is_moe_layer, + is_encoder_decoder=is_encoder_decoder, + ) + ) + self.num_layers = len(self.layers) + + if args.encoder_normalize_before and args.normalize_output: + self.layer_norm = MultiwayWrapper(args, LayerNorm(embed_dim, eps=args.layernorm_eps)) + else: + self.layer_norm = None + + if args.rel_pos_buckets > 0 and args.max_rel_pos > 0: + self.relative_position = RelativePositionBias( + num_buckets=args.rel_pos_buckets, + max_distance=args.max_rel_pos, + n_heads=args.encoder_attention_heads, + ) + else: + self.relative_position = None + + if args.bert_init: + self.apply(init_bert_params) + + if args.deepnorm: + if is_encoder_decoder: + init_scale = math.pow(math.pow(args.encoder_layers, 4) * args.decoder_layers, 0.0625) / 1.15 + else: + init_scale = math.pow(8.0 * args.encoder_layers, 0.25) + for name, p in self.named_parameters(): + if "fc1" in name or "fc2" in name or "out_proj" in name or "v_proj" in name: + p.data.div_(init_scale) + + if args.subln: + if is_encoder_decoder: + init_scale = math.sqrt(math.log(3 * args.decoder_layers) * math.log(2 * args.encoder_layers) / 3) + else: + init_scale = math.sqrt(math.log(args.encoder_layers * 2)) + for name, p in self.named_parameters(): + if "fc1" in name or "fc2" in name or "out_proj" in name or "v_proj" in name: + p.data.mul_(init_scale) + + def random_masking(self, x, mask_ratio): + N, L, D = x.shape # batch, length, dim + len_keep = int(L * (1 - mask_ratio)) + + noise = torch.rand(N, L - 1, device=x.device) + ids_shuffle = torch.argsort(noise, dim=1) + torch.ones(N, L - 1, device=x.device, dtype=int) + ids_keep = ids_shuffle[:, :len_keep] + + x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) + + x0 = x[:, 0, :] + x0 = x0.reshape(N, 1, D) + x_masked_add = torch.cat([x0, x_masked], axis=1) + return x_masked_add, ids_keep + + def build_self_attention(self, embed_dim, args): + return MultiheadAttention( + args, + embed_dim, + args.encoder_attention_heads, + dropout=args.attention_dropout, + self_attention=True, + encoder_decoder_attention=False, + subln=args.subln, + one_attn=args.one_attn, + ) + + def build_output_projection( + self, + args, + ): + if args.share_encoder_input_output_embed: + assert args.encoder_embedding_type == "language" + output_projection = torch.nn.Linear( + self.embed_tokens.weight.shape[1], + self.embed_tokens.weight.shape[0], + bias=False, + ) + output_projection.weight = self.embed_tokens.weight + else: + output_projection = torch.nn.Linear(args.encoder_embed_dim, args.vocab_size, bias=False) + torch.nn.init.normal_(output_projection.weight, mean=0, std=args.encoder_embed_dim**-0.5) + return output_projection + + def checkpointing_and_params_allgather( + self, + origin_layer, + ): + origin_forward = origin_layer.forward + + from deepspeed import checkpointing + def forward(*args, **kwargs): + # deepspeed checkpoint not support kwargs + ret = checkpointing.checkpoint(origin_forward, *args, **kwargs) + return ret + + return forward + + def build_encoder_layer(self, args, depth, attn=None, is_moe_layer=False, is_encoder_decoder=False): + layer = EncoderLayer( + args, + depth, + attn, + is_moe_layer=is_moe_layer, + is_encoder_decoder=is_encoder_decoder, + ) + if args.checkpoint_activations: + rank_zero_info("EncoderLayer params: %s", sum(p.numel() for p in layer.parameters() if p.requires_grad)) + layer = checkpoint_wrapper(layer) + # layer.ffn = checkpoint_wrapper(layer.ffn,) + if args.fsdp: + layer = wrap(layer) + return layer + + def checkpointing_layers(self): + for i, layer in enumerate(self.layers): + rank_zero_info(f"Checkpointing wrapper EncoderLayers: {i}") + self.layers[i] = checkpoint_wrapper(layer) + + def forward_embedding( + self, + src_tokens, + token_embedding=None, + positions=None, + ): + if token_embedding is None: + token_embedding = self.embed_tokens(src_tokens) + x = embed = self.embed_scale * token_embedding + if self.embed_positions is not None: + if src_tokens is not None: + x = embed + self.embed_positions(src_tokens, positions=positions) + else: + x = embed + self.embed_positions(x, positions=positions) + is_flip, ids_keep = 0, None + if self.mask_ratio > 0: + if x.shape[1] == self.vision_len + 1: + x, ids_keep = self.random_masking(x, self.mask_ratio) + is_flip = 1 + elif x.shape[1] == self.vision_len + self.max_text_len + 1: + vision_tokens = x[:, : self.vision_len + 1, :] + vision_tokens, ids_keep = self.random_masking(vision_tokens, self.mask_ratio) + x = torch.cat( + [ + vision_tokens, + x[ + :, + self.vision_len + 1 :, + ], + ], + dim=1, + ) + is_flip = 2 + if self.layernorm_embedding is not None: + x = self.layernorm_embedding(x) + x = self.dropout_module(x) + return x, embed, ids_keep, is_flip + + def forward( + self, + src_tokens, + encoder_padding_mask=None, + attn_mask=None, + return_all_hiddens=False, + token_embeddings=None, + multiway_split_position=None, + features_only=False, + incremental_state=None, + positions=None, + **kwargs + ): + assert src_tokens is not None or token_embeddings is not None + + if encoder_padding_mask is None: + if src_tokens is not None: + encoder_padding_mask = torch.zeros_like(src_tokens, device=src_tokens.device).bool() + else: + encoder_padding_mask = torch.zeros( + [token_embeddings.size(0), token_embeddings.size(1)], + device=token_embeddings.device, + ).bool() + + if multiway_split_position is not None: + assert self.args.multiway + no_sync_module_apply(self, set_split_position(multiway_split_position)) + + x, encoder_embedding, ids_keep, is_flip = self.forward_embedding(src_tokens, token_embeddings, positions) + if is_flip > 0: + if is_flip == 2: + text_ids = ( + torch.arange( + self.vision_len + 1, self.vision_len + 1 + self.max_text_len, device=x.device, dtype=torch.int64 + ) + .unsqueeze(0) + .repeat(ids_keep.shape[0], 1) + ) + cls_ids = torch.zeros(ids_keep.shape[0], 1, device=x.device, dtype=torch.int64) + ids_keep = torch.cat([cls_ids, ids_keep, text_ids], dim=1) + elif is_flip == 1: + cls_ids = torch.zeros(ids_keep.shape[0], 1, device=x.device, dtype=torch.int64) + ids_keep = torch.cat([cls_ids, ids_keep], dim=1) + if encoder_padding_mask is not None: + encoder_padding_mask = torch.gather(encoder_padding_mask, dim=1, index=ids_keep) + if attn_mask is not None: + attn_mask = torch.gather( + attn_mask, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, attn_mask.shape[-1]) + ) + attn_mask = torch.gather(attn_mask, dim=2, index=ids_keep.unsqueeze(1).repeat(1, attn_mask.shape[1], 1)) + if multiway_split_position > 0: + multiway_split_position = ids_keep.shape[1] - self.max_text_len + x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x)) + + encoder_states = [] + + if return_all_hiddens: + encoder_states.append(x) + + rel_pos_bias = None + if self.relative_position is not None: + rel_pos_bias = self.relative_position(batch_size=x.size(0), qlen=x.size(1), klen=x.size(1)) + + l_aux = [] + for idx, layer in enumerate(self.layers): + x, l_aux_i = layer( + x, + encoder_padding_mask=encoder_padding_mask if incremental_state is None else None, + attn_mask=attn_mask, + rel_pos=rel_pos_bias, + multiway_split_position=multiway_split_position, + incremental_state=incremental_state[idx] if incremental_state is not None else None, + ) + if return_all_hiddens: + assert encoder_states is not None + encoder_states.append(x) + l_aux.append(l_aux_i) + + if multiway_split_position is not None: + assert self.args.multiway + no_sync_module_apply(self, set_split_position(multiway_split_position)) + if self.layer_norm is not None: + x = self.layer_norm(x) + + if not features_only and self.output_projection is not None: + x = self.output_projection(x) + + return { + "encoder_out": x, + "encoder_embedding": encoder_embedding, + "encoder_padding_mask": encoder_padding_mask, + "encoder_states": encoder_states, + "l_aux": l_aux, + "multiway_split_position": multiway_split_position, + } diff --git a/vlmo/torchscale/architecture/encoder_decoder.py b/vlmo/torchscale/architecture/encoder_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..5f3e3eb8c2ec9dcc2b80f055ee81ae9fc07a58cf --- /dev/null +++ b/vlmo/torchscale/architecture/encoder_decoder.py @@ -0,0 +1,43 @@ +# Copyright (c) 2022 Microsoft +# Licensed under The MIT License [see LICENSE for details] + +import torch.nn as nn + +from vlmo.torchscale.architecture.decoder import Decoder +from vlmo.torchscale.architecture.encoder import Encoder + + +class EncoderDecoder(nn.Module): + def __init__( + self, + args, + encoder_embed_tokens=None, + encoder_embed_positions=None, + decoder_embed_tokens=None, + decoder_embed_positions=None, + output_projection=None, + **kwargs + ): + super().__init__() + self.args = args + if args.share_all_embeddings: + args.share_decoder_input_output_embed = True + + self.encoder = Encoder(args, encoder_embed_tokens, encoder_embed_positions, is_encoder_decoder=True, **kwargs) + + if args.share_all_embeddings and decoder_embed_tokens is None: + decoder_embed_tokens = self.encoder.embed_tokens + + self.decoder = Decoder( + args, decoder_embed_tokens, decoder_embed_positions, output_projection, is_encoder_decoder=True, **kwargs + ) + + def forward(self, src_tokens, prev_output_tokens, return_all_hiddens=False, features_only=False, **kwargs): + encoder_out = self.encoder(src_tokens, return_all_hiddens=return_all_hiddens) + decoder_out = self.decoder( + prev_output_tokens, + encoder_out=encoder_out, + features_only=features_only, + return_all_hiddens=return_all_hiddens, + ) + return decoder_out diff --git a/vlmo/torchscale/architecture/utils.py b/vlmo/torchscale/architecture/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6311020f488af03806c8f9779750382af4f9d51d --- /dev/null +++ b/vlmo/torchscale/architecture/utils.py @@ -0,0 +1,33 @@ +# Copyright (c) 2022 Microsoft +# Licensed under The MIT License [see LICENSE for details] + +import torch.nn as nn + +from vlmo.torchscale.component.multihead_attention import MultiheadAttention +from vlmo.torchscale.component.multiway_network import MultiwayNetwork + + +def init_bert_params(module): + def normal_(data): + data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device)) + + if isinstance(module, nn.Linear): + normal_(module.weight.data) + if module.bias is not None: + module.bias.data.zero_() + if isinstance(module, nn.Embedding): + normal_(module.weight.data) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + if isinstance(module, MultiheadAttention): + if isinstance(module.q_proj, MultiwayNetwork): + normal_(module.q_proj.A.weight.data) + normal_(module.q_proj.B.weight.data) + normal_(module.k_proj.A.weight.data) + normal_(module.k_proj.B.weight.data) + normal_(module.v_proj.A.weight.data) + normal_(module.v_proj.B.weight.data) + else: + normal_(module.q_proj.weight.data) + normal_(module.k_proj.weight.data) + normal_(module.v_proj.weight.data) diff --git a/vlmo/torchscale/component/__init__.py b/vlmo/torchscale/component/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3ae31e2507e8759f2ac7f85e517288f536c04ac3 --- /dev/null +++ b/vlmo/torchscale/component/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) 2022 Microsoft +# Licensed under The MIT License [see LICENSE for details] diff --git a/vlmo/torchscale/component/droppath.py b/vlmo/torchscale/component/droppath.py new file mode 100644 index 0000000000000000000000000000000000000000..18c06440816d67402470f8a0876e9c4806d172fc --- /dev/null +++ b/vlmo/torchscale/component/droppath.py @@ -0,0 +1,19 @@ +# Copyright (c) 2022 Microsoft +# Licensed under The MIT License [see LICENSE for details] + +import torch.nn as nn +from timm.models.layers import drop_path + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + def extra_repr(self): + return "p={}".format(self.drop_prob) diff --git a/vlmo/torchscale/component/embedding.py b/vlmo/torchscale/component/embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..cbab09ea6f0041368f6f1f269cc7574df50105c9 --- /dev/null +++ b/vlmo/torchscale/component/embedding.py @@ -0,0 +1,110 @@ +# Copyright (c) 2022 Microsoft +# Licensed under The MIT License [see LICENSE for details] + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class VisionLanguageEmbedding(nn.Module): + def __init__(self, text_embed, vision_embed): + super().__init__() + self.text_embed = text_embed + self.vision_embed = vision_embed + + def forward(self, textual_tokens, visual_tokens, **kwargs): + if textual_tokens is None: + return self.vision_embed(visual_tokens) + + if visual_tokens is None: + return self.text_embed(textual_tokens) + + x1 = self.vision_embed(visual_tokens) + x2 = self.text_embed(textual_tokens) + + return torch.cat([x1, x2], dim=1) + + +class VisionEmbedding(nn.Module): + """Image to Patch Embedding""" + + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + contain_mask_token=False, + prepend_cls_token=False, + ): + super().__init__() + img_size = (img_size, img_size) + patch_size = (patch_size, patch_size) + num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) + self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + if contain_mask_token: + self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + else: + self.mask_token = None + + if prepend_cls_token: + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + else: + self.cls_token = None + + def num_position_embeddings(self): + if self.cls_token is None: + return self.num_patches + else: + return self.num_patches + 1 + + def forward(self, x, masked_position=None, **kwargs): + B, C, H, W = x.shape + x = self.proj(x).flatten(2).transpose(1, 2) + + batch_size, seq_len, _ = x.size() + + if masked_position is not None: + assert self.mask_token is not None + mask_token = self.mask_token.expand(batch_size, seq_len, -1) + w = masked_position.unsqueeze(-1).type_as(mask_token) + x = x * (1 - w) + mask_token * w + + if self.cls_token is not None: + cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + + return x + + +class TextEmbedding(nn.Embedding): + def reset_parameters(self): + nn.init.normal_(self.weight, mean=0, std=self.embedding_dim**-0.5) + self._fill_padding_idx_with_zero() + + +class PositionalEmbedding(nn.Embedding): + def forward( + self, + x, + positions=None, + **kwargs, + ): + if positions is None: + # being consistent with Fairseq, which starts from 2. + positions = torch.arange(2, x.size(1) + 2, device=x.device).long().unsqueeze(0) + return F.embedding( + positions, + self.weight, + self.padding_idx, + self.max_norm, + self.norm_type, + self.scale_grad_by_freq, + self.sparse, + ) diff --git a/vlmo/torchscale/component/feedforward_network.py b/vlmo/torchscale/component/feedforward_network.py new file mode 100644 index 0000000000000000000000000000000000000000..4a4b1388a0aff21ec343e04be92571d7b428a580 --- /dev/null +++ b/vlmo/torchscale/component/feedforward_network.py @@ -0,0 +1,128 @@ +# Copyright (c) 2022 Microsoft +# Licensed under The MIT License [see LICENSE for details] + +import torch +import torch.nn as nn +import torch.nn.functional as F + +try: + from apex.normalization import FusedLayerNorm as LayerNorm +except ModuleNotFoundError: + from torch.nn import LayerNorm + + +class set_torch_seed(object): + def __init__(self, seed): + assert isinstance(seed, int) + self.rng_state = self.get_rng_state() + + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + + def get_rng_state(self): + state = {"torch_rng_state": torch.get_rng_state()} + if torch.cuda.is_available(): + state["cuda_rng_state"] = torch.cuda.get_rng_state() + return state + + def set_rng_state(self, state): + torch.set_rng_state(state["torch_rng_state"]) + if torch.cuda.is_available(): + torch.cuda.set_rng_state(state["cuda_rng_state"]) + + def __enter__(self): + return self + + def __exit__(self, *exc): + self.set_rng_state(self.rng_state) + + +def make_experts(args, embed_dim, expert_ffn_dim): + world_size = 1 if not torch.distributed.is_initialized() else torch.distributed.get_world_size() + expert_list = [] + ddp_rank = args.ddp_rank + start_seed = torch.randint(1000000, (1,)).item() + # at least as many experts than gpus + if args.moe_expert_count >= world_size: + assert args.moe_expert_count % world_size == 0, f"{args.moe_expert_count}, {world_size}" + local_moe_expert_count = args.moe_expert_count // world_size + for i in range(local_moe_expert_count): + with set_torch_seed(start_seed + ddp_rank * local_moe_expert_count + i): + expert_list.append( + FeedForwardNetwork( + embed_dim, + expert_ffn_dim, + args.activation_fn, + args.dropout, + args.activation_dropout, + args.layernorm_eps, + args.subln, + ) + ) + else: + assert world_size % args.moe_expert_count == 0, f"{world_size}, {args.moe_expert_count}" + + with set_torch_seed(start_seed + ddp_rank % args.moe_expert_count): + expert_list.append( + FeedForwardNetwork( + embed_dim, + expert_ffn_dim, + args.activation_fn, + args.dropout, + args.activation_dropout, + args.layernorm_eps, + args.subln, + ) + ) + experts = nn.ModuleList(expert_list) + return experts + + +def get_activation_fn(activation): + if activation == "relu": + return F.relu + elif activation == "gelu": + return F.gelu + else: + raise NotImplementedError + + +class FeedForwardNetwork(nn.Module): + def __init__( + self, + embed_dim, + ffn_dim, + activation_fn, + dropout, + activation_dropout, + layernorm_eps, + subln=False, + ): + super().__init__() + self.embed_dim = embed_dim + self.activation_fn = get_activation_fn(activation=str(activation_fn)) + self.activation_dropout_module = torch.nn.Dropout(activation_dropout) + self.dropout_module = torch.nn.Dropout(dropout) + self.fc1 = nn.Linear(self.embed_dim, ffn_dim) + self.fc2 = nn.Linear(ffn_dim, self.embed_dim) + self.ffn_layernorm = LayerNorm(ffn_dim, eps=layernorm_eps) if subln else None + + def reset_parameters(self): + self.fc1.reset_parameters() + self.fc2.reset_parameters() + if self.ffn_layernorm is not None: + self.ffn_layernorm.reset_parameters() + + def forward(self, x): + # x = x.reshape(-1, x.size(-1)) + x = self.fc1(x) + # x = self.activation_fn(x.float()).type_as(x) + x = self.activation_fn(x) + x = self.activation_dropout_module(x) + if self.ffn_layernorm is not None: + x = self.ffn_layernorm(x) + x = self.fc2(x) + # x = x.view(x_shape) + x = self.dropout_module(x) + return x diff --git a/vlmo/torchscale/component/multihead_attention.py b/vlmo/torchscale/component/multihead_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..528290e2a44527c45d2b5b72bf9c408039ecab5e --- /dev/null +++ b/vlmo/torchscale/component/multihead_attention.py @@ -0,0 +1,154 @@ +# Copyright (c) 2022 Microsoft +# Licensed under The MIT License [see LICENSE for details] + +import math + +import torch +import torch.nn.functional as F +from torch import nn + +try: + from apex.normalization import FusedLayerNorm as LayerNorm +except ModuleNotFoundError: + from torch.nn import LayerNorm + +from .multiway_network import MultiwayWrapper +from .xpos_relative_position import XPOS + + +class MultiheadAttention(nn.Module): + def __init__( + self, + args, + embed_dim, + num_heads, + dropout=0.0, + self_attention=False, + encoder_decoder_attention=False, + subln=False, + one_attn=False, + ): + super().__init__() + self.args = args + self.embed_dim = embed_dim + self.num_heads = num_heads + self.head_dim = embed_dim // num_heads + self.scaling = self.head_dim ** (-0.5) + self.self_attention = self_attention + self.encoder_decoder_attention = encoder_decoder_attention + assert self.self_attention ^ self.encoder_decoder_attention + if one_attn: + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=True) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True) + else: + self.k_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=True)) + self.v_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=True)) + self.q_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=True)) + # self.qkv_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim*3, bias=True)) + self.out_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=True)) + self.inner_attn_ln = ( + MultiwayWrapper(args, LayerNorm(self.embed_dim, eps=args.layernorm_eps)) + if subln and self.self_attention + else None + ) + self.dropout_module = torch.nn.Dropout(dropout) + self.xpos = XPOS(self.head_dim, args.xpos_scale_base) if args.xpos_rel_pos and self.self_attention else None + + def reset_parameters(self): + nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2)) + nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2)) + nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2)) + nn.init.xavier_uniform_(self.out_proj.weight) + nn.init.constant_(self.out_proj.bias, 0.0) + + def forward( + self, + query, + key, + value, + incremental_state=None, + key_padding_mask=None, + attn_mask=None, + rel_pos=None, + ): + bsz, tgt_len, embed_dim = query.size() + src_len = tgt_len + assert embed_dim == self.embed_dim, f"query dim {embed_dim} != {self.embed_dim}" + + key_bsz, src_len, _ = key.size() + assert key_bsz == bsz, f"{query.size(), key.size()}" + assert value is not None + assert bsz, src_len == value.shape[:2] + # if query is key and key is value: + # qkv = self.qkv_proj(query) + # else: + # # W*(q+k+v) = W(q) + W(k) + W(v) + # qkv = self.qkv_proj(query+key+value) + # q,k,v = qkv.split(self.embed_dim, dim=-1) + + q = self.q_proj(query) + k = self.k_proj(key) + v = self.v_proj(value) + + q = (q * self.scaling).view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) + k = k.view(bsz, src_len, self.num_heads, self.head_dim).transpose(1, 2) + v = v.view(bsz, src_len, self.num_heads, self.head_dim).transpose(1, 2) + q = q.reshape(bsz * self.num_heads, tgt_len, self.head_dim) + k = k.reshape(bsz * self.num_heads, src_len, self.head_dim) + v = v.reshape(bsz * self.num_heads, src_len, self.head_dim) + + if incremental_state is not None: + if "prev_key" in incremental_state: + prev_key = incremental_state["prev_key"].view(bsz * self.num_heads, -1, self.head_dim) + prev_value = incremental_state["prev_value"].view(bsz * self.num_heads, -1, self.head_dim) + k = torch.cat([prev_key, k], dim=1) + v = torch.cat([prev_value, v], dim=1) + incremental_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim) + incremental_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim) + src_len = k.size(1) + + if self.xpos is not None: + if incremental_state is not None: + offset = src_len - 1 + else: + offset = 0 + k = self.xpos(k, offset=0, downscale=True) + q = self.xpos(q, offset=offset, downscale=False) + + attn_weights = torch.bmm(q, k.transpose(1, 2)) + + if attn_mask is not None: + attn_weights = torch.nan_to_num(attn_weights) + if len(attn_mask.shape) != len(attn_weights.shape): + attn_mask = attn_mask.unsqueeze(0) + else: + attn_mask = attn_mask.repeat_interleave(self.num_heads, dim=0) + attn_weights += attn_mask + + if key_padding_mask is not None: + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), + float("-inf"), + ) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if rel_pos is not None: + rel_pos = rel_pos.view(attn_weights.size()) + attn_weights = attn_weights + rel_pos + + attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).type_as(attn_weights) + attn_probs = self.dropout_module(attn_weights) + + attn = torch.bmm(attn_probs, v) + attn = attn.transpose(0, 1).reshape(tgt_len, bsz, embed_dim).transpose(0, 1) + + if self.inner_attn_ln is not None: + attn = self.inner_attn_ln(attn) + + attn = self.out_proj(attn) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0) + + return attn, attn_weights diff --git a/vlmo/torchscale/component/multiway_network.py b/vlmo/torchscale/component/multiway_network.py new file mode 100644 index 0000000000000000000000000000000000000000..fdb15a2cae975459655f976ac24d56297be52d3d --- /dev/null +++ b/vlmo/torchscale/component/multiway_network.py @@ -0,0 +1,55 @@ +# Copyright (c) 2022 Microsoft +# Licensed under The MIT License [see LICENSE for details] + +import copy + +import torch +import torch.nn as nn + + +def MultiwayWrapper(args, module, dim=1): + if args.multiway: + return MultiwayNetwork(module, dim=dim) + return module + + +def set_split_position(position): + def apply_fn(module): + if hasattr(module, "split_position"): + module.split_position = position + + return apply_fn + + +class MultiwayNetwork(nn.Module): + def __init__(self, module, dim=1): + super().__init__() + self.dim = dim + self.A = module + self.B = copy.deepcopy(module) + self.B.reset_parameters() + self.split_position = -1 + + def forward(self, x, **kwargs): + if self.split_position == -1: + return self.A(x, **kwargs) + if self.split_position == 0: + return self.B(x, **kwargs) + x1, x2 = torch.split( + x, + [self.split_position, x.size(self.dim) - self.split_position], + dim=self.dim, + ) + # x1, x2 = x[:self.split_position], x[self.split_position:] + y1, y2 = self.A(x1, **kwargs), self.B(x2, **kwargs) + return torch.cat([y1, y2], dim=self.dim) + + +class MutliwayEmbedding(MultiwayNetwork): + def __init__(self, modules, dim=1): + super(MultiwayNetwork, self).__init__() + self.dim = dim + assert len(modules) == 2 + self.A = modules[0] + self.B = modules[1] + self.split_position = -1 diff --git a/vlmo/torchscale/component/relative_position_bias.py b/vlmo/torchscale/component/relative_position_bias.py new file mode 100644 index 0000000000000000000000000000000000000000..b5249f20cf6f6091a3738199299dff096796bd0d --- /dev/null +++ b/vlmo/torchscale/component/relative_position_bias.py @@ -0,0 +1,67 @@ +# Copyright (c) 2022 Microsoft +# Licensed under The MIT License [see LICENSE for details] + +import math + +import torch +import torch.nn as nn + + +class RelativePositionBias(nn.Module): + def __init__(self, bidirectional=True, num_buckets=32, max_distance=128, n_heads=12): + super().__init__() + self.bidirectional = bidirectional + self.num_buckets = num_buckets + self.max_distance = max_distance + self.n_heads = n_heads + self.relative_attention_bias = nn.Embedding(self.num_buckets, self.n_heads) + + @staticmethod + def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): + ret = 0 + n = -relative_position + if bidirectional: + num_buckets //= 2 + ret += (n < 0).to(torch.long) * num_buckets + n = torch.abs(n) + else: + n = torch.max(n, torch.zeros_like(n)) + + max_exact = num_buckets // 2 + is_small = n < max_exact + + val_if_large = max_exact + ( + torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) + ).to(torch.long) + val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1)) + + ret += torch.where(is_small, n, val_if_large) + return ret + + def compute_bias(self, qlen, klen, step=None): + step = 0 if step is None else step + context_position = torch.arange( + step, + step + qlen, + dtype=torch.long, + device=self.relative_attention_bias.weight.device, + )[:, None] + memory_position = torch.arange(klen, dtype=torch.long, device=self.relative_attention_bias.weight.device)[ + None, : + ] + relative_position = memory_position - context_position # shape (qlen, klen) + + rp_bucket = self._relative_position_bucket( + relative_position, # shape (qlen, klen) + bidirectional=self.bidirectional, + num_buckets=self.num_buckets, + max_distance=self.max_distance, + ) + rp_bucket = rp_bucket.to(self.relative_attention_bias.weight.device) + values = self.relative_attention_bias(rp_bucket) # shape (qlen, klen, num_heads) + values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, qlen, klen) + return values + + def forward(self, batch_size, qlen, klen, step=None): + # shape (batch * num_heads, qlen, klen) + return self.compute_bias(qlen, klen, step).repeat(batch_size, 1, 1, 1).view(-1, qlen, klen) diff --git a/vlmo/torchscale/component/xpos_relative_position.py b/vlmo/torchscale/component/xpos_relative_position.py new file mode 100644 index 0000000000000000000000000000000000000000..6d96d8e19658805f389c94cd4e2977a68b7960bc --- /dev/null +++ b/vlmo/torchscale/component/xpos_relative_position.py @@ -0,0 +1,62 @@ +# Copyright (c) 2022 Microsoft +# Licensed under The MIT License [see LICENSE for details] + +import torch +import torch.nn as nn + + +def fixed_pos_embedding(x): + seq_len, dim = x.shape + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim) / dim)) + sinusoid_inp = torch.einsum("i , j -> i j", torch.arange(0, seq_len, dtype=torch.float), inv_freq).to(x) + return torch.sin(sinusoid_inp), torch.cos(sinusoid_inp) + + +def rotate_every_two(x): + x1 = x[:, :, ::2] + x2 = x[:, :, 1::2] + x = torch.stack((-x2, x1), dim=-1) + return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)')\ + + +def duplicate_interleave(m): + """ + A simple version of `torch.repeat_interleave` for duplicating a matrix while interleaving the copy. + """ + dim0 = m.shape[0] + m = m.view(-1, 1) # flatten the matrix + m = m.repeat(1, 2) # repeat all elements into the 2nd dimension + m = m.view(dim0, -1) # reshape into a matrix, interleaving the copy + return m + + +def apply_rotary_pos_emb(x, sin, cos, scale=1): + sin, cos = map(lambda t: duplicate_interleave(t * scale), (sin, cos)) + # einsum notation for lambda t: repeat(t[offset:x.shape[1]+offset,:], "n d -> () n () (d j)", j=2) + return (x * cos) + (rotate_every_two(x) * sin) + + +class XPOS(nn.Module): + def __init__(self, head_dim, scale_base=512): + super().__init__() + self.head_dim = head_dim + self.scale_base = scale_base + self.register_buffer("scale", (torch.arange(0, head_dim, 2) + 0.4 * head_dim) / (1.4 * head_dim)) + + def forward(self, x, offset=0, downscale=False): + length = x.shape[1] + min_pos = -(length + offset) // 2 + max_pos = length + offset + min_pos + scale = self.scale ** torch.arange(min_pos, max_pos, 1).to(self.scale).div(self.scale_base)[:, None] + sin, cos = fixed_pos_embedding(scale) + + if scale.shape[0] > length: + scale = scale[-length:] + sin = sin[-length:] + cos = cos[-length:] + + if downscale: + scale = 1 / scale + + x = apply_rotary_pos_emb(x, sin, cos, scale) + return x diff --git a/vlmo/torchscale/model/BEiT3.py b/vlmo/torchscale/model/BEiT3.py new file mode 100644 index 0000000000000000000000000000000000000000..103ef8d5cde9204d7dcb0e8054d3ea291dbf3061 --- /dev/null +++ b/vlmo/torchscale/model/BEiT3.py @@ -0,0 +1,96 @@ +# Copyright (c) 2022 Microsoft +# Licensed under The MIT License [see LICENSE for details] + +import torch +import torch.nn as nn + +from vlmo.torchscale.architecture.encoder import Encoder +from vlmo.torchscale.component.embedding import ( + PositionalEmbedding, + TextEmbedding, + VisionEmbedding, +) +from vlmo.torchscale.component.multiway_network import MutliwayEmbedding + + +class BEiT3(nn.Module): + def __init__(self, args, **kwargs): + super().__init__() + self.args = args + assert args.multiway + assert args.vocab_size > 0 + assert not args.share_encoder_input_output_embed + self.text_embed = TextEmbedding(args.vocab_size, args.encoder_embed_dim) + self.vision_embed = VisionEmbedding( + args.img_size, + args.patch_size, + args.in_chans, + args.encoder_embed_dim, + contain_mask_token=True, + prepend_cls_token=True, + ) + # being consistent with Fairseq, which starts from 2 for position embedding + embed_positions = MutliwayEmbedding( + modules=[ + PositionalEmbedding(self.vision_embed.num_position_embeddings() + 2, args.encoder_embed_dim), + PositionalEmbedding(args.max_source_positions, args.encoder_embed_dim), + ], + dim=1, + ) + self.encoder = Encoder( + args, + embed_tokens=None, + embed_positions=embed_positions, + output_projection=None, + is_encoder_decoder=False, + ) + + def forward( + self, + textual_tokens=None, + visual_tokens=None, + text_padding_position=None, + attn_mask=None, + vision_masked_position=None, + incremental_state=None, + positions=None, + ): + assert textual_tokens is not None or visual_tokens is not None + + if textual_tokens is None: + x = self.vision_embed(visual_tokens, vision_masked_position) + encoder_padding_mask = None + multiway_split_position = -1 + elif visual_tokens is None: + x = self.text_embed(textual_tokens) + encoder_padding_mask = text_padding_position + multiway_split_position = 0 + else: + x1 = self.vision_embed(visual_tokens, vision_masked_position) + multiway_split_position = x1.size(1) + x2 = self.text_embed(textual_tokens) + diff = x1.shape[0] // x2.shape[0] + if diff != 1: + x2 = torch.repeat_interleave(x2, diff, dim=0) + text_padding_position = torch.repeat_interleave(text_padding_position, diff, dim=0) + x = torch.cat([x1, x2], dim=1) + if text_padding_position is not None: + encoder_padding_mask = torch.cat( + [ + torch.zeros(x1.shape[:-1], device=x1.device, dtype=torch.bool), + text_padding_position, + ], + dim=1, + ) + else: + encoder_padding_mask = None + encoder_out = self.encoder( + src_tokens=None, + encoder_padding_mask=encoder_padding_mask, + attn_mask=attn_mask, + token_embeddings=x, + multiway_split_position=multiway_split_position, + incremental_state=incremental_state, + positions=positions, + ) + return encoder_out diff --git a/vlmo/torchscale/model/__init__.py b/vlmo/torchscale/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3ae31e2507e8759f2ac7f85e517288f536c04ac3 --- /dev/null +++ b/vlmo/torchscale/model/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) 2022 Microsoft +# Licensed under The MIT License [see LICENSE for details] diff --git a/vlmo/transforms/__init__.py b/vlmo/transforms/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..368f0c5938d030326aa617b4b3a4811920d4ab3f --- /dev/null +++ b/vlmo/transforms/__init__.py @@ -0,0 +1,19 @@ +from .pixelbert import ( + pixelbert_transform, + pixelbert_transform_randaug, +) +from .square_transform import ( + square_transform, + square_transform_randaug, +) + +_transforms = { + "pixelbert": pixelbert_transform, + "pixelbert_randaug": pixelbert_transform_randaug, + "square_transform": square_transform, + "square_transform_randaug": square_transform_randaug, +} + + +def keys_to_transforms(keys: list, size=224): + return [_transforms[key](size=size) for key in keys] diff --git a/vlmo/transforms/pixelbert.py b/vlmo/transforms/pixelbert.py new file mode 100644 index 0000000000000000000000000000000000000000..d7316e4d7ac0e71e4b19a4783008d910ed9cf3e3 --- /dev/null +++ b/vlmo/transforms/pixelbert.py @@ -0,0 +1,30 @@ +from .utils import ( + inception_normalize, + MinMaxResize, +) +from torchvision import transforms +from .randaug import RandAugment + + +def pixelbert_transform(size=800): + longer = int((1333 / 800) * size) + return transforms.Compose( + [ + MinMaxResize(shorter=size, longer=longer), + transforms.ToTensor(), + inception_normalize, + ] + ) + + +def pixelbert_transform_randaug(size=800): + longer = int((1333 / 800) * size) + trs = transforms.Compose( + [ + MinMaxResize(shorter=size, longer=longer), + transforms.ToTensor(), + inception_normalize, + ] + ) + trs.transforms.insert(0, RandAugment(2, 9)) + return trs diff --git a/vlmo/transforms/randaug.py b/vlmo/transforms/randaug.py new file mode 100644 index 0000000000000000000000000000000000000000..3d8f2bc6f62fc704a03256f772f107124afbb195 --- /dev/null +++ b/vlmo/transforms/randaug.py @@ -0,0 +1,271 @@ +# code in this file is adpated from rpmcruz/autoaugment +# https://github.com/rpmcruz/autoaugment/blob/master/transformations.py +import random + +import PIL + +# from PIL import ImageOps, ImageEnhance, ImageDraw +import numpy as np +import torch +from PIL import Image + + +def ShearX(img, v): # [-0.3, 0.3] + assert -0.3 <= v <= 0.3 + if random.random() > 0.5: + v = -v + return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0)) + + +def ShearY(img, v): # [-0.3, 0.3] + assert -0.3 <= v <= 0.3 + if random.random() > 0.5: + v = -v + return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0)) + + +def TranslateX(img, v): # [-150, 150] => percentage: [-0.45, 0.45] + assert -0.45 <= v <= 0.45 + if random.random() > 0.5: + v = -v + v = v * img.size[0] + return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) + + +def TranslateXabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45] + assert 0 <= v + if random.random() > 0.5: + v = -v + return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) + + +def TranslateY(img, v): # [-150, 150] => percentage: [-0.45, 0.45] + assert -0.45 <= v <= 0.45 + if random.random() > 0.5: + v = -v + v = v * img.size[1] + return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) + + +def TranslateYabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45] + assert 0 <= v + if random.random() > 0.5: + v = -v + return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) + + +def Rotate(img, v): # [-30, 30] + assert -30 <= v <= 30 + if random.random() > 0.5: + v = -v + return img.rotate(v) + + +def AutoContrast(img, _): + return PIL.ImageOps.autocontrast(img) + + +def Invert(img, _): + return PIL.ImageOps.invert(img) + + +def Equalize(img, _): + return PIL.ImageOps.equalize(img) + + +def Flip(img, _): # not from the paper + return PIL.ImageOps.mirror(img) + + +def Solarize(img, v): # [0, 256] + assert 0 <= v <= 256 + return PIL.ImageOps.solarize(img, v) + + +def SolarizeAdd(img, addition=0, threshold=128): + img_np = np.array(img).astype(np.int) + img_np = img_np + addition + img_np = np.clip(img_np, 0, 255) + img_np = img_np.astype(np.uint8) + img = Image.fromarray(img_np) + return PIL.ImageOps.solarize(img, threshold) + + +def Posterize(img, v): # [4, 8] + v = int(v) + v = max(1, v) + return PIL.ImageOps.posterize(img, v) + + +def Contrast(img, v): # [0.1,1.9] + assert 0.1 <= v <= 1.9 + return PIL.ImageEnhance.Contrast(img).enhance(v) + + +def Color(img, v): # [0.1,1.9] + assert 0.1 <= v <= 1.9 + return PIL.ImageEnhance.Color(img).enhance(v) + + +def Brightness(img, v): # [0.1,1.9] + assert 0.1 <= v <= 1.9 + return PIL.ImageEnhance.Brightness(img).enhance(v) + + +def Sharpness(img, v): # [0.1,1.9] + assert 0.1 <= v <= 1.9 + return PIL.ImageEnhance.Sharpness(img).enhance(v) + + +def Cutout(img, v): # [0, 60] => percentage: [0, 0.2] + assert 0.0 <= v <= 0.2 + if v <= 0.0: + return img + + v = v * img.size[0] + return CutoutAbs(img, v) + + +def CutoutAbs(img, v): # [0, 60] => percentage: [0, 0.2] + # assert 0 <= v <= 20 + if v < 0: + return img + w, h = img.size + x0 = np.random.uniform(w) + y0 = np.random.uniform(h) + + x0 = int(max(0, x0 - v / 2.0)) + y0 = int(max(0, y0 - v / 2.0)) + x1 = min(w, x0 + v) + y1 = min(h, y0 + v) + + xy = (x0, y0, x1, y1) + color = (125, 123, 114) + # color = (0, 0, 0) + img = img.copy() + PIL.ImageDraw.Draw(img).rectangle(xy, color) + return img + + +def SamplePairing(imgs): # [0, 0.4] + def f(img1, v): + i = np.random.choice(len(imgs)) + img2 = PIL.Image.fromarray(imgs[i]) + return PIL.Image.blend(img1, img2, v) + + return f + + +def Identity(img, v): + return img + + +def augment_list(): # 16 oeprations and their ranges + # https://github.com/google-research/uda/blob/master/image/randaugment/policies.py#L57 + # l = [ + # (Identity, 0., 1.0), + # (ShearX, 0., 0.3), # 0 + # (ShearY, 0., 0.3), # 1 + # (TranslateX, 0., 0.33), # 2 + # (TranslateY, 0., 0.33), # 3 + # (Rotate, 0, 30), # 4 + # (AutoContrast, 0, 1), # 5 + # (Invert, 0, 1), # 6 + # (Equalize, 0, 1), # 7 + # (Solarize, 0, 110), # 8 + # (Posterize, 4, 8), # 9 + # # (Contrast, 0.1, 1.9), # 10 + # (Color, 0.1, 1.9), # 11 + # (Brightness, 0.1, 1.9), # 12 + # (Sharpness, 0.1, 1.9), # 13 + # # (Cutout, 0, 0.2), # 14 + # # (SamplePairing(imgs), 0, 0.4), # 15 + # ] + + # https://github.com/tensorflow/tpu/blob/8462d083dd89489a79e3200bcc8d4063bf362186/models/official/efficientnet/autoaugment.py#L505 + l = [ + (AutoContrast, 0, 1), + (Equalize, 0, 1), + # (Invert, 0, 1), + (Rotate, 0, 30), + (Posterize, 0, 4), + (Solarize, 0, 256), + (SolarizeAdd, 0, 110), + (Color, 0.1, 1.9), + (Contrast, 0.1, 1.9), + (Brightness, 0.1, 1.9), + (Sharpness, 0.1, 1.9), + (ShearX, 0.0, 0.3), + (ShearY, 0.0, 0.3), + # (CutoutAbs, 0, 40), + (TranslateXabs, 0.0, 100), + (TranslateYabs, 0.0, 100), + ] + + return l + + +class Lighting(object): + """Lighting noise(AlexNet - style PCA - based noise)""" + + def __init__(self, alphastd, eigval, eigvec): + self.alphastd = alphastd + self.eigval = torch.Tensor(eigval) + self.eigvec = torch.Tensor(eigvec) + + def __call__(self, img): + if self.alphastd == 0: + return img + + alpha = img.new().resize_(3).normal_(0, self.alphastd) + rgb = ( + self.eigvec.type_as(img) + .clone() + .mul(alpha.view(1, 3).expand(3, 3)) + .mul(self.eigval.view(1, 3).expand(3, 3)) + .sum(1) + .squeeze() + ) + + return img.add(rgb.view(3, 1, 1).expand_as(img)) + + +class CutoutDefault(object): + """ + Reference : https://github.com/quark0/darts/blob/master/cnn/utils.py + """ + + def __init__(self, length): + self.length = length + + def __call__(self, img): + h, w = img.size(1), img.size(2) + mask = np.ones((h, w), np.float32) + y = np.random.randint(h) + x = np.random.randint(w) + + y1 = np.clip(y - self.length // 2, 0, h) + y2 = np.clip(y + self.length // 2, 0, h) + x1 = np.clip(x - self.length // 2, 0, w) + x2 = np.clip(x + self.length // 2, 0, w) + + mask[y1:y2, x1:x2] = 0.0 + mask = torch.from_numpy(mask) + mask = mask.expand_as(img) + img *= mask + return img + + +class RandAugment: + def __init__(self, n, m): + self.n = n + self.m = m # [0, 30] + self.augment_list = augment_list() + + def __call__(self, img): + ops = random.choices(self.augment_list, k=self.n) + for op, minval, maxval in ops: + val = (float(self.m) / 30) * float(maxval - minval) + minval + img = op(img, val) + + return img diff --git a/vlmo/transforms/randaugment.py b/vlmo/transforms/randaugment.py new file mode 100644 index 0000000000000000000000000000000000000000..4a0c46b60c3d8330f82af0d55c7b0d033f4a4923 --- /dev/null +++ b/vlmo/transforms/randaugment.py @@ -0,0 +1,334 @@ +import cv2 +import numpy as np + + +# aug functions +def identity_func(img): + return img + + +def autocontrast_func(img, cutoff=0): + """ + same output as PIL.ImageOps.autocontrast + """ + n_bins = 256 + + def tune_channel(ch): + n = ch.size + cut = cutoff * n // 100 + if cut == 0: + high, low = ch.max(), ch.min() + else: + hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins]) + low = np.argwhere(np.cumsum(hist) > cut) + low = 0 if low.shape[0] == 0 else low[0] + high = np.argwhere(np.cumsum(hist[::-1]) > cut) + high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0] + if high <= low: + table = np.arange(n_bins) + else: + scale = (n_bins - 1) / (high - low) + offset = -low * scale + table = np.arange(n_bins) * scale + offset + table[table < 0] = 0 + table[table > n_bins - 1] = n_bins - 1 + table = table.clip(0, 255).astype(np.uint8) + return table[ch] + + channels = [tune_channel(ch) for ch in cv2.split(img)] + out = cv2.merge(channels) + return out + + +def equalize_func(img): + """ + same output as PIL.ImageOps.equalize + PIL's implementation is different from cv2.equalize + """ + n_bins = 256 + + def tune_channel(ch): + hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins]) + non_zero_hist = hist[hist != 0].reshape(-1) + step = np.sum(non_zero_hist[:-1]) // (n_bins - 1) + if step == 0: + return ch + n = np.empty_like(hist) + n[0] = step // 2 + n[1:] = hist[:-1] + table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8) + return table[ch] + + channels = [tune_channel(ch) for ch in cv2.split(img)] + out = cv2.merge(channels) + return out + + +def rotate_func(img, degree, fill=(0, 0, 0)): + """ + like PIL, rotate by degree, not radians + """ + H, W = img.shape[0], img.shape[1] + center = W / 2, H / 2 + M = cv2.getRotationMatrix2D(center, degree, 1) + out = cv2.warpAffine(img, M, (W, H), borderValue=fill) + return out + + +def solarize_func(img, thresh=128): + """ + same output as PIL.ImageOps.posterize + """ + table = np.array([el if el < thresh else 255 - el for el in range(256)]) + table = table.clip(0, 255).astype(np.uint8) + out = table[img] + return out + + +def color_func(img, factor): + """ + same output as PIL.ImageEnhance.Color + """ + # implementation according to PIL definition, quite slow + # degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis] + # out = blend(degenerate, img, factor) + # M = ( + # np.eye(3) * factor + # + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor) + # )[np.newaxis, np.newaxis, :] + M = np.float32([[0.886, -0.114, -0.114], [-0.587, 0.413, -0.587], [-0.299, -0.299, 0.701]]) * factor + np.float32( + [[0.114], [0.587], [0.299]] + ) + out = np.matmul(img, M).clip(0, 255).astype(np.uint8) + return out + + +def contrast_func(img, factor): + """ + same output as PIL.ImageEnhance.Contrast + """ + mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299])) + table = np.array([(el - mean) * factor + mean for el in range(256)]).clip(0, 255).astype(np.uint8) + out = table[img] + return out + + +def brightness_func(img, factor): + """ + same output as PIL.ImageEnhance.Contrast + """ + table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8) + out = table[img] + return out + + +def sharpness_func(img, factor): + """ + The differences the this result and PIL are all on the 4 boundaries, the center + areas are same + """ + kernel = np.ones((3, 3), dtype=np.float32) + kernel[1][1] = 5 + kernel /= 13 + degenerate = cv2.filter2D(img, -1, kernel) + if factor == 0.0: + out = degenerate + elif factor == 1.0: + out = img + else: + out = img.astype(np.float32) + degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :] + out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate) + out = out.astype(np.uint8) + return out + + +def shear_x_func(img, factor, fill=(0, 0, 0)): + H, W = img.shape[0], img.shape[1] + M = np.float32([[1, factor, 0], [0, 1, 0]]) + out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8) + return out + + +def translate_x_func(img, offset, fill=(0, 0, 0)): + """ + same output as PIL.Image.transform + """ + H, W = img.shape[0], img.shape[1] + M = np.float32([[1, 0, -offset], [0, 1, 0]]) + out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8) + return out + + +def translate_y_func(img, offset, fill=(0, 0, 0)): + """ + same output as PIL.Image.transform + """ + H, W = img.shape[0], img.shape[1] + M = np.float32([[1, 0, 0], [0, 1, -offset]]) + out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8) + return out + + +def posterize_func(img, bits): + """ + same output as PIL.ImageOps.posterize + """ + out = np.bitwise_and(img, np.uint8(255 << (8 - bits))) + return out + + +def shear_y_func(img, factor, fill=(0, 0, 0)): + H, W = img.shape[0], img.shape[1] + M = np.float32([[1, 0, 0], [factor, 1, 0]]) + out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8) + return out + + +def cutout_func(img, pad_size, replace=(0, 0, 0)): + replace = np.array(replace, dtype=np.uint8) + H, W = img.shape[0], img.shape[1] + rh, rw = np.random.random(2) + pad_size = pad_size // 2 + ch, cw = int(rh * H), int(rw * W) + x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H) + y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W) + out = img.copy() + out[x1:x2, y1:y2, :] = replace + return out + + +# level to args +def enhance_level_to_args(MAX_LEVEL): + def level_to_args(level): + return ((level / MAX_LEVEL) * 1.8 + 0.1,) + + return level_to_args + + +def shear_level_to_args(MAX_LEVEL, replace_value): + def level_to_args(level): + level = (level / MAX_LEVEL) * 0.3 + if np.random.random() > 0.5: + level = -level + return (level, replace_value) + + return level_to_args + + +def translate_level_to_args(translate_const, MAX_LEVEL, replace_value): + def level_to_args(level): + level = (level / MAX_LEVEL) * float(translate_const) + if np.random.random() > 0.5: + level = -level + return (level, replace_value) + + return level_to_args + + +def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value): + def level_to_args(level): + level = int((level / MAX_LEVEL) * cutout_const) + return (level, replace_value) + + return level_to_args + + +def solarize_level_to_args(MAX_LEVEL): + def level_to_args(level): + level = int((level / MAX_LEVEL) * 256) + return (level,) + + return level_to_args + + +def none_level_to_args(level): + return () + + +def posterize_level_to_args(MAX_LEVEL): + def level_to_args(level): + level = int((level / MAX_LEVEL) * 4) + return (level,) + + return level_to_args + + +def rotate_level_to_args(MAX_LEVEL, replace_value): + def level_to_args(level): + level = (level / MAX_LEVEL) * 30 + if np.random.random() < 0.5: + level = -level + return (level, replace_value) + + return level_to_args + + +func_dict = { + "Identity": identity_func, + "AutoContrast": autocontrast_func, + "Equalize": equalize_func, + "Rotate": rotate_func, + "Solarize": solarize_func, + "Color": color_func, + "Contrast": contrast_func, + "Brightness": brightness_func, + "Sharpness": sharpness_func, + "ShearX": shear_x_func, + "TranslateX": translate_x_func, + "TranslateY": translate_y_func, + "Posterize": posterize_func, + "ShearY": shear_y_func, +} + +translate_const = 10 +MAX_LEVEL = 10 +replace_value = (128, 128, 128) +arg_dict = { + "Identity": none_level_to_args, + "AutoContrast": none_level_to_args, + "Equalize": none_level_to_args, + "Rotate": rotate_level_to_args(MAX_LEVEL, replace_value), + "Solarize": solarize_level_to_args(MAX_LEVEL), + "Color": enhance_level_to_args(MAX_LEVEL), + "Contrast": enhance_level_to_args(MAX_LEVEL), + "Brightness": enhance_level_to_args(MAX_LEVEL), + "Sharpness": enhance_level_to_args(MAX_LEVEL), + "ShearX": shear_level_to_args(MAX_LEVEL, replace_value), + "TranslateX": translate_level_to_args(translate_const, MAX_LEVEL, replace_value), + "TranslateY": translate_level_to_args(translate_const, MAX_LEVEL, replace_value), + "Posterize": posterize_level_to_args(MAX_LEVEL), + "ShearY": shear_level_to_args(MAX_LEVEL, replace_value), +} + + +class RandomAugment(object): + def __init__(self, N=2, M=10, isPIL=False, augs=[]): + self.N = N + self.M = M + self.isPIL = isPIL + if augs: + self.augs = augs + else: + self.augs = list(arg_dict.keys()) + + def get_random_ops(self): + sampled_ops = np.random.choice(self.augs, self.N) + return [(op, 0.5, self.M) for op in sampled_ops] + + def __call__(self, img): + if self.isPIL: + img = np.array(img) + ops = self.get_random_ops() + for name, prob, level in ops: + if np.random.random() > prob: + continue + args = arg_dict[name](level) + img = func_dict[name](img, *args) + return img + + +if __name__ == "__main__": + a = RandomAugment() + img = np.random.randn(32, 32, 3) + a(img) diff --git a/vlmo/transforms/square_transform.py b/vlmo/transforms/square_transform.py new file mode 100644 index 0000000000000000000000000000000000000000..ece09f8532d44670a46c70c454f1e2f9c2a5ff02 --- /dev/null +++ b/vlmo/transforms/square_transform.py @@ -0,0 +1,41 @@ +# code in this file is adpated from the ALBEF repo (https://github.com/salesforce/ALBEF) + +from torchvision import transforms +from .randaugment import RandomAugment +from PIL import Image + + +def square_transform(size=224): + return transforms.Compose( + [ + transforms.Resize((size, size), interpolation=Image.BICUBIC), + transforms.ToTensor(), + ] + ) + + +def square_transform_randaug(size=224): + return transforms.Compose( + [ + transforms.RandomResizedCrop(size, scale=(0.8, 1.0), interpolation=Image.BICUBIC), + transforms.RandomHorizontalFlip(), + RandomAugment( + 2, + 7, + isPIL=True, + augs=[ + "Identity", + "AutoContrast", + "Equalize", + "Brightness", + "Sharpness", + "ShearX", + "ShearY", + "TranslateX", + "TranslateY", + "Rotate", + ], + ), + transforms.ToTensor(), + ] + ) diff --git a/vlmo/transforms/utils.py b/vlmo/transforms/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d31422be7d26acd4e18307ee2fccbd8b452e9080 --- /dev/null +++ b/vlmo/transforms/utils.py @@ -0,0 +1,56 @@ +from torchvision import transforms +from PIL import Image + + +class MinMaxResize: + def __init__(self, shorter=800, longer=1333): + self.min = shorter + self.max = longer + + def __call__(self, x): + w, h = x.size + scale = self.min / min(w, h) + if h < w: + newh, neww = self.min, scale * w + else: + newh, neww = scale * h, self.min + + if max(newh, neww) > self.max: + scale = self.max / max(newh, neww) + newh = newh * scale + neww = neww * scale + + newh, neww = int(newh + 0.5), int(neww + 0.5) + newh, neww = newh // 32 * 32, neww // 32 * 32 + + return x.resize((neww, newh), resample=Image.BICUBIC) + + +class UnNormalize(object): + def __init__(self, mean, std): + self.mean = mean + self.std = std + + def __call__(self, tensor): + """ + Args: + tensor (Tensor): Tensor image of size (C, H, W) to be normalized. + Returns: + Tensor: Normalized image. + """ + for t, m, s in zip(tensor, self.mean, self.std): + t.mul_(s).add_(m) + # The normalize code -> t.sub_(m).div_(s) + return tensor + + +# This is simple maximum entropy normalization performed in Inception paper +inception_normalize = transforms.Compose([transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])]) + +# ViT uses simple non-biased inception normalization +# https://github.com/google-research/vision_transformer/blob/master/vit_jax/input_pipeline.py#L132 +inception_unnormalize = transforms.Compose([UnNormalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])]) + +cn_clip_normalize = transforms.Compose( + [transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])] +) diff --git a/vlmo/utils/__init__.py b/vlmo/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/vlmo/utils/beit_utils.py b/vlmo/utils/beit_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3b3bbd355bfdb22f47019ba83aa566e2e1fd94e8 --- /dev/null +++ b/vlmo/utils/beit_utils.py @@ -0,0 +1,75 @@ +import json +import os +import urllib +from tqdm import tqdm + +from vlmo.config import config, _loss_names # noqa +from vlmo.modules import VLMo +from vlmo.transforms import keys_to_transforms + +def _download(url: str, root: str): + os.makedirs(root, exist_ok=True) + filename = os.path.basename(url) + + download_target = os.path.join(root, filename) + + if os.path.exists(download_target) and not os.path.isfile(download_target): + raise RuntimeError(f"{download_target} exists and is not a regular file") + + if os.path.isfile(download_target): + return download_target + + with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: + with tqdm( + total=int(source.info().get("Content-Length")), ncols=80, unit="iB", unit_scale=True, unit_divisor=1024 + ) as loop: + while True: + buffer = source.read(8192) + if not buffer: + break + + output.write(buffer) + loop.update(len(buffer)) + + return download_target + + +def config_setting(custom_config: dict): + cfg = eval("config")() + for k, v in custom_config.items(): + cfg[k] = v + return cfg + + +def load_from_config(model_config): + if isinstance(model_config, str): + model_config = json.loads(open(model_config, 'r').read()) + else: + assert isinstance(model_config, dict) + + model_url = model_config.pop('model_url', None) + model_path = model_config.pop('model_path', None) + if model_path and os.path.exists(model_path): + load_path = model_path + elif model_url: + load_path = _download(model_url, os.path.expanduser("~/.cache/m2_encoder")) + else: + from modelscope import snapshot_download + modelscope_cfg = model_config.pop('modelscope', None) + model_dir = snapshot_download(**modelscope_cfg) + load_path = os.path.join(model_dir, model_config.pop('model_file')) + + cfg = config_setting(model_config) + cfg["load_path"] = load_path + + if cfg["flash_attn"]: + from vlmo.utils.patch_utils import patch_torch_scale_with_flash_attn + patch_torch_scale_with_flash_attn() + + model = VLMo(cfg) + + from vlmo.modules.vlmo_module import get_pretrained_tokenizer + txt_processor = get_pretrained_tokenizer(cfg["tokenizer_type"], from_pretrained=cfg["tokenizer"]) + img_processor = keys_to_transforms(cfg["val_transform_keys"], size=cfg["image_size"])[0] + + return model, [txt_processor, img_processor] diff --git a/vlmo/utils/patch_utils.py b/vlmo/utils/patch_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d3fc1e580bc38850f28acf3125f8b1cb93de0fbf --- /dev/null +++ b/vlmo/utils/patch_utils.py @@ -0,0 +1,107 @@ +# coding: utf-8 +# Copyright (c) Antfin, Inc. All rights reserved. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import torch + + +def _patch_forward( + self, + query, + key, + value, + incremental_state=None, + key_padding_mask=None, + attn_mask=None, + rel_pos=None, +): + bsz, tgt_len, embed_dim = query.size() + src_len = tgt_len + assert embed_dim == self.embed_dim, f"query dim {embed_dim} != {self.embed_dim}" + + key_bsz, src_len, _ = key.size() + assert key_bsz == bsz, f"{query.size(), key.size()}" + assert value is not None + assert bsz, src_len == value.shape[:2] + + q = self.q_proj(query) + k = self.k_proj(key) + v = self.v_proj(value) + + q = q.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) + k = k.view(bsz, src_len, self.num_heads, self.head_dim).transpose(1, 2) + v = v.view(bsz, src_len, self.num_heads, self.head_dim).transpose(1, 2) + + if incremental_state is not None or self.xpos is not None: + q = q.reshape(bsz * self.num_heads, tgt_len, self.head_dim) + k = k.reshape(bsz * self.num_heads, src_len, self.head_dim) + v = v.reshape(bsz * self.num_heads, src_len, self.head_dim) + if incremental_state is not None: + if "prev_key" in incremental_state: + prev_key = incremental_state["prev_key"].view(bsz * self.num_heads, -1, self.head_dim) + prev_value = incremental_state["prev_value"].view(bsz * self.num_heads, -1, self.head_dim) + k = torch.cat([prev_key, k], dim=1) + v = torch.cat([prev_value, v], dim=1) + incremental_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim) + incremental_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim) + src_len = k.size(1) + + if self.xpos is not None: + if incremental_state is not None: + offset = src_len - 1 + else: + offset = 0 + k = self.xpos(k, offset=0, downscale=True) + q = self.xpos(q, offset=offset, downscale=False) + q = q.reshape(bsz, self.num_heads, tgt_len, self.head_dim) + k = k.reshape(bsz, self.num_heads, src_len, self.head_dim) + v = v.reshape(bsz, self.num_heads, src_len, self.head_dim) + + assert rel_pos is None + + # move repeat_interleave to encoder.py is useless?(recompute will save more tensor) + if attn_mask is not None: + if len(attn_mask.shape) == 2: + attn_mask = attn_mask.unsqueeze(0).repeat_interleave(bsz * self.num_heads, dim=0) + else: + attn_mask = attn_mask.repeat_interleave(self.num_heads, dim=0) + + if key_padding_mask is not None: + key_padding_mask = key_padding_mask.unsqueeze(1).unsqueeze(2) + key_padding_mask = key_padding_mask.repeat_interleave(tgt_len, dim=2) + key_padding_mask = key_padding_mask.repeat_interleave(self.num_heads, dim=1) + key_padding_mask = key_padding_mask.view(bsz * self.num_heads, tgt_len, src_len) + if attn_mask is not None: + attn_mask.masked_fill_(key_padding_mask.to(torch.bool), -torch.inf) + else: + attn_mask = key_padding_mask.to(q.dtype).masked_fill(key_padding_mask.to(torch.bool), -torch.inf) + if attn_mask is not None: + attn_mask = attn_mask.to(q.dtype).reshape(bsz, self.num_heads, *tuple(attn_mask.shape[-2:])) + with torch.backends.cuda.sdp_kernel(enable_math=False if attn_mask is None else True): + attn = torch.nn.functional.scaled_dot_product_attention( + q, + k, + v, + attn_mask=attn_mask, + dropout_p=self.dropout_module.p if self.training else 0.0, + ) + + attn = attn.transpose(1, 2).reshape(bsz, tgt_len, embed_dim) + + if self.inner_attn_ln is not None: + attn = self.inner_attn_ln(attn) + + attn = self.out_proj(attn) + # encoder未使用attn weight,直接返回None + return attn, None + + +def patch_torch_scale_with_flash_attn(): + from vlmo.torchscale.component.multihead_attention import MultiheadAttention + torch.backends.cuda.enable_flash_sdp(True) + MultiheadAttention._origin_forward = MultiheadAttention.forward + MultiheadAttention.forward = _patch_forward + print('Finish patch_torch_scale_with_flash_attn!')