Spaces:
Runtime error
Runtime error
Using the Reversible block | |
========================== | |
Intro | |
------- | |
This block applies to residual paths, and was first proposed by Gomez et al ([1]_). | |
Its application in the Transformer ([3]_) context was first proposed in the `Reformer` ([2]_) paper, | |
and is largely unrelated to the other proposals from this paper (LSH and chunked MLP processing). | |
We use and very lightly adapt the implementation by Robin Bruegger_ and some blocks from LucidRains_. | |
A reversible layer requires two inputs (x1, x2) and produces two outputs (y1, y2) | |
via two functions F and G, following the relations | |
:: | |
y1 = x1 + F(x2) | |
y2 = x2 + G(y1) | |
In turn, this means that (x1, x2) can be recovered from (y1, y2) (see [1]_ for details) | |
:: | |
x2 = y2 - G(y1) # Note that another FW-like pass is needed | |
x1 = y1 - F(x2) | |
The effect is comparable to activation checkpointing, in that it opens up for a tradeoff in between GPU memory | |
and compute. One benefit is that no extra wrap is needed, all the residual paths can be naturally checkpointed. | |
In a distributed setting, freeing up GPU memory can help using less GPUs, and the saved communication cost can more than make up for the extra compute. | |
Moreover, if your model is made of a stack of reversible blocks, then the memory requirement does not increase with the number of blocks. | |
Transformer | |
----------- | |
Considering the multi-head attention and feedforward blocks (including the residual paths), one can set F as MHA (+ layer norm) and G as Feedforward (+ layer norm) and get to something very close (but not exactly the same) to the original Transformer formulation from [Vaswani et al.][3], as follows | |
:: | |
y1 = x1 + MHA(x2) | |
y2 = x2 + Feedforward(y1) | |
A difference is that the residual path in the Feedforward deals with the original input, and not the MHA output, | |
but in practice if `dim(x1) == dim(x2) == dim(model)`, the accuracy should not be affected, as verified in [2]_ and in xFormers. | |
In practice | |
----------- | |
This repository exposes two main helpers in `xformers.components.reversible`: ReversibleBlock and ReversibleSequence. `ReversibleBlock` will take `f` and `g` as defined above, and `ReversibleSequence` can combine them sequentially, similarly to `torch.nn.ModuleList`. | |
.. code-block:: python | |
class ReversibleBlock(nn.Module): | |
def __init__(self, f: nn.Module, g: nn.Module): | |
... | |
def forward(self, x: torch.Tensor, f_args={}, g_args={}): | |
... | |
class ReversibleSequence(nn.Module): | |
def __init__(self, blocks: nn.ModuleList): | |
... | |
def forward(self, x, arg_route=(True, False), **kwargs): | |
""" | |
arg_route: whether to route the kwargs to f and g | |
""" | |
... | |
Reversible layers are also exposed as a boolean option in when building complete xFormers (which is optional), as defined in `xformers.factory.model_factory`. Please note that the reversible layer is not yet compatible with the use of multiple forward passes and DDP. | |
.. code-block:: python | |
class xFormerStackConfig: | |
block_config: Union[xFormerEncoderConfig, xFormerDecoderConfig] | |
num_layers: int | |
reversible: bool # the sequence of layers becomes reversible | |
.. [1] Gomez, A. N., Ren, M., Urtasun, R., & Grosse, R. B. (2017). | |
The reversible residual network: Backpropagation without storing activations. | |
.. [2] Kitaev, N., Kaiser, Ł., & Levskaya, A. (2020). | |
Reformer: The Efficient Transformer. | |
.. [3] Vaswani et al., | |
Attention is all you need, 2017 | |
.. _Bruegger: https://github.com/RobinBruegger/RevTorch/blob/master/revtorch/revtorch.py | |
.. _LucidRains: https://github.com/lucidrains/reformer-pytorch/blob/master/reformer_pytorch/reversible.py | |