Mariam-Elz commited on
Commit
eaaa933
·
verified ·
1 Parent(s): 584bda0

Upload imagedream/ldm/modules/distributions/distributions.py with huggingface_hub

Browse files
imagedream/ldm/modules/distributions/distributions.py CHANGED
@@ -1,102 +1,102 @@
1
- import torch
2
- import numpy as np
3
-
4
-
5
- class AbstractDistribution:
6
- def sample(self):
7
- raise NotImplementedError()
8
-
9
- def mode(self):
10
- raise NotImplementedError()
11
-
12
-
13
- class DiracDistribution(AbstractDistribution):
14
- def __init__(self, value):
15
- self.value = value
16
-
17
- def sample(self):
18
- return self.value
19
-
20
- def mode(self):
21
- return self.value
22
-
23
-
24
- class DiagonalGaussianDistribution(object):
25
- def __init__(self, parameters, deterministic=False):
26
- self.parameters = parameters
27
- self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
28
- self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
29
- self.deterministic = deterministic
30
- self.std = torch.exp(0.5 * self.logvar)
31
- self.var = torch.exp(self.logvar)
32
- if self.deterministic:
33
- self.var = self.std = torch.zeros_like(self.mean).to(
34
- device=self.parameters.device
35
- )
36
-
37
- def sample(self):
38
- x = self.mean + self.std * torch.randn(self.mean.shape).to(
39
- device=self.parameters.device
40
- )
41
- return x
42
-
43
- def kl(self, other=None):
44
- if self.deterministic:
45
- return torch.Tensor([0.0])
46
- else:
47
- if other is None:
48
- return 0.5 * torch.sum(
49
- torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
50
- dim=[1, 2, 3],
51
- )
52
- else:
53
- return 0.5 * torch.sum(
54
- torch.pow(self.mean - other.mean, 2) / other.var
55
- + self.var / other.var
56
- - 1.0
57
- - self.logvar
58
- + other.logvar,
59
- dim=[1, 2, 3],
60
- )
61
-
62
- def nll(self, sample, dims=[1, 2, 3]):
63
- if self.deterministic:
64
- return torch.Tensor([0.0])
65
- logtwopi = np.log(2.0 * np.pi)
66
- return 0.5 * torch.sum(
67
- logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
68
- dim=dims,
69
- )
70
-
71
- def mode(self):
72
- return self.mean
73
-
74
-
75
- def normal_kl(mean1, logvar1, mean2, logvar2):
76
- """
77
- source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
78
- Compute the KL divergence between two gaussians.
79
- Shapes are automatically broadcasted, so batches can be compared to
80
- scalars, among other use cases.
81
- """
82
- tensor = None
83
- for obj in (mean1, logvar1, mean2, logvar2):
84
- if isinstance(obj, torch.Tensor):
85
- tensor = obj
86
- break
87
- assert tensor is not None, "at least one argument must be a Tensor"
88
-
89
- # Force variances to be Tensors. Broadcasting helps convert scalars to
90
- # Tensors, but it does not work for torch.exp().
91
- logvar1, logvar2 = [
92
- x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
93
- for x in (logvar1, logvar2)
94
- ]
95
-
96
- return 0.5 * (
97
- -1.0
98
- + logvar2
99
- - logvar1
100
- + torch.exp(logvar1 - logvar2)
101
- + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
102
- )
 
1
+ import torch
2
+ import numpy as np
3
+
4
+
5
+ class AbstractDistribution:
6
+ def sample(self):
7
+ raise NotImplementedError()
8
+
9
+ def mode(self):
10
+ raise NotImplementedError()
11
+
12
+
13
+ class DiracDistribution(AbstractDistribution):
14
+ def __init__(self, value):
15
+ self.value = value
16
+
17
+ def sample(self):
18
+ return self.value
19
+
20
+ def mode(self):
21
+ return self.value
22
+
23
+
24
+ class DiagonalGaussianDistribution(object):
25
+ def __init__(self, parameters, deterministic=False):
26
+ self.parameters = parameters
27
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
28
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
29
+ self.deterministic = deterministic
30
+ self.std = torch.exp(0.5 * self.logvar)
31
+ self.var = torch.exp(self.logvar)
32
+ if self.deterministic:
33
+ self.var = self.std = torch.zeros_like(self.mean).to(
34
+ device=self.parameters.device
35
+ )
36
+
37
+ def sample(self):
38
+ x = self.mean + self.std * torch.randn(self.mean.shape).to(
39
+ device=self.parameters.device
40
+ )
41
+ return x
42
+
43
+ def kl(self, other=None):
44
+ if self.deterministic:
45
+ return torch.Tensor([0.0])
46
+ else:
47
+ if other is None:
48
+ return 0.5 * torch.sum(
49
+ torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
50
+ dim=[1, 2, 3],
51
+ )
52
+ else:
53
+ return 0.5 * torch.sum(
54
+ torch.pow(self.mean - other.mean, 2) / other.var
55
+ + self.var / other.var
56
+ - 1.0
57
+ - self.logvar
58
+ + other.logvar,
59
+ dim=[1, 2, 3],
60
+ )
61
+
62
+ def nll(self, sample, dims=[1, 2, 3]):
63
+ if self.deterministic:
64
+ return torch.Tensor([0.0])
65
+ logtwopi = np.log(2.0 * np.pi)
66
+ return 0.5 * torch.sum(
67
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
68
+ dim=dims,
69
+ )
70
+
71
+ def mode(self):
72
+ return self.mean
73
+
74
+
75
+ def normal_kl(mean1, logvar1, mean2, logvar2):
76
+ """
77
+ source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
78
+ Compute the KL divergence between two gaussians.
79
+ Shapes are automatically broadcasted, so batches can be compared to
80
+ scalars, among other use cases.
81
+ """
82
+ tensor = None
83
+ for obj in (mean1, logvar1, mean2, logvar2):
84
+ if isinstance(obj, torch.Tensor):
85
+ tensor = obj
86
+ break
87
+ assert tensor is not None, "at least one argument must be a Tensor"
88
+
89
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
90
+ # Tensors, but it does not work for torch.exp().
91
+ logvar1, logvar2 = [
92
+ x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
93
+ for x in (logvar1, logvar2)
94
+ ]
95
+
96
+ return 0.5 * (
97
+ -1.0
98
+ + logvar2
99
+ - logvar1
100
+ + torch.exp(logvar1 - logvar2)
101
+ + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
102
+ )