Mariam-Elz commited on
Commit
0f7e7a6
·
verified ·
1 Parent(s): 26198af

Upload imagedream/ldm/util.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. imagedream/ldm/util.py +226 -0
imagedream/ldm/util.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+
3
+ import random
4
+ import torch
5
+ import numpy as np
6
+ from collections import abc
7
+
8
+ import multiprocessing as mp
9
+ from threading import Thread
10
+ from queue import Queue
11
+
12
+ from inspect import isfunction
13
+ from PIL import Image, ImageDraw, ImageFont
14
+
15
+
16
+ def log_txt_as_img(wh, xc, size=10):
17
+ # wh a tuple of (width, height)
18
+ # xc a list of captions to plot
19
+ b = len(xc)
20
+ txts = list()
21
+ for bi in range(b):
22
+ txt = Image.new("RGB", wh, color="white")
23
+ draw = ImageDraw.Draw(txt)
24
+ font = ImageFont.truetype("data/DejaVuSans.ttf", size=size)
25
+ nc = int(40 * (wh[0] / 256))
26
+ lines = "\n".join(
27
+ xc[bi][start : start + nc] for start in range(0, len(xc[bi]), nc)
28
+ )
29
+
30
+ try:
31
+ draw.text((0, 0), lines, fill="black", font=font)
32
+ except UnicodeEncodeError:
33
+ print("Cant encode string for logging. Skipping.")
34
+
35
+ txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
36
+ txts.append(txt)
37
+ txts = np.stack(txts)
38
+ txts = torch.tensor(txts)
39
+ return txts
40
+
41
+
42
+ def ismap(x):
43
+ if not isinstance(x, torch.Tensor):
44
+ return False
45
+ return (len(x.shape) == 4) and (x.shape[1] > 3)
46
+
47
+
48
+ def isimage(x):
49
+ if not isinstance(x, torch.Tensor):
50
+ return False
51
+ return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
52
+
53
+
54
+ def exists(x):
55
+ return x is not None
56
+
57
+
58
+ def default(val, d):
59
+ if exists(val):
60
+ return val
61
+ return d() if isfunction(d) else d
62
+
63
+
64
+ def mean_flat(tensor):
65
+ """
66
+ https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
67
+ Take the mean over all non-batch dimensions.
68
+ """
69
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
70
+
71
+
72
+ def count_params(model, verbose=False):
73
+ total_params = sum(p.numel() for p in model.parameters())
74
+ if verbose:
75
+ print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
76
+ return total_params
77
+
78
+
79
+ def instantiate_from_config(config):
80
+ if not "target" in config:
81
+ if config == "__is_first_stage__":
82
+ return None
83
+ elif config == "__is_unconditional__":
84
+ return None
85
+ raise KeyError("Expected key `target` to instantiate.")
86
+ # import pdb; pdb.set_trace()
87
+ return get_obj_from_str(config["target"])(**config.get("params", dict()))
88
+
89
+
90
+ def get_obj_from_str(string, reload=False):
91
+ module, cls = string.rsplit(".", 1)
92
+ # import pdb; pdb.set_trace()
93
+ if reload:
94
+ module_imp = importlib.import_module(module)
95
+ importlib.reload(module_imp)
96
+ return getattr(importlib.import_module(module, package=None), cls)
97
+
98
+
99
+ def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False):
100
+ # create dummy dataset instance
101
+
102
+ # run prefetching
103
+ if idx_to_fn:
104
+ res = func(data, worker_id=idx)
105
+ else:
106
+ res = func(data)
107
+ Q.put([idx, res])
108
+ Q.put("Done")
109
+
110
+
111
+ def parallel_data_prefetch(
112
+ func: callable,
113
+ data,
114
+ n_proc,
115
+ target_data_type="ndarray",
116
+ cpu_intensive=True,
117
+ use_worker_id=False,
118
+ ):
119
+ # if target_data_type not in ["ndarray", "list"]:
120
+ # raise ValueError(
121
+ # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray."
122
+ # )
123
+ if isinstance(data, np.ndarray) and target_data_type == "list":
124
+ raise ValueError("list expected but function got ndarray.")
125
+ elif isinstance(data, abc.Iterable):
126
+ if isinstance(data, dict):
127
+ print(
128
+ f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
129
+ )
130
+ data = list(data.values())
131
+ if target_data_type == "ndarray":
132
+ data = np.asarray(data)
133
+ else:
134
+ data = list(data)
135
+ else:
136
+ raise TypeError(
137
+ f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}."
138
+ )
139
+
140
+ if cpu_intensive:
141
+ Q = mp.Queue(1000)
142
+ proc = mp.Process
143
+ else:
144
+ Q = Queue(1000)
145
+ proc = Thread
146
+ # spawn processes
147
+ if target_data_type == "ndarray":
148
+ arguments = [
149
+ [func, Q, part, i, use_worker_id]
150
+ for i, part in enumerate(np.array_split(data, n_proc))
151
+ ]
152
+ else:
153
+ step = (
154
+ int(len(data) / n_proc + 1)
155
+ if len(data) % n_proc != 0
156
+ else int(len(data) / n_proc)
157
+ )
158
+ arguments = [
159
+ [func, Q, part, i, use_worker_id]
160
+ for i, part in enumerate(
161
+ [data[i : i + step] for i in range(0, len(data), step)]
162
+ )
163
+ ]
164
+ processes = []
165
+ for i in range(n_proc):
166
+ p = proc(target=_do_parallel_data_prefetch, args=arguments[i])
167
+ processes += [p]
168
+
169
+ # start processes
170
+ print(f"Start prefetching...")
171
+ import time
172
+
173
+ start = time.time()
174
+ gather_res = [[] for _ in range(n_proc)]
175
+ try:
176
+ for p in processes:
177
+ p.start()
178
+
179
+ k = 0
180
+ while k < n_proc:
181
+ # get result
182
+ res = Q.get()
183
+ if res == "Done":
184
+ k += 1
185
+ else:
186
+ gather_res[res[0]] = res[1]
187
+
188
+ except Exception as e:
189
+ print("Exception: ", e)
190
+ for p in processes:
191
+ p.terminate()
192
+
193
+ raise e
194
+ finally:
195
+ for p in processes:
196
+ p.join()
197
+ print(f"Prefetching complete. [{time.time() - start} sec.]")
198
+
199
+ if target_data_type == "ndarray":
200
+ if not isinstance(gather_res[0], np.ndarray):
201
+ return np.concatenate([np.asarray(r) for r in gather_res], axis=0)
202
+
203
+ # order outputs
204
+ return np.concatenate(gather_res, axis=0)
205
+ elif target_data_type == "list":
206
+ out = []
207
+ for r in gather_res:
208
+ out.extend(r)
209
+ return out
210
+ else:
211
+ return gather_res
212
+
213
+ def set_seed(seed=None):
214
+ random.seed(seed)
215
+ np.random.seed(seed)
216
+ if seed is not None:
217
+ torch.manual_seed(seed)
218
+ torch.cuda.manual_seed_all(seed)
219
+
220
+ def add_random_background(image, bg_color=None):
221
+ bg_color = np.random.rand() * 255 if bg_color is None else bg_color
222
+ image = np.array(image)
223
+ rgb, alpha = image[..., :3], image[..., 3:]
224
+ alpha = alpha.astype(np.float32) / 255.0
225
+ image_new = rgb * alpha + bg_color * (1 - alpha)
226
+ return Image.fromarray(image_new.astype(np.uint8))