Barak1 commited on
Commit
d48db0f
·
2 Parent(s): 65ace7f dc77641

Merge branch 'main' of https://huggingface.co/spaces/barakmeiri/RNRI

Browse files
app.py CHANGED
@@ -3,38 +3,50 @@ import numpy as np
3
  import random
4
  from diffusers import DiffusionPipeline
5
  import torch
 
 
 
 
 
6
 
7
  device = "cuda" if torch.cuda.is_available() else "cpu"
8
 
9
- if torch.cuda.is_available():
10
- torch.cuda.max_memory_allocated(device=device)
11
- pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16", use_safetensors=True)
12
- pipe.enable_xformers_memory_efficient_attention()
13
- pipe = pipe.to(device)
14
- else:
15
- pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", use_safetensors=True)
16
- pipe = pipe.to(device)
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  MAX_SEED = np.iinfo(np.int32).max
19
  MAX_IMAGE_SIZE = 1024
20
 
21
- def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps):
22
 
23
- if randomize_seed:
24
- seed = random.randint(0, MAX_SEED)
25
-
26
- generator = torch.Generator().manual_seed(seed)
27
-
28
- image = pipe(
29
- prompt = prompt,
30
- negative_prompt = negative_prompt,
31
- guidance_scale = guidance_scale,
32
- num_inference_steps = num_inference_steps,
33
- width = width,
34
- height = height,
35
- generator = generator
36
- ).images[0]
37
 
 
 
 
38
  return image
39
 
40
  examples = [
@@ -56,63 +68,38 @@ else:
56
  power_device = "CPU"
57
 
58
  with gr.Blocks(css=css) as demo:
59
-
 
 
 
60
  with gr.Column(elem_id="col-container"):
61
- gr.Markdown(f"""
62
- # Text-to-Image Gradio Template
63
- Currently running on {power_device}.
64
- """)
65
 
66
  with gr.Row():
67
 
68
- prompt = gr.Text(
69
- label="Prompt",
70
  show_label=False,
71
  max_lines=1,
72
- placeholder="Enter your prompt",
73
  container=False,
74
  )
75
-
76
- run_button = gr.Button("Run", scale=0)
77
 
78
- result = gr.Image(label="Result", show_label=False)
79
-
80
- with gr.Accordion("Advanced Settings", open=False):
81
-
82
- negative_prompt = gr.Text(
83
- label="Negative prompt",
84
  max_lines=1,
85
- placeholder="Enter a negative prompt",
86
- visible=False,
87
- )
88
-
89
- seed = gr.Slider(
90
- label="Seed",
91
- minimum=0,
92
- maximum=MAX_SEED,
93
- step=1,
94
- value=0,
95
  )
96
-
97
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
98
-
99
- with gr.Row():
100
-
101
- width = gr.Slider(
102
- label="Width",
103
- minimum=256,
104
- maximum=MAX_IMAGE_SIZE,
105
- step=32,
106
- value=512,
107
- )
108
-
109
- height = gr.Slider(
110
- label="Height",
111
- minimum=256,
112
- maximum=MAX_IMAGE_SIZE,
113
- step=32,
114
- value=512,
115
- )
116
 
117
  with gr.Row():
118
 
@@ -121,25 +108,33 @@ with gr.Blocks(css=css) as demo:
121
  minimum=0.0,
122
  maximum=10.0,
123
  step=0.1,
124
- value=0.0,
125
  )
126
 
127
  num_inference_steps = gr.Slider(
128
- label="Number of inference steps",
129
  minimum=1,
130
  maximum=12,
131
  step=1,
132
- value=2,
133
  )
 
134
 
135
- gr.Examples(
136
- examples = examples,
137
- inputs = [prompt]
138
- )
 
 
 
 
 
 
 
139
 
140
  run_button.click(
141
  fn = infer,
142
- inputs = [prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
143
  outputs = [result]
144
  )
145
 
 
3
  import random
4
  from diffusers import DiffusionPipeline
5
  import torch
6
+ from src.euler_scheduler import MyEulerAncestralDiscreteScheduler
7
+ from diffusers.pipelines.auto_pipeline import AutoPipelineForImage2Image
8
+ from src.sdxl_inversion_pipeline import SDXLDDIMPipeline
9
+ from src.config import RunConfig
10
+ from src.editor import ImageEditorDemo
11
 
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
13
 
14
+
15
+ scheduler_class = MyEulerAncestralDiscreteScheduler
16
+
17
+
18
+ pipe_inversion = SDXLDDIMPipeline.from_pretrained("stabilityai/sdxl-turbo", use_safetensors=True).to(device)
19
+ pipe_inference = AutoPipelineForImage2Image.from_pretrained("stabilityai/sdxl-turbo", use_safetensors=True).to(device)
20
+ pipe_inference.scheduler = scheduler_class.from_config(pipe_inference.scheduler.config)
21
+ pipe_inversion.scheduler = scheduler_class.from_config(pipe_inversion.scheduler.config)
22
+ pipe_inversion.scheduler_inference = scheduler_class.from_config(pipe_inference.scheduler.config)
23
+
24
+
25
+ # if torch.cuda.is_available():
26
+ # torch.cuda.max_memory_allocated(device=device)
27
+ # pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16", use_safetensors=True)
28
+ # pipe.enable_xformers_memory_efficient_attention()
29
+ # pipe = pipe.to(device)
30
+ # else:
31
+ # pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", use_safetensors=True)
32
+ # pipe = pipe.to(device)
33
 
34
  MAX_SEED = np.iinfo(np.int32).max
35
  MAX_IMAGE_SIZE = 1024
36
 
 
37
 
38
+
39
+
40
+
41
+ def infer(input_image, description_prompt, target_prompt, guidance_scale, num_inference_steps=4, num_inversion_steps=4, inversion_max_step=0.6):
42
+ config = RunConfig(num_inference_steps=num_inference_steps,
43
+ num_inversion_steps=num_inversion_steps,
44
+ guidance_scale=guidance_scale,
45
+ inversion_max_step=inversion_max_step)
 
 
 
 
 
 
46
 
47
+ editor = ImageEditorDemo(pipe_inversion, pipe_inference, input_image, description_prompt, config)
48
+
49
+ image = editor.edit(target_prompt)
50
  return image
51
 
52
  examples = [
 
68
  power_device = "CPU"
69
 
70
  with gr.Blocks(css=css) as demo:
71
+
72
+ gr.Markdown(f"""
73
+ # RNRI briel and links on device: {power_device}.
74
+ """)
75
  with gr.Column(elem_id="col-container"):
76
+
77
+ with gr.Row():
78
+ input_image = gr.Image(label="Input image", sources=['upload', 'webcam', 'clipboard'], type="pil")
 
79
 
80
  with gr.Row():
81
 
82
+ description_prompt = gr.Text(
83
+ label="Image description",
84
  show_label=False,
85
  max_lines=1,
86
+ placeholder="Enter your image description",
87
  container=False,
88
  )
89
+
 
90
 
91
+ with gr.Row():
92
+
93
+ target_prompt = gr.Text(
94
+ label="Edit prompt",
95
+ show_label=False,
 
96
  max_lines=1,
97
+ placeholder="Enter your edit prompt",
98
+ container=False,
 
 
 
 
 
 
 
 
99
  )
100
+
101
+
102
+ with gr.Accordion("Advanced Settings", open=False):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
  with gr.Row():
105
 
 
108
  minimum=0.0,
109
  maximum=10.0,
110
  step=0.1,
111
+ value=1.2,
112
  )
113
 
114
  num_inference_steps = gr.Slider(
115
+ label="Number of RNRI iterations",
116
  minimum=1,
117
  maximum=12,
118
  step=1,
119
+ value=4,
120
  )
121
+
122
 
123
+ with gr.Row():
124
+ run_button = gr.Button("Edit", scale=0)
125
+
126
+ with gr.Column(elem_id="col-container"):
127
+
128
+ result = gr.Image(label="Result", show_label=False)
129
+
130
+ # gr.Examples(
131
+ # examples = examples,
132
+ # inputs = [prompt]
133
+ # )
134
 
135
  run_button.click(
136
  fn = infer,
137
+ inputs = [input_image, description_prompt, target_prompt, guidance_scale, num_inference_steps, num_inference_steps],
138
  outputs = [result]
139
  )
140
 
requirements.txt CHANGED
@@ -1,6 +1,8 @@
1
- accelerate
2
- diffusers
3
  invisible_watermark
4
- torch
5
- transformers
6
- xformers
 
 
 
1
+ accelerate==0.25.0
2
+ diffusers==0.24.0
3
  invisible_watermark
4
+ torch==2.2.0
5
+ transformers==4.32.1
6
+ xformers
7
+ torchvision==0.17.0
8
+ pyrallis==0.3.1
src/config.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code is based on ReNoise https://github.com/garibida/ReNoise-Inversion
2
+
3
+ from dataclasses import dataclass
4
+
5
+
6
+ @dataclass
7
+ class RunConfig:
8
+ num_inference_steps: int = 4
9
+
10
+ num_inversion_steps: int = 100
11
+
12
+ guidance_scale: float = 0.0
13
+
14
+ inversion_max_step: float = 1.0
15
+
16
+ def __post_init__(self):
17
+ pass
src/editor.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from src.config import RunConfig
3
+ import PIL
4
+ from src.euler_scheduler import MyEulerAncestralDiscreteScheduler
5
+ from diffusers.pipelines.auto_pipeline import AutoPipelineForImage2Image
6
+ from src.sdxl_inversion_pipeline import SDXLDDIMPipeline
7
+
8
+ from diffusers.utils.torch_utils import randn_tensor
9
+
10
+
11
+ def inversion_callback(pipe, step, timestep, callback_kwargs):
12
+ return callback_kwargs
13
+
14
+ def inference_callback(pipe, step, timestep, callback_kwargs):
15
+ return callback_kwargs
16
+
17
+ def center_crop(im):
18
+ width, height = im.size # Get dimensions
19
+ min_dim = min(width, height)
20
+ left = (width - min_dim) / 2
21
+ top = (height - min_dim) / 2
22
+ right = (width + min_dim) / 2
23
+ bottom = (height + min_dim) / 2
24
+
25
+ # Crop the center of the image
26
+ im = im.crop((left, top, right, bottom))
27
+ return im
28
+
29
+
30
+ def load_im_into_format_from_path(im_path):
31
+ if isinstance(im_path, str):
32
+ return center_crop(PIL.Image.open(im_path)).resize((512, 512))
33
+ else:
34
+ return center_crop(im_path).resize((512, 512))
35
+
36
+
37
+ class ImageEditorDemo:
38
+ def __init__(self, pipe_inversion, pipe_inference, input_image, description_prompt, cfg):
39
+ self.pipe_inversion = pipe_inversion
40
+ self.pipe_inference = pipe_inference
41
+ self.original_image = load_im_into_format_from_path(input_image).convert("RGB")
42
+ self.load_image = True
43
+ g_cpu = torch.Generator().manual_seed(7865)
44
+ img_size = (512,512)
45
+ VQAE_SCALE = 8
46
+ latents_size = (1, 4, img_size[0] // VQAE_SCALE, img_size[1] // VQAE_SCALE)
47
+ noise = [randn_tensor(latents_size, dtype=torch.float16, device=torch.device("cuda:0"), generator=g_cpu) for i
48
+ in range(cfg.num_inversion_steps)]
49
+ pipe_inversion.scheduler.set_noise_list(noise)
50
+ pipe_inference.scheduler.set_noise_list(noise)
51
+ pipe_inversion.scheduler_inference.set_noise_list(noise)
52
+ pipe_inversion.set_progress_bar_config(disable=True)
53
+ pipe_inference.set_progress_bar_config(disable=True)
54
+ self.cfg = cfg
55
+ self.pipe_inversion.cfg = cfg
56
+ self.pipe_inference.cfg = cfg
57
+ self.inv_hp = [2, 0.1, 0.2]
58
+ self.edit_cfg = 1.2
59
+
60
+ self.pipe_inference.to("cuda")
61
+ self.pipe_inversion.to("cuda")
62
+
63
+ self.last_latent = self.invert(self.original_image, description_prompt)
64
+ self.original_latent = self.last_latent
65
+
66
+ def invert(self, init_image, base_prompt):
67
+ res = self.pipe_inversion(prompt=base_prompt,
68
+ num_inversion_steps=self.cfg.num_inversion_steps,
69
+ num_inference_steps=self.cfg.num_inference_steps,
70
+ image=init_image,
71
+ guidance_scale=self.cfg.guidance_scale,
72
+ callback_on_step_end=inversion_callback,
73
+ strength=self.cfg.inversion_max_step,
74
+ denoising_start=1.0 - self.cfg.inversion_max_step,
75
+ inv_hp=self.inv_hp)[0][0]
76
+ return res
77
+
78
+ def edit(self, target_prompt):
79
+ image = self.pipe_inference(prompt=target_prompt,
80
+ num_inference_steps=self.cfg.num_inference_steps,
81
+ negative_prompt="",
82
+ callback_on_step_end=inference_callback,
83
+ image=self.last_latent,
84
+ strength=self.cfg.inversion_max_step,
85
+ denoising_start=1.0 - self.cfg.inversion_max_step,
86
+ guidance_scale=self.edit_cfg).images[0]
87
+ return image
88
+
src/euler_scheduler.py ADDED
@@ -0,0 +1,584 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code is based on ReNoise https://github.com/garibida/ReNoise-Inversion
2
+
3
+ from diffusers import EulerAncestralDiscreteScheduler
4
+ from diffusers.utils import BaseOutput
5
+ import torch
6
+ from typing import List, Optional, Tuple, Union
7
+ import numpy as np
8
+
9
+ from src.eunms import Epsilon_Update_Type
10
+
11
+ class EulerAncestralDiscreteSchedulerOutput(BaseOutput):
12
+ """
13
+ Output class for the scheduler's `step` function output.
14
+
15
+ Args:
16
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
17
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
18
+ denoising loop.
19
+ pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
20
+ The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
21
+ `pred_original_sample` can be used to preview progress or for guidance.
22
+ """
23
+
24
+ prev_sample: torch.FloatTensor
25
+ pred_original_sample: Optional[torch.FloatTensor] = None
26
+
27
+ class MyEulerAncestralDiscreteScheduler(EulerAncestralDiscreteScheduler):
28
+ def set_noise_list(self, noise_list):
29
+ self.noise_list = noise_list
30
+
31
+ def get_noise_to_remove(self):
32
+ sigma_from = self.sigmas[self.step_index]
33
+ sigma_to = self.sigmas[self.step_index + 1]
34
+ sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
35
+
36
+ return self.noise_list[self.step_index] * sigma_up\
37
+
38
+ def scale_model_input(
39
+ self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
40
+ ) -> torch.FloatTensor:
41
+ """
42
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
43
+ current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
44
+
45
+ Args:
46
+ sample (`torch.FloatTensor`):
47
+ The input sample.
48
+ timestep (`int`, *optional*):
49
+ The current timestep in the diffusion chain.
50
+
51
+ Returns:
52
+ `torch.FloatTensor`:
53
+ A scaled input sample.
54
+ """
55
+
56
+ self._init_step_index(timestep.view((1)))
57
+ return EulerAncestralDiscreteScheduler.scale_model_input(self, sample, timestep)
58
+
59
+
60
+ def step(
61
+ self,
62
+ model_output: torch.FloatTensor,
63
+ timestep: Union[float, torch.FloatTensor],
64
+ sample: torch.FloatTensor,
65
+ generator: Optional[torch.Generator] = None,
66
+ return_dict: bool = True,
67
+ ) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]:
68
+ """
69
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
70
+ process from the learned model outputs (most often the predicted noise).
71
+
72
+ Args:
73
+ model_output (`torch.FloatTensor`):
74
+ The direct output from learned diffusion model.
75
+ timestep (`float`):
76
+ The current discrete timestep in the diffusion chain.
77
+ sample (`torch.FloatTensor`):
78
+ A current instance of a sample created by the diffusion process.
79
+ generator (`torch.Generator`, *optional*):
80
+ A random number generator.
81
+ return_dict (`bool`):
82
+ Whether or not to return a
83
+ [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or tuple.
84
+
85
+ Returns:
86
+ [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or `tuple`:
87
+ If return_dict is `True`,
88
+ [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] is returned,
89
+ otherwise a tuple is returned where the first element is the sample tensor.
90
+
91
+ """
92
+
93
+ if (
94
+ isinstance(timestep, int)
95
+ or isinstance(timestep, torch.IntTensor)
96
+ or isinstance(timestep, torch.LongTensor)
97
+ ):
98
+ raise ValueError(
99
+ (
100
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
101
+ " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
102
+ " one of the `scheduler.timesteps` as a timestep."
103
+ ),
104
+ )
105
+
106
+ if not self.is_scale_input_called:
107
+ logger.warning(
108
+ "The `scale_model_input` function should be called before `step` to ensure correct denoising. "
109
+ "See `StableDiffusionPipeline` for a usage example."
110
+ )
111
+
112
+ self._init_step_index(timestep.view((1)))
113
+
114
+ sigma = self.sigmas[self.step_index]
115
+
116
+ # Upcast to avoid precision issues when computing prev_sample
117
+ sample = sample.to(torch.float32)
118
+
119
+ # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
120
+ if self.config.prediction_type == "epsilon":
121
+ pred_original_sample = sample - sigma * model_output
122
+ elif self.config.prediction_type == "v_prediction":
123
+ # * c_out + input * c_skip
124
+ pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
125
+ elif self.config.prediction_type == "sample":
126
+ raise NotImplementedError("prediction_type not implemented yet: sample")
127
+ else:
128
+ raise ValueError(
129
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
130
+ )
131
+
132
+ sigma_from = self.sigmas[self.step_index]
133
+ sigma_to = self.sigmas[self.step_index + 1]
134
+ sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
135
+ sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
136
+
137
+ # 2. Convert to an ODE derivative
138
+ # derivative = (sample - pred_original_sample) / sigma
139
+ derivative = model_output
140
+
141
+ dt = sigma_down - sigma
142
+
143
+ prev_sample = sample + derivative * dt
144
+
145
+ device = model_output.device
146
+ # noise = randn_tensor(model_output.shape, dtype=model_output.dtype, device=device, generator=generator)
147
+ # prev_sample = prev_sample + noise * sigma_up
148
+
149
+ prev_sample = prev_sample + self.noise_list[self.step_index] * sigma_up
150
+
151
+ # Cast sample back to model compatible dtype
152
+ prev_sample = prev_sample.to(model_output.dtype)
153
+
154
+ # upon completion increase step index by one
155
+ self._step_index += 1
156
+
157
+ if not return_dict:
158
+ return (prev_sample,)
159
+
160
+ return EulerAncestralDiscreteSchedulerOutput(
161
+ prev_sample=prev_sample, pred_original_sample=pred_original_sample
162
+ )
163
+
164
+ def step_and_update_noise(
165
+ self,
166
+ model_output: torch.FloatTensor,
167
+ timestep: Union[float, torch.FloatTensor],
168
+ sample: torch.FloatTensor,
169
+ expected_prev_sample: torch.FloatTensor,
170
+ update_epsilon_type=Epsilon_Update_Type.OVERRIDE,
171
+ generator: Optional[torch.Generator] = None,
172
+ return_dict: bool = True,
173
+ ) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]:
174
+ """
175
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
176
+ process from the learned model outputs (most often the predicted noise).
177
+
178
+ Args:
179
+ model_output (`torch.FloatTensor`):
180
+ The direct output from learned diffusion model.
181
+ timestep (`float`):
182
+ The current discrete timestep in the diffusion chain.
183
+ sample (`torch.FloatTensor`):
184
+ A current instance of a sample created by the diffusion process.
185
+ generator (`torch.Generator`, *optional*):
186
+ A random number generator.
187
+ return_dict (`bool`):
188
+ Whether or not to return a
189
+ [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or tuple.
190
+
191
+ Returns:
192
+ [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or `tuple`:
193
+ If return_dict is `True`,
194
+ [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] is returned,
195
+ otherwise a tuple is returned where the first element is the sample tensor.
196
+
197
+ """
198
+
199
+ if (
200
+ isinstance(timestep, int)
201
+ or isinstance(timestep, torch.IntTensor)
202
+ or isinstance(timestep, torch.LongTensor)
203
+ ):
204
+ raise ValueError(
205
+ (
206
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
207
+ " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
208
+ " one of the `scheduler.timesteps` as a timestep."
209
+ ),
210
+ )
211
+
212
+ if not self.is_scale_input_called:
213
+ logger.warning(
214
+ "The `scale_model_input` function should be called before `step` to ensure correct denoising. "
215
+ "See `StableDiffusionPipeline` for a usage example."
216
+ )
217
+
218
+ self._init_step_index(timestep.view((1)))
219
+
220
+ sigma = self.sigmas[self.step_index]
221
+
222
+ # Upcast to avoid precision issues when computing prev_sample
223
+ sample = sample.to(torch.float32)
224
+
225
+ # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
226
+ if self.config.prediction_type == "epsilon":
227
+ pred_original_sample = sample - sigma * model_output
228
+ elif self.config.prediction_type == "v_prediction":
229
+ # * c_out + input * c_skip
230
+ pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
231
+ elif self.config.prediction_type == "sample":
232
+ raise NotImplementedError("prediction_type not implemented yet: sample")
233
+ else:
234
+ raise ValueError(
235
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
236
+ )
237
+
238
+ sigma_from = self.sigmas[self.step_index]
239
+ sigma_to = self.sigmas[self.step_index + 1]
240
+ sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
241
+ sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
242
+
243
+ # 2. Convert to an ODE derivative
244
+ # derivative = (sample - pred_original_sample) / sigma
245
+ derivative = model_output
246
+
247
+ dt = sigma_down - sigma
248
+
249
+ prev_sample = sample + derivative * dt
250
+
251
+ device = model_output.device
252
+ # noise = randn_tensor(model_output.shape, dtype=model_output.dtype, device=device, generator=generator)
253
+ # prev_sample = prev_sample + noise * sigma_up
254
+
255
+ if sigma_up > 0:
256
+ req_noise = (expected_prev_sample - prev_sample) / sigma_up
257
+ if update_epsilon_type == Epsilon_Update_Type.OVERRIDE:
258
+ self.noise_list[self.step_index] = req_noise
259
+ else:
260
+ for i in range(10):
261
+ n = torch.autograd.Variable(self.noise_list[self.step_index].detach().clone(), requires_grad=True)
262
+ loss = torch.norm(n - req_noise.detach())
263
+ loss.backward()
264
+ self.noise_list[self.step_index] -= n.grad.detach() * 1.8
265
+
266
+
267
+ prev_sample = prev_sample + self.noise_list[self.step_index] * sigma_up
268
+
269
+ # Cast sample back to model compatible dtype
270
+ prev_sample = prev_sample.to(model_output.dtype)
271
+
272
+ # upon completion increase step index by one
273
+ self._step_index += 1
274
+
275
+ if not return_dict:
276
+ return (prev_sample,)
277
+
278
+ return EulerAncestralDiscreteSchedulerOutput(
279
+ prev_sample=prev_sample, pred_original_sample=pred_original_sample
280
+ )
281
+
282
+ def inv_step(
283
+ self,
284
+ model_output: torch.FloatTensor,
285
+ timestep: Union[float, torch.FloatTensor],
286
+ sample: torch.FloatTensor,
287
+ generator: Optional[torch.Generator] = None,
288
+ return_dict: bool = True,
289
+ ) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]:
290
+ """
291
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
292
+ process from the learned model outputs (most often the predicted noise).
293
+
294
+ Args:
295
+ model_output (`torch.FloatTensor`):
296
+ The direct output from learned diffusion model.
297
+ timestep (`float`):
298
+ The current discrete timestep in the diffusion chain.
299
+ sample (`torch.FloatTensor`):
300
+ A current instance of a sample created by the diffusion process.
301
+ generator (`torch.Generator`, *optional*):
302
+ A random number generator.
303
+ return_dict (`bool`):
304
+ Whether or not to return a
305
+ [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or tuple.
306
+
307
+ Returns:
308
+ [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or `tuple`:
309
+ If return_dict is `True`,
310
+ [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] is returned,
311
+ otherwise a tuple is returned where the first element is the sample tensor.
312
+
313
+ """
314
+
315
+ if (
316
+ isinstance(timestep, int)
317
+ or isinstance(timestep, torch.IntTensor)
318
+ or isinstance(timestep, torch.LongTensor)
319
+ ):
320
+ raise ValueError(
321
+ (
322
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
323
+ " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
324
+ " one of the `scheduler.timesteps` as a timestep."
325
+ ),
326
+ )
327
+
328
+ if not self.is_scale_input_called:
329
+ logger.warning(
330
+ "The `scale_model_input` function should be called before `step` to ensure correct denoising. "
331
+ "See `StableDiffusionPipeline` for a usage example."
332
+ )
333
+
334
+ self._init_step_index(timestep.view((1)))
335
+
336
+ sigma = self.sigmas[self.step_index]
337
+
338
+ # Upcast to avoid precision issues when computing prev_sample
339
+ sample = sample.to(torch.float32)
340
+
341
+ # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
342
+ if self.config.prediction_type == "epsilon":
343
+ pred_original_sample = sample - sigma * model_output
344
+ elif self.config.prediction_type == "v_prediction":
345
+ # * c_out + input * c_skip
346
+ pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
347
+ elif self.config.prediction_type == "sample":
348
+ raise NotImplementedError("prediction_type not implemented yet: sample")
349
+ else:
350
+ raise ValueError(
351
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
352
+ )
353
+
354
+ sigma_from = self.sigmas[self.step_index]
355
+ sigma_to = self.sigmas[self.step_index+1]
356
+ # sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
357
+ sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2).abs() / sigma_from**2) ** 0.5
358
+ # sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
359
+ sigma_down = sigma_to**2 / sigma_from
360
+
361
+ # 2. Convert to an ODE derivative
362
+ # derivative = (sample - pred_original_sample) / sigma
363
+ derivative = model_output
364
+
365
+ dt = sigma_down - sigma
366
+ # dt = sigma_down - sigma_from
367
+
368
+ prev_sample = sample - derivative * dt
369
+
370
+ device = model_output.device
371
+ # noise = randn_tensor(model_output.shape, dtype=model_output.dtype, device=device, generator=generator)
372
+ # prev_sample = prev_sample + noise * sigma_up
373
+
374
+ prev_sample = prev_sample - self.noise_list[self.step_index] * sigma_up
375
+
376
+ # Cast sample back to model compatible dtype
377
+ prev_sample = prev_sample.to(model_output.dtype)
378
+
379
+ # upon completion increase step index by one
380
+ self._step_index += 1
381
+
382
+ if not return_dict:
383
+ return (prev_sample,)
384
+
385
+ return EulerAncestralDiscreteSchedulerOutput(
386
+ prev_sample=prev_sample, pred_original_sample=pred_original_sample
387
+ )
388
+
389
+ def get_all_sigmas(self) -> torch.FloatTensor:
390
+ sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
391
+ sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32)
392
+ return torch.from_numpy(sigmas)
393
+
394
+ def add_noise_off_schedule(
395
+ self,
396
+ original_samples: torch.FloatTensor,
397
+ noise: torch.FloatTensor,
398
+ timesteps: torch.FloatTensor,
399
+ ) -> torch.FloatTensor:
400
+ # Make sure sigmas and timesteps have the same device and dtype as original_samples
401
+ sigmas = self.get_all_sigmas()
402
+ sigmas = sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
403
+ if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
404
+ # mps does not support float64
405
+ timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
406
+ else:
407
+ timesteps = timesteps.to(original_samples.device)
408
+
409
+ step_indices = 1000 - int(timesteps.item())
410
+
411
+ sigma = sigmas[step_indices].flatten()
412
+ while len(sigma.shape) < len(original_samples.shape):
413
+ sigma = sigma.unsqueeze(-1)
414
+
415
+ noisy_samples = original_samples + noise * sigma
416
+ return noisy_samples
417
+
418
+ # def update_noise_for_friendly_inversion(
419
+ # self,
420
+ # model_output: torch.FloatTensor,
421
+ # timestep: Union[float, torch.FloatTensor],
422
+ # z_t: torch.FloatTensor,
423
+ # z_tp1: torch.FloatTensor,
424
+ # return_dict: bool = True,
425
+ # ) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]:
426
+ # if (
427
+ # isinstance(timestep, int)
428
+ # or isinstance(timestep, torch.IntTensor)
429
+ # or isinstance(timestep, torch.LongTensor)
430
+ # ):
431
+ # raise ValueError(
432
+ # (
433
+ # "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
434
+ # " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
435
+ # " one of the `scheduler.timesteps` as a timestep."
436
+ # ),
437
+ # )
438
+
439
+ # if not self.is_scale_input_called:
440
+ # logger.warning(
441
+ # "The `scale_model_input` function should be called before `step` to ensure correct denoising. "
442
+ # "See `StableDiffusionPipeline` for a usage example."
443
+ # )
444
+
445
+ # self._init_step_index(timestep.view((1)))
446
+
447
+ # sigma = self.sigmas[self.step_index]
448
+
449
+ # sigma_from = self.sigmas[self.step_index]
450
+ # sigma_to = self.sigmas[self.step_index+1]
451
+ # # sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
452
+ # sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2).abs() / sigma_from**2) ** 0.5
453
+ # # sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
454
+ # sigma_down = sigma_to**2 / sigma_from
455
+
456
+ # # 2. Conv = (sample - pred_original_sample) / sigma
457
+ # derivative = model_output
458
+
459
+ # dt = sigma_down - sigma
460
+ # # dt = sigma_down - sigma_from
461
+
462
+ # prev_sample = z_t - derivative * dt
463
+
464
+ # if sigma_up > 0:
465
+ # self.noise_list[self.step_index] = (prev_sample - z_tp1) / sigma_up
466
+
467
+ # prev_sample = prev_sample - self.noise_list[self.step_index] * sigma_up
468
+
469
+
470
+ # if not return_dict:
471
+ # return (prev_sample,)
472
+
473
+ # return EulerAncestralDiscreteSchedulerOutput(
474
+ # prev_sample=prev_sample, pred_original_sample=None
475
+ # )
476
+
477
+
478
+ # def step_friendly_inversion(
479
+ # self,
480
+ # model_output: torch.FloatTensor,
481
+ # timestep: Union[float, torch.FloatTensor],
482
+ # sample: torch.FloatTensor,
483
+ # generator: Optional[torch.Generator] = None,
484
+ # return_dict: bool = True,
485
+ # expected_next_sample: torch.FloatTensor = None,
486
+ # ) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]:
487
+ # """
488
+ # Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
489
+ # process from the learned model outputs (most often the predicted noise).
490
+
491
+ # Args:
492
+ # model_output (`torch.FloatTensor`):
493
+ # The direct output from learned diffusion model.
494
+ # timestep (`float`):
495
+ # The current discrete timestep in the diffusion chain.
496
+ # sample (`torch.FloatTensor`):
497
+ # A current instance of a sample created by the diffusion process.
498
+ # generator (`torch.Generator`, *optional*):
499
+ # A random number generator.
500
+ # return_dict (`bool`):
501
+ # Whether or not to return a
502
+ # [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or tuple.
503
+
504
+ # Returns:
505
+ # [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or `tuple`:
506
+ # If return_dict is `True`,
507
+ # [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] is returned,
508
+ # otherwise a tuple is returned where the first element is the sample tensor.
509
+
510
+ # """
511
+
512
+ # if (
513
+ # isinstance(timestep, int)
514
+ # or isinstance(timestep, torch.IntTensor)
515
+ # or isinstance(timestep, torch.LongTensor)
516
+ # ):
517
+ # raise ValueError(
518
+ # (
519
+ # "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
520
+ # " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
521
+ # " one of the `scheduler.timesteps` as a timestep."
522
+ # ),
523
+ # )
524
+
525
+ # if not self.is_scale_input_called:
526
+ # logger.warning(
527
+ # "The `scale_model_input` function should be called before `step` to ensure correct denoising. "
528
+ # "See `StableDiffusionPipeline` for a usage example."
529
+ # )
530
+
531
+ # self._init_step_index(timestep.view((1)))
532
+
533
+ # sigma = self.sigmas[self.step_index]
534
+
535
+ # # Upcast to avoid precision issues when computing prev_sample
536
+ # sample = sample.to(torch.float32)
537
+
538
+ # # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
539
+ # if self.config.prediction_type == "epsilon":
540
+ # pred_original_sample = sample - sigma * model_output
541
+ # elif self.config.prediction_type == "v_prediction":
542
+ # # * c_out + input * c_skip
543
+ # pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
544
+ # elif self.config.prediction_type == "sample":
545
+ # raise NotImplementedError("prediction_type not implemented yet: sample")
546
+ # else:
547
+ # raise ValueError(
548
+ # f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
549
+ # )
550
+
551
+ # sigma_from = self.sigmas[self.step_index]
552
+ # sigma_to = self.sigmas[self.step_index + 1]
553
+ # sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
554
+ # sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
555
+
556
+ # # 2. Convert to an ODE derivative
557
+ # # derivative = (sample - pred_original_sample) / sigma
558
+ # derivative = model_output
559
+
560
+ # dt = sigma_down - sigma
561
+
562
+ # prev_sample = sample + derivative * dt
563
+
564
+ # device = model_output.device
565
+ # # noise = randn_tensor(model_output.shape, dtype=model_output.dtype, device=device, generator=generator)
566
+ # # prev_sample = prev_sample + noise * sigma_up
567
+
568
+ # if sigma_up > 0:
569
+ # self.noise_list[self.step_index] = (expected_next_sample - prev_sample) / sigma_up
570
+
571
+ # prev_sample = prev_sample + self.noise_list[self.step_index] * sigma_up
572
+
573
+ # # Cast sample back to model compatible dtype
574
+ # prev_sample = prev_sample.to(model_output.dtype)
575
+
576
+ # # upon completion increase step index by one
577
+ # self._step_index += 1
578
+
579
+ # if not return_dict:
580
+ # return (prev_sample,)
581
+
582
+ # return EulerAncestralDiscreteSchedulerOutput(
583
+ # prev_sample=prev_sample, pred_original_sample=pred_original_sample
584
+ # )
src/eunms.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+
3
+ class Scheduler_Type(Enum):
4
+ DDIM = 1
5
+ EULER = 2
6
+ LCM = 3
7
+ DDPM = 4
8
+
9
+ class Model_Type(Enum):
10
+ SDXL = 1
11
+ SDXL_Turbo = 2
12
+ LCM_SDXL = 3
13
+ SD15 = 4
14
+ SD21 = 5
15
+ SD21_Turbo = 6
16
+ SD14 = 7
17
+
18
+ class Gradient_Averaging_Type(Enum):
19
+ NONE = 1
20
+ EACH_ITER = 2
21
+ ON_END = 3
22
+
23
+ class Epsilon_Update_Type(Enum):
24
+ NONE = 1
25
+ OVERRIDE = 2
26
+ OPTIMIZE = 3
src/sdxl_inversion_pipeline.py ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code is based on ReNoise https://github.com/garibida/ReNoise-Inversion
2
+
3
+ import torch
4
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
5
+
6
+ from diffusers import (
7
+ StableDiffusionXLImg2ImgPipeline,
8
+ )
9
+ from diffusers.utils.torch_utils import randn_tensor
10
+
11
+ from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import (
12
+ StableDiffusionXLPipelineOutput,
13
+ retrieve_timesteps,
14
+ PipelineImageInput
15
+ )
16
+
17
+ from src.eunms import Epsilon_Update_Type
18
+
19
+
20
+ def _backward_ddim(x_tm1, alpha_t, alpha_tm1, eps_xt):
21
+ """
22
+ let a = alpha_t, b = alpha_{t - 1}
23
+ We have a > b,
24
+ x_{t} - x_{t - 1} = sqrt(a) ((sqrt(1/b) - sqrt(1/a)) * x_{t-1} + (sqrt(1/a - 1) - sqrt(1/b - 1)) * eps_{t-1})
25
+ From https://arxiv.org/pdf/2105.05233.pdf, section F.
26
+ """
27
+
28
+ a, b = alpha_t, alpha_tm1
29
+ sa = a ** 0.5
30
+ sb = b ** 0.5
31
+
32
+ return sa * ((1 / sb) * x_tm1 + ((1 / a - 1) ** 0.5 - (1 / b - 1) ** 0.5) * eps_xt)
33
+
34
+
35
+ class SDXLDDIMPipeline(StableDiffusionXLImg2ImgPipeline):
36
+ # @torch.no_grad()
37
+ def __call__(
38
+ self,
39
+ prompt: Union[str, List[str]] = None,
40
+ prompt_2: Optional[Union[str, List[str]]] = None,
41
+ image: PipelineImageInput = None,
42
+ strength: float = 0.3,
43
+ num_inversion_steps: int = 50,
44
+ timesteps: List[int] = None,
45
+ denoising_start: Optional[float] = None,
46
+ denoising_end: Optional[float] = None,
47
+ guidance_scale: float = 1.0,
48
+ negative_prompt: Optional[Union[str, List[str]]] = None,
49
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
50
+ num_images_per_prompt: Optional[int] = 1,
51
+ eta: float = 0.0,
52
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
53
+ latents: Optional[torch.FloatTensor] = None,
54
+ prompt_embeds: Optional[torch.FloatTensor] = None,
55
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
56
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
57
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
58
+ ip_adapter_image: Optional[PipelineImageInput] = None,
59
+ output_type: Optional[str] = "pil",
60
+ return_dict: bool = True,
61
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
62
+ guidance_rescale: float = 0.0,
63
+ original_size: Tuple[int, int] = None,
64
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
65
+ target_size: Tuple[int, int] = None,
66
+ negative_original_size: Optional[Tuple[int, int]] = None,
67
+ negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
68
+ negative_target_size: Optional[Tuple[int, int]] = None,
69
+ aesthetic_score: float = 6.0,
70
+ negative_aesthetic_score: float = 2.5,
71
+ clip_skip: Optional[int] = None,
72
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
73
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
74
+ num_inference_steps: int = 50,
75
+ inv_hp=None,
76
+ **kwargs,
77
+ ):
78
+ callback = kwargs.pop("callback", None)
79
+ callback_steps = kwargs.pop("callback_steps", None)
80
+
81
+ if callback is not None:
82
+ deprecate(
83
+ "callback",
84
+ "1.0.0",
85
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
86
+ )
87
+ if callback_steps is not None:
88
+ deprecate(
89
+ "callback_steps",
90
+ "1.0.0",
91
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
92
+ )
93
+
94
+ # 1. Check inputs. Raise error if not correct
95
+ self.check_inputs(
96
+ prompt,
97
+ prompt_2,
98
+ strength,
99
+ num_inversion_steps,
100
+ callback_steps,
101
+ negative_prompt,
102
+ negative_prompt_2,
103
+ prompt_embeds,
104
+ negative_prompt_embeds,
105
+ callback_on_step_end_tensor_inputs,
106
+ )
107
+
108
+ denoising_start_fr = 1.0 - denoising_start
109
+ denoising_start = denoising_start
110
+
111
+ self._guidance_scale = guidance_scale
112
+ self._guidance_rescale = guidance_rescale
113
+ self._clip_skip = clip_skip
114
+ self._cross_attention_kwargs = cross_attention_kwargs
115
+ self._denoising_end = denoising_end
116
+ self._denoising_start = denoising_start
117
+
118
+ # 2. Define call parameters
119
+ if prompt is not None and isinstance(prompt, str):
120
+ batch_size = 1
121
+ elif prompt is not None and isinstance(prompt, list):
122
+ batch_size = len(prompt)
123
+ else:
124
+ batch_size = prompt_embeds.shape[0]
125
+
126
+ device = self._execution_device
127
+
128
+ # 3. Encode input prompt
129
+ text_encoder_lora_scale = (
130
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
131
+ )
132
+ (
133
+ prompt_embeds,
134
+ negative_prompt_embeds,
135
+ pooled_prompt_embeds,
136
+ negative_pooled_prompt_embeds,
137
+ ) = self.encode_prompt(
138
+ prompt=prompt,
139
+ prompt_2=prompt_2,
140
+ device=device,
141
+ num_images_per_prompt=num_images_per_prompt,
142
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
143
+ negative_prompt=negative_prompt,
144
+ negative_prompt_2=negative_prompt_2,
145
+ prompt_embeds=prompt_embeds,
146
+ negative_prompt_embeds=negative_prompt_embeds,
147
+ pooled_prompt_embeds=pooled_prompt_embeds,
148
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
149
+ lora_scale=text_encoder_lora_scale,
150
+ clip_skip=self.clip_skip,
151
+ )
152
+
153
+ # 4. Preprocess image
154
+ image = self.image_processor.preprocess(image)
155
+
156
+ # 5. Prepare timesteps
157
+ def denoising_value_valid(dnv):
158
+ return isinstance(self.denoising_end, float) and 0 < dnv < 1
159
+
160
+ timesteps, num_inversion_steps = retrieve_timesteps(self.scheduler, num_inversion_steps, device, timesteps)
161
+ timesteps_num_inference_steps, num_inference_steps = retrieve_timesteps(self.scheduler_inference,
162
+ num_inference_steps, device, None)
163
+
164
+ timesteps, num_inversion_steps = self.get_timesteps(
165
+ num_inversion_steps,
166
+ strength,
167
+ device,
168
+ denoising_start=self.denoising_start if denoising_value_valid else None,
169
+ )
170
+ # latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
171
+
172
+ # add_noise = True if self.denoising_start is None else False
173
+ # 6. Prepare latent variables
174
+ with torch.no_grad():
175
+ latents = self.prepare_latents(
176
+ image,
177
+ None,
178
+ batch_size,
179
+ num_images_per_prompt,
180
+ prompt_embeds.dtype,
181
+ device,
182
+ generator,
183
+ False,
184
+ )
185
+ # 7. Prepare extra step kwargs.
186
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
187
+
188
+ height, width = latents.shape[-2:]
189
+ height = height * self.vae_scale_factor
190
+ width = width * self.vae_scale_factor
191
+
192
+ original_size = original_size or (height, width)
193
+ target_size = target_size or (height, width)
194
+
195
+ # 8. Prepare added time ids & embeddings
196
+ if negative_original_size is None:
197
+ negative_original_size = original_size
198
+ if negative_target_size is None:
199
+ negative_target_size = target_size
200
+
201
+ add_text_embeds = pooled_prompt_embeds
202
+ if self.text_encoder_2 is None:
203
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
204
+ else:
205
+ text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
206
+
207
+ add_time_ids, add_neg_time_ids = self._get_add_time_ids(
208
+ original_size,
209
+ crops_coords_top_left,
210
+ target_size,
211
+ aesthetic_score,
212
+ negative_aesthetic_score,
213
+ negative_original_size,
214
+ negative_crops_coords_top_left,
215
+ negative_target_size,
216
+ dtype=prompt_embeds.dtype,
217
+ text_encoder_projection_dim=text_encoder_projection_dim,
218
+ )
219
+ add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)
220
+
221
+ if self.do_classifier_free_guidance:
222
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
223
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
224
+ add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1)
225
+ add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0)
226
+
227
+ prompt_embeds = prompt_embeds.to(device)
228
+ add_text_embeds = add_text_embeds.to(device)
229
+ add_time_ids = add_time_ids.to(device)
230
+
231
+ if ip_adapter_image is not None:
232
+ image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt)
233
+ if self.do_classifier_free_guidance:
234
+ image_embeds = torch.cat([negative_image_embeds, image_embeds])
235
+ image_embeds = image_embeds.to(device)
236
+
237
+ # 9. Denoising loop
238
+ num_warmup_steps = max(len(timesteps) - num_inversion_steps * self.scheduler.order, 0)
239
+ prev_timestep = None
240
+
241
+ self._num_timesteps = len(timesteps)
242
+ self.prev_z = torch.clone(latents)
243
+ self.prev_z4 = torch.clone(latents)
244
+ self.z_0 = torch.clone(latents)
245
+ g_cpu = torch.Generator().manual_seed(7865)
246
+ self.noise = randn_tensor(self.z_0.shape, generator=g_cpu, device=self.z_0.device, dtype=self.z_0.dtype)
247
+
248
+ # Friendly inversion params
249
+ timesteps_for = reversed(timesteps)
250
+ noise = randn_tensor(latents.shape, generator=g_cpu, device=latents.device, dtype=latents.dtype)
251
+ #latents = latents
252
+ z_T = latents.clone()
253
+
254
+ all_latents = [latents.clone()]
255
+ with self.progress_bar(total=num_inversion_steps) as progress_bar:
256
+ for i, t in enumerate(timesteps_for):
257
+
258
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
259
+ if ip_adapter_image is not None:
260
+ added_cond_kwargs["image_embeds"] = image_embeds
261
+
262
+ z_tp1 = self.inversion_step(latents,
263
+ t,
264
+ prompt_embeds,
265
+ added_cond_kwargs,
266
+ prev_timestep=prev_timestep,
267
+ inv_hp=inv_hp,
268
+ z_0=self.z_0)
269
+
270
+ prev_timestep = t
271
+ latents = z_tp1
272
+
273
+ all_latents.append(latents.clone())
274
+
275
+ if callback_on_step_end is not None:
276
+ callback_kwargs = {}
277
+ for k in callback_on_step_end_tensor_inputs:
278
+ callback_kwargs[k] = locals()[k]
279
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
280
+
281
+ latents = callback_outputs.pop("latents", latents)
282
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
283
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
284
+ add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
285
+ negative_pooled_prompt_embeds = callback_outputs.pop(
286
+ "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
287
+ )
288
+ add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
289
+ add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids)
290
+
291
+ # call the callback, if provided
292
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
293
+ progress_bar.update()
294
+ if callback is not None and i % callback_steps == 0:
295
+ step_idx = i // getattr(self.scheduler, "order", 1)
296
+ callback(step_idx, t, latents)
297
+
298
+ image = latents
299
+
300
+ # Offload all models
301
+ self.maybe_free_model_hooks()
302
+
303
+ return StableDiffusionXLPipelineOutput(images=image), all_latents
304
+
305
+ def get_timestamp_dist(self, z_0, timesteps):
306
+ timesteps = timesteps.to(z_0.device)
307
+ sigma = self.scheduler.sigmas.cuda()[:-1][self.scheduler.timesteps == timesteps]
308
+ z_0 = z_0.reshape(-1, 1)
309
+
310
+ def gaussian_pdf(x):
311
+ shape = x.shape
312
+ x = x.reshape(-1, 1)
313
+ all_probs = - 0.5 * torch.pow(((x - z_0) / sigma), 2)
314
+ return all_probs.reshape(shape)
315
+
316
+ return gaussian_pdf
317
+
318
+ # @torch.no_grad()
319
+ def inversion_step(
320
+ self,
321
+ z_t: torch.tensor,
322
+ t: torch.tensor,
323
+ prompt_embeds,
324
+ added_cond_kwargs,
325
+ prev_timestep: Optional[torch.tensor] = None,
326
+ inv_hp=None,
327
+ z_0=None,
328
+ ) -> torch.tensor:
329
+
330
+ n_iters, alpha, lr = inv_hp
331
+ latent = z_t
332
+ best_latent = None
333
+ best_score = torch.inf
334
+ curr_dist = self.get_timestamp_dist(z_0, t)
335
+ for i in range(n_iters):
336
+ latent.requires_grad = True
337
+ noise_pred = self.unet_pass(latent, t, prompt_embeds, added_cond_kwargs)
338
+
339
+ next_latent = self.backward_step(noise_pred, t, z_t, prev_timestep)
340
+ f_x = (next_latent - latent).abs() - alpha * curr_dist(next_latent)
341
+ score = f_x.mean()
342
+
343
+ if score < best_score:
344
+ best_score = score
345
+ best_latent = next_latent.detach()
346
+
347
+ f_x.sum().backward()
348
+ latent = latent - lr * (f_x / latent.grad)
349
+ latent.grad = None
350
+ latent._grad_fn = None
351
+
352
+ # if self.cfg.update_epsilon_type != Epsilon_Update_Type.NONE:
353
+ # noise_pred = self.unet_pass(best_latent, t, prompt_embeds, added_cond_kwargs)
354
+ # self.scheduler.step_and_update_noise(noise_pred, t, best_latent, z_t, return_dict=False,
355
+ # update_epsilon_type=self.cfg.update_epsilon_type)
356
+ return best_latent
357
+
358
+ @torch.no_grad()
359
+ def unet_pass(self, z_t, t, prompt_embeds, added_cond_kwargs):
360
+ latent_model_input = torch.cat([z_t] * 2) if self.do_classifier_free_guidance else z_t
361
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
362
+ return self.unet(
363
+ latent_model_input,
364
+ t,
365
+ encoder_hidden_states=prompt_embeds,
366
+ timestep_cond=None,
367
+ cross_attention_kwargs=self.cross_attention_kwargs,
368
+ added_cond_kwargs=added_cond_kwargs,
369
+ return_dict=False,
370
+ )[0]
371
+
372
+ @torch.no_grad()
373
+ def backward_step(self, nosie_pred, t, z_t, prev_timestep):
374
+ extra_step_kwargs = {}
375
+ return self.scheduler.inv_step(nosie_pred, t, z_t, **extra_step_kwargs, return_dict=False)[0].detach()