Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | |
# | |
# This source code is licensed under the BSD license found in the | |
# LICENSE file in the root directory of this source tree. | |
from contextlib import nullcontext | |
from copy import deepcopy | |
import pytest | |
import torch | |
from torch import nn | |
import xformers.ops | |
from xformers.checkpoint import ( | |
_optimize_runtime_with_given_memory, | |
checkpoint, | |
get_optimal_checkpoint_policy, | |
list_operators, | |
selective_checkpoint_wrapper, | |
) | |
cuda_only = pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA") | |
_devices = ["cpu"] | |
cuda_cap = (0, 0) | |
if torch.cuda.is_available(): | |
_devices.append("cuda") | |
cuda_cap = torch.cuda.get_device_capability(_devices[1]) | |
def _relu_policy(ctx, func, *args, **kwargs): | |
return func == torch.ops.aten.relu.default | |
def _all_policy(ctx, func, *args, **kwargs): | |
return True | |
def test_checkpoint(policy_fn, input_requires_grad, device, autocast): | |
def build_module(): | |
return nn.Sequential( | |
nn.Linear(10, 10), | |
nn.ReLU(), | |
nn.Linear(10, 10), | |
nn.ReLU(), | |
).to(device) | |
module = nn.ModuleList([build_module() for i in range(10)]) | |
# Run model with and without checkpointing and verify gradients are | |
# equivalent, regardless of if inputs require grads or not. | |
module_copy = deepcopy(module) | |
inputs = torch.rand(32, 10, device=device) | |
inputs_copy = inputs.clone() | |
inputs.requires_grad_(input_requires_grad) | |
inputs_copy.requires_grad_(input_requires_grad) | |
out = inputs | |
out_copy = inputs_copy | |
with torch.autocast(device_type=device, enabled=autocast): | |
for i in range(10): | |
out = checkpoint(module[i], out, policy_fn=policy_fn) | |
out_copy = module_copy[i](out_copy) | |
assert torch.allclose(out, out_copy) | |
out.sum().backward() | |
out_copy.sum().backward() | |
for p, p_copy in zip(module.parameters(), module_copy.parameters()): | |
assert torch.allclose(p.grad, p_copy.grad) | |
def test_checkpoint_with_grad(policy_fn, input_requires_grad, grad_mode): | |
module = nn.Sequential( | |
nn.Linear(10, 10), | |
nn.ReLU(), | |
nn.Linear(10, 10), | |
nn.ReLU(), | |
) | |
# Run model with and without checkpointing and verify gradients are | |
# equivalent, regardless of if inputs require grads or not. | |
module_copy = deepcopy(module) | |
inputs = torch.rand(32, 10) | |
inputs_copy = inputs.clone() | |
inputs.requires_grad_(input_requires_grad) | |
inputs_copy.requires_grad_(input_requires_grad) | |
out = inputs | |
out_copy = inputs_copy | |
with torch.set_grad_enabled(grad_mode): | |
for i in range(10): | |
out = checkpoint(module, out, policy_fn=policy_fn) | |
out_copy = module_copy(out_copy) | |
assert torch.allclose(out, out_copy) | |
def test_checkpoint_attention(policy_fn, input_requires_grad, device, autocast, op): | |
if ( | |
op[0].CUDA_MINIMUM_COMPUTE_CAPABILITY > cuda_cap | |
or op[1].CUDA_MINIMUM_COMPUTE_CAPABILITY > cuda_cap | |
): | |
pytest.skip("skipping operator not supported in this arch") | |
if ( | |
op is xformers.ops.MemoryEfficientAttentionFlashAttentionOp | |
and torch.version.hip | |
): | |
pytest.skip("FlashAttentionOp is not supported on ROCM!") | |
if op is xformers.ops.MemoryEfficientAttentionCkOp: | |
pytest.skip("Gradience is currently not supported by ck-tiled!") | |
class Attn(nn.Module): | |
def forward(self, x): | |
out = xformers.ops.memory_efficient_attention(x, x, x, op=op) | |
return out + x | |
num_layers = 10 | |
dtype = torch.float32 if autocast else torch.float16 | |
modules = nn.Sequential( | |
*[ | |
nn.Sequential( | |
nn.Linear(10, 64), | |
Attn(), | |
nn.ReLU(), | |
nn.Linear(64, 10), | |
nn.ReLU(), | |
) | |
.to(device) | |
.to(dtype) | |
for _ in range(num_layers) | |
] | |
) | |
# Run model with and without checkpointing and verify gradients are | |
# equivalent, regardless of if inputs require grads or not. | |
modules_copy = deepcopy(modules) | |
inputs = torch.rand(32, 128, 10, dtype=dtype, device=device) | |
inputs_copy = inputs.clone() | |
inputs.requires_grad_(input_requires_grad) | |
inputs_copy.requires_grad_(input_requires_grad) | |
out = inputs | |
out_copy = inputs_copy | |
with torch.autocast(device_type=device, enabled=autocast): | |
for i in range(num_layers): | |
out = checkpoint(modules[i], out, policy_fn=policy_fn) | |
out_copy = modules_copy[i](out_copy) | |
assert torch.allclose(out, out_copy) | |
out.sum().backward() | |
out_copy.sum().backward() | |
for p, p_copy in zip(modules.parameters(), modules_copy.parameters()): | |
assert torch.allclose( | |
p.grad, p_copy.grad | |
), f"{(p.grad - p_copy.grad).abs().max()}" | |
if input_requires_grad: | |
assert torch.allclose(inputs.grad, inputs_copy.grad) | |
def test_list_operators(): | |
module = nn.Sequential( | |
nn.Linear(10, 10), | |
nn.ReLU(), | |
nn.Linear(10, 10), | |
nn.ReLU(), | |
) | |
inputs = torch.rand(32, 10) | |
operators = list_operators(module, inputs) | |
operators_str = [str(x) for x in operators] | |
ref = [ | |
"aten.t.default", | |
"aten.addmm.default", | |
"aten.relu.default", | |
"aten.detach.default", | |
"aten.t.default", | |
"aten.addmm.default", | |
"aten.relu.default", | |
"aten.detach.default", | |
] | |
assert operators_str == ref | |
def test_optimize_runtime_with_given_memory(max_memory, optimal_soln): | |
data = [ | |
("aten.copy_", 5, 0), | |
("aten.add", 5, 100), | |
("aten.div", 8, 100), | |
("aten.mm", 15, 120), | |
("aten.native_dropout", 15, 0), | |
("aten.linear", 9, 100), | |
("aten.t", 1, 0), | |
("aten.relu_", 5, 0), | |
] | |
inplace_ops = [(0, 0), (7, 5)] | |
view_like_ops = [6] | |
rand_ops = [4] | |
runtimes = torch.tensor([x[1] for x in data], dtype=torch.float64) | |
memory = torch.tensor([x[2] for x in data], dtype=torch.float64) | |
out = _optimize_runtime_with_given_memory( | |
memory, | |
runtimes, | |
max_memory, | |
view_like_ops, | |
inplace_ops, | |
rand_ops, | |
force_store_random=False, | |
) | |
torch.testing.assert_close(optimal_soln, out) | |
def _get_model_blocks(num_layers, dtype, device, inplace, random, first_inplace): | |
modules = [] | |
class Add_(torch.nn.Module): | |
def forward(self, x): | |
return x.add_(1) | |
for _ in range(num_layers): | |
mods = [ | |
nn.Linear(10, 10), | |
nn.CELU(inplace=inplace), | |
] | |
if first_inplace: | |
mods.insert(0, Add_()) | |
if random: | |
mods.append(nn.Dropout()) | |
mods.append(nn.Linear(10, 10)) | |
if random: | |
mods.append(nn.Dropout()) | |
mods.append(nn.CELU(inplace=inplace)) | |
modules.append(nn.Sequential(*mods).to(device).to(dtype)) | |
return modules | |
class _Model(torch.nn.Module): | |
def __init__(self, blocks, policy_fn): | |
super().__init__() | |
self.blocks = torch.nn.ModuleList(blocks) | |
self.policy_fn = policy_fn | |
def forward(self, x): | |
for b in self.blocks: | |
x = checkpoint(b, x, policy_fn=self.policy_fn) | |
return x | |
def test_optimal_checkpoint_policy( | |
device, memory_budget, inplace, random, first_inplace | |
): | |
if first_inplace and inplace: | |
pytest.skip("This case is degenerate and doesn't work with vanilla PyTorch") | |
torch.manual_seed(42) | |
dtype = torch.float16 | |
modules = _get_model_blocks( | |
3, dtype, device, inplace=inplace, random=random, first_inplace=first_inplace | |
) | |
inputs = torch.rand(32, 128, 10, dtype=dtype, device=device) | |
policy_fn = get_optimal_checkpoint_policy( | |
modules[0], inputs, memory_budget=memory_budget | |
) | |
model = _Model(modules, policy_fn) | |
model_ref = torch.nn.Sequential(*deepcopy(modules)) | |
grad = torch.rand_like(inputs) | |
torch.manual_seed(42) | |
out = model(inputs.clone()) | |
out.backward(grad) | |
torch.manual_seed(42) | |
out_ref = model_ref(inputs.clone()) | |
out_ref.backward(grad) | |
torch.testing.assert_close(out, out_ref) | |
for p, p_ref in zip(model.parameters(), model_ref.parameters()): | |
torch.testing.assert_close(p.grad, p_ref.grad) | |
def test_selective_checkpoint_wrapper_compile( | |
device, no_grad, memory_budget, inplace, random | |
): | |
torch.manual_seed(42) | |
dtype = torch.float16 | |
modules = _get_model_blocks( | |
3, dtype, device, inplace=inplace, random=random, first_inplace=False | |
) | |
inputs = torch.rand(32, 128, 10, dtype=dtype, device=device) | |
model = torch.nn.Sequential( | |
*[selective_checkpoint_wrapper(b, memory_budget=memory_budget) for b in modules] | |
) | |
model = torch.compile(model) | |
model_ref = torch.nn.Sequential(*deepcopy(modules)) | |
grad = torch.rand_like(inputs) | |
context = torch.no_grad() if no_grad else nullcontext() | |
with context: | |
torch.manual_seed(42) | |
out = model(inputs.clone()) | |
if not no_grad: | |
out.backward(grad) | |
torch.manual_seed(42) | |
out_ref = model_ref(inputs.clone()) | |
if not no_grad: | |
out_ref.backward(grad) | |
atol = 3e-4 | |
rtol = 1e-3 | |
torch.testing.assert_close(out, out_ref, atol=atol, rtol=rtol) | |
if no_grad: | |
return | |
for p, p_ref in zip(model.parameters(), model_ref.parameters()): | |
atol = 4e-4 | |
rtol = 2e-3 | |
torch.testing.assert_close(p.grad, p_ref.grad, atol=atol, rtol=rtol) | |