Flexstorydiff / xformers /tests /test_reversible.py
FlexTheAi's picture
Upload folder using huggingface_hub
e202b16 verified
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import random
import pytest
import torch
from xformers.factory.model_factory import xFormer, xFormerConfig
BATCH = 2
SEQ = 64
EMB = 48
VOCAB = 16
DEVICES = (
[torch.device("cpu")]
if not torch.cuda.is_available()
else [torch.device("cuda")] # save a bit on CI, we have separate cpu and gpu jobs
)
_test_config_encoder = {
"reversible": False,
"block_type": "encoder",
"dim_model": EMB,
"position_encoding_config": {
"name": "vocab",
"seq_len": SEQ,
"vocab_size": VOCAB,
"dim_model": EMB,
},
"num_layers": 3,
"multi_head_config": {
"num_heads": 4,
"residual_dropout": 0,
"attention": {
"name": "linformer",
"dropout": 0,
"causal": True,
"seq_len": SEQ,
},
"dim_model": EMB,
},
"feedforward_config": {
"name": "MLP",
"dropout": 0,
"activation": "relu",
"hidden_layer_multiplier": 4,
"dim_model": EMB,
},
}
_test_config_decoder = {
"block_type": "decoder",
"dim_model": EMB,
"position_encoding_config": {
"name": "vocab",
"seq_len": SEQ,
"vocab_size": VOCAB,
"dim_model": EMB,
},
"num_layers": 2,
"multi_head_config_masked": {
"num_heads": 4,
"residual_dropout": 0,
"dim_model": EMB,
"attention": {
"name": "linformer",
"dropout": 0,
"causal": True,
"seq_len": SEQ,
},
},
"multi_head_config_cross": {
"num_heads": 4,
"residual_dropout": 0,
"dim_model": EMB,
"attention": {
"name": "linformer",
"dropout": 0,
"causal": True,
"seq_len": SEQ,
},
},
"feedforward_config": {
"name": "MLP",
"dropout": 0,
"activation": "relu",
"hidden_layer_multiplier": 4,
"dim_model": EMB,
},
}
# Test a pure encoder, a pure decoder, an encoder/decoder stack
_test_configs = [
[_test_config_encoder, _test_config_decoder],
[_test_config_encoder],
]
def _rev_config(config, flag: bool):
for c in filter(
lambda x: x["block_type"] == "encoder",
config,
):
c["reversible"] = flag
return config
@pytest.mark.parametrize("config", _test_configs)
@pytest.mark.parametrize("device", DEVICES)
def test_reversible_runs(config, device):
# Build both a reversible and non-reversible model
model_non_reversible = xFormer.from_config(
xFormerConfig(_rev_config(config, False))
).to(device)
model_reversible = xFormer.from_config(xFormerConfig(_rev_config(config, True))).to(
device
)
# Dummy inputs, test a forward
inputs = (torch.rand((BATCH, SEQ), device=device) * 10).abs().to(torch.int)
_ = model_non_reversible(inputs)
_ = model_reversible(inputs)
@pytest.mark.parametrize("device", DEVICES)
def test_reversible_no_alternate(device):
# Check that we cannot build a non-coherent stack
with pytest.raises(AssertionError):
rev = dict(_test_config_encoder) # we need to make a copy
rev["reversible"] = True
non_rev = dict(_test_config_encoder)
non_rev["reversible"] = False
_ = xFormer.from_config(xFormerConfig([rev, non_rev])).to(device)
@pytest.mark.parametrize("config", _test_configs)
@pytest.mark.parametrize("device", DEVICES)
def test_reversible_train(config, device):
torch.manual_seed(0)
random.seed(0)
# Dummy inputs, test some training to make sure that we both can approximate the same thing to some extent
# This is not super scientific, more of a foolproof catch
def data():
input_a = torch.zeros((BATCH, SEQ), device=device).to(torch.int)
input_b = (torch.rand((BATCH, SEQ), device=device) * VOCAB).abs().to(torch.int)
target_a = torch.zeros((BATCH, SEQ), device=device)
target_b = torch.ones((BATCH, SEQ), device=device)
if random.random() > 0.5:
return torch.cat([input_a, input_b], dim=0), torch.cat(
[target_a, target_b], dim=0
)
return torch.cat([input_b, input_a], dim=0), torch.cat(
[target_b, target_a], dim=0
)
def step(model: torch.nn.Module, optim: torch.optim.Optimizer):
batch, target = data()
model.train()
optim.zero_grad()
outputs = model(batch)
loss = torch.norm(torch.mean(outputs, dim=-1) - target)
loss.backward()
# Clip grad and error out if we're producing NaNs, part of the unit test
torch.nn.utils.clip_grad_norm_(
model.parameters(), 10.0, norm_type=2.0, error_if_nonfinite=True
)
optim.step()
return loss.item()
def evaluate(model: torch.nn.Module):
batch, target = data()
model.eval()
outputs = model(batch)
return torch.norm(torch.mean(outputs, dim=-1) - target).item()
# Build both a reversible and non-reversible model
model_non_reversible = xFormer.from_config(
xFormerConfig(_rev_config(config, False))
).to(device)
model_reversible = xFormer.from_config(xFormerConfig(_rev_config(config, True))).to(
device
)
optim_rev = torch.optim.SGD(model_reversible.parameters(), lr=1e-3, momentum=0.9)
optim_non_rev = torch.optim.SGD(
model_non_reversible.parameters(), lr=1e-3, momentum=0.9
)
# Check that both models can be trained to comparable results
eval_start_rev = evaluate(model_reversible)
eval_start_non_rev = evaluate(model_non_reversible)
for i in range(100):
print(i, " reversible: ", step(model_reversible, optim_rev))
print(i, " non reversible: ", step(model_non_reversible, optim_non_rev))
# Check that we can classify this dummy example
# Arbitrary threshold
eval_stop_rev = evaluate(model_reversible)
eval_stop_non_rev = evaluate(model_non_reversible)
if len(config) < 2: # only check the encoder case
train_ratio_rev = eval_start_rev / eval_stop_rev
train_ratio_non_rev = eval_start_non_rev / eval_stop_non_rev
# Assert that train ratio > 1 (we trained),
# and reversible is not much worse than non-reversible (it's actually better on this dummy test)
assert train_ratio_rev > 1
assert train_ratio_non_rev > 1
assert train_ratio_rev > train_ratio_non_rev