NeMo / nemo /utils /cast_utils.py
camenduru's picture
thanks to NVIDIA ❤
7934b29
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import nullcontext
import torch
def avoid_bfloat16_autocast_context():
"""
If the current autocast context is bfloat16,
cast it to float32
"""
if torch.is_autocast_enabled() and torch.get_autocast_gpu_dtype() == torch.bfloat16:
return torch.cuda.amp.autocast(dtype=torch.float32)
else:
return nullcontext()
def avoid_float16_autocast_context():
"""
If the current autocast context is float16, cast it to bfloat16
if available (unless we're in jit) or float32
"""
if torch.is_autocast_enabled() and torch.get_autocast_gpu_dtype() == torch.float16:
if torch.jit.is_scripting() or torch.jit.is_tracing():
return torch.cuda.amp.autocast(dtype=torch.float32)
if torch.cuda.is_bf16_supported():
return torch.cuda.amp.autocast(dtype=torch.bfloat16)
else:
return torch.cuda.amp.autocast(dtype=torch.float32)
else:
return nullcontext()
def cast_tensor(x, from_dtype=torch.float16, to_dtype=torch.float32):
return x.to(dtype=to_dtype) if x.dtype == from_dtype else x
def cast_all(x, from_dtype=torch.float16, to_dtype=torch.float32):
if isinstance(x, torch.Tensor):
return cast_tensor(x, from_dtype=from_dtype, to_dtype=to_dtype)
else:
if isinstance(x, dict):
new_dict = {}
for k in x.keys():
new_dict[k] = cast_all(x[k], from_dtype=from_dtype, to_dtype=to_dtype)
return new_dict
elif isinstance(x, tuple):
return tuple(cast_all(y, from_dtype=from_dtype, to_dtype=to_dtype) for y in x)
class CastToFloat(torch.nn.Module):
def __init__(self, mod):
super(CastToFloat, self).__init__()
self.mod = mod
def forward(self, x):
with torch.cuda.amp.autocast(enabled=False):
ret = self.mod.forward(x.to(torch.float32)).to(x.dtype)
return ret
class CastToFloatAll(torch.nn.Module):
def __init__(self, mod):
super(CastToFloatAll, self).__init__()
self.mod = mod
def forward(self, *args):
from_dtype = args[0].dtype
with torch.cuda.amp.autocast(enabled=False):
ret = self.mod.forward(*cast_all(args, from_dtype=from_dtype, to_dtype=torch.float32))
return cast_all(ret, from_dtype=torch.float32, to_dtype=from_dtype)