Sqxww commited on
Commit
7a6754c
·
1 Parent(s): 7c876ac

initial commit

Browse files
README.md CHANGED
@@ -11,3 +11,5 @@ license: mit
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
14
+
15
+ Modified from: https://huggingface.co/spaces/turboedit/turbo_edit
app.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ from app_base import create_demo as create_demo_face
4
+
5
+ with gr.Blocks(css="style.css") as demo:
6
+ with gr.Tabs():
7
+ with gr.Tab(label="Face"):
8
+ create_demo_face()
9
+
10
+ demo.launch(server_name="0.0.0.0")
app_base.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import gradio as gr
3
+ import time
4
+ import torch
5
+ import tempfile
6
+ import os
7
+ import gc
8
+
9
+ from loading_utils import load_image
10
+
11
+ from segment_utils import(
12
+ segment_image,
13
+ restore_result,
14
+ )
15
+ from enhance_utils import enhance_sd_image
16
+ from inversion_run_base import run as base_run
17
+
18
+ DEFAULT_SRC_PROMPT = "a person"
19
+ DEFAULT_EDIT_PROMPT = "a person with perfect face"
20
+
21
+ DEFAULT_CATEGORY = "face"
22
+
23
+ def image_to_image(
24
+ input_image_path: str,
25
+ input_image_prompt: str,
26
+ edit_prompt: str,
27
+ seed: int,
28
+ w1: float,
29
+ num_steps: int,
30
+ start_step: int,
31
+ guidance_scale: float,
32
+ generate_size: int,
33
+ mask_expansion: int = 50,
34
+ mask_dilation: int = 2,
35
+ save_quality: int = 95,
36
+ enable_segment: bool = True,
37
+ ):
38
+ segment_category = "face"
39
+ w2 = 1.0
40
+ run_task_time = 0
41
+ time_cost_str = ''
42
+
43
+ run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
44
+ input_image = load_image(input_image_path)
45
+ icc_profile = input_image.info.get('icc_profile')
46
+ run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str, 'load_image done')
47
+
48
+ if enable_segment:
49
+ target_area_image, croper = segment_image(
50
+ input_image,
51
+ segment_category,
52
+ generate_size,
53
+ mask_expansion,
54
+ mask_dilation,
55
+ )
56
+ else:
57
+ target_area_image = resize_image(input_image, generate_size)
58
+ croper = None
59
+
60
+ run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str, 'segment_image done')
61
+
62
+ run_model = base_run
63
+ try:
64
+ res_image = run_model(
65
+ target_area_image,
66
+ input_image_prompt,
67
+ edit_prompt ,
68
+ seed,
69
+ w1,
70
+ w2,
71
+ num_steps,
72
+ start_step,
73
+ guidance_scale,
74
+ )
75
+ run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str, 'run_sd_model done')
76
+
77
+ finally:
78
+ torch.cuda.empty_cache()
79
+ run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str, 'cuda_empty_cache done')
80
+
81
+ enhanced_image = res_image
82
+ enhanced_image = enhance_sd_image(res_image)
83
+ run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str, 'enhance_image done')
84
+
85
+ if enable_segment:
86
+ restored_image = restore_result(croper, segment_category, enhanced_image)
87
+ else:
88
+ restored_image = enhanced_image.resize(input_image.size)
89
+ run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str, 'restore_result done')
90
+
91
+ torch.cuda.empty_cache()
92
+ run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str, 'cuda_empty_cache done')
93
+ if os.getenv('ENABLE_GC', False):
94
+ gc.collect()
95
+ run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str, 'gc_collect done')
96
+
97
+ extension = 'png'
98
+ if restored_image.mode == 'RGBA':
99
+ extension = 'png'
100
+ else:
101
+ extension = 'webp'
102
+
103
+ output_path = tempfile.mktemp(suffix=f".{extension}")
104
+ restored_image.save(output_path, format=extension, quality=save_quality, icc_profile=icc_profile)
105
+
106
+ run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str, 'save_image done')
107
+
108
+ return output_path, restored_image, time_cost_str
109
+
110
+ def get_time_cost(
111
+ run_task_time,
112
+ time_cost_str,
113
+ step: str = ''
114
+ ):
115
+ now_time = int(time.time()*1000)
116
+ if run_task_time == 0:
117
+ time_cost_str = 'start'
118
+ else:
119
+ if time_cost_str != '':
120
+ time_cost_str += f'-->'
121
+ time_cost_str += f'{now_time - run_task_time}'
122
+ if step != '':
123
+ time_cost_str += f'-->{step}'
124
+ run_task_time = now_time
125
+ return run_task_time, time_cost_str
126
+
127
+ def resize_image(image, target_size = 1024):
128
+ h, w = image.size
129
+ if h >= w:
130
+ w = int(w * target_size / h)
131
+ h = target_size
132
+ else:
133
+ h = int(h * target_size / w)
134
+ w = target_size
135
+ return image.resize((w, h))
136
+
137
+
138
+ def infer(
139
+ input_image_path: str,
140
+ input_image_prompt: str,
141
+ edit_prompt: str,
142
+ seed: int,
143
+ w1: float,
144
+ num_steps: int,
145
+ start_step: int,
146
+ guidance_scale: float,
147
+ generate_size: int,
148
+ mask_expansion: int = 50,
149
+ mask_dilation: int = 2,
150
+ save_quality: int = 95,
151
+ enable_segment: bool = True,
152
+ ):
153
+ return image_to_image(
154
+ input_image_path,
155
+ input_image_prompt,
156
+ edit_prompt,
157
+ seed,
158
+ w1,
159
+ num_steps,
160
+ start_step,
161
+ guidance_scale,
162
+ generate_size,
163
+ mask_expansion,
164
+ mask_dilation,
165
+ save_quality,
166
+ enable_segment
167
+ )
168
+
169
+ infer = spaces.GPU(infer)
170
+
171
+ def create_demo() -> gr.Blocks:
172
+
173
+ with gr.Blocks() as demo:
174
+ with gr.Row():
175
+ with gr.Column():
176
+ input_image_prompt = gr.Textbox(lines=1, label="Input Image Prompt", value=DEFAULT_SRC_PROMPT)
177
+ edit_prompt = gr.Textbox(lines=1, label="Edit Prompt", value=DEFAULT_EDIT_PROMPT)
178
+ with gr.Accordion("Advanced Options", open=False):
179
+ enable_segment = gr.Checkbox(label="Enable Segment", value=True)
180
+ mask_expansion = gr.Number(label="Mask Expansion", value=50, visible=True)
181
+ mask_dilation = gr.Slider(minimum=0, maximum=10, value=2, step=1, label="Mask Dilation")
182
+ save_quality = gr.Slider(minimum=1, maximum=100, value=95, step=1, label="Save Quality")
183
+ with gr.Column():
184
+ num_steps = gr.Slider(minimum=1, maximum=100, value=20, step=1, label="Num Steps")
185
+ start_step = gr.Slider(minimum=1, maximum=100, value=15, step=1, label="Start Step")
186
+ g_btn = gr.Button("Edit Image")
187
+ with gr.Accordion("Advanced Options", open=False):
188
+ guidance_scale = gr.Slider(minimum=0, maximum=20, value=0, step=0.5, label="Guidance Scale")
189
+ seed = gr.Number(label="Seed", value=8)
190
+ w1 = gr.Number(label="W1", value=1.5)
191
+ generate_size = gr.Number(label="Generate Size", value=1024)
192
+
193
+ with gr.Row():
194
+ with gr.Column():
195
+ input_image_path = gr.Image(label="Input Image", type="filepath", interactive=True)
196
+ with gr.Column():
197
+ restored_image = gr.Image(label="Restored Image", format="png", type="pil", interactive=False)
198
+ download_path = gr.File(label="Download the output image", interactive=False)
199
+ generated_cost = gr.Textbox(label="Time cost by step (ms):", visible=True, interactive=False)
200
+
201
+ g_btn.click(
202
+ fn=infer,
203
+ inputs=[input_image_path, input_image_prompt, edit_prompt,seed,w1, num_steps, start_step, guidance_scale, generate_size, mask_expansion, mask_dilation, save_quality, enable_segment],
204
+ outputs=[download_path, restored_image, generated_cost],
205
+ )
206
+
207
+
208
+
209
+ return demo
checkpoints/selfie_multiclass_256x256.tflite ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c6748b1253a99067ef71f7e26ca71096cd449baefa8f101900ea23016507e0e0
3
+ size 16371837
config.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ml_collections import config_dict
2
+ import yaml
3
+ from diffusers.schedulers import (
4
+ DDIMScheduler,
5
+ EulerAncestralDiscreteScheduler,
6
+ EulerDiscreteScheduler,
7
+ DDPMScheduler,
8
+ )
9
+ from inversion_utils import (
10
+ deterministic_ddim_step,
11
+ deterministic_ddpm_step,
12
+ deterministic_euler_step,
13
+ deterministic_non_ancestral_euler_step,
14
+ )
15
+
16
+ BREAKDOWNS = ["x_t_c_hat", "x_t_hat_c", "no_breakdown", "x_t_hat_c_with_zeros"]
17
+ SCHEDULERS = ["ddpm", "ddim", "euler", "euler_non_ancestral"]
18
+ MODELS = [
19
+ "stabilityai/sdxl-turbo",
20
+ "stabilityai/stable-diffusion-xl-base-1.0",
21
+ "CompVis/stable-diffusion-v1-4",
22
+ ]
23
+
24
+ def get_num_steps_actual(cfg):
25
+ return (
26
+ cfg.num_steps_inversion
27
+ - cfg.step_start
28
+ + (1 if cfg.clean_step_timestep > 0 else 0)
29
+ if cfg.timesteps is None
30
+ else len(cfg.timesteps) + (1 if cfg.clean_step_timestep > 0 else 0)
31
+ )
32
+
33
+
34
+ def get_config(args):
35
+ if args.config_from_file and args.config_from_file != "":
36
+ with open(args.config_from_file, "r") as f:
37
+ cfg = config_dict.ConfigDict(yaml.safe_load(f))
38
+
39
+ num_steps_actual = get_num_steps_actual(cfg)
40
+
41
+ else:
42
+ cfg = config_dict.ConfigDict()
43
+
44
+ cfg.seed = 2
45
+ cfg.self_r = 0.5
46
+ cfg.cross_r = 0.9
47
+ cfg.eta = 1
48
+ cfg.scheduler_type = SCHEDULERS[0]
49
+
50
+ cfg.num_steps_inversion = 50 # timesteps: 999, 799, 599, 399, 199
51
+ cfg.step_start = 20
52
+ cfg.timesteps = None
53
+ cfg.noise_timesteps = None
54
+ num_steps_actual = get_num_steps_actual(cfg)
55
+ cfg.ws1 = [2] * num_steps_actual
56
+ cfg.ws2 = [1] * num_steps_actual
57
+ cfg.real_cfg_scale = 0
58
+ cfg.real_cfg_scale_save = 0
59
+ cfg.breakdown = BREAKDOWNS[1]
60
+ cfg.noise_shift_delta = 1
61
+ cfg.max_norm_zs = [-1] * (num_steps_actual - 1) + [15.5]
62
+
63
+ cfg.clean_step_timestep = 0
64
+
65
+ cfg.model = MODELS[1]
66
+
67
+ if cfg.scheduler_type == "ddim":
68
+ cfg.scheduler_class = DDIMScheduler
69
+ cfg.step_function = deterministic_ddim_step
70
+ elif cfg.scheduler_type == "ddpm":
71
+ cfg.scheduler_class = DDPMScheduler
72
+ cfg.step_function = deterministic_ddpm_step
73
+ elif cfg.scheduler_type == "euler":
74
+ cfg.scheduler_class = EulerAncestralDiscreteScheduler
75
+ cfg.step_function = deterministic_euler_step
76
+ elif cfg.scheduler_type == "euler_non_ancestral":
77
+ cfg.scheduler_class = EulerDiscreteScheduler
78
+ cfg.step_function = deterministic_non_ancestral_euler_step
79
+ else:
80
+ raise ValueError(f"Unknown scheduler type: {cfg.scheduler_type}")
81
+
82
+ with cfg.ignore_type():
83
+ if isinstance(cfg.max_norm_zs, (int, float)):
84
+ cfg.max_norm_zs = [cfg.max_norm_zs] * num_steps_actual
85
+
86
+ if isinstance(cfg.ws1, (int, float)):
87
+ cfg.ws1 = [cfg.ws1] * num_steps_actual
88
+
89
+ if isinstance(cfg.ws2, (int, float)):
90
+ cfg.ws2 = [cfg.ws2] * num_steps_actual
91
+
92
+ if not hasattr(cfg, "update_eta"):
93
+ cfg.update_eta = False
94
+
95
+ if not hasattr(cfg, "save_timesteps"):
96
+ cfg.save_timesteps = None
97
+
98
+ if not hasattr(cfg, "scheduler_timesteps"):
99
+ cfg.scheduler_timesteps = None
100
+
101
+ assert (
102
+ cfg.scheduler_type == "ddpm" or cfg.timesteps is None
103
+ ), "timesteps must be None for ddim/euler"
104
+
105
+ cfg.max_norm_zs = [-1] * (num_steps_actual - 1) + [15.5]
106
+ assert (
107
+ len(cfg.max_norm_zs) == num_steps_actual
108
+ ), f"len(cfg.max_norm_zs) ({len(cfg.max_norm_zs)}) != num_steps_actual ({num_steps_actual})"
109
+
110
+ assert (
111
+ len(cfg.ws1) == num_steps_actual
112
+ ), f"len(cfg.ws1) ({len(cfg.ws1)}) != num_steps_actual ({num_steps_actual})"
113
+
114
+ assert (
115
+ len(cfg.ws2) == num_steps_actual
116
+ ), f"len(cfg.ws2) ({len(cfg.ws2)}) != num_steps_actual ({num_steps_actual})"
117
+
118
+ assert cfg.noise_timesteps is None or len(cfg.noise_timesteps) == (
119
+ num_steps_actual - (1 if cfg.clean_step_timestep > 0 else 0)
120
+ ), f"len(cfg.noise_timesteps) ({len(cfg.noise_timesteps)}) != num_steps_actual ({num_steps_actual})"
121
+
122
+ assert cfg.save_timesteps is None or len(cfg.save_timesteps) == (
123
+ num_steps_actual - (1 if cfg.clean_step_timestep > 0 else 0)
124
+ ), f"len(cfg.save_timesteps) ({len(cfg.save_timesteps)}) != num_steps_actual ({num_steps_actual})"
125
+
126
+ return cfg
127
+
128
+
129
+ def get_config_name(config, args):
130
+ if args.folder_name is not None and args.folder_name != "":
131
+ return args.folder_name
132
+ timesteps_str = (
133
+ f"step_start {config.step_start}"
134
+ if config.timesteps is None
135
+ else f"timesteps {config.timesteps}"
136
+ )
137
+ return f"""\
138
+ ws1 {config.ws1[0]} ws2 {config.ws2[0]} real_cfg_scale {config.real_cfg_scale} {timesteps_str} \
139
+ real_cfg_scale_save {config.real_cfg_scale_save} seed {config.seed} max_norm_zs {config.max_norm_zs[-1]} noise_shift_delta {config.noise_shift_delta} \
140
+ scheduler_type {config.scheduler_type} fp16 {args.fp16}\
141
+ """
croper.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import PIL
2
+ import numpy as np
3
+
4
+ from PIL import Image
5
+
6
+ class Croper:
7
+ def __init__(
8
+ self,
9
+ input_image: PIL.Image,
10
+ target_mask: np.ndarray,
11
+ mask_size: int = 256,
12
+ mask_expansion: int = 20,
13
+ ):
14
+ self.input_image = input_image
15
+ self.target_mask = target_mask
16
+ self.mask_size = mask_size
17
+ self.mask_expansion = mask_expansion
18
+
19
+ def corp_mask_image(self):
20
+ target_mask = self.target_mask
21
+ input_image = self.input_image
22
+ mask_expansion = self.mask_expansion
23
+ original_width, original_height = input_image.size
24
+ mask_indices = np.where(target_mask)
25
+ start_y = np.min(mask_indices[0])
26
+ end_y = np.max(mask_indices[0])
27
+ start_x = np.min(mask_indices[1])
28
+ end_x = np.max(mask_indices[1])
29
+ mask_height = end_y - start_y
30
+ mask_width = end_x - start_x
31
+ # choose the max side length
32
+ max_side_length = max(mask_height, mask_width)
33
+ # expand the mask area
34
+ height_diff = (max_side_length - mask_height) // 2
35
+ width_diff = (max_side_length - mask_width) // 2
36
+ start_y = start_y - mask_expansion - height_diff
37
+ if start_y < 0:
38
+ start_y = 0
39
+ end_y = end_y + mask_expansion + height_diff
40
+ if end_y > original_height:
41
+ end_y = original_height
42
+ start_x = start_x - mask_expansion - width_diff
43
+ if start_x < 0:
44
+ start_x = 0
45
+ end_x = end_x + mask_expansion + width_diff
46
+ if end_x > original_width:
47
+ end_x = original_width
48
+ expanded_height = end_y - start_y
49
+ expanded_width = end_x - start_x
50
+ expanded_max_side_length = max(expanded_height, expanded_width)
51
+ # calculate the crop area
52
+ crop_mask = target_mask[start_y:end_y, start_x:end_x]
53
+ crop_mask_start_y = (expanded_max_side_length - expanded_height) // 2
54
+ crop_mask_end_y = crop_mask_start_y + expanded_height
55
+ crop_mask_start_x = (expanded_max_side_length - expanded_width) // 2
56
+ crop_mask_end_x = crop_mask_start_x + expanded_width
57
+ # create a square mask
58
+ square_mask = np.zeros((expanded_max_side_length, expanded_max_side_length), dtype=target_mask.dtype)
59
+ square_mask[crop_mask_start_y:crop_mask_end_y, crop_mask_start_x:crop_mask_end_x] = crop_mask
60
+ square_mask_image = Image.fromarray((square_mask * 255).astype(np.uint8))
61
+
62
+ crop_image = input_image.crop((start_x, start_y, end_x, end_y))
63
+ square_image = Image.new("RGB", (expanded_max_side_length, expanded_max_side_length))
64
+ square_image.paste(crop_image, (crop_mask_start_x, crop_mask_start_y))
65
+
66
+ self.origin_start_x = start_x
67
+ self.origin_start_y = start_y
68
+ self.origin_end_x = end_x
69
+ self.origin_end_y = end_y
70
+
71
+ self.square_start_x = crop_mask_start_x
72
+ self.square_start_y = crop_mask_start_y
73
+ self.square_end_x = crop_mask_end_x
74
+ self.square_end_y = crop_mask_end_y
75
+
76
+ self.square_length = expanded_max_side_length
77
+ self.square_mask_image = square_mask_image
78
+ self.square_image = square_image
79
+ self.corp_mask = crop_mask
80
+
81
+ mask_size = self.mask_size
82
+ self.resized_square_mask_image = square_mask_image.resize((mask_size, mask_size))
83
+ self.resized_square_image = square_image.resize((mask_size, mask_size))
84
+
85
+ return self.resized_square_mask_image
86
+
87
+ def restore_result(self, generated_image):
88
+ square_length = self.square_length
89
+ generated_image = generated_image.resize((square_length, square_length))
90
+ square_mask_image = self.square_mask_image
91
+ cropped_generated_image = generated_image.crop((self.square_start_x, self.square_start_y, self.square_end_x, self.square_end_y))
92
+ cropped_square_mask_image = square_mask_image.crop((self.square_start_x, self.square_start_y, self.square_end_x, self.square_end_y))
93
+
94
+ restored_image = self.input_image.copy()
95
+ restored_image.paste(cropped_generated_image, (self.origin_start_x, self.origin_start_y), cropped_square_mask_image)
96
+
97
+ return restored_image
98
+
99
+ def restore_result_v2(self, generated_image):
100
+ square_length = self.square_length
101
+ generated_image = generated_image.resize((square_length, square_length))
102
+ cropped_generated_image = generated_image.crop((self.square_start_x, self.square_start_y, self.square_end_x, self.square_end_y))
103
+
104
+ restored_image = self.input_image.copy()
105
+ restored_image.paste(cropped_generated_image, (self.origin_start_x, self.origin_start_y))
106
+
107
+ return restored_image
108
+
enhance_utils.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from PIL import Image
4
+ from model_handler import MODELS
5
+
6
+ upscaler = MODELS.upscaler
7
+ upscaler4SD = MODELS.upscaler4SD
8
+
9
+ def enhance_image(
10
+ input_image: Image,
11
+ ):
12
+
13
+ h, w = input_image.size
14
+ max_scale_size = 1024
15
+ if h > max_scale_size:
16
+ w = int(w * max_scale_size / h)
17
+ h = max_scale_size
18
+ if w > max_scale_size:
19
+ h = int(h * max_scale_size / w)
20
+ w = max_scale_size
21
+
22
+ if h != input_image.size[1] or w != input_image.size[0]:
23
+ input_image = input_image.resize((w, h))
24
+
25
+ if os.environ.get("TILING", False):
26
+ tileSizeStr = os.environ.get("TILE_SIZE", 1024)
27
+ tileSize = int(tileSizeStr)
28
+ enhanced_image = upscaler(input_image, tiling=True, tile_width=tileSize, tile_height=tileSize)
29
+ else:
30
+ enhanced_image = upscaler(input_image)
31
+
32
+ return enhanced_image
33
+
34
+ def enhance_sd_image(
35
+ input_image: Image,
36
+ ):
37
+ if os.environ.get("TILING", False):
38
+ tileSizeStr = os.environ.get("TILE_SIZE", 1024)
39
+ tileSize = int(tileSizeStr)
40
+ enhanced_image = upscaler4SD(input_image, tiling=True, tile_width=tileSize, tile_height=tileSize)
41
+ else:
42
+ enhanced_image = upscaler4SD(input_image)
43
+
44
+ return enhanced_image
inversion_run_base.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img import retrieve_timesteps, retrieve_latents
4
+ from PIL import Image
5
+ from inversion_utils import get_ddpm_inversion_scheduler, create_xts
6
+ from config import get_config, get_num_steps_actual
7
+ from functools import partial
8
+ from compel import Compel, ReturnedEmbeddingsType
9
+
10
+ from model_handler import MODELS
11
+
12
+ class Object(object):
13
+ pass
14
+
15
+ args = Object()
16
+ args.images_paths = None
17
+ args.images_folder = None
18
+ args.force_use_cpu = False
19
+ args.folder_name = 'test_measure_time'
20
+ args.config_from_file = 'run_configs/noise_shift_guidance_1_5.yaml'
21
+ args.save_intermediate_results = False
22
+ args.batch_size = None
23
+ args.skip_p_to_p = True
24
+ args.only_p_to_p = False
25
+ args.fp16 = False
26
+ args.prompts_file = 'dataset_measure_time/dataset.json'
27
+ args.images_in_prompts_file = None
28
+ args.seed = 986
29
+ args.time_measure_n = 1
30
+
31
+
32
+ assert (
33
+ args.batch_size is None or args.save_intermediate_results is False
34
+ ), "save_intermediate_results is not implemented for batch_size > 1"
35
+
36
+ generator = None
37
+ device = "cuda"
38
+
39
+ pipeline = MODELS.base_pipe
40
+
41
+ config = get_config(args)
42
+
43
+ compel_proc = Compel(
44
+ tokenizer=[pipeline.tokenizer, pipeline.tokenizer_2] ,
45
+ text_encoder=[pipeline.text_encoder, pipeline.text_encoder_2],
46
+ returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
47
+ requires_pooled=[False, True]
48
+ )
49
+
50
+ def run(
51
+ input_image:Image,
52
+ src_prompt:str,
53
+ tgt_prompt:str,
54
+ seed:int,
55
+ w1:float,
56
+ w2:float,
57
+ num_steps:int,
58
+ start_step:int,
59
+ guidance_scale:float,
60
+ ):
61
+ generator = torch.Generator().manual_seed(seed)
62
+
63
+ config.num_steps_inversion = num_steps
64
+ config.step_start = start_step
65
+ num_steps_actual = get_num_steps_actual(config)
66
+
67
+
68
+ num_steps_inversion = config.num_steps_inversion
69
+ denoising_start = (num_steps_inversion - num_steps_actual) / num_steps_inversion
70
+
71
+ timesteps, num_inference_steps = retrieve_timesteps(
72
+ pipeline.scheduler, num_steps_inversion, device, None
73
+ )
74
+ timesteps, num_inference_steps = pipeline.get_timesteps(
75
+ num_inference_steps=num_inference_steps,
76
+ denoising_start=denoising_start,
77
+ strength=0,
78
+ device=device,
79
+ )
80
+ timesteps = timesteps.type(torch.int64)
81
+
82
+ timesteps = [torch.tensor(t) for t in timesteps.tolist()]
83
+ timesteps_len = len(timesteps)
84
+ config.step_start = start_step + num_steps_actual - timesteps_len
85
+ num_steps_actual = timesteps_len
86
+ config.max_norm_zs = [-1] * (num_steps_actual - 1) + [15.5]
87
+
88
+ pipeline.__call__ = partial(
89
+ pipeline.__call__,
90
+ num_inference_steps=num_steps_inversion,
91
+ guidance_scale=guidance_scale,
92
+ generator=generator,
93
+ denoising_start=denoising_start,
94
+ strength=0,
95
+ )
96
+
97
+ x_0_image = input_image
98
+ x_0 = encode_image(x_0_image, pipeline)
99
+ x_ts = create_xts(1, None, 0, generator, pipeline.scheduler, timesteps, x_0, no_add_noise=False)
100
+ x_ts = [xt.to(dtype=torch.float16) for xt in x_ts]
101
+ latents = [x_ts[0]]
102
+ x_ts_c_hat = [None]
103
+ config.ws1 = [w1] * num_steps_actual
104
+ config.ws2 = [w2] * num_steps_actual
105
+ pipeline.scheduler = get_ddpm_inversion_scheduler(
106
+ pipeline.scheduler,
107
+ config.step_function,
108
+ config,
109
+ timesteps,
110
+ config.save_timesteps,
111
+ latents,
112
+ x_ts,
113
+ x_ts_c_hat,
114
+ args.save_intermediate_results,
115
+ pipeline,
116
+ x_0,
117
+ v1s_images := [],
118
+ v2s_images := [],
119
+ deltas_images := [],
120
+ v1_x0s := [],
121
+ v2_x0s := [],
122
+ deltas_x0s := [],
123
+ "res12",
124
+ image_name="im_name",
125
+ time_measure_n=args.time_measure_n,
126
+ )
127
+ latent = latents[0].expand(3, -1, -1, -1)
128
+ prompt = [src_prompt, src_prompt, tgt_prompt]
129
+ conditioning, pooled = compel_proc(prompt)
130
+ image = pipeline.__call__(
131
+ image=latent,
132
+ prompt_embeds=conditioning,
133
+ pooled_prompt_embeds=pooled,
134
+ eta=1,
135
+ ).images
136
+ return image[2]
137
+
138
+ def encode_image(image, pipe):
139
+ image = pipe.image_processor.preprocess(image)
140
+ originDtype = pipe.dtype
141
+ image = image.to(device=device, dtype=originDtype)
142
+
143
+ if pipe.vae.config.force_upcast:
144
+ image = image.float()
145
+ pipe.vae.to(dtype=torch.float32)
146
+
147
+ if isinstance(generator, list):
148
+ init_latents = [
149
+ retrieve_latents(pipe.vae.encode(image[i : i + 1]), generator=generator[i])
150
+ for i in range(1)
151
+ ]
152
+ init_latents = torch.cat(init_latents, dim=0)
153
+ else:
154
+ init_latents = retrieve_latents(pipe.vae.encode(image), generator=generator)
155
+
156
+ if pipe.vae.config.force_upcast:
157
+ pipe.vae.to(originDtype)
158
+
159
+ init_latents = init_latents.to(originDtype)
160
+ init_latents = pipe.vae.config.scaling_factor * init_latents
161
+
162
+ return init_latents.to(dtype=torch.float16)
163
+
164
+ def get_timesteps(pipe, num_inference_steps, strength, device, denoising_start=None):
165
+ # get the original timestep using init_timestep
166
+ if denoising_start is None:
167
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
168
+ t_start = max(num_inference_steps - init_timestep, 0)
169
+ else:
170
+ t_start = 0
171
+
172
+ timesteps = pipe.scheduler.timesteps[t_start * pipe.scheduler.order :]
173
+
174
+ # Strength is irrelevant if we directly request a timestep to start at;
175
+ # that is, strength is determined by the denoising_start instead.
176
+ if denoising_start is not None:
177
+ discrete_timestep_cutoff = int(
178
+ round(
179
+ pipe.scheduler.config.num_train_timesteps
180
+ - (denoising_start * pipe.scheduler.config.num_train_timesteps)
181
+ )
182
+ )
183
+
184
+ num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item()
185
+ if pipe.scheduler.order == 2 and num_inference_steps % 2 == 0:
186
+ # if the scheduler is a 2nd order scheduler we might have to do +1
187
+ # because `num_inference_steps` might be even given that every timestep
188
+ # (except the highest one) is duplicated. If `num_inference_steps` is even it would
189
+ # mean that we cut the timesteps in the middle of the denoising step
190
+ # (between 1st and 2nd derivative) which leads to incorrect results. By adding 1
191
+ # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler
192
+ num_inference_steps = num_inference_steps + 1
193
+
194
+ # because t_n+1 >= t_n, we slice the timesteps starting from the end
195
+ timesteps = timesteps[-num_inference_steps:]
196
+ return timesteps, num_inference_steps
197
+
198
+ return timesteps, num_inference_steps - t_start
inversion_utils.py ADDED
@@ -0,0 +1,793 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ import PIL
4
+
5
+ from typing import List, Optional, Union
6
+ from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput
7
+ from diffusers.utils import logging
8
+
9
+ VECTOR_DATA_FOLDER = "vector_data"
10
+ VECTOR_DATA_DICT = "vector_data"
11
+
12
+ logger = logging.get_logger(__name__)
13
+
14
+ def get_ddpm_inversion_scheduler(
15
+ scheduler,
16
+ step_function,
17
+ config,
18
+ timesteps,
19
+ save_timesteps,
20
+ latents,
21
+ x_ts,
22
+ x_ts_c_hat,
23
+ save_intermediate_results,
24
+ pipe,
25
+ x_0,
26
+ v1s_images,
27
+ v2s_images,
28
+ deltas_images,
29
+ v1_x0s,
30
+ v2_x0s,
31
+ deltas_x0s,
32
+ folder_name,
33
+ image_name,
34
+ time_measure_n,
35
+ ):
36
+ def step(
37
+ model_output: torch.FloatTensor,
38
+ timestep: int,
39
+ sample: torch.FloatTensor,
40
+ eta: float = 0.0,
41
+ use_clipped_model_output: bool = False,
42
+ generator=None,
43
+ variance_noise: Optional[torch.FloatTensor] = None,
44
+ return_dict: bool = True,
45
+ ):
46
+ # if scheduler.is_save:
47
+ # start = timer()
48
+ res_inv = step_save_latents(
49
+ scheduler,
50
+ model_output[:1, :, :, :],
51
+ timestep,
52
+ sample[:1, :, :, :],
53
+ eta,
54
+ use_clipped_model_output,
55
+ generator,
56
+ variance_noise,
57
+ return_dict,
58
+ )
59
+ # end = timer()
60
+ # print(f"Run Time Inv: {end - start}")
61
+
62
+ res_inf = step_use_latents(
63
+ scheduler,
64
+ model_output[1:, :, :, :],
65
+ timestep,
66
+ sample[1:, :, :, :],
67
+ eta,
68
+ use_clipped_model_output,
69
+ generator,
70
+ variance_noise,
71
+ return_dict,
72
+ )
73
+ # res = res_inv
74
+ res = (torch.cat((res_inv[0], res_inf[0]), dim=0),)
75
+ return res
76
+ # return res
77
+
78
+ scheduler.step_function = step_function
79
+ scheduler.is_save = True
80
+ scheduler._timesteps = timesteps
81
+ scheduler._save_timesteps = save_timesteps if save_timesteps else timesteps
82
+ scheduler._config = config
83
+ scheduler.latents = latents
84
+ scheduler.x_ts = x_ts
85
+ scheduler.x_ts_c_hat = x_ts_c_hat
86
+ scheduler.step = step
87
+ scheduler.save_intermediate_results = save_intermediate_results
88
+ scheduler.pipe = pipe
89
+ scheduler.v1s_images = v1s_images
90
+ scheduler.v2s_images = v2s_images
91
+ scheduler.deltas_images = deltas_images
92
+ scheduler.v1_x0s = v1_x0s
93
+ scheduler.v2_x0s = v2_x0s
94
+ scheduler.deltas_x0s = deltas_x0s
95
+ scheduler.clean_step_run = False
96
+ scheduler.x_0s = create_xts(
97
+ config.noise_shift_delta,
98
+ config.noise_timesteps,
99
+ config.clean_step_timestep,
100
+ None,
101
+ pipe.scheduler,
102
+ timesteps,
103
+ x_0,
104
+ no_add_noise=True,
105
+ )
106
+ scheduler.folder_name = folder_name
107
+ scheduler.image_name = image_name
108
+ scheduler.p_to_p = False
109
+ scheduler.p_to_p_replace = False
110
+ scheduler.time_measure_n = time_measure_n
111
+ return scheduler
112
+
113
+ def step_save_latents(
114
+ self,
115
+ model_output: torch.FloatTensor,
116
+ timestep: int,
117
+ sample: torch.FloatTensor,
118
+ eta: float = 0.0,
119
+ use_clipped_model_output: bool = False,
120
+ generator=None,
121
+ variance_noise: Optional[torch.FloatTensor] = None,
122
+ return_dict: bool = True,
123
+ ):
124
+ # print(self._save_timesteps)
125
+ # timestep_index = map_timpstep_to_index[timestep]
126
+ # timestep_index = ((self._save_timesteps == timestep).nonzero(as_tuple=True)[0]).item()
127
+ timestep_index = self._save_timesteps.index(timestep) if not self.clean_step_run else -1
128
+ next_timestep_index = timestep_index + 1 if not self.clean_step_run else -1
129
+ u_hat_t = self.step_function(
130
+ model_output=model_output,
131
+ timestep=timestep,
132
+ sample=sample,
133
+ eta=eta,
134
+ use_clipped_model_output=use_clipped_model_output,
135
+ generator=generator,
136
+ variance_noise=variance_noise,
137
+ return_dict=False,
138
+ scheduler=self,
139
+ )
140
+
141
+ x_t_minus_1 = self.x_ts[next_timestep_index]
142
+ self.x_ts_c_hat.append(u_hat_t)
143
+
144
+ z_t = x_t_minus_1 - u_hat_t
145
+ self.latents.append(z_t)
146
+ z_t, _ = normalize(z_t, timestep_index, self._config.max_norm_zs)
147
+
148
+ x_t_minus_1_predicted = u_hat_t + z_t
149
+
150
+ if not return_dict:
151
+ return (x_t_minus_1_predicted,)
152
+
153
+ return DDIMSchedulerOutput(prev_sample=x_t_minus_1, pred_original_sample=None)
154
+
155
+ def step_use_latents(
156
+ self,
157
+ model_output: torch.FloatTensor,
158
+ timestep: int,
159
+ sample: torch.FloatTensor,
160
+ eta: float = 0.0,
161
+ use_clipped_model_output: bool = False,
162
+ generator=None,
163
+ variance_noise: Optional[torch.FloatTensor] = None,
164
+ return_dict: bool = True,
165
+ ):
166
+ # timestep_index = ((self._save_timesteps == timestep).nonzero(as_tuple=True)[0]).item()
167
+ timestep_index = self._timesteps.index(timestep) if not self.clean_step_run else -1
168
+ next_timestep_index = (
169
+ timestep_index + 1 if not self.clean_step_run else -1
170
+ )
171
+ z_t = self.latents[next_timestep_index] # + 1 because latents[0] is X_T
172
+
173
+ _, normalize_coefficient = normalize(
174
+ z_t[0] if self._config.breakdown == "x_t_hat_c_with_zeros" else z_t,
175
+ timestep_index,
176
+ self._config.max_norm_zs,
177
+ )
178
+
179
+ if normalize_coefficient == 0:
180
+ eta = 0
181
+
182
+ # eta = normalize_coefficient
183
+
184
+ x_t_hat_c_hat = self.step_function(
185
+ model_output=model_output,
186
+ timestep=timestep,
187
+ sample=sample,
188
+ eta=eta,
189
+ use_clipped_model_output=use_clipped_model_output,
190
+ generator=generator,
191
+ variance_noise=variance_noise,
192
+ return_dict=False,
193
+ scheduler=self,
194
+ )
195
+
196
+ w1 = self._config.ws1[timestep_index]
197
+ w2 = self._config.ws2[timestep_index]
198
+
199
+ x_t_minus_1_exact = self.x_ts[next_timestep_index]
200
+ x_t_minus_1_exact = x_t_minus_1_exact.expand_as(x_t_hat_c_hat)
201
+
202
+ x_t_c_hat: torch.Tensor = self.x_ts_c_hat[next_timestep_index]
203
+ if self._config.breakdown == "x_t_c_hat":
204
+ raise NotImplementedError("breakdown x_t_c_hat not implemented yet")
205
+
206
+ # x_t_c_hat = x_t_c_hat.expand_as(x_t_hat_c_hat)
207
+ x_t_c = x_t_c_hat[0].expand_as(x_t_hat_c_hat)
208
+
209
+ # if self._config.breakdown == "x_t_c_hat":
210
+ # v1 = x_t_hat_c_hat - x_t_c_hat
211
+ # v2 = x_t_c_hat - x_t_c
212
+ if (
213
+ self._config.breakdown == "x_t_hat_c"
214
+ or self._config.breakdown == "x_t_hat_c_with_zeros"
215
+ ):
216
+ zero_index_reconstruction = 1 if not self.time_measure_n else 0
217
+ edit_prompts_num = (
218
+ (model_output.size(0) - zero_index_reconstruction) // 3
219
+ if self._config.breakdown == "x_t_hat_c_with_zeros" and not self.p_to_p
220
+ else (model_output.size(0) - zero_index_reconstruction) // 2
221
+ )
222
+ x_t_hat_c_indices = (zero_index_reconstruction, edit_prompts_num + zero_index_reconstruction)
223
+ edit_images_indices = (
224
+ edit_prompts_num + zero_index_reconstruction,
225
+ (
226
+ model_output.size(0)
227
+ if self._config.breakdown == "x_t_hat_c"
228
+ else zero_index_reconstruction + 2 * edit_prompts_num
229
+ ),
230
+ )
231
+ x_t_hat_c = torch.zeros_like(x_t_hat_c_hat)
232
+ x_t_hat_c[edit_images_indices[0] : edit_images_indices[1]] = x_t_hat_c_hat[
233
+ x_t_hat_c_indices[0] : x_t_hat_c_indices[1]
234
+ ]
235
+ v1 = x_t_hat_c_hat - x_t_hat_c
236
+ v2 = x_t_hat_c - normalize_coefficient * x_t_c
237
+ if self._config.breakdown == "x_t_hat_c_with_zeros" and not self.p_to_p:
238
+ path = os.path.join(
239
+ self.folder_name,
240
+ VECTOR_DATA_FOLDER,
241
+ self.image_name,
242
+ )
243
+ if not hasattr(self, VECTOR_DATA_DICT):
244
+ os.makedirs(path, exist_ok=True)
245
+ self.vector_data = dict()
246
+
247
+ x_t_0 = x_t_c_hat[1]
248
+ empty_prompt_indices = (1 + 2 * edit_prompts_num, 1 + 3 * edit_prompts_num)
249
+ x_t_hat_0 = x_t_hat_c_hat[empty_prompt_indices[0] : empty_prompt_indices[1]]
250
+
251
+ self.vector_data[timestep.item()] = dict()
252
+ self.vector_data[timestep.item()]["x_t_hat_c"] = x_t_hat_c[
253
+ edit_images_indices[0] : edit_images_indices[1]
254
+ ]
255
+ self.vector_data[timestep.item()]["x_t_hat_0"] = x_t_hat_0
256
+ self.vector_data[timestep.item()]["x_t_c"] = x_t_c[0].expand_as(x_t_hat_0)
257
+ self.vector_data[timestep.item()]["x_t_0"] = x_t_0.expand_as(x_t_hat_0)
258
+ self.vector_data[timestep.item()]["x_t_hat_c_hat"] = x_t_hat_c_hat[
259
+ edit_images_indices[0] : edit_images_indices[1]
260
+ ]
261
+ self.vector_data[timestep.item()]["x_t_minus_1_noisy"] = x_t_minus_1_exact[
262
+ 0
263
+ ].expand_as(x_t_hat_0)
264
+ self.vector_data[timestep.item()]["x_t_minus_1_clean"] = self.x_0s[
265
+ next_timestep_index
266
+ ].expand_as(x_t_hat_0)
267
+
268
+ else: # no breakdown
269
+ v1 = x_t_hat_c_hat - normalize_coefficient * x_t_c
270
+ v2 = 0
271
+
272
+ if self.save_intermediate_results and not self.p_to_p:
273
+ delta = v1 + v2
274
+ v1_plus_x0 = self.x_0s[next_timestep_index] + v1
275
+ v2_plus_x0 = self.x_0s[next_timestep_index] + v2
276
+ delta_plus_x0 = self.x_0s[next_timestep_index] + delta
277
+
278
+ v1_images = decode_latents(v1, self.pipe)
279
+ self.v1s_images.append(v1_images)
280
+ v2_images = (
281
+ decode_latents(v2, self.pipe)
282
+ if self._config.breakdown != "no_breakdown"
283
+ else [PIL.Image.new("RGB", (1, 1))]
284
+ )
285
+ self.v2s_images.append(v2_images)
286
+ delta_images = decode_latents(delta, self.pipe)
287
+ self.deltas_images.append(delta_images)
288
+ v1_plus_x0_images = decode_latents(v1_plus_x0, self.pipe)
289
+ self.v1_x0s.append(v1_plus_x0_images)
290
+ v2_plus_x0_images = (
291
+ decode_latents(v2_plus_x0, self.pipe)
292
+ if self._config.breakdown != "no_breakdown"
293
+ else [PIL.Image.new("RGB", (1, 1))]
294
+ )
295
+ self.v2_x0s.append(v2_plus_x0_images)
296
+ delta_plus_x0_images = decode_latents(delta_plus_x0, self.pipe)
297
+ self.deltas_x0s.append(delta_plus_x0_images)
298
+
299
+ # print(f"v1 norm: {torch.norm(v1, dim=0).mean()}")
300
+ # if self._config.breakdown != "no_breakdown":
301
+ # print(f"v2 norm: {torch.norm(v2, dim=0).mean()}")
302
+ # print(f"v sum norm: {torch.norm(v1 + v2, dim=0).mean()}")
303
+
304
+ x_t_minus_1 = normalize_coefficient * x_t_minus_1_exact + w1 * v1 + w2 * v2
305
+
306
+ if (
307
+ self._config.breakdown == "x_t_hat_c"
308
+ or self._config.breakdown == "x_t_hat_c_with_zeros"
309
+ ):
310
+ x_t_minus_1[x_t_hat_c_indices[0] : x_t_hat_c_indices[1]] = x_t_minus_1[
311
+ edit_images_indices[0] : edit_images_indices[1]
312
+ ] # update x_t_hat_c to be x_t_hat_c_hat
313
+ if self._config.breakdown == "x_t_hat_c_with_zeros" and not self.p_to_p:
314
+ x_t_minus_1[empty_prompt_indices[0] : empty_prompt_indices[1]] = (
315
+ x_t_minus_1[edit_images_indices[0] : edit_images_indices[1]]
316
+ )
317
+ self.vector_data[timestep.item()]["x_t_minus_1_edited"] = x_t_minus_1[
318
+ edit_images_indices[0] : edit_images_indices[1]
319
+ ]
320
+ if timestep == self._timesteps[-1]:
321
+ torch.save(
322
+ self.vector_data,
323
+ os.path.join(
324
+ path,
325
+ f"{VECTOR_DATA_DICT}.pt",
326
+ ),
327
+ )
328
+ # p_to_p_force_perfect_reconstruction
329
+ if not self.time_measure_n:
330
+ x_t_minus_1[0] = x_t_minus_1_exact[0]
331
+
332
+ if not return_dict:
333
+ return (x_t_minus_1,)
334
+
335
+ return DDIMSchedulerOutput(
336
+ prev_sample=x_t_minus_1,
337
+ pred_original_sample=None,
338
+ )
339
+
340
+ def create_xts(
341
+ noise_shift_delta,
342
+ noise_timesteps,
343
+ clean_step_timestep,
344
+ generator,
345
+ scheduler,
346
+ timesteps,
347
+ x_0,
348
+ no_add_noise=False,
349
+ ):
350
+ if noise_timesteps is None:
351
+ noising_delta = noise_shift_delta * (timesteps[0] - timesteps[1])
352
+ noise_timesteps = [timestep - int(noising_delta) for timestep in timesteps]
353
+
354
+ first_x_0_idx = len(noise_timesteps)
355
+ for i in range(len(noise_timesteps)):
356
+ if noise_timesteps[i] <= 0:
357
+ first_x_0_idx = i
358
+ break
359
+
360
+ noise_timesteps = noise_timesteps[:first_x_0_idx]
361
+
362
+ x_0_expanded = x_0.expand(len(noise_timesteps), -1, -1, -1)
363
+ noise = (
364
+ torch.randn(x_0_expanded.size(), generator=generator, device="cpu").to(
365
+ x_0.device
366
+ )
367
+ if not no_add_noise
368
+ else torch.zeros_like(x_0_expanded)
369
+ )
370
+ x_ts = scheduler.add_noise(
371
+ x_0_expanded,
372
+ noise,
373
+ torch.IntTensor(noise_timesteps),
374
+ )
375
+ x_ts = [t.unsqueeze(dim=0) for t in list(x_ts)]
376
+ x_ts += [x_0] * (len(timesteps) - first_x_0_idx)
377
+ x_ts += [x_0]
378
+ if clean_step_timestep > 0:
379
+ x_ts += [x_0]
380
+ return x_ts
381
+
382
+ def normalize(
383
+ z_t,
384
+ i,
385
+ max_norm_zs,
386
+ ):
387
+ max_norm = max_norm_zs[i]
388
+ if max_norm < 0:
389
+ return z_t, 1
390
+
391
+ norm = torch.norm(z_t)
392
+ if norm < max_norm:
393
+ return z_t, 1
394
+
395
+ coeff = max_norm / norm
396
+ z_t = z_t * coeff
397
+ return z_t, coeff
398
+
399
+ def decode_latents(latent, pipe):
400
+ latent_img = pipe.vae.decode(
401
+ latent / pipe.vae.config.scaling_factor, return_dict=False
402
+ )[0]
403
+ return pipe.image_processor.postprocess(latent_img, output_type="pil")
404
+
405
+ def deterministic_ddim_step(
406
+ model_output: torch.FloatTensor,
407
+ timestep: int,
408
+ sample: torch.FloatTensor,
409
+ eta: float = 0.0,
410
+ use_clipped_model_output: bool = False,
411
+ generator=None,
412
+ variance_noise: Optional[torch.FloatTensor] = None,
413
+ return_dict: bool = True,
414
+ scheduler=None,
415
+ ):
416
+
417
+ if scheduler.num_inference_steps is None:
418
+ raise ValueError(
419
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
420
+ )
421
+
422
+ prev_timestep = (
423
+ timestep - scheduler.config.num_train_timesteps // scheduler.num_inference_steps
424
+ )
425
+
426
+ # 2. compute alphas, betas
427
+ alpha_prod_t = scheduler.alphas_cumprod[timestep]
428
+ alpha_prod_t_prev = (
429
+ scheduler.alphas_cumprod[prev_timestep]
430
+ if prev_timestep >= 0
431
+ else scheduler.final_alpha_cumprod
432
+ )
433
+
434
+ beta_prod_t = 1 - alpha_prod_t
435
+
436
+ if scheduler.config.prediction_type == "epsilon":
437
+ pred_original_sample = (
438
+ sample - beta_prod_t ** (0.5) * model_output
439
+ ) / alpha_prod_t ** (0.5)
440
+ pred_epsilon = model_output
441
+ elif scheduler.config.prediction_type == "sample":
442
+ pred_original_sample = model_output
443
+ pred_epsilon = (
444
+ sample - alpha_prod_t ** (0.5) * pred_original_sample
445
+ ) / beta_prod_t ** (0.5)
446
+ elif scheduler.config.prediction_type == "v_prediction":
447
+ pred_original_sample = (alpha_prod_t**0.5) * sample - (
448
+ beta_prod_t**0.5
449
+ ) * model_output
450
+ pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
451
+ else:
452
+ raise ValueError(
453
+ f"prediction_type given as {scheduler.config.prediction_type} must be one of `epsilon`, `sample`, or"
454
+ " `v_prediction`"
455
+ )
456
+
457
+ # 4. Clip or threshold "predicted x_0"
458
+ if scheduler.config.thresholding:
459
+ pred_original_sample = scheduler._threshold_sample(pred_original_sample)
460
+ elif scheduler.config.clip_sample:
461
+ pred_original_sample = pred_original_sample.clamp(
462
+ -scheduler.config.clip_sample_range,
463
+ scheduler.config.clip_sample_range,
464
+ )
465
+
466
+ # 5. compute variance: "sigma_t(η)" -> see formula (16)
467
+ # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
468
+ variance = scheduler._get_variance(timestep, prev_timestep)
469
+ std_dev_t = eta * variance ** (0.5)
470
+
471
+ if use_clipped_model_output:
472
+ # the pred_epsilon is always re-derived from the clipped x_0 in Glide
473
+ pred_epsilon = (
474
+ sample - alpha_prod_t ** (0.5) * pred_original_sample
475
+ ) / beta_prod_t ** (0.5)
476
+
477
+ # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
478
+ pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (
479
+ 0.5
480
+ ) * pred_epsilon
481
+
482
+ # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
483
+ prev_sample = (
484
+ alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
485
+ )
486
+ return prev_sample
487
+
488
+
489
+ def deterministic_euler_step(
490
+ model_output: torch.FloatTensor,
491
+ timestep: Union[float, torch.FloatTensor],
492
+ sample: torch.FloatTensor,
493
+ eta,
494
+ use_clipped_model_output,
495
+ generator,
496
+ variance_noise,
497
+ return_dict,
498
+ scheduler,
499
+ ):
500
+ """
501
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
502
+ process from the learned model outputs (most often the predicted noise).
503
+
504
+ Args:
505
+ model_output (`torch.FloatTensor`):
506
+ The direct output from learned diffusion model.
507
+ timestep (`float`):
508
+ The current discrete timestep in the diffusion chain.
509
+ sample (`torch.FloatTensor`):
510
+ A current instance of a sample created by the diffusion process.
511
+ generator (`torch.Generator`, *optional*):
512
+ A random number generator.
513
+ return_dict (`bool`):
514
+ Whether or not to return a
515
+ [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or tuple.
516
+
517
+ Returns:
518
+ [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or `tuple`:
519
+ If return_dict is `True`,
520
+ [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] is returned,
521
+ otherwise a tuple is returned where the first element is the sample tensor.
522
+
523
+ """
524
+
525
+ if (
526
+ isinstance(timestep, int)
527
+ or isinstance(timestep, torch.IntTensor)
528
+ or isinstance(timestep, torch.LongTensor)
529
+ ):
530
+ raise ValueError(
531
+ (
532
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
533
+ " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
534
+ " one of the `scheduler.timesteps` as a timestep."
535
+ ),
536
+ )
537
+
538
+ if scheduler.step_index is None:
539
+ scheduler._init_step_index(timestep)
540
+
541
+ sigma = scheduler.sigmas[scheduler.step_index]
542
+
543
+ # Upcast to avoid precision issues when computing prev_sample
544
+ sample = sample.to(torch.float32)
545
+
546
+ # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
547
+ if scheduler.config.prediction_type == "epsilon":
548
+ pred_original_sample = sample - sigma * model_output
549
+ elif scheduler.config.prediction_type == "v_prediction":
550
+ # * c_out + input * c_skip
551
+ pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (
552
+ sample / (sigma**2 + 1)
553
+ )
554
+ elif scheduler.config.prediction_type == "sample":
555
+ raise NotImplementedError("prediction_type not implemented yet: sample")
556
+ else:
557
+ raise ValueError(
558
+ f"prediction_type given as {scheduler.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
559
+ )
560
+
561
+ sigma_from = scheduler.sigmas[scheduler.step_index]
562
+ sigma_to = scheduler.sigmas[scheduler.step_index + 1]
563
+ sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
564
+ sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
565
+
566
+ # 2. Convert to an ODE derivative
567
+ derivative = (sample - pred_original_sample) / sigma
568
+
569
+ dt = sigma_down - sigma
570
+
571
+ prev_sample = sample + derivative * dt
572
+
573
+ # Cast sample back to model compatible dtype
574
+ prev_sample = prev_sample.to(model_output.dtype)
575
+
576
+ # upon completion increase step index by one
577
+ scheduler._step_index += 1
578
+
579
+ return prev_sample
580
+
581
+
582
+ def deterministic_non_ancestral_euler_step(
583
+ model_output: torch.FloatTensor,
584
+ timestep: Union[float, torch.FloatTensor],
585
+ sample: torch.FloatTensor,
586
+ eta: float = 0.0,
587
+ use_clipped_model_output: bool = False,
588
+ s_churn: float = 0.0,
589
+ s_tmin: float = 0.0,
590
+ s_tmax: float = float("inf"),
591
+ s_noise: float = 1.0,
592
+ generator: Optional[torch.Generator] = None,
593
+ variance_noise: Optional[torch.FloatTensor] = None,
594
+ return_dict: bool = True,
595
+ scheduler=None,
596
+ ):
597
+ """
598
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
599
+ process from the learned model outputs (most often the predicted noise).
600
+
601
+ Args:
602
+ model_output (`torch.FloatTensor`):
603
+ The direct output from learned diffusion model.
604
+ timestep (`float`):
605
+ The current discrete timestep in the diffusion chain.
606
+ sample (`torch.FloatTensor`):
607
+ A current instance of a sample created by the diffusion process.
608
+ s_churn (`float`):
609
+ s_tmin (`float`):
610
+ s_tmax (`float`):
611
+ s_noise (`float`, defaults to 1.0):
612
+ Scaling factor for noise added to the sample.
613
+ generator (`torch.Generator`, *optional*):
614
+ A random number generator.
615
+ return_dict (`bool`):
616
+ Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
617
+ tuple.
618
+
619
+ Returns:
620
+ [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
621
+ If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
622
+ returned, otherwise a tuple is returned where the first element is the sample tensor.
623
+ """
624
+
625
+ if (
626
+ isinstance(timestep, int)
627
+ or isinstance(timestep, torch.IntTensor)
628
+ or isinstance(timestep, torch.LongTensor)
629
+ ):
630
+ raise ValueError(
631
+ (
632
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
633
+ " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
634
+ " one of the `scheduler.timesteps` as a timestep."
635
+ ),
636
+ )
637
+
638
+ if not scheduler.is_scale_input_called:
639
+ logger.warning(
640
+ "The `scale_model_input` function should be called before `step` to ensure correct denoising. "
641
+ "See `StableDiffusionPipeline` for a usage example."
642
+ )
643
+
644
+ if scheduler.step_index is None:
645
+ scheduler._init_step_index(timestep)
646
+
647
+ # Upcast to avoid precision issues when computing prev_sample
648
+ sample = sample.to(torch.float32)
649
+
650
+ sigma = scheduler.sigmas[scheduler.step_index]
651
+
652
+ gamma = (
653
+ min(s_churn / (len(scheduler.sigmas) - 1), 2**0.5 - 1)
654
+ if s_tmin <= sigma <= s_tmax
655
+ else 0.0
656
+ )
657
+
658
+ sigma_hat = sigma * (gamma + 1)
659
+
660
+ # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
661
+ # NOTE: "original_sample" should not be an expected prediction_type but is left in for
662
+ # backwards compatibility
663
+ if (
664
+ scheduler.config.prediction_type == "original_sample"
665
+ or scheduler.config.prediction_type == "sample"
666
+ ):
667
+ pred_original_sample = model_output
668
+ elif scheduler.config.prediction_type == "epsilon":
669
+ pred_original_sample = sample - sigma_hat * model_output
670
+ elif scheduler.config.prediction_type == "v_prediction":
671
+ # denoised = model_output * c_out + input * c_skip
672
+ pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (
673
+ sample / (sigma**2 + 1)
674
+ )
675
+ else:
676
+ raise ValueError(
677
+ f"prediction_type given as {scheduler.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
678
+ )
679
+
680
+ # 2. Convert to an ODE derivative
681
+ derivative = (sample - pred_original_sample) / sigma_hat
682
+
683
+ dt = scheduler.sigmas[scheduler.step_index + 1] - sigma_hat
684
+
685
+ prev_sample = sample + derivative * dt
686
+
687
+ # Cast sample back to model compatible dtype
688
+ prev_sample = prev_sample.to(model_output.dtype)
689
+
690
+ # upon completion increase step index by one
691
+ scheduler._step_index += 1
692
+
693
+ return prev_sample
694
+
695
+
696
+ def deterministic_ddpm_step(
697
+ model_output: torch.FloatTensor,
698
+ timestep: Union[float, torch.FloatTensor],
699
+ sample: torch.FloatTensor,
700
+ eta,
701
+ use_clipped_model_output,
702
+ generator,
703
+ variance_noise,
704
+ return_dict,
705
+ scheduler,
706
+ ):
707
+ """
708
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
709
+ process from the learned model outputs (most often the predicted noise).
710
+
711
+ Args:
712
+ model_output (`torch.FloatTensor`):
713
+ The direct output from learned diffusion model.
714
+ timestep (`float`):
715
+ The current discrete timestep in the diffusion chain.
716
+ sample (`torch.FloatTensor`):
717
+ A current instance of a sample created by the diffusion process.
718
+ generator (`torch.Generator`, *optional*):
719
+ A random number generator.
720
+ return_dict (`bool`, *optional*, defaults to `True`):
721
+ Whether or not to return a [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] or `tuple`.
722
+
723
+ Returns:
724
+ [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] or `tuple`:
725
+ If return_dict is `True`, [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] is returned, otherwise a
726
+ tuple is returned where the first element is the sample tensor.
727
+
728
+ """
729
+ t = timestep
730
+
731
+ prev_t = scheduler.previous_timestep(t)
732
+
733
+ if model_output.shape[1] == sample.shape[1] * 2 and scheduler.variance_type in [
734
+ "learned",
735
+ "learned_range",
736
+ ]:
737
+ model_output, predicted_variance = torch.split(
738
+ model_output, sample.shape[1], dim=1
739
+ )
740
+ else:
741
+ predicted_variance = None
742
+
743
+ # 1. compute alphas, betas
744
+ alpha_prod_t = scheduler.alphas_cumprod[t]
745
+ alpha_prod_t_prev = (
746
+ scheduler.alphas_cumprod[prev_t] if prev_t >= 0 else scheduler.one
747
+ )
748
+ beta_prod_t = 1 - alpha_prod_t
749
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
750
+ current_alpha_t = alpha_prod_t / alpha_prod_t_prev
751
+ current_beta_t = 1 - current_alpha_t
752
+
753
+ # 2. compute predicted original sample from predicted noise also called
754
+ # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
755
+ if scheduler.config.prediction_type == "epsilon":
756
+ pred_original_sample = (
757
+ sample - beta_prod_t ** (0.5) * model_output
758
+ ) / alpha_prod_t ** (0.5)
759
+ elif scheduler.config.prediction_type == "sample":
760
+ pred_original_sample = model_output
761
+ elif scheduler.config.prediction_type == "v_prediction":
762
+ pred_original_sample = (alpha_prod_t**0.5) * sample - (
763
+ beta_prod_t**0.5
764
+ ) * model_output
765
+ else:
766
+ raise ValueError(
767
+ f"prediction_type given as {scheduler.config.prediction_type} must be one of `epsilon`, `sample` or"
768
+ " `v_prediction` for the DDPMScheduler."
769
+ )
770
+
771
+ # 3. Clip or threshold "predicted x_0"
772
+ if scheduler.config.thresholding:
773
+ pred_original_sample = scheduler._threshold_sample(pred_original_sample)
774
+ elif scheduler.config.clip_sample:
775
+ pred_original_sample = pred_original_sample.clamp(
776
+ -scheduler.config.clip_sample_range, scheduler.config.clip_sample_range
777
+ )
778
+
779
+ # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
780
+ # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
781
+ pred_original_sample_coeff = (
782
+ alpha_prod_t_prev ** (0.5) * current_beta_t
783
+ ) / beta_prod_t
784
+ current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t
785
+
786
+ # 5. Compute predicted previous sample µ_t
787
+ # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
788
+ pred_prev_sample = (
789
+ pred_original_sample_coeff * pred_original_sample
790
+ + current_sample_coeff * sample
791
+ )
792
+
793
+ return pred_prev_sample
loading_utils.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Callable, Optional, Union
3
+
4
+ import PIL.Image
5
+ import PIL.ImageOps
6
+ import requests
7
+
8
+
9
+ session = requests.Session()
10
+ adapter = requests.adapters.HTTPAdapter(
11
+ pool_connections=10,
12
+ pool_maxsize=10,
13
+ max_retries=3
14
+ )
15
+ session.mount('http://', adapter)
16
+ session.mount('https://', adapter)
17
+
18
+
19
+ def load_image(
20
+ image: Union[str, PIL.Image.Image], convert_method: Optional[Callable[[PIL.Image.Image], PIL.Image.Image]] = None
21
+ ) -> PIL.Image.Image:
22
+ """
23
+ Loads `image` to a PIL Image.
24
+
25
+ Args:
26
+ image (`str` or `PIL.Image.Image`):
27
+ The image to convert to the PIL Image format.
28
+ convert_method (Callable[[PIL.Image.Image], PIL.Image.Image], *optional*):
29
+ A conversion method to apply to the image after loading it. When set to `None` the image will be converted
30
+ "RGB".
31
+
32
+ Returns:
33
+ `PIL.Image.Image`:
34
+ A PIL Image.
35
+ """
36
+ if isinstance(image, str):
37
+ if image.startswith("http://") or image.startswith("https://"):
38
+ image = PIL.Image.open(session.get(image, stream=True).raw)
39
+ elif os.path.isfile(image):
40
+ image = PIL.Image.open(image)
41
+ else:
42
+ raise ValueError(
43
+ f"Incorrect path or URL. URLs must start with `http://` or `https://`, and {image} is not a valid path."
44
+ )
45
+ elif isinstance(image, PIL.Image.Image):
46
+ image = image
47
+ else:
48
+ raise ValueError(
49
+ "Incorrect format used for the image. Should be a URL linking to an image, a local path, or a PIL image."
50
+ )
51
+
52
+ image = PIL.ImageOps.exif_transpose(image)
53
+
54
+ if convert_method is not None:
55
+ image = convert_method(image)
56
+ else:
57
+ image = image.convert("RGB")
58
+
59
+ return image
60
+
model_handler.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+
4
+ from diffusers import (
5
+ DDPMScheduler,
6
+ StableDiffusionXLImg2ImgPipeline,
7
+ LTXPipeline,
8
+ AutoencoderKL,
9
+ )
10
+
11
+ from hidiffusion import apply_hidiffusion
12
+
13
+ from mediapipe.tasks import python
14
+ from mediapipe.tasks.python import vision
15
+
16
+ from image_gen_aux import UpscaleWithModel
17
+
18
+ BASE_MODEL = "stabilityai/sdxl-turbo"
19
+ VIDEO_MODEL = "Lightricks/LTX-Video"
20
+ device = "cuda" if torch.cuda.is_available() else "cpu"
21
+
22
+ class ModelHandler:
23
+ def __init__(self):
24
+ self.base_pipe = None
25
+ self.video_pipe = None
26
+ self.compiled_model = None
27
+ self.segmenter = None
28
+ self.upscaler = None
29
+ self.upscaler4SD = None
30
+ self.load_models()
31
+
32
+ def load_base(self):
33
+ vae = AutoencoderKL.from_pretrained(
34
+ "madebyollin/sdxl-vae-fp16-fix",
35
+ torch_dtype=torch.float16,
36
+ )
37
+
38
+ base_pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
39
+ BASE_MODEL,
40
+ vae=vae,
41
+ torch_dtype=torch.float16,
42
+ variant="fp16",
43
+ use_safetensors=True,
44
+ )
45
+ base_pipe = base_pipe.to(device, silence_dtype_warnings=True)
46
+ base_pipe.scheduler = DDPMScheduler.from_pretrained(
47
+ BASE_MODEL,
48
+ subfolder="scheduler",
49
+ )
50
+ apply_hidiffusion(base_pipe)
51
+
52
+ return base_pipe
53
+
54
+ def load_video_pipe(self):
55
+ pipe = LTXPipeline.from_pretrained(VIDEO_MODEL, torch_dtype=torch.bfloat16)
56
+ pipe.to(device)
57
+ return pipe
58
+
59
+ def load_segmenter(self):
60
+ segment_model = "checkpoints/selfie_multiclass_256x256.tflite"
61
+ base_options = python.BaseOptions(model_asset_path=segment_model)
62
+ options = vision.ImageSegmenterOptions(base_options=base_options,output_category_mask=True)
63
+ segmenter = vision.ImageSegmenter.create_from_options(options)
64
+ return segmenter
65
+
66
+ def load_upscaler(self):
67
+ model_name = os.environ.get("UPSCALE_MODEL", "Phips/4xNomosWebPhoto_RealPLKSR")
68
+ upscaler = UpscaleWithModel.from_pretrained(model_name).to(device)
69
+ return upscaler
70
+
71
+ def load_upscaler4SD(self):
72
+ model_name = os.environ.get("UPSCALE_FOR_SD_MODEL", "Phips/1xDeJPG_realplksr_otf")
73
+ upscaler = UpscaleWithModel.from_pretrained(model_name).to(device)
74
+ return upscaler
75
+
76
+ def load_models(self):
77
+ base_pipe = self.load_base()
78
+ segmenter = self.load_segmenter()
79
+ upscaler = self.load_upscaler()
80
+ upscaler4SD = self.load_upscaler4SD()
81
+
82
+ self.base_pipe = base_pipe
83
+ self.segmenter = segmenter
84
+ self.upscaler = upscaler
85
+ self.upscaler4SD = upscaler4SD
86
+
87
+ MODELS = ModelHandler()
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ ml-collections
4
+ mediapipe
5
+ gradio
6
+ diffusers
7
+ transformers
8
+ accelerate
9
+ sentencepiece
10
+ compel
11
+ hidiffusion
12
+ git+https://github.com/asomoza/image_gen_aux.git
run_configs/noise_shift_3_steps.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ breakdown: "x_t_hat_c"
2
+ cross_r: 0.9
3
+ eta_reconstruct: 1
4
+ eta_retrieve: 1
5
+ max_norm_zs: [-1, -1, 15.5]
6
+ model: "stabilityai/sdxl-turbo"
7
+ noise_shift_delta: 1
8
+ noise_timesteps: [599, 299, 0]
9
+ timesteps: [799, 499, 199]
10
+ num_steps_inversion: 5
11
+ step_start: 1
12
+ real_cfg_scale: 0
13
+ real_cfg_scale_save: 0
14
+ scheduler_type: "ddpm"
15
+ seed: 2
16
+ self_r: 0.5
17
+ ws1: 1.5
18
+ ws2: 1
19
+ clean_step_timestep: 0
run_configs/noise_shift_guidance_1_5.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ breakdown: "x_t_hat_c"
2
+ cross_r: 0.9
3
+ eta: 1
4
+ max_norm_zs: [-1, -1, -1, 15.5]
5
+ model: ""
6
+ noise_shift_delta: 1
7
+ noise_timesteps: null
8
+ num_steps_inversion: 20
9
+ step_start: 5
10
+ real_cfg_scale: 0
11
+ real_cfg_scale_save: 0
12
+ scheduler_type: "ddpm"
13
+ seed: 2
14
+ self_r: 0.5
15
+ timesteps: null
16
+ ws1: 1.5
17
+ ws2: 1
18
+ clean_step_timestep: 0
segment_utils.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import mediapipe as mp
3
+ import uuid
4
+
5
+ from PIL import Image
6
+ from scipy.ndimage import binary_dilation
7
+ from croper import Croper
8
+
9
+ from model_handler import MODELS
10
+
11
+ segmenter = MODELS.segmenter
12
+
13
+ def restore_result(croper, category, generated_image):
14
+ square_length = croper.square_length
15
+ generated_image = generated_image.resize((square_length, square_length))
16
+
17
+ cropped_generated_image = generated_image.crop((croper.square_start_x, croper.square_start_y, croper.square_end_x, croper.square_end_y))
18
+ cropped_square_mask_image = get_restore_mask_image(croper, category, cropped_generated_image)
19
+
20
+ restored_image = croper.input_image.copy()
21
+ restored_image.paste(cropped_generated_image, (croper.origin_start_x, croper.origin_start_y), cropped_square_mask_image)
22
+
23
+ return restored_image
24
+
25
+ def segment_image(input_image, category, input_size, mask_expansion, mask_dilation):
26
+ mask_size = int(input_size)
27
+ mask_expansion = int(mask_expansion)
28
+
29
+ image = mp.Image(image_format=mp.ImageFormat.SRGB, data=np.asarray(input_image))
30
+ segmentation_result = segmenter.segment(image)
31
+ category_mask = segmentation_result.category_mask
32
+ category_mask_np = category_mask.numpy_view()
33
+
34
+ if category == "hair":
35
+ target_mask = get_hair_mask(category_mask_np, mask_dilation)
36
+ elif category == "clothes":
37
+ target_mask = get_clothes_mask(category_mask_np, mask_dilation)
38
+ elif category == "face":
39
+ target_mask = get_face_mask(category_mask_np, mask_dilation)
40
+ else:
41
+ target_mask = get_face_mask(category_mask_np, mask_dilation)
42
+
43
+ croper = Croper(input_image, target_mask, mask_size, mask_expansion)
44
+ croper.corp_mask_image()
45
+ origin_area_image = croper.resized_square_image
46
+
47
+ return origin_area_image, croper
48
+
49
+ def get_face_mask(category_mask_np, dilation=1):
50
+ face_skin_mask = category_mask_np == 3
51
+ if dilation > 0:
52
+ face_skin_mask = binary_dilation(face_skin_mask, iterations=dilation)
53
+
54
+ return face_skin_mask
55
+
56
+ def get_clothes_mask(category_mask_np, dilation=1):
57
+ body_skin_mask = category_mask_np == 2
58
+ clothes_mask = category_mask_np == 4
59
+ combined_mask = np.logical_or(body_skin_mask, clothes_mask)
60
+ combined_mask = binary_dilation(combined_mask, iterations=4)
61
+ if dilation > 0:
62
+ combined_mask = binary_dilation(combined_mask, iterations=dilation)
63
+ return combined_mask
64
+
65
+ def get_hair_mask(category_mask_np, dilation=1):
66
+ hair_mask = category_mask_np == 1
67
+ if dilation > 0:
68
+ hair_mask = binary_dilation(hair_mask, iterations=dilation)
69
+ return hair_mask
70
+
71
+ def get_restore_mask_image(croper, category, generated_image):
72
+ image = mp.Image(image_format=mp.ImageFormat.SRGB, data=np.asarray(generated_image))
73
+ segmentation_result = segmenter.segment(image)
74
+ category_mask = segmentation_result.category_mask
75
+ category_mask_np = category_mask.numpy_view()
76
+
77
+ if category == "hair":
78
+ target_mask = get_hair_mask(category_mask_np, 0)
79
+ elif category == "clothes":
80
+ target_mask = get_clothes_mask(category_mask_np, 0)
81
+ elif category == "face":
82
+ target_mask = get_face_mask(category_mask_np, 0)
83
+
84
+ combined_mask = np.logical_or(target_mask, croper.corp_mask)
85
+ mask_image = Image.fromarray((combined_mask * 255).astype(np.uint8))
86
+ return mask_image