Spaces:
Running
on
Zero
Running
on
Zero
File size: 906 Bytes
87b3c4b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 |
from contextlib import contextmanager
from typing import *
import math
from ..modules import sparse as sp
from ..utils.elastic_utils import ElasticModuleMixin
class SparseTransformerElasticMixin(ElasticModuleMixin):
def _get_input_size(self, x: sp.SparseTensor, *args, **kwargs):
return x.feats.shape[0]
@contextmanager
def with_mem_ratio(self, mem_ratio=1.0):
if mem_ratio == 1.0:
yield 1.0
return
num_blocks = len(self.blocks)
num_checkpoint_blocks = min(math.ceil((1 - mem_ratio) * num_blocks) + 1, num_blocks)
exact_mem_ratio = 1 - (num_checkpoint_blocks - 1) / num_blocks
for i in range(num_blocks):
self.blocks[i].use_checkpoint = i < num_checkpoint_blocks
yield exact_mem_ratio
for i in range(num_blocks):
self.blocks[i].use_checkpoint = False
|