File size: 2,534 Bytes
c19ca42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import importlib
from typing import Any, Optional
import torch


ops = ["torch.Tensor.__matmul__", "torch.addbmm", "torch.addmm", "torch.addmv", "torch.addr", "torch.baddbmm", "torch.bmm", "torch.chain_matmul", "torch.linalg.multi_dot", "torch.nn.functional.conv1d", "torch.nn.functional.conv2d", "torch.nn.functional.conv3d", "torch.nn.functional.conv_transpose1d", "torch.nn.functional.conv_transpose2d", "torch.nn.functional.conv_transpose3d", "torch.nn.GRUCell", "torch.nn.functional.linear", "torch.nn.LSTMCell", "torch.matmul", "torch.mm", "torch.mv", "torch.prelu", "torch.nn.RNNCell", "torch.embedding"]
supported_cast_pairs = {
    torch.float16: (torch.float32,),
    torch.float32: (torch.float16,),
}


def forward(op, args: tuple, kwargs: dict):
    if not torch.dml.is_autocast_enabled:
        return op(*args, **kwargs)
    args = list(map(cast, args))
    for kwarg in kwargs:
        kwargs[kwarg] = cast(kwargs[kwarg])
    return op(*args, **kwargs)


def cast(tensor: torch.Tensor):
    if not torch.is_tensor(tensor):
        return tensor
    dtype: torch.dtype = tensor.dtype
    if dtype not in supported_cast_pairs or (torch.dml.autocast_gpu_dtype != dtype and torch.dml.autocast_gpu_dtype not in supported_cast_pairs[dtype]):
        return tensor
    return tensor.type(torch.dml.autocast_gpu_dtype)


def cond(op: str):
    if isinstance(op, str):
        func_path = op.split('.')
        for i in range(len(func_path)-1, -1, -1):
            try:
                resolved_obj = importlib.import_module('.'.join(func_path[:i]))
                break
            except ImportError:
                pass
        for attr_name in func_path[i:-1]:
            resolved_obj = getattr(resolved_obj, attr_name)
        op = getattr(resolved_obj, func_path[-1])
        setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: forward(op, args, kwargs))


for o in ops:
    cond(o)


class autocast:
    prev: bool

    fast_dtype: torch.dtype = torch.float16
    prev_fast_dtype: torch.dtype
    def __init__(self, dtype: Optional[torch.dtype] = torch.float16):
        self.fast_dtype = dtype

    def __enter__(self):
        self.prev = torch.dml.is_autocast_enabled
        self.prev_fast_dtype = torch.dml.autocast_gpu_dtype
        torch.dml.is_autocast_enabled = True
        torch.dml.autocast_gpu_dtype = self.fast_dtype

    def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any):
        torch.dml.is_autocast_enabled = self.prev
        torch.dml.autocast_gpu_dtype = self.prev_fast_dtype