|
from __future__ import annotations
|
|
import gradio as gr
|
|
import logging
|
|
import os
|
|
import re
|
|
|
|
import lora_patches
|
|
import functools
|
|
import network
|
|
|
|
import torch
|
|
from typing import Union
|
|
|
|
from modules import shared, sd_models, errors, scripts
|
|
from ldm_patched.modules.utils import load_torch_file
|
|
from ldm_patched.modules.sd import load_lora_for_models
|
|
|
|
|
|
@functools.lru_cache(maxsize=5)
|
|
def load_lora_state_dict(filename):
|
|
return load_torch_file(filename, safe_load=True)
|
|
|
|
|
|
def convert_diffusers_name_to_compvis(key, is_sd2):
|
|
pass
|
|
|
|
|
|
def assign_network_names_to_compvis_modules(sd_model):
|
|
pass
|
|
|
|
|
|
class BundledTIHash(str):
|
|
def __init__(self, hash_str):
|
|
self.hash = hash_str
|
|
|
|
def __str__(self):
|
|
return self.hash if shared.opts.lora_bundled_ti_to_infotext else ''
|
|
|
|
|
|
def load_network(name, network_on_disk):
|
|
net = network.Network(name, network_on_disk)
|
|
net.mtime = os.path.getmtime(network_on_disk.filename)
|
|
|
|
return net
|
|
|
|
|
|
def purge_networks_from_memory():
|
|
pass
|
|
|
|
|
|
def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None):
|
|
global lora_state_dict_cache
|
|
|
|
current_sd = sd_models.model_data.get_sd_model()
|
|
if current_sd is None:
|
|
return
|
|
|
|
loaded_networks.clear()
|
|
|
|
unavailable_networks = []
|
|
for name in names:
|
|
if name.lower() in forbidden_network_aliases and available_networks.get(name) is None:
|
|
unavailable_networks.append(name)
|
|
elif available_network_aliases.get(name) is None:
|
|
unavailable_networks.append(name)
|
|
|
|
if unavailable_networks:
|
|
update_available_networks_by_names(unavailable_networks)
|
|
|
|
networks_on_disk = [available_networks.get(name, None) if name.lower() in forbidden_network_aliases else available_network_aliases.get(name, None) for name in names]
|
|
if any(x is None for x in networks_on_disk):
|
|
list_available_networks()
|
|
networks_on_disk = [available_networks.get(name, None) if name.lower() in forbidden_network_aliases else available_network_aliases.get(name, None) for name in names]
|
|
|
|
for i, (network_on_disk, name) in enumerate(zip(networks_on_disk, names)):
|
|
try:
|
|
net = load_network(name, network_on_disk)
|
|
except Exception as e:
|
|
errors.display(e, f"loading network {network_on_disk.filename}")
|
|
continue
|
|
net.mentioned_name = name
|
|
network_on_disk.read_hash()
|
|
loaded_networks.append(net)
|
|
|
|
compiled_lora_targets = []
|
|
for a, b, c in zip(networks_on_disk, unet_multipliers, te_multipliers):
|
|
compiled_lora_targets.append([a.filename, b, c])
|
|
|
|
compiled_lora_targets_hash = str(compiled_lora_targets)
|
|
|
|
if current_sd.current_lora_hash == compiled_lora_targets_hash:
|
|
return
|
|
|
|
current_sd.current_lora_hash = compiled_lora_targets_hash
|
|
current_sd.forge_objects.unet = current_sd.forge_objects_original.unet
|
|
current_sd.forge_objects.clip = current_sd.forge_objects_original.clip
|
|
|
|
for filename, strength_model, strength_clip in compiled_lora_targets:
|
|
lora_sd = load_lora_state_dict(filename)
|
|
current_sd.forge_objects.unet, current_sd.forge_objects.clip = load_lora_for_models(
|
|
current_sd.forge_objects.unet, current_sd.forge_objects.clip, lora_sd, strength_model, strength_clip,
|
|
filename=filename)
|
|
|
|
current_sd.forge_objects_after_applying_lora = current_sd.forge_objects.shallow_copy()
|
|
return
|
|
|
|
|
|
def allowed_layer_without_weight(layer):
|
|
if isinstance(layer, torch.nn.LayerNorm) and not layer.elementwise_affine:
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
def store_weights_backup(weight):
|
|
if weight is None:
|
|
return None
|
|
|
|
return weight.to(devices.cpu, copy=True)
|
|
|
|
|
|
def restore_weights_backup(obj, field, weight):
|
|
if weight is None:
|
|
setattr(obj, field, None)
|
|
return
|
|
|
|
getattr(obj, field).copy_(weight)
|
|
|
|
|
|
def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
|
|
pass
|
|
|
|
|
|
def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
|
|
pass
|
|
|
|
|
|
def network_forward(org_module, input, original_forward):
|
|
pass
|
|
|
|
|
|
def network_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]):
|
|
pass
|
|
|
|
|
|
def network_Linear_forward(self, input):
|
|
pass
|
|
|
|
|
|
def network_Linear_load_state_dict(self, *args, **kwargs):
|
|
pass
|
|
|
|
|
|
def network_Conv2d_forward(self, input):
|
|
pass
|
|
|
|
|
|
def network_Conv2d_load_state_dict(self, *args, **kwargs):
|
|
pass
|
|
|
|
|
|
def network_GroupNorm_forward(self, input):
|
|
pass
|
|
|
|
|
|
def network_GroupNorm_load_state_dict(self, *args, **kwargs):
|
|
pass
|
|
|
|
|
|
def network_LayerNorm_forward(self, input):
|
|
pass
|
|
|
|
|
|
def network_LayerNorm_load_state_dict(self, *args, **kwargs):
|
|
pass
|
|
|
|
|
|
def network_MultiheadAttention_forward(self, *args, **kwargs):
|
|
pass
|
|
|
|
|
|
def network_MultiheadAttention_load_state_dict(self, *args, **kwargs):
|
|
pass
|
|
|
|
|
|
def process_network_files(names: list[str] | None = None):
|
|
candidates = list(shared.walk_files(shared.cmd_opts.lora_dir, allowed_extensions=[".pt", ".ckpt", ".safetensors"]))
|
|
for filename in candidates:
|
|
if os.path.isdir(filename):
|
|
continue
|
|
name = os.path.splitext(os.path.basename(filename))[0]
|
|
|
|
if names and name not in names:
|
|
continue
|
|
try:
|
|
entry = network.NetworkOnDisk(name, filename)
|
|
except OSError:
|
|
errors.report(f"Failed to load network {name} from {filename}", exc_info=True)
|
|
continue
|
|
|
|
available_networks[name] = entry
|
|
|
|
if entry.alias in available_network_aliases:
|
|
forbidden_network_aliases[entry.alias.lower()] = 1
|
|
|
|
available_network_aliases[name] = entry
|
|
available_network_aliases[entry.alias] = entry
|
|
|
|
|
|
def update_available_networks_by_names(names: list[str]):
|
|
process_network_files(names)
|
|
|
|
|
|
def list_available_networks():
|
|
available_networks.clear()
|
|
available_network_aliases.clear()
|
|
forbidden_network_aliases.clear()
|
|
available_network_hash_lookup.clear()
|
|
forbidden_network_aliases.update({"none": 1, "Addams": 1})
|
|
|
|
os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True)
|
|
|
|
process_network_files()
|
|
|
|
|
|
re_network_name = re.compile(r"(.*)\s*\([0-9a-fA-F]+\)")
|
|
|
|
|
|
def infotext_pasted(infotext, params):
|
|
if "AddNet Module 1" in [x[1] for x in scripts.scripts_txt2img.infotext_fields]:
|
|
return
|
|
|
|
added = []
|
|
|
|
for k in params:
|
|
if not k.startswith("AddNet Model "):
|
|
continue
|
|
|
|
num = k[13:]
|
|
|
|
if params.get("AddNet Module " + num) != "LoRA":
|
|
continue
|
|
|
|
name = params.get("AddNet Model " + num)
|
|
if name is None:
|
|
continue
|
|
|
|
m = re_network_name.match(name)
|
|
if m:
|
|
name = m.group(1)
|
|
|
|
multiplier = params.get("AddNet Weight A " + num, "1.0")
|
|
|
|
added.append(f"<lora:{name}:{multiplier}>")
|
|
|
|
if added:
|
|
params["Prompt"] += "\n" + "".join(added)
|
|
|
|
|
|
originals: lora_patches.LoraPatches = None
|
|
|
|
extra_network_lora = None
|
|
|
|
available_networks = {}
|
|
available_network_aliases = {}
|
|
loaded_networks = []
|
|
loaded_bundle_embeddings = {}
|
|
networks_in_memory = {}
|
|
available_network_hash_lookup = {}
|
|
forbidden_network_aliases = {}
|
|
|
|
list_available_networks()
|
|
|