acai66 commited on
Commit
3440f83
·
verified ·
1 Parent(s): 24b8390

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .DS_Store +0 -0
  2. .gitattributes +1 -0
  3. .msc +0 -0
  4. .mv +1 -0
  5. README.md +94 -0
  6. configuration.json +1 -0
  7. m2_encoder_1B.ckpt +3 -0
  8. ms_wrapper.py +219 -0
  9. requirements.txt +14 -0
  10. res/effect.png +3 -0
  11. vlmo/.DS_Store +0 -0
  12. vlmo/Encoder_0.4B.json +17 -0
  13. vlmo/README.md +10 -0
  14. vlmo/__init__.py +0 -0
  15. vlmo/config.py +165 -0
  16. vlmo/modules/__init__.py +1 -0
  17. vlmo/modules/heads.py +24 -0
  18. vlmo/modules/modeling_utils.py +179 -0
  19. vlmo/modules/multiway_transformer.py +396 -0
  20. vlmo/modules/objectives.py +12 -0
  21. vlmo/modules/vlmo_module.py +405 -0
  22. vlmo/modules/vlmo_utils.py +12 -0
  23. vlmo/tokenizer/__init__.py +6 -0
  24. vlmo/tokenizer/sp.model +3 -0
  25. vlmo/tokenizer/tokenization_glm.py +307 -0
  26. vlmo/tokenizer/tokenizer_config.json +17 -0
  27. vlmo/torchscale/__init__.py +2 -0
  28. vlmo/torchscale/architecture/__init__.py +2 -0
  29. vlmo/torchscale/architecture/config.py +197 -0
  30. vlmo/torchscale/architecture/decoder.py +428 -0
  31. vlmo/torchscale/architecture/encoder.py +489 -0
  32. vlmo/torchscale/architecture/encoder_decoder.py +43 -0
  33. vlmo/torchscale/architecture/utils.py +33 -0
  34. vlmo/torchscale/component/__init__.py +2 -0
  35. vlmo/torchscale/component/droppath.py +19 -0
  36. vlmo/torchscale/component/embedding.py +110 -0
  37. vlmo/torchscale/component/feedforward_network.py +128 -0
  38. vlmo/torchscale/component/multihead_attention.py +154 -0
  39. vlmo/torchscale/component/multiway_network.py +55 -0
  40. vlmo/torchscale/component/relative_position_bias.py +67 -0
  41. vlmo/torchscale/component/xpos_relative_position.py +62 -0
  42. vlmo/torchscale/model/BEiT3.py +96 -0
  43. vlmo/torchscale/model/__init__.py +2 -0
  44. vlmo/transforms/__init__.py +19 -0
  45. vlmo/transforms/pixelbert.py +30 -0
  46. vlmo/transforms/randaug.py +271 -0
  47. vlmo/transforms/randaugment.py +334 -0
  48. vlmo/transforms/square_transform.py +41 -0
  49. vlmo/transforms/utils.py +56 -0
  50. vlmo/utils/__init__.py +0 -0
.DS_Store ADDED
Binary file (6.15 kB). View file
 
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ res/effect.png filter=lfs diff=lfs merge=lfs -text
.msc ADDED
Binary file (4.2 kB). View file
 
.mv ADDED
@@ -0,0 +1 @@
 
 
1
+ Revision:master,CreatedAt:1742457812
README.md ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tasks:
3
+ - multi-modal-embedding
4
+ - image-text-retrieval
5
+ domain:
6
+ - multi-modal
7
+ frameworks:
8
+ - pytorch
9
+ backbone:
10
+ - transformers
11
+ metrics:
12
+ - R@1
13
+ license: apache-2.0
14
+ tags:
15
+ - Ant Group
16
+ - multi-modal-embedding
17
+ widgets:
18
+ - inputs:
19
+ - validator:
20
+ max_words: 52
21
+ type: text
22
+ title: 查询文本
23
+ output:
24
+ maximize: false
25
+ examples:
26
+ - name: 1
27
+ inputs:
28
+ - data: 戴眼镜的猫
29
+ - name: 2
30
+ inputs:
31
+ - data: 一个在逛公园的女孩
32
+ task: multi-modal-embedding
33
+ ---
34
+
35
+ ## 模型描述
36
+ M2-Encoder是强大的中英双语多模态模型,它在我们构建的包含60亿图文对(30亿中文+30亿英文)的BM-6B上训练得到,支持zero-shot 图文跨模态检索(文搜图、图搜文) 以及 zero-shot图片分类 任务。
37
+
38
+ 模型效果如下:
39
+
40
+ ![M2-Encoder](./res/effect.png)
41
+
42
+ ## 期望模型使用方式以及适用范围
43
+ 本模型主要用于:
44
+ 1. 图片检索文本,或文本检索图片: 以文本检索图片为例,使用M2-Encoder提前对所有图片底库进行特征抽取,给定文本query,使用M2-Encoder对query文本进行特征抽取, 然后和图片底库保存的特征进行相似度计算。
45
+ 2. 图片zero-shot开集分类: 给定图像以及对应的标签列表,根据图像和标签相似度,输出与图像最匹配的标签。
46
+
47
+
48
+ ## 如何使用
49
+
50
+ ### 代码范例
51
+ ```
52
+ # 新建环境(Python版本3.8)
53
+ conda create -n m2-encoder python=3.8
54
+ source activate m2-encoder
55
+
56
+ # clone项目地址
57
+ cd /YourPath/
58
+ git clone https://github.com/alipay/Ant-Multi-Modal-Framework
59
+
60
+ # 安装包依赖
61
+ cd ./Ant-Multi-Modal-Framework/prj/M2_Encoder/
62
+ pip install -r requirements.txt
63
+
64
+ # 运行demo,会自动通过model_scope下载对应模型权重
65
+ python run.py
66
+ ```
67
+
68
+ ### 模型局限性以及可能的偏差
69
+ 模型在数据集上训练,有可能产生一些偏差,请用户自行评测后决定如何使用。
70
+
71
+ ## 训练数据介绍
72
+ BM-6B数据集: 包含60亿清洗后的高质量中英双语图文对数据,其中文和英文数据比例基本保持一致,均为30亿。数据集搜集、构建过程详见[技术报告](https://arxiv.org/abs/2401.15896)。
73
+
74
+ ## 模型训练流程
75
+ 暂时不支持通过ModelScope接口进行训练,敬请期待。
76
+
77
+
78
+ ### 训练
79
+ 暂不支持。
80
+ ## 数据评估及结果
81
+ zero-shot图文跨模态检索和zero-shot分类任务均达到SOTA.
82
+
83
+
84
+
85
+ ### 相关论文以及引用信息
86
+ 如果你觉得这个该模型对有所帮助,请考虑引用下面的相关的论文:
87
+ ```
88
+ @misc{guo2024m2encoder,
89
+ title={M2-Encoder: Advancing Bilingual Image-Text Understanding by Large-scale Efficient Pretraining},
90
+ author={Qingpei Guo and Furong Xu and Hanxiao Zhang and Wang Ren and Ziping Ma and Lin Ju and Jian Wang and Jingdong Chen and Ming Yang},
91
+ year={2024},
92
+ url={https://arxiv.org/abs/2401.15896},
93
+ }
94
+ ```
configuration.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"framework":"Pytorch","task":"multi-modal-embeddings","pipeline":{"type":"multi-modal-embedding-pipeline"},"allow_remote": true}
m2_encoder_1B.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ac4c5d9a0e44fff05f0ccadf54617b7809f489bc401212abd836c8d075047e9b
3
+ size 2921990385
ms_wrapper.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+
3
+ import torch
4
+ import os
5
+
6
+ from modelscope.models.base import TorchModel
7
+ from modelscope.preprocessors.base import Preprocessor
8
+ from modelscope.pipelines.base import Model, Pipeline
9
+ from modelscope.utils.config import Config
10
+ from modelscope.pipelines.builder import PIPELINES
11
+ from modelscope.preprocessors.builder import PREPROCESSORS
12
+ from modelscope.models.builder import MODELS
13
+ from modelscope.preprocessors.image import load_image
14
+
15
+
16
+ from vlmo.utils.beit_utils import load_from_config
17
+
18
+
19
+ @PIPELINES.register_module(
20
+ "multi-modal-embeddings", module_name="multi-modal-embedding-pipeline"
21
+ )
22
+ class MyCustomPipeline(Pipeline):
23
+ """Give simple introduction to this pipeline.
24
+
25
+ Examples:
26
+
27
+ >>> from modelscope.pipelines import pipeline
28
+ >>> input = "Hello, ModelScope!"
29
+ >>> my_pipeline = pipeline('my-task', 'my-model-id')
30
+ >>> result = my_pipeline(input)
31
+
32
+ """
33
+
34
+ def __init__(self, model, preprocessor=None, **kwargs):
35
+ """
36
+ use `model` and `preprocessor` to create a custom pipeline for prediction
37
+ Args:
38
+ model: model id on modelscope hub.
39
+ preprocessor: the class of method be init_preprocessor
40
+ """
41
+ super().__init__(model=model, auto_collate=False)
42
+ self.model_dir = model
43
+ self._device = "cuda" if torch.cuda.is_available() else "cpu"
44
+ # model_config = {
45
+ # "loss_names": {"itc": 1},
46
+ # "encoder_layers": 9,
47
+ # "beit3_vl_layers": 3,
48
+ # "tokenizer_type": "GLMChineseTokenizer",
49
+ # "tokenizer": os.path.join(self.model_dir, "./vlmo/tokenizer"),
50
+ # "vocab_size": 115244,
51
+ # "whole_word_masking": True,
52
+ # "precision": 32,
53
+ # "test_only": True,
54
+ # "flash_attn": True,
55
+ # "model_path": os.path.join(self.model_dir, "m2_encoder_1B.ckpt"),
56
+ # "modelscope": {"model_id": "M2Cognition/M2-Encoder-Large"},
57
+ # "model_file": "m2_encoder_1B.ckpt",
58
+ # }
59
+ model_config = {
60
+ "loss_names": {"itc": 1},
61
+ "beit_version": "large",
62
+ "encoder_embed_dim": 1024,
63
+ "out_embed_dim": 1024,
64
+ "encoder_layers": 21,
65
+ "beit3_vl_layers": 3,
66
+ # "image_size": 224,
67
+ "visual_mask_size": 14,
68
+ "tokenizer_type": "GLMChineseTokenizer",
69
+ "tokenizer": os.path.join(self.model_dir, "./vlmo/tokenizer"),
70
+ "vocab_size": 115244,
71
+ "whole_word_masking": False,
72
+ "precision": 32,
73
+ "test_only": True,
74
+ "flash_attn": True,
75
+ "model_path": os.path.join(self.model_dir, "m2_encoder_1B.ckpt"),
76
+ "modelscope": {
77
+ "model_id": "M2Cognition/M2_Encoder_Large"
78
+ },
79
+ "model_file": "m2_encoder_1B.ckpt"
80
+ }
81
+ model, processors = load_from_config(model_config)
82
+ self.model = model
83
+ self.model.to(self._device).eval()
84
+ self._tokenizer, self._img_processor = processors
85
+
86
+ def _sanitize_parameters(self, **pipeline_parameters):
87
+ """
88
+ this method should sanitize the keyword args to preprocessor params,
89
+ forward params and postprocess params on '__call__' or '_process_single' method
90
+ considered to be a normal classmethod with default implementation / output
91
+
92
+ Default Returns:
93
+ Dict[str, str]: preprocess_params = {}
94
+ Dict[str, str]: forward_params = {}
95
+ Dict[str, str]: postprocess_params = pipeline_parameters
96
+ """
97
+ return {}, pipeline_parameters, {}
98
+
99
+ def _check_input(self, inputs):
100
+ pass
101
+
102
+ def _check_output(self, outputs):
103
+ pass
104
+
105
+ def forward(self, forward_params):
106
+ """Provide default implementation using self.model and user can reimplement it"""
107
+ # print("forward_params", forward_params)
108
+ labels = forward_params.get("label_list", "")
109
+ labels = labels.split(",")
110
+ if len(labels) > 1 and labels[0] != "":
111
+ txt_encoding = self._tokenizer(
112
+ labels,
113
+ padding="max_length",
114
+ truncation=True,
115
+ max_length=self.model.hparams.config["max_text_len"],
116
+ return_special_tokens_mask=True,
117
+ )
118
+ txt_data = {
119
+ "text_ids": torch.tensor(txt_encoding["input_ids"]).to(self._device),
120
+ "text_masks": torch.tensor(txt_encoding["attention_mask"]).to(
121
+ self._device
122
+ ),
123
+ "text_labels": None,
124
+ }
125
+ txt_feats = self.model.infer_text(txt_data)["cls_vlffn_feats"]
126
+ image = forward_params["image"]
127
+ image = load_image(image)
128
+ img = self._img_processor(image).unsqueeze(0)
129
+ img_data = {"image": [img.to(self._device)]}
130
+ img_feats = self.model.infer_image(img_data)["cls_vlffn_feats"]
131
+ logits_per_image = self.model.logit_scale.exp() * img_feats @ txt_feats.t()
132
+ probs = logits_per_image.softmax(dim=-1).detach().cpu()
133
+ index = probs.max(dim=-1)[1][0]
134
+ label = labels[index]
135
+ return {"text": label, "scores": probs.numpy().tolist()[0]}
136
+ else:
137
+ rets = {}
138
+ if "text" in forward_params:
139
+ text = forward_params.get("text")
140
+ txt_encoding = self._tokenizer(
141
+ text,
142
+ padding="max_length",
143
+ truncation=True,
144
+ max_length=self.model.hparams.config["max_text_len"],
145
+ return_special_tokens_mask=True,
146
+ )
147
+ txt_data = {
148
+ "text_ids": torch.tensor(txt_encoding["input_ids"]).to(
149
+ self._device
150
+ ),
151
+ "text_masks": torch.tensor(txt_encoding["attention_mask"]).to(
152
+ self._device
153
+ ),
154
+ "text_labels": None,
155
+ }
156
+ txt_feats = self.model.infer_text(txt_data)["cls_vlffn_feats"]
157
+ rets.update({"text_embedding": txt_feats.detach()})
158
+ if "img" in forward_params:
159
+ input_img = forward_params["img"]
160
+ img = self._img_processor(input_img).unsqueeze(0)
161
+ img_data = {"image": [img.to(self._device)]}
162
+ img_feats = self.model.infer_image(img_data)["cls_vlffn_feats"]
163
+ rets.update({"img_embedding": img_feats.detach()})
164
+
165
+ return rets
166
+
167
+ def preprocess(self, inputs):
168
+ return inputs
169
+
170
+ def postprocess(self, inputs):
171
+ """If current pipeline support model reuse, common postprocess
172
+ code should be write here.
173
+
174
+ Args:
175
+ inputs: input data
176
+
177
+ Return:
178
+ dict of results: a dict containing outputs of model, each
179
+ output should have the standard output name.
180
+ """
181
+ return inputs
182
+
183
+
184
+ """
185
+ # Tips: usr_config_path is the temporary save configuration location, after upload modelscope hub, it is the model_id
186
+ usr_config_path = "/tmp/snapdown/"
187
+ config = Config(
188
+ {
189
+ "framework": "pytorch",
190
+ "task": "multi-modal-embeddings",
191
+ "model": {"type": "m2-encoder"},
192
+ "pipeline": {"type": "multi-modal-embedding-pipeline"},
193
+ "allow_remote": True,
194
+ }
195
+ )
196
+ config.dump("/tmp/snapdown/" + "configuration.json")
197
+ """
198
+
199
+ if __name__ == "__main__":
200
+ from modelscope.pipelines import pipeline
201
+ from modelscope.preprocessors.image import load_image
202
+
203
+ model = "M2Cognition/M2-Encoder"
204
+ pipe = pipeline("multi-modal-embeddings", model=model)
205
+ input = {
206
+ "image": "https://clip-cn-beijing.oss-cn-beijing.aliyuncs.com/pokemon.jpeg",
207
+ "label_list": "杰尼龟,妙蛙种子,小火龙,皮卡丘",
208
+ }
209
+ demo = pipe(input)
210
+ print("demo output", demo)
211
+ inputs = {"text": ["杰尼龟", "妙蛙种子", "小火龙", "皮卡丘"]}
212
+ output = pipe(inputs)
213
+ print("text output", output)
214
+ input_img = load_image(
215
+ "https://clip-cn-beijing.oss-cn-beijing.aliyuncs.com/pokemon.jpeg"
216
+ ) # 支持皮卡丘示例图片路径/本地图片 返回PIL.Image
217
+ inputs = {"img": input_img}
218
+ img_embedding = pipe(inputs) # 2D Tensor, [图片数, 特征维度]
219
+ print("image output", img_embedding)
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ pytorch_lightning<=2.0.8
3
+ transformers
4
+ Pillow
5
+ tqdm
6
+ einops
7
+ sacred
8
+ timm
9
+ torchvision
10
+ fairscale
11
+ numpy
12
+ opencv-python
13
+ sentencepiece
14
+ modelscope
res/effect.png ADDED

Git LFS Details

  • SHA256: 4d6a1950c1ab8770d7d6949d9164c03627fd6bfe673538e71ab4700a68aa6167
  • Pointer size: 132 Bytes
  • Size of remote file: 1.12 MB
vlmo/.DS_Store ADDED
Binary file (6.15 kB). View file
 
vlmo/Encoder_0.4B.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "loss_names": {"itc": 1},
3
+ "encoder_layers": 9,
4
+ "beit3_vl_layers": 3,
5
+ "tokenizer_type": "GLMChineseTokenizer",
6
+ "tokenizer": "./vlmo/tokenizer",
7
+ "vocab_size": 115244,
8
+ "whole_word_masking": true,
9
+ "precision": 32,
10
+ "test_only": true,
11
+ "flash_attn": true,
12
+ "model_path": "m2_encoder_0.4B.ckpt",
13
+ "modelscope": {
14
+ "model_id": "M2Cognition/M2-Encoder"
15
+ },
16
+ "model_file": "m2_encoder_0.4B.ckpt"
17
+ }
vlmo/README.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: Apache License 2.0
3
+ ---
4
+ ###### 该模型当前使用的是默认介绍模版,处于“预发布”阶段,页面仅限所有者可见。
5
+ ###### 请根据[模型贡献文档说明](https://www.modelscope.cn/docs/%E5%A6%82%E4%BD%95%E6%92%B0%E5%86%99%E5%A5%BD%E7%94%A8%E7%9A%84%E6%A8%A1%E5%9E%8B%E5%8D%A1%E7%89%87),及时完善模型卡片内容。ModelScope平台将在模型卡片完善后展示。谢谢您的理解。
6
+
7
+ #### Clone with HTTP
8
+ ```bash
9
+ git clone https://www.modelscope.cn/M2Cognition/M2_Encoder_demo.git
10
+ ```
vlmo/__init__.py ADDED
File without changes
vlmo/config.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sacred import Experiment
2
+
3
+ ex = Experiment("VLMo")
4
+
5
+
6
+ def _loss_names(d):
7
+ ret = {
8
+ "itm": 0, # image-text matching loss
9
+ "itc": 0, # image-text contrastive loss
10
+ "caption": 0, # image captioning loss
11
+ "mvlm": 0, # masked language modeling loss
12
+ "textmlm": 0, # text-only masked language modeling
13
+ "imagemlm": 0, # image-only masked language modeling
14
+ "vqa": 0,
15
+ "nlvr2": 0,
16
+ "irtr": 0, # retrieval task ft
17
+ }
18
+ ret.update(d)
19
+ return ret
20
+
21
+
22
+ @ex.config
23
+ def config():
24
+ exp_name = "vlmo"
25
+ seed = 1
26
+ datasets = ["coco", "vg", "sbu", "gcc"] # dataset name, the definition can refer to: vlmo/datamodules/__init__.py # noqa
27
+ loss_names = _loss_names({"itm": 0, "itc": 0, "mvlm": 0}) # training loss
28
+ batch_size = 1024 # this is a desired batch size; pl trainer will accumulate gradients.
29
+
30
+ # BEiT-v3 setting
31
+ encoder_layers = 12 # the layer number of backbone
32
+ encoder_embed_dim = 768 # the hidden size of tokenizer
33
+ out_embed_dim = 768 # the hidden size of output embedding
34
+ beit_version = "base" # model size: base(0.4B)|large(1B)|huge(10B)
35
+ beit3_vl_layers = 3 # the layer number of vl_backbone
36
+ deepnorm_init = True # init method
37
+ share_layer = False # if share the weight between layer within backbone
38
+ share_attn = False # if share the attention weight of different layer
39
+ one_attn = False # if share the attention weight of vision and language
40
+
41
+ # Image setting
42
+ train_transform_keys = ["square_transform_randaug"] # train transform: refer to vlmo/transforms/__init__.py
43
+ val_transform_keys = ["square_transform"] # test transform: refer to refer to vlmo/transforms/__init__.py
44
+ image_size = 224 # image size
45
+ reclip_image_size = None # reclip image size
46
+ patch_size = 16 # patch size
47
+ draw_false_image = 0 # if get negative image
48
+ image_only = False # only input image
49
+ text_only = False # # only input text
50
+
51
+ # Video setting, video_num_frm is not None means video input
52
+ video_num_frm = None
53
+
54
+ # Visual tokenizer setting based on beit2
55
+ tokenizer_model = "beit2_visual_tokenizer"
56
+ codebook_size = 8192
57
+ codebook_dim = 32
58
+ visual_mask_size = 14
59
+ visual_mask_num = 80
60
+
61
+ # Text Setting
62
+ lang = 'cn' # language for zero-shot imagenet testing: cn|en
63
+ vqav2_label_size = 3129
64
+ max_text_len = 40 # the number of characters
65
+ max_text_len_of_initckpt = 196
66
+ tokenizer_type = "BertTokenizer" # Chinese text
67
+ vocab_size = 21128
68
+ tokenizer = "./vocab.txt"
69
+ whole_word_masking = True
70
+ mlm_prob = 0.15 # language mask ratio
71
+ draw_false_text = 0
72
+ mvlm_prob = 0.50 # vision-langurage mlm task
73
+ mask_ratio = 0 # flip: mask ratio for image
74
+
75
+ # cap setting
76
+ cap_onlytext = False # default caption image to text
77
+
78
+ # imagemlm setting
79
+ split_data_for_imagemlm = False # if True, split a batch data to two parts, and the first part for imagemlm.
80
+
81
+ # itc setting
82
+ itc_mask = False # itc use masked token
83
+ aggregate_nodes = -1 # aggregate nodes num for compute_itc, default -1 is for all nodes
84
+
85
+ # Transformer Setting
86
+ model_arch = "vlmo_base_patch16"
87
+ drop_path_rate = 0.1
88
+
89
+ # Downstream Setting
90
+ get_recall_metric = False
91
+ get_recall_rerank_metric = False
92
+ get_zeroshot_metric = False
93
+ get_muge_feat = False
94
+ get_f30k_feat = False
95
+ k_test = 32
96
+
97
+ # PL Trainer Setting
98
+ resume_from = None
99
+ fast_dev_run = False
100
+ val_check_interval = 1.0
101
+ test_only = False
102
+ use_sharded_training = False
103
+ resume_during_training = False
104
+ save_top_k = 10
105
+ every_n_train_steps = 2000 # the step to save checkpoint
106
+ log_metric_steps = 100 # the step to log metric
107
+
108
+ # below params varies with the environment
109
+ use_pcache = False # data storage method: pcache or nas
110
+ pcache_root = ""
111
+ # main_site: pcache://multimodalproxyi-pool.cz50c.alipay.com:39999/mnt/
112
+ # public_cloud: pcache://pcache_public_cloud.pcache.local:39999/mnt/abc7c88079a60b45ddfce7afa40720b7/
113
+ gpu_env = "main_site" # public_cloud or main_site
114
+ data_root = "" # data root for data list
115
+
116
+
117
+ log_dir = "result"
118
+ per_gpu_batchsize = 4 # you should define this manually with per_gpu_batch_size=#
119
+ num_gpus = 1
120
+ num_nodes = 1
121
+ load_path = ""
122
+ num_workers = 8
123
+ precision = 16
124
+ local_run = True
125
+ flash_attn = False
126
+ deepspeed_config = None # "ds_config.json"
127
+ coalesce_backbone = False
128
+ mask_data = "v+l" # 'v+l':choose input of imagemlm+textmlm task, 'vl': choose input of mvlm task.
129
+ communication_benchmark = False
130
+ checkpoint_activations = False
131
+
132
+ # dataset setting
133
+ single_cap = True # if have only one caption
134
+ random_one = False # if choose one caption from caption list
135
+
136
+ # ITC setting
137
+ itc_feats_name = "cls_vlffn_feats" # feat for itc loss
138
+ itc_distill = ""
139
+ itc_distill_dim = 1024
140
+ itc_teacher_weights = ""
141
+
142
+ # mup training setting
143
+ mup = False
144
+ base_encoder_embed_dim = 1
145
+ delta_encoder_embed_dim = 2
146
+ mup_encoder_attention_heads = 1
147
+ base_encoder_ffn_embed_dim = 1
148
+ delta_encoder_ffn_embed_dim = 2
149
+
150
+ # atorch
151
+ atorch_config = None
152
+ compile_op = False
153
+ optimizer_state_shard_save = False
154
+ model_state_shard_save = False
155
+
156
+ # itc loss
157
+ local_loss = False
158
+ use_dual_softmax = False
159
+
160
+ num_frames = 1
161
+ # ----------------------- LMM pretraining config -----------------------
162
+
163
+ # norm setting
164
+ deepnorm = False
165
+
vlmo/modules/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .vlmo_module import VLMo
vlmo/modules/heads.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+
4
+ class Pooler(nn.Module):
5
+ def __init__(self, hidden_size):
6
+ super().__init__()
7
+ self.dense = nn.Linear(hidden_size, hidden_size)
8
+ self.activation = nn.Tanh()
9
+
10
+ def forward(self, hidden_states):
11
+ first_token_tensor = hidden_states[:, 0]
12
+ pooled_output = self.dense(first_token_tensor)
13
+ pooled_output = self.activation(pooled_output)
14
+ return pooled_output
15
+
16
+
17
+ class ITCHead(nn.Module):
18
+ def __init__(self, hidden_size, out_size):
19
+ super().__init__()
20
+ self.fc = nn.Linear(hidden_size, out_size, bias=False)
21
+
22
+ def forward(self, x):
23
+ x = self.fc(x)
24
+ return x
vlmo/modules/modeling_utils.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Image as a Foreign Language: BEiT Pretraining for Vision and Vision-Language Tasks (https://arxiv.org/abs/2208.10442)
3
+ # Github source: https://github.com/microsoft/unilm/tree/master/beit3
4
+ # Copyright (c) 2023 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # --------------------------------------------------------'
7
+
8
+ import math
9
+ import torch
10
+ import torch.nn as nn
11
+ from timm.models.layers import trunc_normal_ as __call_trunc_normal_
12
+
13
+ from vlmo.torchscale.model.BEiT3 import BEiT3
14
+ from vlmo.torchscale.architecture.config import EncoderConfig
15
+
16
+
17
+ def trunc_normal_(tensor, mean=0.0, std=1.0):
18
+ __call_trunc_normal_(tensor, mean=mean, std=std, a=-std, b=std)
19
+
20
+
21
+ def _get_base_config(
22
+ img_size=224,
23
+ patch_size=16,
24
+ drop_path_rate=0,
25
+ checkpoint_activations=None,
26
+ mlp_ratio=4,
27
+ vocab_size=64010,
28
+ encoder_layers=12,
29
+ encoder_embed_dim=768,
30
+ encoder_attention_heads=12,
31
+ share_layer=False,
32
+ share_attn=False,
33
+ deepnorm=False,
34
+ mask_ratio=0,
35
+ max_text_len=52,
36
+ one_attn=False,
37
+ **kwargs
38
+ ):
39
+ return EncoderConfig(
40
+ img_size=img_size,
41
+ patch_size=patch_size,
42
+ vocab_size=vocab_size,
43
+ multiway=True,
44
+ layernorm_embedding=False,
45
+ normalize_output=True,
46
+ no_output_layer=True,
47
+ drop_path_rate=drop_path_rate,
48
+ encoder_embed_dim=encoder_embed_dim,
49
+ encoder_attention_heads=encoder_attention_heads,
50
+ encoder_layers=encoder_layers,
51
+ encoder_ffn_embed_dim=int(encoder_embed_dim * mlp_ratio),
52
+ checkpoint_activations=checkpoint_activations,
53
+ share_layer=share_layer,
54
+ share_attn=share_attn,
55
+ deepnorm=deepnorm,
56
+ mask_ratio=mask_ratio,
57
+ max_text_len=max_text_len,
58
+ one_attn=one_attn,
59
+ )
60
+
61
+
62
+ def _get_large_config(
63
+ img_size=224,
64
+ patch_size=16,
65
+ drop_path_rate=0,
66
+ checkpoint_activations=None,
67
+ mlp_ratio=4,
68
+ vocab_size=64010,
69
+ encoder_layers=24,
70
+ encoder_embed_dim=1024,
71
+ encoder_attention_heads=16,
72
+ share_layer=False,
73
+ share_attn=False,
74
+ deepnorm=False,
75
+ mask_ratio=0,
76
+ max_text_len=52,
77
+ one_attn=False,
78
+ **kwargs
79
+ ):
80
+ return EncoderConfig(
81
+ img_size=img_size,
82
+ patch_size=patch_size,
83
+ vocab_size=vocab_size,
84
+ multiway=True,
85
+ layernorm_embedding=False,
86
+ normalize_output=True,
87
+ no_output_layer=True,
88
+ drop_path_rate=drop_path_rate,
89
+ encoder_embed_dim=encoder_embed_dim,
90
+ encoder_attention_heads=encoder_attention_heads,
91
+ encoder_layers=encoder_layers,
92
+ encoder_ffn_embed_dim=int(encoder_embed_dim * mlp_ratio),
93
+ checkpoint_activations=checkpoint_activations,
94
+ share_layer=share_layer,
95
+ share_attn=share_attn,
96
+ deepnorm=deepnorm,
97
+ mask_ratio=mask_ratio,
98
+ max_text_len=max_text_len,
99
+ one_attn=one_attn,
100
+ )
101
+
102
+
103
+ def _get_huge_config(
104
+ img_size=224,
105
+ patch_size=16,
106
+ drop_path_rate=0,
107
+ checkpoint_activations=None,
108
+ mlp_ratio=4,
109
+ vocab_size=30522,
110
+ encoder_layers=32,
111
+ encoder_embed_dim=4096,
112
+ encoder_attention_heads=32,
113
+ share_layer=False,
114
+ share_attn=False,
115
+ deepnorm=False,
116
+ mask_ratio=0,
117
+ max_text_len=52,
118
+ one_attn=False,
119
+ **kwargs
120
+ ):
121
+ return EncoderConfig(
122
+ img_size=img_size,
123
+ patch_size=patch_size,
124
+ vocab_size=vocab_size,
125
+ multiway=True,
126
+ layernorm_embedding=False,
127
+ normalize_output=True,
128
+ no_output_layer=True,
129
+ drop_path_rate=drop_path_rate,
130
+ encoder_embed_dim=encoder_embed_dim,
131
+ encoder_attention_heads=encoder_attention_heads,
132
+ encoder_layers=encoder_layers,
133
+ encoder_ffn_embed_dim=int(encoder_embed_dim * mlp_ratio),
134
+ checkpoint_activations=checkpoint_activations,
135
+ share_layer=share_layer,
136
+ share_attn=share_attn,
137
+ deepnorm=deepnorm,
138
+ mask_ratio=mask_ratio,
139
+ max_text_len=max_text_len,
140
+ one_attn=one_attn,
141
+ )
142
+
143
+
144
+ class BEiT3Wrapper(nn.Module):
145
+ def __init__(self, args, **kwargs):
146
+ super().__init__()
147
+ self.args = args
148
+ self.beit3 = BEiT3(args)
149
+ self.apply(self._init_weights)
150
+
151
+ def fix_init_weight(self):
152
+ def rescale(param, layer_id):
153
+ param.div_(math.sqrt(2.0 * layer_id))
154
+
155
+ for layer_id, layer in enumerate(self.blocks):
156
+ rescale(layer.attn.proj.weight.data, layer_id + 1)
157
+ rescale(layer.mlp.fc2.weight.data, layer_id + 1)
158
+
159
+ def get_num_layers(self):
160
+ return self.beit3.encoder.num_layers
161
+
162
+ @torch.jit.ignore
163
+ def no_weight_decay(self):
164
+ return {
165
+ "pos_embed",
166
+ "cls_token",
167
+ "beit3.encoder.embed_positions.A.weight",
168
+ "beit3.vision_embed.cls_token",
169
+ "logit_scale",
170
+ }
171
+
172
+ def _init_weights(self, m):
173
+ if isinstance(m, nn.Linear):
174
+ trunc_normal_(m.weight, std=0.02)
175
+ if isinstance(m, nn.Linear) and m.bias is not None:
176
+ nn.init.constant_(m.bias, 0)
177
+ elif isinstance(m, nn.LayerNorm):
178
+ nn.init.constant_(m.bias, 0)
179
+ nn.init.constant_(m.weight, 1.0)
vlmo/modules/multiway_transformer.py ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Vision Transformer (ViT) in PyTorch
2
+
3
+ A PyTorch implement of Vision Transformers as described in
4
+ 'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' - https://arxiv.org/abs/2010.11929
5
+
6
+ The official jax code is released and available at https://github.com/google-research/vision_transformer
7
+
8
+ Acknowledgments:
9
+ * The paper authors for releasing code and weights, thanks!
10
+ * I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out
11
+ for some einops/einsum fun
12
+ * Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT
13
+ * Bert reference code checks against Huggingface Transformers and Tensorflow Bert
14
+
15
+ DeiT model defs and weights from https://github.com/facebookresearch/deit,
16
+ paper `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877
17
+
18
+ Hacked together by / Copyright 2020 Ross Wightman
19
+ """
20
+ from functools import partial
21
+
22
+ import torch
23
+ import torch.nn as nn
24
+ import torch.nn.functional as F
25
+
26
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
27
+ from timm.models.registry import register_model
28
+ from pytorch_lightning.utilities.rank_zero import rank_zero_info
29
+
30
+
31
+ class Mlp(nn.Module):
32
+ def __init__(
33
+ self,
34
+ in_features,
35
+ hidden_features=None,
36
+ out_features=None,
37
+ act_layer=nn.GELU,
38
+ drop=0.0,
39
+ ):
40
+ super().__init__()
41
+ out_features = out_features or in_features
42
+ hidden_features = hidden_features or in_features
43
+ self.fc1 = nn.Linear(in_features, hidden_features)
44
+ self.act = act_layer()
45
+ self.fc2 = nn.Linear(hidden_features, out_features)
46
+ self.drop = nn.Dropout(drop)
47
+
48
+ def forward(self, x):
49
+ x = self.fc1(x)
50
+ x = self.act(x)
51
+ x = self.drop(x)
52
+ x = self.fc2(x)
53
+ x = self.drop(x)
54
+ return x
55
+
56
+
57
+ class Attention(nn.Module):
58
+ def __init__(
59
+ self,
60
+ dim,
61
+ num_heads=8,
62
+ qkv_bias=False,
63
+ qk_scale=None,
64
+ attn_drop=0.0,
65
+ proj_drop=0.0,
66
+ ):
67
+ super().__init__()
68
+ self.num_heads = num_heads
69
+ head_dim = dim // num_heads
70
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
71
+ self.scale = qk_scale or head_dim**-0.5
72
+
73
+ self.qkv = nn.Linear(dim, dim * 3, bias=False)
74
+ if qkv_bias:
75
+ self.q_bias = nn.Parameter(torch.zeros(dim))
76
+ self.v_bias = nn.Parameter(torch.zeros(dim))
77
+ else:
78
+ self.q_bias = None
79
+ self.v_bias = None
80
+
81
+ self.attn_drop = nn.Dropout(attn_drop)
82
+ self.proj = nn.Linear(dim, dim)
83
+ self.proj_drop = nn.Dropout(proj_drop)
84
+
85
+ def forward(self, x, mask=None, relative_position_bias=None):
86
+ B, N, C = x.shape
87
+
88
+ qkv_bias = None
89
+ if self.q_bias is not None:
90
+ qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
91
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
92
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
93
+
94
+ q, k, v = (
95
+ qkv[0],
96
+ qkv[1],
97
+ qkv[2],
98
+ ) # make torchscript happy (cannot use tensor as tuple)
99
+
100
+ q = q * self.scale
101
+ attn = q.float() @ k.float().transpose(-2, -1)
102
+
103
+ if relative_position_bias is not None:
104
+ attn = attn + relative_position_bias.unsqueeze(0)
105
+
106
+ if mask is not None:
107
+ mask = mask.bool()
108
+ attn = attn.masked_fill(~mask[:, None, None, :], float("-inf"))
109
+ attn = attn.softmax(dim=-1).type_as(x)
110
+ attn = self.attn_drop(attn)
111
+
112
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
113
+ x = self.proj(x)
114
+ x = self.proj_drop(x)
115
+ return x
116
+
117
+
118
+ class Block(nn.Module):
119
+ def __init__(
120
+ self,
121
+ dim,
122
+ num_heads,
123
+ mlp_ratio=4.0,
124
+ qkv_bias=False,
125
+ qk_scale=None,
126
+ drop=0.0,
127
+ attn_drop=0.0,
128
+ drop_path=0.0,
129
+ act_layer=nn.GELU,
130
+ norm_layer=nn.LayerNorm,
131
+ with_vlffn=False,
132
+ layer_scale_init_values=0.1,
133
+ max_text_len=40,
134
+ ):
135
+ super().__init__()
136
+ self.norm1 = norm_layer(dim)
137
+ self.attn = Attention(
138
+ dim,
139
+ num_heads=num_heads,
140
+ qkv_bias=qkv_bias,
141
+ qk_scale=qk_scale,
142
+ attn_drop=attn_drop,
143
+ proj_drop=drop,
144
+ )
145
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
146
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
147
+ self.norm2_text = norm_layer(dim)
148
+ self.norm2_imag = norm_layer(dim)
149
+ mlp_hidden_dim = int(dim * mlp_ratio)
150
+ self.mlp_text = Mlp(
151
+ in_features=dim,
152
+ hidden_features=mlp_hidden_dim,
153
+ act_layer=act_layer,
154
+ drop=drop,
155
+ )
156
+ self.mlp_imag = Mlp(
157
+ in_features=dim,
158
+ hidden_features=mlp_hidden_dim,
159
+ act_layer=act_layer,
160
+ drop=drop,
161
+ )
162
+ self.mlp_vl = None
163
+ if with_vlffn:
164
+ self.mlp_vl = Mlp(
165
+ in_features=dim,
166
+ hidden_features=mlp_hidden_dim,
167
+ act_layer=act_layer,
168
+ drop=drop,
169
+ )
170
+ self.norm2_vl = norm_layer(dim)
171
+
172
+ self.gamma_1 = (
173
+ nn.Parameter(layer_scale_init_values * torch.ones((dim)), requires_grad=True)
174
+ if layer_scale_init_values is not None
175
+ else 1.0
176
+ )
177
+ self.gamma_2 = (
178
+ nn.Parameter(layer_scale_init_values * torch.ones((dim)), requires_grad=True)
179
+ if layer_scale_init_values is not None
180
+ else 1.0
181
+ )
182
+
183
+ self.max_text_len = max_text_len
184
+
185
+ def forward(self, x, mask=None, modality_type=None, relative_position_bias=None):
186
+ x = x + self.drop_path(
187
+ self.gamma_1 * self.attn(self.norm1(x), mask=mask, relative_position_bias=relative_position_bias)
188
+ )
189
+
190
+ if modality_type == "image":
191
+ x = x + self.drop_path(self.gamma_2 * self.mlp_imag(self.norm2_imag(x)))
192
+ elif modality_type == "text":
193
+ x = x + self.drop_path(self.gamma_2 * self.mlp_text(self.norm2_text(x)))
194
+ else:
195
+ if self.mlp_vl is None:
196
+ x_text = x[:, : self.max_text_len]
197
+ x_imag = x[:, self.max_text_len :]
198
+ x_text = x_text + self.drop_path(self.gamma_2 * self.mlp_text(self.norm2_text(x_text)))
199
+ x_imag = x_imag + self.drop_path(self.gamma_2 * self.mlp_imag(self.norm2_imag(x_imag)))
200
+ x = torch.cat([x_text, x_imag], dim=1)
201
+ else:
202
+ x = x + self.drop_path(self.gamma_2 * self.mlp_vl(self.norm2_vl(x)))
203
+
204
+ return x
205
+
206
+
207
+ class PatchEmbed(nn.Module):
208
+ """Image to Patch Embedding"""
209
+
210
+ def __init__(
211
+ self,
212
+ img_size=224,
213
+ patch_size=16,
214
+ in_chans=3,
215
+ embed_dim=768,
216
+ no_patch_embed_bias=False,
217
+ ):
218
+ super().__init__()
219
+ img_size = to_2tuple(img_size)
220
+ patch_size = to_2tuple(patch_size)
221
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
222
+ self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
223
+ self.img_size = img_size
224
+ self.patch_size = patch_size
225
+ self.num_patches = num_patches
226
+
227
+ self.proj = nn.Conv2d(
228
+ in_chans,
229
+ embed_dim,
230
+ kernel_size=patch_size,
231
+ stride=patch_size,
232
+ bias=False if no_patch_embed_bias else True,
233
+ )
234
+
235
+ def forward(self, x):
236
+ B, C, H, W = x.shape
237
+ assert (
238
+ H == self.img_size[0] and W == self.img_size[1]
239
+ ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
240
+ # FIXME look at relaxing size constraints
241
+ x = self.proj(x)
242
+ return x
243
+
244
+
245
+ class MultiWayTransformer(nn.Module):
246
+ """Vision Transformer
247
+
248
+ A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` -
249
+ https://arxiv.org/abs/2010.11929
250
+ """
251
+
252
+ def __init__(
253
+ self,
254
+ img_size=224,
255
+ patch_size=16,
256
+ in_chans=3,
257
+ embed_dim=768,
258
+ depth=12,
259
+ num_heads=12,
260
+ mlp_ratio=4.0,
261
+ qkv_bias=True,
262
+ qk_scale=None,
263
+ drop_rate=0.0,
264
+ attn_drop_rate=0.0,
265
+ drop_path_rate=0.0,
266
+ norm_layer=None,
267
+ need_relative_position_embed=True,
268
+ use_abs_pos_emb=False,
269
+ layer_scale_init_values=0.1,
270
+ vlffn_start_layer_index=10,
271
+ config=None,
272
+ ):
273
+ """
274
+ Args:
275
+ img_size (int, tuple): input image size
276
+ patch_size (int, tuple): patch size
277
+ in_chans (int): number of input channels
278
+ num_classes (int): number of classes for classification head
279
+ embed_dim (int): embedding dimension
280
+ depth (int): depth of transformer
281
+ num_heads (int): number of attention heads
282
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
283
+ qkv_bias (bool): enable bias for qkv if True
284
+ qk_scale (float): override default qk scale of head_dim ** -0.5 if set
285
+ drop_rate (float): dropout rate
286
+ attn_drop_rate (float): attention dropout rate
287
+ drop_path_rate (float): stochastic depth rate
288
+ norm_layer: (nn.Module): normalization layer
289
+ need_relative_position_embed (bool): enable relative position bias on self-attention
290
+ use_abs_pos_emb (bool): enable abs pos emb
291
+ layer_scale_init_values (float or None): layer scale init values, set None to disable
292
+ vlffn_start_layer_index (int): vl-ffn start index
293
+ config: (dict): other hyper from pytorch-lighting
294
+ """
295
+ super().__init__()
296
+ drop_path_rate = drop_path_rate if config is None else config["drop_path_rate"]
297
+ rank_zero_info("drop path rate: {}".format(drop_path_rate))
298
+ self.use_abs_pos_emb = use_abs_pos_emb
299
+ self.need_relative_position_embed = need_relative_position_embed
300
+
301
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
302
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
303
+
304
+ self.patch_embed = PatchEmbed(
305
+ img_size=img_size,
306
+ patch_size=patch_size,
307
+ in_chans=in_chans,
308
+ embed_dim=embed_dim,
309
+ )
310
+ num_patches = self.patch_embed.num_patches
311
+ self.patch_size = patch_size
312
+ self.num_heads = num_heads
313
+ self.vlffn_start_layer_index = vlffn_start_layer_index
314
+ if config["loss_names"]["textmlm"] > 0:
315
+ self.vlffn_start_layer_index = depth
316
+ rank_zero_info(
317
+ "Set vlffn_start_layer_index={} for text-only pretraining".format(self.vlffn_start_layer_index)
318
+ )
319
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
320
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) if self.use_abs_pos_emb else None
321
+ self.pos_drop = nn.Dropout(p=drop_rate)
322
+
323
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
324
+ self.blocks = nn.ModuleList(
325
+ [
326
+ Block(
327
+ dim=embed_dim,
328
+ num_heads=num_heads,
329
+ mlp_ratio=mlp_ratio,
330
+ qkv_bias=qkv_bias,
331
+ qk_scale=qk_scale,
332
+ drop=drop_rate,
333
+ attn_drop=attn_drop_rate,
334
+ drop_path=dpr[i],
335
+ norm_layer=norm_layer,
336
+ with_vlffn=(i >= self.vlffn_start_layer_index),
337
+ layer_scale_init_values=layer_scale_init_values,
338
+ max_text_len=config["max_text_len"],
339
+ )
340
+ for i in range(depth)
341
+ ]
342
+ )
343
+ self.norm = norm_layer(embed_dim)
344
+
345
+ if self.pos_embed is not None:
346
+ trunc_normal_(self.pos_embed, std=0.02)
347
+ trunc_normal_(self.cls_token, std=0.02)
348
+ self.apply(self._init_weights)
349
+
350
+ def _init_weights(self, m):
351
+ if isinstance(m, nn.Linear):
352
+ trunc_normal_(m.weight, std=0.02)
353
+ if isinstance(m, nn.Linear) and m.bias is not None:
354
+ nn.init.constant_(m.bias, 0)
355
+ elif isinstance(m, nn.LayerNorm):
356
+ nn.init.constant_(m.bias, 0)
357
+ nn.init.constant_(m.weight, 1.0)
358
+
359
+ @torch.jit.ignore
360
+ def no_weight_decay(self):
361
+ return {"pos_embed", "cls_token"}
362
+
363
+ def visual_embed(self, _x):
364
+ x = self.patch_embed(_x)
365
+ x = x.flatten(2).transpose(1, 2)
366
+ B, L, _ = x.shape
367
+
368
+ cls_tokens = self.cls_token.expand(B, -1, -1)
369
+ x = torch.cat((cls_tokens, x), dim=1)
370
+
371
+ if self.pos_embed is not None:
372
+ x = x + self.pos_embed
373
+ x = self.pos_drop(x)
374
+
375
+ x_mask = torch.ones(x.shape[0], x.shape[1])
376
+
377
+ return x, x_mask
378
+
379
+
380
+ # VLMo base/p16
381
+ @register_model
382
+ def vlmo_base_patch16(pretrained=False, **kwargs):
383
+ img_size = kwargs.pop("img_size", 224)
384
+ model = MultiWayTransformer(
385
+ img_size=img_size,
386
+ patch_size=16,
387
+ embed_dim=768,
388
+ depth=12,
389
+ num_heads=12,
390
+ mlp_ratio=4,
391
+ qkv_bias=True,
392
+ vlffn_start_layer_index=10,
393
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
394
+ **kwargs,
395
+ )
396
+ return model
vlmo/modules/objectives.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+
4
+ def init_weights(module):
5
+ if isinstance(module, (nn.Linear, nn.Embedding)):
6
+ module.weight.data.normal_(mean=0.0, std=0.02)
7
+ elif isinstance(module, nn.LayerNorm):
8
+ module.bias.data.zero_()
9
+ module.weight.data.fill_(1.0)
10
+
11
+ if isinstance(module, nn.Linear) and module.bias is not None:
12
+ module.bias.data.zero_()
vlmo/modules/vlmo_module.py ADDED
@@ -0,0 +1,405 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ import time
4
+
5
+ import numpy as np
6
+ import pytorch_lightning as pl
7
+ import torch
8
+ import torch.distributed as dist
9
+ import torch.nn as nn
10
+ from pytorch_lightning.utilities.rank_zero import rank_zero_info
11
+ from timm.models import create_model
12
+ from transformers import AutoTokenizer, BertTokenizer, XLMRobertaTokenizer # noqa
13
+ from vlmo.modules import heads, objectives, vlmo_utils
14
+ from vlmo.tokenizer.tokenization_glm import GLMChineseTokenizer # noqa
15
+ from vlmo.torchscale.architecture.encoder import Encoder
16
+ from vlmo.torchscale.model.BEiT3 import BEiT3 as ts_backbone
17
+ from vlmo.transforms.utils import inception_normalize as img_norm
18
+
19
+ from .modeling_utils import _get_base_config, _get_large_config, _get_huge_config, trunc_normal_ # noqa
20
+
21
+
22
+ def convert_pl_ckpt(state_dict, num_visual_token=197):
23
+ print("start convert_pl_ckpt!!!")
24
+ new_state_dict = {}
25
+ for key in state_dict:
26
+ value = state_dict[key]
27
+ if "visual_tokenizer" in key:
28
+ continue
29
+ elif "backbone.encoder.embed_positions.A.weight" in key:
30
+ if value.shape[0] < num_visual_token + 2:
31
+ N = value.shape[0] - 3
32
+ dim = value.shape[-1]
33
+ class_pos_embed = value[:3, ]
34
+ patch_pos_embed = value[3:, ]
35
+ w0, h0 = int(math.sqrt(num_visual_token - 1)), int(math.sqrt(num_visual_token - 1))
36
+ patch_pos_embed = patch_pos_embed.float()
37
+ patch_pos_embed = nn.functional.interpolate(
38
+ patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
39
+ size=(w0, h0),
40
+ mode="area",
41
+ )
42
+ patch_pos_embed = patch_pos_embed.to(class_pos_embed.dtype)
43
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(-1, dim)
44
+ new_value = torch.cat((class_pos_embed, patch_pos_embed), dim=0)
45
+ new_state_dict[key] = new_value
46
+ print("reshape ", key, "raw shape: ", value.shape, "new shape: ", new_value.shape, num_visual_token)
47
+ elif value.shape[0] > num_visual_token + 2:
48
+ new_state_dict[key] = value[: num_visual_token + 2, :]
49
+ print("first ", key, "raw shape: ", value.shape, new_state_dict[key].shape, num_visual_token)
50
+ else:
51
+ new_state_dict[key] = value
52
+ print("raw shape")
53
+ else:
54
+ new_state_dict[key] = state_dict[key]
55
+
56
+ return new_state_dict
57
+
58
+
59
+ def convert_deepspeed_ckpt(state_dict, num_visual_token=197):
60
+ new_state_dict = {}
61
+ for key in state_dict:
62
+ if key.startswith("_forward_module."):
63
+ new_key = key[len("_forward_module."):]
64
+ value = state_dict[key]
65
+ new_state_dict[new_key] = value
66
+ if "visual_tokenizer.encoder.pos_embed" in new_key or "visual_tokenizer.decoder.pos_embed" in new_key:
67
+ if value.shape[1] != num_visual_token:
68
+ N = value.shape[1] - 1
69
+ dim = value.shape[-1]
70
+ class_pos_embed = value[:, 0]
71
+ patch_pos_embed = value[:, 1:]
72
+ w0, h0 = int(math.sqrt(num_visual_token - 1)), int(math.sqrt(num_visual_token - 1))
73
+ patch_pos_embed = patch_pos_embed.float()
74
+ patch_pos_embed = nn.functional.interpolate(
75
+ patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
76
+ size=(w0, h0),
77
+ mode="area",
78
+ )
79
+ patch_pos_embed = patch_pos_embed.to(class_pos_embed.dtype)
80
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
81
+ new_value = torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
82
+ new_state_dict[new_key] = new_value
83
+ print("reshape ", new_key, "raw shape: ", value.shape, "new_shape: ", new_value.shape)
84
+ if "backbone.encoder.embed_positions.A.weight" in new_key:
85
+ if value.shape[1] != num_visual_token + 2:
86
+ N = value.shape[0] - 3
87
+ dim = value.shape[-1]
88
+ class_pos_embed = value[:3, ]
89
+ patch_pos_embed = value[3:, ]
90
+ w0, h0 = int(math.sqrt(num_visual_token - 1)), int(math.sqrt(num_visual_token - 1))
91
+ patch_pos_embed = patch_pos_embed.float()
92
+ patch_pos_embed = nn.functional.interpolate(
93
+ patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
94
+ size=(w0, h0),
95
+ mode="area",
96
+ )
97
+ patch_pos_embed = patch_pos_embed.to(class_pos_embed.dtype)
98
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(-1, dim)
99
+ new_value = torch.cat((class_pos_embed, patch_pos_embed), dim=0)
100
+ new_state_dict[new_key] = new_value
101
+ print("reshape ", new_key, "raw shape: ", value.shape, "new_shape: ", new_value.shape)
102
+
103
+ else:
104
+ new_state_dict[key] = state_dict[key]
105
+
106
+ return new_state_dict
107
+
108
+
109
+ def get_visual_tokenizer(config):
110
+ tokenizer_name = config["tokenizer_model"]
111
+ print(f"Creating visual tokenizer: {tokenizer_name}")
112
+ model = create_model(
113
+ config["tokenizer_model"],
114
+ img_size=config["image_size"],
115
+ n_code=config["codebook_size"],
116
+ code_dim=config["codebook_dim"],
117
+ ).eval()
118
+ return model
119
+
120
+
121
+ def get_pretrained_tokenizer(tokenizer_type, from_pretrained):
122
+ _Tokenizer = eval(f"{tokenizer_type}")
123
+ if torch.distributed.is_initialized():
124
+ if torch.distributed.get_rank() == 0:
125
+ _Tokenizer.from_pretrained(from_pretrained)
126
+ torch.distributed.barrier()
127
+ return _Tokenizer.from_pretrained(from_pretrained)
128
+
129
+
130
+ class VLMo(pl.LightningModule):
131
+ def __init__(self, config):
132
+ super().__init__()
133
+ self.save_hyperparameters()
134
+ s_t = time.time()
135
+
136
+ # tokenizer & backbone
137
+ self.img_size = config["image_size"]
138
+ if not config["test_only"]:
139
+ self.visual_tokenizer = get_visual_tokenizer(config)
140
+ kwargs = {}
141
+ if "encoder_attention_heads" in config:
142
+ kwargs["encoder_attention_heads"] = config["encoder_attention_heads"]
143
+ if "atorch_config" in config and config["atorch_config"]:
144
+ checkpoint_activations = False # ?
145
+ else:
146
+ checkpoint_activations = config["checkpoint_activations"]
147
+ args = eval(f'_get_{config["beit_version"]}_config')(
148
+ img_size=config["image_size"],
149
+ patch_size=config["patch_size"],
150
+ vocab_size=config["vocab_size"],
151
+ encoder_layers=config["encoder_layers"],
152
+ encoder_embed_dim=config["encoder_embed_dim"],
153
+ checkpoint_activations=checkpoint_activations,
154
+ share_layer=config["share_layer"],
155
+ share_attn=config["share_attn"],
156
+ deepnorm=config["deepnorm"],
157
+ mask_ratio=config["mask_ratio"],
158
+ max_text_len=config["max_text_len"],
159
+ one_attn=config["one_attn"],
160
+ **kwargs,
161
+ )
162
+ self.num_features = args.encoder_embed_dim
163
+ self.out_features = config["out_embed_dim"]
164
+ self.cap_onlytext = config["cap_onlytext"]
165
+ self.lang = config["lang"]
166
+ self.num_frames = config["num_frames"]
167
+ self.tokenizer_type = config["tokenizer_type"]
168
+ self.text_tokenizer = get_pretrained_tokenizer(self.tokenizer_type, from_pretrained=config["tokenizer"]) # noqa
169
+ print("BEiT args", args.__dict__)
170
+ self.backbone = ts_backbone(args)
171
+
172
+ self.use_vl = config["beit3_vl_layers"] > 0
173
+ if self.use_vl:
174
+ args.encoder_layers = config["beit3_vl_layers"]
175
+ self.backbone_vl = Encoder(args)
176
+
177
+ self.norm = nn.LayerNorm(self.num_features, eps=1e-6)
178
+
179
+ # task layers
180
+ self.pooler = heads.Pooler(self.num_features)
181
+ self.pooler.apply(objectives.init_weights)
182
+
183
+ # contrastive loss (or sampling for global hard negative)
184
+ if config["loss_names"]["itc"] > 0:
185
+ self.itc_text_proj = heads.ITCHead(self.num_features, self.out_features)
186
+ self.itc_image_proj = heads.ITCHead(self.num_features, self.out_features)
187
+ self.itc_text_proj.apply(objectives.init_weights)
188
+ self.itc_image_proj.apply(objectives.init_weights)
189
+
190
+ self.itc_vl_text_proj = heads.ITCHead(self.num_features, self.out_features)
191
+ self.itc_vl_image_proj = heads.ITCHead(self.num_features, self.out_features)
192
+ self.itc_vl_text_proj.apply(objectives.init_weights)
193
+ self.itc_vl_image_proj.apply(objectives.init_weights)
194
+
195
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
196
+ self.logit_vl_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
197
+
198
+ lp_s_t = time.time()
199
+
200
+ self.load_pretrained_weight()
201
+ load_pretrain_time = time.time() - lp_s_t
202
+
203
+ self.current_tasks = list()
204
+
205
+ # ===================== load downstream (test_only) ======================
206
+
207
+ if self.hparams.config["load_path"] != "" and self.hparams.config["test_only"]:
208
+ rank_zero_info("Load ckpt from: {}".format(self.hparams.config["load_path"]))
209
+ ckpt = torch.load(self.hparams.config["load_path"], map_location="cpu")
210
+
211
+ state_dict = None
212
+
213
+ for state_dict_key in ("state_dict", "module", "model"):
214
+ if state_dict_key in ckpt:
215
+ rank_zero_info("Read state dict from ckpt[%s]. " % state_dict_key)
216
+ state_dict = ckpt[state_dict_key]
217
+ break
218
+ if state_dict_key == "module":
219
+ state_dict = convert_deepspeed_ckpt(state_dict, self.backbone.vision_embed.num_position_embeddings())
220
+ if state_dict_key == "state_dict":
221
+ state_dict = convert_pl_ckpt(state_dict, self.backbone.vision_embed.num_position_embeddings())
222
+ if state_dict is None:
223
+ if list(ckpt.keys())[0].startswith('_forward_module.'):
224
+ rank_zero_info("Read state dict from ckpt with _forward_module prefix. ")
225
+ state_dict = convert_deepspeed_ckpt(ckpt, self.backbone.vision_embed.num_position_embeddings())
226
+ else:
227
+ rank_zero_info("Read state dict from ckpt. ")
228
+ state_dict = ckpt
229
+
230
+ missing_keys, unexpected_keys = self.load_state_dict(state_dict, strict=False)
231
+ rank_zero_info("missing_keys: {}".format(missing_keys))
232
+ rank_zero_info("unexpected_keys: {}".format(unexpected_keys))
233
+
234
+ construct_time = time.time() - s_t
235
+ print(
236
+ f"Process {os.getpid()}. VLMo Constructor time: {construct_time}s;",
237
+ f"load_pretrain_time: {load_pretrain_time}s",
238
+ flush=True,
239
+ )
240
+ # coalesce backbone calls
241
+ self._coalesce_backbone = config["coalesce_backbone"]
242
+ self._mask_data = config["mask_data"]
243
+ self._backbone_inputs = {}
244
+ self._backbone_inputs_current_size = 0
245
+ self._backbone_inputs_keys = {}
246
+ self._backbone_outputs = None
247
+ self._default_attn_masks = {}
248
+ self._itc_group = None
249
+ self._itc_aggregate_dict = None
250
+ self._itc_mask = config["itc_mask"]
251
+ self._local_loss = config["local_loss"]
252
+ self._aggregate_nodes = config["aggregate_nodes"]
253
+ self.accumulated_batches_reached = False
254
+ vlmo_utils.set_task(self)
255
+ self._only_itc_single_machine = (
256
+ self._aggregate_nodes > 0 and len(self.current_tasks) == 1 and "itc" in self.current_tasks
257
+ )
258
+ self._split_data_for_imagemlm = config["split_data_for_imagemlm"]
259
+ self.log_metric_steps = config["log_metric_steps"]
260
+
261
+ def _init_weights(self, m):
262
+ if isinstance(m, nn.Linear):
263
+ trunc_normal_(m.weight, std=0.02)
264
+ if isinstance(m, nn.Linear) and m.bias is not None:
265
+ nn.init.constant_(m.bias, 0)
266
+ elif isinstance(m, nn.LayerNorm):
267
+ nn.init.constant_(m.bias, 0)
268
+ nn.init.constant_(m.weight, 1.0)
269
+
270
+ def fix_init_weight(self):
271
+ def rescale(param, layer_id):
272
+ param.div_(math.sqrt(2.0 * layer_id))
273
+
274
+ for layer_id, layer in enumerate(self.backbone.encoder.layers):
275
+ rescale(layer.self_attn.v_proj.A.weight.data, layer_id + 1)
276
+ rescale(layer.self_attn.v_proj.B.weight.data, layer_id + 1)
277
+ rescale(layer.self_attn.out_proj.A.weight.data, layer_id + 1)
278
+ rescale(layer.self_attn.out_proj.B.weight.data, layer_id + 1)
279
+ rescale(layer.ffn.A.fc2.weight.data, layer_id + 1)
280
+ rescale(layer.ffn.B.fc2.weight.data, layer_id + 1)
281
+
282
+ if self.use_vl:
283
+ pre_layers = len(self.backbone.encoder.layers) + 1
284
+ for layer_id, layer in enumerate(self.backbone_vl.layers):
285
+ rescale(layer.self_attn.v_proj.A.weight.data, layer_id + pre_layers)
286
+ rescale(layer.self_attn.v_proj.B.weight.data, layer_id + pre_layers)
287
+ rescale(layer.self_attn.out_proj.A.weight.data, layer_id + pre_layers)
288
+ rescale(layer.self_attn.out_proj.B.weight.data, layer_id + pre_layers)
289
+ rescale(layer.ffn.A.fc2.weight.data, layer_id + pre_layers)
290
+ rescale(layer.ffn.B.fc2.weight.data, layer_id + pre_layers)
291
+
292
+ def load_pretrained_weight(self):
293
+ if self.hparams.config["load_path"] != "" and not self.hparams.config["test_only"]:
294
+ config = self.hparams.config
295
+ ckpt = torch.load(self.hparams.config["load_path"], map_location="cpu")
296
+ rank_zero_info("Load ckpt from: {}".format(self.hparams.config["load_path"]))
297
+
298
+ state_dict = None
299
+
300
+ for state_dict_key in ("state_dict", "module", "model"):
301
+ if state_dict_key in ckpt:
302
+ rank_zero_info("Read state dict from ckpt[%s]. " % state_dict_key)
303
+ state_dict = ckpt[state_dict_key]
304
+ break
305
+ if state_dict_key == "module":
306
+ state_dict = convert_deepspeed_ckpt(state_dict, self.backbone.vision_embed.num_position_embeddings())
307
+ if state_dict_key == "state_dict":
308
+ state_dict = convert_pl_ckpt(state_dict, self.backbone.vision_embed.num_position_embeddings())
309
+ if state_dict is None:
310
+ if list(ckpt.keys())[0].startswith('_forward_module.'):
311
+ rank_zero_info("Read state dict from ckpt with _forward_module prefix. ")
312
+ state_dict = convert_deepspeed_ckpt(ckpt,
313
+ self.backbone.vision_embed.num_position_embeddings())
314
+ else:
315
+ rank_zero_info("Read state dict from ckpt. ")
316
+ state_dict = ckpt
317
+
318
+ missing_keys, unexpected_keys = self.load_state_dict(state_dict, strict=False)
319
+ missing_keys = [k for k in missing_keys if "itc_teacher" not in k]
320
+ rank_zero_info("missing_keys: {}".format(missing_keys))
321
+ rank_zero_info("unexpected_keys: {}".format(unexpected_keys))
322
+
323
+ def infer_text(
324
+ self,
325
+ batch,
326
+ mask_text=False,
327
+ ):
328
+ do_mlm = "_mlm" if mask_text else ""
329
+ text_ids = batch[f"text_ids{do_mlm}"]
330
+ text_labels = batch[f"text_labels{do_mlm}"]
331
+ text_masks = batch[f"text_masks"]
332
+ text_embed = self.backbone.text_embed(text_ids)
333
+ text_padding_position = 1 - text_masks
334
+ lffn_hiddens = self.backbone(
335
+ textual_tokens=text_ids,
336
+ text_padding_position=text_padding_position,
337
+ )["encoder_out"]
338
+ vlffn_hiddens = self.backbone_vl(
339
+ src_tokens=None,
340
+ token_embeddings=lffn_hiddens,
341
+ encoder_padding_mask=text_padding_position,
342
+ multiway_split_position=-1,
343
+ )["encoder_out"]
344
+
345
+ cls_feats = self.itc_text_proj(lffn_hiddens[:, 0])
346
+ cls_feats = cls_feats / cls_feats.norm(dim=-1, keepdim=True)
347
+
348
+ cls_vlffn_feats = self.itc_vl_text_proj(vlffn_hiddens[:, 0])
349
+ cls_vlffn_feats = cls_vlffn_feats / cls_vlffn_feats.norm(dim=-1, keepdim=True)
350
+
351
+ ret = {
352
+ "cls_feats": cls_feats,
353
+ "cls_vlffn_feats": cls_vlffn_feats,
354
+ "text_embed": text_embed,
355
+ }
356
+
357
+ return ret
358
+
359
+ def infer_image(
360
+ self,
361
+ batch,
362
+ mask_image=False,
363
+ image_token_type_idx=1,
364
+ image_embeds=None,
365
+ image_masks=None,
366
+ ):
367
+ if f"image_{image_token_type_idx - 1}" in batch:
368
+ imgkey = f"image_{image_token_type_idx - 1}"
369
+ else:
370
+ imgkey = "image"
371
+
372
+ img = batch[imgkey][0]
373
+ if mask_image:
374
+ image_masks = batch[f"{imgkey}_masks"][0].flatten(1)
375
+
376
+ with torch.no_grad():
377
+ img = self.visual_tokenizer.pre_process(img)
378
+ quantize, embed_ind, _ = self.visual_tokenizer.encode(img)
379
+ image_ids = embed_ind.view(img.shape[0], -1)
380
+
381
+ image_labels = torch.full_like(image_ids, -100)
382
+ bool_masked_pos = image_masks.to(torch.bool)
383
+ image_labels[bool_masked_pos] = image_ids[bool_masked_pos]
384
+
385
+ img_tensor = img_norm(img)
386
+ vffn_hiddens = self.backbone(visual_tokens=img_tensor)["encoder_out"]
387
+ vlffn_hiddens = self.backbone_vl(
388
+ src_tokens=None,
389
+ token_embeddings=vffn_hiddens,
390
+ multiway_split_position=-1,
391
+ )["encoder_out"]
392
+
393
+ cls_feats = self.itc_image_proj(vffn_hiddens[:, 0])
394
+ cls_feats = cls_feats / cls_feats.norm(dim=-1, keepdim=True)
395
+
396
+ cls_vlffn_feats = self.itc_vl_image_proj(vlffn_hiddens[:, 0])
397
+ cls_vlffn_feats = cls_vlffn_feats / cls_vlffn_feats.norm(dim=-1, keepdim=True)
398
+
399
+ ret = {
400
+ "image_feats": vffn_hiddens,
401
+ "cls_feats": cls_feats,
402
+ "cls_vlffn_feats": cls_vlffn_feats,
403
+ }
404
+
405
+ return ret
vlmo/modules/vlmo_utils.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def set_task(pl_module):
2
+ pl_module.current_tasks = [k for k, v in pl_module.hparams.config["loss_names"].items() if v >= 1]
3
+ return
4
+
5
+
6
+ def no_sync_module_apply(module, fn):
7
+ """FSDP module .apply will use _unshard_params_recurse which will sync params across ranks.
8
+ using this function when apply fn is unnecessary to sync params across ranks.
9
+ """
10
+ for child in module.children():
11
+ fn(child)
12
+ no_sync_module_apply(child, fn)
vlmo/tokenizer/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+ # Copyright (c) Antfin, Inc. All rights reserved.
3
+
4
+ from __future__ import absolute_import
5
+ from __future__ import division
6
+ from __future__ import print_function
vlmo/tokenizer/sp.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b7fe3bcc8d284fcb782691411e8b6fd4f45d7245565b094de6ab795e66bcd32f
3
+ size 2270960
vlmo/tokenizer/tokenization_glm.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from shutil import copyfile
3
+ from typing import Optional, Tuple, List, Union
4
+
5
+ import sentencepiece as spm
6
+ import torch
7
+ from transformers import PreTrainedTokenizer
8
+ from transformers.models.auto.tokenization_auto import get_tokenizer_config
9
+ from transformers.tokenization_utils_base import BatchEncoding
10
+ from transformers.utils import logging
11
+
12
+ logger = logging.get_logger(__name__)
13
+
14
+
15
+ class GLMBatchEncoding(BatchEncoding):
16
+ def to(self, device: Union[str, "torch.device"]) -> "BatchEncoding":
17
+ """
18
+ Send all values to device by calling `v.to(device)` (PyTorch only).
19
+
20
+ Args:
21
+ device (`str` or `torch.device`): The device to put the tensors on.
22
+
23
+ Returns:
24
+ [`BatchEncoding`]: The same instance after modification.
25
+ """
26
+
27
+ # This check catches things like APEX blindly calling "to" on all inputs to a module
28
+ # Otherwise it passes the casts down and casts the LongTensor containing the token idxs
29
+ # into a HalfTensor
30
+ if isinstance(device, str) or isinstance(device, int):
31
+ #if isinstance(device, str) or _is_torch_device(device) or isinstance(device, int):
32
+ self.data = {k: v.to(device=device) if torch.is_tensor(v) else v for k, v in self.data.items()}
33
+ else:
34
+ logger.warning(f"Attempting to cast a BatchEncoding to type {str(device)}. This is not supported.")
35
+ return self
36
+
37
+
38
+ class GLMTokenizerMixin:
39
+ @property
40
+ def sop_token(self) -> Optional[str]:
41
+ return "<|startofpiece|>"
42
+
43
+ @property
44
+ def sop_token_id(self) -> Optional[int]:
45
+ """
46
+ `Optional[int]`: Id of the start token in the vocabulary, used when training a model with autoregressive blank filling.
47
+ """
48
+ return self.convert_tokens_to_ids(self.sop_token)
49
+
50
+ @property
51
+ def eop_token(self) -> Optional[str]:
52
+ return "<|endofpiece|>"
53
+
54
+ @property
55
+ def eop_token_id(self) -> Optional[int]:
56
+ """
57
+ `Optional[int]`: Id of the end token in the vocabulary, used when training a model with autoregressive blank filling.
58
+ """
59
+ return self.convert_tokens_to_ids(self.eop_token)
60
+
61
+ @property
62
+ def gmask_token_id(self) -> int:
63
+ return self.convert_tokens_to_ids("[gMASK]")
64
+
65
+ @property
66
+ def smask_token_id(self) -> int:
67
+ return self.convert_tokens_to_ids("[sMASK]")
68
+
69
+ @property
70
+ def mask_token_ids(self):
71
+ return [self.mask_token_id, self.smask_token_id, self.gmask_token_id]
72
+
73
+ def _build_input_for_multiple_choice(self, context, choices):
74
+ context_id = context["input_ids"]
75
+ if torch.is_tensor(context_id):
76
+ context_id = context_id.tolist()
77
+
78
+ division = len(context_id)
79
+ mask_position = context_id.index(self.mask_token_id)
80
+
81
+ token = torch.tensor(context_id, dtype=torch.long)
82
+ attention_mask = [context["attention_mask"].expand(division, -1)]
83
+ position_id = torch.arange(division, dtype=torch.long)
84
+ block_position_id = torch.zeros(division, dtype=torch.long)
85
+
86
+ choice_ids, choice_indices = [], []
87
+
88
+ for choice_str in choices:
89
+ choice = torch.tensor(self(choice_str, add_special_tokens=False, padding=False)['input_ids'],
90
+ dtype=torch.long)
91
+ choice_ids.append(choice)
92
+ choice_indices.append(torch.arange(len(token), len(token) + len(choice), dtype=torch.long))
93
+ attention_mask.append(torch.tril(torch.ones((len(choice), len(choice)), dtype=torch.long)))
94
+
95
+ token = torch.cat((token, torch.tensor([self.sop_token_id], dtype=torch.long), choice[:-1]))
96
+ position_id = torch.cat((position_id, torch.tensor([mask_position] * len(choice), dtype=torch.long)))
97
+ block_position_id = torch.cat((block_position_id, torch.arange(1, 1 + len(choice), dtype=torch.long)))
98
+
99
+ attention_mask = torch.block_diag(*attention_mask)
100
+ attention_mask[division:, :division] = context["attention_mask"].unsqueeze(0)
101
+
102
+ return {
103
+ "input_ids": token,
104
+ "position_ids": torch.stack((position_id, block_position_id)),
105
+ "attention_mask": attention_mask,
106
+ "choice_ids": choice_ids,
107
+ "choice_indices": choice_indices
108
+ }
109
+
110
+ def _pad_batch(self, tokens, position_ids, attention_mask, max_seq_length):
111
+ pad_length = max_seq_length - len(tokens)
112
+ attention_mask = torch.nn.functional.pad(
113
+ attention_mask,
114
+ (0, pad_length, 0, pad_length),
115
+ mode="constant",
116
+ value=0,
117
+ )
118
+ tokens = torch.cat((tokens, torch.zeros(pad_length, dtype=torch.long)))
119
+ position_ids = torch.cat((position_ids, position_ids[..., -1:].expand(-1, pad_length)), dim=-1)
120
+ return tokens, position_ids, attention_mask
121
+
122
+ def _collate(self, samples):
123
+ TILE = 1
124
+ length_to_pad = (max(map(lambda spl: len(spl["input_ids"]), samples)) + TILE - 1) // TILE * TILE
125
+
126
+ token_batch, position_id_batch, attention_mask_batch = [], [], []
127
+ choices_batch, choice_target_ids_batch = [], []
128
+
129
+ for sample in samples:
130
+ token, position_id, attention_mask = self._pad_batch(
131
+ sample["input_ids"], sample["position_ids"], sample["attention_mask"], length_to_pad
132
+ )
133
+ token_batch.append(token)
134
+ position_id_batch.append(position_id)
135
+ attention_mask_batch.append(attention_mask)
136
+ choices_batch.append(sample["choice_ids"])
137
+ choice_target_ids_batch.append(sample["choice_indices"])
138
+ return {
139
+ "input_ids": torch.stack(token_batch),
140
+ "position_ids": torch.stack(position_id_batch),
141
+ "attention_mask": torch.stack(attention_mask_batch).unsqueeze(1),
142
+ "choice_ids": choices_batch,
143
+ "choice_indices": choice_target_ids_batch,
144
+ }
145
+
146
+ def build_inputs_for_multiple_choice(self, model_input: BatchEncoding, choices, max_length=None):
147
+ samples = [{key: value[i] for key, value in model_input.items()} for i in range(len(model_input["input_ids"]))]
148
+ samples = [self._build_input_for_multiple_choice(sample, choice) for sample, choice in
149
+ zip(samples, choices)]
150
+ inputs = self._collate(samples)
151
+ return GLMBatchEncoding(inputs)
152
+
153
+ def build_inputs_for_generation(self, model_input: BatchEncoding, max_gen_length=512, targets=None, padding=False):
154
+ mask_ids = self.mask_token_ids
155
+ input_ids = model_input.input_ids
156
+ batch_size, seq_length = input_ids.shape[:2]
157
+ position_id, block_position_id = list(range(seq_length)), [0 for _ in range(seq_length)]
158
+ position_ids, block_position_ids = [], []
159
+ labels = None
160
+ if targets is not None:
161
+ is_batched = isinstance(targets, (list, tuple))
162
+ targets = self(targets, add_special_tokens=False, padding=False).input_ids
163
+ if not is_batched:
164
+ targets = [targets]
165
+ assert len(targets) == len(input_ids)
166
+ targets = [(target + [self.eop_token_id])[:max_gen_length] for target in targets]
167
+ if not padding:
168
+ max_gen_length = max(map(len, targets))
169
+ targets = [[self.sop_token_id] + target for target in targets]
170
+ labels = [target[1:] for target in targets]
171
+ targets = [target + [self.pad_token_id] * (max_gen_length + 1 - len(target)) for target in targets]
172
+ labels = [label + [-100] * (max_gen_length - len(label)) for label in labels]
173
+ targets = torch.tensor(targets, dtype=input_ids.dtype, device=input_ids.device)
174
+ labels = torch.tensor(labels, dtype=input_ids.dtype, device=input_ids.device)
175
+ labels = torch.cat((input_ids.new_full((batch_size, seq_length), -100), labels), dim=1)
176
+ for i in range(batch_size):
177
+ mask_positions = []
178
+ for mask_id in mask_ids:
179
+ mask_positions += (input_ids[i] == mask_id).nonzero(as_tuple=True)[0].tolist()
180
+ if not mask_positions:
181
+ raise ValueError("Cannot find mask token in the input")
182
+ mask_positions.sort()
183
+ mask_pos = mask_positions[0]
184
+ position_ids.append(position_id + [mask_pos] * max_gen_length)
185
+ block_position_ids.append(block_position_id + list(range(1, max_gen_length + 1)))
186
+ position_ids = torch.tensor(position_ids, dtype=input_ids.dtype, device=input_ids.device)
187
+ block_position_ids = torch.tensor(block_position_ids, dtype=input_ids.dtype, device=input_ids.device)
188
+ position_ids = torch.stack((position_ids, block_position_ids), dim=1)
189
+ attention_mask = model_input.attention_mask
190
+ attention_mask = attention_mask.unsqueeze(1).expand(-1, seq_length + max_gen_length, -1)
191
+ generation_attention_mask = torch.cat([attention_mask.new_zeros((seq_length, max_gen_length)),
192
+ torch.tril(attention_mask.new_ones((max_gen_length, max_gen_length)))],
193
+ dim=0).unsqueeze(0).expand(batch_size, -1, -1)
194
+ attention_mask = torch.cat((attention_mask, generation_attention_mask), dim=2)
195
+ attention_mask = attention_mask.unsqueeze(1)
196
+ if targets is None:
197
+ input_ids = torch.cat((input_ids, input_ids.new_full((batch_size, 1), self.sop_token_id)), dim=-1)
198
+ else:
199
+ input_ids = torch.cat((input_ids, targets[:, :-1]), dim=1)
200
+ batch = {"input_ids": input_ids, "position_ids": position_ids}
201
+ if labels is None:
202
+ batch["generation_attention_mask"] = attention_mask
203
+ else:
204
+ batch["attention_mask"] = attention_mask
205
+ batch["labels"] = labels
206
+ return BatchEncoding(batch)
207
+
208
+ def encode_whitespaces(content):
209
+ for i in range(10, 1, -1):
210
+ content = content.replace(' '*i, f'<|blank_{i}|>')
211
+ return content
212
+
213
+ def decode_whitespaces(content):
214
+ for i in range(10, 1, -1):
215
+ content = content.replace(f'<|blank_{i}|>', ' '*i)
216
+ return content
217
+
218
+
219
+ class GLMChineseTokenizer(PreTrainedTokenizer, GLMTokenizerMixin):
220
+ vocab_files_names = {"vocab_file": "sp.model"}
221
+ truncation_side: str = "left"
222
+
223
+ def __init__(self, vocab_file, **kwargs):
224
+ self.vocab_file = vocab_file
225
+ self.sp_model = spm.SentencePieceProcessor()
226
+ self.sp_model.Load(vocab_file)
227
+ super().__init__(**kwargs)
228
+
229
+ @property
230
+ def vocab_size(self):
231
+ return len(self.sp_model)
232
+
233
+ def get_vocab(self):
234
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
235
+ vocab.update(self.added_tokens_encoder)
236
+ return vocab
237
+
238
+ def _tokenize(self, text, **kwargs):
239
+ text = encode_whitespaces(text)
240
+ return self.sp_model.EncodeAsPieces(text)
241
+ #return self.sp_model.EncodeAsPieces(text, out_type=str)
242
+
243
+ def _convert_token_to_id(self, token):
244
+ """Converts a token (str) in an id using the vocab."""
245
+ return self.sp_model.PieceToId(token)
246
+
247
+ def _convert_id_to_token(self, index):
248
+ """Converts an index (integer) in a token (str) using the vocab."""
249
+ return self.sp_model.IdToPiece(index)
250
+
251
+ def convert_tokens_to_string(self, tokens):
252
+ res = self.sp_model.DecodeIds(tokens)
253
+ return decode_whitespaces(res)
254
+
255
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
256
+ if not os.path.isdir(save_directory):
257
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
258
+ return
259
+ out_vocab_file = os.path.join(
260
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + self.vocab_files_names["vocab_file"]
261
+ )
262
+
263
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
264
+ copyfile(self.vocab_file, out_vocab_file)
265
+ elif not os.path.isfile(self.vocab_file):
266
+ with open(out_vocab_file, "wb") as fi:
267
+ content_spiece_model = self.sp_model.serialized_model_proto()
268
+ fi.write(content_spiece_model)
269
+
270
+ return (out_vocab_file,)
271
+
272
+ def build_inputs_with_special_tokens(
273
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
274
+ ) -> List[int]:
275
+ """
276
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
277
+ adding special tokens. A BERT sequence has the following format:
278
+
279
+ - single sequence: ``[CLS] X [SEP]``
280
+ - pair of sequences: ``[CLS] A [SEP] B [SEP]``
281
+
282
+ Args:
283
+ token_ids_0 (:obj:`List[int]`):
284
+ List of IDs to which the special tokens will be added.
285
+ token_ids_1 (:obj:`List[int]`, `optional`):
286
+ Optional second list of IDs for sequence pairs.
287
+
288
+ Returns:
289
+ :obj:`List[int]`: List of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
290
+ """
291
+ assert token_ids_1 is None
292
+ cls = [self.cls_token_id]
293
+ eos = [self.eos_token_id]
294
+ return cls + token_ids_0 + eos
295
+
296
+
297
+ class GLMTokenizer:
298
+ @classmethod
299
+ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
300
+ tokenizer_config = get_tokenizer_config(pretrained_model_name_or_path, **kwargs)
301
+ config_tokenizer_class = tokenizer_config.get("tokenizer_class")
302
+
303
+ if config_tokenizer_class == "GLMChineseTokenizer":
304
+ tokenizer_class = GLMChineseTokenizer
305
+ else:
306
+ raise NotImplementedError("Not implemented tokenizer type:", config_tokenizer_class)
307
+ return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
vlmo/tokenizer/tokenizer_config.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name_or_path": "THUDM/glm-10b-chinese",
3
+ "eos_token": "<|endoftext|>",
4
+ "pad_token": "<|endoftext|>",
5
+ "cls_token": "[CLS]",
6
+ "mask_token": "[MASK]",
7
+ "unk_token": "[UNK]",
8
+ "add_prefix_space": false,
9
+ "tokenizer_class": "GLMChineseTokenizer",
10
+ "use_fast": false,
11
+ "auto_map": {
12
+ "AutoTokenizer": [
13
+ "tokenization_glm.GLMChineseTokenizer",
14
+ null
15
+ ]
16
+ }
17
+ }
vlmo/torchscale/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Copyright (c) 2022 Microsoft
2
+ # Licensed under The MIT License [see LICENSE for details]
vlmo/torchscale/architecture/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Copyright (c) 2022 Microsoft
2
+ # Licensed under The MIT License [see LICENSE for details]
vlmo/torchscale/architecture/config.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022 Microsoft
2
+ # Licensed under The MIT License [see LICENSE for details]
3
+
4
+
5
+ class EncoderConfig(object):
6
+ def __init__(self, **kwargs):
7
+ self.encoder_embed_dim = kwargs.pop("encoder_embed_dim", 768)
8
+ self.encoder_attention_heads = kwargs.pop("encoder_attention_heads", 12)
9
+ self.encoder_ffn_embed_dim = kwargs.pop("encoder_ffn_embed_dim", 3072)
10
+ self.encoder_layers = kwargs.pop("encoder_layers", 12)
11
+ self.encoder_normalize_before = kwargs.pop("encoder_normalize_before", True)
12
+ self.normalize_output = kwargs.pop("normalize_output", True)
13
+ self.activation_fn = kwargs.pop("activation_fn", "gelu")
14
+ self.dropout = kwargs.pop("dropout", 0.0)
15
+ self.drop_path_rate = kwargs.pop("drop_path_rate", 0.0)
16
+ self.attention_dropout = kwargs.pop("attention_dropout", 0.0)
17
+ self.activation_dropout = kwargs.pop("activation_dropout", 0.0)
18
+ self.no_scale_embedding = kwargs.pop("no_scale_embedding", True)
19
+ self.layernorm_embedding = kwargs.pop("layernorm_embedding", False)
20
+ self.moe_freq = kwargs.pop("moe_freq", 0)
21
+ self.moe_top1_expert = kwargs.pop("moe_top1_expert", False)
22
+ self.moe_expert_count = kwargs.pop("moe_expert_count", 0)
23
+ self.moe_gating_use_fp32 = kwargs.pop("moe_gating_use_fp32", True)
24
+ self.moe_eval_capacity_token_fraction = kwargs.pop("moe_eval_capacity_token_fraction", 0.25)
25
+ self.moe_second_expert_policy = kwargs.pop("moe_second_expert_policy", "random")
26
+ self.moe_normalize_gate_prob_before_dropping = kwargs.pop("moe_normalize_gate_prob_before_dropping", False)
27
+ self.use_xmoe = kwargs.pop("use_xmoe", False)
28
+ self.rel_pos_buckets = kwargs.pop("rel_pos_buckets", 0)
29
+ self.max_rel_pos = kwargs.pop("max_rel_pos", 0)
30
+ self.deepnorm = kwargs.pop("deepnorm", False)
31
+ self.subln = kwargs.pop("subln", True)
32
+ self.bert_init = kwargs.pop("bert_init", False)
33
+ self.multiway = kwargs.pop("multiway", False)
34
+ self.share_encoder_input_output_embed = kwargs.pop("share_encoder_input_output_embed", False)
35
+ self.max_source_positions = kwargs.pop("max_source_positions", 1024)
36
+ self.no_output_layer = kwargs.pop("no_output_layer", False)
37
+ self.layernorm_eps = kwargs.pop("layernorm_eps", 1e-5)
38
+ self.share_layer = kwargs.pop("share_layer", False)
39
+ self.share_attn = kwargs.pop("share_attn", False)
40
+ self.mask_ratio = kwargs.pop("mask_ratio", 0)
41
+ self.max_text_len = kwargs.pop("max_text_len", 52)
42
+ self.one_attn = kwargs.pop('one_attn', False)
43
+
44
+
45
+ # Text
46
+ self.vocab_size = kwargs.pop("vocab_size", -1)
47
+ # Vision
48
+ self.img_size = kwargs.pop("img_size", 224)
49
+ self.patch_size = kwargs.pop("patch_size", 16)
50
+ self.in_chans = kwargs.pop("in_chans", 3)
51
+ # Fairscale
52
+ self.checkpoint_activations = kwargs.pop("checkpoint_activations", False)
53
+ self.fsdp = kwargs.pop("fsdp", False)
54
+ self.ddp_rank = kwargs.pop("ddp_rank", 0)
55
+ self.xpos_rel_pos = kwargs.pop("xpos_rel_pos", False)
56
+ self.xpos_scale_base = kwargs.pop("xpos_scale_base", 512)
57
+
58
+ if self.deepnorm:
59
+ self.encoder_normalize_before = False
60
+ self.subln = False
61
+ if self.subln:
62
+ self.encoder_normalize_before = True
63
+ self.deepnorm = False
64
+ if self.use_xmoe:
65
+ self.moe_normalize_gate_prob_before_dropping = True
66
+ self.moe_second_expert_policy = "random"
67
+ assert self.moe_freq > 0 and self.moe_expert_count > 0
68
+
69
+ def override(self, args):
70
+ for hp in self.__dict__.keys():
71
+ if getattr(args, hp, None) is not None:
72
+ self.__dict__[hp] = getattr(args, hp, None)
73
+
74
+
75
+ class DecoderConfig(object):
76
+ def __init__(self, **kwargs):
77
+ self.decoder_embed_dim = kwargs.pop("decoder_embed_dim", 768)
78
+ self.decoder_attention_heads = kwargs.pop("decoder_attention_heads", 12)
79
+ self.decoder_ffn_embed_dim = kwargs.pop("decoder_ffn_embed_dim", 3072)
80
+ self.decoder_layers = kwargs.pop("decoder_layers", 12)
81
+ self.decoder_normalize_before = kwargs.pop("decoder_normalize_before", True)
82
+ self.activation_fn = kwargs.pop("activation_fn", "gelu")
83
+ self.dropout = kwargs.pop("dropout", 0.0)
84
+ self.drop_path_rate = kwargs.pop("drop_path_rate", 0.0)
85
+ self.attention_dropout = kwargs.pop("attention_dropout", 0.0)
86
+ self.activation_dropout = kwargs.pop("activation_dropout", 0.0)
87
+ self.no_scale_embedding = kwargs.pop("no_scale_embedding", True)
88
+ self.layernorm_embedding = kwargs.pop("layernorm_embedding", False)
89
+ self.moe_freq = kwargs.pop("moe_freq", 0)
90
+ self.moe_top1_expert = kwargs.pop("moe_top1_expert", False)
91
+ self.moe_expert_count = kwargs.pop("moe_expert_count", 0)
92
+ self.moe_gating_use_fp32 = kwargs.pop("moe_gating_use_fp32", True)
93
+ self.moe_eval_capacity_token_fraction = kwargs.pop("moe_eval_capacity_token_fraction", 0.25)
94
+ self.moe_second_expert_policy = kwargs.pop("moe_second_expert_policy", "random")
95
+ self.moe_normalize_gate_prob_before_dropping = kwargs.pop("moe_normalize_gate_prob_before_dropping", False)
96
+ self.use_xmoe = kwargs.pop("use_xmoe", False)
97
+ self.rel_pos_buckets = kwargs.pop("rel_pos_buckets", 0)
98
+ self.max_rel_pos = kwargs.pop("max_rel_pos", 0)
99
+ self.deepnorm = kwargs.pop("deepnorm", False)
100
+ self.subln = kwargs.pop("subln", True)
101
+ self.bert_init = kwargs.pop("bert_init", False)
102
+ self.multiway = kwargs.pop("multiway", False)
103
+ self.share_decoder_input_output_embed = kwargs.pop("share_decoder_input_output_embed", False)
104
+ self.max_target_positions = kwargs.pop("max_target_positions", 1024)
105
+ self.no_output_layer = kwargs.pop("no_output_layer", False)
106
+ self.layernorm_eps = kwargs.pop("layernorm_eps", 1e-5)
107
+ # Text
108
+ self.vocab_size = kwargs.pop("vocab_size", -1)
109
+ # Fairscale
110
+ self.checkpoint_activations = kwargs.pop("checkpoint_activations", False)
111
+ self.fsdp = kwargs.pop("fsdp", False)
112
+ self.ddp_rank = kwargs.pop("ddp_rank", 0)
113
+ self.xpos_rel_pos = kwargs.pop("xpos_rel_pos", False)
114
+ self.xpos_scale_base = kwargs.pop("xpos_scale_base", 512)
115
+
116
+ if self.deepnorm:
117
+ self.decoder_normalize_before = False
118
+ self.subln = False
119
+ if self.subln:
120
+ self.decoder_normalize_before = True
121
+ self.deepnorm = False
122
+ if self.use_xmoe:
123
+ self.moe_normalize_gate_prob_before_dropping = True
124
+ self.moe_second_expert_policy = "random"
125
+ assert self.moe_freq > 0 and self.moe_expert_count > 0
126
+
127
+ def override(self, args):
128
+ for hp in self.__dict__.keys():
129
+ if getattr(args, hp, None) is not None:
130
+ self.__dict__[hp] = getattr(args, hp, None)
131
+
132
+
133
+ class EncoderDecoderConfig(object):
134
+ def __init__(self, **kwargs):
135
+ self.encoder_embed_dim = kwargs.pop("encoder_embed_dim", 768)
136
+ self.encoder_attention_heads = kwargs.pop("encoder_attention_heads", 12)
137
+ self.encoder_ffn_embed_dim = kwargs.pop("encoder_ffn_embed_dim", 3072)
138
+ self.encoder_layers = kwargs.pop("encoder_layers", 12)
139
+ self.encoder_normalize_before = kwargs.pop("encoder_normalize_before", True)
140
+ self.decoder_embed_dim = kwargs.pop("decoder_embed_dim", 768)
141
+ self.decoder_attention_heads = kwargs.pop("decoder_attention_heads", 12)
142
+ self.decoder_ffn_embed_dim = kwargs.pop("decoder_ffn_embed_dim", 3072)
143
+ self.decoder_layers = kwargs.pop("decoder_layers", 12)
144
+ self.decoder_normalize_before = kwargs.pop("decoder_normalize_before", True)
145
+ self.activation_fn = kwargs.pop("activation_fn", "gelu")
146
+ self.dropout = kwargs.pop("dropout", 0.0)
147
+ self.drop_path_rate = kwargs.pop("drop_path_rate", 0.0)
148
+ self.attention_dropout = kwargs.pop("attention_dropout", 0.0)
149
+ self.activation_dropout = kwargs.pop("activation_dropout", 0.0)
150
+ self.no_scale_embedding = kwargs.pop("no_scale_embedding", True)
151
+ self.layernorm_embedding = kwargs.pop("layernorm_embedding", False)
152
+ self.moe_freq = kwargs.pop("moe_freq", 0)
153
+ self.moe_top1_expert = kwargs.pop("moe_top1_expert", False)
154
+ self.moe_expert_count = kwargs.pop("moe_expert_count", 0)
155
+ self.moe_gating_use_fp32 = kwargs.pop("moe_gating_use_fp32", True)
156
+ self.moe_eval_capacity_token_fraction = kwargs.pop("moe_eval_capacity_token_fraction", 0.25)
157
+ self.moe_second_expert_policy = kwargs.pop("moe_second_expert_policy", "random")
158
+ self.moe_normalize_gate_prob_before_dropping = kwargs.pop("moe_normalize_gate_prob_before_dropping", False)
159
+ self.use_xmoe = kwargs.pop("use_xmoe", False)
160
+ self.rel_pos_buckets = kwargs.pop("rel_pos_buckets", 0)
161
+ self.max_rel_pos = kwargs.pop("max_rel_pos", 0)
162
+ self.deepnorm = kwargs.pop("deepnorm", False)
163
+ self.subln = kwargs.pop("subln", True)
164
+ self.bert_init = kwargs.pop("bert_init", False)
165
+ self.multiway = kwargs.pop("multiway", False)
166
+ self.share_all_embeddings = kwargs.pop("share_all_embeddings", False)
167
+ self.share_decoder_input_output_embed = kwargs.pop("share_decoder_input_output_embed", False)
168
+ self.max_source_positions = kwargs.pop("max_source_positions", 1024)
169
+ self.max_target_positions = kwargs.pop("max_target_positions", 1024)
170
+ self.no_output_layer = kwargs.pop("no_output_layer", False)
171
+ self.layernorm_eps = kwargs.pop("layernorm_eps", 1e-5)
172
+ # Text
173
+ self.vocab_size = kwargs.pop("vocab_size", -1)
174
+ # Fairscale
175
+ self.checkpoint_activations = kwargs.pop("checkpoint_activations", False)
176
+ self.fsdp = kwargs.pop("fsdp", False)
177
+ self.ddp_rank = kwargs.pop("ddp_rank", 0)
178
+ self.xpos_rel_pos = kwargs.pop("xpos_rel_pos", False)
179
+ self.xpos_scale_base = kwargs.pop("xpos_scale_base", 512)
180
+
181
+ if self.deepnorm:
182
+ self.encoder_normalize_before = False
183
+ self.decoder_normalize_before = False
184
+ self.subln = False
185
+ if self.subln:
186
+ self.encoder_normalize_before = True
187
+ self.decoder_normalize_before = True
188
+ self.deepnorm = False
189
+ if self.use_xmoe:
190
+ self.moe_normalize_gate_prob_before_dropping = True
191
+ self.moe_second_expert_policy = "random"
192
+ assert self.moe_freq > 0 and self.moe_expert_count > 0
193
+
194
+ def override(self, args):
195
+ for hp in self.__dict__.keys():
196
+ if getattr(args, hp, None) is not None:
197
+ self.__dict__[hp] = getattr(args, hp, None)
vlmo/torchscale/architecture/decoder.py ADDED
@@ -0,0 +1,428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022 Microsoft
2
+ # Licensed under The MIT License [see LICENSE for details]
3
+
4
+ import math
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ from fairscale.nn import checkpoint_wrapper, wrap
10
+
11
+ from vlmo.torchscale.architecture.utils import init_bert_params
12
+ from vlmo.torchscale.component.droppath import DropPath
13
+ from vlmo.torchscale.component.feedforward_network import FeedForwardNetwork, make_experts
14
+ from vlmo.torchscale.component.multihead_attention import MultiheadAttention
15
+ from vlmo.torchscale.component.relative_position_bias import RelativePositionBias
16
+ #from vlmo.torchscale.component.xmoe.moe_layer import MOELayer
17
+ #from vlmo.torchscale.component.xmoe.routing import Top1Gate, Top2Gate
18
+
19
+ try:
20
+ from apex.normalization import FusedLayerNorm as LayerNorm
21
+ except ModuleNotFoundError:
22
+ from torch.nn import LayerNorm
23
+
24
+
25
+ class DecoderLayer(nn.Module):
26
+ def __init__(
27
+ self,
28
+ args,
29
+ depth,
30
+ is_moe_layer=False,
31
+ is_encoder_decoder=False,
32
+ ):
33
+ super().__init__()
34
+ self.args = args
35
+ self.embed_dim = args.decoder_embed_dim
36
+ self.dropout_module = torch.nn.Dropout(args.dropout)
37
+
38
+ if args.drop_path_rate > 0:
39
+ drop_path_prob = np.linspace(0, args.drop_path_rate, args.decoder_layers)[depth]
40
+ self.drop_path = DropPath(drop_path_prob)
41
+ else:
42
+ self.drop_path = None
43
+
44
+ self.self_attn = self.build_self_attention(self.embed_dim, args)
45
+
46
+ self.normalize_before = args.decoder_normalize_before
47
+
48
+ self.self_attn_layer_norm = LayerNorm(self.embed_dim, eps=args.layernorm_eps)
49
+
50
+ if not is_encoder_decoder:
51
+ self.encoder_attn = None
52
+ self.encoder_attn_layer_norm = None
53
+ else:
54
+ self.encoder_attn = self.build_encoder_attention(self.embed_dim, args)
55
+ self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, eps=args.layernorm_eps)
56
+
57
+ self.is_moe_layer = is_moe_layer
58
+ self.ffn_dim = args.decoder_ffn_embed_dim
59
+
60
+ if not self.is_moe_layer:
61
+ self.ffn = self.build_ffn(
62
+ self.embed_dim,
63
+ self.args,
64
+ )
65
+ else:
66
+ if args.moe_top1_expert:
67
+ gate = Top1Gate(
68
+ self.embed_dim,
69
+ args.moe_expert_count,
70
+ use_fp32=args.moe_gating_use_fp32,
71
+ moe_eval_capacity_token_fraction=args.moe_eval_capacity_token_fraction,
72
+ use_xmoe=args.use_xmoe,
73
+ )
74
+ else:
75
+ gate = Top2Gate(
76
+ self.embed_dim,
77
+ args.moe_expert_count,
78
+ args.moe_gating_use_fp32,
79
+ args.moe_second_expert_policy,
80
+ args.moe_normalize_gate_prob_before_dropping,
81
+ args.moe_eval_capacity_token_fraction,
82
+ use_xmoe=args.use_xmoe,
83
+ )
84
+ experts = make_experts(args, self.embed_dim, self.ffn_dim)
85
+ self.moe_layer = MOELayer(gate, experts, args)
86
+
87
+ self.final_layer_norm = LayerNorm(self.embed_dim, eps=args.layernorm_eps)
88
+
89
+ if args.deepnorm:
90
+ if is_encoder_decoder:
91
+ self.alpha = math.pow(3.0 * args.decoder_layers, 0.25)
92
+ else:
93
+ self.alpha = math.pow(2.0 * args.decoder_layers, 0.25)
94
+ else:
95
+ self.alpha = 1.0
96
+
97
+ def build_ffn(self, embed_dim, args):
98
+ return FeedForwardNetwork(
99
+ embed_dim,
100
+ self.ffn_dim,
101
+ args.activation_fn,
102
+ args.dropout,
103
+ args.activation_dropout,
104
+ args.layernorm_eps,
105
+ args.subln,
106
+ )
107
+
108
+ def build_self_attention(self, embed_dim, args):
109
+ return MultiheadAttention(
110
+ args,
111
+ embed_dim,
112
+ args.decoder_attention_heads,
113
+ dropout=args.attention_dropout,
114
+ self_attention=True,
115
+ encoder_decoder_attention=False,
116
+ subln=args.subln,
117
+ )
118
+
119
+ def build_encoder_attention(self, embed_dim, args):
120
+ return MultiheadAttention(
121
+ args,
122
+ embed_dim,
123
+ args.decoder_attention_heads,
124
+ dropout=args.attention_dropout,
125
+ self_attention=False,
126
+ encoder_decoder_attention=True,
127
+ subln=args.subln,
128
+ )
129
+
130
+ def residual_connection(self, x, residual):
131
+ return residual * self.alpha + x
132
+
133
+ def forward(
134
+ self,
135
+ x,
136
+ encoder_out=None,
137
+ encoder_padding_mask=None,
138
+ incremental_state=None,
139
+ self_attn_mask=None,
140
+ self_attn_padding_mask=None,
141
+ self_attn_rel_pos=None,
142
+ cross_attn_rel_pos=None,
143
+ ):
144
+ residual = x
145
+ if self.normalize_before:
146
+ x = self.self_attn_layer_norm(x)
147
+
148
+ x, attn = self.self_attn(
149
+ query=x,
150
+ key=x,
151
+ value=x,
152
+ key_padding_mask=self_attn_padding_mask,
153
+ incremental_state=incremental_state,
154
+ attn_mask=self_attn_mask,
155
+ rel_pos=self_attn_rel_pos,
156
+ )
157
+ x = self.dropout_module(x)
158
+
159
+ if self.drop_path is not None:
160
+ x = self.drop_path(x)
161
+
162
+ x = self.residual_connection(x, residual)
163
+ if not self.normalize_before:
164
+ x = self.self_attn_layer_norm(x)
165
+
166
+ if self.encoder_attn is not None and encoder_out is not None:
167
+ residual = x
168
+ if self.normalize_before:
169
+ x = self.encoder_attn_layer_norm(x)
170
+
171
+ x, attn = self.encoder_attn(
172
+ query=x,
173
+ key=encoder_out,
174
+ value=encoder_out,
175
+ key_padding_mask=encoder_padding_mask,
176
+ incremental_state=None,
177
+ rel_pos=cross_attn_rel_pos,
178
+ )
179
+ x = self.dropout_module(x)
180
+
181
+ if self.drop_path is not None:
182
+ x = self.drop_path(x)
183
+
184
+ x = self.residual_connection(x, residual)
185
+ if not self.normalize_before:
186
+ x = self.encoder_attn_layer_norm(x)
187
+
188
+ residual = x
189
+ if self.normalize_before:
190
+ x = self.final_layer_norm(x)
191
+ if not self.is_moe_layer:
192
+ x = self.ffn(x)
193
+ l_aux = None
194
+ else:
195
+ x, l_aux = self.moe_layer(x)
196
+
197
+ if self.drop_path is not None:
198
+ x = self.drop_path(x)
199
+
200
+ x = self.residual_connection(x, residual)
201
+ if not self.normalize_before:
202
+ x = self.final_layer_norm(x)
203
+
204
+ return x, attn, None, l_aux
205
+
206
+
207
+ class Decoder(nn.Module):
208
+ def __init__(
209
+ self, args, embed_tokens=None, embed_positions=None, output_projection=None, is_encoder_decoder=False, **kwargs
210
+ ):
211
+ super().__init__(**kwargs)
212
+ self.args = args
213
+
214
+ self.dropout_module = torch.nn.Dropout(args.dropout)
215
+
216
+ embed_dim = args.decoder_embed_dim
217
+ self.embed_dim = embed_dim
218
+ self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt(embed_dim)
219
+
220
+ self.embed_tokens = embed_tokens
221
+ self.embed_positions = embed_positions
222
+
223
+ if output_projection is None and not args.no_output_layer and args.vocab_size > 0:
224
+ self.output_projection = self.build_output_projection(args)
225
+ else:
226
+ self.output_projection = output_projection
227
+
228
+ if args.layernorm_embedding:
229
+ self.layernorm_embedding = LayerNorm(embed_dim, eps=args.layernorm_eps)
230
+ else:
231
+ self.layernorm_embedding = None
232
+
233
+ self.layers = nn.ModuleList([])
234
+
235
+ moe_freq = args.moe_freq
236
+ for i in range(args.decoder_layers):
237
+ is_moe_layer = moe_freq != 0 and (i + 1) % moe_freq == 0
238
+ self.layers.append(
239
+ self.build_decoder_layer(
240
+ args,
241
+ depth=i,
242
+ is_moe_layer=is_moe_layer,
243
+ is_encoder_decoder=is_encoder_decoder,
244
+ )
245
+ )
246
+
247
+ self.num_layers = len(self.layers)
248
+
249
+ if args.decoder_normalize_before:
250
+ self.layer_norm = LayerNorm(embed_dim, eps=args.layernorm_eps)
251
+ else:
252
+ self.layer_norm = None
253
+
254
+ self.self_attn_relative_position = None
255
+ self.cross_attn_relative_position = None
256
+
257
+ if args.rel_pos_buckets > 0 and args.max_rel_pos > 0:
258
+ self.self_attn_relative_position = RelativePositionBias(
259
+ num_buckets=args.rel_pos_buckets,
260
+ max_distance=args.max_rel_pos,
261
+ n_heads=args.decoder_attention_heads,
262
+ )
263
+ if is_encoder_decoder:
264
+ self.cross_attn_relative_position = RelativePositionBias(
265
+ num_buckets=args.rel_pos_buckets,
266
+ max_distance=args.max_rel_pos,
267
+ n_heads=args.decoder_attention_heads,
268
+ )
269
+
270
+ if args.bert_init:
271
+ self.apply(init_bert_params)
272
+
273
+ if args.deepnorm:
274
+ if is_encoder_decoder:
275
+ init_scale = math.pow(12.0 * args.decoder_layers, 0.25)
276
+ else:
277
+ init_scale = math.pow(8.0 * args.decoder_layers, 0.25)
278
+ for name, p in self.named_parameters():
279
+ if "fc1" in name or "fc2" in name or "out_proj" in name or "v_proj" in name:
280
+ p.data.div_(init_scale)
281
+
282
+ if args.subln:
283
+ if is_encoder_decoder:
284
+ init_scale = math.sqrt(math.log(args.decoder_layers * 3))
285
+ else:
286
+ init_scale = math.sqrt(math.log(args.decoder_layers * 2))
287
+ for name, p in self.named_parameters():
288
+ if "encoder_attn" in name:
289
+ continue
290
+ if "fc1" in name or "fc2" in name or "out_proj" in name or "v_proj" in name:
291
+ p.data.mul_(init_scale)
292
+
293
+ def build_output_projection(
294
+ self,
295
+ args,
296
+ ):
297
+ if args.share_decoder_input_output_embed:
298
+ output_projection = torch.nn.Linear(
299
+ self.embed_tokens.weight.shape[1],
300
+ self.embed_tokens.weight.shape[0],
301
+ bias=False,
302
+ )
303
+ output_projection.weight = self.embed_tokens.weight
304
+ else:
305
+ output_projection = torch.nn.Linear(args.decoder_embed_dim, args.vocab_size, bias=False)
306
+ torch.nn.init.normal_(output_projection.weight, mean=0, std=args.decoder_embed_dim**-0.5)
307
+ return output_projection
308
+
309
+ def build_decoder_layer(self, args, depth, is_moe_layer=False, is_encoder_decoder=False):
310
+ layer = DecoderLayer(
311
+ args,
312
+ depth,
313
+ is_moe_layer=is_moe_layer,
314
+ is_encoder_decoder=is_encoder_decoder,
315
+ )
316
+ if args.checkpoint_activations:
317
+ layer = checkpoint_wrapper(layer)
318
+ if args.fsdp:
319
+ layer = wrap(layer)
320
+ return layer
321
+
322
+ def forward_embedding(
323
+ self,
324
+ tokens,
325
+ token_embedding=None,
326
+ incremental_state=None,
327
+ ):
328
+ positions = None
329
+ if self.embed_positions is not None:
330
+ positions = self.embed_positions(tokens, incremental_state=incremental_state)
331
+
332
+ if incremental_state is not None:
333
+ tokens = tokens[:, -1:]
334
+ if positions is not None:
335
+ positions = positions[:, -1:]
336
+
337
+ if token_embedding is None:
338
+ token_embedding = self.embed_tokens(tokens)
339
+
340
+ x = embed = self.embed_scale * token_embedding
341
+
342
+ if positions is not None:
343
+ x += positions
344
+
345
+ if self.layernorm_embedding is not None:
346
+ x = self.layernorm_embedding(x)
347
+
348
+ x = self.dropout_module(x)
349
+
350
+ return x, embed
351
+
352
+ def forward(
353
+ self,
354
+ prev_output_tokens,
355
+ self_attn_padding_mask=None,
356
+ encoder_out=None,
357
+ incremental_state=None,
358
+ features_only=False,
359
+ return_all_hiddens=False,
360
+ token_embeddings=None,
361
+ **kwargs
362
+ ):
363
+ # embed tokens and positions
364
+ x, _ = self.forward_embedding(prev_output_tokens, token_embeddings, incremental_state)
365
+
366
+ # relative position
367
+ self_attn_rel_pos_bias = None
368
+ slen = prev_output_tokens.size(1)
369
+ if self.self_attn_relative_position is not None:
370
+ self_attn_rel_pos_bias = self.self_attn_relative_position(batch_size=x.size(0), qlen=slen, klen=slen)
371
+ if incremental_state is not None:
372
+ self_attn_rel_pos_bias = self_attn_rel_pos_bias[-1:, :, :]
373
+ cross_attn_rel_pos_bias = None
374
+ if self.cross_attn_relative_position is not None:
375
+ cross_attn_rel_pos_bias = self.cross_attn_relative_position(
376
+ batch_size=x.size(0),
377
+ qlen=slen,
378
+ klen=encoder_out["encoder_out"].size(1),
379
+ )
380
+ if incremental_state is not None:
381
+ cross_attn_rel_pos_bias = cross_attn_rel_pos_bias[-1:, :, :]
382
+
383
+ # decoder layers
384
+ inner_states = [x]
385
+
386
+ if encoder_out is None:
387
+ l_aux = []
388
+ else:
389
+ l_aux = encoder_out["l_aux"] if "l_aux" in encoder_out else []
390
+
391
+ for idx, layer in enumerate(self.layers):
392
+ if incremental_state is None:
393
+ self_attn_mask = torch.triu(
394
+ torch.zeros([x.size(1), x.size(1)]).float().fill_(float("-inf")).type_as(x),
395
+ 1,
396
+ )
397
+ else:
398
+ self_attn_mask = None
399
+ if idx not in incremental_state:
400
+ incremental_state[idx] = {}
401
+
402
+ x, layer_attn, _, l_aux_i = layer(
403
+ x,
404
+ encoder_out["encoder_out"] if encoder_out is not None else None,
405
+ encoder_out["encoder_padding_mask"] if encoder_out is not None else None,
406
+ incremental_state[idx] if incremental_state is not None else None,
407
+ self_attn_mask=self_attn_mask,
408
+ self_attn_padding_mask=self_attn_padding_mask,
409
+ self_attn_rel_pos=self_attn_rel_pos_bias,
410
+ cross_attn_rel_pos=cross_attn_rel_pos_bias,
411
+ )
412
+ l_aux.append(l_aux_i)
413
+ inner_states.append(x)
414
+
415
+ if self.layer_norm is not None:
416
+ x = self.layer_norm(x)
417
+
418
+ if not features_only:
419
+ x = self.output_layer(x)
420
+
421
+ return x, {
422
+ "inner_states": inner_states,
423
+ "l_aux": l_aux,
424
+ "attn": None,
425
+ }
426
+
427
+ def output_layer(self, features):
428
+ return self.output_projection(features)
vlmo/torchscale/architecture/encoder.py ADDED
@@ -0,0 +1,489 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022 Microsoft
2
+ # Licensed under The MIT License [see LICENSE for details]
3
+
4
+ import math
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ from fairscale.nn import checkpoint_wrapper, wrap
10
+
11
+ try:
12
+ from apex.normalization import FusedLayerNorm as LayerNorm
13
+ except ModuleNotFoundError:
14
+ from torch.nn import LayerNorm
15
+
16
+ from vlmo.torchscale.architecture.utils import init_bert_params
17
+ from vlmo.torchscale.component.droppath import DropPath
18
+ from vlmo.torchscale.component.feedforward_network import FeedForwardNetwork, make_experts
19
+ from vlmo.torchscale.component.multihead_attention import MultiheadAttention
20
+ from vlmo.torchscale.component.multiway_network import MultiwayWrapper, set_split_position
21
+ from vlmo.torchscale.component.relative_position_bias import RelativePositionBias
22
+ #from vlmo.torchscale.component.xmoe.moe_layer import MOELayer
23
+ #from vlmo.torchscale.component.xmoe.routing import Top1Gate, Top2Gate
24
+ # from vlmo.modules.vlmo_utils import no_sync_module_apply
25
+ from pytorch_lightning.utilities.rank_zero import rank_zero_info
26
+
27
+ def no_sync_module_apply(module, fn):
28
+ """FSDP module .apply will use _unshard_params_recurse which will sync params across ranks.
29
+ using this function when apply fn is unnecessary to sync params across ranks.
30
+ """
31
+ for child in module.children():
32
+ fn(child)
33
+ no_sync_module_apply(child, fn)
34
+
35
+ class EncoderLayer(nn.Module):
36
+ def __init__(self, args, depth, attn=None, is_moe_layer=False, is_encoder_decoder=False):
37
+ super().__init__()
38
+ self.args = args
39
+ self.embed_dim = args.encoder_embed_dim
40
+ self.self_attn = self.build_self_attention(self.embed_dim, args) if attn is None else attn
41
+ self.self_attn_layer_norm = MultiwayWrapper(args, LayerNorm(self.embed_dim, eps=args.layernorm_eps))
42
+ self.dropout_module = torch.nn.Dropout(args.dropout)
43
+
44
+ if args.drop_path_rate > 0:
45
+ drop_path_prob = np.linspace(0, args.drop_path_rate, args.encoder_layers)[depth]
46
+ self.drop_path = DropPath(drop_path_prob)
47
+ else:
48
+ self.drop_path = None
49
+
50
+ self.normalize_before = args.encoder_normalize_before
51
+ self.is_moe_layer = is_moe_layer
52
+ self.ffn_dim = args.encoder_ffn_embed_dim
53
+
54
+ if not self.is_moe_layer:
55
+ self.ffn = MultiwayWrapper(
56
+ args,
57
+ self.build_ffn(
58
+ self.embed_dim,
59
+ self.args,
60
+ ),
61
+ )
62
+ else:
63
+ assert not self.args.multiway
64
+ if args.moe_top1_expert:
65
+ gate = Top1Gate(
66
+ self.embed_dim,
67
+ args.moe_expert_count,
68
+ use_fp32=args.moe_gating_use_fp32,
69
+ moe_eval_capacity_token_fraction=args.moe_eval_capacity_token_fraction,
70
+ use_xmoe=args.use_xmoe,
71
+ )
72
+ else:
73
+ gate = Top2Gate(
74
+ self.embed_dim,
75
+ args.moe_expert_count,
76
+ args.moe_gating_use_fp32,
77
+ args.moe_second_expert_policy,
78
+ args.moe_normalize_gate_prob_before_dropping,
79
+ args.moe_eval_capacity_token_fraction,
80
+ use_xmoe=args.use_xmoe,
81
+ )
82
+ experts = make_experts(args, self.embed_dim, self.ffn_dim)
83
+ self.moe_layer = MOELayer(gate, experts, args)
84
+ self.final_layer_norm = MultiwayWrapper(args, LayerNorm(self.embed_dim, eps=args.layernorm_eps))
85
+
86
+ if args.deepnorm:
87
+ if is_encoder_decoder:
88
+ self.alpha = math.pow(math.pow(args.encoder_layers, 4) * args.decoder_layers, 0.0625) * 0.81
89
+ else:
90
+ self.alpha = math.pow(2.0 * args.encoder_layers, 0.25)
91
+ else:
92
+ self.alpha = 1.0
93
+
94
+ def build_ffn(self, embed_dim, args):
95
+ return FeedForwardNetwork(
96
+ embed_dim,
97
+ self.ffn_dim,
98
+ args.activation_fn,
99
+ args.dropout,
100
+ args.activation_dropout,
101
+ args.layernorm_eps,
102
+ args.subln,
103
+ )
104
+
105
+ def build_self_attention(self, embed_dim, args):
106
+ return MultiheadAttention(
107
+ args,
108
+ embed_dim,
109
+ args.encoder_attention_heads,
110
+ dropout=args.attention_dropout,
111
+ self_attention=True,
112
+ encoder_decoder_attention=False,
113
+ subln=args.subln,
114
+ one_attn=args.one_attn,
115
+ )
116
+
117
+ def residual_connection(self, x, residual):
118
+ return residual * self.alpha + x
119
+
120
+ def forward(
121
+ self,
122
+ x,
123
+ encoder_padding_mask,
124
+ attn_mask=None,
125
+ rel_pos=None,
126
+ multiway_split_position=None,
127
+ incremental_state=None,
128
+ ):
129
+ if multiway_split_position is not None:
130
+ assert self.args.multiway
131
+ no_sync_module_apply(self, set_split_position(multiway_split_position))
132
+
133
+ if attn_mask is not None:
134
+ # float16: -1e8 equal 0
135
+ attn_mask = attn_mask.masked_fill(attn_mask.to(torch.bool), -1e8)
136
+
137
+ residual = x
138
+ if self.normalize_before:
139
+ x = self.self_attn_layer_norm(x)
140
+ x, _ = self.self_attn(
141
+ query=x,
142
+ key=x,
143
+ value=x,
144
+ key_padding_mask=encoder_padding_mask,
145
+ attn_mask=attn_mask,
146
+ rel_pos=rel_pos,
147
+ incremental_state=incremental_state,
148
+ )
149
+ x = self.dropout_module(x)
150
+
151
+ if self.drop_path is not None:
152
+ x = self.drop_path(x)
153
+
154
+ x = self.residual_connection(x, residual)
155
+ if not self.normalize_before:
156
+ x = self.self_attn_layer_norm(x)
157
+
158
+ residual = x
159
+ if self.normalize_before:
160
+ x = self.final_layer_norm(x)
161
+ if not self.is_moe_layer:
162
+ x = self.ffn(x)
163
+ l_aux = None
164
+ else:
165
+ x = x.transpose(0, 1)
166
+ x, l_aux = self.moe_layer(x)
167
+ x = x.transpose(0, 1)
168
+
169
+ if self.drop_path is not None:
170
+ x = self.drop_path(x)
171
+
172
+ x = self.residual_connection(x, residual)
173
+ if not self.normalize_before:
174
+ x = self.final_layer_norm(x)
175
+ return x, l_aux
176
+
177
+
178
+ class Encoder(nn.Module):
179
+ def __init__(
180
+ self, args, embed_tokens=None, embed_positions=None, output_projection=None, is_encoder_decoder=False, **kwargs
181
+ ):
182
+ self.args = args
183
+ super().__init__(**kwargs)
184
+
185
+ self.dropout_module = torch.nn.Dropout(args.dropout)
186
+
187
+ embed_dim = args.encoder_embed_dim
188
+ self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt(embed_dim)
189
+ self.mask_ratio = args.mask_ratio
190
+ self.max_text_len = args.max_text_len
191
+ self.vision_len = (args.img_size // args.patch_size) * (args.img_size // args.patch_size)
192
+
193
+ self.embed_tokens = embed_tokens
194
+ self.embed_positions = embed_positions
195
+
196
+ if output_projection is None and not is_encoder_decoder and not args.no_output_layer and args.vocab_size > 0:
197
+ self.output_projection = self.build_output_projection(args)
198
+ else:
199
+ self.output_projection = output_projection
200
+
201
+ if args.layernorm_embedding:
202
+ self.layernorm_embedding = MultiwayWrapper(args, LayerNorm(embed_dim, eps=args.layernorm_eps), dim=1)
203
+ else:
204
+ self.layernorm_embedding = None
205
+
206
+ self.layers = nn.ModuleList([])
207
+ if self.args.share_layer:
208
+ single_layer = self.build_encoder_layer(
209
+ args, depth=0, is_moe_layer=False, is_encoder_decoder=is_encoder_decoder
210
+ )
211
+ for i in range(args.encoder_layers):
212
+ self.layers.append(single_layer)
213
+ elif self.args.share_attn:
214
+ moe_freq = args.moe_freq
215
+ embed_dim = args.encoder_embed_dim
216
+ shared_attn = self.build_self_attention(embed_dim, self.args)
217
+ for i in range(args.encoder_layers):
218
+ is_moe_layer = moe_freq != 0 and (i + 1) % moe_freq == 0
219
+ self.layers.append(
220
+ self.build_encoder_layer(
221
+ args,
222
+ depth=i,
223
+ attn=shared_attn,
224
+ is_moe_layer=is_moe_layer,
225
+ is_encoder_decoder=is_encoder_decoder,
226
+ )
227
+ )
228
+
229
+ else:
230
+ moe_freq = args.moe_freq
231
+ for i in range(args.encoder_layers):
232
+ is_moe_layer = moe_freq != 0 and (i + 1) % moe_freq == 0
233
+ self.layers.append(
234
+ self.build_encoder_layer(
235
+ args,
236
+ depth=i,
237
+ is_moe_layer=is_moe_layer,
238
+ is_encoder_decoder=is_encoder_decoder,
239
+ )
240
+ )
241
+ self.num_layers = len(self.layers)
242
+
243
+ if args.encoder_normalize_before and args.normalize_output:
244
+ self.layer_norm = MultiwayWrapper(args, LayerNorm(embed_dim, eps=args.layernorm_eps))
245
+ else:
246
+ self.layer_norm = None
247
+
248
+ if args.rel_pos_buckets > 0 and args.max_rel_pos > 0:
249
+ self.relative_position = RelativePositionBias(
250
+ num_buckets=args.rel_pos_buckets,
251
+ max_distance=args.max_rel_pos,
252
+ n_heads=args.encoder_attention_heads,
253
+ )
254
+ else:
255
+ self.relative_position = None
256
+
257
+ if args.bert_init:
258
+ self.apply(init_bert_params)
259
+
260
+ if args.deepnorm:
261
+ if is_encoder_decoder:
262
+ init_scale = math.pow(math.pow(args.encoder_layers, 4) * args.decoder_layers, 0.0625) / 1.15
263
+ else:
264
+ init_scale = math.pow(8.0 * args.encoder_layers, 0.25)
265
+ for name, p in self.named_parameters():
266
+ if "fc1" in name or "fc2" in name or "out_proj" in name or "v_proj" in name:
267
+ p.data.div_(init_scale)
268
+
269
+ if args.subln:
270
+ if is_encoder_decoder:
271
+ init_scale = math.sqrt(math.log(3 * args.decoder_layers) * math.log(2 * args.encoder_layers) / 3)
272
+ else:
273
+ init_scale = math.sqrt(math.log(args.encoder_layers * 2))
274
+ for name, p in self.named_parameters():
275
+ if "fc1" in name or "fc2" in name or "out_proj" in name or "v_proj" in name:
276
+ p.data.mul_(init_scale)
277
+
278
+ def random_masking(self, x, mask_ratio):
279
+ N, L, D = x.shape # batch, length, dim
280
+ len_keep = int(L * (1 - mask_ratio))
281
+
282
+ noise = torch.rand(N, L - 1, device=x.device)
283
+ ids_shuffle = torch.argsort(noise, dim=1) + torch.ones(N, L - 1, device=x.device, dtype=int)
284
+ ids_keep = ids_shuffle[:, :len_keep]
285
+
286
+ x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
287
+
288
+ x0 = x[:, 0, :]
289
+ x0 = x0.reshape(N, 1, D)
290
+ x_masked_add = torch.cat([x0, x_masked], axis=1)
291
+ return x_masked_add, ids_keep
292
+
293
+ def build_self_attention(self, embed_dim, args):
294
+ return MultiheadAttention(
295
+ args,
296
+ embed_dim,
297
+ args.encoder_attention_heads,
298
+ dropout=args.attention_dropout,
299
+ self_attention=True,
300
+ encoder_decoder_attention=False,
301
+ subln=args.subln,
302
+ one_attn=args.one_attn,
303
+ )
304
+
305
+ def build_output_projection(
306
+ self,
307
+ args,
308
+ ):
309
+ if args.share_encoder_input_output_embed:
310
+ assert args.encoder_embedding_type == "language"
311
+ output_projection = torch.nn.Linear(
312
+ self.embed_tokens.weight.shape[1],
313
+ self.embed_tokens.weight.shape[0],
314
+ bias=False,
315
+ )
316
+ output_projection.weight = self.embed_tokens.weight
317
+ else:
318
+ output_projection = torch.nn.Linear(args.encoder_embed_dim, args.vocab_size, bias=False)
319
+ torch.nn.init.normal_(output_projection.weight, mean=0, std=args.encoder_embed_dim**-0.5)
320
+ return output_projection
321
+
322
+ def checkpointing_and_params_allgather(
323
+ self,
324
+ origin_layer,
325
+ ):
326
+ origin_forward = origin_layer.forward
327
+
328
+ from deepspeed import checkpointing
329
+ def forward(*args, **kwargs):
330
+ # deepspeed checkpoint not support kwargs
331
+ ret = checkpointing.checkpoint(origin_forward, *args, **kwargs)
332
+ return ret
333
+
334
+ return forward
335
+
336
+ def build_encoder_layer(self, args, depth, attn=None, is_moe_layer=False, is_encoder_decoder=False):
337
+ layer = EncoderLayer(
338
+ args,
339
+ depth,
340
+ attn,
341
+ is_moe_layer=is_moe_layer,
342
+ is_encoder_decoder=is_encoder_decoder,
343
+ )
344
+ if args.checkpoint_activations:
345
+ rank_zero_info("EncoderLayer params: %s", sum(p.numel() for p in layer.parameters() if p.requires_grad))
346
+ layer = checkpoint_wrapper(layer)
347
+ # layer.ffn = checkpoint_wrapper(layer.ffn,)
348
+ if args.fsdp:
349
+ layer = wrap(layer)
350
+ return layer
351
+
352
+ def checkpointing_layers(self):
353
+ for i, layer in enumerate(self.layers):
354
+ rank_zero_info(f"Checkpointing wrapper EncoderLayers: {i}")
355
+ self.layers[i] = checkpoint_wrapper(layer)
356
+
357
+ def forward_embedding(
358
+ self,
359
+ src_tokens,
360
+ token_embedding=None,
361
+ positions=None,
362
+ ):
363
+ if token_embedding is None:
364
+ token_embedding = self.embed_tokens(src_tokens)
365
+ x = embed = self.embed_scale * token_embedding
366
+ if self.embed_positions is not None:
367
+ if src_tokens is not None:
368
+ x = embed + self.embed_positions(src_tokens, positions=positions)
369
+ else:
370
+ x = embed + self.embed_positions(x, positions=positions)
371
+ is_flip, ids_keep = 0, None
372
+ if self.mask_ratio > 0:
373
+ if x.shape[1] == self.vision_len + 1:
374
+ x, ids_keep = self.random_masking(x, self.mask_ratio)
375
+ is_flip = 1
376
+ elif x.shape[1] == self.vision_len + self.max_text_len + 1:
377
+ vision_tokens = x[:, : self.vision_len + 1, :]
378
+ vision_tokens, ids_keep = self.random_masking(vision_tokens, self.mask_ratio)
379
+ x = torch.cat(
380
+ [
381
+ vision_tokens,
382
+ x[
383
+ :,
384
+ self.vision_len + 1 :,
385
+ ],
386
+ ],
387
+ dim=1,
388
+ )
389
+ is_flip = 2
390
+ if self.layernorm_embedding is not None:
391
+ x = self.layernorm_embedding(x)
392
+ x = self.dropout_module(x)
393
+ return x, embed, ids_keep, is_flip
394
+
395
+ def forward(
396
+ self,
397
+ src_tokens,
398
+ encoder_padding_mask=None,
399
+ attn_mask=None,
400
+ return_all_hiddens=False,
401
+ token_embeddings=None,
402
+ multiway_split_position=None,
403
+ features_only=False,
404
+ incremental_state=None,
405
+ positions=None,
406
+ **kwargs
407
+ ):
408
+ assert src_tokens is not None or token_embeddings is not None
409
+
410
+ if encoder_padding_mask is None:
411
+ if src_tokens is not None:
412
+ encoder_padding_mask = torch.zeros_like(src_tokens, device=src_tokens.device).bool()
413
+ else:
414
+ encoder_padding_mask = torch.zeros(
415
+ [token_embeddings.size(0), token_embeddings.size(1)],
416
+ device=token_embeddings.device,
417
+ ).bool()
418
+
419
+ if multiway_split_position is not None:
420
+ assert self.args.multiway
421
+ no_sync_module_apply(self, set_split_position(multiway_split_position))
422
+
423
+ x, encoder_embedding, ids_keep, is_flip = self.forward_embedding(src_tokens, token_embeddings, positions)
424
+ if is_flip > 0:
425
+ if is_flip == 2:
426
+ text_ids = (
427
+ torch.arange(
428
+ self.vision_len + 1, self.vision_len + 1 + self.max_text_len, device=x.device, dtype=torch.int64
429
+ )
430
+ .unsqueeze(0)
431
+ .repeat(ids_keep.shape[0], 1)
432
+ )
433
+ cls_ids = torch.zeros(ids_keep.shape[0], 1, device=x.device, dtype=torch.int64)
434
+ ids_keep = torch.cat([cls_ids, ids_keep, text_ids], dim=1)
435
+ elif is_flip == 1:
436
+ cls_ids = torch.zeros(ids_keep.shape[0], 1, device=x.device, dtype=torch.int64)
437
+ ids_keep = torch.cat([cls_ids, ids_keep], dim=1)
438
+ if encoder_padding_mask is not None:
439
+ encoder_padding_mask = torch.gather(encoder_padding_mask, dim=1, index=ids_keep)
440
+ if attn_mask is not None:
441
+ attn_mask = torch.gather(
442
+ attn_mask, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, attn_mask.shape[-1])
443
+ )
444
+ attn_mask = torch.gather(attn_mask, dim=2, index=ids_keep.unsqueeze(1).repeat(1, attn_mask.shape[1], 1))
445
+ if multiway_split_position > 0:
446
+ multiway_split_position = ids_keep.shape[1] - self.max_text_len
447
+ x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x))
448
+
449
+ encoder_states = []
450
+
451
+ if return_all_hiddens:
452
+ encoder_states.append(x)
453
+
454
+ rel_pos_bias = None
455
+ if self.relative_position is not None:
456
+ rel_pos_bias = self.relative_position(batch_size=x.size(0), qlen=x.size(1), klen=x.size(1))
457
+
458
+ l_aux = []
459
+ for idx, layer in enumerate(self.layers):
460
+ x, l_aux_i = layer(
461
+ x,
462
+ encoder_padding_mask=encoder_padding_mask if incremental_state is None else None,
463
+ attn_mask=attn_mask,
464
+ rel_pos=rel_pos_bias,
465
+ multiway_split_position=multiway_split_position,
466
+ incremental_state=incremental_state[idx] if incremental_state is not None else None,
467
+ )
468
+ if return_all_hiddens:
469
+ assert encoder_states is not None
470
+ encoder_states.append(x)
471
+ l_aux.append(l_aux_i)
472
+
473
+ if multiway_split_position is not None:
474
+ assert self.args.multiway
475
+ no_sync_module_apply(self, set_split_position(multiway_split_position))
476
+ if self.layer_norm is not None:
477
+ x = self.layer_norm(x)
478
+
479
+ if not features_only and self.output_projection is not None:
480
+ x = self.output_projection(x)
481
+
482
+ return {
483
+ "encoder_out": x,
484
+ "encoder_embedding": encoder_embedding,
485
+ "encoder_padding_mask": encoder_padding_mask,
486
+ "encoder_states": encoder_states,
487
+ "l_aux": l_aux,
488
+ "multiway_split_position": multiway_split_position,
489
+ }
vlmo/torchscale/architecture/encoder_decoder.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022 Microsoft
2
+ # Licensed under The MIT License [see LICENSE for details]
3
+
4
+ import torch.nn as nn
5
+
6
+ from vlmo.torchscale.architecture.decoder import Decoder
7
+ from vlmo.torchscale.architecture.encoder import Encoder
8
+
9
+
10
+ class EncoderDecoder(nn.Module):
11
+ def __init__(
12
+ self,
13
+ args,
14
+ encoder_embed_tokens=None,
15
+ encoder_embed_positions=None,
16
+ decoder_embed_tokens=None,
17
+ decoder_embed_positions=None,
18
+ output_projection=None,
19
+ **kwargs
20
+ ):
21
+ super().__init__()
22
+ self.args = args
23
+ if args.share_all_embeddings:
24
+ args.share_decoder_input_output_embed = True
25
+
26
+ self.encoder = Encoder(args, encoder_embed_tokens, encoder_embed_positions, is_encoder_decoder=True, **kwargs)
27
+
28
+ if args.share_all_embeddings and decoder_embed_tokens is None:
29
+ decoder_embed_tokens = self.encoder.embed_tokens
30
+
31
+ self.decoder = Decoder(
32
+ args, decoder_embed_tokens, decoder_embed_positions, output_projection, is_encoder_decoder=True, **kwargs
33
+ )
34
+
35
+ def forward(self, src_tokens, prev_output_tokens, return_all_hiddens=False, features_only=False, **kwargs):
36
+ encoder_out = self.encoder(src_tokens, return_all_hiddens=return_all_hiddens)
37
+ decoder_out = self.decoder(
38
+ prev_output_tokens,
39
+ encoder_out=encoder_out,
40
+ features_only=features_only,
41
+ return_all_hiddens=return_all_hiddens,
42
+ )
43
+ return decoder_out
vlmo/torchscale/architecture/utils.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022 Microsoft
2
+ # Licensed under The MIT License [see LICENSE for details]
3
+
4
+ import torch.nn as nn
5
+
6
+ from vlmo.torchscale.component.multihead_attention import MultiheadAttention
7
+ from vlmo.torchscale.component.multiway_network import MultiwayNetwork
8
+
9
+
10
+ def init_bert_params(module):
11
+ def normal_(data):
12
+ data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device))
13
+
14
+ if isinstance(module, nn.Linear):
15
+ normal_(module.weight.data)
16
+ if module.bias is not None:
17
+ module.bias.data.zero_()
18
+ if isinstance(module, nn.Embedding):
19
+ normal_(module.weight.data)
20
+ if module.padding_idx is not None:
21
+ module.weight.data[module.padding_idx].zero_()
22
+ if isinstance(module, MultiheadAttention):
23
+ if isinstance(module.q_proj, MultiwayNetwork):
24
+ normal_(module.q_proj.A.weight.data)
25
+ normal_(module.q_proj.B.weight.data)
26
+ normal_(module.k_proj.A.weight.data)
27
+ normal_(module.k_proj.B.weight.data)
28
+ normal_(module.v_proj.A.weight.data)
29
+ normal_(module.v_proj.B.weight.data)
30
+ else:
31
+ normal_(module.q_proj.weight.data)
32
+ normal_(module.k_proj.weight.data)
33
+ normal_(module.v_proj.weight.data)
vlmo/torchscale/component/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Copyright (c) 2022 Microsoft
2
+ # Licensed under The MIT License [see LICENSE for details]
vlmo/torchscale/component/droppath.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022 Microsoft
2
+ # Licensed under The MIT License [see LICENSE for details]
3
+
4
+ import torch.nn as nn
5
+ from timm.models.layers import drop_path
6
+
7
+
8
+ class DropPath(nn.Module):
9
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
10
+
11
+ def __init__(self, drop_prob=None):
12
+ super(DropPath, self).__init__()
13
+ self.drop_prob = drop_prob
14
+
15
+ def forward(self, x):
16
+ return drop_path(x, self.drop_prob, self.training)
17
+
18
+ def extra_repr(self):
19
+ return "p={}".format(self.drop_prob)
vlmo/torchscale/component/embedding.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022 Microsoft
2
+ # Licensed under The MIT License [see LICENSE for details]
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+
9
+ class VisionLanguageEmbedding(nn.Module):
10
+ def __init__(self, text_embed, vision_embed):
11
+ super().__init__()
12
+ self.text_embed = text_embed
13
+ self.vision_embed = vision_embed
14
+
15
+ def forward(self, textual_tokens, visual_tokens, **kwargs):
16
+ if textual_tokens is None:
17
+ return self.vision_embed(visual_tokens)
18
+
19
+ if visual_tokens is None:
20
+ return self.text_embed(textual_tokens)
21
+
22
+ x1 = self.vision_embed(visual_tokens)
23
+ x2 = self.text_embed(textual_tokens)
24
+
25
+ return torch.cat([x1, x2], dim=1)
26
+
27
+
28
+ class VisionEmbedding(nn.Module):
29
+ """Image to Patch Embedding"""
30
+
31
+ def __init__(
32
+ self,
33
+ img_size=224,
34
+ patch_size=16,
35
+ in_chans=3,
36
+ embed_dim=768,
37
+ contain_mask_token=False,
38
+ prepend_cls_token=False,
39
+ ):
40
+ super().__init__()
41
+ img_size = (img_size, img_size)
42
+ patch_size = (patch_size, patch_size)
43
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
44
+ self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
45
+ self.img_size = img_size
46
+ self.patch_size = patch_size
47
+ self.num_patches = num_patches
48
+
49
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
50
+
51
+ if contain_mask_token:
52
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
53
+ else:
54
+ self.mask_token = None
55
+
56
+ if prepend_cls_token:
57
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
58
+ else:
59
+ self.cls_token = None
60
+
61
+ def num_position_embeddings(self):
62
+ if self.cls_token is None:
63
+ return self.num_patches
64
+ else:
65
+ return self.num_patches + 1
66
+
67
+ def forward(self, x, masked_position=None, **kwargs):
68
+ B, C, H, W = x.shape
69
+ x = self.proj(x).flatten(2).transpose(1, 2)
70
+
71
+ batch_size, seq_len, _ = x.size()
72
+
73
+ if masked_position is not None:
74
+ assert self.mask_token is not None
75
+ mask_token = self.mask_token.expand(batch_size, seq_len, -1)
76
+ w = masked_position.unsqueeze(-1).type_as(mask_token)
77
+ x = x * (1 - w) + mask_token * w
78
+
79
+ if self.cls_token is not None:
80
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
81
+ x = torch.cat((cls_tokens, x), dim=1)
82
+
83
+ return x
84
+
85
+
86
+ class TextEmbedding(nn.Embedding):
87
+ def reset_parameters(self):
88
+ nn.init.normal_(self.weight, mean=0, std=self.embedding_dim**-0.5)
89
+ self._fill_padding_idx_with_zero()
90
+
91
+
92
+ class PositionalEmbedding(nn.Embedding):
93
+ def forward(
94
+ self,
95
+ x,
96
+ positions=None,
97
+ **kwargs,
98
+ ):
99
+ if positions is None:
100
+ # being consistent with Fairseq, which starts from 2.
101
+ positions = torch.arange(2, x.size(1) + 2, device=x.device).long().unsqueeze(0)
102
+ return F.embedding(
103
+ positions,
104
+ self.weight,
105
+ self.padding_idx,
106
+ self.max_norm,
107
+ self.norm_type,
108
+ self.scale_grad_by_freq,
109
+ self.sparse,
110
+ )
vlmo/torchscale/component/feedforward_network.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022 Microsoft
2
+ # Licensed under The MIT License [see LICENSE for details]
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ try:
9
+ from apex.normalization import FusedLayerNorm as LayerNorm
10
+ except ModuleNotFoundError:
11
+ from torch.nn import LayerNorm
12
+
13
+
14
+ class set_torch_seed(object):
15
+ def __init__(self, seed):
16
+ assert isinstance(seed, int)
17
+ self.rng_state = self.get_rng_state()
18
+
19
+ torch.manual_seed(seed)
20
+ if torch.cuda.is_available():
21
+ torch.cuda.manual_seed(seed)
22
+
23
+ def get_rng_state(self):
24
+ state = {"torch_rng_state": torch.get_rng_state()}
25
+ if torch.cuda.is_available():
26
+ state["cuda_rng_state"] = torch.cuda.get_rng_state()
27
+ return state
28
+
29
+ def set_rng_state(self, state):
30
+ torch.set_rng_state(state["torch_rng_state"])
31
+ if torch.cuda.is_available():
32
+ torch.cuda.set_rng_state(state["cuda_rng_state"])
33
+
34
+ def __enter__(self):
35
+ return self
36
+
37
+ def __exit__(self, *exc):
38
+ self.set_rng_state(self.rng_state)
39
+
40
+
41
+ def make_experts(args, embed_dim, expert_ffn_dim):
42
+ world_size = 1 if not torch.distributed.is_initialized() else torch.distributed.get_world_size()
43
+ expert_list = []
44
+ ddp_rank = args.ddp_rank
45
+ start_seed = torch.randint(1000000, (1,)).item()
46
+ # at least as many experts than gpus
47
+ if args.moe_expert_count >= world_size:
48
+ assert args.moe_expert_count % world_size == 0, f"{args.moe_expert_count}, {world_size}"
49
+ local_moe_expert_count = args.moe_expert_count // world_size
50
+ for i in range(local_moe_expert_count):
51
+ with set_torch_seed(start_seed + ddp_rank * local_moe_expert_count + i):
52
+ expert_list.append(
53
+ FeedForwardNetwork(
54
+ embed_dim,
55
+ expert_ffn_dim,
56
+ args.activation_fn,
57
+ args.dropout,
58
+ args.activation_dropout,
59
+ args.layernorm_eps,
60
+ args.subln,
61
+ )
62
+ )
63
+ else:
64
+ assert world_size % args.moe_expert_count == 0, f"{world_size}, {args.moe_expert_count}"
65
+
66
+ with set_torch_seed(start_seed + ddp_rank % args.moe_expert_count):
67
+ expert_list.append(
68
+ FeedForwardNetwork(
69
+ embed_dim,
70
+ expert_ffn_dim,
71
+ args.activation_fn,
72
+ args.dropout,
73
+ args.activation_dropout,
74
+ args.layernorm_eps,
75
+ args.subln,
76
+ )
77
+ )
78
+ experts = nn.ModuleList(expert_list)
79
+ return experts
80
+
81
+
82
+ def get_activation_fn(activation):
83
+ if activation == "relu":
84
+ return F.relu
85
+ elif activation == "gelu":
86
+ return F.gelu
87
+ else:
88
+ raise NotImplementedError
89
+
90
+
91
+ class FeedForwardNetwork(nn.Module):
92
+ def __init__(
93
+ self,
94
+ embed_dim,
95
+ ffn_dim,
96
+ activation_fn,
97
+ dropout,
98
+ activation_dropout,
99
+ layernorm_eps,
100
+ subln=False,
101
+ ):
102
+ super().__init__()
103
+ self.embed_dim = embed_dim
104
+ self.activation_fn = get_activation_fn(activation=str(activation_fn))
105
+ self.activation_dropout_module = torch.nn.Dropout(activation_dropout)
106
+ self.dropout_module = torch.nn.Dropout(dropout)
107
+ self.fc1 = nn.Linear(self.embed_dim, ffn_dim)
108
+ self.fc2 = nn.Linear(ffn_dim, self.embed_dim)
109
+ self.ffn_layernorm = LayerNorm(ffn_dim, eps=layernorm_eps) if subln else None
110
+
111
+ def reset_parameters(self):
112
+ self.fc1.reset_parameters()
113
+ self.fc2.reset_parameters()
114
+ if self.ffn_layernorm is not None:
115
+ self.ffn_layernorm.reset_parameters()
116
+
117
+ def forward(self, x):
118
+ # x = x.reshape(-1, x.size(-1))
119
+ x = self.fc1(x)
120
+ # x = self.activation_fn(x.float()).type_as(x)
121
+ x = self.activation_fn(x)
122
+ x = self.activation_dropout_module(x)
123
+ if self.ffn_layernorm is not None:
124
+ x = self.ffn_layernorm(x)
125
+ x = self.fc2(x)
126
+ # x = x.view(x_shape)
127
+ x = self.dropout_module(x)
128
+ return x
vlmo/torchscale/component/multihead_attention.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022 Microsoft
2
+ # Licensed under The MIT License [see LICENSE for details]
3
+
4
+ import math
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from torch import nn
9
+
10
+ try:
11
+ from apex.normalization import FusedLayerNorm as LayerNorm
12
+ except ModuleNotFoundError:
13
+ from torch.nn import LayerNorm
14
+
15
+ from .multiway_network import MultiwayWrapper
16
+ from .xpos_relative_position import XPOS
17
+
18
+
19
+ class MultiheadAttention(nn.Module):
20
+ def __init__(
21
+ self,
22
+ args,
23
+ embed_dim,
24
+ num_heads,
25
+ dropout=0.0,
26
+ self_attention=False,
27
+ encoder_decoder_attention=False,
28
+ subln=False,
29
+ one_attn=False,
30
+ ):
31
+ super().__init__()
32
+ self.args = args
33
+ self.embed_dim = embed_dim
34
+ self.num_heads = num_heads
35
+ self.head_dim = embed_dim // num_heads
36
+ self.scaling = self.head_dim ** (-0.5)
37
+ self.self_attention = self_attention
38
+ self.encoder_decoder_attention = encoder_decoder_attention
39
+ assert self.self_attention ^ self.encoder_decoder_attention
40
+ if one_attn:
41
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=True)
42
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True)
43
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True)
44
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
45
+ else:
46
+ self.k_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=True))
47
+ self.v_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=True))
48
+ self.q_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=True))
49
+ # self.qkv_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim*3, bias=True))
50
+ self.out_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=True))
51
+ self.inner_attn_ln = (
52
+ MultiwayWrapper(args, LayerNorm(self.embed_dim, eps=args.layernorm_eps))
53
+ if subln and self.self_attention
54
+ else None
55
+ )
56
+ self.dropout_module = torch.nn.Dropout(dropout)
57
+ self.xpos = XPOS(self.head_dim, args.xpos_scale_base) if args.xpos_rel_pos and self.self_attention else None
58
+
59
+ def reset_parameters(self):
60
+ nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
61
+ nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
62
+ nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
63
+ nn.init.xavier_uniform_(self.out_proj.weight)
64
+ nn.init.constant_(self.out_proj.bias, 0.0)
65
+
66
+ def forward(
67
+ self,
68
+ query,
69
+ key,
70
+ value,
71
+ incremental_state=None,
72
+ key_padding_mask=None,
73
+ attn_mask=None,
74
+ rel_pos=None,
75
+ ):
76
+ bsz, tgt_len, embed_dim = query.size()
77
+ src_len = tgt_len
78
+ assert embed_dim == self.embed_dim, f"query dim {embed_dim} != {self.embed_dim}"
79
+
80
+ key_bsz, src_len, _ = key.size()
81
+ assert key_bsz == bsz, f"{query.size(), key.size()}"
82
+ assert value is not None
83
+ assert bsz, src_len == value.shape[:2]
84
+ # if query is key and key is value:
85
+ # qkv = self.qkv_proj(query)
86
+ # else:
87
+ # # W*(q+k+v) = W(q) + W(k) + W(v)
88
+ # qkv = self.qkv_proj(query+key+value)
89
+ # q,k,v = qkv.split(self.embed_dim, dim=-1)
90
+
91
+ q = self.q_proj(query)
92
+ k = self.k_proj(key)
93
+ v = self.v_proj(value)
94
+
95
+ q = (q * self.scaling).view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2)
96
+ k = k.view(bsz, src_len, self.num_heads, self.head_dim).transpose(1, 2)
97
+ v = v.view(bsz, src_len, self.num_heads, self.head_dim).transpose(1, 2)
98
+ q = q.reshape(bsz * self.num_heads, tgt_len, self.head_dim)
99
+ k = k.reshape(bsz * self.num_heads, src_len, self.head_dim)
100
+ v = v.reshape(bsz * self.num_heads, src_len, self.head_dim)
101
+
102
+ if incremental_state is not None:
103
+ if "prev_key" in incremental_state:
104
+ prev_key = incremental_state["prev_key"].view(bsz * self.num_heads, -1, self.head_dim)
105
+ prev_value = incremental_state["prev_value"].view(bsz * self.num_heads, -1, self.head_dim)
106
+ k = torch.cat([prev_key, k], dim=1)
107
+ v = torch.cat([prev_value, v], dim=1)
108
+ incremental_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
109
+ incremental_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
110
+ src_len = k.size(1)
111
+
112
+ if self.xpos is not None:
113
+ if incremental_state is not None:
114
+ offset = src_len - 1
115
+ else:
116
+ offset = 0
117
+ k = self.xpos(k, offset=0, downscale=True)
118
+ q = self.xpos(q, offset=offset, downscale=False)
119
+
120
+ attn_weights = torch.bmm(q, k.transpose(1, 2))
121
+
122
+ if attn_mask is not None:
123
+ attn_weights = torch.nan_to_num(attn_weights)
124
+ if len(attn_mask.shape) != len(attn_weights.shape):
125
+ attn_mask = attn_mask.unsqueeze(0)
126
+ else:
127
+ attn_mask = attn_mask.repeat_interleave(self.num_heads, dim=0)
128
+ attn_weights += attn_mask
129
+
130
+ if key_padding_mask is not None:
131
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
132
+ attn_weights = attn_weights.masked_fill(
133
+ key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
134
+ float("-inf"),
135
+ )
136
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
137
+
138
+ if rel_pos is not None:
139
+ rel_pos = rel_pos.view(attn_weights.size())
140
+ attn_weights = attn_weights + rel_pos
141
+
142
+ attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).type_as(attn_weights)
143
+ attn_probs = self.dropout_module(attn_weights)
144
+
145
+ attn = torch.bmm(attn_probs, v)
146
+ attn = attn.transpose(0, 1).reshape(tgt_len, bsz, embed_dim).transpose(0, 1)
147
+
148
+ if self.inner_attn_ln is not None:
149
+ attn = self.inner_attn_ln(attn)
150
+
151
+ attn = self.out_proj(attn)
152
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0)
153
+
154
+ return attn, attn_weights
vlmo/torchscale/component/multiway_network.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022 Microsoft
2
+ # Licensed under The MIT License [see LICENSE for details]
3
+
4
+ import copy
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+
10
+ def MultiwayWrapper(args, module, dim=1):
11
+ if args.multiway:
12
+ return MultiwayNetwork(module, dim=dim)
13
+ return module
14
+
15
+
16
+ def set_split_position(position):
17
+ def apply_fn(module):
18
+ if hasattr(module, "split_position"):
19
+ module.split_position = position
20
+
21
+ return apply_fn
22
+
23
+
24
+ class MultiwayNetwork(nn.Module):
25
+ def __init__(self, module, dim=1):
26
+ super().__init__()
27
+ self.dim = dim
28
+ self.A = module
29
+ self.B = copy.deepcopy(module)
30
+ self.B.reset_parameters()
31
+ self.split_position = -1
32
+
33
+ def forward(self, x, **kwargs):
34
+ if self.split_position == -1:
35
+ return self.A(x, **kwargs)
36
+ if self.split_position == 0:
37
+ return self.B(x, **kwargs)
38
+ x1, x2 = torch.split(
39
+ x,
40
+ [self.split_position, x.size(self.dim) - self.split_position],
41
+ dim=self.dim,
42
+ )
43
+ # x1, x2 = x[:self.split_position], x[self.split_position:]
44
+ y1, y2 = self.A(x1, **kwargs), self.B(x2, **kwargs)
45
+ return torch.cat([y1, y2], dim=self.dim)
46
+
47
+
48
+ class MutliwayEmbedding(MultiwayNetwork):
49
+ def __init__(self, modules, dim=1):
50
+ super(MultiwayNetwork, self).__init__()
51
+ self.dim = dim
52
+ assert len(modules) == 2
53
+ self.A = modules[0]
54
+ self.B = modules[1]
55
+ self.split_position = -1
vlmo/torchscale/component/relative_position_bias.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022 Microsoft
2
+ # Licensed under The MIT License [see LICENSE for details]
3
+
4
+ import math
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+
10
+ class RelativePositionBias(nn.Module):
11
+ def __init__(self, bidirectional=True, num_buckets=32, max_distance=128, n_heads=12):
12
+ super().__init__()
13
+ self.bidirectional = bidirectional
14
+ self.num_buckets = num_buckets
15
+ self.max_distance = max_distance
16
+ self.n_heads = n_heads
17
+ self.relative_attention_bias = nn.Embedding(self.num_buckets, self.n_heads)
18
+
19
+ @staticmethod
20
+ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
21
+ ret = 0
22
+ n = -relative_position
23
+ if bidirectional:
24
+ num_buckets //= 2
25
+ ret += (n < 0).to(torch.long) * num_buckets
26
+ n = torch.abs(n)
27
+ else:
28
+ n = torch.max(n, torch.zeros_like(n))
29
+
30
+ max_exact = num_buckets // 2
31
+ is_small = n < max_exact
32
+
33
+ val_if_large = max_exact + (
34
+ torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
35
+ ).to(torch.long)
36
+ val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
37
+
38
+ ret += torch.where(is_small, n, val_if_large)
39
+ return ret
40
+
41
+ def compute_bias(self, qlen, klen, step=None):
42
+ step = 0 if step is None else step
43
+ context_position = torch.arange(
44
+ step,
45
+ step + qlen,
46
+ dtype=torch.long,
47
+ device=self.relative_attention_bias.weight.device,
48
+ )[:, None]
49
+ memory_position = torch.arange(klen, dtype=torch.long, device=self.relative_attention_bias.weight.device)[
50
+ None, :
51
+ ]
52
+ relative_position = memory_position - context_position # shape (qlen, klen)
53
+
54
+ rp_bucket = self._relative_position_bucket(
55
+ relative_position, # shape (qlen, klen)
56
+ bidirectional=self.bidirectional,
57
+ num_buckets=self.num_buckets,
58
+ max_distance=self.max_distance,
59
+ )
60
+ rp_bucket = rp_bucket.to(self.relative_attention_bias.weight.device)
61
+ values = self.relative_attention_bias(rp_bucket) # shape (qlen, klen, num_heads)
62
+ values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, qlen, klen)
63
+ return values
64
+
65
+ def forward(self, batch_size, qlen, klen, step=None):
66
+ # shape (batch * num_heads, qlen, klen)
67
+ return self.compute_bias(qlen, klen, step).repeat(batch_size, 1, 1, 1).view(-1, qlen, klen)
vlmo/torchscale/component/xpos_relative_position.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022 Microsoft
2
+ # Licensed under The MIT License [see LICENSE for details]
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+
8
+ def fixed_pos_embedding(x):
9
+ seq_len, dim = x.shape
10
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim) / dim))
11
+ sinusoid_inp = torch.einsum("i , j -> i j", torch.arange(0, seq_len, dtype=torch.float), inv_freq).to(x)
12
+ return torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)
13
+
14
+
15
+ def rotate_every_two(x):
16
+ x1 = x[:, :, ::2]
17
+ x2 = x[:, :, 1::2]
18
+ x = torch.stack((-x2, x1), dim=-1)
19
+ return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)')\
20
+
21
+
22
+ def duplicate_interleave(m):
23
+ """
24
+ A simple version of `torch.repeat_interleave` for duplicating a matrix while interleaving the copy.
25
+ """
26
+ dim0 = m.shape[0]
27
+ m = m.view(-1, 1) # flatten the matrix
28
+ m = m.repeat(1, 2) # repeat all elements into the 2nd dimension
29
+ m = m.view(dim0, -1) # reshape into a matrix, interleaving the copy
30
+ return m
31
+
32
+
33
+ def apply_rotary_pos_emb(x, sin, cos, scale=1):
34
+ sin, cos = map(lambda t: duplicate_interleave(t * scale), (sin, cos))
35
+ # einsum notation for lambda t: repeat(t[offset:x.shape[1]+offset,:], "n d -> () n () (d j)", j=2)
36
+ return (x * cos) + (rotate_every_two(x) * sin)
37
+
38
+
39
+ class XPOS(nn.Module):
40
+ def __init__(self, head_dim, scale_base=512):
41
+ super().__init__()
42
+ self.head_dim = head_dim
43
+ self.scale_base = scale_base
44
+ self.register_buffer("scale", (torch.arange(0, head_dim, 2) + 0.4 * head_dim) / (1.4 * head_dim))
45
+
46
+ def forward(self, x, offset=0, downscale=False):
47
+ length = x.shape[1]
48
+ min_pos = -(length + offset) // 2
49
+ max_pos = length + offset + min_pos
50
+ scale = self.scale ** torch.arange(min_pos, max_pos, 1).to(self.scale).div(self.scale_base)[:, None]
51
+ sin, cos = fixed_pos_embedding(scale)
52
+
53
+ if scale.shape[0] > length:
54
+ scale = scale[-length:]
55
+ sin = sin[-length:]
56
+ cos = cos[-length:]
57
+
58
+ if downscale:
59
+ scale = 1 / scale
60
+
61
+ x = apply_rotary_pos_emb(x, sin, cos, scale)
62
+ return x
vlmo/torchscale/model/BEiT3.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022 Microsoft
2
+ # Licensed under The MIT License [see LICENSE for details]
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from vlmo.torchscale.architecture.encoder import Encoder
8
+ from vlmo.torchscale.component.embedding import (
9
+ PositionalEmbedding,
10
+ TextEmbedding,
11
+ VisionEmbedding,
12
+ )
13
+ from vlmo.torchscale.component.multiway_network import MutliwayEmbedding
14
+
15
+
16
+ class BEiT3(nn.Module):
17
+ def __init__(self, args, **kwargs):
18
+ super().__init__()
19
+ self.args = args
20
+ assert args.multiway
21
+ assert args.vocab_size > 0
22
+ assert not args.share_encoder_input_output_embed
23
+ self.text_embed = TextEmbedding(args.vocab_size, args.encoder_embed_dim)
24
+ self.vision_embed = VisionEmbedding(
25
+ args.img_size,
26
+ args.patch_size,
27
+ args.in_chans,
28
+ args.encoder_embed_dim,
29
+ contain_mask_token=True,
30
+ prepend_cls_token=True,
31
+ )
32
+ # being consistent with Fairseq, which starts from 2 for position embedding
33
+ embed_positions = MutliwayEmbedding(
34
+ modules=[
35
+ PositionalEmbedding(self.vision_embed.num_position_embeddings() + 2, args.encoder_embed_dim),
36
+ PositionalEmbedding(args.max_source_positions, args.encoder_embed_dim),
37
+ ],
38
+ dim=1,
39
+ )
40
+ self.encoder = Encoder(
41
+ args,
42
+ embed_tokens=None,
43
+ embed_positions=embed_positions,
44
+ output_projection=None,
45
+ is_encoder_decoder=False,
46
+ )
47
+
48
+ def forward(
49
+ self,
50
+ textual_tokens=None,
51
+ visual_tokens=None,
52
+ text_padding_position=None,
53
+ attn_mask=None,
54
+ vision_masked_position=None,
55
+ incremental_state=None,
56
+ positions=None,
57
+ ):
58
+ assert textual_tokens is not None or visual_tokens is not None
59
+
60
+ if textual_tokens is None:
61
+ x = self.vision_embed(visual_tokens, vision_masked_position)
62
+ encoder_padding_mask = None
63
+ multiway_split_position = -1
64
+ elif visual_tokens is None:
65
+ x = self.text_embed(textual_tokens)
66
+ encoder_padding_mask = text_padding_position
67
+ multiway_split_position = 0
68
+ else:
69
+ x1 = self.vision_embed(visual_tokens, vision_masked_position)
70
+ multiway_split_position = x1.size(1)
71
+ x2 = self.text_embed(textual_tokens)
72
+ diff = x1.shape[0] // x2.shape[0]
73
+ if diff != 1:
74
+ x2 = torch.repeat_interleave(x2, diff, dim=0)
75
+ text_padding_position = torch.repeat_interleave(text_padding_position, diff, dim=0)
76
+ x = torch.cat([x1, x2], dim=1)
77
+ if text_padding_position is not None:
78
+ encoder_padding_mask = torch.cat(
79
+ [
80
+ torch.zeros(x1.shape[:-1], device=x1.device, dtype=torch.bool),
81
+ text_padding_position,
82
+ ],
83
+ dim=1,
84
+ )
85
+ else:
86
+ encoder_padding_mask = None
87
+ encoder_out = self.encoder(
88
+ src_tokens=None,
89
+ encoder_padding_mask=encoder_padding_mask,
90
+ attn_mask=attn_mask,
91
+ token_embeddings=x,
92
+ multiway_split_position=multiway_split_position,
93
+ incremental_state=incremental_state,
94
+ positions=positions,
95
+ )
96
+ return encoder_out
vlmo/torchscale/model/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Copyright (c) 2022 Microsoft
2
+ # Licensed under The MIT License [see LICENSE for details]
vlmo/transforms/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .pixelbert import (
2
+ pixelbert_transform,
3
+ pixelbert_transform_randaug,
4
+ )
5
+ from .square_transform import (
6
+ square_transform,
7
+ square_transform_randaug,
8
+ )
9
+
10
+ _transforms = {
11
+ "pixelbert": pixelbert_transform,
12
+ "pixelbert_randaug": pixelbert_transform_randaug,
13
+ "square_transform": square_transform,
14
+ "square_transform_randaug": square_transform_randaug,
15
+ }
16
+
17
+
18
+ def keys_to_transforms(keys: list, size=224):
19
+ return [_transforms[key](size=size) for key in keys]
vlmo/transforms/pixelbert.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .utils import (
2
+ inception_normalize,
3
+ MinMaxResize,
4
+ )
5
+ from torchvision import transforms
6
+ from .randaug import RandAugment
7
+
8
+
9
+ def pixelbert_transform(size=800):
10
+ longer = int((1333 / 800) * size)
11
+ return transforms.Compose(
12
+ [
13
+ MinMaxResize(shorter=size, longer=longer),
14
+ transforms.ToTensor(),
15
+ inception_normalize,
16
+ ]
17
+ )
18
+
19
+
20
+ def pixelbert_transform_randaug(size=800):
21
+ longer = int((1333 / 800) * size)
22
+ trs = transforms.Compose(
23
+ [
24
+ MinMaxResize(shorter=size, longer=longer),
25
+ transforms.ToTensor(),
26
+ inception_normalize,
27
+ ]
28
+ )
29
+ trs.transforms.insert(0, RandAugment(2, 9))
30
+ return trs
vlmo/transforms/randaug.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # code in this file is adpated from rpmcruz/autoaugment
2
+ # https://github.com/rpmcruz/autoaugment/blob/master/transformations.py
3
+ import random
4
+
5
+ import PIL
6
+
7
+ # from PIL import ImageOps, ImageEnhance, ImageDraw
8
+ import numpy as np
9
+ import torch
10
+ from PIL import Image
11
+
12
+
13
+ def ShearX(img, v): # [-0.3, 0.3]
14
+ assert -0.3 <= v <= 0.3
15
+ if random.random() > 0.5:
16
+ v = -v
17
+ return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0))
18
+
19
+
20
+ def ShearY(img, v): # [-0.3, 0.3]
21
+ assert -0.3 <= v <= 0.3
22
+ if random.random() > 0.5:
23
+ v = -v
24
+ return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0))
25
+
26
+
27
+ def TranslateX(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
28
+ assert -0.45 <= v <= 0.45
29
+ if random.random() > 0.5:
30
+ v = -v
31
+ v = v * img.size[0]
32
+ return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))
33
+
34
+
35
+ def TranslateXabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
36
+ assert 0 <= v
37
+ if random.random() > 0.5:
38
+ v = -v
39
+ return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))
40
+
41
+
42
+ def TranslateY(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
43
+ assert -0.45 <= v <= 0.45
44
+ if random.random() > 0.5:
45
+ v = -v
46
+ v = v * img.size[1]
47
+ return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))
48
+
49
+
50
+ def TranslateYabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
51
+ assert 0 <= v
52
+ if random.random() > 0.5:
53
+ v = -v
54
+ return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))
55
+
56
+
57
+ def Rotate(img, v): # [-30, 30]
58
+ assert -30 <= v <= 30
59
+ if random.random() > 0.5:
60
+ v = -v
61
+ return img.rotate(v)
62
+
63
+
64
+ def AutoContrast(img, _):
65
+ return PIL.ImageOps.autocontrast(img)
66
+
67
+
68
+ def Invert(img, _):
69
+ return PIL.ImageOps.invert(img)
70
+
71
+
72
+ def Equalize(img, _):
73
+ return PIL.ImageOps.equalize(img)
74
+
75
+
76
+ def Flip(img, _): # not from the paper
77
+ return PIL.ImageOps.mirror(img)
78
+
79
+
80
+ def Solarize(img, v): # [0, 256]
81
+ assert 0 <= v <= 256
82
+ return PIL.ImageOps.solarize(img, v)
83
+
84
+
85
+ def SolarizeAdd(img, addition=0, threshold=128):
86
+ img_np = np.array(img).astype(np.int)
87
+ img_np = img_np + addition
88
+ img_np = np.clip(img_np, 0, 255)
89
+ img_np = img_np.astype(np.uint8)
90
+ img = Image.fromarray(img_np)
91
+ return PIL.ImageOps.solarize(img, threshold)
92
+
93
+
94
+ def Posterize(img, v): # [4, 8]
95
+ v = int(v)
96
+ v = max(1, v)
97
+ return PIL.ImageOps.posterize(img, v)
98
+
99
+
100
+ def Contrast(img, v): # [0.1,1.9]
101
+ assert 0.1 <= v <= 1.9
102
+ return PIL.ImageEnhance.Contrast(img).enhance(v)
103
+
104
+
105
+ def Color(img, v): # [0.1,1.9]
106
+ assert 0.1 <= v <= 1.9
107
+ return PIL.ImageEnhance.Color(img).enhance(v)
108
+
109
+
110
+ def Brightness(img, v): # [0.1,1.9]
111
+ assert 0.1 <= v <= 1.9
112
+ return PIL.ImageEnhance.Brightness(img).enhance(v)
113
+
114
+
115
+ def Sharpness(img, v): # [0.1,1.9]
116
+ assert 0.1 <= v <= 1.9
117
+ return PIL.ImageEnhance.Sharpness(img).enhance(v)
118
+
119
+
120
+ def Cutout(img, v): # [0, 60] => percentage: [0, 0.2]
121
+ assert 0.0 <= v <= 0.2
122
+ if v <= 0.0:
123
+ return img
124
+
125
+ v = v * img.size[0]
126
+ return CutoutAbs(img, v)
127
+
128
+
129
+ def CutoutAbs(img, v): # [0, 60] => percentage: [0, 0.2]
130
+ # assert 0 <= v <= 20
131
+ if v < 0:
132
+ return img
133
+ w, h = img.size
134
+ x0 = np.random.uniform(w)
135
+ y0 = np.random.uniform(h)
136
+
137
+ x0 = int(max(0, x0 - v / 2.0))
138
+ y0 = int(max(0, y0 - v / 2.0))
139
+ x1 = min(w, x0 + v)
140
+ y1 = min(h, y0 + v)
141
+
142
+ xy = (x0, y0, x1, y1)
143
+ color = (125, 123, 114)
144
+ # color = (0, 0, 0)
145
+ img = img.copy()
146
+ PIL.ImageDraw.Draw(img).rectangle(xy, color)
147
+ return img
148
+
149
+
150
+ def SamplePairing(imgs): # [0, 0.4]
151
+ def f(img1, v):
152
+ i = np.random.choice(len(imgs))
153
+ img2 = PIL.Image.fromarray(imgs[i])
154
+ return PIL.Image.blend(img1, img2, v)
155
+
156
+ return f
157
+
158
+
159
+ def Identity(img, v):
160
+ return img
161
+
162
+
163
+ def augment_list(): # 16 oeprations and their ranges
164
+ # https://github.com/google-research/uda/blob/master/image/randaugment/policies.py#L57
165
+ # l = [
166
+ # (Identity, 0., 1.0),
167
+ # (ShearX, 0., 0.3), # 0
168
+ # (ShearY, 0., 0.3), # 1
169
+ # (TranslateX, 0., 0.33), # 2
170
+ # (TranslateY, 0., 0.33), # 3
171
+ # (Rotate, 0, 30), # 4
172
+ # (AutoContrast, 0, 1), # 5
173
+ # (Invert, 0, 1), # 6
174
+ # (Equalize, 0, 1), # 7
175
+ # (Solarize, 0, 110), # 8
176
+ # (Posterize, 4, 8), # 9
177
+ # # (Contrast, 0.1, 1.9), # 10
178
+ # (Color, 0.1, 1.9), # 11
179
+ # (Brightness, 0.1, 1.9), # 12
180
+ # (Sharpness, 0.1, 1.9), # 13
181
+ # # (Cutout, 0, 0.2), # 14
182
+ # # (SamplePairing(imgs), 0, 0.4), # 15
183
+ # ]
184
+
185
+ # https://github.com/tensorflow/tpu/blob/8462d083dd89489a79e3200bcc8d4063bf362186/models/official/efficientnet/autoaugment.py#L505
186
+ l = [
187
+ (AutoContrast, 0, 1),
188
+ (Equalize, 0, 1),
189
+ # (Invert, 0, 1),
190
+ (Rotate, 0, 30),
191
+ (Posterize, 0, 4),
192
+ (Solarize, 0, 256),
193
+ (SolarizeAdd, 0, 110),
194
+ (Color, 0.1, 1.9),
195
+ (Contrast, 0.1, 1.9),
196
+ (Brightness, 0.1, 1.9),
197
+ (Sharpness, 0.1, 1.9),
198
+ (ShearX, 0.0, 0.3),
199
+ (ShearY, 0.0, 0.3),
200
+ # (CutoutAbs, 0, 40),
201
+ (TranslateXabs, 0.0, 100),
202
+ (TranslateYabs, 0.0, 100),
203
+ ]
204
+
205
+ return l
206
+
207
+
208
+ class Lighting(object):
209
+ """Lighting noise(AlexNet - style PCA - based noise)"""
210
+
211
+ def __init__(self, alphastd, eigval, eigvec):
212
+ self.alphastd = alphastd
213
+ self.eigval = torch.Tensor(eigval)
214
+ self.eigvec = torch.Tensor(eigvec)
215
+
216
+ def __call__(self, img):
217
+ if self.alphastd == 0:
218
+ return img
219
+
220
+ alpha = img.new().resize_(3).normal_(0, self.alphastd)
221
+ rgb = (
222
+ self.eigvec.type_as(img)
223
+ .clone()
224
+ .mul(alpha.view(1, 3).expand(3, 3))
225
+ .mul(self.eigval.view(1, 3).expand(3, 3))
226
+ .sum(1)
227
+ .squeeze()
228
+ )
229
+
230
+ return img.add(rgb.view(3, 1, 1).expand_as(img))
231
+
232
+
233
+ class CutoutDefault(object):
234
+ """
235
+ Reference : https://github.com/quark0/darts/blob/master/cnn/utils.py
236
+ """
237
+
238
+ def __init__(self, length):
239
+ self.length = length
240
+
241
+ def __call__(self, img):
242
+ h, w = img.size(1), img.size(2)
243
+ mask = np.ones((h, w), np.float32)
244
+ y = np.random.randint(h)
245
+ x = np.random.randint(w)
246
+
247
+ y1 = np.clip(y - self.length // 2, 0, h)
248
+ y2 = np.clip(y + self.length // 2, 0, h)
249
+ x1 = np.clip(x - self.length // 2, 0, w)
250
+ x2 = np.clip(x + self.length // 2, 0, w)
251
+
252
+ mask[y1:y2, x1:x2] = 0.0
253
+ mask = torch.from_numpy(mask)
254
+ mask = mask.expand_as(img)
255
+ img *= mask
256
+ return img
257
+
258
+
259
+ class RandAugment:
260
+ def __init__(self, n, m):
261
+ self.n = n
262
+ self.m = m # [0, 30]
263
+ self.augment_list = augment_list()
264
+
265
+ def __call__(self, img):
266
+ ops = random.choices(self.augment_list, k=self.n)
267
+ for op, minval, maxval in ops:
268
+ val = (float(self.m) / 30) * float(maxval - minval) + minval
269
+ img = op(img, val)
270
+
271
+ return img
vlmo/transforms/randaugment.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+
4
+
5
+ # aug functions
6
+ def identity_func(img):
7
+ return img
8
+
9
+
10
+ def autocontrast_func(img, cutoff=0):
11
+ """
12
+ same output as PIL.ImageOps.autocontrast
13
+ """
14
+ n_bins = 256
15
+
16
+ def tune_channel(ch):
17
+ n = ch.size
18
+ cut = cutoff * n // 100
19
+ if cut == 0:
20
+ high, low = ch.max(), ch.min()
21
+ else:
22
+ hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
23
+ low = np.argwhere(np.cumsum(hist) > cut)
24
+ low = 0 if low.shape[0] == 0 else low[0]
25
+ high = np.argwhere(np.cumsum(hist[::-1]) > cut)
26
+ high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0]
27
+ if high <= low:
28
+ table = np.arange(n_bins)
29
+ else:
30
+ scale = (n_bins - 1) / (high - low)
31
+ offset = -low * scale
32
+ table = np.arange(n_bins) * scale + offset
33
+ table[table < 0] = 0
34
+ table[table > n_bins - 1] = n_bins - 1
35
+ table = table.clip(0, 255).astype(np.uint8)
36
+ return table[ch]
37
+
38
+ channels = [tune_channel(ch) for ch in cv2.split(img)]
39
+ out = cv2.merge(channels)
40
+ return out
41
+
42
+
43
+ def equalize_func(img):
44
+ """
45
+ same output as PIL.ImageOps.equalize
46
+ PIL's implementation is different from cv2.equalize
47
+ """
48
+ n_bins = 256
49
+
50
+ def tune_channel(ch):
51
+ hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
52
+ non_zero_hist = hist[hist != 0].reshape(-1)
53
+ step = np.sum(non_zero_hist[:-1]) // (n_bins - 1)
54
+ if step == 0:
55
+ return ch
56
+ n = np.empty_like(hist)
57
+ n[0] = step // 2
58
+ n[1:] = hist[:-1]
59
+ table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8)
60
+ return table[ch]
61
+
62
+ channels = [tune_channel(ch) for ch in cv2.split(img)]
63
+ out = cv2.merge(channels)
64
+ return out
65
+
66
+
67
+ def rotate_func(img, degree, fill=(0, 0, 0)):
68
+ """
69
+ like PIL, rotate by degree, not radians
70
+ """
71
+ H, W = img.shape[0], img.shape[1]
72
+ center = W / 2, H / 2
73
+ M = cv2.getRotationMatrix2D(center, degree, 1)
74
+ out = cv2.warpAffine(img, M, (W, H), borderValue=fill)
75
+ return out
76
+
77
+
78
+ def solarize_func(img, thresh=128):
79
+ """
80
+ same output as PIL.ImageOps.posterize
81
+ """
82
+ table = np.array([el if el < thresh else 255 - el for el in range(256)])
83
+ table = table.clip(0, 255).astype(np.uint8)
84
+ out = table[img]
85
+ return out
86
+
87
+
88
+ def color_func(img, factor):
89
+ """
90
+ same output as PIL.ImageEnhance.Color
91
+ """
92
+ # implementation according to PIL definition, quite slow
93
+ # degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis]
94
+ # out = blend(degenerate, img, factor)
95
+ # M = (
96
+ # np.eye(3) * factor
97
+ # + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor)
98
+ # )[np.newaxis, np.newaxis, :]
99
+ M = np.float32([[0.886, -0.114, -0.114], [-0.587, 0.413, -0.587], [-0.299, -0.299, 0.701]]) * factor + np.float32(
100
+ [[0.114], [0.587], [0.299]]
101
+ )
102
+ out = np.matmul(img, M).clip(0, 255).astype(np.uint8)
103
+ return out
104
+
105
+
106
+ def contrast_func(img, factor):
107
+ """
108
+ same output as PIL.ImageEnhance.Contrast
109
+ """
110
+ mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299]))
111
+ table = np.array([(el - mean) * factor + mean for el in range(256)]).clip(0, 255).astype(np.uint8)
112
+ out = table[img]
113
+ return out
114
+
115
+
116
+ def brightness_func(img, factor):
117
+ """
118
+ same output as PIL.ImageEnhance.Contrast
119
+ """
120
+ table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8)
121
+ out = table[img]
122
+ return out
123
+
124
+
125
+ def sharpness_func(img, factor):
126
+ """
127
+ The differences the this result and PIL are all on the 4 boundaries, the center
128
+ areas are same
129
+ """
130
+ kernel = np.ones((3, 3), dtype=np.float32)
131
+ kernel[1][1] = 5
132
+ kernel /= 13
133
+ degenerate = cv2.filter2D(img, -1, kernel)
134
+ if factor == 0.0:
135
+ out = degenerate
136
+ elif factor == 1.0:
137
+ out = img
138
+ else:
139
+ out = img.astype(np.float32)
140
+ degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :]
141
+ out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate)
142
+ out = out.astype(np.uint8)
143
+ return out
144
+
145
+
146
+ def shear_x_func(img, factor, fill=(0, 0, 0)):
147
+ H, W = img.shape[0], img.shape[1]
148
+ M = np.float32([[1, factor, 0], [0, 1, 0]])
149
+ out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
150
+ return out
151
+
152
+
153
+ def translate_x_func(img, offset, fill=(0, 0, 0)):
154
+ """
155
+ same output as PIL.Image.transform
156
+ """
157
+ H, W = img.shape[0], img.shape[1]
158
+ M = np.float32([[1, 0, -offset], [0, 1, 0]])
159
+ out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
160
+ return out
161
+
162
+
163
+ def translate_y_func(img, offset, fill=(0, 0, 0)):
164
+ """
165
+ same output as PIL.Image.transform
166
+ """
167
+ H, W = img.shape[0], img.shape[1]
168
+ M = np.float32([[1, 0, 0], [0, 1, -offset]])
169
+ out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
170
+ return out
171
+
172
+
173
+ def posterize_func(img, bits):
174
+ """
175
+ same output as PIL.ImageOps.posterize
176
+ """
177
+ out = np.bitwise_and(img, np.uint8(255 << (8 - bits)))
178
+ return out
179
+
180
+
181
+ def shear_y_func(img, factor, fill=(0, 0, 0)):
182
+ H, W = img.shape[0], img.shape[1]
183
+ M = np.float32([[1, 0, 0], [factor, 1, 0]])
184
+ out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
185
+ return out
186
+
187
+
188
+ def cutout_func(img, pad_size, replace=(0, 0, 0)):
189
+ replace = np.array(replace, dtype=np.uint8)
190
+ H, W = img.shape[0], img.shape[1]
191
+ rh, rw = np.random.random(2)
192
+ pad_size = pad_size // 2
193
+ ch, cw = int(rh * H), int(rw * W)
194
+ x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H)
195
+ y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W)
196
+ out = img.copy()
197
+ out[x1:x2, y1:y2, :] = replace
198
+ return out
199
+
200
+
201
+ # level to args
202
+ def enhance_level_to_args(MAX_LEVEL):
203
+ def level_to_args(level):
204
+ return ((level / MAX_LEVEL) * 1.8 + 0.1,)
205
+
206
+ return level_to_args
207
+
208
+
209
+ def shear_level_to_args(MAX_LEVEL, replace_value):
210
+ def level_to_args(level):
211
+ level = (level / MAX_LEVEL) * 0.3
212
+ if np.random.random() > 0.5:
213
+ level = -level
214
+ return (level, replace_value)
215
+
216
+ return level_to_args
217
+
218
+
219
+ def translate_level_to_args(translate_const, MAX_LEVEL, replace_value):
220
+ def level_to_args(level):
221
+ level = (level / MAX_LEVEL) * float(translate_const)
222
+ if np.random.random() > 0.5:
223
+ level = -level
224
+ return (level, replace_value)
225
+
226
+ return level_to_args
227
+
228
+
229
+ def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value):
230
+ def level_to_args(level):
231
+ level = int((level / MAX_LEVEL) * cutout_const)
232
+ return (level, replace_value)
233
+
234
+ return level_to_args
235
+
236
+
237
+ def solarize_level_to_args(MAX_LEVEL):
238
+ def level_to_args(level):
239
+ level = int((level / MAX_LEVEL) * 256)
240
+ return (level,)
241
+
242
+ return level_to_args
243
+
244
+
245
+ def none_level_to_args(level):
246
+ return ()
247
+
248
+
249
+ def posterize_level_to_args(MAX_LEVEL):
250
+ def level_to_args(level):
251
+ level = int((level / MAX_LEVEL) * 4)
252
+ return (level,)
253
+
254
+ return level_to_args
255
+
256
+
257
+ def rotate_level_to_args(MAX_LEVEL, replace_value):
258
+ def level_to_args(level):
259
+ level = (level / MAX_LEVEL) * 30
260
+ if np.random.random() < 0.5:
261
+ level = -level
262
+ return (level, replace_value)
263
+
264
+ return level_to_args
265
+
266
+
267
+ func_dict = {
268
+ "Identity": identity_func,
269
+ "AutoContrast": autocontrast_func,
270
+ "Equalize": equalize_func,
271
+ "Rotate": rotate_func,
272
+ "Solarize": solarize_func,
273
+ "Color": color_func,
274
+ "Contrast": contrast_func,
275
+ "Brightness": brightness_func,
276
+ "Sharpness": sharpness_func,
277
+ "ShearX": shear_x_func,
278
+ "TranslateX": translate_x_func,
279
+ "TranslateY": translate_y_func,
280
+ "Posterize": posterize_func,
281
+ "ShearY": shear_y_func,
282
+ }
283
+
284
+ translate_const = 10
285
+ MAX_LEVEL = 10
286
+ replace_value = (128, 128, 128)
287
+ arg_dict = {
288
+ "Identity": none_level_to_args,
289
+ "AutoContrast": none_level_to_args,
290
+ "Equalize": none_level_to_args,
291
+ "Rotate": rotate_level_to_args(MAX_LEVEL, replace_value),
292
+ "Solarize": solarize_level_to_args(MAX_LEVEL),
293
+ "Color": enhance_level_to_args(MAX_LEVEL),
294
+ "Contrast": enhance_level_to_args(MAX_LEVEL),
295
+ "Brightness": enhance_level_to_args(MAX_LEVEL),
296
+ "Sharpness": enhance_level_to_args(MAX_LEVEL),
297
+ "ShearX": shear_level_to_args(MAX_LEVEL, replace_value),
298
+ "TranslateX": translate_level_to_args(translate_const, MAX_LEVEL, replace_value),
299
+ "TranslateY": translate_level_to_args(translate_const, MAX_LEVEL, replace_value),
300
+ "Posterize": posterize_level_to_args(MAX_LEVEL),
301
+ "ShearY": shear_level_to_args(MAX_LEVEL, replace_value),
302
+ }
303
+
304
+
305
+ class RandomAugment(object):
306
+ def __init__(self, N=2, M=10, isPIL=False, augs=[]):
307
+ self.N = N
308
+ self.M = M
309
+ self.isPIL = isPIL
310
+ if augs:
311
+ self.augs = augs
312
+ else:
313
+ self.augs = list(arg_dict.keys())
314
+
315
+ def get_random_ops(self):
316
+ sampled_ops = np.random.choice(self.augs, self.N)
317
+ return [(op, 0.5, self.M) for op in sampled_ops]
318
+
319
+ def __call__(self, img):
320
+ if self.isPIL:
321
+ img = np.array(img)
322
+ ops = self.get_random_ops()
323
+ for name, prob, level in ops:
324
+ if np.random.random() > prob:
325
+ continue
326
+ args = arg_dict[name](level)
327
+ img = func_dict[name](img, *args)
328
+ return img
329
+
330
+
331
+ if __name__ == "__main__":
332
+ a = RandomAugment()
333
+ img = np.random.randn(32, 32, 3)
334
+ a(img)
vlmo/transforms/square_transform.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # code in this file is adpated from the ALBEF repo (https://github.com/salesforce/ALBEF)
2
+
3
+ from torchvision import transforms
4
+ from .randaugment import RandomAugment
5
+ from PIL import Image
6
+
7
+
8
+ def square_transform(size=224):
9
+ return transforms.Compose(
10
+ [
11
+ transforms.Resize((size, size), interpolation=Image.BICUBIC),
12
+ transforms.ToTensor(),
13
+ ]
14
+ )
15
+
16
+
17
+ def square_transform_randaug(size=224):
18
+ return transforms.Compose(
19
+ [
20
+ transforms.RandomResizedCrop(size, scale=(0.8, 1.0), interpolation=Image.BICUBIC),
21
+ transforms.RandomHorizontalFlip(),
22
+ RandomAugment(
23
+ 2,
24
+ 7,
25
+ isPIL=True,
26
+ augs=[
27
+ "Identity",
28
+ "AutoContrast",
29
+ "Equalize",
30
+ "Brightness",
31
+ "Sharpness",
32
+ "ShearX",
33
+ "ShearY",
34
+ "TranslateX",
35
+ "TranslateY",
36
+ "Rotate",
37
+ ],
38
+ ),
39
+ transforms.ToTensor(),
40
+ ]
41
+ )
vlmo/transforms/utils.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torchvision import transforms
2
+ from PIL import Image
3
+
4
+
5
+ class MinMaxResize:
6
+ def __init__(self, shorter=800, longer=1333):
7
+ self.min = shorter
8
+ self.max = longer
9
+
10
+ def __call__(self, x):
11
+ w, h = x.size
12
+ scale = self.min / min(w, h)
13
+ if h < w:
14
+ newh, neww = self.min, scale * w
15
+ else:
16
+ newh, neww = scale * h, self.min
17
+
18
+ if max(newh, neww) > self.max:
19
+ scale = self.max / max(newh, neww)
20
+ newh = newh * scale
21
+ neww = neww * scale
22
+
23
+ newh, neww = int(newh + 0.5), int(neww + 0.5)
24
+ newh, neww = newh // 32 * 32, neww // 32 * 32
25
+
26
+ return x.resize((neww, newh), resample=Image.BICUBIC)
27
+
28
+
29
+ class UnNormalize(object):
30
+ def __init__(self, mean, std):
31
+ self.mean = mean
32
+ self.std = std
33
+
34
+ def __call__(self, tensor):
35
+ """
36
+ Args:
37
+ tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
38
+ Returns:
39
+ Tensor: Normalized image.
40
+ """
41
+ for t, m, s in zip(tensor, self.mean, self.std):
42
+ t.mul_(s).add_(m)
43
+ # The normalize code -> t.sub_(m).div_(s)
44
+ return tensor
45
+
46
+
47
+ # This is simple maximum entropy normalization performed in Inception paper
48
+ inception_normalize = transforms.Compose([transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])
49
+
50
+ # ViT uses simple non-biased inception normalization
51
+ # https://github.com/google-research/vision_transformer/blob/master/vit_jax/input_pipeline.py#L132
52
+ inception_unnormalize = transforms.Compose([UnNormalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])
53
+
54
+ cn_clip_normalize = transforms.Compose(
55
+ [transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])]
56
+ )
vlmo/utils/__init__.py ADDED
File without changes