Spaces:
Runtime error
Runtime error
# coding=utf-8 | |
# Copyright 2021 The Deeplab2 Authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Library for rematerialization. | |
Incubates a version of tf.recompute_grad that is XLA compatible. | |
This file is based on the recompute_grad.py in the bigbird codebase [1]: | |
https://github.com/google-research/bigbird/blob/db06498ec8804c6438111938d8654b66ddaccd5d/bigbird/core/recompute_grad.py | |
[1] Big Bird: Transformers for Longer Sequences, NeurIPS 2020. | |
Manzil Zaheer, Guru Guruganesh, Avinava Dubey, Joshua Ainslie, Chris | |
Alberti, Santiago Ontanon, Philip Pham, Anirudh Ravula, Qifan Wang, Li | |
Yang, Amr Ahmed. | |
""" | |
import collections | |
import os | |
import threading | |
from typing import Deque, List, NamedTuple, Optional, Sequence | |
from absl import logging | |
import tensorflow.compat.v2 as tf | |
# pylint: disable=g-direct-tensorflow-import | |
from tensorflow.python.framework import ops | |
from tensorflow.python.ops import custom_gradient | |
# Remove when https://github.com/tensorflow/tensorflow/pull/45298 | |
# gets merged | |
def get_variable_by_name(var_name): | |
"""Retrieves tf.Variable from name in MirroredStrategy (multi-gpu).""" | |
# Get all variables, but it will have copies from different replicas | |
all_global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) | |
def _replica_filter(var): | |
"""Filter out variables from different context.""" | |
try: | |
return var_name == var.op.name | |
except AttributeError: | |
return False | |
candidate_vars = list(filter(_replica_filter, all_global_vars)) | |
if len(candidate_vars) >= 1: | |
# Filter out non-trainable variables. | |
candidate_vars = [v for v in candidate_vars if v.trainable] | |
else: | |
raise ValueError('Unsuccessful at finding variable {}.'.format(var_name)) | |
if len(candidate_vars) == 1: | |
return candidate_vars[0] | |
elif len(candidate_vars) > 1: | |
raise ValueError( | |
'Unsuccessful at finding trainable variable {}. ' | |
'Number of candidates: {}. ' | |
'Candidates: {}'.format(var_name, len(candidate_vars), candidate_vars)) | |
else: | |
# The variable is not trainable. | |
return None | |
custom_gradient.get_variable_by_name = get_variable_by_name | |
class RecomputeContext( | |
NamedTuple('RecomputeContext', [ | |
('is_recomputing', bool), | |
('seed', tf.Tensor), | |
('children', Deque['RecomputeContext']), | |
])): | |
"""Context for recomputation. | |
Attributes: | |
is_recomputing: Whether we are in a recomputation phase. | |
seed: Scalar integer tensor that should be used with stateless random ops | |
for deterministic behavior and correct computation of the gradient. | |
children: Nested `RecomputeContext` instances. Used internally by | |
`recompute_grad` to track nested instances of `RecomputeContext`. | |
""" | |
def __enter__(self): | |
return _context_stack.push(self) | |
def __exit__(self, exc_type, exc_value, traceback): | |
_context_stack.pop(self) | |
# Simplified version of `_DefaultStack` in | |
# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/framework/ops.py. | |
class _ContextStack(threading.local): | |
"""A thread-local stack for providing implicit recompute contexts.""" | |
def __init__(self): | |
super(_ContextStack, self).__init__() | |
self._stack = [] | |
def top(self) -> Optional[RecomputeContext]: | |
return self._stack[-1] if self._stack else None | |
def push(self, context: RecomputeContext): | |
self._stack.append(context) | |
return context | |
def pop(self, context: RecomputeContext): | |
if self._stack[-1] is not context: | |
raise AssertionError('Nesting violated for RecomputeContext.') | |
self._stack.pop() | |
_context_stack = _ContextStack() | |
def get_recompute_context() -> Optional[RecomputeContext]: | |
"""Returns the current recomputing context if it exists.""" | |
return _context_stack.top() | |
# Adapted from | |
# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/control_flow_util.py. | |
def _get_containing_xla_context(graph: tf.Graph) -> Optional[object]: | |
"""Returns the first ancestor `XLAControlFlowContext` in the `graph`.""" | |
ctxt = graph._get_control_flow_context() # pylint: disable=protected-access | |
while ctxt: | |
if ctxt.IsXLAContext(): | |
return ctxt | |
ctxt = ctxt.outer_context | |
return None | |
def _in_xla_context(graph: Optional[tf.Graph] = None) -> bool: | |
"""Detects whether we are in an XLA context.""" | |
if '--tf_xla_auto_jit=2' in os.environ.get('TF_XLA_FLAGS', ''): | |
return True | |
graph = tf.compat.v1.get_default_graph() if graph is None else graph | |
while True: | |
if _get_containing_xla_context(graph) is not None: | |
return True | |
try: | |
graph = graph.outer_graph | |
except AttributeError: | |
return False | |
def _force_data_dependency( | |
first_compute: Sequence[tf.Tensor], | |
then_compute: Sequence[tf.Tensor]) -> List[tf.Tensor]: | |
"""Forces all of `then_compute` to depend on all of `first_compute`. | |
Uses a dummy data dependency, which is useful when running on TPUs because | |
XLA ignores control dependencies. Only supports float arguments. | |
Args: | |
first_compute: Sequence of `Tensor`s to be executed before `then_compute`. | |
then_compute: Sequence of `Tensor`s to executed after `first_compute`. | |
Returns: | |
Sequence of `Tensor`s with same length of `then_compute`. | |
Raises: | |
ValueError: if ranks are unknown or types are not floating. | |
""" | |
def _first_element(x): | |
if x.shape.ndims is None: | |
raise ValueError('Rank of Tensor %s must be known' % x) | |
ndims = x.shape.ndims | |
begin = tf.zeros(ndims, dtype=tf.int32) | |
size = tf.ones(ndims, dtype=tf.int32) | |
return tf.reshape(tf.slice(x, begin, size), []) | |
first_compute_sum = tf.add_n( | |
[_first_element(x) for x in first_compute if x is not None]) | |
dtype = first_compute_sum.dtype | |
if not dtype.is_floating: | |
raise ValueError('_force_data_dependency only supports floating dtypes.') | |
zero = tf.cast(0.0, first_compute_sum.dtype) * first_compute_sum | |
then_compute_sequence = [ | |
x + tf.cast(zero, x.dtype) if x is not None else None | |
for x in tf.nest.flatten(then_compute) | |
] | |
return tf.nest.pack_sequence_as(then_compute, then_compute_sequence) | |
def _make_seed_if_none(seed: Optional[tf.Tensor]) -> tf.Tensor: | |
"""Uses the global generator to make a seed if necessary.""" | |
if seed is not None: | |
return seed | |
generator = tf.random.experimental.get_global_generator() | |
# The two seeds for stateless random ops don't have individual semantics and | |
# are scrambled together, so providing one seed is fine. This makes it easier | |
# for users to provide a local seed without worrying about integer overflow. | |
# See `make_seeds` in | |
# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/stateful_random_ops.py. | |
try: | |
return generator.uniform_full_int([], tf.int32, name='recompute_grad_seed') | |
except (RuntimeError, TypeError, ValueError, tf.errors.NotFoundError) as e: | |
# For a number of reasons, the above operation can fail like using multiple | |
# graphs or toggling between eager and graph modes. Reset the generator. | |
logging.warn('Resetting the generator. %s: %s', type(e), e) | |
tf.random.experimental.set_global_generator(None) | |
generator = tf.random.experimental.get_global_generator() | |
return generator.uniform_full_int([], tf.int32, name='recompute_grad_seed') | |
def recompute_grad(f, seed=None): | |
"""An eager-compatible version of recompute_grad. | |
For f(*args, **kwargs), this supports gradients with respect to args, or to | |
gradients with respect to any variables residing in the kwarg 'variables'. | |
Note that for keras layer and model objects, this is handled automatically. | |
Warning: If `f` was originally a tf.keras Model or Layer object, `g` will not | |
be able to access the member variables of that object, because `g` returns | |
through the wrapper function `inner`. When recomputing gradients through | |
objects that inherit from keras, we suggest keeping a reference to the | |
underlying object around for the purpose of accessing these variables. | |
Args: | |
f: function `f(*x)` that returns a `Tensor` or sequence of `Tensor` outputs. | |
seed: Optional seed for random ops. `seed` should an integer scalar | |
`Tensor`. When compiling to XLA, `seed` must have dtype `tf.int32`. If | |
`seed` is not provided one will be generated. | |
Returns: | |
A function `g` that wraps `f`, but which recomputes `f` on the backwards | |
pass of a gradient call. | |
""" | |
def inner(*args, **kwargs): | |
"""Inner function closure for calculating gradients.""" | |
# Detect when we're nested and in the backwards pass, so we don't generate | |
# an additional seed. | |
parent_context = get_recompute_context() | |
if parent_context is not None and parent_context.is_recomputing: | |
# Use the cached context in the recomputation phase. | |
with parent_context.children.popleft()._replace( | |
is_recomputing=True) as context: | |
result = f(*args, **kwargs) | |
else: | |
with RecomputeContext( | |
is_recomputing=False, | |
seed=_make_seed_if_none(seed), | |
children=collections.deque()) as context: | |
result = f(*args, **kwargs) | |
# In the forward pass, build up a tree of recomputation contexts. | |
if parent_context is not None and not parent_context.is_recomputing: | |
parent_context.children.append(context) | |
def grad(*dresult, **grad_kwargs): | |
"""Gradient function calculation for inner function.""" | |
variables = grad_kwargs.pop('variables', None) | |
if grad_kwargs: | |
raise ValueError('Found unexpected kwargs for `grad`: ', | |
list(grad_kwargs.keys())) | |
inputs, seed = list(args), context.seed | |
if _in_xla_context(): | |
inputs = _force_data_dependency( | |
tf.nest.flatten(dresult), inputs + [seed]) | |
seed = inputs.pop() | |
# tf.keras.backend.set_learning_phase(1) | |
with tf.GradientTape() as tape: | |
tape.watch(inputs) | |
if variables is not None: | |
tape.watch(variables) | |
with tf.control_dependencies(dresult): | |
with context._replace(is_recomputing=True, seed=seed): | |
result = f(*inputs, **kwargs) | |
kw_vars = [] | |
if variables is not None: | |
kw_vars = list(variables) | |
grads = tape.gradient( | |
result, list(inputs) + kw_vars, output_gradients=dresult) | |
return grads[:len(inputs)], grads[len(inputs):] | |
return result, grad | |
return inner | |