Flexstorydiff / xformers /tests /test_pickling.py
FlexTheAi's picture
Upload folder using huggingface_hub
e202b16 verified
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# CREDITS: Initially suggested by Jason Ramapuram, see
# https://github.com/facebookresearch/xformers/issues/203
import pickle
from copy import deepcopy
import pytest
from torch import nn
from xformers.factory import xFormer, xFormerConfig
test_config = [
{
"reversible": False,
"block_type": "encoder",
"num_layers": 2,
"dim_model": 768,
"residual_norm_style": "pre",
"multi_head_config": {
"num_heads": 12,
"residual_dropout": 0.1,
"use_rotary_embeddings": True,
"attention": {
"name": "scaled_dot_product",
"dropout": 0.1,
"causal": False,
},
},
"feedforward_config": {
"name": "MLP",
"dropout": 0.1,
"activation": "gelu",
"hidden_layer_multiplier": 4,
},
}
]
class ViT(nn.Module):
def __init__(self, mlp):
super().__init__()
test_config[0]["feedforward_config"]["name"] = mlp
xformer_config = xFormerConfig(test_config)
self.xformer = xFormer.from_config(xformer_config)
MLPs = ["MLP"]
@pytest.mark.parametrize("mlp", MLPs)
def test_pickling(mlp):
test = ViT(mlp)
_ = pickle.dumps(test)
_ = deepcopy(test)