from contextlib import contextmanager from typing import * import math from ..modules import sparse as sp from ..utils.elastic_utils import ElasticModuleMixin class SparseTransformerElasticMixin(ElasticModuleMixin): def _get_input_size(self, x: sp.SparseTensor, *args, **kwargs): return x.feats.shape[0] @contextmanager def with_mem_ratio(self, mem_ratio=1.0): if mem_ratio == 1.0: yield 1.0 return num_blocks = len(self.blocks) num_checkpoint_blocks = min(math.ceil((1 - mem_ratio) * num_blocks) + 1, num_blocks) exact_mem_ratio = 1 - (num_checkpoint_blocks - 1) / num_blocks for i in range(num_blocks): self.blocks[i].use_checkpoint = i < num_checkpoint_blocks yield exact_mem_ratio for i in range(num_blocks): self.blocks[i].use_checkpoint = False