File size: 2,749 Bytes
c19ca42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
# pylint: disable=no-member,no-self-argument,no-method-argument
from typing import Optional, Callable
import torch
import torch_directml # pylint: disable=import-error
import modules.dml.amp as amp
from .utils import rDevice, get_device
from .device import Device
from .Generator import Generator
from .device_properties import DeviceProperties


def amd_mem_get_info(device: Optional[rDevice]=None) -> tuple[int, int]:
    from .memory_amd import AMDMemoryProvider
    return AMDMemoryProvider.mem_get_info(get_device(device).index)


def pdh_mem_get_info(device: Optional[rDevice]=None) -> tuple[int, int]:
    mem_info = DirectML.memory_provider.get_memory(get_device(device).index)
    return (mem_info["total_committed"] - mem_info["dedicated_usage"], mem_info["total_committed"])


def mem_get_info(device: Optional[rDevice]=None) -> tuple[int, int]: # pylint: disable=unused-argument
    return (8589934592, 8589934592)


class DirectML:
    amp = amp
    device = Device
    Generator = Generator

    context_device: Optional[torch.device] = None

    is_autocast_enabled = False
    autocast_gpu_dtype = torch.float16

    memory_provider = None

    def is_available() -> bool:
        return torch_directml.is_available()

    def is_directml_device(device: torch.device) -> bool:
        return device.type == "privateuseone"

    def has_float64_support(device: Optional[rDevice]=None) -> bool:
        return torch_directml.has_float64_support(get_device(device).index)

    def device_count() -> int:
        return torch_directml.device_count()

    def current_device() -> torch.device:
        return DirectML.context_device or DirectML.default_device()

    def default_device() -> torch.device:
        return torch_directml.device(torch_directml.default_device())

    def get_device_string(device: Optional[rDevice]=None) -> str:
        return f"privateuseone:{get_device(device).index}"

    def get_device_name(device: Optional[rDevice]=None) -> str:
        return torch_directml.device_name(get_device(device).index)

    def get_device_properties(device: Optional[rDevice]=None) -> DeviceProperties:
        return DeviceProperties(get_device(device))

    def memory_stats(device: Optional[rDevice]=None):
        return {
            "num_ooms": 0,
            "num_alloc_retries": 0,
        }

    mem_get_info: Callable = mem_get_info

    def memory_allocated(device: Optional[rDevice]=None) -> int:
        return sum(torch_directml.gpu_memory(get_device(device).index)) * (1 << 20)

    def max_memory_allocated(device: Optional[rDevice]=None):
        return DirectML.memory_allocated(device) # DirectML does not empty GPU memory

    def reset_peak_memory_stats(device: Optional[rDevice]=None):
        return