Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | |
# | |
# This source code is licensed under the BSD license found in the | |
# LICENSE file in the root directory of this source tree. | |
import pytest | |
import torch | |
from xformers.components import Activation | |
from xformers.components.feedforward import FEEDFORWARD_REGISTRY, build_feedforward | |
from xformers.components.feedforward.mixture_of_experts import GateConfig | |
from xformers.helpers.test_utils import init_torch_distributed_local | |
BATCH = 4 | |
SEQ = 256 | |
EMBD = 16 | |
LATENT = 128 | |
DROPOUT = 0.5 | |
DEVICES = ( | |
[torch.device("cpu")] if not torch.cuda.is_available() else [torch.device("cuda")] | |
) | |
assert FEEDFORWARD_REGISTRY.keys(), "Feedforward layers should have been registered" | |
def test_feedforward( | |
feedforward_name: str, activation: Activation, device: torch.device | |
): | |
test_config = { | |
"name": feedforward_name, | |
"dim_model": LATENT, | |
"dropout": DROPOUT, | |
"activation": activation, | |
"hidden_layer_multiplier": 4, | |
"number_of_experts": 4, # MoE | |
"gate": "top_2", # MoE | |
} | |
if feedforward_name == "MixtureOfExperts": | |
init_torch_distributed_local() | |
# dummy, just check construction and dimensions in the FW pass | |
ffw = build_feedforward(test_config) | |
if ffw.requires_cuda and not device.type == "cuda": | |
# pyre-fixme[29]: The library function `pytest.skip` is not supported by Pyre. | |
pytest.skip("This MLP requires CUDA and current device does not match") | |
inputs = torch.rand(BATCH, SEQ, LATENT, device=device) | |
ffw = ffw.to(device) | |
_ = ffw(inputs) | |
def get_expert(): | |
return torch.nn.Linear(LATENT, LATENT, bias=False) | |
def test_moe(gate, number_of_local_experts, expert_constructor): | |
test_config = { | |
"name": "MixtureOfExperts", | |
"dim_model": LATENT, | |
"dropout": DROPOUT, | |
"activation": Activation.ReLU, | |
"hidden_layer_multiplier": 4, | |
"number_of_experts": 4, | |
"number_of_local_experts": number_of_local_experts, | |
"gate": gate, | |
"expert_constructor": expert_constructor, | |
} | |
init_torch_distributed_local() | |
# dummy, just check construction and dimensions in the FW pass | |
ffw = build_feedforward(test_config) | |
inputs = torch.rand(BATCH, SEQ, LATENT, device=torch.device("cuda")) | |
ffw = ffw.to(torch.device("cuda")) | |
outputs = ffw(inputs) | |
loss = torch.sum(outputs) | |
loss.backward() | |