Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .DS_Store +0 -0
- .gitattributes +1 -0
- .msc +0 -0
- .mv +1 -0
- README.md +94 -0
- configuration.json +1 -0
- m2_encoder_1B.ckpt +3 -0
- ms_wrapper.py +219 -0
- requirements.txt +14 -0
- res/effect.png +3 -0
- vlmo/.DS_Store +0 -0
- vlmo/Encoder_0.4B.json +17 -0
- vlmo/README.md +10 -0
- vlmo/__init__.py +0 -0
- vlmo/config.py +165 -0
- vlmo/modules/__init__.py +1 -0
- vlmo/modules/heads.py +24 -0
- vlmo/modules/modeling_utils.py +179 -0
- vlmo/modules/multiway_transformer.py +396 -0
- vlmo/modules/objectives.py +12 -0
- vlmo/modules/vlmo_module.py +405 -0
- vlmo/modules/vlmo_utils.py +12 -0
- vlmo/tokenizer/__init__.py +6 -0
- vlmo/tokenizer/sp.model +3 -0
- vlmo/tokenizer/tokenization_glm.py +307 -0
- vlmo/tokenizer/tokenizer_config.json +17 -0
- vlmo/torchscale/__init__.py +2 -0
- vlmo/torchscale/architecture/__init__.py +2 -0
- vlmo/torchscale/architecture/config.py +197 -0
- vlmo/torchscale/architecture/decoder.py +428 -0
- vlmo/torchscale/architecture/encoder.py +489 -0
- vlmo/torchscale/architecture/encoder_decoder.py +43 -0
- vlmo/torchscale/architecture/utils.py +33 -0
- vlmo/torchscale/component/__init__.py +2 -0
- vlmo/torchscale/component/droppath.py +19 -0
- vlmo/torchscale/component/embedding.py +110 -0
- vlmo/torchscale/component/feedforward_network.py +128 -0
- vlmo/torchscale/component/multihead_attention.py +154 -0
- vlmo/torchscale/component/multiway_network.py +55 -0
- vlmo/torchscale/component/relative_position_bias.py +67 -0
- vlmo/torchscale/component/xpos_relative_position.py +62 -0
- vlmo/torchscale/model/BEiT3.py +96 -0
- vlmo/torchscale/model/__init__.py +2 -0
- vlmo/transforms/__init__.py +19 -0
- vlmo/transforms/pixelbert.py +30 -0
- vlmo/transforms/randaug.py +271 -0
- vlmo/transforms/randaugment.py +334 -0
- vlmo/transforms/square_transform.py +41 -0
- vlmo/transforms/utils.py +56 -0
- vlmo/utils/__init__.py +0 -0
.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
res/effect.png filter=lfs diff=lfs merge=lfs -text
|
.msc
ADDED
Binary file (4.2 kB). View file
|
|
.mv
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
Revision:master,CreatedAt:1742457812
|
README.md
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
tasks:
|
3 |
+
- multi-modal-embedding
|
4 |
+
- image-text-retrieval
|
5 |
+
domain:
|
6 |
+
- multi-modal
|
7 |
+
frameworks:
|
8 |
+
- pytorch
|
9 |
+
backbone:
|
10 |
+
- transformers
|
11 |
+
metrics:
|
12 |
+
- R@1
|
13 |
+
license: apache-2.0
|
14 |
+
tags:
|
15 |
+
- Ant Group
|
16 |
+
- multi-modal-embedding
|
17 |
+
widgets:
|
18 |
+
- inputs:
|
19 |
+
- validator:
|
20 |
+
max_words: 52
|
21 |
+
type: text
|
22 |
+
title: 查询文本
|
23 |
+
output:
|
24 |
+
maximize: false
|
25 |
+
examples:
|
26 |
+
- name: 1
|
27 |
+
inputs:
|
28 |
+
- data: 戴眼镜的猫
|
29 |
+
- name: 2
|
30 |
+
inputs:
|
31 |
+
- data: 一个在逛公园的女孩
|
32 |
+
task: multi-modal-embedding
|
33 |
+
---
|
34 |
+
|
35 |
+
## 模型描述
|
36 |
+
M2-Encoder是强大的中英双语多模态模型,它在我们构建的包含60亿图文对(30亿中文+30亿英文)的BM-6B上训练得到,支持zero-shot 图文跨模态检索(文搜图、图搜文) 以及 zero-shot图片分类 任务。
|
37 |
+
|
38 |
+
模型效果如下:
|
39 |
+
|
40 |
+

|
41 |
+
|
42 |
+
## 期望模型使用方式以及适用范围
|
43 |
+
本模型主要用于:
|
44 |
+
1. 图片检索文本,或文本检索图片: 以文本检索图片为例,使用M2-Encoder提前对所有图片底库进行特征抽取,给定文本query,使用M2-Encoder对query文本进行特征抽取, 然后和图片底库保存的特征进行相似度计算。
|
45 |
+
2. 图片zero-shot开集分类: 给定图像以及对应的标签列表,根据图像和标签相似度,输出与图像最匹配的标签。
|
46 |
+
|
47 |
+
|
48 |
+
## 如何使用
|
49 |
+
|
50 |
+
### 代码范例
|
51 |
+
```
|
52 |
+
# 新建环境(Python版本3.8)
|
53 |
+
conda create -n m2-encoder python=3.8
|
54 |
+
source activate m2-encoder
|
55 |
+
|
56 |
+
# clone项目地址
|
57 |
+
cd /YourPath/
|
58 |
+
git clone https://github.com/alipay/Ant-Multi-Modal-Framework
|
59 |
+
|
60 |
+
# 安装包依赖
|
61 |
+
cd ./Ant-Multi-Modal-Framework/prj/M2_Encoder/
|
62 |
+
pip install -r requirements.txt
|
63 |
+
|
64 |
+
# 运行demo,会自动通过model_scope下载对应模型权重
|
65 |
+
python run.py
|
66 |
+
```
|
67 |
+
|
68 |
+
### 模型局限性以及可能的偏差
|
69 |
+
模型在数据集上训练,有可能产生一些偏差,请用户自行评测后决定如何使用。
|
70 |
+
|
71 |
+
## 训练数据介绍
|
72 |
+
BM-6B数据集: 包含60亿清洗后的高质量中英双语图文对数据,其中文和英文数据比例基本保持一致,均为30亿。数据集搜集、构建过程详见[技术报告](https://arxiv.org/abs/2401.15896)。
|
73 |
+
|
74 |
+
## 模型训练流程
|
75 |
+
暂时不支持通过ModelScope接口进行训练,敬请期待。
|
76 |
+
|
77 |
+
|
78 |
+
### 训练
|
79 |
+
暂不支持。
|
80 |
+
## 数据评估及结果
|
81 |
+
zero-shot图文跨模态检索和zero-shot分类任务均达到SOTA.
|
82 |
+
|
83 |
+
|
84 |
+
|
85 |
+
### 相关论文以及引用信息
|
86 |
+
如果你觉得这个该模型对有所帮助,请考虑引用下面的相关的论文:
|
87 |
+
```
|
88 |
+
@misc{guo2024m2encoder,
|
89 |
+
title={M2-Encoder: Advancing Bilingual Image-Text Understanding by Large-scale Efficient Pretraining},
|
90 |
+
author={Qingpei Guo and Furong Xu and Hanxiao Zhang and Wang Ren and Ziping Ma and Lin Ju and Jian Wang and Jingdong Chen and Ming Yang},
|
91 |
+
year={2024},
|
92 |
+
url={https://arxiv.org/abs/2401.15896},
|
93 |
+
}
|
94 |
+
```
|
configuration.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"framework":"Pytorch","task":"multi-modal-embeddings","pipeline":{"type":"multi-modal-embedding-pipeline"},"allow_remote": true}
|
m2_encoder_1B.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ac4c5d9a0e44fff05f0ccadf54617b7809f489bc401212abd836c8d075047e9b
|
3 |
+
size 2921990385
|
ms_wrapper.py
ADDED
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import os
|
5 |
+
|
6 |
+
from modelscope.models.base import TorchModel
|
7 |
+
from modelscope.preprocessors.base import Preprocessor
|
8 |
+
from modelscope.pipelines.base import Model, Pipeline
|
9 |
+
from modelscope.utils.config import Config
|
10 |
+
from modelscope.pipelines.builder import PIPELINES
|
11 |
+
from modelscope.preprocessors.builder import PREPROCESSORS
|
12 |
+
from modelscope.models.builder import MODELS
|
13 |
+
from modelscope.preprocessors.image import load_image
|
14 |
+
|
15 |
+
|
16 |
+
from vlmo.utils.beit_utils import load_from_config
|
17 |
+
|
18 |
+
|
19 |
+
@PIPELINES.register_module(
|
20 |
+
"multi-modal-embeddings", module_name="multi-modal-embedding-pipeline"
|
21 |
+
)
|
22 |
+
class MyCustomPipeline(Pipeline):
|
23 |
+
"""Give simple introduction to this pipeline.
|
24 |
+
|
25 |
+
Examples:
|
26 |
+
|
27 |
+
>>> from modelscope.pipelines import pipeline
|
28 |
+
>>> input = "Hello, ModelScope!"
|
29 |
+
>>> my_pipeline = pipeline('my-task', 'my-model-id')
|
30 |
+
>>> result = my_pipeline(input)
|
31 |
+
|
32 |
+
"""
|
33 |
+
|
34 |
+
def __init__(self, model, preprocessor=None, **kwargs):
|
35 |
+
"""
|
36 |
+
use `model` and `preprocessor` to create a custom pipeline for prediction
|
37 |
+
Args:
|
38 |
+
model: model id on modelscope hub.
|
39 |
+
preprocessor: the class of method be init_preprocessor
|
40 |
+
"""
|
41 |
+
super().__init__(model=model, auto_collate=False)
|
42 |
+
self.model_dir = model
|
43 |
+
self._device = "cuda" if torch.cuda.is_available() else "cpu"
|
44 |
+
# model_config = {
|
45 |
+
# "loss_names": {"itc": 1},
|
46 |
+
# "encoder_layers": 9,
|
47 |
+
# "beit3_vl_layers": 3,
|
48 |
+
# "tokenizer_type": "GLMChineseTokenizer",
|
49 |
+
# "tokenizer": os.path.join(self.model_dir, "./vlmo/tokenizer"),
|
50 |
+
# "vocab_size": 115244,
|
51 |
+
# "whole_word_masking": True,
|
52 |
+
# "precision": 32,
|
53 |
+
# "test_only": True,
|
54 |
+
# "flash_attn": True,
|
55 |
+
# "model_path": os.path.join(self.model_dir, "m2_encoder_1B.ckpt"),
|
56 |
+
# "modelscope": {"model_id": "M2Cognition/M2-Encoder-Large"},
|
57 |
+
# "model_file": "m2_encoder_1B.ckpt",
|
58 |
+
# }
|
59 |
+
model_config = {
|
60 |
+
"loss_names": {"itc": 1},
|
61 |
+
"beit_version": "large",
|
62 |
+
"encoder_embed_dim": 1024,
|
63 |
+
"out_embed_dim": 1024,
|
64 |
+
"encoder_layers": 21,
|
65 |
+
"beit3_vl_layers": 3,
|
66 |
+
# "image_size": 224,
|
67 |
+
"visual_mask_size": 14,
|
68 |
+
"tokenizer_type": "GLMChineseTokenizer",
|
69 |
+
"tokenizer": os.path.join(self.model_dir, "./vlmo/tokenizer"),
|
70 |
+
"vocab_size": 115244,
|
71 |
+
"whole_word_masking": False,
|
72 |
+
"precision": 32,
|
73 |
+
"test_only": True,
|
74 |
+
"flash_attn": True,
|
75 |
+
"model_path": os.path.join(self.model_dir, "m2_encoder_1B.ckpt"),
|
76 |
+
"modelscope": {
|
77 |
+
"model_id": "M2Cognition/M2_Encoder_Large"
|
78 |
+
},
|
79 |
+
"model_file": "m2_encoder_1B.ckpt"
|
80 |
+
}
|
81 |
+
model, processors = load_from_config(model_config)
|
82 |
+
self.model = model
|
83 |
+
self.model.to(self._device).eval()
|
84 |
+
self._tokenizer, self._img_processor = processors
|
85 |
+
|
86 |
+
def _sanitize_parameters(self, **pipeline_parameters):
|
87 |
+
"""
|
88 |
+
this method should sanitize the keyword args to preprocessor params,
|
89 |
+
forward params and postprocess params on '__call__' or '_process_single' method
|
90 |
+
considered to be a normal classmethod with default implementation / output
|
91 |
+
|
92 |
+
Default Returns:
|
93 |
+
Dict[str, str]: preprocess_params = {}
|
94 |
+
Dict[str, str]: forward_params = {}
|
95 |
+
Dict[str, str]: postprocess_params = pipeline_parameters
|
96 |
+
"""
|
97 |
+
return {}, pipeline_parameters, {}
|
98 |
+
|
99 |
+
def _check_input(self, inputs):
|
100 |
+
pass
|
101 |
+
|
102 |
+
def _check_output(self, outputs):
|
103 |
+
pass
|
104 |
+
|
105 |
+
def forward(self, forward_params):
|
106 |
+
"""Provide default implementation using self.model and user can reimplement it"""
|
107 |
+
# print("forward_params", forward_params)
|
108 |
+
labels = forward_params.get("label_list", "")
|
109 |
+
labels = labels.split(",")
|
110 |
+
if len(labels) > 1 and labels[0] != "":
|
111 |
+
txt_encoding = self._tokenizer(
|
112 |
+
labels,
|
113 |
+
padding="max_length",
|
114 |
+
truncation=True,
|
115 |
+
max_length=self.model.hparams.config["max_text_len"],
|
116 |
+
return_special_tokens_mask=True,
|
117 |
+
)
|
118 |
+
txt_data = {
|
119 |
+
"text_ids": torch.tensor(txt_encoding["input_ids"]).to(self._device),
|
120 |
+
"text_masks": torch.tensor(txt_encoding["attention_mask"]).to(
|
121 |
+
self._device
|
122 |
+
),
|
123 |
+
"text_labels": None,
|
124 |
+
}
|
125 |
+
txt_feats = self.model.infer_text(txt_data)["cls_vlffn_feats"]
|
126 |
+
image = forward_params["image"]
|
127 |
+
image = load_image(image)
|
128 |
+
img = self._img_processor(image).unsqueeze(0)
|
129 |
+
img_data = {"image": [img.to(self._device)]}
|
130 |
+
img_feats = self.model.infer_image(img_data)["cls_vlffn_feats"]
|
131 |
+
logits_per_image = self.model.logit_scale.exp() * img_feats @ txt_feats.t()
|
132 |
+
probs = logits_per_image.softmax(dim=-1).detach().cpu()
|
133 |
+
index = probs.max(dim=-1)[1][0]
|
134 |
+
label = labels[index]
|
135 |
+
return {"text": label, "scores": probs.numpy().tolist()[0]}
|
136 |
+
else:
|
137 |
+
rets = {}
|
138 |
+
if "text" in forward_params:
|
139 |
+
text = forward_params.get("text")
|
140 |
+
txt_encoding = self._tokenizer(
|
141 |
+
text,
|
142 |
+
padding="max_length",
|
143 |
+
truncation=True,
|
144 |
+
max_length=self.model.hparams.config["max_text_len"],
|
145 |
+
return_special_tokens_mask=True,
|
146 |
+
)
|
147 |
+
txt_data = {
|
148 |
+
"text_ids": torch.tensor(txt_encoding["input_ids"]).to(
|
149 |
+
self._device
|
150 |
+
),
|
151 |
+
"text_masks": torch.tensor(txt_encoding["attention_mask"]).to(
|
152 |
+
self._device
|
153 |
+
),
|
154 |
+
"text_labels": None,
|
155 |
+
}
|
156 |
+
txt_feats = self.model.infer_text(txt_data)["cls_vlffn_feats"]
|
157 |
+
rets.update({"text_embedding": txt_feats.detach()})
|
158 |
+
if "img" in forward_params:
|
159 |
+
input_img = forward_params["img"]
|
160 |
+
img = self._img_processor(input_img).unsqueeze(0)
|
161 |
+
img_data = {"image": [img.to(self._device)]}
|
162 |
+
img_feats = self.model.infer_image(img_data)["cls_vlffn_feats"]
|
163 |
+
rets.update({"img_embedding": img_feats.detach()})
|
164 |
+
|
165 |
+
return rets
|
166 |
+
|
167 |
+
def preprocess(self, inputs):
|
168 |
+
return inputs
|
169 |
+
|
170 |
+
def postprocess(self, inputs):
|
171 |
+
"""If current pipeline support model reuse, common postprocess
|
172 |
+
code should be write here.
|
173 |
+
|
174 |
+
Args:
|
175 |
+
inputs: input data
|
176 |
+
|
177 |
+
Return:
|
178 |
+
dict of results: a dict containing outputs of model, each
|
179 |
+
output should have the standard output name.
|
180 |
+
"""
|
181 |
+
return inputs
|
182 |
+
|
183 |
+
|
184 |
+
"""
|
185 |
+
# Tips: usr_config_path is the temporary save configuration location, after upload modelscope hub, it is the model_id
|
186 |
+
usr_config_path = "/tmp/snapdown/"
|
187 |
+
config = Config(
|
188 |
+
{
|
189 |
+
"framework": "pytorch",
|
190 |
+
"task": "multi-modal-embeddings",
|
191 |
+
"model": {"type": "m2-encoder"},
|
192 |
+
"pipeline": {"type": "multi-modal-embedding-pipeline"},
|
193 |
+
"allow_remote": True,
|
194 |
+
}
|
195 |
+
)
|
196 |
+
config.dump("/tmp/snapdown/" + "configuration.json")
|
197 |
+
"""
|
198 |
+
|
199 |
+
if __name__ == "__main__":
|
200 |
+
from modelscope.pipelines import pipeline
|
201 |
+
from modelscope.preprocessors.image import load_image
|
202 |
+
|
203 |
+
model = "M2Cognition/M2-Encoder"
|
204 |
+
pipe = pipeline("multi-modal-embeddings", model=model)
|
205 |
+
input = {
|
206 |
+
"image": "https://clip-cn-beijing.oss-cn-beijing.aliyuncs.com/pokemon.jpeg",
|
207 |
+
"label_list": "杰尼龟,妙蛙种子,小火龙,皮卡丘",
|
208 |
+
}
|
209 |
+
demo = pipe(input)
|
210 |
+
print("demo output", demo)
|
211 |
+
inputs = {"text": ["杰尼龟", "妙蛙种子", "小火龙", "皮卡丘"]}
|
212 |
+
output = pipe(inputs)
|
213 |
+
print("text output", output)
|
214 |
+
input_img = load_image(
|
215 |
+
"https://clip-cn-beijing.oss-cn-beijing.aliyuncs.com/pokemon.jpeg"
|
216 |
+
) # 支持皮卡丘示例图片路径/本地图片 返回PIL.Image
|
217 |
+
inputs = {"img": input_img}
|
218 |
+
img_embedding = pipe(inputs) # 2D Tensor, [图片数, 特征维度]
|
219 |
+
print("image output", img_embedding)
|
requirements.txt
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
pytorch_lightning<=2.0.8
|
3 |
+
transformers
|
4 |
+
Pillow
|
5 |
+
tqdm
|
6 |
+
einops
|
7 |
+
sacred
|
8 |
+
timm
|
9 |
+
torchvision
|
10 |
+
fairscale
|
11 |
+
numpy
|
12 |
+
opencv-python
|
13 |
+
sentencepiece
|
14 |
+
modelscope
|
res/effect.png
ADDED
![]() |
Git LFS Details
|
vlmo/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
vlmo/Encoder_0.4B.json
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"loss_names": {"itc": 1},
|
3 |
+
"encoder_layers": 9,
|
4 |
+
"beit3_vl_layers": 3,
|
5 |
+
"tokenizer_type": "GLMChineseTokenizer",
|
6 |
+
"tokenizer": "./vlmo/tokenizer",
|
7 |
+
"vocab_size": 115244,
|
8 |
+
"whole_word_masking": true,
|
9 |
+
"precision": 32,
|
10 |
+
"test_only": true,
|
11 |
+
"flash_attn": true,
|
12 |
+
"model_path": "m2_encoder_0.4B.ckpt",
|
13 |
+
"modelscope": {
|
14 |
+
"model_id": "M2Cognition/M2-Encoder"
|
15 |
+
},
|
16 |
+
"model_file": "m2_encoder_0.4B.ckpt"
|
17 |
+
}
|
vlmo/README.md
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: Apache License 2.0
|
3 |
+
---
|
4 |
+
###### 该模型当前使用的是默认介绍模版,处于“预发布”阶段,页面仅限所有者可见。
|
5 |
+
###### 请根据[模型贡献文档说明](https://www.modelscope.cn/docs/%E5%A6%82%E4%BD%95%E6%92%B0%E5%86%99%E5%A5%BD%E7%94%A8%E7%9A%84%E6%A8%A1%E5%9E%8B%E5%8D%A1%E7%89%87),及时完善模型卡片内容。ModelScope平台将在模型卡片完善后展示。谢谢您的理解。
|
6 |
+
|
7 |
+
#### Clone with HTTP
|
8 |
+
```bash
|
9 |
+
git clone https://www.modelscope.cn/M2Cognition/M2_Encoder_demo.git
|
10 |
+
```
|
vlmo/__init__.py
ADDED
File without changes
|
vlmo/config.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from sacred import Experiment
|
2 |
+
|
3 |
+
ex = Experiment("VLMo")
|
4 |
+
|
5 |
+
|
6 |
+
def _loss_names(d):
|
7 |
+
ret = {
|
8 |
+
"itm": 0, # image-text matching loss
|
9 |
+
"itc": 0, # image-text contrastive loss
|
10 |
+
"caption": 0, # image captioning loss
|
11 |
+
"mvlm": 0, # masked language modeling loss
|
12 |
+
"textmlm": 0, # text-only masked language modeling
|
13 |
+
"imagemlm": 0, # image-only masked language modeling
|
14 |
+
"vqa": 0,
|
15 |
+
"nlvr2": 0,
|
16 |
+
"irtr": 0, # retrieval task ft
|
17 |
+
}
|
18 |
+
ret.update(d)
|
19 |
+
return ret
|
20 |
+
|
21 |
+
|
22 |
+
@ex.config
|
23 |
+
def config():
|
24 |
+
exp_name = "vlmo"
|
25 |
+
seed = 1
|
26 |
+
datasets = ["coco", "vg", "sbu", "gcc"] # dataset name, the definition can refer to: vlmo/datamodules/__init__.py # noqa
|
27 |
+
loss_names = _loss_names({"itm": 0, "itc": 0, "mvlm": 0}) # training loss
|
28 |
+
batch_size = 1024 # this is a desired batch size; pl trainer will accumulate gradients.
|
29 |
+
|
30 |
+
# BEiT-v3 setting
|
31 |
+
encoder_layers = 12 # the layer number of backbone
|
32 |
+
encoder_embed_dim = 768 # the hidden size of tokenizer
|
33 |
+
out_embed_dim = 768 # the hidden size of output embedding
|
34 |
+
beit_version = "base" # model size: base(0.4B)|large(1B)|huge(10B)
|
35 |
+
beit3_vl_layers = 3 # the layer number of vl_backbone
|
36 |
+
deepnorm_init = True # init method
|
37 |
+
share_layer = False # if share the weight between layer within backbone
|
38 |
+
share_attn = False # if share the attention weight of different layer
|
39 |
+
one_attn = False # if share the attention weight of vision and language
|
40 |
+
|
41 |
+
# Image setting
|
42 |
+
train_transform_keys = ["square_transform_randaug"] # train transform: refer to vlmo/transforms/__init__.py
|
43 |
+
val_transform_keys = ["square_transform"] # test transform: refer to refer to vlmo/transforms/__init__.py
|
44 |
+
image_size = 224 # image size
|
45 |
+
reclip_image_size = None # reclip image size
|
46 |
+
patch_size = 16 # patch size
|
47 |
+
draw_false_image = 0 # if get negative image
|
48 |
+
image_only = False # only input image
|
49 |
+
text_only = False # # only input text
|
50 |
+
|
51 |
+
# Video setting, video_num_frm is not None means video input
|
52 |
+
video_num_frm = None
|
53 |
+
|
54 |
+
# Visual tokenizer setting based on beit2
|
55 |
+
tokenizer_model = "beit2_visual_tokenizer"
|
56 |
+
codebook_size = 8192
|
57 |
+
codebook_dim = 32
|
58 |
+
visual_mask_size = 14
|
59 |
+
visual_mask_num = 80
|
60 |
+
|
61 |
+
# Text Setting
|
62 |
+
lang = 'cn' # language for zero-shot imagenet testing: cn|en
|
63 |
+
vqav2_label_size = 3129
|
64 |
+
max_text_len = 40 # the number of characters
|
65 |
+
max_text_len_of_initckpt = 196
|
66 |
+
tokenizer_type = "BertTokenizer" # Chinese text
|
67 |
+
vocab_size = 21128
|
68 |
+
tokenizer = "./vocab.txt"
|
69 |
+
whole_word_masking = True
|
70 |
+
mlm_prob = 0.15 # language mask ratio
|
71 |
+
draw_false_text = 0
|
72 |
+
mvlm_prob = 0.50 # vision-langurage mlm task
|
73 |
+
mask_ratio = 0 # flip: mask ratio for image
|
74 |
+
|
75 |
+
# cap setting
|
76 |
+
cap_onlytext = False # default caption image to text
|
77 |
+
|
78 |
+
# imagemlm setting
|
79 |
+
split_data_for_imagemlm = False # if True, split a batch data to two parts, and the first part for imagemlm.
|
80 |
+
|
81 |
+
# itc setting
|
82 |
+
itc_mask = False # itc use masked token
|
83 |
+
aggregate_nodes = -1 # aggregate nodes num for compute_itc, default -1 is for all nodes
|
84 |
+
|
85 |
+
# Transformer Setting
|
86 |
+
model_arch = "vlmo_base_patch16"
|
87 |
+
drop_path_rate = 0.1
|
88 |
+
|
89 |
+
# Downstream Setting
|
90 |
+
get_recall_metric = False
|
91 |
+
get_recall_rerank_metric = False
|
92 |
+
get_zeroshot_metric = False
|
93 |
+
get_muge_feat = False
|
94 |
+
get_f30k_feat = False
|
95 |
+
k_test = 32
|
96 |
+
|
97 |
+
# PL Trainer Setting
|
98 |
+
resume_from = None
|
99 |
+
fast_dev_run = False
|
100 |
+
val_check_interval = 1.0
|
101 |
+
test_only = False
|
102 |
+
use_sharded_training = False
|
103 |
+
resume_during_training = False
|
104 |
+
save_top_k = 10
|
105 |
+
every_n_train_steps = 2000 # the step to save checkpoint
|
106 |
+
log_metric_steps = 100 # the step to log metric
|
107 |
+
|
108 |
+
# below params varies with the environment
|
109 |
+
use_pcache = False # data storage method: pcache or nas
|
110 |
+
pcache_root = ""
|
111 |
+
# main_site: pcache://multimodalproxyi-pool.cz50c.alipay.com:39999/mnt/
|
112 |
+
# public_cloud: pcache://pcache_public_cloud.pcache.local:39999/mnt/abc7c88079a60b45ddfce7afa40720b7/
|
113 |
+
gpu_env = "main_site" # public_cloud or main_site
|
114 |
+
data_root = "" # data root for data list
|
115 |
+
|
116 |
+
|
117 |
+
log_dir = "result"
|
118 |
+
per_gpu_batchsize = 4 # you should define this manually with per_gpu_batch_size=#
|
119 |
+
num_gpus = 1
|
120 |
+
num_nodes = 1
|
121 |
+
load_path = ""
|
122 |
+
num_workers = 8
|
123 |
+
precision = 16
|
124 |
+
local_run = True
|
125 |
+
flash_attn = False
|
126 |
+
deepspeed_config = None # "ds_config.json"
|
127 |
+
coalesce_backbone = False
|
128 |
+
mask_data = "v+l" # 'v+l':choose input of imagemlm+textmlm task, 'vl': choose input of mvlm task.
|
129 |
+
communication_benchmark = False
|
130 |
+
checkpoint_activations = False
|
131 |
+
|
132 |
+
# dataset setting
|
133 |
+
single_cap = True # if have only one caption
|
134 |
+
random_one = False # if choose one caption from caption list
|
135 |
+
|
136 |
+
# ITC setting
|
137 |
+
itc_feats_name = "cls_vlffn_feats" # feat for itc loss
|
138 |
+
itc_distill = ""
|
139 |
+
itc_distill_dim = 1024
|
140 |
+
itc_teacher_weights = ""
|
141 |
+
|
142 |
+
# mup training setting
|
143 |
+
mup = False
|
144 |
+
base_encoder_embed_dim = 1
|
145 |
+
delta_encoder_embed_dim = 2
|
146 |
+
mup_encoder_attention_heads = 1
|
147 |
+
base_encoder_ffn_embed_dim = 1
|
148 |
+
delta_encoder_ffn_embed_dim = 2
|
149 |
+
|
150 |
+
# atorch
|
151 |
+
atorch_config = None
|
152 |
+
compile_op = False
|
153 |
+
optimizer_state_shard_save = False
|
154 |
+
model_state_shard_save = False
|
155 |
+
|
156 |
+
# itc loss
|
157 |
+
local_loss = False
|
158 |
+
use_dual_softmax = False
|
159 |
+
|
160 |
+
num_frames = 1
|
161 |
+
# ----------------------- LMM pretraining config -----------------------
|
162 |
+
|
163 |
+
# norm setting
|
164 |
+
deepnorm = False
|
165 |
+
|
vlmo/modules/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .vlmo_module import VLMo
|
vlmo/modules/heads.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
|
3 |
+
|
4 |
+
class Pooler(nn.Module):
|
5 |
+
def __init__(self, hidden_size):
|
6 |
+
super().__init__()
|
7 |
+
self.dense = nn.Linear(hidden_size, hidden_size)
|
8 |
+
self.activation = nn.Tanh()
|
9 |
+
|
10 |
+
def forward(self, hidden_states):
|
11 |
+
first_token_tensor = hidden_states[:, 0]
|
12 |
+
pooled_output = self.dense(first_token_tensor)
|
13 |
+
pooled_output = self.activation(pooled_output)
|
14 |
+
return pooled_output
|
15 |
+
|
16 |
+
|
17 |
+
class ITCHead(nn.Module):
|
18 |
+
def __init__(self, hidden_size, out_size):
|
19 |
+
super().__init__()
|
20 |
+
self.fc = nn.Linear(hidden_size, out_size, bias=False)
|
21 |
+
|
22 |
+
def forward(self, x):
|
23 |
+
x = self.fc(x)
|
24 |
+
return x
|
vlmo/modules/modeling_utils.py
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# Image as a Foreign Language: BEiT Pretraining for Vision and Vision-Language Tasks (https://arxiv.org/abs/2208.10442)
|
3 |
+
# Github source: https://github.com/microsoft/unilm/tree/master/beit3
|
4 |
+
# Copyright (c) 2023 Microsoft
|
5 |
+
# Licensed under The MIT License [see LICENSE for details]
|
6 |
+
# --------------------------------------------------------'
|
7 |
+
|
8 |
+
import math
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from timm.models.layers import trunc_normal_ as __call_trunc_normal_
|
12 |
+
|
13 |
+
from vlmo.torchscale.model.BEiT3 import BEiT3
|
14 |
+
from vlmo.torchscale.architecture.config import EncoderConfig
|
15 |
+
|
16 |
+
|
17 |
+
def trunc_normal_(tensor, mean=0.0, std=1.0):
|
18 |
+
__call_trunc_normal_(tensor, mean=mean, std=std, a=-std, b=std)
|
19 |
+
|
20 |
+
|
21 |
+
def _get_base_config(
|
22 |
+
img_size=224,
|
23 |
+
patch_size=16,
|
24 |
+
drop_path_rate=0,
|
25 |
+
checkpoint_activations=None,
|
26 |
+
mlp_ratio=4,
|
27 |
+
vocab_size=64010,
|
28 |
+
encoder_layers=12,
|
29 |
+
encoder_embed_dim=768,
|
30 |
+
encoder_attention_heads=12,
|
31 |
+
share_layer=False,
|
32 |
+
share_attn=False,
|
33 |
+
deepnorm=False,
|
34 |
+
mask_ratio=0,
|
35 |
+
max_text_len=52,
|
36 |
+
one_attn=False,
|
37 |
+
**kwargs
|
38 |
+
):
|
39 |
+
return EncoderConfig(
|
40 |
+
img_size=img_size,
|
41 |
+
patch_size=patch_size,
|
42 |
+
vocab_size=vocab_size,
|
43 |
+
multiway=True,
|
44 |
+
layernorm_embedding=False,
|
45 |
+
normalize_output=True,
|
46 |
+
no_output_layer=True,
|
47 |
+
drop_path_rate=drop_path_rate,
|
48 |
+
encoder_embed_dim=encoder_embed_dim,
|
49 |
+
encoder_attention_heads=encoder_attention_heads,
|
50 |
+
encoder_layers=encoder_layers,
|
51 |
+
encoder_ffn_embed_dim=int(encoder_embed_dim * mlp_ratio),
|
52 |
+
checkpoint_activations=checkpoint_activations,
|
53 |
+
share_layer=share_layer,
|
54 |
+
share_attn=share_attn,
|
55 |
+
deepnorm=deepnorm,
|
56 |
+
mask_ratio=mask_ratio,
|
57 |
+
max_text_len=max_text_len,
|
58 |
+
one_attn=one_attn,
|
59 |
+
)
|
60 |
+
|
61 |
+
|
62 |
+
def _get_large_config(
|
63 |
+
img_size=224,
|
64 |
+
patch_size=16,
|
65 |
+
drop_path_rate=0,
|
66 |
+
checkpoint_activations=None,
|
67 |
+
mlp_ratio=4,
|
68 |
+
vocab_size=64010,
|
69 |
+
encoder_layers=24,
|
70 |
+
encoder_embed_dim=1024,
|
71 |
+
encoder_attention_heads=16,
|
72 |
+
share_layer=False,
|
73 |
+
share_attn=False,
|
74 |
+
deepnorm=False,
|
75 |
+
mask_ratio=0,
|
76 |
+
max_text_len=52,
|
77 |
+
one_attn=False,
|
78 |
+
**kwargs
|
79 |
+
):
|
80 |
+
return EncoderConfig(
|
81 |
+
img_size=img_size,
|
82 |
+
patch_size=patch_size,
|
83 |
+
vocab_size=vocab_size,
|
84 |
+
multiway=True,
|
85 |
+
layernorm_embedding=False,
|
86 |
+
normalize_output=True,
|
87 |
+
no_output_layer=True,
|
88 |
+
drop_path_rate=drop_path_rate,
|
89 |
+
encoder_embed_dim=encoder_embed_dim,
|
90 |
+
encoder_attention_heads=encoder_attention_heads,
|
91 |
+
encoder_layers=encoder_layers,
|
92 |
+
encoder_ffn_embed_dim=int(encoder_embed_dim * mlp_ratio),
|
93 |
+
checkpoint_activations=checkpoint_activations,
|
94 |
+
share_layer=share_layer,
|
95 |
+
share_attn=share_attn,
|
96 |
+
deepnorm=deepnorm,
|
97 |
+
mask_ratio=mask_ratio,
|
98 |
+
max_text_len=max_text_len,
|
99 |
+
one_attn=one_attn,
|
100 |
+
)
|
101 |
+
|
102 |
+
|
103 |
+
def _get_huge_config(
|
104 |
+
img_size=224,
|
105 |
+
patch_size=16,
|
106 |
+
drop_path_rate=0,
|
107 |
+
checkpoint_activations=None,
|
108 |
+
mlp_ratio=4,
|
109 |
+
vocab_size=30522,
|
110 |
+
encoder_layers=32,
|
111 |
+
encoder_embed_dim=4096,
|
112 |
+
encoder_attention_heads=32,
|
113 |
+
share_layer=False,
|
114 |
+
share_attn=False,
|
115 |
+
deepnorm=False,
|
116 |
+
mask_ratio=0,
|
117 |
+
max_text_len=52,
|
118 |
+
one_attn=False,
|
119 |
+
**kwargs
|
120 |
+
):
|
121 |
+
return EncoderConfig(
|
122 |
+
img_size=img_size,
|
123 |
+
patch_size=patch_size,
|
124 |
+
vocab_size=vocab_size,
|
125 |
+
multiway=True,
|
126 |
+
layernorm_embedding=False,
|
127 |
+
normalize_output=True,
|
128 |
+
no_output_layer=True,
|
129 |
+
drop_path_rate=drop_path_rate,
|
130 |
+
encoder_embed_dim=encoder_embed_dim,
|
131 |
+
encoder_attention_heads=encoder_attention_heads,
|
132 |
+
encoder_layers=encoder_layers,
|
133 |
+
encoder_ffn_embed_dim=int(encoder_embed_dim * mlp_ratio),
|
134 |
+
checkpoint_activations=checkpoint_activations,
|
135 |
+
share_layer=share_layer,
|
136 |
+
share_attn=share_attn,
|
137 |
+
deepnorm=deepnorm,
|
138 |
+
mask_ratio=mask_ratio,
|
139 |
+
max_text_len=max_text_len,
|
140 |
+
one_attn=one_attn,
|
141 |
+
)
|
142 |
+
|
143 |
+
|
144 |
+
class BEiT3Wrapper(nn.Module):
|
145 |
+
def __init__(self, args, **kwargs):
|
146 |
+
super().__init__()
|
147 |
+
self.args = args
|
148 |
+
self.beit3 = BEiT3(args)
|
149 |
+
self.apply(self._init_weights)
|
150 |
+
|
151 |
+
def fix_init_weight(self):
|
152 |
+
def rescale(param, layer_id):
|
153 |
+
param.div_(math.sqrt(2.0 * layer_id))
|
154 |
+
|
155 |
+
for layer_id, layer in enumerate(self.blocks):
|
156 |
+
rescale(layer.attn.proj.weight.data, layer_id + 1)
|
157 |
+
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
|
158 |
+
|
159 |
+
def get_num_layers(self):
|
160 |
+
return self.beit3.encoder.num_layers
|
161 |
+
|
162 |
+
@torch.jit.ignore
|
163 |
+
def no_weight_decay(self):
|
164 |
+
return {
|
165 |
+
"pos_embed",
|
166 |
+
"cls_token",
|
167 |
+
"beit3.encoder.embed_positions.A.weight",
|
168 |
+
"beit3.vision_embed.cls_token",
|
169 |
+
"logit_scale",
|
170 |
+
}
|
171 |
+
|
172 |
+
def _init_weights(self, m):
|
173 |
+
if isinstance(m, nn.Linear):
|
174 |
+
trunc_normal_(m.weight, std=0.02)
|
175 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
176 |
+
nn.init.constant_(m.bias, 0)
|
177 |
+
elif isinstance(m, nn.LayerNorm):
|
178 |
+
nn.init.constant_(m.bias, 0)
|
179 |
+
nn.init.constant_(m.weight, 1.0)
|
vlmo/modules/multiway_transformer.py
ADDED
@@ -0,0 +1,396 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" Vision Transformer (ViT) in PyTorch
|
2 |
+
|
3 |
+
A PyTorch implement of Vision Transformers as described in
|
4 |
+
'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' - https://arxiv.org/abs/2010.11929
|
5 |
+
|
6 |
+
The official jax code is released and available at https://github.com/google-research/vision_transformer
|
7 |
+
|
8 |
+
Acknowledgments:
|
9 |
+
* The paper authors for releasing code and weights, thanks!
|
10 |
+
* I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out
|
11 |
+
for some einops/einsum fun
|
12 |
+
* Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT
|
13 |
+
* Bert reference code checks against Huggingface Transformers and Tensorflow Bert
|
14 |
+
|
15 |
+
DeiT model defs and weights from https://github.com/facebookresearch/deit,
|
16 |
+
paper `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877
|
17 |
+
|
18 |
+
Hacked together by / Copyright 2020 Ross Wightman
|
19 |
+
"""
|
20 |
+
from functools import partial
|
21 |
+
|
22 |
+
import torch
|
23 |
+
import torch.nn as nn
|
24 |
+
import torch.nn.functional as F
|
25 |
+
|
26 |
+
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
27 |
+
from timm.models.registry import register_model
|
28 |
+
from pytorch_lightning.utilities.rank_zero import rank_zero_info
|
29 |
+
|
30 |
+
|
31 |
+
class Mlp(nn.Module):
|
32 |
+
def __init__(
|
33 |
+
self,
|
34 |
+
in_features,
|
35 |
+
hidden_features=None,
|
36 |
+
out_features=None,
|
37 |
+
act_layer=nn.GELU,
|
38 |
+
drop=0.0,
|
39 |
+
):
|
40 |
+
super().__init__()
|
41 |
+
out_features = out_features or in_features
|
42 |
+
hidden_features = hidden_features or in_features
|
43 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
44 |
+
self.act = act_layer()
|
45 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
46 |
+
self.drop = nn.Dropout(drop)
|
47 |
+
|
48 |
+
def forward(self, x):
|
49 |
+
x = self.fc1(x)
|
50 |
+
x = self.act(x)
|
51 |
+
x = self.drop(x)
|
52 |
+
x = self.fc2(x)
|
53 |
+
x = self.drop(x)
|
54 |
+
return x
|
55 |
+
|
56 |
+
|
57 |
+
class Attention(nn.Module):
|
58 |
+
def __init__(
|
59 |
+
self,
|
60 |
+
dim,
|
61 |
+
num_heads=8,
|
62 |
+
qkv_bias=False,
|
63 |
+
qk_scale=None,
|
64 |
+
attn_drop=0.0,
|
65 |
+
proj_drop=0.0,
|
66 |
+
):
|
67 |
+
super().__init__()
|
68 |
+
self.num_heads = num_heads
|
69 |
+
head_dim = dim // num_heads
|
70 |
+
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
|
71 |
+
self.scale = qk_scale or head_dim**-0.5
|
72 |
+
|
73 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=False)
|
74 |
+
if qkv_bias:
|
75 |
+
self.q_bias = nn.Parameter(torch.zeros(dim))
|
76 |
+
self.v_bias = nn.Parameter(torch.zeros(dim))
|
77 |
+
else:
|
78 |
+
self.q_bias = None
|
79 |
+
self.v_bias = None
|
80 |
+
|
81 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
82 |
+
self.proj = nn.Linear(dim, dim)
|
83 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
84 |
+
|
85 |
+
def forward(self, x, mask=None, relative_position_bias=None):
|
86 |
+
B, N, C = x.shape
|
87 |
+
|
88 |
+
qkv_bias = None
|
89 |
+
if self.q_bias is not None:
|
90 |
+
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
|
91 |
+
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
|
92 |
+
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
93 |
+
|
94 |
+
q, k, v = (
|
95 |
+
qkv[0],
|
96 |
+
qkv[1],
|
97 |
+
qkv[2],
|
98 |
+
) # make torchscript happy (cannot use tensor as tuple)
|
99 |
+
|
100 |
+
q = q * self.scale
|
101 |
+
attn = q.float() @ k.float().transpose(-2, -1)
|
102 |
+
|
103 |
+
if relative_position_bias is not None:
|
104 |
+
attn = attn + relative_position_bias.unsqueeze(0)
|
105 |
+
|
106 |
+
if mask is not None:
|
107 |
+
mask = mask.bool()
|
108 |
+
attn = attn.masked_fill(~mask[:, None, None, :], float("-inf"))
|
109 |
+
attn = attn.softmax(dim=-1).type_as(x)
|
110 |
+
attn = self.attn_drop(attn)
|
111 |
+
|
112 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
113 |
+
x = self.proj(x)
|
114 |
+
x = self.proj_drop(x)
|
115 |
+
return x
|
116 |
+
|
117 |
+
|
118 |
+
class Block(nn.Module):
|
119 |
+
def __init__(
|
120 |
+
self,
|
121 |
+
dim,
|
122 |
+
num_heads,
|
123 |
+
mlp_ratio=4.0,
|
124 |
+
qkv_bias=False,
|
125 |
+
qk_scale=None,
|
126 |
+
drop=0.0,
|
127 |
+
attn_drop=0.0,
|
128 |
+
drop_path=0.0,
|
129 |
+
act_layer=nn.GELU,
|
130 |
+
norm_layer=nn.LayerNorm,
|
131 |
+
with_vlffn=False,
|
132 |
+
layer_scale_init_values=0.1,
|
133 |
+
max_text_len=40,
|
134 |
+
):
|
135 |
+
super().__init__()
|
136 |
+
self.norm1 = norm_layer(dim)
|
137 |
+
self.attn = Attention(
|
138 |
+
dim,
|
139 |
+
num_heads=num_heads,
|
140 |
+
qkv_bias=qkv_bias,
|
141 |
+
qk_scale=qk_scale,
|
142 |
+
attn_drop=attn_drop,
|
143 |
+
proj_drop=drop,
|
144 |
+
)
|
145 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
146 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
147 |
+
self.norm2_text = norm_layer(dim)
|
148 |
+
self.norm2_imag = norm_layer(dim)
|
149 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
150 |
+
self.mlp_text = Mlp(
|
151 |
+
in_features=dim,
|
152 |
+
hidden_features=mlp_hidden_dim,
|
153 |
+
act_layer=act_layer,
|
154 |
+
drop=drop,
|
155 |
+
)
|
156 |
+
self.mlp_imag = Mlp(
|
157 |
+
in_features=dim,
|
158 |
+
hidden_features=mlp_hidden_dim,
|
159 |
+
act_layer=act_layer,
|
160 |
+
drop=drop,
|
161 |
+
)
|
162 |
+
self.mlp_vl = None
|
163 |
+
if with_vlffn:
|
164 |
+
self.mlp_vl = Mlp(
|
165 |
+
in_features=dim,
|
166 |
+
hidden_features=mlp_hidden_dim,
|
167 |
+
act_layer=act_layer,
|
168 |
+
drop=drop,
|
169 |
+
)
|
170 |
+
self.norm2_vl = norm_layer(dim)
|
171 |
+
|
172 |
+
self.gamma_1 = (
|
173 |
+
nn.Parameter(layer_scale_init_values * torch.ones((dim)), requires_grad=True)
|
174 |
+
if layer_scale_init_values is not None
|
175 |
+
else 1.0
|
176 |
+
)
|
177 |
+
self.gamma_2 = (
|
178 |
+
nn.Parameter(layer_scale_init_values * torch.ones((dim)), requires_grad=True)
|
179 |
+
if layer_scale_init_values is not None
|
180 |
+
else 1.0
|
181 |
+
)
|
182 |
+
|
183 |
+
self.max_text_len = max_text_len
|
184 |
+
|
185 |
+
def forward(self, x, mask=None, modality_type=None, relative_position_bias=None):
|
186 |
+
x = x + self.drop_path(
|
187 |
+
self.gamma_1 * self.attn(self.norm1(x), mask=mask, relative_position_bias=relative_position_bias)
|
188 |
+
)
|
189 |
+
|
190 |
+
if modality_type == "image":
|
191 |
+
x = x + self.drop_path(self.gamma_2 * self.mlp_imag(self.norm2_imag(x)))
|
192 |
+
elif modality_type == "text":
|
193 |
+
x = x + self.drop_path(self.gamma_2 * self.mlp_text(self.norm2_text(x)))
|
194 |
+
else:
|
195 |
+
if self.mlp_vl is None:
|
196 |
+
x_text = x[:, : self.max_text_len]
|
197 |
+
x_imag = x[:, self.max_text_len :]
|
198 |
+
x_text = x_text + self.drop_path(self.gamma_2 * self.mlp_text(self.norm2_text(x_text)))
|
199 |
+
x_imag = x_imag + self.drop_path(self.gamma_2 * self.mlp_imag(self.norm2_imag(x_imag)))
|
200 |
+
x = torch.cat([x_text, x_imag], dim=1)
|
201 |
+
else:
|
202 |
+
x = x + self.drop_path(self.gamma_2 * self.mlp_vl(self.norm2_vl(x)))
|
203 |
+
|
204 |
+
return x
|
205 |
+
|
206 |
+
|
207 |
+
class PatchEmbed(nn.Module):
|
208 |
+
"""Image to Patch Embedding"""
|
209 |
+
|
210 |
+
def __init__(
|
211 |
+
self,
|
212 |
+
img_size=224,
|
213 |
+
patch_size=16,
|
214 |
+
in_chans=3,
|
215 |
+
embed_dim=768,
|
216 |
+
no_patch_embed_bias=False,
|
217 |
+
):
|
218 |
+
super().__init__()
|
219 |
+
img_size = to_2tuple(img_size)
|
220 |
+
patch_size = to_2tuple(patch_size)
|
221 |
+
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
|
222 |
+
self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
|
223 |
+
self.img_size = img_size
|
224 |
+
self.patch_size = patch_size
|
225 |
+
self.num_patches = num_patches
|
226 |
+
|
227 |
+
self.proj = nn.Conv2d(
|
228 |
+
in_chans,
|
229 |
+
embed_dim,
|
230 |
+
kernel_size=patch_size,
|
231 |
+
stride=patch_size,
|
232 |
+
bias=False if no_patch_embed_bias else True,
|
233 |
+
)
|
234 |
+
|
235 |
+
def forward(self, x):
|
236 |
+
B, C, H, W = x.shape
|
237 |
+
assert (
|
238 |
+
H == self.img_size[0] and W == self.img_size[1]
|
239 |
+
), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
240 |
+
# FIXME look at relaxing size constraints
|
241 |
+
x = self.proj(x)
|
242 |
+
return x
|
243 |
+
|
244 |
+
|
245 |
+
class MultiWayTransformer(nn.Module):
|
246 |
+
"""Vision Transformer
|
247 |
+
|
248 |
+
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` -
|
249 |
+
https://arxiv.org/abs/2010.11929
|
250 |
+
"""
|
251 |
+
|
252 |
+
def __init__(
|
253 |
+
self,
|
254 |
+
img_size=224,
|
255 |
+
patch_size=16,
|
256 |
+
in_chans=3,
|
257 |
+
embed_dim=768,
|
258 |
+
depth=12,
|
259 |
+
num_heads=12,
|
260 |
+
mlp_ratio=4.0,
|
261 |
+
qkv_bias=True,
|
262 |
+
qk_scale=None,
|
263 |
+
drop_rate=0.0,
|
264 |
+
attn_drop_rate=0.0,
|
265 |
+
drop_path_rate=0.0,
|
266 |
+
norm_layer=None,
|
267 |
+
need_relative_position_embed=True,
|
268 |
+
use_abs_pos_emb=False,
|
269 |
+
layer_scale_init_values=0.1,
|
270 |
+
vlffn_start_layer_index=10,
|
271 |
+
config=None,
|
272 |
+
):
|
273 |
+
"""
|
274 |
+
Args:
|
275 |
+
img_size (int, tuple): input image size
|
276 |
+
patch_size (int, tuple): patch size
|
277 |
+
in_chans (int): number of input channels
|
278 |
+
num_classes (int): number of classes for classification head
|
279 |
+
embed_dim (int): embedding dimension
|
280 |
+
depth (int): depth of transformer
|
281 |
+
num_heads (int): number of attention heads
|
282 |
+
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
283 |
+
qkv_bias (bool): enable bias for qkv if True
|
284 |
+
qk_scale (float): override default qk scale of head_dim ** -0.5 if set
|
285 |
+
drop_rate (float): dropout rate
|
286 |
+
attn_drop_rate (float): attention dropout rate
|
287 |
+
drop_path_rate (float): stochastic depth rate
|
288 |
+
norm_layer: (nn.Module): normalization layer
|
289 |
+
need_relative_position_embed (bool): enable relative position bias on self-attention
|
290 |
+
use_abs_pos_emb (bool): enable abs pos emb
|
291 |
+
layer_scale_init_values (float or None): layer scale init values, set None to disable
|
292 |
+
vlffn_start_layer_index (int): vl-ffn start index
|
293 |
+
config: (dict): other hyper from pytorch-lighting
|
294 |
+
"""
|
295 |
+
super().__init__()
|
296 |
+
drop_path_rate = drop_path_rate if config is None else config["drop_path_rate"]
|
297 |
+
rank_zero_info("drop path rate: {}".format(drop_path_rate))
|
298 |
+
self.use_abs_pos_emb = use_abs_pos_emb
|
299 |
+
self.need_relative_position_embed = need_relative_position_embed
|
300 |
+
|
301 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
302 |
+
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
|
303 |
+
|
304 |
+
self.patch_embed = PatchEmbed(
|
305 |
+
img_size=img_size,
|
306 |
+
patch_size=patch_size,
|
307 |
+
in_chans=in_chans,
|
308 |
+
embed_dim=embed_dim,
|
309 |
+
)
|
310 |
+
num_patches = self.patch_embed.num_patches
|
311 |
+
self.patch_size = patch_size
|
312 |
+
self.num_heads = num_heads
|
313 |
+
self.vlffn_start_layer_index = vlffn_start_layer_index
|
314 |
+
if config["loss_names"]["textmlm"] > 0:
|
315 |
+
self.vlffn_start_layer_index = depth
|
316 |
+
rank_zero_info(
|
317 |
+
"Set vlffn_start_layer_index={} for text-only pretraining".format(self.vlffn_start_layer_index)
|
318 |
+
)
|
319 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
320 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) if self.use_abs_pos_emb else None
|
321 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
322 |
+
|
323 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
324 |
+
self.blocks = nn.ModuleList(
|
325 |
+
[
|
326 |
+
Block(
|
327 |
+
dim=embed_dim,
|
328 |
+
num_heads=num_heads,
|
329 |
+
mlp_ratio=mlp_ratio,
|
330 |
+
qkv_bias=qkv_bias,
|
331 |
+
qk_scale=qk_scale,
|
332 |
+
drop=drop_rate,
|
333 |
+
attn_drop=attn_drop_rate,
|
334 |
+
drop_path=dpr[i],
|
335 |
+
norm_layer=norm_layer,
|
336 |
+
with_vlffn=(i >= self.vlffn_start_layer_index),
|
337 |
+
layer_scale_init_values=layer_scale_init_values,
|
338 |
+
max_text_len=config["max_text_len"],
|
339 |
+
)
|
340 |
+
for i in range(depth)
|
341 |
+
]
|
342 |
+
)
|
343 |
+
self.norm = norm_layer(embed_dim)
|
344 |
+
|
345 |
+
if self.pos_embed is not None:
|
346 |
+
trunc_normal_(self.pos_embed, std=0.02)
|
347 |
+
trunc_normal_(self.cls_token, std=0.02)
|
348 |
+
self.apply(self._init_weights)
|
349 |
+
|
350 |
+
def _init_weights(self, m):
|
351 |
+
if isinstance(m, nn.Linear):
|
352 |
+
trunc_normal_(m.weight, std=0.02)
|
353 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
354 |
+
nn.init.constant_(m.bias, 0)
|
355 |
+
elif isinstance(m, nn.LayerNorm):
|
356 |
+
nn.init.constant_(m.bias, 0)
|
357 |
+
nn.init.constant_(m.weight, 1.0)
|
358 |
+
|
359 |
+
@torch.jit.ignore
|
360 |
+
def no_weight_decay(self):
|
361 |
+
return {"pos_embed", "cls_token"}
|
362 |
+
|
363 |
+
def visual_embed(self, _x):
|
364 |
+
x = self.patch_embed(_x)
|
365 |
+
x = x.flatten(2).transpose(1, 2)
|
366 |
+
B, L, _ = x.shape
|
367 |
+
|
368 |
+
cls_tokens = self.cls_token.expand(B, -1, -1)
|
369 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
370 |
+
|
371 |
+
if self.pos_embed is not None:
|
372 |
+
x = x + self.pos_embed
|
373 |
+
x = self.pos_drop(x)
|
374 |
+
|
375 |
+
x_mask = torch.ones(x.shape[0], x.shape[1])
|
376 |
+
|
377 |
+
return x, x_mask
|
378 |
+
|
379 |
+
|
380 |
+
# VLMo base/p16
|
381 |
+
@register_model
|
382 |
+
def vlmo_base_patch16(pretrained=False, **kwargs):
|
383 |
+
img_size = kwargs.pop("img_size", 224)
|
384 |
+
model = MultiWayTransformer(
|
385 |
+
img_size=img_size,
|
386 |
+
patch_size=16,
|
387 |
+
embed_dim=768,
|
388 |
+
depth=12,
|
389 |
+
num_heads=12,
|
390 |
+
mlp_ratio=4,
|
391 |
+
qkv_bias=True,
|
392 |
+
vlffn_start_layer_index=10,
|
393 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
394 |
+
**kwargs,
|
395 |
+
)
|
396 |
+
return model
|
vlmo/modules/objectives.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
|
3 |
+
|
4 |
+
def init_weights(module):
|
5 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
6 |
+
module.weight.data.normal_(mean=0.0, std=0.02)
|
7 |
+
elif isinstance(module, nn.LayerNorm):
|
8 |
+
module.bias.data.zero_()
|
9 |
+
module.weight.data.fill_(1.0)
|
10 |
+
|
11 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
12 |
+
module.bias.data.zero_()
|
vlmo/modules/vlmo_module.py
ADDED
@@ -0,0 +1,405 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import os
|
3 |
+
import time
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import pytorch_lightning as pl
|
7 |
+
import torch
|
8 |
+
import torch.distributed as dist
|
9 |
+
import torch.nn as nn
|
10 |
+
from pytorch_lightning.utilities.rank_zero import rank_zero_info
|
11 |
+
from timm.models import create_model
|
12 |
+
from transformers import AutoTokenizer, BertTokenizer, XLMRobertaTokenizer # noqa
|
13 |
+
from vlmo.modules import heads, objectives, vlmo_utils
|
14 |
+
from vlmo.tokenizer.tokenization_glm import GLMChineseTokenizer # noqa
|
15 |
+
from vlmo.torchscale.architecture.encoder import Encoder
|
16 |
+
from vlmo.torchscale.model.BEiT3 import BEiT3 as ts_backbone
|
17 |
+
from vlmo.transforms.utils import inception_normalize as img_norm
|
18 |
+
|
19 |
+
from .modeling_utils import _get_base_config, _get_large_config, _get_huge_config, trunc_normal_ # noqa
|
20 |
+
|
21 |
+
|
22 |
+
def convert_pl_ckpt(state_dict, num_visual_token=197):
|
23 |
+
print("start convert_pl_ckpt!!!")
|
24 |
+
new_state_dict = {}
|
25 |
+
for key in state_dict:
|
26 |
+
value = state_dict[key]
|
27 |
+
if "visual_tokenizer" in key:
|
28 |
+
continue
|
29 |
+
elif "backbone.encoder.embed_positions.A.weight" in key:
|
30 |
+
if value.shape[0] < num_visual_token + 2:
|
31 |
+
N = value.shape[0] - 3
|
32 |
+
dim = value.shape[-1]
|
33 |
+
class_pos_embed = value[:3, ]
|
34 |
+
patch_pos_embed = value[3:, ]
|
35 |
+
w0, h0 = int(math.sqrt(num_visual_token - 1)), int(math.sqrt(num_visual_token - 1))
|
36 |
+
patch_pos_embed = patch_pos_embed.float()
|
37 |
+
patch_pos_embed = nn.functional.interpolate(
|
38 |
+
patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
|
39 |
+
size=(w0, h0),
|
40 |
+
mode="area",
|
41 |
+
)
|
42 |
+
patch_pos_embed = patch_pos_embed.to(class_pos_embed.dtype)
|
43 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(-1, dim)
|
44 |
+
new_value = torch.cat((class_pos_embed, patch_pos_embed), dim=0)
|
45 |
+
new_state_dict[key] = new_value
|
46 |
+
print("reshape ", key, "raw shape: ", value.shape, "new shape: ", new_value.shape, num_visual_token)
|
47 |
+
elif value.shape[0] > num_visual_token + 2:
|
48 |
+
new_state_dict[key] = value[: num_visual_token + 2, :]
|
49 |
+
print("first ", key, "raw shape: ", value.shape, new_state_dict[key].shape, num_visual_token)
|
50 |
+
else:
|
51 |
+
new_state_dict[key] = value
|
52 |
+
print("raw shape")
|
53 |
+
else:
|
54 |
+
new_state_dict[key] = state_dict[key]
|
55 |
+
|
56 |
+
return new_state_dict
|
57 |
+
|
58 |
+
|
59 |
+
def convert_deepspeed_ckpt(state_dict, num_visual_token=197):
|
60 |
+
new_state_dict = {}
|
61 |
+
for key in state_dict:
|
62 |
+
if key.startswith("_forward_module."):
|
63 |
+
new_key = key[len("_forward_module."):]
|
64 |
+
value = state_dict[key]
|
65 |
+
new_state_dict[new_key] = value
|
66 |
+
if "visual_tokenizer.encoder.pos_embed" in new_key or "visual_tokenizer.decoder.pos_embed" in new_key:
|
67 |
+
if value.shape[1] != num_visual_token:
|
68 |
+
N = value.shape[1] - 1
|
69 |
+
dim = value.shape[-1]
|
70 |
+
class_pos_embed = value[:, 0]
|
71 |
+
patch_pos_embed = value[:, 1:]
|
72 |
+
w0, h0 = int(math.sqrt(num_visual_token - 1)), int(math.sqrt(num_visual_token - 1))
|
73 |
+
patch_pos_embed = patch_pos_embed.float()
|
74 |
+
patch_pos_embed = nn.functional.interpolate(
|
75 |
+
patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
|
76 |
+
size=(w0, h0),
|
77 |
+
mode="area",
|
78 |
+
)
|
79 |
+
patch_pos_embed = patch_pos_embed.to(class_pos_embed.dtype)
|
80 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
81 |
+
new_value = torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
82 |
+
new_state_dict[new_key] = new_value
|
83 |
+
print("reshape ", new_key, "raw shape: ", value.shape, "new_shape: ", new_value.shape)
|
84 |
+
if "backbone.encoder.embed_positions.A.weight" in new_key:
|
85 |
+
if value.shape[1] != num_visual_token + 2:
|
86 |
+
N = value.shape[0] - 3
|
87 |
+
dim = value.shape[-1]
|
88 |
+
class_pos_embed = value[:3, ]
|
89 |
+
patch_pos_embed = value[3:, ]
|
90 |
+
w0, h0 = int(math.sqrt(num_visual_token - 1)), int(math.sqrt(num_visual_token - 1))
|
91 |
+
patch_pos_embed = patch_pos_embed.float()
|
92 |
+
patch_pos_embed = nn.functional.interpolate(
|
93 |
+
patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
|
94 |
+
size=(w0, h0),
|
95 |
+
mode="area",
|
96 |
+
)
|
97 |
+
patch_pos_embed = patch_pos_embed.to(class_pos_embed.dtype)
|
98 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(-1, dim)
|
99 |
+
new_value = torch.cat((class_pos_embed, patch_pos_embed), dim=0)
|
100 |
+
new_state_dict[new_key] = new_value
|
101 |
+
print("reshape ", new_key, "raw shape: ", value.shape, "new_shape: ", new_value.shape)
|
102 |
+
|
103 |
+
else:
|
104 |
+
new_state_dict[key] = state_dict[key]
|
105 |
+
|
106 |
+
return new_state_dict
|
107 |
+
|
108 |
+
|
109 |
+
def get_visual_tokenizer(config):
|
110 |
+
tokenizer_name = config["tokenizer_model"]
|
111 |
+
print(f"Creating visual tokenizer: {tokenizer_name}")
|
112 |
+
model = create_model(
|
113 |
+
config["tokenizer_model"],
|
114 |
+
img_size=config["image_size"],
|
115 |
+
n_code=config["codebook_size"],
|
116 |
+
code_dim=config["codebook_dim"],
|
117 |
+
).eval()
|
118 |
+
return model
|
119 |
+
|
120 |
+
|
121 |
+
def get_pretrained_tokenizer(tokenizer_type, from_pretrained):
|
122 |
+
_Tokenizer = eval(f"{tokenizer_type}")
|
123 |
+
if torch.distributed.is_initialized():
|
124 |
+
if torch.distributed.get_rank() == 0:
|
125 |
+
_Tokenizer.from_pretrained(from_pretrained)
|
126 |
+
torch.distributed.barrier()
|
127 |
+
return _Tokenizer.from_pretrained(from_pretrained)
|
128 |
+
|
129 |
+
|
130 |
+
class VLMo(pl.LightningModule):
|
131 |
+
def __init__(self, config):
|
132 |
+
super().__init__()
|
133 |
+
self.save_hyperparameters()
|
134 |
+
s_t = time.time()
|
135 |
+
|
136 |
+
# tokenizer & backbone
|
137 |
+
self.img_size = config["image_size"]
|
138 |
+
if not config["test_only"]:
|
139 |
+
self.visual_tokenizer = get_visual_tokenizer(config)
|
140 |
+
kwargs = {}
|
141 |
+
if "encoder_attention_heads" in config:
|
142 |
+
kwargs["encoder_attention_heads"] = config["encoder_attention_heads"]
|
143 |
+
if "atorch_config" in config and config["atorch_config"]:
|
144 |
+
checkpoint_activations = False # ?
|
145 |
+
else:
|
146 |
+
checkpoint_activations = config["checkpoint_activations"]
|
147 |
+
args = eval(f'_get_{config["beit_version"]}_config')(
|
148 |
+
img_size=config["image_size"],
|
149 |
+
patch_size=config["patch_size"],
|
150 |
+
vocab_size=config["vocab_size"],
|
151 |
+
encoder_layers=config["encoder_layers"],
|
152 |
+
encoder_embed_dim=config["encoder_embed_dim"],
|
153 |
+
checkpoint_activations=checkpoint_activations,
|
154 |
+
share_layer=config["share_layer"],
|
155 |
+
share_attn=config["share_attn"],
|
156 |
+
deepnorm=config["deepnorm"],
|
157 |
+
mask_ratio=config["mask_ratio"],
|
158 |
+
max_text_len=config["max_text_len"],
|
159 |
+
one_attn=config["one_attn"],
|
160 |
+
**kwargs,
|
161 |
+
)
|
162 |
+
self.num_features = args.encoder_embed_dim
|
163 |
+
self.out_features = config["out_embed_dim"]
|
164 |
+
self.cap_onlytext = config["cap_onlytext"]
|
165 |
+
self.lang = config["lang"]
|
166 |
+
self.num_frames = config["num_frames"]
|
167 |
+
self.tokenizer_type = config["tokenizer_type"]
|
168 |
+
self.text_tokenizer = get_pretrained_tokenizer(self.tokenizer_type, from_pretrained=config["tokenizer"]) # noqa
|
169 |
+
print("BEiT args", args.__dict__)
|
170 |
+
self.backbone = ts_backbone(args)
|
171 |
+
|
172 |
+
self.use_vl = config["beit3_vl_layers"] > 0
|
173 |
+
if self.use_vl:
|
174 |
+
args.encoder_layers = config["beit3_vl_layers"]
|
175 |
+
self.backbone_vl = Encoder(args)
|
176 |
+
|
177 |
+
self.norm = nn.LayerNorm(self.num_features, eps=1e-6)
|
178 |
+
|
179 |
+
# task layers
|
180 |
+
self.pooler = heads.Pooler(self.num_features)
|
181 |
+
self.pooler.apply(objectives.init_weights)
|
182 |
+
|
183 |
+
# contrastive loss (or sampling for global hard negative)
|
184 |
+
if config["loss_names"]["itc"] > 0:
|
185 |
+
self.itc_text_proj = heads.ITCHead(self.num_features, self.out_features)
|
186 |
+
self.itc_image_proj = heads.ITCHead(self.num_features, self.out_features)
|
187 |
+
self.itc_text_proj.apply(objectives.init_weights)
|
188 |
+
self.itc_image_proj.apply(objectives.init_weights)
|
189 |
+
|
190 |
+
self.itc_vl_text_proj = heads.ITCHead(self.num_features, self.out_features)
|
191 |
+
self.itc_vl_image_proj = heads.ITCHead(self.num_features, self.out_features)
|
192 |
+
self.itc_vl_text_proj.apply(objectives.init_weights)
|
193 |
+
self.itc_vl_image_proj.apply(objectives.init_weights)
|
194 |
+
|
195 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
196 |
+
self.logit_vl_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
197 |
+
|
198 |
+
lp_s_t = time.time()
|
199 |
+
|
200 |
+
self.load_pretrained_weight()
|
201 |
+
load_pretrain_time = time.time() - lp_s_t
|
202 |
+
|
203 |
+
self.current_tasks = list()
|
204 |
+
|
205 |
+
# ===================== load downstream (test_only) ======================
|
206 |
+
|
207 |
+
if self.hparams.config["load_path"] != "" and self.hparams.config["test_only"]:
|
208 |
+
rank_zero_info("Load ckpt from: {}".format(self.hparams.config["load_path"]))
|
209 |
+
ckpt = torch.load(self.hparams.config["load_path"], map_location="cpu")
|
210 |
+
|
211 |
+
state_dict = None
|
212 |
+
|
213 |
+
for state_dict_key in ("state_dict", "module", "model"):
|
214 |
+
if state_dict_key in ckpt:
|
215 |
+
rank_zero_info("Read state dict from ckpt[%s]. " % state_dict_key)
|
216 |
+
state_dict = ckpt[state_dict_key]
|
217 |
+
break
|
218 |
+
if state_dict_key == "module":
|
219 |
+
state_dict = convert_deepspeed_ckpt(state_dict, self.backbone.vision_embed.num_position_embeddings())
|
220 |
+
if state_dict_key == "state_dict":
|
221 |
+
state_dict = convert_pl_ckpt(state_dict, self.backbone.vision_embed.num_position_embeddings())
|
222 |
+
if state_dict is None:
|
223 |
+
if list(ckpt.keys())[0].startswith('_forward_module.'):
|
224 |
+
rank_zero_info("Read state dict from ckpt with _forward_module prefix. ")
|
225 |
+
state_dict = convert_deepspeed_ckpt(ckpt, self.backbone.vision_embed.num_position_embeddings())
|
226 |
+
else:
|
227 |
+
rank_zero_info("Read state dict from ckpt. ")
|
228 |
+
state_dict = ckpt
|
229 |
+
|
230 |
+
missing_keys, unexpected_keys = self.load_state_dict(state_dict, strict=False)
|
231 |
+
rank_zero_info("missing_keys: {}".format(missing_keys))
|
232 |
+
rank_zero_info("unexpected_keys: {}".format(unexpected_keys))
|
233 |
+
|
234 |
+
construct_time = time.time() - s_t
|
235 |
+
print(
|
236 |
+
f"Process {os.getpid()}. VLMo Constructor time: {construct_time}s;",
|
237 |
+
f"load_pretrain_time: {load_pretrain_time}s",
|
238 |
+
flush=True,
|
239 |
+
)
|
240 |
+
# coalesce backbone calls
|
241 |
+
self._coalesce_backbone = config["coalesce_backbone"]
|
242 |
+
self._mask_data = config["mask_data"]
|
243 |
+
self._backbone_inputs = {}
|
244 |
+
self._backbone_inputs_current_size = 0
|
245 |
+
self._backbone_inputs_keys = {}
|
246 |
+
self._backbone_outputs = None
|
247 |
+
self._default_attn_masks = {}
|
248 |
+
self._itc_group = None
|
249 |
+
self._itc_aggregate_dict = None
|
250 |
+
self._itc_mask = config["itc_mask"]
|
251 |
+
self._local_loss = config["local_loss"]
|
252 |
+
self._aggregate_nodes = config["aggregate_nodes"]
|
253 |
+
self.accumulated_batches_reached = False
|
254 |
+
vlmo_utils.set_task(self)
|
255 |
+
self._only_itc_single_machine = (
|
256 |
+
self._aggregate_nodes > 0 and len(self.current_tasks) == 1 and "itc" in self.current_tasks
|
257 |
+
)
|
258 |
+
self._split_data_for_imagemlm = config["split_data_for_imagemlm"]
|
259 |
+
self.log_metric_steps = config["log_metric_steps"]
|
260 |
+
|
261 |
+
def _init_weights(self, m):
|
262 |
+
if isinstance(m, nn.Linear):
|
263 |
+
trunc_normal_(m.weight, std=0.02)
|
264 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
265 |
+
nn.init.constant_(m.bias, 0)
|
266 |
+
elif isinstance(m, nn.LayerNorm):
|
267 |
+
nn.init.constant_(m.bias, 0)
|
268 |
+
nn.init.constant_(m.weight, 1.0)
|
269 |
+
|
270 |
+
def fix_init_weight(self):
|
271 |
+
def rescale(param, layer_id):
|
272 |
+
param.div_(math.sqrt(2.0 * layer_id))
|
273 |
+
|
274 |
+
for layer_id, layer in enumerate(self.backbone.encoder.layers):
|
275 |
+
rescale(layer.self_attn.v_proj.A.weight.data, layer_id + 1)
|
276 |
+
rescale(layer.self_attn.v_proj.B.weight.data, layer_id + 1)
|
277 |
+
rescale(layer.self_attn.out_proj.A.weight.data, layer_id + 1)
|
278 |
+
rescale(layer.self_attn.out_proj.B.weight.data, layer_id + 1)
|
279 |
+
rescale(layer.ffn.A.fc2.weight.data, layer_id + 1)
|
280 |
+
rescale(layer.ffn.B.fc2.weight.data, layer_id + 1)
|
281 |
+
|
282 |
+
if self.use_vl:
|
283 |
+
pre_layers = len(self.backbone.encoder.layers) + 1
|
284 |
+
for layer_id, layer in enumerate(self.backbone_vl.layers):
|
285 |
+
rescale(layer.self_attn.v_proj.A.weight.data, layer_id + pre_layers)
|
286 |
+
rescale(layer.self_attn.v_proj.B.weight.data, layer_id + pre_layers)
|
287 |
+
rescale(layer.self_attn.out_proj.A.weight.data, layer_id + pre_layers)
|
288 |
+
rescale(layer.self_attn.out_proj.B.weight.data, layer_id + pre_layers)
|
289 |
+
rescale(layer.ffn.A.fc2.weight.data, layer_id + pre_layers)
|
290 |
+
rescale(layer.ffn.B.fc2.weight.data, layer_id + pre_layers)
|
291 |
+
|
292 |
+
def load_pretrained_weight(self):
|
293 |
+
if self.hparams.config["load_path"] != "" and not self.hparams.config["test_only"]:
|
294 |
+
config = self.hparams.config
|
295 |
+
ckpt = torch.load(self.hparams.config["load_path"], map_location="cpu")
|
296 |
+
rank_zero_info("Load ckpt from: {}".format(self.hparams.config["load_path"]))
|
297 |
+
|
298 |
+
state_dict = None
|
299 |
+
|
300 |
+
for state_dict_key in ("state_dict", "module", "model"):
|
301 |
+
if state_dict_key in ckpt:
|
302 |
+
rank_zero_info("Read state dict from ckpt[%s]. " % state_dict_key)
|
303 |
+
state_dict = ckpt[state_dict_key]
|
304 |
+
break
|
305 |
+
if state_dict_key == "module":
|
306 |
+
state_dict = convert_deepspeed_ckpt(state_dict, self.backbone.vision_embed.num_position_embeddings())
|
307 |
+
if state_dict_key == "state_dict":
|
308 |
+
state_dict = convert_pl_ckpt(state_dict, self.backbone.vision_embed.num_position_embeddings())
|
309 |
+
if state_dict is None:
|
310 |
+
if list(ckpt.keys())[0].startswith('_forward_module.'):
|
311 |
+
rank_zero_info("Read state dict from ckpt with _forward_module prefix. ")
|
312 |
+
state_dict = convert_deepspeed_ckpt(ckpt,
|
313 |
+
self.backbone.vision_embed.num_position_embeddings())
|
314 |
+
else:
|
315 |
+
rank_zero_info("Read state dict from ckpt. ")
|
316 |
+
state_dict = ckpt
|
317 |
+
|
318 |
+
missing_keys, unexpected_keys = self.load_state_dict(state_dict, strict=False)
|
319 |
+
missing_keys = [k for k in missing_keys if "itc_teacher" not in k]
|
320 |
+
rank_zero_info("missing_keys: {}".format(missing_keys))
|
321 |
+
rank_zero_info("unexpected_keys: {}".format(unexpected_keys))
|
322 |
+
|
323 |
+
def infer_text(
|
324 |
+
self,
|
325 |
+
batch,
|
326 |
+
mask_text=False,
|
327 |
+
):
|
328 |
+
do_mlm = "_mlm" if mask_text else ""
|
329 |
+
text_ids = batch[f"text_ids{do_mlm}"]
|
330 |
+
text_labels = batch[f"text_labels{do_mlm}"]
|
331 |
+
text_masks = batch[f"text_masks"]
|
332 |
+
text_embed = self.backbone.text_embed(text_ids)
|
333 |
+
text_padding_position = 1 - text_masks
|
334 |
+
lffn_hiddens = self.backbone(
|
335 |
+
textual_tokens=text_ids,
|
336 |
+
text_padding_position=text_padding_position,
|
337 |
+
)["encoder_out"]
|
338 |
+
vlffn_hiddens = self.backbone_vl(
|
339 |
+
src_tokens=None,
|
340 |
+
token_embeddings=lffn_hiddens,
|
341 |
+
encoder_padding_mask=text_padding_position,
|
342 |
+
multiway_split_position=-1,
|
343 |
+
)["encoder_out"]
|
344 |
+
|
345 |
+
cls_feats = self.itc_text_proj(lffn_hiddens[:, 0])
|
346 |
+
cls_feats = cls_feats / cls_feats.norm(dim=-1, keepdim=True)
|
347 |
+
|
348 |
+
cls_vlffn_feats = self.itc_vl_text_proj(vlffn_hiddens[:, 0])
|
349 |
+
cls_vlffn_feats = cls_vlffn_feats / cls_vlffn_feats.norm(dim=-1, keepdim=True)
|
350 |
+
|
351 |
+
ret = {
|
352 |
+
"cls_feats": cls_feats,
|
353 |
+
"cls_vlffn_feats": cls_vlffn_feats,
|
354 |
+
"text_embed": text_embed,
|
355 |
+
}
|
356 |
+
|
357 |
+
return ret
|
358 |
+
|
359 |
+
def infer_image(
|
360 |
+
self,
|
361 |
+
batch,
|
362 |
+
mask_image=False,
|
363 |
+
image_token_type_idx=1,
|
364 |
+
image_embeds=None,
|
365 |
+
image_masks=None,
|
366 |
+
):
|
367 |
+
if f"image_{image_token_type_idx - 1}" in batch:
|
368 |
+
imgkey = f"image_{image_token_type_idx - 1}"
|
369 |
+
else:
|
370 |
+
imgkey = "image"
|
371 |
+
|
372 |
+
img = batch[imgkey][0]
|
373 |
+
if mask_image:
|
374 |
+
image_masks = batch[f"{imgkey}_masks"][0].flatten(1)
|
375 |
+
|
376 |
+
with torch.no_grad():
|
377 |
+
img = self.visual_tokenizer.pre_process(img)
|
378 |
+
quantize, embed_ind, _ = self.visual_tokenizer.encode(img)
|
379 |
+
image_ids = embed_ind.view(img.shape[0], -1)
|
380 |
+
|
381 |
+
image_labels = torch.full_like(image_ids, -100)
|
382 |
+
bool_masked_pos = image_masks.to(torch.bool)
|
383 |
+
image_labels[bool_masked_pos] = image_ids[bool_masked_pos]
|
384 |
+
|
385 |
+
img_tensor = img_norm(img)
|
386 |
+
vffn_hiddens = self.backbone(visual_tokens=img_tensor)["encoder_out"]
|
387 |
+
vlffn_hiddens = self.backbone_vl(
|
388 |
+
src_tokens=None,
|
389 |
+
token_embeddings=vffn_hiddens,
|
390 |
+
multiway_split_position=-1,
|
391 |
+
)["encoder_out"]
|
392 |
+
|
393 |
+
cls_feats = self.itc_image_proj(vffn_hiddens[:, 0])
|
394 |
+
cls_feats = cls_feats / cls_feats.norm(dim=-1, keepdim=True)
|
395 |
+
|
396 |
+
cls_vlffn_feats = self.itc_vl_image_proj(vlffn_hiddens[:, 0])
|
397 |
+
cls_vlffn_feats = cls_vlffn_feats / cls_vlffn_feats.norm(dim=-1, keepdim=True)
|
398 |
+
|
399 |
+
ret = {
|
400 |
+
"image_feats": vffn_hiddens,
|
401 |
+
"cls_feats": cls_feats,
|
402 |
+
"cls_vlffn_feats": cls_vlffn_feats,
|
403 |
+
}
|
404 |
+
|
405 |
+
return ret
|
vlmo/modules/vlmo_utils.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def set_task(pl_module):
|
2 |
+
pl_module.current_tasks = [k for k, v in pl_module.hparams.config["loss_names"].items() if v >= 1]
|
3 |
+
return
|
4 |
+
|
5 |
+
|
6 |
+
def no_sync_module_apply(module, fn):
|
7 |
+
"""FSDP module .apply will use _unshard_params_recurse which will sync params across ranks.
|
8 |
+
using this function when apply fn is unnecessary to sync params across ranks.
|
9 |
+
"""
|
10 |
+
for child in module.children():
|
11 |
+
fn(child)
|
12 |
+
no_sync_module_apply(child, fn)
|
vlmo/tokenizer/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
# Copyright (c) Antfin, Inc. All rights reserved.
|
3 |
+
|
4 |
+
from __future__ import absolute_import
|
5 |
+
from __future__ import division
|
6 |
+
from __future__ import print_function
|
vlmo/tokenizer/sp.model
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b7fe3bcc8d284fcb782691411e8b6fd4f45d7245565b094de6ab795e66bcd32f
|
3 |
+
size 2270960
|
vlmo/tokenizer/tokenization_glm.py
ADDED
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from shutil import copyfile
|
3 |
+
from typing import Optional, Tuple, List, Union
|
4 |
+
|
5 |
+
import sentencepiece as spm
|
6 |
+
import torch
|
7 |
+
from transformers import PreTrainedTokenizer
|
8 |
+
from transformers.models.auto.tokenization_auto import get_tokenizer_config
|
9 |
+
from transformers.tokenization_utils_base import BatchEncoding
|
10 |
+
from transformers.utils import logging
|
11 |
+
|
12 |
+
logger = logging.get_logger(__name__)
|
13 |
+
|
14 |
+
|
15 |
+
class GLMBatchEncoding(BatchEncoding):
|
16 |
+
def to(self, device: Union[str, "torch.device"]) -> "BatchEncoding":
|
17 |
+
"""
|
18 |
+
Send all values to device by calling `v.to(device)` (PyTorch only).
|
19 |
+
|
20 |
+
Args:
|
21 |
+
device (`str` or `torch.device`): The device to put the tensors on.
|
22 |
+
|
23 |
+
Returns:
|
24 |
+
[`BatchEncoding`]: The same instance after modification.
|
25 |
+
"""
|
26 |
+
|
27 |
+
# This check catches things like APEX blindly calling "to" on all inputs to a module
|
28 |
+
# Otherwise it passes the casts down and casts the LongTensor containing the token idxs
|
29 |
+
# into a HalfTensor
|
30 |
+
if isinstance(device, str) or isinstance(device, int):
|
31 |
+
#if isinstance(device, str) or _is_torch_device(device) or isinstance(device, int):
|
32 |
+
self.data = {k: v.to(device=device) if torch.is_tensor(v) else v for k, v in self.data.items()}
|
33 |
+
else:
|
34 |
+
logger.warning(f"Attempting to cast a BatchEncoding to type {str(device)}. This is not supported.")
|
35 |
+
return self
|
36 |
+
|
37 |
+
|
38 |
+
class GLMTokenizerMixin:
|
39 |
+
@property
|
40 |
+
def sop_token(self) -> Optional[str]:
|
41 |
+
return "<|startofpiece|>"
|
42 |
+
|
43 |
+
@property
|
44 |
+
def sop_token_id(self) -> Optional[int]:
|
45 |
+
"""
|
46 |
+
`Optional[int]`: Id of the start token in the vocabulary, used when training a model with autoregressive blank filling.
|
47 |
+
"""
|
48 |
+
return self.convert_tokens_to_ids(self.sop_token)
|
49 |
+
|
50 |
+
@property
|
51 |
+
def eop_token(self) -> Optional[str]:
|
52 |
+
return "<|endofpiece|>"
|
53 |
+
|
54 |
+
@property
|
55 |
+
def eop_token_id(self) -> Optional[int]:
|
56 |
+
"""
|
57 |
+
`Optional[int]`: Id of the end token in the vocabulary, used when training a model with autoregressive blank filling.
|
58 |
+
"""
|
59 |
+
return self.convert_tokens_to_ids(self.eop_token)
|
60 |
+
|
61 |
+
@property
|
62 |
+
def gmask_token_id(self) -> int:
|
63 |
+
return self.convert_tokens_to_ids("[gMASK]")
|
64 |
+
|
65 |
+
@property
|
66 |
+
def smask_token_id(self) -> int:
|
67 |
+
return self.convert_tokens_to_ids("[sMASK]")
|
68 |
+
|
69 |
+
@property
|
70 |
+
def mask_token_ids(self):
|
71 |
+
return [self.mask_token_id, self.smask_token_id, self.gmask_token_id]
|
72 |
+
|
73 |
+
def _build_input_for_multiple_choice(self, context, choices):
|
74 |
+
context_id = context["input_ids"]
|
75 |
+
if torch.is_tensor(context_id):
|
76 |
+
context_id = context_id.tolist()
|
77 |
+
|
78 |
+
division = len(context_id)
|
79 |
+
mask_position = context_id.index(self.mask_token_id)
|
80 |
+
|
81 |
+
token = torch.tensor(context_id, dtype=torch.long)
|
82 |
+
attention_mask = [context["attention_mask"].expand(division, -1)]
|
83 |
+
position_id = torch.arange(division, dtype=torch.long)
|
84 |
+
block_position_id = torch.zeros(division, dtype=torch.long)
|
85 |
+
|
86 |
+
choice_ids, choice_indices = [], []
|
87 |
+
|
88 |
+
for choice_str in choices:
|
89 |
+
choice = torch.tensor(self(choice_str, add_special_tokens=False, padding=False)['input_ids'],
|
90 |
+
dtype=torch.long)
|
91 |
+
choice_ids.append(choice)
|
92 |
+
choice_indices.append(torch.arange(len(token), len(token) + len(choice), dtype=torch.long))
|
93 |
+
attention_mask.append(torch.tril(torch.ones((len(choice), len(choice)), dtype=torch.long)))
|
94 |
+
|
95 |
+
token = torch.cat((token, torch.tensor([self.sop_token_id], dtype=torch.long), choice[:-1]))
|
96 |
+
position_id = torch.cat((position_id, torch.tensor([mask_position] * len(choice), dtype=torch.long)))
|
97 |
+
block_position_id = torch.cat((block_position_id, torch.arange(1, 1 + len(choice), dtype=torch.long)))
|
98 |
+
|
99 |
+
attention_mask = torch.block_diag(*attention_mask)
|
100 |
+
attention_mask[division:, :division] = context["attention_mask"].unsqueeze(0)
|
101 |
+
|
102 |
+
return {
|
103 |
+
"input_ids": token,
|
104 |
+
"position_ids": torch.stack((position_id, block_position_id)),
|
105 |
+
"attention_mask": attention_mask,
|
106 |
+
"choice_ids": choice_ids,
|
107 |
+
"choice_indices": choice_indices
|
108 |
+
}
|
109 |
+
|
110 |
+
def _pad_batch(self, tokens, position_ids, attention_mask, max_seq_length):
|
111 |
+
pad_length = max_seq_length - len(tokens)
|
112 |
+
attention_mask = torch.nn.functional.pad(
|
113 |
+
attention_mask,
|
114 |
+
(0, pad_length, 0, pad_length),
|
115 |
+
mode="constant",
|
116 |
+
value=0,
|
117 |
+
)
|
118 |
+
tokens = torch.cat((tokens, torch.zeros(pad_length, dtype=torch.long)))
|
119 |
+
position_ids = torch.cat((position_ids, position_ids[..., -1:].expand(-1, pad_length)), dim=-1)
|
120 |
+
return tokens, position_ids, attention_mask
|
121 |
+
|
122 |
+
def _collate(self, samples):
|
123 |
+
TILE = 1
|
124 |
+
length_to_pad = (max(map(lambda spl: len(spl["input_ids"]), samples)) + TILE - 1) // TILE * TILE
|
125 |
+
|
126 |
+
token_batch, position_id_batch, attention_mask_batch = [], [], []
|
127 |
+
choices_batch, choice_target_ids_batch = [], []
|
128 |
+
|
129 |
+
for sample in samples:
|
130 |
+
token, position_id, attention_mask = self._pad_batch(
|
131 |
+
sample["input_ids"], sample["position_ids"], sample["attention_mask"], length_to_pad
|
132 |
+
)
|
133 |
+
token_batch.append(token)
|
134 |
+
position_id_batch.append(position_id)
|
135 |
+
attention_mask_batch.append(attention_mask)
|
136 |
+
choices_batch.append(sample["choice_ids"])
|
137 |
+
choice_target_ids_batch.append(sample["choice_indices"])
|
138 |
+
return {
|
139 |
+
"input_ids": torch.stack(token_batch),
|
140 |
+
"position_ids": torch.stack(position_id_batch),
|
141 |
+
"attention_mask": torch.stack(attention_mask_batch).unsqueeze(1),
|
142 |
+
"choice_ids": choices_batch,
|
143 |
+
"choice_indices": choice_target_ids_batch,
|
144 |
+
}
|
145 |
+
|
146 |
+
def build_inputs_for_multiple_choice(self, model_input: BatchEncoding, choices, max_length=None):
|
147 |
+
samples = [{key: value[i] for key, value in model_input.items()} for i in range(len(model_input["input_ids"]))]
|
148 |
+
samples = [self._build_input_for_multiple_choice(sample, choice) for sample, choice in
|
149 |
+
zip(samples, choices)]
|
150 |
+
inputs = self._collate(samples)
|
151 |
+
return GLMBatchEncoding(inputs)
|
152 |
+
|
153 |
+
def build_inputs_for_generation(self, model_input: BatchEncoding, max_gen_length=512, targets=None, padding=False):
|
154 |
+
mask_ids = self.mask_token_ids
|
155 |
+
input_ids = model_input.input_ids
|
156 |
+
batch_size, seq_length = input_ids.shape[:2]
|
157 |
+
position_id, block_position_id = list(range(seq_length)), [0 for _ in range(seq_length)]
|
158 |
+
position_ids, block_position_ids = [], []
|
159 |
+
labels = None
|
160 |
+
if targets is not None:
|
161 |
+
is_batched = isinstance(targets, (list, tuple))
|
162 |
+
targets = self(targets, add_special_tokens=False, padding=False).input_ids
|
163 |
+
if not is_batched:
|
164 |
+
targets = [targets]
|
165 |
+
assert len(targets) == len(input_ids)
|
166 |
+
targets = [(target + [self.eop_token_id])[:max_gen_length] for target in targets]
|
167 |
+
if not padding:
|
168 |
+
max_gen_length = max(map(len, targets))
|
169 |
+
targets = [[self.sop_token_id] + target for target in targets]
|
170 |
+
labels = [target[1:] for target in targets]
|
171 |
+
targets = [target + [self.pad_token_id] * (max_gen_length + 1 - len(target)) for target in targets]
|
172 |
+
labels = [label + [-100] * (max_gen_length - len(label)) for label in labels]
|
173 |
+
targets = torch.tensor(targets, dtype=input_ids.dtype, device=input_ids.device)
|
174 |
+
labels = torch.tensor(labels, dtype=input_ids.dtype, device=input_ids.device)
|
175 |
+
labels = torch.cat((input_ids.new_full((batch_size, seq_length), -100), labels), dim=1)
|
176 |
+
for i in range(batch_size):
|
177 |
+
mask_positions = []
|
178 |
+
for mask_id in mask_ids:
|
179 |
+
mask_positions += (input_ids[i] == mask_id).nonzero(as_tuple=True)[0].tolist()
|
180 |
+
if not mask_positions:
|
181 |
+
raise ValueError("Cannot find mask token in the input")
|
182 |
+
mask_positions.sort()
|
183 |
+
mask_pos = mask_positions[0]
|
184 |
+
position_ids.append(position_id + [mask_pos] * max_gen_length)
|
185 |
+
block_position_ids.append(block_position_id + list(range(1, max_gen_length + 1)))
|
186 |
+
position_ids = torch.tensor(position_ids, dtype=input_ids.dtype, device=input_ids.device)
|
187 |
+
block_position_ids = torch.tensor(block_position_ids, dtype=input_ids.dtype, device=input_ids.device)
|
188 |
+
position_ids = torch.stack((position_ids, block_position_ids), dim=1)
|
189 |
+
attention_mask = model_input.attention_mask
|
190 |
+
attention_mask = attention_mask.unsqueeze(1).expand(-1, seq_length + max_gen_length, -1)
|
191 |
+
generation_attention_mask = torch.cat([attention_mask.new_zeros((seq_length, max_gen_length)),
|
192 |
+
torch.tril(attention_mask.new_ones((max_gen_length, max_gen_length)))],
|
193 |
+
dim=0).unsqueeze(0).expand(batch_size, -1, -1)
|
194 |
+
attention_mask = torch.cat((attention_mask, generation_attention_mask), dim=2)
|
195 |
+
attention_mask = attention_mask.unsqueeze(1)
|
196 |
+
if targets is None:
|
197 |
+
input_ids = torch.cat((input_ids, input_ids.new_full((batch_size, 1), self.sop_token_id)), dim=-1)
|
198 |
+
else:
|
199 |
+
input_ids = torch.cat((input_ids, targets[:, :-1]), dim=1)
|
200 |
+
batch = {"input_ids": input_ids, "position_ids": position_ids}
|
201 |
+
if labels is None:
|
202 |
+
batch["generation_attention_mask"] = attention_mask
|
203 |
+
else:
|
204 |
+
batch["attention_mask"] = attention_mask
|
205 |
+
batch["labels"] = labels
|
206 |
+
return BatchEncoding(batch)
|
207 |
+
|
208 |
+
def encode_whitespaces(content):
|
209 |
+
for i in range(10, 1, -1):
|
210 |
+
content = content.replace(' '*i, f'<|blank_{i}|>')
|
211 |
+
return content
|
212 |
+
|
213 |
+
def decode_whitespaces(content):
|
214 |
+
for i in range(10, 1, -1):
|
215 |
+
content = content.replace(f'<|blank_{i}|>', ' '*i)
|
216 |
+
return content
|
217 |
+
|
218 |
+
|
219 |
+
class GLMChineseTokenizer(PreTrainedTokenizer, GLMTokenizerMixin):
|
220 |
+
vocab_files_names = {"vocab_file": "sp.model"}
|
221 |
+
truncation_side: str = "left"
|
222 |
+
|
223 |
+
def __init__(self, vocab_file, **kwargs):
|
224 |
+
self.vocab_file = vocab_file
|
225 |
+
self.sp_model = spm.SentencePieceProcessor()
|
226 |
+
self.sp_model.Load(vocab_file)
|
227 |
+
super().__init__(**kwargs)
|
228 |
+
|
229 |
+
@property
|
230 |
+
def vocab_size(self):
|
231 |
+
return len(self.sp_model)
|
232 |
+
|
233 |
+
def get_vocab(self):
|
234 |
+
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
|
235 |
+
vocab.update(self.added_tokens_encoder)
|
236 |
+
return vocab
|
237 |
+
|
238 |
+
def _tokenize(self, text, **kwargs):
|
239 |
+
text = encode_whitespaces(text)
|
240 |
+
return self.sp_model.EncodeAsPieces(text)
|
241 |
+
#return self.sp_model.EncodeAsPieces(text, out_type=str)
|
242 |
+
|
243 |
+
def _convert_token_to_id(self, token):
|
244 |
+
"""Converts a token (str) in an id using the vocab."""
|
245 |
+
return self.sp_model.PieceToId(token)
|
246 |
+
|
247 |
+
def _convert_id_to_token(self, index):
|
248 |
+
"""Converts an index (integer) in a token (str) using the vocab."""
|
249 |
+
return self.sp_model.IdToPiece(index)
|
250 |
+
|
251 |
+
def convert_tokens_to_string(self, tokens):
|
252 |
+
res = self.sp_model.DecodeIds(tokens)
|
253 |
+
return decode_whitespaces(res)
|
254 |
+
|
255 |
+
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
256 |
+
if not os.path.isdir(save_directory):
|
257 |
+
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
|
258 |
+
return
|
259 |
+
out_vocab_file = os.path.join(
|
260 |
+
save_directory, (filename_prefix + "-" if filename_prefix else "") + self.vocab_files_names["vocab_file"]
|
261 |
+
)
|
262 |
+
|
263 |
+
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
|
264 |
+
copyfile(self.vocab_file, out_vocab_file)
|
265 |
+
elif not os.path.isfile(self.vocab_file):
|
266 |
+
with open(out_vocab_file, "wb") as fi:
|
267 |
+
content_spiece_model = self.sp_model.serialized_model_proto()
|
268 |
+
fi.write(content_spiece_model)
|
269 |
+
|
270 |
+
return (out_vocab_file,)
|
271 |
+
|
272 |
+
def build_inputs_with_special_tokens(
|
273 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
274 |
+
) -> List[int]:
|
275 |
+
"""
|
276 |
+
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
|
277 |
+
adding special tokens. A BERT sequence has the following format:
|
278 |
+
|
279 |
+
- single sequence: ``[CLS] X [SEP]``
|
280 |
+
- pair of sequences: ``[CLS] A [SEP] B [SEP]``
|
281 |
+
|
282 |
+
Args:
|
283 |
+
token_ids_0 (:obj:`List[int]`):
|
284 |
+
List of IDs to which the special tokens will be added.
|
285 |
+
token_ids_1 (:obj:`List[int]`, `optional`):
|
286 |
+
Optional second list of IDs for sequence pairs.
|
287 |
+
|
288 |
+
Returns:
|
289 |
+
:obj:`List[int]`: List of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
|
290 |
+
"""
|
291 |
+
assert token_ids_1 is None
|
292 |
+
cls = [self.cls_token_id]
|
293 |
+
eos = [self.eos_token_id]
|
294 |
+
return cls + token_ids_0 + eos
|
295 |
+
|
296 |
+
|
297 |
+
class GLMTokenizer:
|
298 |
+
@classmethod
|
299 |
+
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
|
300 |
+
tokenizer_config = get_tokenizer_config(pretrained_model_name_or_path, **kwargs)
|
301 |
+
config_tokenizer_class = tokenizer_config.get("tokenizer_class")
|
302 |
+
|
303 |
+
if config_tokenizer_class == "GLMChineseTokenizer":
|
304 |
+
tokenizer_class = GLMChineseTokenizer
|
305 |
+
else:
|
306 |
+
raise NotImplementedError("Not implemented tokenizer type:", config_tokenizer_class)
|
307 |
+
return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
vlmo/tokenizer/tokenizer_config.json
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"name_or_path": "THUDM/glm-10b-chinese",
|
3 |
+
"eos_token": "<|endoftext|>",
|
4 |
+
"pad_token": "<|endoftext|>",
|
5 |
+
"cls_token": "[CLS]",
|
6 |
+
"mask_token": "[MASK]",
|
7 |
+
"unk_token": "[UNK]",
|
8 |
+
"add_prefix_space": false,
|
9 |
+
"tokenizer_class": "GLMChineseTokenizer",
|
10 |
+
"use_fast": false,
|
11 |
+
"auto_map": {
|
12 |
+
"AutoTokenizer": [
|
13 |
+
"tokenization_glm.GLMChineseTokenizer",
|
14 |
+
null
|
15 |
+
]
|
16 |
+
}
|
17 |
+
}
|
vlmo/torchscale/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022 Microsoft
|
2 |
+
# Licensed under The MIT License [see LICENSE for details]
|
vlmo/torchscale/architecture/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022 Microsoft
|
2 |
+
# Licensed under The MIT License [see LICENSE for details]
|
vlmo/torchscale/architecture/config.py
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022 Microsoft
|
2 |
+
# Licensed under The MIT License [see LICENSE for details]
|
3 |
+
|
4 |
+
|
5 |
+
class EncoderConfig(object):
|
6 |
+
def __init__(self, **kwargs):
|
7 |
+
self.encoder_embed_dim = kwargs.pop("encoder_embed_dim", 768)
|
8 |
+
self.encoder_attention_heads = kwargs.pop("encoder_attention_heads", 12)
|
9 |
+
self.encoder_ffn_embed_dim = kwargs.pop("encoder_ffn_embed_dim", 3072)
|
10 |
+
self.encoder_layers = kwargs.pop("encoder_layers", 12)
|
11 |
+
self.encoder_normalize_before = kwargs.pop("encoder_normalize_before", True)
|
12 |
+
self.normalize_output = kwargs.pop("normalize_output", True)
|
13 |
+
self.activation_fn = kwargs.pop("activation_fn", "gelu")
|
14 |
+
self.dropout = kwargs.pop("dropout", 0.0)
|
15 |
+
self.drop_path_rate = kwargs.pop("drop_path_rate", 0.0)
|
16 |
+
self.attention_dropout = kwargs.pop("attention_dropout", 0.0)
|
17 |
+
self.activation_dropout = kwargs.pop("activation_dropout", 0.0)
|
18 |
+
self.no_scale_embedding = kwargs.pop("no_scale_embedding", True)
|
19 |
+
self.layernorm_embedding = kwargs.pop("layernorm_embedding", False)
|
20 |
+
self.moe_freq = kwargs.pop("moe_freq", 0)
|
21 |
+
self.moe_top1_expert = kwargs.pop("moe_top1_expert", False)
|
22 |
+
self.moe_expert_count = kwargs.pop("moe_expert_count", 0)
|
23 |
+
self.moe_gating_use_fp32 = kwargs.pop("moe_gating_use_fp32", True)
|
24 |
+
self.moe_eval_capacity_token_fraction = kwargs.pop("moe_eval_capacity_token_fraction", 0.25)
|
25 |
+
self.moe_second_expert_policy = kwargs.pop("moe_second_expert_policy", "random")
|
26 |
+
self.moe_normalize_gate_prob_before_dropping = kwargs.pop("moe_normalize_gate_prob_before_dropping", False)
|
27 |
+
self.use_xmoe = kwargs.pop("use_xmoe", False)
|
28 |
+
self.rel_pos_buckets = kwargs.pop("rel_pos_buckets", 0)
|
29 |
+
self.max_rel_pos = kwargs.pop("max_rel_pos", 0)
|
30 |
+
self.deepnorm = kwargs.pop("deepnorm", False)
|
31 |
+
self.subln = kwargs.pop("subln", True)
|
32 |
+
self.bert_init = kwargs.pop("bert_init", False)
|
33 |
+
self.multiway = kwargs.pop("multiway", False)
|
34 |
+
self.share_encoder_input_output_embed = kwargs.pop("share_encoder_input_output_embed", False)
|
35 |
+
self.max_source_positions = kwargs.pop("max_source_positions", 1024)
|
36 |
+
self.no_output_layer = kwargs.pop("no_output_layer", False)
|
37 |
+
self.layernorm_eps = kwargs.pop("layernorm_eps", 1e-5)
|
38 |
+
self.share_layer = kwargs.pop("share_layer", False)
|
39 |
+
self.share_attn = kwargs.pop("share_attn", False)
|
40 |
+
self.mask_ratio = kwargs.pop("mask_ratio", 0)
|
41 |
+
self.max_text_len = kwargs.pop("max_text_len", 52)
|
42 |
+
self.one_attn = kwargs.pop('one_attn', False)
|
43 |
+
|
44 |
+
|
45 |
+
# Text
|
46 |
+
self.vocab_size = kwargs.pop("vocab_size", -1)
|
47 |
+
# Vision
|
48 |
+
self.img_size = kwargs.pop("img_size", 224)
|
49 |
+
self.patch_size = kwargs.pop("patch_size", 16)
|
50 |
+
self.in_chans = kwargs.pop("in_chans", 3)
|
51 |
+
# Fairscale
|
52 |
+
self.checkpoint_activations = kwargs.pop("checkpoint_activations", False)
|
53 |
+
self.fsdp = kwargs.pop("fsdp", False)
|
54 |
+
self.ddp_rank = kwargs.pop("ddp_rank", 0)
|
55 |
+
self.xpos_rel_pos = kwargs.pop("xpos_rel_pos", False)
|
56 |
+
self.xpos_scale_base = kwargs.pop("xpos_scale_base", 512)
|
57 |
+
|
58 |
+
if self.deepnorm:
|
59 |
+
self.encoder_normalize_before = False
|
60 |
+
self.subln = False
|
61 |
+
if self.subln:
|
62 |
+
self.encoder_normalize_before = True
|
63 |
+
self.deepnorm = False
|
64 |
+
if self.use_xmoe:
|
65 |
+
self.moe_normalize_gate_prob_before_dropping = True
|
66 |
+
self.moe_second_expert_policy = "random"
|
67 |
+
assert self.moe_freq > 0 and self.moe_expert_count > 0
|
68 |
+
|
69 |
+
def override(self, args):
|
70 |
+
for hp in self.__dict__.keys():
|
71 |
+
if getattr(args, hp, None) is not None:
|
72 |
+
self.__dict__[hp] = getattr(args, hp, None)
|
73 |
+
|
74 |
+
|
75 |
+
class DecoderConfig(object):
|
76 |
+
def __init__(self, **kwargs):
|
77 |
+
self.decoder_embed_dim = kwargs.pop("decoder_embed_dim", 768)
|
78 |
+
self.decoder_attention_heads = kwargs.pop("decoder_attention_heads", 12)
|
79 |
+
self.decoder_ffn_embed_dim = kwargs.pop("decoder_ffn_embed_dim", 3072)
|
80 |
+
self.decoder_layers = kwargs.pop("decoder_layers", 12)
|
81 |
+
self.decoder_normalize_before = kwargs.pop("decoder_normalize_before", True)
|
82 |
+
self.activation_fn = kwargs.pop("activation_fn", "gelu")
|
83 |
+
self.dropout = kwargs.pop("dropout", 0.0)
|
84 |
+
self.drop_path_rate = kwargs.pop("drop_path_rate", 0.0)
|
85 |
+
self.attention_dropout = kwargs.pop("attention_dropout", 0.0)
|
86 |
+
self.activation_dropout = kwargs.pop("activation_dropout", 0.0)
|
87 |
+
self.no_scale_embedding = kwargs.pop("no_scale_embedding", True)
|
88 |
+
self.layernorm_embedding = kwargs.pop("layernorm_embedding", False)
|
89 |
+
self.moe_freq = kwargs.pop("moe_freq", 0)
|
90 |
+
self.moe_top1_expert = kwargs.pop("moe_top1_expert", False)
|
91 |
+
self.moe_expert_count = kwargs.pop("moe_expert_count", 0)
|
92 |
+
self.moe_gating_use_fp32 = kwargs.pop("moe_gating_use_fp32", True)
|
93 |
+
self.moe_eval_capacity_token_fraction = kwargs.pop("moe_eval_capacity_token_fraction", 0.25)
|
94 |
+
self.moe_second_expert_policy = kwargs.pop("moe_second_expert_policy", "random")
|
95 |
+
self.moe_normalize_gate_prob_before_dropping = kwargs.pop("moe_normalize_gate_prob_before_dropping", False)
|
96 |
+
self.use_xmoe = kwargs.pop("use_xmoe", False)
|
97 |
+
self.rel_pos_buckets = kwargs.pop("rel_pos_buckets", 0)
|
98 |
+
self.max_rel_pos = kwargs.pop("max_rel_pos", 0)
|
99 |
+
self.deepnorm = kwargs.pop("deepnorm", False)
|
100 |
+
self.subln = kwargs.pop("subln", True)
|
101 |
+
self.bert_init = kwargs.pop("bert_init", False)
|
102 |
+
self.multiway = kwargs.pop("multiway", False)
|
103 |
+
self.share_decoder_input_output_embed = kwargs.pop("share_decoder_input_output_embed", False)
|
104 |
+
self.max_target_positions = kwargs.pop("max_target_positions", 1024)
|
105 |
+
self.no_output_layer = kwargs.pop("no_output_layer", False)
|
106 |
+
self.layernorm_eps = kwargs.pop("layernorm_eps", 1e-5)
|
107 |
+
# Text
|
108 |
+
self.vocab_size = kwargs.pop("vocab_size", -1)
|
109 |
+
# Fairscale
|
110 |
+
self.checkpoint_activations = kwargs.pop("checkpoint_activations", False)
|
111 |
+
self.fsdp = kwargs.pop("fsdp", False)
|
112 |
+
self.ddp_rank = kwargs.pop("ddp_rank", 0)
|
113 |
+
self.xpos_rel_pos = kwargs.pop("xpos_rel_pos", False)
|
114 |
+
self.xpos_scale_base = kwargs.pop("xpos_scale_base", 512)
|
115 |
+
|
116 |
+
if self.deepnorm:
|
117 |
+
self.decoder_normalize_before = False
|
118 |
+
self.subln = False
|
119 |
+
if self.subln:
|
120 |
+
self.decoder_normalize_before = True
|
121 |
+
self.deepnorm = False
|
122 |
+
if self.use_xmoe:
|
123 |
+
self.moe_normalize_gate_prob_before_dropping = True
|
124 |
+
self.moe_second_expert_policy = "random"
|
125 |
+
assert self.moe_freq > 0 and self.moe_expert_count > 0
|
126 |
+
|
127 |
+
def override(self, args):
|
128 |
+
for hp in self.__dict__.keys():
|
129 |
+
if getattr(args, hp, None) is not None:
|
130 |
+
self.__dict__[hp] = getattr(args, hp, None)
|
131 |
+
|
132 |
+
|
133 |
+
class EncoderDecoderConfig(object):
|
134 |
+
def __init__(self, **kwargs):
|
135 |
+
self.encoder_embed_dim = kwargs.pop("encoder_embed_dim", 768)
|
136 |
+
self.encoder_attention_heads = kwargs.pop("encoder_attention_heads", 12)
|
137 |
+
self.encoder_ffn_embed_dim = kwargs.pop("encoder_ffn_embed_dim", 3072)
|
138 |
+
self.encoder_layers = kwargs.pop("encoder_layers", 12)
|
139 |
+
self.encoder_normalize_before = kwargs.pop("encoder_normalize_before", True)
|
140 |
+
self.decoder_embed_dim = kwargs.pop("decoder_embed_dim", 768)
|
141 |
+
self.decoder_attention_heads = kwargs.pop("decoder_attention_heads", 12)
|
142 |
+
self.decoder_ffn_embed_dim = kwargs.pop("decoder_ffn_embed_dim", 3072)
|
143 |
+
self.decoder_layers = kwargs.pop("decoder_layers", 12)
|
144 |
+
self.decoder_normalize_before = kwargs.pop("decoder_normalize_before", True)
|
145 |
+
self.activation_fn = kwargs.pop("activation_fn", "gelu")
|
146 |
+
self.dropout = kwargs.pop("dropout", 0.0)
|
147 |
+
self.drop_path_rate = kwargs.pop("drop_path_rate", 0.0)
|
148 |
+
self.attention_dropout = kwargs.pop("attention_dropout", 0.0)
|
149 |
+
self.activation_dropout = kwargs.pop("activation_dropout", 0.0)
|
150 |
+
self.no_scale_embedding = kwargs.pop("no_scale_embedding", True)
|
151 |
+
self.layernorm_embedding = kwargs.pop("layernorm_embedding", False)
|
152 |
+
self.moe_freq = kwargs.pop("moe_freq", 0)
|
153 |
+
self.moe_top1_expert = kwargs.pop("moe_top1_expert", False)
|
154 |
+
self.moe_expert_count = kwargs.pop("moe_expert_count", 0)
|
155 |
+
self.moe_gating_use_fp32 = kwargs.pop("moe_gating_use_fp32", True)
|
156 |
+
self.moe_eval_capacity_token_fraction = kwargs.pop("moe_eval_capacity_token_fraction", 0.25)
|
157 |
+
self.moe_second_expert_policy = kwargs.pop("moe_second_expert_policy", "random")
|
158 |
+
self.moe_normalize_gate_prob_before_dropping = kwargs.pop("moe_normalize_gate_prob_before_dropping", False)
|
159 |
+
self.use_xmoe = kwargs.pop("use_xmoe", False)
|
160 |
+
self.rel_pos_buckets = kwargs.pop("rel_pos_buckets", 0)
|
161 |
+
self.max_rel_pos = kwargs.pop("max_rel_pos", 0)
|
162 |
+
self.deepnorm = kwargs.pop("deepnorm", False)
|
163 |
+
self.subln = kwargs.pop("subln", True)
|
164 |
+
self.bert_init = kwargs.pop("bert_init", False)
|
165 |
+
self.multiway = kwargs.pop("multiway", False)
|
166 |
+
self.share_all_embeddings = kwargs.pop("share_all_embeddings", False)
|
167 |
+
self.share_decoder_input_output_embed = kwargs.pop("share_decoder_input_output_embed", False)
|
168 |
+
self.max_source_positions = kwargs.pop("max_source_positions", 1024)
|
169 |
+
self.max_target_positions = kwargs.pop("max_target_positions", 1024)
|
170 |
+
self.no_output_layer = kwargs.pop("no_output_layer", False)
|
171 |
+
self.layernorm_eps = kwargs.pop("layernorm_eps", 1e-5)
|
172 |
+
# Text
|
173 |
+
self.vocab_size = kwargs.pop("vocab_size", -1)
|
174 |
+
# Fairscale
|
175 |
+
self.checkpoint_activations = kwargs.pop("checkpoint_activations", False)
|
176 |
+
self.fsdp = kwargs.pop("fsdp", False)
|
177 |
+
self.ddp_rank = kwargs.pop("ddp_rank", 0)
|
178 |
+
self.xpos_rel_pos = kwargs.pop("xpos_rel_pos", False)
|
179 |
+
self.xpos_scale_base = kwargs.pop("xpos_scale_base", 512)
|
180 |
+
|
181 |
+
if self.deepnorm:
|
182 |
+
self.encoder_normalize_before = False
|
183 |
+
self.decoder_normalize_before = False
|
184 |
+
self.subln = False
|
185 |
+
if self.subln:
|
186 |
+
self.encoder_normalize_before = True
|
187 |
+
self.decoder_normalize_before = True
|
188 |
+
self.deepnorm = False
|
189 |
+
if self.use_xmoe:
|
190 |
+
self.moe_normalize_gate_prob_before_dropping = True
|
191 |
+
self.moe_second_expert_policy = "random"
|
192 |
+
assert self.moe_freq > 0 and self.moe_expert_count > 0
|
193 |
+
|
194 |
+
def override(self, args):
|
195 |
+
for hp in self.__dict__.keys():
|
196 |
+
if getattr(args, hp, None) is not None:
|
197 |
+
self.__dict__[hp] = getattr(args, hp, None)
|
vlmo/torchscale/architecture/decoder.py
ADDED
@@ -0,0 +1,428 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022 Microsoft
|
2 |
+
# Licensed under The MIT License [see LICENSE for details]
|
3 |
+
|
4 |
+
import math
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
from fairscale.nn import checkpoint_wrapper, wrap
|
10 |
+
|
11 |
+
from vlmo.torchscale.architecture.utils import init_bert_params
|
12 |
+
from vlmo.torchscale.component.droppath import DropPath
|
13 |
+
from vlmo.torchscale.component.feedforward_network import FeedForwardNetwork, make_experts
|
14 |
+
from vlmo.torchscale.component.multihead_attention import MultiheadAttention
|
15 |
+
from vlmo.torchscale.component.relative_position_bias import RelativePositionBias
|
16 |
+
#from vlmo.torchscale.component.xmoe.moe_layer import MOELayer
|
17 |
+
#from vlmo.torchscale.component.xmoe.routing import Top1Gate, Top2Gate
|
18 |
+
|
19 |
+
try:
|
20 |
+
from apex.normalization import FusedLayerNorm as LayerNorm
|
21 |
+
except ModuleNotFoundError:
|
22 |
+
from torch.nn import LayerNorm
|
23 |
+
|
24 |
+
|
25 |
+
class DecoderLayer(nn.Module):
|
26 |
+
def __init__(
|
27 |
+
self,
|
28 |
+
args,
|
29 |
+
depth,
|
30 |
+
is_moe_layer=False,
|
31 |
+
is_encoder_decoder=False,
|
32 |
+
):
|
33 |
+
super().__init__()
|
34 |
+
self.args = args
|
35 |
+
self.embed_dim = args.decoder_embed_dim
|
36 |
+
self.dropout_module = torch.nn.Dropout(args.dropout)
|
37 |
+
|
38 |
+
if args.drop_path_rate > 0:
|
39 |
+
drop_path_prob = np.linspace(0, args.drop_path_rate, args.decoder_layers)[depth]
|
40 |
+
self.drop_path = DropPath(drop_path_prob)
|
41 |
+
else:
|
42 |
+
self.drop_path = None
|
43 |
+
|
44 |
+
self.self_attn = self.build_self_attention(self.embed_dim, args)
|
45 |
+
|
46 |
+
self.normalize_before = args.decoder_normalize_before
|
47 |
+
|
48 |
+
self.self_attn_layer_norm = LayerNorm(self.embed_dim, eps=args.layernorm_eps)
|
49 |
+
|
50 |
+
if not is_encoder_decoder:
|
51 |
+
self.encoder_attn = None
|
52 |
+
self.encoder_attn_layer_norm = None
|
53 |
+
else:
|
54 |
+
self.encoder_attn = self.build_encoder_attention(self.embed_dim, args)
|
55 |
+
self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, eps=args.layernorm_eps)
|
56 |
+
|
57 |
+
self.is_moe_layer = is_moe_layer
|
58 |
+
self.ffn_dim = args.decoder_ffn_embed_dim
|
59 |
+
|
60 |
+
if not self.is_moe_layer:
|
61 |
+
self.ffn = self.build_ffn(
|
62 |
+
self.embed_dim,
|
63 |
+
self.args,
|
64 |
+
)
|
65 |
+
else:
|
66 |
+
if args.moe_top1_expert:
|
67 |
+
gate = Top1Gate(
|
68 |
+
self.embed_dim,
|
69 |
+
args.moe_expert_count,
|
70 |
+
use_fp32=args.moe_gating_use_fp32,
|
71 |
+
moe_eval_capacity_token_fraction=args.moe_eval_capacity_token_fraction,
|
72 |
+
use_xmoe=args.use_xmoe,
|
73 |
+
)
|
74 |
+
else:
|
75 |
+
gate = Top2Gate(
|
76 |
+
self.embed_dim,
|
77 |
+
args.moe_expert_count,
|
78 |
+
args.moe_gating_use_fp32,
|
79 |
+
args.moe_second_expert_policy,
|
80 |
+
args.moe_normalize_gate_prob_before_dropping,
|
81 |
+
args.moe_eval_capacity_token_fraction,
|
82 |
+
use_xmoe=args.use_xmoe,
|
83 |
+
)
|
84 |
+
experts = make_experts(args, self.embed_dim, self.ffn_dim)
|
85 |
+
self.moe_layer = MOELayer(gate, experts, args)
|
86 |
+
|
87 |
+
self.final_layer_norm = LayerNorm(self.embed_dim, eps=args.layernorm_eps)
|
88 |
+
|
89 |
+
if args.deepnorm:
|
90 |
+
if is_encoder_decoder:
|
91 |
+
self.alpha = math.pow(3.0 * args.decoder_layers, 0.25)
|
92 |
+
else:
|
93 |
+
self.alpha = math.pow(2.0 * args.decoder_layers, 0.25)
|
94 |
+
else:
|
95 |
+
self.alpha = 1.0
|
96 |
+
|
97 |
+
def build_ffn(self, embed_dim, args):
|
98 |
+
return FeedForwardNetwork(
|
99 |
+
embed_dim,
|
100 |
+
self.ffn_dim,
|
101 |
+
args.activation_fn,
|
102 |
+
args.dropout,
|
103 |
+
args.activation_dropout,
|
104 |
+
args.layernorm_eps,
|
105 |
+
args.subln,
|
106 |
+
)
|
107 |
+
|
108 |
+
def build_self_attention(self, embed_dim, args):
|
109 |
+
return MultiheadAttention(
|
110 |
+
args,
|
111 |
+
embed_dim,
|
112 |
+
args.decoder_attention_heads,
|
113 |
+
dropout=args.attention_dropout,
|
114 |
+
self_attention=True,
|
115 |
+
encoder_decoder_attention=False,
|
116 |
+
subln=args.subln,
|
117 |
+
)
|
118 |
+
|
119 |
+
def build_encoder_attention(self, embed_dim, args):
|
120 |
+
return MultiheadAttention(
|
121 |
+
args,
|
122 |
+
embed_dim,
|
123 |
+
args.decoder_attention_heads,
|
124 |
+
dropout=args.attention_dropout,
|
125 |
+
self_attention=False,
|
126 |
+
encoder_decoder_attention=True,
|
127 |
+
subln=args.subln,
|
128 |
+
)
|
129 |
+
|
130 |
+
def residual_connection(self, x, residual):
|
131 |
+
return residual * self.alpha + x
|
132 |
+
|
133 |
+
def forward(
|
134 |
+
self,
|
135 |
+
x,
|
136 |
+
encoder_out=None,
|
137 |
+
encoder_padding_mask=None,
|
138 |
+
incremental_state=None,
|
139 |
+
self_attn_mask=None,
|
140 |
+
self_attn_padding_mask=None,
|
141 |
+
self_attn_rel_pos=None,
|
142 |
+
cross_attn_rel_pos=None,
|
143 |
+
):
|
144 |
+
residual = x
|
145 |
+
if self.normalize_before:
|
146 |
+
x = self.self_attn_layer_norm(x)
|
147 |
+
|
148 |
+
x, attn = self.self_attn(
|
149 |
+
query=x,
|
150 |
+
key=x,
|
151 |
+
value=x,
|
152 |
+
key_padding_mask=self_attn_padding_mask,
|
153 |
+
incremental_state=incremental_state,
|
154 |
+
attn_mask=self_attn_mask,
|
155 |
+
rel_pos=self_attn_rel_pos,
|
156 |
+
)
|
157 |
+
x = self.dropout_module(x)
|
158 |
+
|
159 |
+
if self.drop_path is not None:
|
160 |
+
x = self.drop_path(x)
|
161 |
+
|
162 |
+
x = self.residual_connection(x, residual)
|
163 |
+
if not self.normalize_before:
|
164 |
+
x = self.self_attn_layer_norm(x)
|
165 |
+
|
166 |
+
if self.encoder_attn is not None and encoder_out is not None:
|
167 |
+
residual = x
|
168 |
+
if self.normalize_before:
|
169 |
+
x = self.encoder_attn_layer_norm(x)
|
170 |
+
|
171 |
+
x, attn = self.encoder_attn(
|
172 |
+
query=x,
|
173 |
+
key=encoder_out,
|
174 |
+
value=encoder_out,
|
175 |
+
key_padding_mask=encoder_padding_mask,
|
176 |
+
incremental_state=None,
|
177 |
+
rel_pos=cross_attn_rel_pos,
|
178 |
+
)
|
179 |
+
x = self.dropout_module(x)
|
180 |
+
|
181 |
+
if self.drop_path is not None:
|
182 |
+
x = self.drop_path(x)
|
183 |
+
|
184 |
+
x = self.residual_connection(x, residual)
|
185 |
+
if not self.normalize_before:
|
186 |
+
x = self.encoder_attn_layer_norm(x)
|
187 |
+
|
188 |
+
residual = x
|
189 |
+
if self.normalize_before:
|
190 |
+
x = self.final_layer_norm(x)
|
191 |
+
if not self.is_moe_layer:
|
192 |
+
x = self.ffn(x)
|
193 |
+
l_aux = None
|
194 |
+
else:
|
195 |
+
x, l_aux = self.moe_layer(x)
|
196 |
+
|
197 |
+
if self.drop_path is not None:
|
198 |
+
x = self.drop_path(x)
|
199 |
+
|
200 |
+
x = self.residual_connection(x, residual)
|
201 |
+
if not self.normalize_before:
|
202 |
+
x = self.final_layer_norm(x)
|
203 |
+
|
204 |
+
return x, attn, None, l_aux
|
205 |
+
|
206 |
+
|
207 |
+
class Decoder(nn.Module):
|
208 |
+
def __init__(
|
209 |
+
self, args, embed_tokens=None, embed_positions=None, output_projection=None, is_encoder_decoder=False, **kwargs
|
210 |
+
):
|
211 |
+
super().__init__(**kwargs)
|
212 |
+
self.args = args
|
213 |
+
|
214 |
+
self.dropout_module = torch.nn.Dropout(args.dropout)
|
215 |
+
|
216 |
+
embed_dim = args.decoder_embed_dim
|
217 |
+
self.embed_dim = embed_dim
|
218 |
+
self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt(embed_dim)
|
219 |
+
|
220 |
+
self.embed_tokens = embed_tokens
|
221 |
+
self.embed_positions = embed_positions
|
222 |
+
|
223 |
+
if output_projection is None and not args.no_output_layer and args.vocab_size > 0:
|
224 |
+
self.output_projection = self.build_output_projection(args)
|
225 |
+
else:
|
226 |
+
self.output_projection = output_projection
|
227 |
+
|
228 |
+
if args.layernorm_embedding:
|
229 |
+
self.layernorm_embedding = LayerNorm(embed_dim, eps=args.layernorm_eps)
|
230 |
+
else:
|
231 |
+
self.layernorm_embedding = None
|
232 |
+
|
233 |
+
self.layers = nn.ModuleList([])
|
234 |
+
|
235 |
+
moe_freq = args.moe_freq
|
236 |
+
for i in range(args.decoder_layers):
|
237 |
+
is_moe_layer = moe_freq != 0 and (i + 1) % moe_freq == 0
|
238 |
+
self.layers.append(
|
239 |
+
self.build_decoder_layer(
|
240 |
+
args,
|
241 |
+
depth=i,
|
242 |
+
is_moe_layer=is_moe_layer,
|
243 |
+
is_encoder_decoder=is_encoder_decoder,
|
244 |
+
)
|
245 |
+
)
|
246 |
+
|
247 |
+
self.num_layers = len(self.layers)
|
248 |
+
|
249 |
+
if args.decoder_normalize_before:
|
250 |
+
self.layer_norm = LayerNorm(embed_dim, eps=args.layernorm_eps)
|
251 |
+
else:
|
252 |
+
self.layer_norm = None
|
253 |
+
|
254 |
+
self.self_attn_relative_position = None
|
255 |
+
self.cross_attn_relative_position = None
|
256 |
+
|
257 |
+
if args.rel_pos_buckets > 0 and args.max_rel_pos > 0:
|
258 |
+
self.self_attn_relative_position = RelativePositionBias(
|
259 |
+
num_buckets=args.rel_pos_buckets,
|
260 |
+
max_distance=args.max_rel_pos,
|
261 |
+
n_heads=args.decoder_attention_heads,
|
262 |
+
)
|
263 |
+
if is_encoder_decoder:
|
264 |
+
self.cross_attn_relative_position = RelativePositionBias(
|
265 |
+
num_buckets=args.rel_pos_buckets,
|
266 |
+
max_distance=args.max_rel_pos,
|
267 |
+
n_heads=args.decoder_attention_heads,
|
268 |
+
)
|
269 |
+
|
270 |
+
if args.bert_init:
|
271 |
+
self.apply(init_bert_params)
|
272 |
+
|
273 |
+
if args.deepnorm:
|
274 |
+
if is_encoder_decoder:
|
275 |
+
init_scale = math.pow(12.0 * args.decoder_layers, 0.25)
|
276 |
+
else:
|
277 |
+
init_scale = math.pow(8.0 * args.decoder_layers, 0.25)
|
278 |
+
for name, p in self.named_parameters():
|
279 |
+
if "fc1" in name or "fc2" in name or "out_proj" in name or "v_proj" in name:
|
280 |
+
p.data.div_(init_scale)
|
281 |
+
|
282 |
+
if args.subln:
|
283 |
+
if is_encoder_decoder:
|
284 |
+
init_scale = math.sqrt(math.log(args.decoder_layers * 3))
|
285 |
+
else:
|
286 |
+
init_scale = math.sqrt(math.log(args.decoder_layers * 2))
|
287 |
+
for name, p in self.named_parameters():
|
288 |
+
if "encoder_attn" in name:
|
289 |
+
continue
|
290 |
+
if "fc1" in name or "fc2" in name or "out_proj" in name or "v_proj" in name:
|
291 |
+
p.data.mul_(init_scale)
|
292 |
+
|
293 |
+
def build_output_projection(
|
294 |
+
self,
|
295 |
+
args,
|
296 |
+
):
|
297 |
+
if args.share_decoder_input_output_embed:
|
298 |
+
output_projection = torch.nn.Linear(
|
299 |
+
self.embed_tokens.weight.shape[1],
|
300 |
+
self.embed_tokens.weight.shape[0],
|
301 |
+
bias=False,
|
302 |
+
)
|
303 |
+
output_projection.weight = self.embed_tokens.weight
|
304 |
+
else:
|
305 |
+
output_projection = torch.nn.Linear(args.decoder_embed_dim, args.vocab_size, bias=False)
|
306 |
+
torch.nn.init.normal_(output_projection.weight, mean=0, std=args.decoder_embed_dim**-0.5)
|
307 |
+
return output_projection
|
308 |
+
|
309 |
+
def build_decoder_layer(self, args, depth, is_moe_layer=False, is_encoder_decoder=False):
|
310 |
+
layer = DecoderLayer(
|
311 |
+
args,
|
312 |
+
depth,
|
313 |
+
is_moe_layer=is_moe_layer,
|
314 |
+
is_encoder_decoder=is_encoder_decoder,
|
315 |
+
)
|
316 |
+
if args.checkpoint_activations:
|
317 |
+
layer = checkpoint_wrapper(layer)
|
318 |
+
if args.fsdp:
|
319 |
+
layer = wrap(layer)
|
320 |
+
return layer
|
321 |
+
|
322 |
+
def forward_embedding(
|
323 |
+
self,
|
324 |
+
tokens,
|
325 |
+
token_embedding=None,
|
326 |
+
incremental_state=None,
|
327 |
+
):
|
328 |
+
positions = None
|
329 |
+
if self.embed_positions is not None:
|
330 |
+
positions = self.embed_positions(tokens, incremental_state=incremental_state)
|
331 |
+
|
332 |
+
if incremental_state is not None:
|
333 |
+
tokens = tokens[:, -1:]
|
334 |
+
if positions is not None:
|
335 |
+
positions = positions[:, -1:]
|
336 |
+
|
337 |
+
if token_embedding is None:
|
338 |
+
token_embedding = self.embed_tokens(tokens)
|
339 |
+
|
340 |
+
x = embed = self.embed_scale * token_embedding
|
341 |
+
|
342 |
+
if positions is not None:
|
343 |
+
x += positions
|
344 |
+
|
345 |
+
if self.layernorm_embedding is not None:
|
346 |
+
x = self.layernorm_embedding(x)
|
347 |
+
|
348 |
+
x = self.dropout_module(x)
|
349 |
+
|
350 |
+
return x, embed
|
351 |
+
|
352 |
+
def forward(
|
353 |
+
self,
|
354 |
+
prev_output_tokens,
|
355 |
+
self_attn_padding_mask=None,
|
356 |
+
encoder_out=None,
|
357 |
+
incremental_state=None,
|
358 |
+
features_only=False,
|
359 |
+
return_all_hiddens=False,
|
360 |
+
token_embeddings=None,
|
361 |
+
**kwargs
|
362 |
+
):
|
363 |
+
# embed tokens and positions
|
364 |
+
x, _ = self.forward_embedding(prev_output_tokens, token_embeddings, incremental_state)
|
365 |
+
|
366 |
+
# relative position
|
367 |
+
self_attn_rel_pos_bias = None
|
368 |
+
slen = prev_output_tokens.size(1)
|
369 |
+
if self.self_attn_relative_position is not None:
|
370 |
+
self_attn_rel_pos_bias = self.self_attn_relative_position(batch_size=x.size(0), qlen=slen, klen=slen)
|
371 |
+
if incremental_state is not None:
|
372 |
+
self_attn_rel_pos_bias = self_attn_rel_pos_bias[-1:, :, :]
|
373 |
+
cross_attn_rel_pos_bias = None
|
374 |
+
if self.cross_attn_relative_position is not None:
|
375 |
+
cross_attn_rel_pos_bias = self.cross_attn_relative_position(
|
376 |
+
batch_size=x.size(0),
|
377 |
+
qlen=slen,
|
378 |
+
klen=encoder_out["encoder_out"].size(1),
|
379 |
+
)
|
380 |
+
if incremental_state is not None:
|
381 |
+
cross_attn_rel_pos_bias = cross_attn_rel_pos_bias[-1:, :, :]
|
382 |
+
|
383 |
+
# decoder layers
|
384 |
+
inner_states = [x]
|
385 |
+
|
386 |
+
if encoder_out is None:
|
387 |
+
l_aux = []
|
388 |
+
else:
|
389 |
+
l_aux = encoder_out["l_aux"] if "l_aux" in encoder_out else []
|
390 |
+
|
391 |
+
for idx, layer in enumerate(self.layers):
|
392 |
+
if incremental_state is None:
|
393 |
+
self_attn_mask = torch.triu(
|
394 |
+
torch.zeros([x.size(1), x.size(1)]).float().fill_(float("-inf")).type_as(x),
|
395 |
+
1,
|
396 |
+
)
|
397 |
+
else:
|
398 |
+
self_attn_mask = None
|
399 |
+
if idx not in incremental_state:
|
400 |
+
incremental_state[idx] = {}
|
401 |
+
|
402 |
+
x, layer_attn, _, l_aux_i = layer(
|
403 |
+
x,
|
404 |
+
encoder_out["encoder_out"] if encoder_out is not None else None,
|
405 |
+
encoder_out["encoder_padding_mask"] if encoder_out is not None else None,
|
406 |
+
incremental_state[idx] if incremental_state is not None else None,
|
407 |
+
self_attn_mask=self_attn_mask,
|
408 |
+
self_attn_padding_mask=self_attn_padding_mask,
|
409 |
+
self_attn_rel_pos=self_attn_rel_pos_bias,
|
410 |
+
cross_attn_rel_pos=cross_attn_rel_pos_bias,
|
411 |
+
)
|
412 |
+
l_aux.append(l_aux_i)
|
413 |
+
inner_states.append(x)
|
414 |
+
|
415 |
+
if self.layer_norm is not None:
|
416 |
+
x = self.layer_norm(x)
|
417 |
+
|
418 |
+
if not features_only:
|
419 |
+
x = self.output_layer(x)
|
420 |
+
|
421 |
+
return x, {
|
422 |
+
"inner_states": inner_states,
|
423 |
+
"l_aux": l_aux,
|
424 |
+
"attn": None,
|
425 |
+
}
|
426 |
+
|
427 |
+
def output_layer(self, features):
|
428 |
+
return self.output_projection(features)
|
vlmo/torchscale/architecture/encoder.py
ADDED
@@ -0,0 +1,489 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022 Microsoft
|
2 |
+
# Licensed under The MIT License [see LICENSE for details]
|
3 |
+
|
4 |
+
import math
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
from fairscale.nn import checkpoint_wrapper, wrap
|
10 |
+
|
11 |
+
try:
|
12 |
+
from apex.normalization import FusedLayerNorm as LayerNorm
|
13 |
+
except ModuleNotFoundError:
|
14 |
+
from torch.nn import LayerNorm
|
15 |
+
|
16 |
+
from vlmo.torchscale.architecture.utils import init_bert_params
|
17 |
+
from vlmo.torchscale.component.droppath import DropPath
|
18 |
+
from vlmo.torchscale.component.feedforward_network import FeedForwardNetwork, make_experts
|
19 |
+
from vlmo.torchscale.component.multihead_attention import MultiheadAttention
|
20 |
+
from vlmo.torchscale.component.multiway_network import MultiwayWrapper, set_split_position
|
21 |
+
from vlmo.torchscale.component.relative_position_bias import RelativePositionBias
|
22 |
+
#from vlmo.torchscale.component.xmoe.moe_layer import MOELayer
|
23 |
+
#from vlmo.torchscale.component.xmoe.routing import Top1Gate, Top2Gate
|
24 |
+
# from vlmo.modules.vlmo_utils import no_sync_module_apply
|
25 |
+
from pytorch_lightning.utilities.rank_zero import rank_zero_info
|
26 |
+
|
27 |
+
def no_sync_module_apply(module, fn):
|
28 |
+
"""FSDP module .apply will use _unshard_params_recurse which will sync params across ranks.
|
29 |
+
using this function when apply fn is unnecessary to sync params across ranks.
|
30 |
+
"""
|
31 |
+
for child in module.children():
|
32 |
+
fn(child)
|
33 |
+
no_sync_module_apply(child, fn)
|
34 |
+
|
35 |
+
class EncoderLayer(nn.Module):
|
36 |
+
def __init__(self, args, depth, attn=None, is_moe_layer=False, is_encoder_decoder=False):
|
37 |
+
super().__init__()
|
38 |
+
self.args = args
|
39 |
+
self.embed_dim = args.encoder_embed_dim
|
40 |
+
self.self_attn = self.build_self_attention(self.embed_dim, args) if attn is None else attn
|
41 |
+
self.self_attn_layer_norm = MultiwayWrapper(args, LayerNorm(self.embed_dim, eps=args.layernorm_eps))
|
42 |
+
self.dropout_module = torch.nn.Dropout(args.dropout)
|
43 |
+
|
44 |
+
if args.drop_path_rate > 0:
|
45 |
+
drop_path_prob = np.linspace(0, args.drop_path_rate, args.encoder_layers)[depth]
|
46 |
+
self.drop_path = DropPath(drop_path_prob)
|
47 |
+
else:
|
48 |
+
self.drop_path = None
|
49 |
+
|
50 |
+
self.normalize_before = args.encoder_normalize_before
|
51 |
+
self.is_moe_layer = is_moe_layer
|
52 |
+
self.ffn_dim = args.encoder_ffn_embed_dim
|
53 |
+
|
54 |
+
if not self.is_moe_layer:
|
55 |
+
self.ffn = MultiwayWrapper(
|
56 |
+
args,
|
57 |
+
self.build_ffn(
|
58 |
+
self.embed_dim,
|
59 |
+
self.args,
|
60 |
+
),
|
61 |
+
)
|
62 |
+
else:
|
63 |
+
assert not self.args.multiway
|
64 |
+
if args.moe_top1_expert:
|
65 |
+
gate = Top1Gate(
|
66 |
+
self.embed_dim,
|
67 |
+
args.moe_expert_count,
|
68 |
+
use_fp32=args.moe_gating_use_fp32,
|
69 |
+
moe_eval_capacity_token_fraction=args.moe_eval_capacity_token_fraction,
|
70 |
+
use_xmoe=args.use_xmoe,
|
71 |
+
)
|
72 |
+
else:
|
73 |
+
gate = Top2Gate(
|
74 |
+
self.embed_dim,
|
75 |
+
args.moe_expert_count,
|
76 |
+
args.moe_gating_use_fp32,
|
77 |
+
args.moe_second_expert_policy,
|
78 |
+
args.moe_normalize_gate_prob_before_dropping,
|
79 |
+
args.moe_eval_capacity_token_fraction,
|
80 |
+
use_xmoe=args.use_xmoe,
|
81 |
+
)
|
82 |
+
experts = make_experts(args, self.embed_dim, self.ffn_dim)
|
83 |
+
self.moe_layer = MOELayer(gate, experts, args)
|
84 |
+
self.final_layer_norm = MultiwayWrapper(args, LayerNorm(self.embed_dim, eps=args.layernorm_eps))
|
85 |
+
|
86 |
+
if args.deepnorm:
|
87 |
+
if is_encoder_decoder:
|
88 |
+
self.alpha = math.pow(math.pow(args.encoder_layers, 4) * args.decoder_layers, 0.0625) * 0.81
|
89 |
+
else:
|
90 |
+
self.alpha = math.pow(2.0 * args.encoder_layers, 0.25)
|
91 |
+
else:
|
92 |
+
self.alpha = 1.0
|
93 |
+
|
94 |
+
def build_ffn(self, embed_dim, args):
|
95 |
+
return FeedForwardNetwork(
|
96 |
+
embed_dim,
|
97 |
+
self.ffn_dim,
|
98 |
+
args.activation_fn,
|
99 |
+
args.dropout,
|
100 |
+
args.activation_dropout,
|
101 |
+
args.layernorm_eps,
|
102 |
+
args.subln,
|
103 |
+
)
|
104 |
+
|
105 |
+
def build_self_attention(self, embed_dim, args):
|
106 |
+
return MultiheadAttention(
|
107 |
+
args,
|
108 |
+
embed_dim,
|
109 |
+
args.encoder_attention_heads,
|
110 |
+
dropout=args.attention_dropout,
|
111 |
+
self_attention=True,
|
112 |
+
encoder_decoder_attention=False,
|
113 |
+
subln=args.subln,
|
114 |
+
one_attn=args.one_attn,
|
115 |
+
)
|
116 |
+
|
117 |
+
def residual_connection(self, x, residual):
|
118 |
+
return residual * self.alpha + x
|
119 |
+
|
120 |
+
def forward(
|
121 |
+
self,
|
122 |
+
x,
|
123 |
+
encoder_padding_mask,
|
124 |
+
attn_mask=None,
|
125 |
+
rel_pos=None,
|
126 |
+
multiway_split_position=None,
|
127 |
+
incremental_state=None,
|
128 |
+
):
|
129 |
+
if multiway_split_position is not None:
|
130 |
+
assert self.args.multiway
|
131 |
+
no_sync_module_apply(self, set_split_position(multiway_split_position))
|
132 |
+
|
133 |
+
if attn_mask is not None:
|
134 |
+
# float16: -1e8 equal 0
|
135 |
+
attn_mask = attn_mask.masked_fill(attn_mask.to(torch.bool), -1e8)
|
136 |
+
|
137 |
+
residual = x
|
138 |
+
if self.normalize_before:
|
139 |
+
x = self.self_attn_layer_norm(x)
|
140 |
+
x, _ = self.self_attn(
|
141 |
+
query=x,
|
142 |
+
key=x,
|
143 |
+
value=x,
|
144 |
+
key_padding_mask=encoder_padding_mask,
|
145 |
+
attn_mask=attn_mask,
|
146 |
+
rel_pos=rel_pos,
|
147 |
+
incremental_state=incremental_state,
|
148 |
+
)
|
149 |
+
x = self.dropout_module(x)
|
150 |
+
|
151 |
+
if self.drop_path is not None:
|
152 |
+
x = self.drop_path(x)
|
153 |
+
|
154 |
+
x = self.residual_connection(x, residual)
|
155 |
+
if not self.normalize_before:
|
156 |
+
x = self.self_attn_layer_norm(x)
|
157 |
+
|
158 |
+
residual = x
|
159 |
+
if self.normalize_before:
|
160 |
+
x = self.final_layer_norm(x)
|
161 |
+
if not self.is_moe_layer:
|
162 |
+
x = self.ffn(x)
|
163 |
+
l_aux = None
|
164 |
+
else:
|
165 |
+
x = x.transpose(0, 1)
|
166 |
+
x, l_aux = self.moe_layer(x)
|
167 |
+
x = x.transpose(0, 1)
|
168 |
+
|
169 |
+
if self.drop_path is not None:
|
170 |
+
x = self.drop_path(x)
|
171 |
+
|
172 |
+
x = self.residual_connection(x, residual)
|
173 |
+
if not self.normalize_before:
|
174 |
+
x = self.final_layer_norm(x)
|
175 |
+
return x, l_aux
|
176 |
+
|
177 |
+
|
178 |
+
class Encoder(nn.Module):
|
179 |
+
def __init__(
|
180 |
+
self, args, embed_tokens=None, embed_positions=None, output_projection=None, is_encoder_decoder=False, **kwargs
|
181 |
+
):
|
182 |
+
self.args = args
|
183 |
+
super().__init__(**kwargs)
|
184 |
+
|
185 |
+
self.dropout_module = torch.nn.Dropout(args.dropout)
|
186 |
+
|
187 |
+
embed_dim = args.encoder_embed_dim
|
188 |
+
self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt(embed_dim)
|
189 |
+
self.mask_ratio = args.mask_ratio
|
190 |
+
self.max_text_len = args.max_text_len
|
191 |
+
self.vision_len = (args.img_size // args.patch_size) * (args.img_size // args.patch_size)
|
192 |
+
|
193 |
+
self.embed_tokens = embed_tokens
|
194 |
+
self.embed_positions = embed_positions
|
195 |
+
|
196 |
+
if output_projection is None and not is_encoder_decoder and not args.no_output_layer and args.vocab_size > 0:
|
197 |
+
self.output_projection = self.build_output_projection(args)
|
198 |
+
else:
|
199 |
+
self.output_projection = output_projection
|
200 |
+
|
201 |
+
if args.layernorm_embedding:
|
202 |
+
self.layernorm_embedding = MultiwayWrapper(args, LayerNorm(embed_dim, eps=args.layernorm_eps), dim=1)
|
203 |
+
else:
|
204 |
+
self.layernorm_embedding = None
|
205 |
+
|
206 |
+
self.layers = nn.ModuleList([])
|
207 |
+
if self.args.share_layer:
|
208 |
+
single_layer = self.build_encoder_layer(
|
209 |
+
args, depth=0, is_moe_layer=False, is_encoder_decoder=is_encoder_decoder
|
210 |
+
)
|
211 |
+
for i in range(args.encoder_layers):
|
212 |
+
self.layers.append(single_layer)
|
213 |
+
elif self.args.share_attn:
|
214 |
+
moe_freq = args.moe_freq
|
215 |
+
embed_dim = args.encoder_embed_dim
|
216 |
+
shared_attn = self.build_self_attention(embed_dim, self.args)
|
217 |
+
for i in range(args.encoder_layers):
|
218 |
+
is_moe_layer = moe_freq != 0 and (i + 1) % moe_freq == 0
|
219 |
+
self.layers.append(
|
220 |
+
self.build_encoder_layer(
|
221 |
+
args,
|
222 |
+
depth=i,
|
223 |
+
attn=shared_attn,
|
224 |
+
is_moe_layer=is_moe_layer,
|
225 |
+
is_encoder_decoder=is_encoder_decoder,
|
226 |
+
)
|
227 |
+
)
|
228 |
+
|
229 |
+
else:
|
230 |
+
moe_freq = args.moe_freq
|
231 |
+
for i in range(args.encoder_layers):
|
232 |
+
is_moe_layer = moe_freq != 0 and (i + 1) % moe_freq == 0
|
233 |
+
self.layers.append(
|
234 |
+
self.build_encoder_layer(
|
235 |
+
args,
|
236 |
+
depth=i,
|
237 |
+
is_moe_layer=is_moe_layer,
|
238 |
+
is_encoder_decoder=is_encoder_decoder,
|
239 |
+
)
|
240 |
+
)
|
241 |
+
self.num_layers = len(self.layers)
|
242 |
+
|
243 |
+
if args.encoder_normalize_before and args.normalize_output:
|
244 |
+
self.layer_norm = MultiwayWrapper(args, LayerNorm(embed_dim, eps=args.layernorm_eps))
|
245 |
+
else:
|
246 |
+
self.layer_norm = None
|
247 |
+
|
248 |
+
if args.rel_pos_buckets > 0 and args.max_rel_pos > 0:
|
249 |
+
self.relative_position = RelativePositionBias(
|
250 |
+
num_buckets=args.rel_pos_buckets,
|
251 |
+
max_distance=args.max_rel_pos,
|
252 |
+
n_heads=args.encoder_attention_heads,
|
253 |
+
)
|
254 |
+
else:
|
255 |
+
self.relative_position = None
|
256 |
+
|
257 |
+
if args.bert_init:
|
258 |
+
self.apply(init_bert_params)
|
259 |
+
|
260 |
+
if args.deepnorm:
|
261 |
+
if is_encoder_decoder:
|
262 |
+
init_scale = math.pow(math.pow(args.encoder_layers, 4) * args.decoder_layers, 0.0625) / 1.15
|
263 |
+
else:
|
264 |
+
init_scale = math.pow(8.0 * args.encoder_layers, 0.25)
|
265 |
+
for name, p in self.named_parameters():
|
266 |
+
if "fc1" in name or "fc2" in name or "out_proj" in name or "v_proj" in name:
|
267 |
+
p.data.div_(init_scale)
|
268 |
+
|
269 |
+
if args.subln:
|
270 |
+
if is_encoder_decoder:
|
271 |
+
init_scale = math.sqrt(math.log(3 * args.decoder_layers) * math.log(2 * args.encoder_layers) / 3)
|
272 |
+
else:
|
273 |
+
init_scale = math.sqrt(math.log(args.encoder_layers * 2))
|
274 |
+
for name, p in self.named_parameters():
|
275 |
+
if "fc1" in name or "fc2" in name or "out_proj" in name or "v_proj" in name:
|
276 |
+
p.data.mul_(init_scale)
|
277 |
+
|
278 |
+
def random_masking(self, x, mask_ratio):
|
279 |
+
N, L, D = x.shape # batch, length, dim
|
280 |
+
len_keep = int(L * (1 - mask_ratio))
|
281 |
+
|
282 |
+
noise = torch.rand(N, L - 1, device=x.device)
|
283 |
+
ids_shuffle = torch.argsort(noise, dim=1) + torch.ones(N, L - 1, device=x.device, dtype=int)
|
284 |
+
ids_keep = ids_shuffle[:, :len_keep]
|
285 |
+
|
286 |
+
x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
|
287 |
+
|
288 |
+
x0 = x[:, 0, :]
|
289 |
+
x0 = x0.reshape(N, 1, D)
|
290 |
+
x_masked_add = torch.cat([x0, x_masked], axis=1)
|
291 |
+
return x_masked_add, ids_keep
|
292 |
+
|
293 |
+
def build_self_attention(self, embed_dim, args):
|
294 |
+
return MultiheadAttention(
|
295 |
+
args,
|
296 |
+
embed_dim,
|
297 |
+
args.encoder_attention_heads,
|
298 |
+
dropout=args.attention_dropout,
|
299 |
+
self_attention=True,
|
300 |
+
encoder_decoder_attention=False,
|
301 |
+
subln=args.subln,
|
302 |
+
one_attn=args.one_attn,
|
303 |
+
)
|
304 |
+
|
305 |
+
def build_output_projection(
|
306 |
+
self,
|
307 |
+
args,
|
308 |
+
):
|
309 |
+
if args.share_encoder_input_output_embed:
|
310 |
+
assert args.encoder_embedding_type == "language"
|
311 |
+
output_projection = torch.nn.Linear(
|
312 |
+
self.embed_tokens.weight.shape[1],
|
313 |
+
self.embed_tokens.weight.shape[0],
|
314 |
+
bias=False,
|
315 |
+
)
|
316 |
+
output_projection.weight = self.embed_tokens.weight
|
317 |
+
else:
|
318 |
+
output_projection = torch.nn.Linear(args.encoder_embed_dim, args.vocab_size, bias=False)
|
319 |
+
torch.nn.init.normal_(output_projection.weight, mean=0, std=args.encoder_embed_dim**-0.5)
|
320 |
+
return output_projection
|
321 |
+
|
322 |
+
def checkpointing_and_params_allgather(
|
323 |
+
self,
|
324 |
+
origin_layer,
|
325 |
+
):
|
326 |
+
origin_forward = origin_layer.forward
|
327 |
+
|
328 |
+
from deepspeed import checkpointing
|
329 |
+
def forward(*args, **kwargs):
|
330 |
+
# deepspeed checkpoint not support kwargs
|
331 |
+
ret = checkpointing.checkpoint(origin_forward, *args, **kwargs)
|
332 |
+
return ret
|
333 |
+
|
334 |
+
return forward
|
335 |
+
|
336 |
+
def build_encoder_layer(self, args, depth, attn=None, is_moe_layer=False, is_encoder_decoder=False):
|
337 |
+
layer = EncoderLayer(
|
338 |
+
args,
|
339 |
+
depth,
|
340 |
+
attn,
|
341 |
+
is_moe_layer=is_moe_layer,
|
342 |
+
is_encoder_decoder=is_encoder_decoder,
|
343 |
+
)
|
344 |
+
if args.checkpoint_activations:
|
345 |
+
rank_zero_info("EncoderLayer params: %s", sum(p.numel() for p in layer.parameters() if p.requires_grad))
|
346 |
+
layer = checkpoint_wrapper(layer)
|
347 |
+
# layer.ffn = checkpoint_wrapper(layer.ffn,)
|
348 |
+
if args.fsdp:
|
349 |
+
layer = wrap(layer)
|
350 |
+
return layer
|
351 |
+
|
352 |
+
def checkpointing_layers(self):
|
353 |
+
for i, layer in enumerate(self.layers):
|
354 |
+
rank_zero_info(f"Checkpointing wrapper EncoderLayers: {i}")
|
355 |
+
self.layers[i] = checkpoint_wrapper(layer)
|
356 |
+
|
357 |
+
def forward_embedding(
|
358 |
+
self,
|
359 |
+
src_tokens,
|
360 |
+
token_embedding=None,
|
361 |
+
positions=None,
|
362 |
+
):
|
363 |
+
if token_embedding is None:
|
364 |
+
token_embedding = self.embed_tokens(src_tokens)
|
365 |
+
x = embed = self.embed_scale * token_embedding
|
366 |
+
if self.embed_positions is not None:
|
367 |
+
if src_tokens is not None:
|
368 |
+
x = embed + self.embed_positions(src_tokens, positions=positions)
|
369 |
+
else:
|
370 |
+
x = embed + self.embed_positions(x, positions=positions)
|
371 |
+
is_flip, ids_keep = 0, None
|
372 |
+
if self.mask_ratio > 0:
|
373 |
+
if x.shape[1] == self.vision_len + 1:
|
374 |
+
x, ids_keep = self.random_masking(x, self.mask_ratio)
|
375 |
+
is_flip = 1
|
376 |
+
elif x.shape[1] == self.vision_len + self.max_text_len + 1:
|
377 |
+
vision_tokens = x[:, : self.vision_len + 1, :]
|
378 |
+
vision_tokens, ids_keep = self.random_masking(vision_tokens, self.mask_ratio)
|
379 |
+
x = torch.cat(
|
380 |
+
[
|
381 |
+
vision_tokens,
|
382 |
+
x[
|
383 |
+
:,
|
384 |
+
self.vision_len + 1 :,
|
385 |
+
],
|
386 |
+
],
|
387 |
+
dim=1,
|
388 |
+
)
|
389 |
+
is_flip = 2
|
390 |
+
if self.layernorm_embedding is not None:
|
391 |
+
x = self.layernorm_embedding(x)
|
392 |
+
x = self.dropout_module(x)
|
393 |
+
return x, embed, ids_keep, is_flip
|
394 |
+
|
395 |
+
def forward(
|
396 |
+
self,
|
397 |
+
src_tokens,
|
398 |
+
encoder_padding_mask=None,
|
399 |
+
attn_mask=None,
|
400 |
+
return_all_hiddens=False,
|
401 |
+
token_embeddings=None,
|
402 |
+
multiway_split_position=None,
|
403 |
+
features_only=False,
|
404 |
+
incremental_state=None,
|
405 |
+
positions=None,
|
406 |
+
**kwargs
|
407 |
+
):
|
408 |
+
assert src_tokens is not None or token_embeddings is not None
|
409 |
+
|
410 |
+
if encoder_padding_mask is None:
|
411 |
+
if src_tokens is not None:
|
412 |
+
encoder_padding_mask = torch.zeros_like(src_tokens, device=src_tokens.device).bool()
|
413 |
+
else:
|
414 |
+
encoder_padding_mask = torch.zeros(
|
415 |
+
[token_embeddings.size(0), token_embeddings.size(1)],
|
416 |
+
device=token_embeddings.device,
|
417 |
+
).bool()
|
418 |
+
|
419 |
+
if multiway_split_position is not None:
|
420 |
+
assert self.args.multiway
|
421 |
+
no_sync_module_apply(self, set_split_position(multiway_split_position))
|
422 |
+
|
423 |
+
x, encoder_embedding, ids_keep, is_flip = self.forward_embedding(src_tokens, token_embeddings, positions)
|
424 |
+
if is_flip > 0:
|
425 |
+
if is_flip == 2:
|
426 |
+
text_ids = (
|
427 |
+
torch.arange(
|
428 |
+
self.vision_len + 1, self.vision_len + 1 + self.max_text_len, device=x.device, dtype=torch.int64
|
429 |
+
)
|
430 |
+
.unsqueeze(0)
|
431 |
+
.repeat(ids_keep.shape[0], 1)
|
432 |
+
)
|
433 |
+
cls_ids = torch.zeros(ids_keep.shape[0], 1, device=x.device, dtype=torch.int64)
|
434 |
+
ids_keep = torch.cat([cls_ids, ids_keep, text_ids], dim=1)
|
435 |
+
elif is_flip == 1:
|
436 |
+
cls_ids = torch.zeros(ids_keep.shape[0], 1, device=x.device, dtype=torch.int64)
|
437 |
+
ids_keep = torch.cat([cls_ids, ids_keep], dim=1)
|
438 |
+
if encoder_padding_mask is not None:
|
439 |
+
encoder_padding_mask = torch.gather(encoder_padding_mask, dim=1, index=ids_keep)
|
440 |
+
if attn_mask is not None:
|
441 |
+
attn_mask = torch.gather(
|
442 |
+
attn_mask, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, attn_mask.shape[-1])
|
443 |
+
)
|
444 |
+
attn_mask = torch.gather(attn_mask, dim=2, index=ids_keep.unsqueeze(1).repeat(1, attn_mask.shape[1], 1))
|
445 |
+
if multiway_split_position > 0:
|
446 |
+
multiway_split_position = ids_keep.shape[1] - self.max_text_len
|
447 |
+
x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x))
|
448 |
+
|
449 |
+
encoder_states = []
|
450 |
+
|
451 |
+
if return_all_hiddens:
|
452 |
+
encoder_states.append(x)
|
453 |
+
|
454 |
+
rel_pos_bias = None
|
455 |
+
if self.relative_position is not None:
|
456 |
+
rel_pos_bias = self.relative_position(batch_size=x.size(0), qlen=x.size(1), klen=x.size(1))
|
457 |
+
|
458 |
+
l_aux = []
|
459 |
+
for idx, layer in enumerate(self.layers):
|
460 |
+
x, l_aux_i = layer(
|
461 |
+
x,
|
462 |
+
encoder_padding_mask=encoder_padding_mask if incremental_state is None else None,
|
463 |
+
attn_mask=attn_mask,
|
464 |
+
rel_pos=rel_pos_bias,
|
465 |
+
multiway_split_position=multiway_split_position,
|
466 |
+
incremental_state=incremental_state[idx] if incremental_state is not None else None,
|
467 |
+
)
|
468 |
+
if return_all_hiddens:
|
469 |
+
assert encoder_states is not None
|
470 |
+
encoder_states.append(x)
|
471 |
+
l_aux.append(l_aux_i)
|
472 |
+
|
473 |
+
if multiway_split_position is not None:
|
474 |
+
assert self.args.multiway
|
475 |
+
no_sync_module_apply(self, set_split_position(multiway_split_position))
|
476 |
+
if self.layer_norm is not None:
|
477 |
+
x = self.layer_norm(x)
|
478 |
+
|
479 |
+
if not features_only and self.output_projection is not None:
|
480 |
+
x = self.output_projection(x)
|
481 |
+
|
482 |
+
return {
|
483 |
+
"encoder_out": x,
|
484 |
+
"encoder_embedding": encoder_embedding,
|
485 |
+
"encoder_padding_mask": encoder_padding_mask,
|
486 |
+
"encoder_states": encoder_states,
|
487 |
+
"l_aux": l_aux,
|
488 |
+
"multiway_split_position": multiway_split_position,
|
489 |
+
}
|
vlmo/torchscale/architecture/encoder_decoder.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022 Microsoft
|
2 |
+
# Licensed under The MIT License [see LICENSE for details]
|
3 |
+
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
from vlmo.torchscale.architecture.decoder import Decoder
|
7 |
+
from vlmo.torchscale.architecture.encoder import Encoder
|
8 |
+
|
9 |
+
|
10 |
+
class EncoderDecoder(nn.Module):
|
11 |
+
def __init__(
|
12 |
+
self,
|
13 |
+
args,
|
14 |
+
encoder_embed_tokens=None,
|
15 |
+
encoder_embed_positions=None,
|
16 |
+
decoder_embed_tokens=None,
|
17 |
+
decoder_embed_positions=None,
|
18 |
+
output_projection=None,
|
19 |
+
**kwargs
|
20 |
+
):
|
21 |
+
super().__init__()
|
22 |
+
self.args = args
|
23 |
+
if args.share_all_embeddings:
|
24 |
+
args.share_decoder_input_output_embed = True
|
25 |
+
|
26 |
+
self.encoder = Encoder(args, encoder_embed_tokens, encoder_embed_positions, is_encoder_decoder=True, **kwargs)
|
27 |
+
|
28 |
+
if args.share_all_embeddings and decoder_embed_tokens is None:
|
29 |
+
decoder_embed_tokens = self.encoder.embed_tokens
|
30 |
+
|
31 |
+
self.decoder = Decoder(
|
32 |
+
args, decoder_embed_tokens, decoder_embed_positions, output_projection, is_encoder_decoder=True, **kwargs
|
33 |
+
)
|
34 |
+
|
35 |
+
def forward(self, src_tokens, prev_output_tokens, return_all_hiddens=False, features_only=False, **kwargs):
|
36 |
+
encoder_out = self.encoder(src_tokens, return_all_hiddens=return_all_hiddens)
|
37 |
+
decoder_out = self.decoder(
|
38 |
+
prev_output_tokens,
|
39 |
+
encoder_out=encoder_out,
|
40 |
+
features_only=features_only,
|
41 |
+
return_all_hiddens=return_all_hiddens,
|
42 |
+
)
|
43 |
+
return decoder_out
|
vlmo/torchscale/architecture/utils.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022 Microsoft
|
2 |
+
# Licensed under The MIT License [see LICENSE for details]
|
3 |
+
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
from vlmo.torchscale.component.multihead_attention import MultiheadAttention
|
7 |
+
from vlmo.torchscale.component.multiway_network import MultiwayNetwork
|
8 |
+
|
9 |
+
|
10 |
+
def init_bert_params(module):
|
11 |
+
def normal_(data):
|
12 |
+
data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device))
|
13 |
+
|
14 |
+
if isinstance(module, nn.Linear):
|
15 |
+
normal_(module.weight.data)
|
16 |
+
if module.bias is not None:
|
17 |
+
module.bias.data.zero_()
|
18 |
+
if isinstance(module, nn.Embedding):
|
19 |
+
normal_(module.weight.data)
|
20 |
+
if module.padding_idx is not None:
|
21 |
+
module.weight.data[module.padding_idx].zero_()
|
22 |
+
if isinstance(module, MultiheadAttention):
|
23 |
+
if isinstance(module.q_proj, MultiwayNetwork):
|
24 |
+
normal_(module.q_proj.A.weight.data)
|
25 |
+
normal_(module.q_proj.B.weight.data)
|
26 |
+
normal_(module.k_proj.A.weight.data)
|
27 |
+
normal_(module.k_proj.B.weight.data)
|
28 |
+
normal_(module.v_proj.A.weight.data)
|
29 |
+
normal_(module.v_proj.B.weight.data)
|
30 |
+
else:
|
31 |
+
normal_(module.q_proj.weight.data)
|
32 |
+
normal_(module.k_proj.weight.data)
|
33 |
+
normal_(module.v_proj.weight.data)
|
vlmo/torchscale/component/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022 Microsoft
|
2 |
+
# Licensed under The MIT License [see LICENSE for details]
|
vlmo/torchscale/component/droppath.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022 Microsoft
|
2 |
+
# Licensed under The MIT License [see LICENSE for details]
|
3 |
+
|
4 |
+
import torch.nn as nn
|
5 |
+
from timm.models.layers import drop_path
|
6 |
+
|
7 |
+
|
8 |
+
class DropPath(nn.Module):
|
9 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
10 |
+
|
11 |
+
def __init__(self, drop_prob=None):
|
12 |
+
super(DropPath, self).__init__()
|
13 |
+
self.drop_prob = drop_prob
|
14 |
+
|
15 |
+
def forward(self, x):
|
16 |
+
return drop_path(x, self.drop_prob, self.training)
|
17 |
+
|
18 |
+
def extra_repr(self):
|
19 |
+
return "p={}".format(self.drop_prob)
|
vlmo/torchscale/component/embedding.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022 Microsoft
|
2 |
+
# Licensed under The MIT License [see LICENSE for details]
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
|
9 |
+
class VisionLanguageEmbedding(nn.Module):
|
10 |
+
def __init__(self, text_embed, vision_embed):
|
11 |
+
super().__init__()
|
12 |
+
self.text_embed = text_embed
|
13 |
+
self.vision_embed = vision_embed
|
14 |
+
|
15 |
+
def forward(self, textual_tokens, visual_tokens, **kwargs):
|
16 |
+
if textual_tokens is None:
|
17 |
+
return self.vision_embed(visual_tokens)
|
18 |
+
|
19 |
+
if visual_tokens is None:
|
20 |
+
return self.text_embed(textual_tokens)
|
21 |
+
|
22 |
+
x1 = self.vision_embed(visual_tokens)
|
23 |
+
x2 = self.text_embed(textual_tokens)
|
24 |
+
|
25 |
+
return torch.cat([x1, x2], dim=1)
|
26 |
+
|
27 |
+
|
28 |
+
class VisionEmbedding(nn.Module):
|
29 |
+
"""Image to Patch Embedding"""
|
30 |
+
|
31 |
+
def __init__(
|
32 |
+
self,
|
33 |
+
img_size=224,
|
34 |
+
patch_size=16,
|
35 |
+
in_chans=3,
|
36 |
+
embed_dim=768,
|
37 |
+
contain_mask_token=False,
|
38 |
+
prepend_cls_token=False,
|
39 |
+
):
|
40 |
+
super().__init__()
|
41 |
+
img_size = (img_size, img_size)
|
42 |
+
patch_size = (patch_size, patch_size)
|
43 |
+
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
|
44 |
+
self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
|
45 |
+
self.img_size = img_size
|
46 |
+
self.patch_size = patch_size
|
47 |
+
self.num_patches = num_patches
|
48 |
+
|
49 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
50 |
+
|
51 |
+
if contain_mask_token:
|
52 |
+
self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
53 |
+
else:
|
54 |
+
self.mask_token = None
|
55 |
+
|
56 |
+
if prepend_cls_token:
|
57 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
58 |
+
else:
|
59 |
+
self.cls_token = None
|
60 |
+
|
61 |
+
def num_position_embeddings(self):
|
62 |
+
if self.cls_token is None:
|
63 |
+
return self.num_patches
|
64 |
+
else:
|
65 |
+
return self.num_patches + 1
|
66 |
+
|
67 |
+
def forward(self, x, masked_position=None, **kwargs):
|
68 |
+
B, C, H, W = x.shape
|
69 |
+
x = self.proj(x).flatten(2).transpose(1, 2)
|
70 |
+
|
71 |
+
batch_size, seq_len, _ = x.size()
|
72 |
+
|
73 |
+
if masked_position is not None:
|
74 |
+
assert self.mask_token is not None
|
75 |
+
mask_token = self.mask_token.expand(batch_size, seq_len, -1)
|
76 |
+
w = masked_position.unsqueeze(-1).type_as(mask_token)
|
77 |
+
x = x * (1 - w) + mask_token * w
|
78 |
+
|
79 |
+
if self.cls_token is not None:
|
80 |
+
cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
81 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
82 |
+
|
83 |
+
return x
|
84 |
+
|
85 |
+
|
86 |
+
class TextEmbedding(nn.Embedding):
|
87 |
+
def reset_parameters(self):
|
88 |
+
nn.init.normal_(self.weight, mean=0, std=self.embedding_dim**-0.5)
|
89 |
+
self._fill_padding_idx_with_zero()
|
90 |
+
|
91 |
+
|
92 |
+
class PositionalEmbedding(nn.Embedding):
|
93 |
+
def forward(
|
94 |
+
self,
|
95 |
+
x,
|
96 |
+
positions=None,
|
97 |
+
**kwargs,
|
98 |
+
):
|
99 |
+
if positions is None:
|
100 |
+
# being consistent with Fairseq, which starts from 2.
|
101 |
+
positions = torch.arange(2, x.size(1) + 2, device=x.device).long().unsqueeze(0)
|
102 |
+
return F.embedding(
|
103 |
+
positions,
|
104 |
+
self.weight,
|
105 |
+
self.padding_idx,
|
106 |
+
self.max_norm,
|
107 |
+
self.norm_type,
|
108 |
+
self.scale_grad_by_freq,
|
109 |
+
self.sparse,
|
110 |
+
)
|
vlmo/torchscale/component/feedforward_network.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022 Microsoft
|
2 |
+
# Licensed under The MIT License [see LICENSE for details]
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
try:
|
9 |
+
from apex.normalization import FusedLayerNorm as LayerNorm
|
10 |
+
except ModuleNotFoundError:
|
11 |
+
from torch.nn import LayerNorm
|
12 |
+
|
13 |
+
|
14 |
+
class set_torch_seed(object):
|
15 |
+
def __init__(self, seed):
|
16 |
+
assert isinstance(seed, int)
|
17 |
+
self.rng_state = self.get_rng_state()
|
18 |
+
|
19 |
+
torch.manual_seed(seed)
|
20 |
+
if torch.cuda.is_available():
|
21 |
+
torch.cuda.manual_seed(seed)
|
22 |
+
|
23 |
+
def get_rng_state(self):
|
24 |
+
state = {"torch_rng_state": torch.get_rng_state()}
|
25 |
+
if torch.cuda.is_available():
|
26 |
+
state["cuda_rng_state"] = torch.cuda.get_rng_state()
|
27 |
+
return state
|
28 |
+
|
29 |
+
def set_rng_state(self, state):
|
30 |
+
torch.set_rng_state(state["torch_rng_state"])
|
31 |
+
if torch.cuda.is_available():
|
32 |
+
torch.cuda.set_rng_state(state["cuda_rng_state"])
|
33 |
+
|
34 |
+
def __enter__(self):
|
35 |
+
return self
|
36 |
+
|
37 |
+
def __exit__(self, *exc):
|
38 |
+
self.set_rng_state(self.rng_state)
|
39 |
+
|
40 |
+
|
41 |
+
def make_experts(args, embed_dim, expert_ffn_dim):
|
42 |
+
world_size = 1 if not torch.distributed.is_initialized() else torch.distributed.get_world_size()
|
43 |
+
expert_list = []
|
44 |
+
ddp_rank = args.ddp_rank
|
45 |
+
start_seed = torch.randint(1000000, (1,)).item()
|
46 |
+
# at least as many experts than gpus
|
47 |
+
if args.moe_expert_count >= world_size:
|
48 |
+
assert args.moe_expert_count % world_size == 0, f"{args.moe_expert_count}, {world_size}"
|
49 |
+
local_moe_expert_count = args.moe_expert_count // world_size
|
50 |
+
for i in range(local_moe_expert_count):
|
51 |
+
with set_torch_seed(start_seed + ddp_rank * local_moe_expert_count + i):
|
52 |
+
expert_list.append(
|
53 |
+
FeedForwardNetwork(
|
54 |
+
embed_dim,
|
55 |
+
expert_ffn_dim,
|
56 |
+
args.activation_fn,
|
57 |
+
args.dropout,
|
58 |
+
args.activation_dropout,
|
59 |
+
args.layernorm_eps,
|
60 |
+
args.subln,
|
61 |
+
)
|
62 |
+
)
|
63 |
+
else:
|
64 |
+
assert world_size % args.moe_expert_count == 0, f"{world_size}, {args.moe_expert_count}"
|
65 |
+
|
66 |
+
with set_torch_seed(start_seed + ddp_rank % args.moe_expert_count):
|
67 |
+
expert_list.append(
|
68 |
+
FeedForwardNetwork(
|
69 |
+
embed_dim,
|
70 |
+
expert_ffn_dim,
|
71 |
+
args.activation_fn,
|
72 |
+
args.dropout,
|
73 |
+
args.activation_dropout,
|
74 |
+
args.layernorm_eps,
|
75 |
+
args.subln,
|
76 |
+
)
|
77 |
+
)
|
78 |
+
experts = nn.ModuleList(expert_list)
|
79 |
+
return experts
|
80 |
+
|
81 |
+
|
82 |
+
def get_activation_fn(activation):
|
83 |
+
if activation == "relu":
|
84 |
+
return F.relu
|
85 |
+
elif activation == "gelu":
|
86 |
+
return F.gelu
|
87 |
+
else:
|
88 |
+
raise NotImplementedError
|
89 |
+
|
90 |
+
|
91 |
+
class FeedForwardNetwork(nn.Module):
|
92 |
+
def __init__(
|
93 |
+
self,
|
94 |
+
embed_dim,
|
95 |
+
ffn_dim,
|
96 |
+
activation_fn,
|
97 |
+
dropout,
|
98 |
+
activation_dropout,
|
99 |
+
layernorm_eps,
|
100 |
+
subln=False,
|
101 |
+
):
|
102 |
+
super().__init__()
|
103 |
+
self.embed_dim = embed_dim
|
104 |
+
self.activation_fn = get_activation_fn(activation=str(activation_fn))
|
105 |
+
self.activation_dropout_module = torch.nn.Dropout(activation_dropout)
|
106 |
+
self.dropout_module = torch.nn.Dropout(dropout)
|
107 |
+
self.fc1 = nn.Linear(self.embed_dim, ffn_dim)
|
108 |
+
self.fc2 = nn.Linear(ffn_dim, self.embed_dim)
|
109 |
+
self.ffn_layernorm = LayerNorm(ffn_dim, eps=layernorm_eps) if subln else None
|
110 |
+
|
111 |
+
def reset_parameters(self):
|
112 |
+
self.fc1.reset_parameters()
|
113 |
+
self.fc2.reset_parameters()
|
114 |
+
if self.ffn_layernorm is not None:
|
115 |
+
self.ffn_layernorm.reset_parameters()
|
116 |
+
|
117 |
+
def forward(self, x):
|
118 |
+
# x = x.reshape(-1, x.size(-1))
|
119 |
+
x = self.fc1(x)
|
120 |
+
# x = self.activation_fn(x.float()).type_as(x)
|
121 |
+
x = self.activation_fn(x)
|
122 |
+
x = self.activation_dropout_module(x)
|
123 |
+
if self.ffn_layernorm is not None:
|
124 |
+
x = self.ffn_layernorm(x)
|
125 |
+
x = self.fc2(x)
|
126 |
+
# x = x.view(x_shape)
|
127 |
+
x = self.dropout_module(x)
|
128 |
+
return x
|
vlmo/torchscale/component/multihead_attention.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022 Microsoft
|
2 |
+
# Licensed under The MIT License [see LICENSE for details]
|
3 |
+
|
4 |
+
import math
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from torch import nn
|
9 |
+
|
10 |
+
try:
|
11 |
+
from apex.normalization import FusedLayerNorm as LayerNorm
|
12 |
+
except ModuleNotFoundError:
|
13 |
+
from torch.nn import LayerNorm
|
14 |
+
|
15 |
+
from .multiway_network import MultiwayWrapper
|
16 |
+
from .xpos_relative_position import XPOS
|
17 |
+
|
18 |
+
|
19 |
+
class MultiheadAttention(nn.Module):
|
20 |
+
def __init__(
|
21 |
+
self,
|
22 |
+
args,
|
23 |
+
embed_dim,
|
24 |
+
num_heads,
|
25 |
+
dropout=0.0,
|
26 |
+
self_attention=False,
|
27 |
+
encoder_decoder_attention=False,
|
28 |
+
subln=False,
|
29 |
+
one_attn=False,
|
30 |
+
):
|
31 |
+
super().__init__()
|
32 |
+
self.args = args
|
33 |
+
self.embed_dim = embed_dim
|
34 |
+
self.num_heads = num_heads
|
35 |
+
self.head_dim = embed_dim // num_heads
|
36 |
+
self.scaling = self.head_dim ** (-0.5)
|
37 |
+
self.self_attention = self_attention
|
38 |
+
self.encoder_decoder_attention = encoder_decoder_attention
|
39 |
+
assert self.self_attention ^ self.encoder_decoder_attention
|
40 |
+
if one_attn:
|
41 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=True)
|
42 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True)
|
43 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True)
|
44 |
+
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
|
45 |
+
else:
|
46 |
+
self.k_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=True))
|
47 |
+
self.v_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=True))
|
48 |
+
self.q_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=True))
|
49 |
+
# self.qkv_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim*3, bias=True))
|
50 |
+
self.out_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=True))
|
51 |
+
self.inner_attn_ln = (
|
52 |
+
MultiwayWrapper(args, LayerNorm(self.embed_dim, eps=args.layernorm_eps))
|
53 |
+
if subln and self.self_attention
|
54 |
+
else None
|
55 |
+
)
|
56 |
+
self.dropout_module = torch.nn.Dropout(dropout)
|
57 |
+
self.xpos = XPOS(self.head_dim, args.xpos_scale_base) if args.xpos_rel_pos and self.self_attention else None
|
58 |
+
|
59 |
+
def reset_parameters(self):
|
60 |
+
nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
|
61 |
+
nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
|
62 |
+
nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
|
63 |
+
nn.init.xavier_uniform_(self.out_proj.weight)
|
64 |
+
nn.init.constant_(self.out_proj.bias, 0.0)
|
65 |
+
|
66 |
+
def forward(
|
67 |
+
self,
|
68 |
+
query,
|
69 |
+
key,
|
70 |
+
value,
|
71 |
+
incremental_state=None,
|
72 |
+
key_padding_mask=None,
|
73 |
+
attn_mask=None,
|
74 |
+
rel_pos=None,
|
75 |
+
):
|
76 |
+
bsz, tgt_len, embed_dim = query.size()
|
77 |
+
src_len = tgt_len
|
78 |
+
assert embed_dim == self.embed_dim, f"query dim {embed_dim} != {self.embed_dim}"
|
79 |
+
|
80 |
+
key_bsz, src_len, _ = key.size()
|
81 |
+
assert key_bsz == bsz, f"{query.size(), key.size()}"
|
82 |
+
assert value is not None
|
83 |
+
assert bsz, src_len == value.shape[:2]
|
84 |
+
# if query is key and key is value:
|
85 |
+
# qkv = self.qkv_proj(query)
|
86 |
+
# else:
|
87 |
+
# # W*(q+k+v) = W(q) + W(k) + W(v)
|
88 |
+
# qkv = self.qkv_proj(query+key+value)
|
89 |
+
# q,k,v = qkv.split(self.embed_dim, dim=-1)
|
90 |
+
|
91 |
+
q = self.q_proj(query)
|
92 |
+
k = self.k_proj(key)
|
93 |
+
v = self.v_proj(value)
|
94 |
+
|
95 |
+
q = (q * self.scaling).view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2)
|
96 |
+
k = k.view(bsz, src_len, self.num_heads, self.head_dim).transpose(1, 2)
|
97 |
+
v = v.view(bsz, src_len, self.num_heads, self.head_dim).transpose(1, 2)
|
98 |
+
q = q.reshape(bsz * self.num_heads, tgt_len, self.head_dim)
|
99 |
+
k = k.reshape(bsz * self.num_heads, src_len, self.head_dim)
|
100 |
+
v = v.reshape(bsz * self.num_heads, src_len, self.head_dim)
|
101 |
+
|
102 |
+
if incremental_state is not None:
|
103 |
+
if "prev_key" in incremental_state:
|
104 |
+
prev_key = incremental_state["prev_key"].view(bsz * self.num_heads, -1, self.head_dim)
|
105 |
+
prev_value = incremental_state["prev_value"].view(bsz * self.num_heads, -1, self.head_dim)
|
106 |
+
k = torch.cat([prev_key, k], dim=1)
|
107 |
+
v = torch.cat([prev_value, v], dim=1)
|
108 |
+
incremental_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
|
109 |
+
incremental_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
|
110 |
+
src_len = k.size(1)
|
111 |
+
|
112 |
+
if self.xpos is not None:
|
113 |
+
if incremental_state is not None:
|
114 |
+
offset = src_len - 1
|
115 |
+
else:
|
116 |
+
offset = 0
|
117 |
+
k = self.xpos(k, offset=0, downscale=True)
|
118 |
+
q = self.xpos(q, offset=offset, downscale=False)
|
119 |
+
|
120 |
+
attn_weights = torch.bmm(q, k.transpose(1, 2))
|
121 |
+
|
122 |
+
if attn_mask is not None:
|
123 |
+
attn_weights = torch.nan_to_num(attn_weights)
|
124 |
+
if len(attn_mask.shape) != len(attn_weights.shape):
|
125 |
+
attn_mask = attn_mask.unsqueeze(0)
|
126 |
+
else:
|
127 |
+
attn_mask = attn_mask.repeat_interleave(self.num_heads, dim=0)
|
128 |
+
attn_weights += attn_mask
|
129 |
+
|
130 |
+
if key_padding_mask is not None:
|
131 |
+
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
132 |
+
attn_weights = attn_weights.masked_fill(
|
133 |
+
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
|
134 |
+
float("-inf"),
|
135 |
+
)
|
136 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
137 |
+
|
138 |
+
if rel_pos is not None:
|
139 |
+
rel_pos = rel_pos.view(attn_weights.size())
|
140 |
+
attn_weights = attn_weights + rel_pos
|
141 |
+
|
142 |
+
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).type_as(attn_weights)
|
143 |
+
attn_probs = self.dropout_module(attn_weights)
|
144 |
+
|
145 |
+
attn = torch.bmm(attn_probs, v)
|
146 |
+
attn = attn.transpose(0, 1).reshape(tgt_len, bsz, embed_dim).transpose(0, 1)
|
147 |
+
|
148 |
+
if self.inner_attn_ln is not None:
|
149 |
+
attn = self.inner_attn_ln(attn)
|
150 |
+
|
151 |
+
attn = self.out_proj(attn)
|
152 |
+
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0)
|
153 |
+
|
154 |
+
return attn, attn_weights
|
vlmo/torchscale/component/multiway_network.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022 Microsoft
|
2 |
+
# Licensed under The MIT License [see LICENSE for details]
|
3 |
+
|
4 |
+
import copy
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
|
9 |
+
|
10 |
+
def MultiwayWrapper(args, module, dim=1):
|
11 |
+
if args.multiway:
|
12 |
+
return MultiwayNetwork(module, dim=dim)
|
13 |
+
return module
|
14 |
+
|
15 |
+
|
16 |
+
def set_split_position(position):
|
17 |
+
def apply_fn(module):
|
18 |
+
if hasattr(module, "split_position"):
|
19 |
+
module.split_position = position
|
20 |
+
|
21 |
+
return apply_fn
|
22 |
+
|
23 |
+
|
24 |
+
class MultiwayNetwork(nn.Module):
|
25 |
+
def __init__(self, module, dim=1):
|
26 |
+
super().__init__()
|
27 |
+
self.dim = dim
|
28 |
+
self.A = module
|
29 |
+
self.B = copy.deepcopy(module)
|
30 |
+
self.B.reset_parameters()
|
31 |
+
self.split_position = -1
|
32 |
+
|
33 |
+
def forward(self, x, **kwargs):
|
34 |
+
if self.split_position == -1:
|
35 |
+
return self.A(x, **kwargs)
|
36 |
+
if self.split_position == 0:
|
37 |
+
return self.B(x, **kwargs)
|
38 |
+
x1, x2 = torch.split(
|
39 |
+
x,
|
40 |
+
[self.split_position, x.size(self.dim) - self.split_position],
|
41 |
+
dim=self.dim,
|
42 |
+
)
|
43 |
+
# x1, x2 = x[:self.split_position], x[self.split_position:]
|
44 |
+
y1, y2 = self.A(x1, **kwargs), self.B(x2, **kwargs)
|
45 |
+
return torch.cat([y1, y2], dim=self.dim)
|
46 |
+
|
47 |
+
|
48 |
+
class MutliwayEmbedding(MultiwayNetwork):
|
49 |
+
def __init__(self, modules, dim=1):
|
50 |
+
super(MultiwayNetwork, self).__init__()
|
51 |
+
self.dim = dim
|
52 |
+
assert len(modules) == 2
|
53 |
+
self.A = modules[0]
|
54 |
+
self.B = modules[1]
|
55 |
+
self.split_position = -1
|
vlmo/torchscale/component/relative_position_bias.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022 Microsoft
|
2 |
+
# Licensed under The MIT License [see LICENSE for details]
|
3 |
+
|
4 |
+
import math
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
|
9 |
+
|
10 |
+
class RelativePositionBias(nn.Module):
|
11 |
+
def __init__(self, bidirectional=True, num_buckets=32, max_distance=128, n_heads=12):
|
12 |
+
super().__init__()
|
13 |
+
self.bidirectional = bidirectional
|
14 |
+
self.num_buckets = num_buckets
|
15 |
+
self.max_distance = max_distance
|
16 |
+
self.n_heads = n_heads
|
17 |
+
self.relative_attention_bias = nn.Embedding(self.num_buckets, self.n_heads)
|
18 |
+
|
19 |
+
@staticmethod
|
20 |
+
def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
|
21 |
+
ret = 0
|
22 |
+
n = -relative_position
|
23 |
+
if bidirectional:
|
24 |
+
num_buckets //= 2
|
25 |
+
ret += (n < 0).to(torch.long) * num_buckets
|
26 |
+
n = torch.abs(n)
|
27 |
+
else:
|
28 |
+
n = torch.max(n, torch.zeros_like(n))
|
29 |
+
|
30 |
+
max_exact = num_buckets // 2
|
31 |
+
is_small = n < max_exact
|
32 |
+
|
33 |
+
val_if_large = max_exact + (
|
34 |
+
torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
|
35 |
+
).to(torch.long)
|
36 |
+
val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
|
37 |
+
|
38 |
+
ret += torch.where(is_small, n, val_if_large)
|
39 |
+
return ret
|
40 |
+
|
41 |
+
def compute_bias(self, qlen, klen, step=None):
|
42 |
+
step = 0 if step is None else step
|
43 |
+
context_position = torch.arange(
|
44 |
+
step,
|
45 |
+
step + qlen,
|
46 |
+
dtype=torch.long,
|
47 |
+
device=self.relative_attention_bias.weight.device,
|
48 |
+
)[:, None]
|
49 |
+
memory_position = torch.arange(klen, dtype=torch.long, device=self.relative_attention_bias.weight.device)[
|
50 |
+
None, :
|
51 |
+
]
|
52 |
+
relative_position = memory_position - context_position # shape (qlen, klen)
|
53 |
+
|
54 |
+
rp_bucket = self._relative_position_bucket(
|
55 |
+
relative_position, # shape (qlen, klen)
|
56 |
+
bidirectional=self.bidirectional,
|
57 |
+
num_buckets=self.num_buckets,
|
58 |
+
max_distance=self.max_distance,
|
59 |
+
)
|
60 |
+
rp_bucket = rp_bucket.to(self.relative_attention_bias.weight.device)
|
61 |
+
values = self.relative_attention_bias(rp_bucket) # shape (qlen, klen, num_heads)
|
62 |
+
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, qlen, klen)
|
63 |
+
return values
|
64 |
+
|
65 |
+
def forward(self, batch_size, qlen, klen, step=None):
|
66 |
+
# shape (batch * num_heads, qlen, klen)
|
67 |
+
return self.compute_bias(qlen, klen, step).repeat(batch_size, 1, 1, 1).view(-1, qlen, klen)
|
vlmo/torchscale/component/xpos_relative_position.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022 Microsoft
|
2 |
+
# Licensed under The MIT License [see LICENSE for details]
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
|
8 |
+
def fixed_pos_embedding(x):
|
9 |
+
seq_len, dim = x.shape
|
10 |
+
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim) / dim))
|
11 |
+
sinusoid_inp = torch.einsum("i , j -> i j", torch.arange(0, seq_len, dtype=torch.float), inv_freq).to(x)
|
12 |
+
return torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)
|
13 |
+
|
14 |
+
|
15 |
+
def rotate_every_two(x):
|
16 |
+
x1 = x[:, :, ::2]
|
17 |
+
x2 = x[:, :, 1::2]
|
18 |
+
x = torch.stack((-x2, x1), dim=-1)
|
19 |
+
return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)')\
|
20 |
+
|
21 |
+
|
22 |
+
def duplicate_interleave(m):
|
23 |
+
"""
|
24 |
+
A simple version of `torch.repeat_interleave` for duplicating a matrix while interleaving the copy.
|
25 |
+
"""
|
26 |
+
dim0 = m.shape[0]
|
27 |
+
m = m.view(-1, 1) # flatten the matrix
|
28 |
+
m = m.repeat(1, 2) # repeat all elements into the 2nd dimension
|
29 |
+
m = m.view(dim0, -1) # reshape into a matrix, interleaving the copy
|
30 |
+
return m
|
31 |
+
|
32 |
+
|
33 |
+
def apply_rotary_pos_emb(x, sin, cos, scale=1):
|
34 |
+
sin, cos = map(lambda t: duplicate_interleave(t * scale), (sin, cos))
|
35 |
+
# einsum notation for lambda t: repeat(t[offset:x.shape[1]+offset,:], "n d -> () n () (d j)", j=2)
|
36 |
+
return (x * cos) + (rotate_every_two(x) * sin)
|
37 |
+
|
38 |
+
|
39 |
+
class XPOS(nn.Module):
|
40 |
+
def __init__(self, head_dim, scale_base=512):
|
41 |
+
super().__init__()
|
42 |
+
self.head_dim = head_dim
|
43 |
+
self.scale_base = scale_base
|
44 |
+
self.register_buffer("scale", (torch.arange(0, head_dim, 2) + 0.4 * head_dim) / (1.4 * head_dim))
|
45 |
+
|
46 |
+
def forward(self, x, offset=0, downscale=False):
|
47 |
+
length = x.shape[1]
|
48 |
+
min_pos = -(length + offset) // 2
|
49 |
+
max_pos = length + offset + min_pos
|
50 |
+
scale = self.scale ** torch.arange(min_pos, max_pos, 1).to(self.scale).div(self.scale_base)[:, None]
|
51 |
+
sin, cos = fixed_pos_embedding(scale)
|
52 |
+
|
53 |
+
if scale.shape[0] > length:
|
54 |
+
scale = scale[-length:]
|
55 |
+
sin = sin[-length:]
|
56 |
+
cos = cos[-length:]
|
57 |
+
|
58 |
+
if downscale:
|
59 |
+
scale = 1 / scale
|
60 |
+
|
61 |
+
x = apply_rotary_pos_emb(x, sin, cos, scale)
|
62 |
+
return x
|
vlmo/torchscale/model/BEiT3.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022 Microsoft
|
2 |
+
# Licensed under The MIT License [see LICENSE for details]
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
from vlmo.torchscale.architecture.encoder import Encoder
|
8 |
+
from vlmo.torchscale.component.embedding import (
|
9 |
+
PositionalEmbedding,
|
10 |
+
TextEmbedding,
|
11 |
+
VisionEmbedding,
|
12 |
+
)
|
13 |
+
from vlmo.torchscale.component.multiway_network import MutliwayEmbedding
|
14 |
+
|
15 |
+
|
16 |
+
class BEiT3(nn.Module):
|
17 |
+
def __init__(self, args, **kwargs):
|
18 |
+
super().__init__()
|
19 |
+
self.args = args
|
20 |
+
assert args.multiway
|
21 |
+
assert args.vocab_size > 0
|
22 |
+
assert not args.share_encoder_input_output_embed
|
23 |
+
self.text_embed = TextEmbedding(args.vocab_size, args.encoder_embed_dim)
|
24 |
+
self.vision_embed = VisionEmbedding(
|
25 |
+
args.img_size,
|
26 |
+
args.patch_size,
|
27 |
+
args.in_chans,
|
28 |
+
args.encoder_embed_dim,
|
29 |
+
contain_mask_token=True,
|
30 |
+
prepend_cls_token=True,
|
31 |
+
)
|
32 |
+
# being consistent with Fairseq, which starts from 2 for position embedding
|
33 |
+
embed_positions = MutliwayEmbedding(
|
34 |
+
modules=[
|
35 |
+
PositionalEmbedding(self.vision_embed.num_position_embeddings() + 2, args.encoder_embed_dim),
|
36 |
+
PositionalEmbedding(args.max_source_positions, args.encoder_embed_dim),
|
37 |
+
],
|
38 |
+
dim=1,
|
39 |
+
)
|
40 |
+
self.encoder = Encoder(
|
41 |
+
args,
|
42 |
+
embed_tokens=None,
|
43 |
+
embed_positions=embed_positions,
|
44 |
+
output_projection=None,
|
45 |
+
is_encoder_decoder=False,
|
46 |
+
)
|
47 |
+
|
48 |
+
def forward(
|
49 |
+
self,
|
50 |
+
textual_tokens=None,
|
51 |
+
visual_tokens=None,
|
52 |
+
text_padding_position=None,
|
53 |
+
attn_mask=None,
|
54 |
+
vision_masked_position=None,
|
55 |
+
incremental_state=None,
|
56 |
+
positions=None,
|
57 |
+
):
|
58 |
+
assert textual_tokens is not None or visual_tokens is not None
|
59 |
+
|
60 |
+
if textual_tokens is None:
|
61 |
+
x = self.vision_embed(visual_tokens, vision_masked_position)
|
62 |
+
encoder_padding_mask = None
|
63 |
+
multiway_split_position = -1
|
64 |
+
elif visual_tokens is None:
|
65 |
+
x = self.text_embed(textual_tokens)
|
66 |
+
encoder_padding_mask = text_padding_position
|
67 |
+
multiway_split_position = 0
|
68 |
+
else:
|
69 |
+
x1 = self.vision_embed(visual_tokens, vision_masked_position)
|
70 |
+
multiway_split_position = x1.size(1)
|
71 |
+
x2 = self.text_embed(textual_tokens)
|
72 |
+
diff = x1.shape[0] // x2.shape[0]
|
73 |
+
if diff != 1:
|
74 |
+
x2 = torch.repeat_interleave(x2, diff, dim=0)
|
75 |
+
text_padding_position = torch.repeat_interleave(text_padding_position, diff, dim=0)
|
76 |
+
x = torch.cat([x1, x2], dim=1)
|
77 |
+
if text_padding_position is not None:
|
78 |
+
encoder_padding_mask = torch.cat(
|
79 |
+
[
|
80 |
+
torch.zeros(x1.shape[:-1], device=x1.device, dtype=torch.bool),
|
81 |
+
text_padding_position,
|
82 |
+
],
|
83 |
+
dim=1,
|
84 |
+
)
|
85 |
+
else:
|
86 |
+
encoder_padding_mask = None
|
87 |
+
encoder_out = self.encoder(
|
88 |
+
src_tokens=None,
|
89 |
+
encoder_padding_mask=encoder_padding_mask,
|
90 |
+
attn_mask=attn_mask,
|
91 |
+
token_embeddings=x,
|
92 |
+
multiway_split_position=multiway_split_position,
|
93 |
+
incremental_state=incremental_state,
|
94 |
+
positions=positions,
|
95 |
+
)
|
96 |
+
return encoder_out
|
vlmo/torchscale/model/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022 Microsoft
|
2 |
+
# Licensed under The MIT License [see LICENSE for details]
|
vlmo/transforms/__init__.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .pixelbert import (
|
2 |
+
pixelbert_transform,
|
3 |
+
pixelbert_transform_randaug,
|
4 |
+
)
|
5 |
+
from .square_transform import (
|
6 |
+
square_transform,
|
7 |
+
square_transform_randaug,
|
8 |
+
)
|
9 |
+
|
10 |
+
_transforms = {
|
11 |
+
"pixelbert": pixelbert_transform,
|
12 |
+
"pixelbert_randaug": pixelbert_transform_randaug,
|
13 |
+
"square_transform": square_transform,
|
14 |
+
"square_transform_randaug": square_transform_randaug,
|
15 |
+
}
|
16 |
+
|
17 |
+
|
18 |
+
def keys_to_transforms(keys: list, size=224):
|
19 |
+
return [_transforms[key](size=size) for key in keys]
|
vlmo/transforms/pixelbert.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .utils import (
|
2 |
+
inception_normalize,
|
3 |
+
MinMaxResize,
|
4 |
+
)
|
5 |
+
from torchvision import transforms
|
6 |
+
from .randaug import RandAugment
|
7 |
+
|
8 |
+
|
9 |
+
def pixelbert_transform(size=800):
|
10 |
+
longer = int((1333 / 800) * size)
|
11 |
+
return transforms.Compose(
|
12 |
+
[
|
13 |
+
MinMaxResize(shorter=size, longer=longer),
|
14 |
+
transforms.ToTensor(),
|
15 |
+
inception_normalize,
|
16 |
+
]
|
17 |
+
)
|
18 |
+
|
19 |
+
|
20 |
+
def pixelbert_transform_randaug(size=800):
|
21 |
+
longer = int((1333 / 800) * size)
|
22 |
+
trs = transforms.Compose(
|
23 |
+
[
|
24 |
+
MinMaxResize(shorter=size, longer=longer),
|
25 |
+
transforms.ToTensor(),
|
26 |
+
inception_normalize,
|
27 |
+
]
|
28 |
+
)
|
29 |
+
trs.transforms.insert(0, RandAugment(2, 9))
|
30 |
+
return trs
|
vlmo/transforms/randaug.py
ADDED
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# code in this file is adpated from rpmcruz/autoaugment
|
2 |
+
# https://github.com/rpmcruz/autoaugment/blob/master/transformations.py
|
3 |
+
import random
|
4 |
+
|
5 |
+
import PIL
|
6 |
+
|
7 |
+
# from PIL import ImageOps, ImageEnhance, ImageDraw
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
from PIL import Image
|
11 |
+
|
12 |
+
|
13 |
+
def ShearX(img, v): # [-0.3, 0.3]
|
14 |
+
assert -0.3 <= v <= 0.3
|
15 |
+
if random.random() > 0.5:
|
16 |
+
v = -v
|
17 |
+
return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0))
|
18 |
+
|
19 |
+
|
20 |
+
def ShearY(img, v): # [-0.3, 0.3]
|
21 |
+
assert -0.3 <= v <= 0.3
|
22 |
+
if random.random() > 0.5:
|
23 |
+
v = -v
|
24 |
+
return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0))
|
25 |
+
|
26 |
+
|
27 |
+
def TranslateX(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
|
28 |
+
assert -0.45 <= v <= 0.45
|
29 |
+
if random.random() > 0.5:
|
30 |
+
v = -v
|
31 |
+
v = v * img.size[0]
|
32 |
+
return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))
|
33 |
+
|
34 |
+
|
35 |
+
def TranslateXabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
|
36 |
+
assert 0 <= v
|
37 |
+
if random.random() > 0.5:
|
38 |
+
v = -v
|
39 |
+
return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))
|
40 |
+
|
41 |
+
|
42 |
+
def TranslateY(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
|
43 |
+
assert -0.45 <= v <= 0.45
|
44 |
+
if random.random() > 0.5:
|
45 |
+
v = -v
|
46 |
+
v = v * img.size[1]
|
47 |
+
return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))
|
48 |
+
|
49 |
+
|
50 |
+
def TranslateYabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
|
51 |
+
assert 0 <= v
|
52 |
+
if random.random() > 0.5:
|
53 |
+
v = -v
|
54 |
+
return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))
|
55 |
+
|
56 |
+
|
57 |
+
def Rotate(img, v): # [-30, 30]
|
58 |
+
assert -30 <= v <= 30
|
59 |
+
if random.random() > 0.5:
|
60 |
+
v = -v
|
61 |
+
return img.rotate(v)
|
62 |
+
|
63 |
+
|
64 |
+
def AutoContrast(img, _):
|
65 |
+
return PIL.ImageOps.autocontrast(img)
|
66 |
+
|
67 |
+
|
68 |
+
def Invert(img, _):
|
69 |
+
return PIL.ImageOps.invert(img)
|
70 |
+
|
71 |
+
|
72 |
+
def Equalize(img, _):
|
73 |
+
return PIL.ImageOps.equalize(img)
|
74 |
+
|
75 |
+
|
76 |
+
def Flip(img, _): # not from the paper
|
77 |
+
return PIL.ImageOps.mirror(img)
|
78 |
+
|
79 |
+
|
80 |
+
def Solarize(img, v): # [0, 256]
|
81 |
+
assert 0 <= v <= 256
|
82 |
+
return PIL.ImageOps.solarize(img, v)
|
83 |
+
|
84 |
+
|
85 |
+
def SolarizeAdd(img, addition=0, threshold=128):
|
86 |
+
img_np = np.array(img).astype(np.int)
|
87 |
+
img_np = img_np + addition
|
88 |
+
img_np = np.clip(img_np, 0, 255)
|
89 |
+
img_np = img_np.astype(np.uint8)
|
90 |
+
img = Image.fromarray(img_np)
|
91 |
+
return PIL.ImageOps.solarize(img, threshold)
|
92 |
+
|
93 |
+
|
94 |
+
def Posterize(img, v): # [4, 8]
|
95 |
+
v = int(v)
|
96 |
+
v = max(1, v)
|
97 |
+
return PIL.ImageOps.posterize(img, v)
|
98 |
+
|
99 |
+
|
100 |
+
def Contrast(img, v): # [0.1,1.9]
|
101 |
+
assert 0.1 <= v <= 1.9
|
102 |
+
return PIL.ImageEnhance.Contrast(img).enhance(v)
|
103 |
+
|
104 |
+
|
105 |
+
def Color(img, v): # [0.1,1.9]
|
106 |
+
assert 0.1 <= v <= 1.9
|
107 |
+
return PIL.ImageEnhance.Color(img).enhance(v)
|
108 |
+
|
109 |
+
|
110 |
+
def Brightness(img, v): # [0.1,1.9]
|
111 |
+
assert 0.1 <= v <= 1.9
|
112 |
+
return PIL.ImageEnhance.Brightness(img).enhance(v)
|
113 |
+
|
114 |
+
|
115 |
+
def Sharpness(img, v): # [0.1,1.9]
|
116 |
+
assert 0.1 <= v <= 1.9
|
117 |
+
return PIL.ImageEnhance.Sharpness(img).enhance(v)
|
118 |
+
|
119 |
+
|
120 |
+
def Cutout(img, v): # [0, 60] => percentage: [0, 0.2]
|
121 |
+
assert 0.0 <= v <= 0.2
|
122 |
+
if v <= 0.0:
|
123 |
+
return img
|
124 |
+
|
125 |
+
v = v * img.size[0]
|
126 |
+
return CutoutAbs(img, v)
|
127 |
+
|
128 |
+
|
129 |
+
def CutoutAbs(img, v): # [0, 60] => percentage: [0, 0.2]
|
130 |
+
# assert 0 <= v <= 20
|
131 |
+
if v < 0:
|
132 |
+
return img
|
133 |
+
w, h = img.size
|
134 |
+
x0 = np.random.uniform(w)
|
135 |
+
y0 = np.random.uniform(h)
|
136 |
+
|
137 |
+
x0 = int(max(0, x0 - v / 2.0))
|
138 |
+
y0 = int(max(0, y0 - v / 2.0))
|
139 |
+
x1 = min(w, x0 + v)
|
140 |
+
y1 = min(h, y0 + v)
|
141 |
+
|
142 |
+
xy = (x0, y0, x1, y1)
|
143 |
+
color = (125, 123, 114)
|
144 |
+
# color = (0, 0, 0)
|
145 |
+
img = img.copy()
|
146 |
+
PIL.ImageDraw.Draw(img).rectangle(xy, color)
|
147 |
+
return img
|
148 |
+
|
149 |
+
|
150 |
+
def SamplePairing(imgs): # [0, 0.4]
|
151 |
+
def f(img1, v):
|
152 |
+
i = np.random.choice(len(imgs))
|
153 |
+
img2 = PIL.Image.fromarray(imgs[i])
|
154 |
+
return PIL.Image.blend(img1, img2, v)
|
155 |
+
|
156 |
+
return f
|
157 |
+
|
158 |
+
|
159 |
+
def Identity(img, v):
|
160 |
+
return img
|
161 |
+
|
162 |
+
|
163 |
+
def augment_list(): # 16 oeprations and their ranges
|
164 |
+
# https://github.com/google-research/uda/blob/master/image/randaugment/policies.py#L57
|
165 |
+
# l = [
|
166 |
+
# (Identity, 0., 1.0),
|
167 |
+
# (ShearX, 0., 0.3), # 0
|
168 |
+
# (ShearY, 0., 0.3), # 1
|
169 |
+
# (TranslateX, 0., 0.33), # 2
|
170 |
+
# (TranslateY, 0., 0.33), # 3
|
171 |
+
# (Rotate, 0, 30), # 4
|
172 |
+
# (AutoContrast, 0, 1), # 5
|
173 |
+
# (Invert, 0, 1), # 6
|
174 |
+
# (Equalize, 0, 1), # 7
|
175 |
+
# (Solarize, 0, 110), # 8
|
176 |
+
# (Posterize, 4, 8), # 9
|
177 |
+
# # (Contrast, 0.1, 1.9), # 10
|
178 |
+
# (Color, 0.1, 1.9), # 11
|
179 |
+
# (Brightness, 0.1, 1.9), # 12
|
180 |
+
# (Sharpness, 0.1, 1.9), # 13
|
181 |
+
# # (Cutout, 0, 0.2), # 14
|
182 |
+
# # (SamplePairing(imgs), 0, 0.4), # 15
|
183 |
+
# ]
|
184 |
+
|
185 |
+
# https://github.com/tensorflow/tpu/blob/8462d083dd89489a79e3200bcc8d4063bf362186/models/official/efficientnet/autoaugment.py#L505
|
186 |
+
l = [
|
187 |
+
(AutoContrast, 0, 1),
|
188 |
+
(Equalize, 0, 1),
|
189 |
+
# (Invert, 0, 1),
|
190 |
+
(Rotate, 0, 30),
|
191 |
+
(Posterize, 0, 4),
|
192 |
+
(Solarize, 0, 256),
|
193 |
+
(SolarizeAdd, 0, 110),
|
194 |
+
(Color, 0.1, 1.9),
|
195 |
+
(Contrast, 0.1, 1.9),
|
196 |
+
(Brightness, 0.1, 1.9),
|
197 |
+
(Sharpness, 0.1, 1.9),
|
198 |
+
(ShearX, 0.0, 0.3),
|
199 |
+
(ShearY, 0.0, 0.3),
|
200 |
+
# (CutoutAbs, 0, 40),
|
201 |
+
(TranslateXabs, 0.0, 100),
|
202 |
+
(TranslateYabs, 0.0, 100),
|
203 |
+
]
|
204 |
+
|
205 |
+
return l
|
206 |
+
|
207 |
+
|
208 |
+
class Lighting(object):
|
209 |
+
"""Lighting noise(AlexNet - style PCA - based noise)"""
|
210 |
+
|
211 |
+
def __init__(self, alphastd, eigval, eigvec):
|
212 |
+
self.alphastd = alphastd
|
213 |
+
self.eigval = torch.Tensor(eigval)
|
214 |
+
self.eigvec = torch.Tensor(eigvec)
|
215 |
+
|
216 |
+
def __call__(self, img):
|
217 |
+
if self.alphastd == 0:
|
218 |
+
return img
|
219 |
+
|
220 |
+
alpha = img.new().resize_(3).normal_(0, self.alphastd)
|
221 |
+
rgb = (
|
222 |
+
self.eigvec.type_as(img)
|
223 |
+
.clone()
|
224 |
+
.mul(alpha.view(1, 3).expand(3, 3))
|
225 |
+
.mul(self.eigval.view(1, 3).expand(3, 3))
|
226 |
+
.sum(1)
|
227 |
+
.squeeze()
|
228 |
+
)
|
229 |
+
|
230 |
+
return img.add(rgb.view(3, 1, 1).expand_as(img))
|
231 |
+
|
232 |
+
|
233 |
+
class CutoutDefault(object):
|
234 |
+
"""
|
235 |
+
Reference : https://github.com/quark0/darts/blob/master/cnn/utils.py
|
236 |
+
"""
|
237 |
+
|
238 |
+
def __init__(self, length):
|
239 |
+
self.length = length
|
240 |
+
|
241 |
+
def __call__(self, img):
|
242 |
+
h, w = img.size(1), img.size(2)
|
243 |
+
mask = np.ones((h, w), np.float32)
|
244 |
+
y = np.random.randint(h)
|
245 |
+
x = np.random.randint(w)
|
246 |
+
|
247 |
+
y1 = np.clip(y - self.length // 2, 0, h)
|
248 |
+
y2 = np.clip(y + self.length // 2, 0, h)
|
249 |
+
x1 = np.clip(x - self.length // 2, 0, w)
|
250 |
+
x2 = np.clip(x + self.length // 2, 0, w)
|
251 |
+
|
252 |
+
mask[y1:y2, x1:x2] = 0.0
|
253 |
+
mask = torch.from_numpy(mask)
|
254 |
+
mask = mask.expand_as(img)
|
255 |
+
img *= mask
|
256 |
+
return img
|
257 |
+
|
258 |
+
|
259 |
+
class RandAugment:
|
260 |
+
def __init__(self, n, m):
|
261 |
+
self.n = n
|
262 |
+
self.m = m # [0, 30]
|
263 |
+
self.augment_list = augment_list()
|
264 |
+
|
265 |
+
def __call__(self, img):
|
266 |
+
ops = random.choices(self.augment_list, k=self.n)
|
267 |
+
for op, minval, maxval in ops:
|
268 |
+
val = (float(self.m) / 30) * float(maxval - minval) + minval
|
269 |
+
img = op(img, val)
|
270 |
+
|
271 |
+
return img
|
vlmo/transforms/randaugment.py
ADDED
@@ -0,0 +1,334 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
|
5 |
+
# aug functions
|
6 |
+
def identity_func(img):
|
7 |
+
return img
|
8 |
+
|
9 |
+
|
10 |
+
def autocontrast_func(img, cutoff=0):
|
11 |
+
"""
|
12 |
+
same output as PIL.ImageOps.autocontrast
|
13 |
+
"""
|
14 |
+
n_bins = 256
|
15 |
+
|
16 |
+
def tune_channel(ch):
|
17 |
+
n = ch.size
|
18 |
+
cut = cutoff * n // 100
|
19 |
+
if cut == 0:
|
20 |
+
high, low = ch.max(), ch.min()
|
21 |
+
else:
|
22 |
+
hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
|
23 |
+
low = np.argwhere(np.cumsum(hist) > cut)
|
24 |
+
low = 0 if low.shape[0] == 0 else low[0]
|
25 |
+
high = np.argwhere(np.cumsum(hist[::-1]) > cut)
|
26 |
+
high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0]
|
27 |
+
if high <= low:
|
28 |
+
table = np.arange(n_bins)
|
29 |
+
else:
|
30 |
+
scale = (n_bins - 1) / (high - low)
|
31 |
+
offset = -low * scale
|
32 |
+
table = np.arange(n_bins) * scale + offset
|
33 |
+
table[table < 0] = 0
|
34 |
+
table[table > n_bins - 1] = n_bins - 1
|
35 |
+
table = table.clip(0, 255).astype(np.uint8)
|
36 |
+
return table[ch]
|
37 |
+
|
38 |
+
channels = [tune_channel(ch) for ch in cv2.split(img)]
|
39 |
+
out = cv2.merge(channels)
|
40 |
+
return out
|
41 |
+
|
42 |
+
|
43 |
+
def equalize_func(img):
|
44 |
+
"""
|
45 |
+
same output as PIL.ImageOps.equalize
|
46 |
+
PIL's implementation is different from cv2.equalize
|
47 |
+
"""
|
48 |
+
n_bins = 256
|
49 |
+
|
50 |
+
def tune_channel(ch):
|
51 |
+
hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
|
52 |
+
non_zero_hist = hist[hist != 0].reshape(-1)
|
53 |
+
step = np.sum(non_zero_hist[:-1]) // (n_bins - 1)
|
54 |
+
if step == 0:
|
55 |
+
return ch
|
56 |
+
n = np.empty_like(hist)
|
57 |
+
n[0] = step // 2
|
58 |
+
n[1:] = hist[:-1]
|
59 |
+
table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8)
|
60 |
+
return table[ch]
|
61 |
+
|
62 |
+
channels = [tune_channel(ch) for ch in cv2.split(img)]
|
63 |
+
out = cv2.merge(channels)
|
64 |
+
return out
|
65 |
+
|
66 |
+
|
67 |
+
def rotate_func(img, degree, fill=(0, 0, 0)):
|
68 |
+
"""
|
69 |
+
like PIL, rotate by degree, not radians
|
70 |
+
"""
|
71 |
+
H, W = img.shape[0], img.shape[1]
|
72 |
+
center = W / 2, H / 2
|
73 |
+
M = cv2.getRotationMatrix2D(center, degree, 1)
|
74 |
+
out = cv2.warpAffine(img, M, (W, H), borderValue=fill)
|
75 |
+
return out
|
76 |
+
|
77 |
+
|
78 |
+
def solarize_func(img, thresh=128):
|
79 |
+
"""
|
80 |
+
same output as PIL.ImageOps.posterize
|
81 |
+
"""
|
82 |
+
table = np.array([el if el < thresh else 255 - el for el in range(256)])
|
83 |
+
table = table.clip(0, 255).astype(np.uint8)
|
84 |
+
out = table[img]
|
85 |
+
return out
|
86 |
+
|
87 |
+
|
88 |
+
def color_func(img, factor):
|
89 |
+
"""
|
90 |
+
same output as PIL.ImageEnhance.Color
|
91 |
+
"""
|
92 |
+
# implementation according to PIL definition, quite slow
|
93 |
+
# degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis]
|
94 |
+
# out = blend(degenerate, img, factor)
|
95 |
+
# M = (
|
96 |
+
# np.eye(3) * factor
|
97 |
+
# + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor)
|
98 |
+
# )[np.newaxis, np.newaxis, :]
|
99 |
+
M = np.float32([[0.886, -0.114, -0.114], [-0.587, 0.413, -0.587], [-0.299, -0.299, 0.701]]) * factor + np.float32(
|
100 |
+
[[0.114], [0.587], [0.299]]
|
101 |
+
)
|
102 |
+
out = np.matmul(img, M).clip(0, 255).astype(np.uint8)
|
103 |
+
return out
|
104 |
+
|
105 |
+
|
106 |
+
def contrast_func(img, factor):
|
107 |
+
"""
|
108 |
+
same output as PIL.ImageEnhance.Contrast
|
109 |
+
"""
|
110 |
+
mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299]))
|
111 |
+
table = np.array([(el - mean) * factor + mean for el in range(256)]).clip(0, 255).astype(np.uint8)
|
112 |
+
out = table[img]
|
113 |
+
return out
|
114 |
+
|
115 |
+
|
116 |
+
def brightness_func(img, factor):
|
117 |
+
"""
|
118 |
+
same output as PIL.ImageEnhance.Contrast
|
119 |
+
"""
|
120 |
+
table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8)
|
121 |
+
out = table[img]
|
122 |
+
return out
|
123 |
+
|
124 |
+
|
125 |
+
def sharpness_func(img, factor):
|
126 |
+
"""
|
127 |
+
The differences the this result and PIL are all on the 4 boundaries, the center
|
128 |
+
areas are same
|
129 |
+
"""
|
130 |
+
kernel = np.ones((3, 3), dtype=np.float32)
|
131 |
+
kernel[1][1] = 5
|
132 |
+
kernel /= 13
|
133 |
+
degenerate = cv2.filter2D(img, -1, kernel)
|
134 |
+
if factor == 0.0:
|
135 |
+
out = degenerate
|
136 |
+
elif factor == 1.0:
|
137 |
+
out = img
|
138 |
+
else:
|
139 |
+
out = img.astype(np.float32)
|
140 |
+
degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :]
|
141 |
+
out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate)
|
142 |
+
out = out.astype(np.uint8)
|
143 |
+
return out
|
144 |
+
|
145 |
+
|
146 |
+
def shear_x_func(img, factor, fill=(0, 0, 0)):
|
147 |
+
H, W = img.shape[0], img.shape[1]
|
148 |
+
M = np.float32([[1, factor, 0], [0, 1, 0]])
|
149 |
+
out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
|
150 |
+
return out
|
151 |
+
|
152 |
+
|
153 |
+
def translate_x_func(img, offset, fill=(0, 0, 0)):
|
154 |
+
"""
|
155 |
+
same output as PIL.Image.transform
|
156 |
+
"""
|
157 |
+
H, W = img.shape[0], img.shape[1]
|
158 |
+
M = np.float32([[1, 0, -offset], [0, 1, 0]])
|
159 |
+
out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
|
160 |
+
return out
|
161 |
+
|
162 |
+
|
163 |
+
def translate_y_func(img, offset, fill=(0, 0, 0)):
|
164 |
+
"""
|
165 |
+
same output as PIL.Image.transform
|
166 |
+
"""
|
167 |
+
H, W = img.shape[0], img.shape[1]
|
168 |
+
M = np.float32([[1, 0, 0], [0, 1, -offset]])
|
169 |
+
out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
|
170 |
+
return out
|
171 |
+
|
172 |
+
|
173 |
+
def posterize_func(img, bits):
|
174 |
+
"""
|
175 |
+
same output as PIL.ImageOps.posterize
|
176 |
+
"""
|
177 |
+
out = np.bitwise_and(img, np.uint8(255 << (8 - bits)))
|
178 |
+
return out
|
179 |
+
|
180 |
+
|
181 |
+
def shear_y_func(img, factor, fill=(0, 0, 0)):
|
182 |
+
H, W = img.shape[0], img.shape[1]
|
183 |
+
M = np.float32([[1, 0, 0], [factor, 1, 0]])
|
184 |
+
out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
|
185 |
+
return out
|
186 |
+
|
187 |
+
|
188 |
+
def cutout_func(img, pad_size, replace=(0, 0, 0)):
|
189 |
+
replace = np.array(replace, dtype=np.uint8)
|
190 |
+
H, W = img.shape[0], img.shape[1]
|
191 |
+
rh, rw = np.random.random(2)
|
192 |
+
pad_size = pad_size // 2
|
193 |
+
ch, cw = int(rh * H), int(rw * W)
|
194 |
+
x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H)
|
195 |
+
y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W)
|
196 |
+
out = img.copy()
|
197 |
+
out[x1:x2, y1:y2, :] = replace
|
198 |
+
return out
|
199 |
+
|
200 |
+
|
201 |
+
# level to args
|
202 |
+
def enhance_level_to_args(MAX_LEVEL):
|
203 |
+
def level_to_args(level):
|
204 |
+
return ((level / MAX_LEVEL) * 1.8 + 0.1,)
|
205 |
+
|
206 |
+
return level_to_args
|
207 |
+
|
208 |
+
|
209 |
+
def shear_level_to_args(MAX_LEVEL, replace_value):
|
210 |
+
def level_to_args(level):
|
211 |
+
level = (level / MAX_LEVEL) * 0.3
|
212 |
+
if np.random.random() > 0.5:
|
213 |
+
level = -level
|
214 |
+
return (level, replace_value)
|
215 |
+
|
216 |
+
return level_to_args
|
217 |
+
|
218 |
+
|
219 |
+
def translate_level_to_args(translate_const, MAX_LEVEL, replace_value):
|
220 |
+
def level_to_args(level):
|
221 |
+
level = (level / MAX_LEVEL) * float(translate_const)
|
222 |
+
if np.random.random() > 0.5:
|
223 |
+
level = -level
|
224 |
+
return (level, replace_value)
|
225 |
+
|
226 |
+
return level_to_args
|
227 |
+
|
228 |
+
|
229 |
+
def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value):
|
230 |
+
def level_to_args(level):
|
231 |
+
level = int((level / MAX_LEVEL) * cutout_const)
|
232 |
+
return (level, replace_value)
|
233 |
+
|
234 |
+
return level_to_args
|
235 |
+
|
236 |
+
|
237 |
+
def solarize_level_to_args(MAX_LEVEL):
|
238 |
+
def level_to_args(level):
|
239 |
+
level = int((level / MAX_LEVEL) * 256)
|
240 |
+
return (level,)
|
241 |
+
|
242 |
+
return level_to_args
|
243 |
+
|
244 |
+
|
245 |
+
def none_level_to_args(level):
|
246 |
+
return ()
|
247 |
+
|
248 |
+
|
249 |
+
def posterize_level_to_args(MAX_LEVEL):
|
250 |
+
def level_to_args(level):
|
251 |
+
level = int((level / MAX_LEVEL) * 4)
|
252 |
+
return (level,)
|
253 |
+
|
254 |
+
return level_to_args
|
255 |
+
|
256 |
+
|
257 |
+
def rotate_level_to_args(MAX_LEVEL, replace_value):
|
258 |
+
def level_to_args(level):
|
259 |
+
level = (level / MAX_LEVEL) * 30
|
260 |
+
if np.random.random() < 0.5:
|
261 |
+
level = -level
|
262 |
+
return (level, replace_value)
|
263 |
+
|
264 |
+
return level_to_args
|
265 |
+
|
266 |
+
|
267 |
+
func_dict = {
|
268 |
+
"Identity": identity_func,
|
269 |
+
"AutoContrast": autocontrast_func,
|
270 |
+
"Equalize": equalize_func,
|
271 |
+
"Rotate": rotate_func,
|
272 |
+
"Solarize": solarize_func,
|
273 |
+
"Color": color_func,
|
274 |
+
"Contrast": contrast_func,
|
275 |
+
"Brightness": brightness_func,
|
276 |
+
"Sharpness": sharpness_func,
|
277 |
+
"ShearX": shear_x_func,
|
278 |
+
"TranslateX": translate_x_func,
|
279 |
+
"TranslateY": translate_y_func,
|
280 |
+
"Posterize": posterize_func,
|
281 |
+
"ShearY": shear_y_func,
|
282 |
+
}
|
283 |
+
|
284 |
+
translate_const = 10
|
285 |
+
MAX_LEVEL = 10
|
286 |
+
replace_value = (128, 128, 128)
|
287 |
+
arg_dict = {
|
288 |
+
"Identity": none_level_to_args,
|
289 |
+
"AutoContrast": none_level_to_args,
|
290 |
+
"Equalize": none_level_to_args,
|
291 |
+
"Rotate": rotate_level_to_args(MAX_LEVEL, replace_value),
|
292 |
+
"Solarize": solarize_level_to_args(MAX_LEVEL),
|
293 |
+
"Color": enhance_level_to_args(MAX_LEVEL),
|
294 |
+
"Contrast": enhance_level_to_args(MAX_LEVEL),
|
295 |
+
"Brightness": enhance_level_to_args(MAX_LEVEL),
|
296 |
+
"Sharpness": enhance_level_to_args(MAX_LEVEL),
|
297 |
+
"ShearX": shear_level_to_args(MAX_LEVEL, replace_value),
|
298 |
+
"TranslateX": translate_level_to_args(translate_const, MAX_LEVEL, replace_value),
|
299 |
+
"TranslateY": translate_level_to_args(translate_const, MAX_LEVEL, replace_value),
|
300 |
+
"Posterize": posterize_level_to_args(MAX_LEVEL),
|
301 |
+
"ShearY": shear_level_to_args(MAX_LEVEL, replace_value),
|
302 |
+
}
|
303 |
+
|
304 |
+
|
305 |
+
class RandomAugment(object):
|
306 |
+
def __init__(self, N=2, M=10, isPIL=False, augs=[]):
|
307 |
+
self.N = N
|
308 |
+
self.M = M
|
309 |
+
self.isPIL = isPIL
|
310 |
+
if augs:
|
311 |
+
self.augs = augs
|
312 |
+
else:
|
313 |
+
self.augs = list(arg_dict.keys())
|
314 |
+
|
315 |
+
def get_random_ops(self):
|
316 |
+
sampled_ops = np.random.choice(self.augs, self.N)
|
317 |
+
return [(op, 0.5, self.M) for op in sampled_ops]
|
318 |
+
|
319 |
+
def __call__(self, img):
|
320 |
+
if self.isPIL:
|
321 |
+
img = np.array(img)
|
322 |
+
ops = self.get_random_ops()
|
323 |
+
for name, prob, level in ops:
|
324 |
+
if np.random.random() > prob:
|
325 |
+
continue
|
326 |
+
args = arg_dict[name](level)
|
327 |
+
img = func_dict[name](img, *args)
|
328 |
+
return img
|
329 |
+
|
330 |
+
|
331 |
+
if __name__ == "__main__":
|
332 |
+
a = RandomAugment()
|
333 |
+
img = np.random.randn(32, 32, 3)
|
334 |
+
a(img)
|
vlmo/transforms/square_transform.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# code in this file is adpated from the ALBEF repo (https://github.com/salesforce/ALBEF)
|
2 |
+
|
3 |
+
from torchvision import transforms
|
4 |
+
from .randaugment import RandomAugment
|
5 |
+
from PIL import Image
|
6 |
+
|
7 |
+
|
8 |
+
def square_transform(size=224):
|
9 |
+
return transforms.Compose(
|
10 |
+
[
|
11 |
+
transforms.Resize((size, size), interpolation=Image.BICUBIC),
|
12 |
+
transforms.ToTensor(),
|
13 |
+
]
|
14 |
+
)
|
15 |
+
|
16 |
+
|
17 |
+
def square_transform_randaug(size=224):
|
18 |
+
return transforms.Compose(
|
19 |
+
[
|
20 |
+
transforms.RandomResizedCrop(size, scale=(0.8, 1.0), interpolation=Image.BICUBIC),
|
21 |
+
transforms.RandomHorizontalFlip(),
|
22 |
+
RandomAugment(
|
23 |
+
2,
|
24 |
+
7,
|
25 |
+
isPIL=True,
|
26 |
+
augs=[
|
27 |
+
"Identity",
|
28 |
+
"AutoContrast",
|
29 |
+
"Equalize",
|
30 |
+
"Brightness",
|
31 |
+
"Sharpness",
|
32 |
+
"ShearX",
|
33 |
+
"ShearY",
|
34 |
+
"TranslateX",
|
35 |
+
"TranslateY",
|
36 |
+
"Rotate",
|
37 |
+
],
|
38 |
+
),
|
39 |
+
transforms.ToTensor(),
|
40 |
+
]
|
41 |
+
)
|
vlmo/transforms/utils.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torchvision import transforms
|
2 |
+
from PIL import Image
|
3 |
+
|
4 |
+
|
5 |
+
class MinMaxResize:
|
6 |
+
def __init__(self, shorter=800, longer=1333):
|
7 |
+
self.min = shorter
|
8 |
+
self.max = longer
|
9 |
+
|
10 |
+
def __call__(self, x):
|
11 |
+
w, h = x.size
|
12 |
+
scale = self.min / min(w, h)
|
13 |
+
if h < w:
|
14 |
+
newh, neww = self.min, scale * w
|
15 |
+
else:
|
16 |
+
newh, neww = scale * h, self.min
|
17 |
+
|
18 |
+
if max(newh, neww) > self.max:
|
19 |
+
scale = self.max / max(newh, neww)
|
20 |
+
newh = newh * scale
|
21 |
+
neww = neww * scale
|
22 |
+
|
23 |
+
newh, neww = int(newh + 0.5), int(neww + 0.5)
|
24 |
+
newh, neww = newh // 32 * 32, neww // 32 * 32
|
25 |
+
|
26 |
+
return x.resize((neww, newh), resample=Image.BICUBIC)
|
27 |
+
|
28 |
+
|
29 |
+
class UnNormalize(object):
|
30 |
+
def __init__(self, mean, std):
|
31 |
+
self.mean = mean
|
32 |
+
self.std = std
|
33 |
+
|
34 |
+
def __call__(self, tensor):
|
35 |
+
"""
|
36 |
+
Args:
|
37 |
+
tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
|
38 |
+
Returns:
|
39 |
+
Tensor: Normalized image.
|
40 |
+
"""
|
41 |
+
for t, m, s in zip(tensor, self.mean, self.std):
|
42 |
+
t.mul_(s).add_(m)
|
43 |
+
# The normalize code -> t.sub_(m).div_(s)
|
44 |
+
return tensor
|
45 |
+
|
46 |
+
|
47 |
+
# This is simple maximum entropy normalization performed in Inception paper
|
48 |
+
inception_normalize = transforms.Compose([transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])
|
49 |
+
|
50 |
+
# ViT uses simple non-biased inception normalization
|
51 |
+
# https://github.com/google-research/vision_transformer/blob/master/vit_jax/input_pipeline.py#L132
|
52 |
+
inception_unnormalize = transforms.Compose([UnNormalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])
|
53 |
+
|
54 |
+
cn_clip_normalize = transforms.Compose(
|
55 |
+
[transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])]
|
56 |
+
)
|
vlmo/utils/__init__.py
ADDED
File without changes
|