File size: 5,109 Bytes
ecc4278 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
import torch
from ldm_patched.modules.conds import CONDRegular, CONDCrossAttn
from ldm_patched.modules.samplers import sampling_function
from ldm_patched.modules import model_management
from ldm_patched.modules.ops import cleanup_cache
def cond_from_a1111_to_patched_ldm(cond):
if isinstance(cond, torch.Tensor):
result = dict(
cross_attn=cond,
model_conds=dict(
c_crossattn=CONDCrossAttn(cond),
)
)
return [result, ]
cross_attn = cond['crossattn']
pooled_output = cond['vector']
result = dict(
cross_attn=cross_attn,
pooled_output=pooled_output,
model_conds=dict(
c_crossattn=CONDCrossAttn(cross_attn),
y=CONDRegular(pooled_output)
)
)
return [result, ]
def cond_from_a1111_to_patched_ldm_weighted(cond, weights):
transposed = list(map(list, zip(*weights)))
results = []
for cond_pre in transposed:
current_indices = []
current_weight = 0
for i, w in cond_pre:
current_indices.append(i)
current_weight = w
if hasattr(cond, 'advanced_indexing'):
feed = cond.advanced_indexing(current_indices)
else:
feed = cond[current_indices]
h = cond_from_a1111_to_patched_ldm(feed)
h[0]['strength'] = current_weight
results += h
return results
def forge_sample(self, denoiser_params, cond_scale, cond_composition):
model = self.inner_model.inner_model.forge_objects.unet.model
control = self.inner_model.inner_model.forge_objects.unet.controlnet_linked_list
extra_concat_condition = self.inner_model.inner_model.forge_objects.unet.extra_concat_condition
x = denoiser_params.x
timestep = denoiser_params.sigma
uncond = cond_from_a1111_to_patched_ldm(denoiser_params.text_uncond)
cond = cond_from_a1111_to_patched_ldm_weighted(denoiser_params.text_cond, cond_composition)
model_options = self.inner_model.inner_model.forge_objects.unet.model_options
seed = self.p.seeds[0]
if extra_concat_condition is not None:
image_cond_in = extra_concat_condition
else:
image_cond_in = denoiser_params.image_cond
if isinstance(image_cond_in, torch.Tensor):
if image_cond_in.shape[0] == x.shape[0] \
and image_cond_in.shape[2] == x.shape[2] \
and image_cond_in.shape[3] == x.shape[3]:
for i in range(len(uncond)):
uncond[i]['model_conds']['c_concat'] = CONDRegular(image_cond_in)
for i in range(len(cond)):
cond[i]['model_conds']['c_concat'] = CONDRegular(image_cond_in)
if control is not None:
for h in cond + uncond:
h['control'] = control
# Handle skip_uncond
skip_uncond = getattr(self, 'skip_uncond', False)
if skip_uncond:
uncond = None
# Handle is_edit_model
is_edit_model = getattr(self, 'is_edit_model', False)
if is_edit_model:
image_cfg_scale = getattr(self, 'image_cfg_scale', None)
model_options['image_cfg_scale'] = image_cfg_scale
# Handle mask and init_latent
mask = getattr(self, 'mask', None)
init_latent = getattr(self, 'init_latent', None)
if mask is not None and init_latent is not None:
model_options['mask'] = mask
model_options['init_latent'] = init_latent
for modifier in model_options.get('conditioning_modifiers', []):
model, x, timestep, uncond, cond, cond_scale, model_options, seed = modifier(model, x, timestep, uncond, cond, cond_scale, model_options, seed)
denoised = sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options, seed)
# Handle mask_before_denoising
if getattr(self, 'mask_before_denoising', False) and mask is not None:
denoised = denoised * (1 - mask) + init_latent * mask
return denoised
def sampling_prepare(unet, x):
B, C, H, W = x.shape
memory_estimation_function = unet.model_options.get('memory_peak_estimation_modifier', unet.memory_required)
unet_inference_memory = memory_estimation_function([B * 2, C, H, W])
additional_inference_memory = unet.extra_preserved_memory_during_sampling
additional_model_patchers = unet.extra_model_patchers_during_sampling
if unet.controlnet_linked_list is not None:
additional_inference_memory += unet.controlnet_linked_list.inference_memory_requirements(unet.model_dtype())
additional_model_patchers += unet.controlnet_linked_list.get_models()
model_management.load_models_gpu(
models=[unet] + additional_model_patchers,
memory_required=unet_inference_memory + additional_inference_memory)
real_model = unet.model
percent_to_timestep_function = lambda p: real_model.model_sampling.percent_to_sigma(p)
for cnet in unet.list_controlnets():
cnet.pre_run(real_model, percent_to_timestep_function)
return
def sampling_cleanup(unet):
for cnet in unet.list_controlnets():
cnet.cleanup()
cleanup_cache()
return
|