Spaces:
Sleeping
Sleeping
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
""" | |
LayerDrop as described in https://arxiv.org/abs/1909.11556. | |
""" | |
import torch | |
import torch.nn as nn | |
class LayerDropModuleList(nn.ModuleList): | |
""" | |
A LayerDrop implementation based on :class:`torch.nn.ModuleList`. | |
We refresh the choice of which layers to drop every time we iterate | |
over the LayerDropModuleList instance. During evaluation we always | |
iterate over all layers. | |
Usage:: | |
layers = LayerDropList(p=0.5, modules=[layer1, layer2, layer3]) | |
for layer in layers: # this might iterate over layers 1 and 3 | |
x = layer(x) | |
for layer in layers: # this might iterate over all layers | |
x = layer(x) | |
for layer in layers: # this might not iterate over any layers | |
x = layer(x) | |
Args: | |
p (float): probability of dropping out each layer | |
modules (iterable, optional): an iterable of modules to add | |
""" | |
def __init__(self, p, modules=None): | |
super().__init__(modules) | |
self.p = p | |
def __iter__(self): | |
dropout_probs = torch.empty(len(self)).uniform_() | |
for i, m in enumerate(super().__iter__()): | |
if not self.training or (dropout_probs[i] > self.p): | |
yield m | |