lehduong commited on
Commit
acf05ba
·
verified ·
1 Parent(s): 077684c

Delete diffusion/pipelines/image_processor.py with huggingface_hub

Browse files
diffusion/pipelines/image_processor.py DELETED
@@ -1,674 +0,0 @@
1
- # Copyright 2024 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- import math
16
- import warnings
17
- from typing import List, Optional, Tuple, Union
18
-
19
- import numpy as np
20
- import PIL.Image
21
- import torch
22
- import torch.nn.functional as F
23
- import torchvision.transforms as T
24
- from PIL import Image, ImageFilter, ImageOps
25
-
26
- from diffusers.configuration_utils import ConfigMixin, register_to_config
27
- from diffusers.utils import CONFIG_NAME, PIL_INTERPOLATION, deprecate
28
-
29
- from onediffusion.dataset.transforms import CenterCropResizeImage
30
-
31
- PipelineImageInput = Union[
32
- PIL.Image.Image,
33
- np.ndarray,
34
- torch.Tensor,
35
- List[PIL.Image.Image],
36
- List[np.ndarray],
37
- List[torch.Tensor],
38
- ]
39
-
40
- PipelineDepthInput = PipelineImageInput
41
-
42
-
43
- def is_valid_image(image):
44
- return isinstance(image, PIL.Image.Image) or isinstance(image, (np.ndarray, torch.Tensor)) and image.ndim in (2, 3)
45
-
46
-
47
- def is_valid_image_imagelist(images):
48
- # check if the image input is one of the supported formats for image and image list:
49
- # it can be either one of below 3
50
- # (1) a 4d pytorch tensor or numpy array,
51
- # (2) a valid image: PIL.Image.Image, 2-d np.ndarray or torch.Tensor (grayscale image), 3-d np.ndarray or torch.Tensor
52
- # (3) a list of valid image
53
- if isinstance(images, (np.ndarray, torch.Tensor)) and images.ndim == 4:
54
- return True
55
- elif is_valid_image(images):
56
- return True
57
- elif isinstance(images, list):
58
- return all(is_valid_image(image) for image in images)
59
- return False
60
-
61
-
62
- class VaeImageProcessorOneDiffuser(ConfigMixin):
63
- """
64
- Image processor for VAE.
65
-
66
- Args:
67
- do_resize (`bool`, *optional*, defaults to `True`):
68
- Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept
69
- `height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method.
70
- vae_scale_factor (`int`, *optional*, defaults to `8`):
71
- VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
72
- resample (`str`, *optional*, defaults to `lanczos`):
73
- Resampling filter to use when resizing the image.
74
- do_normalize (`bool`, *optional*, defaults to `True`):
75
- Whether to normalize the image to [-1,1].
76
- do_binarize (`bool`, *optional*, defaults to `False`):
77
- Whether to binarize the image to 0/1.
78
- do_convert_rgb (`bool`, *optional*, defaults to be `False`):
79
- Whether to convert the images to RGB format.
80
- do_convert_grayscale (`bool`, *optional*, defaults to be `False`):
81
- Whether to convert the images to grayscale format.
82
- """
83
-
84
- config_name = CONFIG_NAME
85
-
86
- @register_to_config
87
- def __init__(
88
- self,
89
- do_resize: bool = True,
90
- vae_scale_factor: int = 8,
91
- vae_latent_channels: int = 4,
92
- resample: str = "lanczos",
93
- do_normalize: bool = True,
94
- do_binarize: bool = False,
95
- do_convert_rgb: bool = False,
96
- do_convert_grayscale: bool = False,
97
- ):
98
- super().__init__()
99
- if do_convert_rgb and do_convert_grayscale:
100
- raise ValueError(
101
- "`do_convert_rgb` and `do_convert_grayscale` can not both be set to `True`,"
102
- " if you intended to convert the image into RGB format, please set `do_convert_grayscale = False`.",
103
- " if you intended to convert the image into grayscale format, please set `do_convert_rgb = False`",
104
- )
105
-
106
- @staticmethod
107
- def numpy_to_pil(images: np.ndarray) -> List[PIL.Image.Image]:
108
- """
109
- Convert a numpy image or a batch of images to a PIL image.
110
- """
111
- if images.ndim == 3:
112
- images = images[None, ...]
113
- images = (images * 255).round().astype("uint8")
114
- if images.shape[-1] == 1:
115
- # special case for grayscale (single channel) images
116
- pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
117
- else:
118
- pil_images = [Image.fromarray(image) for image in images]
119
-
120
- return pil_images
121
-
122
- @staticmethod
123
- def pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray:
124
- """
125
- Convert a PIL image or a list of PIL images to NumPy arrays.
126
- """
127
- if not isinstance(images, list):
128
- images = [images]
129
- images = [np.array(image).astype(np.float32) / 255.0 for image in images]
130
- images = np.stack(images, axis=0)
131
-
132
- return images
133
-
134
- @staticmethod
135
- def numpy_to_pt(images: np.ndarray) -> torch.Tensor:
136
- """
137
- Convert a NumPy image to a PyTorch tensor.
138
- """
139
- if images.ndim == 3:
140
- images = images[..., None]
141
-
142
- images = torch.from_numpy(images.transpose(0, 3, 1, 2))
143
- return images
144
-
145
- @staticmethod
146
- def pt_to_numpy(images: torch.Tensor) -> np.ndarray:
147
- """
148
- Convert a PyTorch tensor to a NumPy image.
149
- """
150
- images = images.cpu().permute(0, 2, 3, 1).float().numpy()
151
- return images
152
-
153
- @staticmethod
154
- def normalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
155
- """
156
- Normalize an image array to [-1,1].
157
- """
158
- return 2.0 * images - 1.0
159
-
160
- @staticmethod
161
- def denormalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
162
- """
163
- Denormalize an image array to [0,1].
164
- """
165
- return (images / 2 + 0.5).clamp(0, 1)
166
-
167
- @staticmethod
168
- def convert_to_rgb(image: PIL.Image.Image) -> PIL.Image.Image:
169
- """
170
- Converts a PIL image to RGB format.
171
- """
172
- image = image.convert("RGB")
173
-
174
- return image
175
-
176
- @staticmethod
177
- def convert_to_grayscale(image: PIL.Image.Image) -> PIL.Image.Image:
178
- """
179
- Converts a PIL image to grayscale format.
180
- """
181
- image = image.convert("L")
182
-
183
- return image
184
-
185
- @staticmethod
186
- def blur(image: PIL.Image.Image, blur_factor: int = 4) -> PIL.Image.Image:
187
- """
188
- Applies Gaussian blur to an image.
189
- """
190
- image = image.filter(ImageFilter.GaussianBlur(blur_factor))
191
-
192
- return image
193
-
194
- @staticmethod
195
- def get_crop_region(mask_image: PIL.Image.Image, width: int, height: int, pad=0):
196
- """
197
- Finds a rectangular region that contains all masked ares in an image, and expands region to match the aspect
198
- ratio of the original image; for example, if user drew mask in a 128x32 region, and the dimensions for
199
- processing are 512x512, the region will be expanded to 128x128.
200
-
201
- Args:
202
- mask_image (PIL.Image.Image): Mask image.
203
- width (int): Width of the image to be processed.
204
- height (int): Height of the image to be processed.
205
- pad (int, optional): Padding to be added to the crop region. Defaults to 0.
206
-
207
- Returns:
208
- tuple: (x1, y1, x2, y2) represent a rectangular region that contains all masked ares in an image and
209
- matches the original aspect ratio.
210
- """
211
-
212
- mask_image = mask_image.convert("L")
213
- mask = np.array(mask_image)
214
-
215
- # 1. find a rectangular region that contains all masked ares in an image
216
- h, w = mask.shape
217
- crop_left = 0
218
- for i in range(w):
219
- if not (mask[:, i] == 0).all():
220
- break
221
- crop_left += 1
222
-
223
- crop_right = 0
224
- for i in reversed(range(w)):
225
- if not (mask[:, i] == 0).all():
226
- break
227
- crop_right += 1
228
-
229
- crop_top = 0
230
- for i in range(h):
231
- if not (mask[i] == 0).all():
232
- break
233
- crop_top += 1
234
-
235
- crop_bottom = 0
236
- for i in reversed(range(h)):
237
- if not (mask[i] == 0).all():
238
- break
239
- crop_bottom += 1
240
-
241
- # 2. add padding to the crop region
242
- x1, y1, x2, y2 = (
243
- int(max(crop_left - pad, 0)),
244
- int(max(crop_top - pad, 0)),
245
- int(min(w - crop_right + pad, w)),
246
- int(min(h - crop_bottom + pad, h)),
247
- )
248
-
249
- # 3. expands crop region to match the aspect ratio of the image to be processed
250
- ratio_crop_region = (x2 - x1) / (y2 - y1)
251
- ratio_processing = width / height
252
-
253
- if ratio_crop_region > ratio_processing:
254
- desired_height = (x2 - x1) / ratio_processing
255
- desired_height_diff = int(desired_height - (y2 - y1))
256
- y1 -= desired_height_diff // 2
257
- y2 += desired_height_diff - desired_height_diff // 2
258
- if y2 >= mask_image.height:
259
- diff = y2 - mask_image.height
260
- y2 -= diff
261
- y1 -= diff
262
- if y1 < 0:
263
- y2 -= y1
264
- y1 -= y1
265
- if y2 >= mask_image.height:
266
- y2 = mask_image.height
267
- else:
268
- desired_width = (y2 - y1) * ratio_processing
269
- desired_width_diff = int(desired_width - (x2 - x1))
270
- x1 -= desired_width_diff // 2
271
- x2 += desired_width_diff - desired_width_diff // 2
272
- if x2 >= mask_image.width:
273
- diff = x2 - mask_image.width
274
- x2 -= diff
275
- x1 -= diff
276
- if x1 < 0:
277
- x2 -= x1
278
- x1 -= x1
279
- if x2 >= mask_image.width:
280
- x2 = mask_image.width
281
-
282
- return x1, y1, x2, y2
283
-
284
- def _resize_and_fill(
285
- self,
286
- image: PIL.Image.Image,
287
- width: int,
288
- height: int,
289
- ) -> PIL.Image.Image:
290
- """
291
- Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center
292
- the image within the dimensions, filling empty with data from image.
293
-
294
- Args:
295
- image: The image to resize.
296
- width: The width to resize the image to.
297
- height: The height to resize the image to.
298
- """
299
-
300
- ratio = width / height
301
- src_ratio = image.width / image.height
302
-
303
- src_w = width if ratio < src_ratio else image.width * height // image.height
304
- src_h = height if ratio >= src_ratio else image.height * width // image.width
305
-
306
- resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION["lanczos"])
307
- res = Image.new("RGB", (width, height))
308
- res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
309
-
310
- if ratio < src_ratio:
311
- fill_height = height // 2 - src_h // 2
312
- if fill_height > 0:
313
- res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))
314
- res.paste(
315
- resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)),
316
- box=(0, fill_height + src_h),
317
- )
318
- elif ratio > src_ratio:
319
- fill_width = width // 2 - src_w // 2
320
- if fill_width > 0:
321
- res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))
322
- res.paste(
323
- resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)),
324
- box=(fill_width + src_w, 0),
325
- )
326
-
327
- return res
328
-
329
- def _resize_and_crop(
330
- self,
331
- image: PIL.Image.Image,
332
- width: int,
333
- height: int,
334
- ) -> PIL.Image.Image:
335
- """
336
- Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center
337
- the image within the dimensions, cropping the excess.
338
-
339
- Args:
340
- image: The image to resize.
341
- width: The width to resize the image to.
342
- height: The height to resize the image to.
343
- """
344
- ratio = width / height
345
- src_ratio = image.width / image.height
346
-
347
- src_w = width if ratio > src_ratio else image.width * height // image.height
348
- src_h = height if ratio <= src_ratio else image.height * width // image.width
349
-
350
- resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION["lanczos"])
351
- res = Image.new("RGB", (width, height))
352
- res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
353
- return res
354
-
355
- def resize(
356
- self,
357
- image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
358
- height: int,
359
- width: int,
360
- resize_mode: str = "default", # "default", "fill", "crop"
361
- ) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]:
362
- """
363
- Resize image.
364
-
365
- Args:
366
- image (`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`):
367
- The image input, can be a PIL image, numpy array or pytorch tensor.
368
- height (`int`):
369
- The height to resize to.
370
- width (`int`):
371
- The width to resize to.
372
- resize_mode (`str`, *optional*, defaults to `default`):
373
- The resize mode to use, can be one of `default` or `fill`. If `default`, will resize the image to fit
374
- within the specified width and height, and it may not maintaining the original aspect ratio. If `fill`,
375
- will resize the image to fit within the specified width and height, maintaining the aspect ratio, and
376
- then center the image within the dimensions, filling empty with data from image. If `crop`, will resize
377
- the image to fit within the specified width and height, maintaining the aspect ratio, and then center
378
- the image within the dimensions, cropping the excess. Note that resize_mode `fill` and `crop` are only
379
- supported for PIL image input.
380
-
381
- Returns:
382
- `PIL.Image.Image`, `np.ndarray` or `torch.Tensor`:
383
- The resized image.
384
- """
385
- if resize_mode != "default" and not isinstance(image, PIL.Image.Image):
386
- raise ValueError(f"Only PIL image input is supported for resize_mode {resize_mode}")
387
- if isinstance(image, PIL.Image.Image):
388
- if resize_mode == "default":
389
- image = image.resize((width, height), resample=PIL_INTERPOLATION[self.config.resample])
390
- elif resize_mode == "fill":
391
- image = self._resize_and_fill(image, width, height)
392
- elif resize_mode == "crop":
393
- image = self._resize_and_crop(image, width, height)
394
- else:
395
- raise ValueError(f"resize_mode {resize_mode} is not supported")
396
-
397
- elif isinstance(image, torch.Tensor):
398
- image = torch.nn.functional.interpolate(
399
- image,
400
- size=(height, width),
401
- )
402
- elif isinstance(image, np.ndarray):
403
- image = self.numpy_to_pt(image)
404
- image = torch.nn.functional.interpolate(
405
- image,
406
- size=(height, width),
407
- )
408
- image = self.pt_to_numpy(image)
409
- return image
410
-
411
- def binarize(self, image: PIL.Image.Image) -> PIL.Image.Image:
412
- """
413
- Create a mask.
414
-
415
- Args:
416
- image (`PIL.Image.Image`):
417
- The image input, should be a PIL image.
418
-
419
- Returns:
420
- `PIL.Image.Image`:
421
- The binarized image. Values less than 0.5 are set to 0, values greater than 0.5 are set to 1.
422
- """
423
- image[image < 0.5] = 0
424
- image[image >= 0.5] = 1
425
-
426
- return image
427
-
428
- def get_default_height_width(
429
- self,
430
- image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
431
- height: Optional[int] = None,
432
- width: Optional[int] = None,
433
- ) -> Tuple[int, int]:
434
- """
435
- This function return the height and width that are downscaled to the next integer multiple of
436
- `vae_scale_factor`.
437
-
438
- Args:
439
- image(`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`):
440
- The image input, can be a PIL image, numpy array or pytorch tensor. if it is a numpy array, should have
441
- shape `[batch, height, width]` or `[batch, height, width, channel]` if it is a pytorch tensor, should
442
- have shape `[batch, channel, height, width]`.
443
- height (`int`, *optional*, defaults to `None`):
444
- The height in preprocessed image. If `None`, will use the height of `image` input.
445
- width (`int`, *optional*`, defaults to `None`):
446
- The width in preprocessed. If `None`, will use the width of the `image` input.
447
- """
448
-
449
- if height is None:
450
- if isinstance(image, PIL.Image.Image):
451
- height = image.height
452
- elif isinstance(image, torch.Tensor):
453
- height = image.shape[2]
454
- else:
455
- height = image.shape[1]
456
-
457
- if width is None:
458
- if isinstance(image, PIL.Image.Image):
459
- width = image.width
460
- elif isinstance(image, torch.Tensor):
461
- width = image.shape[3]
462
- else:
463
- width = image.shape[2]
464
-
465
- width, height = (
466
- x - x % self.config.vae_scale_factor for x in (width, height)
467
- ) # resize to integer multiple of vae_scale_factor
468
-
469
- return height, width
470
-
471
- def preprocess(
472
- self,
473
- image: PipelineImageInput,
474
- height: Optional[int] = None,
475
- width: Optional[int] = None,
476
- resize_mode: str = "default", # "default", "fill", "crop"
477
- crops_coords: Optional[Tuple[int, int, int, int]] = None,
478
- do_crop: bool = True,
479
- ) -> torch.Tensor:
480
- """
481
- Preprocess the image input.
482
-
483
- Args:
484
- image (`pipeline_image_input`):
485
- The image input, accepted formats are PIL images, NumPy arrays, PyTorch tensors; Also accept list of
486
- supported formats.
487
- height (`int`, *optional*, defaults to `None`):
488
- The height in preprocessed image. If `None`, will use the `get_default_height_width()` to get default
489
- height.
490
- width (`int`, *optional*`, defaults to `None`):
491
- The width in preprocessed. If `None`, will use get_default_height_width()` to get the default width.
492
- resize_mode (`str`, *optional*, defaults to `default`):
493
- The resize mode, can be one of `default` or `fill`. If `default`, will resize the image to fit within
494
- the specified width and height, and it may not maintaining the original aspect ratio. If `fill`, will
495
- resize the image to fit within the specified width and height, maintaining the aspect ratio, and then
496
- center the image within the dimensions, filling empty with data from image. If `crop`, will resize the
497
- image to fit within the specified width and height, maintaining the aspect ratio, and then center the
498
- image within the dimensions, cropping the excess. Note that resize_mode `fill` and `crop` are only
499
- supported for PIL image input.
500
- crops_coords (`List[Tuple[int, int, int, int]]`, *optional*, defaults to `None`):
501
- The crop coordinates for each image in the batch. If `None`, will not crop the image.
502
- """
503
- supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor)
504
-
505
- # Expand the missing dimension for 3-dimensional pytorch tensor or numpy array that represents grayscale image
506
- if self.config.do_convert_grayscale and isinstance(image, (torch.Tensor, np.ndarray)) and image.ndim == 3:
507
- if isinstance(image, torch.Tensor):
508
- # if image is a pytorch tensor could have 2 possible shapes:
509
- # 1. batch x height x width: we should insert the channel dimension at position 1
510
- # 2. channel x height x width: we should insert batch dimension at position 0,
511
- # however, since both channel and batch dimension has same size 1, it is same to insert at position 1
512
- # for simplicity, we insert a dimension of size 1 at position 1 for both cases
513
- image = image.unsqueeze(1)
514
- else:
515
- # if it is a numpy array, it could have 2 possible shapes:
516
- # 1. batch x height x width: insert channel dimension on last position
517
- # 2. height x width x channel: insert batch dimension on first position
518
- if image.shape[-1] == 1:
519
- image = np.expand_dims(image, axis=0)
520
- else:
521
- image = np.expand_dims(image, axis=-1)
522
-
523
- if isinstance(image, list) and isinstance(image[0], np.ndarray) and image[0].ndim == 4:
524
- warnings.warn(
525
- "Passing `image` as a list of 4d np.ndarray is deprecated."
526
- "Please concatenate the list along the batch dimension and pass it as a single 4d np.ndarray",
527
- FutureWarning,
528
- )
529
- image = np.concatenate(image, axis=0)
530
- if isinstance(image, list) and isinstance(image[0], torch.Tensor) and image[0].ndim == 4:
531
- warnings.warn(
532
- "Passing `image` as a list of 4d torch.Tensor is deprecated."
533
- "Please concatenate the list along the batch dimension and pass it as a single 4d torch.Tensor",
534
- FutureWarning,
535
- )
536
- image = torch.cat(image, axis=0)
537
-
538
- if not is_valid_image_imagelist(image):
539
- raise ValueError(
540
- f"Input is in incorrect format. Currently, we only support {', '.join(str(x) for x in supported_formats)}"
541
- )
542
- if not isinstance(image, list):
543
- image = [image]
544
-
545
- if isinstance(image[0], PIL.Image.Image):
546
- pass
547
- elif isinstance(image[0], np.ndarray):
548
- image = self.numpy_to_pil(image)
549
- elif isinstance(image[0], torch.Tensor):
550
- image = self.pt_to_numpy(image)
551
- image = self.numpy_to_pil(image)
552
-
553
- if do_crop:
554
- transforms = T.Compose([
555
- T.Lambda(lambda image: image.convert('RGB')),
556
- T.ToTensor(),
557
- CenterCropResizeImage((height, width)),
558
- T.Normalize([.5], [.5]),
559
- ])
560
- else:
561
- transforms = T.Compose([
562
- T.Lambda(lambda image: image.convert('RGB')),
563
- T.ToTensor(),
564
- T.Resize((height, width)),
565
- T.Normalize([.5], [.5]),
566
- ])
567
- image = torch.stack([transforms(i) for i in image])
568
-
569
- # expected range [0,1], normalize to [-1,1]
570
- do_normalize = self.config.do_normalize
571
- if do_normalize and image.min() < 0:
572
- warnings.warn(
573
- "Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] "
574
- f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]",
575
- FutureWarning,
576
- )
577
- do_normalize = False
578
- if do_normalize:
579
- image = self.normalize(image)
580
-
581
- if self.config.do_binarize:
582
- image = self.binarize(image)
583
-
584
- return image
585
-
586
- def postprocess(
587
- self,
588
- image: torch.Tensor,
589
- output_type: str = "pil",
590
- do_denormalize: Optional[List[bool]] = None,
591
- ) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]:
592
- """
593
- Postprocess the image output from tensor to `output_type`.
594
-
595
- Args:
596
- image (`torch.Tensor`):
597
- The image input, should be a pytorch tensor with shape `B x C x H x W`.
598
- output_type (`str`, *optional*, defaults to `pil`):
599
- The output type of the image, can be one of `pil`, `np`, `pt`, `latent`.
600
- do_denormalize (`List[bool]`, *optional*, defaults to `None`):
601
- Whether to denormalize the image to [0,1]. If `None`, will use the value of `do_normalize` in the
602
- `VaeImageProcessor` config.
603
-
604
- Returns:
605
- `PIL.Image.Image`, `np.ndarray` or `torch.Tensor`:
606
- The postprocessed image.
607
- """
608
- if not isinstance(image, torch.Tensor):
609
- raise ValueError(
610
- f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor"
611
- )
612
- if output_type not in ["latent", "pt", "np", "pil"]:
613
- deprecation_message = (
614
- f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: "
615
- "`pil`, `np`, `pt`, `latent`"
616
- )
617
- deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
618
- output_type = "np"
619
-
620
- if output_type == "latent":
621
- return image
622
-
623
- if do_denormalize is None:
624
- do_denormalize = [self.config.do_normalize] * image.shape[0]
625
-
626
- image = torch.stack(
627
- [self.denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])]
628
- )
629
-
630
- if output_type == "pt":
631
- return image
632
-
633
- image = self.pt_to_numpy(image)
634
-
635
- if output_type == "np":
636
- return image
637
-
638
- if output_type == "pil":
639
- return self.numpy_to_pil(image)
640
-
641
- def apply_overlay(
642
- self,
643
- mask: PIL.Image.Image,
644
- init_image: PIL.Image.Image,
645
- image: PIL.Image.Image,
646
- crop_coords: Optional[Tuple[int, int, int, int]] = None,
647
- ) -> PIL.Image.Image:
648
- """
649
- overlay the inpaint output to the original image
650
- """
651
-
652
- width, height = image.width, image.height
653
-
654
- init_image = self.resize(init_image, width=width, height=height)
655
- mask = self.resize(mask, width=width, height=height)
656
-
657
- init_image_masked = PIL.Image.new("RGBa", (width, height))
658
- init_image_masked.paste(init_image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(mask.convert("L")))
659
- init_image_masked = init_image_masked.convert("RGBA")
660
-
661
- if crop_coords is not None:
662
- x, y, x2, y2 = crop_coords
663
- w = x2 - x
664
- h = y2 - y
665
- base_image = PIL.Image.new("RGBA", (width, height))
666
- image = self.resize(image, height=h, width=w, resize_mode="crop")
667
- base_image.paste(image, (x, y))
668
- image = base_image.convert("RGB")
669
-
670
- image = image.convert("RGBA")
671
- image.alpha_composite(init_image_masked)
672
- image = image.convert("RGB")
673
-
674
- return image