File size: 422 Bytes
c19ca42
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
from typing import Optional
import torch
from .utils import rDevice, get_device


class Device:
    def __enter__(self, device: Optional[rDevice]=None):
        torch.dml.context_device = get_device(device)

    def __init__(self, device: Optional[rDevice]=None) -> torch.device: # pylint: disable=return-in-init
        return get_device(device)

    def __exit__(self, t, v, tb):
        torch.dml.context_device = None