Reforge / modules_forge /forge_sampler.py
YXStableDiffusion's picture
Upload folder using huggingface_hub
ecc4278 verified
raw
history blame contribute delete
5.11 kB
import torch
from ldm_patched.modules.conds import CONDRegular, CONDCrossAttn
from ldm_patched.modules.samplers import sampling_function
from ldm_patched.modules import model_management
from ldm_patched.modules.ops import cleanup_cache
def cond_from_a1111_to_patched_ldm(cond):
if isinstance(cond, torch.Tensor):
result = dict(
cross_attn=cond,
model_conds=dict(
c_crossattn=CONDCrossAttn(cond),
)
)
return [result, ]
cross_attn = cond['crossattn']
pooled_output = cond['vector']
result = dict(
cross_attn=cross_attn,
pooled_output=pooled_output,
model_conds=dict(
c_crossattn=CONDCrossAttn(cross_attn),
y=CONDRegular(pooled_output)
)
)
return [result, ]
def cond_from_a1111_to_patched_ldm_weighted(cond, weights):
transposed = list(map(list, zip(*weights)))
results = []
for cond_pre in transposed:
current_indices = []
current_weight = 0
for i, w in cond_pre:
current_indices.append(i)
current_weight = w
if hasattr(cond, 'advanced_indexing'):
feed = cond.advanced_indexing(current_indices)
else:
feed = cond[current_indices]
h = cond_from_a1111_to_patched_ldm(feed)
h[0]['strength'] = current_weight
results += h
return results
def forge_sample(self, denoiser_params, cond_scale, cond_composition):
model = self.inner_model.inner_model.forge_objects.unet.model
control = self.inner_model.inner_model.forge_objects.unet.controlnet_linked_list
extra_concat_condition = self.inner_model.inner_model.forge_objects.unet.extra_concat_condition
x = denoiser_params.x
timestep = denoiser_params.sigma
uncond = cond_from_a1111_to_patched_ldm(denoiser_params.text_uncond)
cond = cond_from_a1111_to_patched_ldm_weighted(denoiser_params.text_cond, cond_composition)
model_options = self.inner_model.inner_model.forge_objects.unet.model_options
seed = self.p.seeds[0]
if extra_concat_condition is not None:
image_cond_in = extra_concat_condition
else:
image_cond_in = denoiser_params.image_cond
if isinstance(image_cond_in, torch.Tensor):
if image_cond_in.shape[0] == x.shape[0] \
and image_cond_in.shape[2] == x.shape[2] \
and image_cond_in.shape[3] == x.shape[3]:
for i in range(len(uncond)):
uncond[i]['model_conds']['c_concat'] = CONDRegular(image_cond_in)
for i in range(len(cond)):
cond[i]['model_conds']['c_concat'] = CONDRegular(image_cond_in)
if control is not None:
for h in cond + uncond:
h['control'] = control
# Handle skip_uncond
skip_uncond = getattr(self, 'skip_uncond', False)
if skip_uncond:
uncond = None
# Handle is_edit_model
is_edit_model = getattr(self, 'is_edit_model', False)
if is_edit_model:
image_cfg_scale = getattr(self, 'image_cfg_scale', None)
model_options['image_cfg_scale'] = image_cfg_scale
# Handle mask and init_latent
mask = getattr(self, 'mask', None)
init_latent = getattr(self, 'init_latent', None)
if mask is not None and init_latent is not None:
model_options['mask'] = mask
model_options['init_latent'] = init_latent
for modifier in model_options.get('conditioning_modifiers', []):
model, x, timestep, uncond, cond, cond_scale, model_options, seed = modifier(model, x, timestep, uncond, cond, cond_scale, model_options, seed)
denoised = sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options, seed)
# Handle mask_before_denoising
if getattr(self, 'mask_before_denoising', False) and mask is not None:
denoised = denoised * (1 - mask) + init_latent * mask
return denoised
def sampling_prepare(unet, x):
B, C, H, W = x.shape
memory_estimation_function = unet.model_options.get('memory_peak_estimation_modifier', unet.memory_required)
unet_inference_memory = memory_estimation_function([B * 2, C, H, W])
additional_inference_memory = unet.extra_preserved_memory_during_sampling
additional_model_patchers = unet.extra_model_patchers_during_sampling
if unet.controlnet_linked_list is not None:
additional_inference_memory += unet.controlnet_linked_list.inference_memory_requirements(unet.model_dtype())
additional_model_patchers += unet.controlnet_linked_list.get_models()
model_management.load_models_gpu(
models=[unet] + additional_model_patchers,
memory_required=unet_inference_memory + additional_inference_memory)
real_model = unet.model
percent_to_timestep_function = lambda p: real_model.model_sampling.percent_to_sigma(p)
for cnet in unet.list_controlnets():
cnet.pre_run(real_model, percent_to_timestep_function)
return
def sampling_cleanup(unet):
for cnet in unet.list_controlnets():
cnet.cleanup()
cleanup_cache()
return