File size: 1,503 Bytes
e202b16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

# CREDITS: Initially suggested by Jason Ramapuram, see
# https://github.com/facebookresearch/xformers/issues/203

import pickle
from copy import deepcopy

import pytest
from torch import nn

from xformers.factory import xFormer, xFormerConfig

test_config = [
    {
        "reversible": False,
        "block_type": "encoder",
        "num_layers": 2,
        "dim_model": 768,
        "residual_norm_style": "pre",
        "multi_head_config": {
            "num_heads": 12,
            "residual_dropout": 0.1,
            "use_rotary_embeddings": True,
            "attention": {
                "name": "scaled_dot_product",
                "dropout": 0.1,
                "causal": False,
            },
        },
        "feedforward_config": {
            "name": "MLP",
            "dropout": 0.1,
            "activation": "gelu",
            "hidden_layer_multiplier": 4,
        },
    }
]


class ViT(nn.Module):
    def __init__(self, mlp):
        super().__init__()
        test_config[0]["feedforward_config"]["name"] = mlp
        xformer_config = xFormerConfig(test_config)
        self.xformer = xFormer.from_config(xformer_config)


MLPs = ["MLP"]


@pytest.mark.parametrize("mlp", MLPs)
def test_pickling(mlp):
    test = ViT(mlp)
    _ = pickle.dumps(test)
    _ = deepcopy(test)