test / modules /dml /amp /autocast_mode.py
bilegentile's picture
Upload folder using huggingface_hub
c19ca42 verified
import importlib
from typing import Any, Optional
import torch
ops = ["torch.Tensor.__matmul__", "torch.addbmm", "torch.addmm", "torch.addmv", "torch.addr", "torch.baddbmm", "torch.bmm", "torch.chain_matmul", "torch.linalg.multi_dot", "torch.nn.functional.conv1d", "torch.nn.functional.conv2d", "torch.nn.functional.conv3d", "torch.nn.functional.conv_transpose1d", "torch.nn.functional.conv_transpose2d", "torch.nn.functional.conv_transpose3d", "torch.nn.GRUCell", "torch.nn.functional.linear", "torch.nn.LSTMCell", "torch.matmul", "torch.mm", "torch.mv", "torch.prelu", "torch.nn.RNNCell", "torch.embedding"]
supported_cast_pairs = {
torch.float16: (torch.float32,),
torch.float32: (torch.float16,),
}
def forward(op, args: tuple, kwargs: dict):
if not torch.dml.is_autocast_enabled:
return op(*args, **kwargs)
args = list(map(cast, args))
for kwarg in kwargs:
kwargs[kwarg] = cast(kwargs[kwarg])
return op(*args, **kwargs)
def cast(tensor: torch.Tensor):
if not torch.is_tensor(tensor):
return tensor
dtype: torch.dtype = tensor.dtype
if dtype not in supported_cast_pairs or (torch.dml.autocast_gpu_dtype != dtype and torch.dml.autocast_gpu_dtype not in supported_cast_pairs[dtype]):
return tensor
return tensor.type(torch.dml.autocast_gpu_dtype)
def cond(op: str):
if isinstance(op, str):
func_path = op.split('.')
for i in range(len(func_path)-1, -1, -1):
try:
resolved_obj = importlib.import_module('.'.join(func_path[:i]))
break
except ImportError:
pass
for attr_name in func_path[i:-1]:
resolved_obj = getattr(resolved_obj, attr_name)
op = getattr(resolved_obj, func_path[-1])
setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: forward(op, args, kwargs))
for o in ops:
cond(o)
class autocast:
prev: bool
fast_dtype: torch.dtype = torch.float16
prev_fast_dtype: torch.dtype
def __init__(self, dtype: Optional[torch.dtype] = torch.float16):
self.fast_dtype = dtype
def __enter__(self):
self.prev = torch.dml.is_autocast_enabled
self.prev_fast_dtype = torch.dml.autocast_gpu_dtype
torch.dml.is_autocast_enabled = True
torch.dml.autocast_gpu_dtype = self.fast_dtype
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any):
torch.dml.is_autocast_enabled = self.prev
torch.dml.autocast_gpu_dtype = self.prev_fast_dtype