File size: 12,786 Bytes
c19ca42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
import math
import numpy as np
import skimage
import gradio as gr
from PIL import Image, ImageDraw
import modules.scripts as scripts
from modules import images
from modules.processing import Processed, process_images
from modules.shared import opts, state


# this function is taken from https://github.com/parlance-zz/g-diffuser-bot
def get_matched_noise(_np_src_image, np_mask_rgb, noise_q=1, color_variation=0.05):
    # helper fft routines that keep ortho normalization and auto-shift before and after fft
    def _fft2(data):
        if data.ndim > 2:  # has channels
            out_fft = np.zeros((data.shape[0], data.shape[1], data.shape[2]), dtype=np.complex128)
            for c in range(data.shape[2]):
                c_data = data[:, :, c]
                out_fft[:, :, c] = np.fft.fft2(np.fft.fftshift(c_data), norm="ortho")
                out_fft[:, :, c] = np.fft.ifftshift(out_fft[:, :, c])
        else:  # one channel
            out_fft = np.zeros((data.shape[0], data.shape[1]), dtype=np.complex128)
            out_fft[:, :] = np.fft.fft2(np.fft.fftshift(data), norm="ortho")
            out_fft[:, :] = np.fft.ifftshift(out_fft[:, :])
        return out_fft

    def _ifft2(data):
        if data.ndim > 2:  # has channels
            out_ifft = np.zeros((data.shape[0], data.shape[1], data.shape[2]), dtype=np.complex128)
            for c in range(data.shape[2]):
                c_data = data[:, :, c]
                out_ifft[:, :, c] = np.fft.ifft2(np.fft.fftshift(c_data), norm="ortho")
                out_ifft[:, :, c] = np.fft.ifftshift(out_ifft[:, :, c])
        else:  # one channel
            out_ifft = np.zeros((data.shape[0], data.shape[1]), dtype=np.complex128)
            out_ifft[:, :] = np.fft.ifft2(np.fft.fftshift(data), norm="ortho")
            out_ifft[:, :] = np.fft.ifftshift(out_ifft[:, :])
        return out_ifft

    def _get_gaussian_window(width, height, std=3.14, mode=0):
        window_scale_x = float(width / min(width, height))
        window_scale_y = float(height / min(width, height))
        window = np.zeros((width, height))
        x = (np.arange(width) / width * 2. - 1.) * window_scale_x
        for y in range(height):
            fy = (y / height * 2. - 1.) * window_scale_y
            if mode == 0:
                window[:, y] = np.exp(-(x ** 2 + fy ** 2) * std)
            else:
                window[:, y] = (1 / ((x ** 2 + 1.) * (fy ** 2 + 1.))) ** (std / 3.14)  # hey wait a minute that's not gaussian
        return window

    def _get_masked_window_rgb(np_mask_grey, hardness=1.):
        np_mask_rgb = np.zeros((np_mask_grey.shape[0], np_mask_grey.shape[1], 3))
        if hardness != 1.:
            hardened = np_mask_grey[:] ** hardness
        else:
            hardened = np_mask_grey[:]
        for c in range(3):
            np_mask_rgb[:, :, c] = hardened[:]
        return np_mask_rgb
    width = _np_src_image.shape[0]
    height = _np_src_image.shape[1]
    num_channels = _np_src_image.shape[2]
    _np_src_image[:] * (1. - np_mask_rgb) # pylint: disable=pointless-statement
    np_mask_grey = np.sum(np_mask_rgb, axis=2) / 3.
    img_mask = np_mask_grey > 1e-6
    ref_mask = np_mask_grey < 1e-3
    windowed_image = _np_src_image * (1. - _get_masked_window_rgb(np_mask_grey))
    windowed_image /= np.max(windowed_image)
    windowed_image += np.average(_np_src_image) * np_mask_rgb  # / (1.-np.average(np_mask_rgb))  # rather than leave the masked area black, we get better results from fft by filling the average unmasked color
    src_fft = _fft2(windowed_image)  # get feature statistics from masked src img
    src_dist = np.absolute(src_fft)
    src_phase = src_fft / src_dist
    # create a generator with a static seed to make outpainting deterministic / only follow global seed
    rng = np.random.default_rng(0)
    noise_window = _get_gaussian_window(width, height, mode=1)  # start with simple gaussian noise
    noise_rgb = rng.random((width, height, num_channels))
    noise_grey = np.sum(noise_rgb, axis=2) / 3.
    noise_rgb *= color_variation  # the colorfulness of the starting noise is blended to greyscale with a parameter
    for c in range(num_channels):
        noise_rgb[:, :, c] += (1. - color_variation) * noise_grey
    noise_fft = _fft2(noise_rgb)
    for c in range(num_channels):
        noise_fft[:, :, c] *= noise_window
    noise_rgb = np.real(_ifft2(noise_fft))
    shaped_noise_fft = _fft2(noise_rgb)
    shaped_noise_fft[:, :, :] = np.absolute(shaped_noise_fft[:, :, :]) ** 2 * (src_dist ** noise_q) * src_phase  # perform the actual shaping
    brightness_variation = 0.  # color_variation
    contrast_adjusted_np_src = _np_src_image[:] * (brightness_variation + 1.) - brightness_variation * 2.
    # scikit-image is used for histogram matching, very convenient!
    shaped_noise = np.real(_ifft2(shaped_noise_fft))
    shaped_noise -= np.min(shaped_noise)
    shaped_noise /= np.max(shaped_noise)
    shaped_noise[img_mask, :] = skimage.exposure.match_histograms(shaped_noise[img_mask, :] ** 1., contrast_adjusted_np_src[ref_mask, :], channel_axis=1)
    shaped_noise = _np_src_image[:] * (1. - np_mask_rgb) + shaped_noise * np_mask_rgb
    matched_noise = shaped_noise[:]
    return np.clip(matched_noise, 0., 1.)


class Script(scripts.Script):
    def title(self):
        return "Outpainting"

    def show(self, is_img2img):
        return is_img2img

    def ui(self, is_img2img):
        if not is_img2img:
            return None
        with gr.Row():
            info = gr.HTML("<span>&nbsp Outpainting</span><br>")
        with gr.Row():
            pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128, elem_id=self.elem_id("pixels"))
            mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=8, elem_id=self.elem_id("mask_blur"))
        direction = gr.CheckboxGroup(label="Outpainting direction", choices=['left', 'right', 'up', 'down'], value=['left', 'right', 'up', 'down'], elem_id=self.elem_id("direction"))
        with gr.Row():
            noise_q = gr.Slider(label="Fall-off exponent (lower=higher detail)", minimum=0.0, maximum=4.0, step=0.01, value=1.0, elem_id=self.elem_id("noise_q"))
            color_variation = gr.Slider(label="Color variation", minimum=0.0, maximum=1.0, step=0.01, value=0.05, elem_id=self.elem_id("color_variation"))
        return [info, pixels, mask_blur, direction, noise_q, color_variation]

    def run(self, p, _, pixels, mask_blur, direction, noise_q, color_variation): # pylint: disable=arguments-differ
        initial_seed_and_info = [None, None]
        process_width = p.width
        process_height = p.height
        p.mask_blur = mask_blur*4
        p.inpaint_full_res = False
        p.inpainting_fill = 1
        p.do_not_save_samples = True
        p.do_not_save_grid = True
        left = pixels if "left" in direction else 0
        right = pixels if "right" in direction else 0
        up = pixels if "up" in direction else 0
        down = pixels if "down" in direction else 0
        init_img = p.init_images[0]
        target_w = math.ceil((init_img.width + left + right) / 64) * 64
        target_h = math.ceil((init_img.height + up + down) / 64) * 64
        if left > 0:
            left = left * (target_w - init_img.width) // (left + right)
        if right > 0:
            right = target_w - init_img.width - left
        if up > 0:
            up = up * (target_h - init_img.height) // (up + down)
        if down > 0:
            down = target_h - init_img.height - up
        def expand(init, count, expand_pixels, is_left=False, is_right=False, is_top=False, is_bottom=False):
            is_horiz = is_left or is_right
            is_vert = is_top or is_bottom
            pixels_horiz = expand_pixels if is_horiz else 0
            pixels_vert = expand_pixels if is_vert else 0
            images_to_process = []
            output_images = []
            for n in range(count):
                res_w = init[n].width + pixels_horiz
                res_h = init[n].height + pixels_vert
                process_res_w = math.ceil(res_w / 64) * 64
                process_res_h = math.ceil(res_h / 64) * 64
                img = Image.new("RGB", (process_res_w, process_res_h))
                img.paste(init[n], (pixels_horiz if is_left else 0, pixels_vert if is_top else 0))
                mask = Image.new("RGB", (process_res_w, process_res_h), "white")
                draw = ImageDraw.Draw(mask)
                draw.rectangle((
                    expand_pixels + mask_blur if is_left else 0,
                    expand_pixels + mask_blur if is_top else 0,
                    mask.width - expand_pixels - mask_blur if is_right else res_w,
                    mask.height - expand_pixels - mask_blur if is_bottom else res_h,
                ), fill="black")
                np_image = (np.asarray(img) / 255.0).astype(np.float64)
                np_mask = (np.asarray(mask) / 255.0).astype(np.float64)
                noised = get_matched_noise(np_image, np_mask, noise_q, color_variation)
                output_images.append(Image.fromarray(np.clip(255.0 * noised, 0.0, 255.0).astype(np.uint8), mode="RGB"))
                target_width = min(process_width, init[n].width + pixels_horiz) if is_horiz else img.width
                target_height = min(process_height, init[n].height + pixels_vert) if is_vert else img.height
                p.width = target_width if is_horiz else img.width
                p.height = target_height if is_vert else img.height
                crop_region = (
                    0 if is_left else output_images[n].width - target_width,
                    0 if is_top else output_images[n].height - target_height,
                    target_width if is_left else output_images[n].width,
                    target_height if is_top else output_images[n].height,
                )
                mask = mask.crop(crop_region)
                p.image_mask = mask
                image_to_process = output_images[n].crop(crop_region)
                images_to_process.append(image_to_process)
            p.init_images = images_to_process
            latent_mask = Image.new("RGB", (p.width, p.height), "white")
            draw = ImageDraw.Draw(latent_mask)
            draw.rectangle((
                expand_pixels + mask_blur * 2 if is_left else 0,
                expand_pixels + mask_blur * 2 if is_top else 0,
                mask.width - expand_pixels - mask_blur * 2 if is_right else res_w,
                mask.height - expand_pixels - mask_blur * 2 if is_bottom else res_h,
            ), fill="black")
            p.latent_mask = latent_mask
            proc = process_images(p)
            if initial_seed_and_info[0] is None:
                initial_seed_and_info[0] = proc.seed
                initial_seed_and_info[1] = proc.info
            for n in range(count):
                output_images[n].paste(proc.images[n], (0 if is_left else output_images[n].width - proc.images[n].width, 0 if is_top else output_images[n].height - proc.images[n].height))
                output_images[n] = output_images[n].crop((0, 0, res_w, res_h))
            return output_images
        batch_count = p.n_iter
        batch_size = p.batch_size
        p.n_iter = 1
        state.job_count = batch_count * ((1 if left > 0 else 0) + (1 if right > 0 else 0) + (1 if up > 0 else 0) + (1 if down > 0 else 0))
        all_processed_images = []
        for i in range(batch_count):
            imgs = [init_img] * batch_size
            state.job = f"outpainting batch {i+1}/{batch_count}"
            if left > 0:
                imgs = expand(imgs, batch_size, left, is_left=True)
            if right > 0:
                imgs = expand(imgs, batch_size, right, is_right=True)
            if up > 0:
                imgs = expand(imgs, batch_size, up, is_top=True)
            if down > 0:
                imgs = expand(imgs, batch_size, down, is_bottom=True)
            all_processed_images += imgs
        all_images = all_processed_images
        combined_grid_image = images.image_grid(all_processed_images)
        if opts.return_grid and len(all_processed_images) > 1:
            all_images = [combined_grid_image] + all_processed_images
        res = Processed(p, all_images, initial_seed_and_info[0], initial_seed_and_info[1])
        if opts.samples_save:
            for img in all_processed_images:
                images.save_image(img, p.outpath_samples, "", res.seed, p.prompt, opts.samples_format, info=res.info, p=p)
        if opts.grid_save and len(all_processed_images) > 1:
            images.save_image(combined_grid_image, p.outpath_grids, "grid", res.seed, p.prompt, opts.samples_format, info=res.info, grid=True, p=p)
        return res