Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | |
# | |
# This source code is licensed under the BSD license found in the | |
# LICENSE file in the root directory of this source tree. | |
import random | |
import pytest | |
import torch | |
from xformers.factory.model_factory import xFormer, xFormerConfig | |
BATCH = 2 | |
SEQ = 64 | |
EMB = 48 | |
VOCAB = 16 | |
DEVICES = ( | |
[torch.device("cpu")] | |
if not torch.cuda.is_available() | |
else [torch.device("cuda")] # save a bit on CI, we have separate cpu and gpu jobs | |
) | |
_test_config_encoder = { | |
"reversible": False, | |
"block_type": "encoder", | |
"dim_model": EMB, | |
"position_encoding_config": { | |
"name": "vocab", | |
"seq_len": SEQ, | |
"vocab_size": VOCAB, | |
"dim_model": EMB, | |
}, | |
"num_layers": 3, | |
"multi_head_config": { | |
"num_heads": 4, | |
"residual_dropout": 0, | |
"attention": { | |
"name": "linformer", | |
"dropout": 0, | |
"causal": True, | |
"seq_len": SEQ, | |
}, | |
"dim_model": EMB, | |
}, | |
"feedforward_config": { | |
"name": "MLP", | |
"dropout": 0, | |
"activation": "relu", | |
"hidden_layer_multiplier": 4, | |
"dim_model": EMB, | |
}, | |
} | |
_test_config_decoder = { | |
"block_type": "decoder", | |
"dim_model": EMB, | |
"position_encoding_config": { | |
"name": "vocab", | |
"seq_len": SEQ, | |
"vocab_size": VOCAB, | |
"dim_model": EMB, | |
}, | |
"num_layers": 2, | |
"multi_head_config_masked": { | |
"num_heads": 4, | |
"residual_dropout": 0, | |
"dim_model": EMB, | |
"attention": { | |
"name": "linformer", | |
"dropout": 0, | |
"causal": True, | |
"seq_len": SEQ, | |
}, | |
}, | |
"multi_head_config_cross": { | |
"num_heads": 4, | |
"residual_dropout": 0, | |
"dim_model": EMB, | |
"attention": { | |
"name": "linformer", | |
"dropout": 0, | |
"causal": True, | |
"seq_len": SEQ, | |
}, | |
}, | |
"feedforward_config": { | |
"name": "MLP", | |
"dropout": 0, | |
"activation": "relu", | |
"hidden_layer_multiplier": 4, | |
"dim_model": EMB, | |
}, | |
} | |
# Test a pure encoder, a pure decoder, an encoder/decoder stack | |
_test_configs = [ | |
[_test_config_encoder, _test_config_decoder], | |
[_test_config_encoder], | |
] | |
def _rev_config(config, flag: bool): | |
for c in filter( | |
lambda x: x["block_type"] == "encoder", | |
config, | |
): | |
c["reversible"] = flag | |
return config | |
def test_reversible_runs(config, device): | |
# Build both a reversible and non-reversible model | |
model_non_reversible = xFormer.from_config( | |
xFormerConfig(_rev_config(config, False)) | |
).to(device) | |
model_reversible = xFormer.from_config(xFormerConfig(_rev_config(config, True))).to( | |
device | |
) | |
# Dummy inputs, test a forward | |
inputs = (torch.rand((BATCH, SEQ), device=device) * 10).abs().to(torch.int) | |
_ = model_non_reversible(inputs) | |
_ = model_reversible(inputs) | |
def test_reversible_no_alternate(device): | |
# Check that we cannot build a non-coherent stack | |
with pytest.raises(AssertionError): | |
rev = dict(_test_config_encoder) # we need to make a copy | |
rev["reversible"] = True | |
non_rev = dict(_test_config_encoder) | |
non_rev["reversible"] = False | |
_ = xFormer.from_config(xFormerConfig([rev, non_rev])).to(device) | |
def test_reversible_train(config, device): | |
torch.manual_seed(0) | |
random.seed(0) | |
# Dummy inputs, test some training to make sure that we both can approximate the same thing to some extent | |
# This is not super scientific, more of a foolproof catch | |
def data(): | |
input_a = torch.zeros((BATCH, SEQ), device=device).to(torch.int) | |
input_b = (torch.rand((BATCH, SEQ), device=device) * VOCAB).abs().to(torch.int) | |
target_a = torch.zeros((BATCH, SEQ), device=device) | |
target_b = torch.ones((BATCH, SEQ), device=device) | |
if random.random() > 0.5: | |
return torch.cat([input_a, input_b], dim=0), torch.cat( | |
[target_a, target_b], dim=0 | |
) | |
return torch.cat([input_b, input_a], dim=0), torch.cat( | |
[target_b, target_a], dim=0 | |
) | |
def step(model: torch.nn.Module, optim: torch.optim.Optimizer): | |
batch, target = data() | |
model.train() | |
optim.zero_grad() | |
outputs = model(batch) | |
loss = torch.norm(torch.mean(outputs, dim=-1) - target) | |
loss.backward() | |
# Clip grad and error out if we're producing NaNs, part of the unit test | |
torch.nn.utils.clip_grad_norm_( | |
model.parameters(), 10.0, norm_type=2.0, error_if_nonfinite=True | |
) | |
optim.step() | |
return loss.item() | |
def evaluate(model: torch.nn.Module): | |
batch, target = data() | |
model.eval() | |
outputs = model(batch) | |
return torch.norm(torch.mean(outputs, dim=-1) - target).item() | |
# Build both a reversible and non-reversible model | |
model_non_reversible = xFormer.from_config( | |
xFormerConfig(_rev_config(config, False)) | |
).to(device) | |
model_reversible = xFormer.from_config(xFormerConfig(_rev_config(config, True))).to( | |
device | |
) | |
optim_rev = torch.optim.SGD(model_reversible.parameters(), lr=1e-3, momentum=0.9) | |
optim_non_rev = torch.optim.SGD( | |
model_non_reversible.parameters(), lr=1e-3, momentum=0.9 | |
) | |
# Check that both models can be trained to comparable results | |
eval_start_rev = evaluate(model_reversible) | |
eval_start_non_rev = evaluate(model_non_reversible) | |
for i in range(100): | |
print(i, " reversible: ", step(model_reversible, optim_rev)) | |
print(i, " non reversible: ", step(model_non_reversible, optim_non_rev)) | |
# Check that we can classify this dummy example | |
# Arbitrary threshold | |
eval_stop_rev = evaluate(model_reversible) | |
eval_stop_non_rev = evaluate(model_non_reversible) | |
if len(config) < 2: # only check the encoder case | |
train_ratio_rev = eval_start_rev / eval_stop_rev | |
train_ratio_non_rev = eval_start_non_rev / eval_stop_non_rev | |
# Assert that train ratio > 1 (we trained), | |
# and reversible is not much worse than non-reversible (it's actually better on this dummy test) | |
assert train_ratio_rev > 1 | |
assert train_ratio_non_rev > 1 | |
assert train_ratio_rev > train_ratio_non_rev | |