File size: 2,982 Bytes
7934b29 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import nullcontext
import torch
def avoid_bfloat16_autocast_context():
"""
If the current autocast context is bfloat16,
cast it to float32
"""
if torch.is_autocast_enabled() and torch.get_autocast_gpu_dtype() == torch.bfloat16:
return torch.cuda.amp.autocast(dtype=torch.float32)
else:
return nullcontext()
def avoid_float16_autocast_context():
"""
If the current autocast context is float16, cast it to bfloat16
if available (unless we're in jit) or float32
"""
if torch.is_autocast_enabled() and torch.get_autocast_gpu_dtype() == torch.float16:
if torch.jit.is_scripting() or torch.jit.is_tracing():
return torch.cuda.amp.autocast(dtype=torch.float32)
if torch.cuda.is_bf16_supported():
return torch.cuda.amp.autocast(dtype=torch.bfloat16)
else:
return torch.cuda.amp.autocast(dtype=torch.float32)
else:
return nullcontext()
def cast_tensor(x, from_dtype=torch.float16, to_dtype=torch.float32):
return x.to(dtype=to_dtype) if x.dtype == from_dtype else x
def cast_all(x, from_dtype=torch.float16, to_dtype=torch.float32):
if isinstance(x, torch.Tensor):
return cast_tensor(x, from_dtype=from_dtype, to_dtype=to_dtype)
else:
if isinstance(x, dict):
new_dict = {}
for k in x.keys():
new_dict[k] = cast_all(x[k], from_dtype=from_dtype, to_dtype=to_dtype)
return new_dict
elif isinstance(x, tuple):
return tuple(cast_all(y, from_dtype=from_dtype, to_dtype=to_dtype) for y in x)
class CastToFloat(torch.nn.Module):
def __init__(self, mod):
super(CastToFloat, self).__init__()
self.mod = mod
def forward(self, x):
with torch.cuda.amp.autocast(enabled=False):
ret = self.mod.forward(x.to(torch.float32)).to(x.dtype)
return ret
class CastToFloatAll(torch.nn.Module):
def __init__(self, mod):
super(CastToFloatAll, self).__init__()
self.mod = mod
def forward(self, *args):
from_dtype = args[0].dtype
with torch.cuda.amp.autocast(enabled=False):
ret = self.mod.forward(*cast_all(args, from_dtype=from_dtype, to_dtype=torch.float32))
return cast_all(ret, from_dtype=torch.float32, to_dtype=from_dtype)
|