Spaces:
Runtime error
Runtime error
File size: 7,700 Bytes
c19ca42 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 |
"""SAMPLING ONLY."""
import torch
from .uni_pc import NoiseScheduleVP, model_wrapper, UniPC, get_time_steps
from modules import shared, devices
class UniPCSampler(object):
def __init__(self, model, **kwargs):
super().__init__()
self.model = model
to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)
self.before_sample = None
self.after_sample = None
self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod))
self.noise_schedule = NoiseScheduleVP("discrete", alphas_cumprod=self.alphas_cumprod)
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
# persist steps so we can eventually find denoising strength
self.inflated_steps = ddim_num_steps
@devices.inference_context()
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
if noise is None:
noise = torch.randn_like(x0)
# first time we have all the info to get the real parameters from the ui
# value from the hires steps slider:
num_inference_steps = t[0] + 1
num_inference_steps / self.inflated_steps
self.denoise_steps = max(num_inference_steps, shared.opts.schedulers_solver_order)
max(self.inflated_steps - self.denoise_steps, 0)
# actual number of steps we'll run
all_timesteps = get_time_steps(
self.noise_schedule,
shared.opts.uni_pc_skip_type,
self.noise_schedule.T,
1./self.noise_schedule.total_N,
self.inflated_steps+1,
t.device,
)
# the rest of the timesteps will be used for denoising
self.timesteps = all_timesteps[-(self.denoise_steps+1):]
latent_timestep = (
( # get the timestep of our first denoise step
self.timesteps[:1]
# multiply by number of alphas to get int index
* self.noise_schedule.total_N
).int() - 1 # minus one for 0-indexed
).repeat(x0.shape[0])
alphas_cumprod = self.alphas_cumprod
sqrt_alpha_prod = alphas_cumprod[latent_timestep] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(x0.shape):
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[latent_timestep]) ** 0.5
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(x0.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
return (sqrt_alpha_prod * x0 + sqrt_one_minus_alpha_prod * noise)
def decode(self, x_latent, conditioning, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
use_original_steps=False, callback=None):
# same as in .sample(), i guess
model_type = "v" if self.model.parameterization == "v" else "noise"
model_fn = model_wrapper(
lambda x, t, c: self.model.apply_model(x, t, c),
self.noise_schedule,
model_type=model_type,
guidance_type="classifier-free",
#condition=conditioning,
#unconditional_condition=unconditional_conditioning,
guidance_scale=unconditional_guidance_scale,
)
self.uni_pc = UniPC(
model_fn,
self.noise_schedule,
predict_x0=True,
thresholding=False,
variant=shared.opts.uni_pc_variant,
condition=conditioning,
unconditional_condition=unconditional_conditioning,
before_sample=self.before_sample,
after_sample=self.after_sample,
after_update=self.after_update,
)
return self.uni_pc.sample(
x_latent,
steps=self.denoise_steps,
skip_type=shared.opts.uni_pc_skip_type,
method="multistep",
order=shared.opts.schedulers_solver_order,
lower_order_final=shared.opts.schedulers_use_loworder,
denoise_to_zero=True,
timesteps=self.timesteps,
)
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != devices.device:
attr = attr.to(devices.device)
setattr(self, name, attr)
def set_hooks(self, before_sample, after_sample, after_update):
self.before_sample = before_sample
self.after_sample = after_sample
self.after_update = after_update
@devices.inference_context()
def sample(self,
S,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.,
mask=None,
x0=None,
temperature=1.,
noise_dropout=0.,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.,
unconditional_conditioning=None,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**kwargs
):
if conditioning is not None:
if isinstance(conditioning, dict):
ctmp = conditioning[list(conditioning.keys())[0]]
while isinstance(ctmp, list):
ctmp = ctmp[0]
cbs = ctmp.shape[0]
if cbs != batch_size:
shared.log.warning(f"UniPC: got {cbs} conditionings but batch-size is {batch_size}")
elif isinstance(conditioning, list):
for ctmp in conditioning:
if ctmp.shape[0] != batch_size:
shared.log.warning(f"UniPC: Got {cbs} conditionings but batch-size is {batch_size}")
else:
if conditioning.shape[0] != batch_size:
shared.log.warning(f"UniPC: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
device = self.model.betas.device
if x_T is None:
img = torch.randn(size, device=device)
else:
img = x_T
# SD 1.X is "noise", SD 2.X is "v"
model_type = "v" if self.model.parameterization == "v" else "noise"
model_fn = model_wrapper(
lambda x, t, c: self.model.apply_model(x, t, c),
self.noise_schedule,
model_type=model_type,
guidance_type="classifier-free",
#condition=conditioning,
#unconditional_condition=unconditional_conditioning,
guidance_scale=unconditional_guidance_scale,
)
uni_pc = UniPC(model_fn, self.noise_schedule, predict_x0=True, thresholding=False, variant=shared.opts.uni_pc_variant, condition=conditioning, unconditional_condition=unconditional_conditioning, before_sample=self.before_sample, after_sample=self.after_sample, after_update=self.after_update)
x = uni_pc.sample(img, steps=S, skip_type=shared.opts.uni_pc_skip_type, method="multistep", order=shared.opts.schedulers_solver_order, lower_order_final=shared.opts.schedulers_use_loworder)
return x.to(device), None
|