File size: 1,744 Bytes
c19ca42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import torch

from modules.sd_hijack_utils import CondFunc

CondFunc('torchsde._brownian.brownian_interval._randn', lambda _, size, dtype, device, seed: torch.randn(size, dtype=dtype, device=torch.device("cpu"), generator=torch.Generator(torch.device("cpu")).manual_seed(int(seed))).to(device), lambda _, size, dtype, device, seed: device.type == 'privateuseone')

# https://github.com/microsoft/DirectML/issues/400
CondFunc('torch.Tensor.new', lambda orig, self, *args, **kwargs: orig(self.cpu(), *args, **kwargs).to(self.device), lambda orig, self, *args, **kwargs: torch.dml.is_directml_device(self.device))

_lerp = torch.lerp
def lerp(*args, **kwargs) -> torch.Tensor:
    rep = None
    for i in range(0, len(args)):
        if torch.is_tensor(args[i]):
            rep = args[i]
            break
    if rep is None:
        for key in kwargs:
            if torch.is_tensor(kwargs[key]):
                rep = kwargs[key]
                break
    if torch.dml.is_directml_device(rep.device):
        args = list(args)

        if rep.dtype == torch.float16:
            for i in range(len(args)):
                if torch.is_tensor(args[i]):
                    args[i] = args[i].float()
        for i in range(len(args)):
            if torch.is_tensor(args[i]):
                args[i] = args[i].cpu()

        if rep.dtype == torch.float16:
            for kwarg in kwargs:
                if torch.is_tensor(kwargs[kwarg]):
                    kwargs[kwarg] = kwargs[kwarg].float()
        for kwarg in kwargs:
            if torch.is_tensor(kwargs[kwarg]):
                kwargs[kwarg] = kwargs[kwarg].cpu()
        return _lerp(*args, **kwargs).to(rep.device).type(rep.dtype)
    return _lerp(*args, **kwargs)
torch.lerp = lerp