# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. import pytest import torch from xformers.components import Activation from xformers.components.feedforward import FEEDFORWARD_REGISTRY, build_feedforward from xformers.components.feedforward.mixture_of_experts import GateConfig from xformers.helpers.test_utils import init_torch_distributed_local BATCH = 4 SEQ = 256 EMBD = 16 LATENT = 128 DROPOUT = 0.5 DEVICES = ( [torch.device("cpu")] if not torch.cuda.is_available() else [torch.device("cuda")] ) assert FEEDFORWARD_REGISTRY.keys(), "Feedforward layers should have been registered" @pytest.mark.parametrize("feedforward_name", FEEDFORWARD_REGISTRY.keys()) @pytest.mark.parametrize("activation", [a.value for a in Activation]) @pytest.mark.parametrize("device", DEVICES) def test_feedforward( feedforward_name: str, activation: Activation, device: torch.device ): test_config = { "name": feedforward_name, "dim_model": LATENT, "dropout": DROPOUT, "activation": activation, "hidden_layer_multiplier": 4, "number_of_experts": 4, # MoE "gate": "top_2", # MoE } if feedforward_name == "MixtureOfExperts": init_torch_distributed_local() # dummy, just check construction and dimensions in the FW pass ffw = build_feedforward(test_config) if ffw.requires_cuda and not device.type == "cuda": # pyre-fixme[29]: The library function `pytest.skip` is not supported by Pyre. pytest.skip("This MLP requires CUDA and current device does not match") inputs = torch.rand(BATCH, SEQ, LATENT, device=device) ffw = ffw.to(device) _ = ffw(inputs) def get_expert(): return torch.nn.Linear(LATENT, LATENT, bias=False) @pytest.mark.skipif(not torch.cuda.is_available(), reason="This test requires CUDA") @pytest.mark.parametrize("gate", [g.value for g in GateConfig]) @pytest.mark.parametrize("number_of_local_experts", [None, 4]) @pytest.mark.parametrize("expert_constructor", [None, get_expert]) def test_moe(gate, number_of_local_experts, expert_constructor): test_config = { "name": "MixtureOfExperts", "dim_model": LATENT, "dropout": DROPOUT, "activation": Activation.ReLU, "hidden_layer_multiplier": 4, "number_of_experts": 4, "number_of_local_experts": number_of_local_experts, "gate": gate, "expert_constructor": expert_constructor, } init_torch_distributed_local() # dummy, just check construction and dimensions in the FW pass ffw = build_feedforward(test_config) inputs = torch.rand(BATCH, SEQ, LATENT, device=torch.device("cuda")) ffw = ffw.to(torch.device("cuda")) outputs = ffw(inputs) loss = torch.sum(outputs) loss.backward()