Spaces:
Runtime error
Runtime error
import importlib | |
from typing import Any, Optional | |
import torch | |
ops = ["torch.Tensor.__matmul__", "torch.addbmm", "torch.addmm", "torch.addmv", "torch.addr", "torch.baddbmm", "torch.bmm", "torch.chain_matmul", "torch.linalg.multi_dot", "torch.nn.functional.conv1d", "torch.nn.functional.conv2d", "torch.nn.functional.conv3d", "torch.nn.functional.conv_transpose1d", "torch.nn.functional.conv_transpose2d", "torch.nn.functional.conv_transpose3d", "torch.nn.GRUCell", "torch.nn.functional.linear", "torch.nn.LSTMCell", "torch.matmul", "torch.mm", "torch.mv", "torch.prelu", "torch.nn.RNNCell", "torch.embedding"] | |
supported_cast_pairs = { | |
torch.float16: (torch.float32,), | |
torch.float32: (torch.float16,), | |
} | |
def forward(op, args: tuple, kwargs: dict): | |
if not torch.dml.is_autocast_enabled: | |
return op(*args, **kwargs) | |
args = list(map(cast, args)) | |
for kwarg in kwargs: | |
kwargs[kwarg] = cast(kwargs[kwarg]) | |
return op(*args, **kwargs) | |
def cast(tensor: torch.Tensor): | |
if not torch.is_tensor(tensor): | |
return tensor | |
dtype: torch.dtype = tensor.dtype | |
if dtype not in supported_cast_pairs or (torch.dml.autocast_gpu_dtype != dtype and torch.dml.autocast_gpu_dtype not in supported_cast_pairs[dtype]): | |
return tensor | |
return tensor.type(torch.dml.autocast_gpu_dtype) | |
def cond(op: str): | |
if isinstance(op, str): | |
func_path = op.split('.') | |
for i in range(len(func_path)-1, -1, -1): | |
try: | |
resolved_obj = importlib.import_module('.'.join(func_path[:i])) | |
break | |
except ImportError: | |
pass | |
for attr_name in func_path[i:-1]: | |
resolved_obj = getattr(resolved_obj, attr_name) | |
op = getattr(resolved_obj, func_path[-1]) | |
setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: forward(op, args, kwargs)) | |
for o in ops: | |
cond(o) | |
class autocast: | |
prev: bool | |
fast_dtype: torch.dtype = torch.float16 | |
prev_fast_dtype: torch.dtype | |
def __init__(self, dtype: Optional[torch.dtype] = torch.float16): | |
self.fast_dtype = dtype | |
def __enter__(self): | |
self.prev = torch.dml.is_autocast_enabled | |
self.prev_fast_dtype = torch.dml.autocast_gpu_dtype | |
torch.dml.is_autocast_enabled = True | |
torch.dml.autocast_gpu_dtype = self.fast_dtype | |
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any): | |
torch.dml.is_autocast_enabled = self.prev | |
torch.dml.autocast_gpu_dtype = self.prev_fast_dtype | |