Spaces:
Paused
Paused
import json | |
import os | |
from collections import OrderedDict | |
from typing import Optional, Union, List, Type, TYPE_CHECKING, Dict, Any, Literal | |
import torch | |
from optimum.quanto import QTensor | |
from torch import nn | |
import weakref | |
from tqdm import tqdm | |
from toolkit.config_modules import NetworkConfig | |
from toolkit.lorm import extract_conv, extract_linear, count_parameters | |
from toolkit.metadata import add_model_hash_to_meta | |
from toolkit.paths import KEYMAPS_ROOT | |
from toolkit.saving import get_lora_keymap_from_model_keymap | |
from optimum.quanto import QBytesTensor | |
if TYPE_CHECKING: | |
from toolkit.lycoris_special import LycorisSpecialNetwork, LoConSpecialModule | |
from toolkit.lora_special import LoRASpecialNetwork, LoRAModule | |
from toolkit.stable_diffusion_model import StableDiffusion | |
from toolkit.models.DoRA import DoRAModule | |
Network = Union['LycorisSpecialNetwork', 'LoRASpecialNetwork'] | |
Module = Union['LoConSpecialModule', 'LoRAModule', 'DoRAModule'] | |
LINEAR_MODULES = [ | |
'Linear', | |
'LoRACompatibleLinear', | |
'QLinear' | |
# 'GroupNorm', | |
] | |
CONV_MODULES = [ | |
'Conv2d', | |
'LoRACompatibleConv' | |
] | |
ExtractMode = Union[ | |
'existing' | |
'fixed', | |
'threshold', | |
'ratio', | |
'quantile', | |
'percentage' | |
] | |
def broadcast_and_multiply(tensor, multiplier): | |
# Determine the number of dimensions required | |
num_extra_dims = tensor.dim() - multiplier.dim() | |
# Unsqueezing the tensor to match the dimensionality | |
for _ in range(num_extra_dims): | |
multiplier = multiplier.unsqueeze(-1) | |
try: | |
# Multiplying the broadcasted tensor with the output tensor | |
result = tensor * multiplier | |
except RuntimeError as e: | |
print(e) | |
print(tensor.size()) | |
print(multiplier.size()) | |
raise e | |
return result | |
def add_bias(tensor, bias): | |
if bias is None: | |
return tensor | |
# add batch dim | |
bias = bias.unsqueeze(0) | |
bias = torch.cat([bias] * tensor.size(0), dim=0) | |
# Determine the number of dimensions required | |
num_extra_dims = tensor.dim() - bias.dim() | |
# Unsqueezing the tensor to match the dimensionality | |
for _ in range(num_extra_dims): | |
bias = bias.unsqueeze(-1) | |
# we may need to swap -1 for -2 | |
if bias.size(1) != tensor.size(1): | |
if len(bias.size()) == 3: | |
bias = bias.permute(0, 2, 1) | |
elif len(bias.size()) == 4: | |
bias = bias.permute(0, 3, 1, 2) | |
# Multiplying the broadcasted tensor with the output tensor | |
try: | |
result = tensor + bias | |
except RuntimeError as e: | |
print(e) | |
print(tensor.size()) | |
print(bias.size()) | |
raise e | |
return result | |
class ExtractableModuleMixin: | |
def extract_weight( | |
self: Module, | |
extract_mode: ExtractMode = "existing", | |
extract_mode_param: Union[int, float] = None, | |
): | |
device = self.lora_down.weight.device | |
weight_to_extract = self.org_module[0].weight | |
if extract_mode == "existing": | |
extract_mode = 'fixed' | |
extract_mode_param = self.lora_dim | |
if isinstance(weight_to_extract, QBytesTensor): | |
weight_to_extract = weight_to_extract.dequantize() | |
weight_to_extract = weight_to_extract.clone().detach().float() | |
if self.org_module[0].__class__.__name__ in CONV_MODULES: | |
# do conv extraction | |
down_weight, up_weight, new_dim, diff = extract_conv( | |
weight=weight_to_extract, | |
mode=extract_mode, | |
mode_param=extract_mode_param, | |
device=device | |
) | |
elif self.org_module[0].__class__.__name__ in LINEAR_MODULES: | |
# do linear extraction | |
down_weight, up_weight, new_dim, diff = extract_linear( | |
weight=weight_to_extract, | |
mode=extract_mode, | |
mode_param=extract_mode_param, | |
device=device, | |
) | |
else: | |
raise ValueError(f"Unknown module type: {self.org_module[0].__class__.__name__}") | |
self.lora_dim = new_dim | |
# inject weights into the param | |
self.lora_down.weight.data = down_weight.to(self.lora_down.weight.dtype).clone().detach() | |
self.lora_up.weight.data = up_weight.to(self.lora_up.weight.dtype).clone().detach() | |
# copy bias if we have one and are using them | |
if self.org_module[0].bias is not None and self.lora_up.bias is not None: | |
self.lora_up.bias.data = self.org_module[0].bias.data.clone().detach() | |
# set up alphas | |
self.alpha = (self.alpha * 0) + down_weight.shape[0] | |
self.scale = self.alpha / self.lora_dim | |
# assign them | |
# handle trainable scaler method locon does | |
if hasattr(self, 'scalar'): | |
# scaler is a parameter update the value with 1.0 | |
self.scalar.data = torch.tensor(1.0).to(self.scalar.device, self.scalar.dtype) | |
class ToolkitModuleMixin: | |
def __init__( | |
self: Module, | |
*args, | |
network: Network, | |
**kwargs | |
): | |
self.network_ref: weakref.ref = weakref.ref(network) | |
self.is_checkpointing = False | |
self._multiplier: Union[float, list, torch.Tensor] = None | |
def _call_forward(self: Module, x): | |
# module dropout | |
if self.module_dropout is not None and self.training: | |
if torch.rand(1) < self.module_dropout: | |
return 0.0 # added to original forward | |
if hasattr(self, 'lora_mid') and self.lora_mid is not None: | |
lx = self.lora_mid(self.lora_down(x)) | |
else: | |
try: | |
lx = self.lora_down(x) | |
except RuntimeError as e: | |
print(f"Error in {self.__class__.__name__} lora_down") | |
print(e) | |
if isinstance(self.dropout, nn.Dropout) or isinstance(self.dropout, nn.Identity): | |
lx = self.dropout(lx) | |
# normal dropout | |
elif self.dropout is not None and self.training: | |
lx = torch.nn.functional.dropout(lx, p=self.dropout) | |
# rank dropout | |
if self.rank_dropout is not None and self.rank_dropout > 0 and self.training: | |
mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout | |
if len(lx.size()) == 3: | |
mask = mask.unsqueeze(1) # for Text Encoder | |
elif len(lx.size()) == 4: | |
mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d | |
lx = lx * mask | |
# scaling for rank dropout: treat as if the rank is changed | |
# maskから計算することも考えられるが、augmentation的な効果を期待してrank_dropoutを用いる | |
scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability | |
else: | |
scale = self.scale | |
lx = self.lora_up(lx) | |
# handle trainable scaler method locon does | |
if hasattr(self, 'scalar'): | |
scale = scale * self.scalar | |
return lx * scale | |
def lorm_forward(self: Network, x, *args, **kwargs): | |
network: Network = self.network_ref() | |
if not network.is_active: | |
return self.org_forward(x, *args, **kwargs) | |
orig_dtype = x.dtype | |
if x.dtype != self.lora_down.weight.dtype: | |
x = x.to(self.lora_down.weight.dtype) | |
if network.lorm_train_mode == 'local': | |
# we are going to predict input with both and do a loss on them | |
inputs = x.detach() | |
with torch.no_grad(): | |
# get the local prediction | |
target_pred = self.org_forward(inputs, *args, **kwargs).detach() | |
with torch.set_grad_enabled(True): | |
# make a prediction with the lorm | |
lorm_pred = self.lora_up(self.lora_down(inputs.requires_grad_(True))) | |
local_loss = torch.nn.functional.mse_loss(target_pred.float(), lorm_pred.float()) | |
# backpropr | |
local_loss.backward() | |
network.module_losses.append(local_loss.detach()) | |
# return the original as we dont want our trainer to affect ones down the line | |
return target_pred | |
else: | |
x = self.lora_up(self.lora_down(x)) | |
if x.dtype != orig_dtype: | |
x = x.to(orig_dtype) | |
def forward(self: Module, x, *args, **kwargs): | |
skip = False | |
network: Network = self.network_ref() | |
if network.is_lorm: | |
# we are doing lorm | |
return self.lorm_forward(x, *args, **kwargs) | |
# skip if not active | |
if not network.is_active: | |
skip = True | |
# skip if is merged in | |
if network.is_merged_in: | |
skip = True | |
# skip if multiplier is 0 | |
if network._multiplier == 0: | |
skip = True | |
if skip: | |
# network is not active, avoid doing anything | |
return self.org_forward(x, *args, **kwargs) | |
# if self.__class__.__name__ == "DoRAModule": | |
# # return dora forward | |
# return self.dora_forward(x, *args, **kwargs) | |
if self.__class__.__name__ == "LokrModule": | |
return self._call_forward(x) | |
org_forwarded = self.org_forward(x, *args, **kwargs) | |
if isinstance(x, QTensor): | |
x = x.dequantize() | |
# always cast to float32 | |
lora_input = x.to(self.lora_down.weight.dtype) | |
lora_output = self._call_forward(lora_input) | |
multiplier = self.network_ref().torch_multiplier | |
lora_output_batch_size = lora_output.size(0) | |
multiplier_batch_size = multiplier.size(0) | |
if lora_output_batch_size != multiplier_batch_size: | |
num_interleaves = lora_output_batch_size // multiplier_batch_size | |
# todo check if this is correct, do we just concat when doing cfg? | |
multiplier = multiplier.repeat_interleave(num_interleaves) | |
scaled_lora_output = broadcast_and_multiply(lora_output, multiplier) | |
scaled_lora_output = scaled_lora_output.to(org_forwarded.dtype) | |
if self.__class__.__name__ == "DoRAModule": | |
# ref https://github.com/huggingface/peft/blob/1e6d1d73a0850223b0916052fd8d2382a90eae5a/src/peft/tuners/lora/layer.py#L417 | |
# x = dropout(x) | |
# todo this wont match the dropout applied to the lora | |
if isinstance(self.dropout, nn.Dropout) or isinstance(self.dropout, nn.Identity): | |
lx = self.dropout(x) | |
# normal dropout | |
elif self.dropout is not None and self.training: | |
lx = torch.nn.functional.dropout(x, p=self.dropout) | |
else: | |
lx = x | |
lora_weight = self.lora_up.weight @ self.lora_down.weight | |
# scale it here | |
# todo handle our batch split scalers for slider training. For now take the mean of them | |
scale = multiplier.mean() | |
scaled_lora_weight = lora_weight * scale | |
scaled_lora_output = scaled_lora_output + self.apply_dora(lx, scaled_lora_weight).to(org_forwarded.dtype) | |
try: | |
x = org_forwarded + scaled_lora_output | |
except RuntimeError as e: | |
print(e) | |
print(org_forwarded.size()) | |
print(scaled_lora_output.size()) | |
raise e | |
return x | |
def enable_gradient_checkpointing(self: Module): | |
self.is_checkpointing = True | |
def disable_gradient_checkpointing(self: Module): | |
self.is_checkpointing = False | |
def merge_out(self: Module, merge_out_weight=1.0): | |
# make sure it is positive | |
merge_out_weight = abs(merge_out_weight) | |
# merging out is just merging in the negative of the weight | |
self.merge_in(merge_weight=-merge_out_weight) | |
def merge_in(self: Module, merge_weight=1.0): | |
if not self.can_merge_in: | |
return | |
# get up/down weight | |
up_weight = self.lora_up.weight.clone().float() | |
down_weight = self.lora_down.weight.clone().float() | |
# extract weight from org_module | |
org_sd = self.org_module[0].state_dict() | |
# todo find a way to merge in weights when doing quantized model | |
if 'weight._data' in org_sd: | |
# quantized weight | |
return | |
weight_key = "weight" | |
if 'weight._data' in org_sd: | |
# quantized weight | |
weight_key = "weight._data" | |
orig_dtype = org_sd[weight_key].dtype | |
weight = org_sd[weight_key].float() | |
multiplier = merge_weight | |
scale = self.scale | |
# handle trainable scaler method locon does | |
if hasattr(self, 'scalar'): | |
scale = scale * self.scalar | |
# merge weight | |
if len(weight.size()) == 2: | |
# linear | |
weight = weight + multiplier * (up_weight @ down_weight) * scale | |
elif down_weight.size()[2:4] == (1, 1): | |
# conv2d 1x1 | |
weight = ( | |
weight | |
+ multiplier | |
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) | |
* scale | |
) | |
else: | |
# conv2d 3x3 | |
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) | |
# print(conved.size(), weight.size(), module.stride, module.padding) | |
weight = weight + multiplier * conved * scale | |
# set weight to org_module | |
org_sd[weight_key] = weight.to(orig_dtype) | |
self.org_module[0].load_state_dict(org_sd) | |
def setup_lorm(self: Module, state_dict: Optional[Dict[str, Any]] = None): | |
# LoRM (Low Rank Middle) is a method reduce the number of parameters in a module while keeping the inputs and | |
# outputs the same. It is basically a LoRA but with the original module removed | |
# if a state dict is passed, use those weights instead of extracting | |
# todo load from state dict | |
network: Network = self.network_ref() | |
lorm_config = network.network_config.lorm_config.get_config_for_module(self.lora_name) | |
extract_mode = lorm_config.extract_mode | |
extract_mode_param = lorm_config.extract_mode_param | |
parameter_threshold = lorm_config.parameter_threshold | |
self.extract_weight( | |
extract_mode=extract_mode, | |
extract_mode_param=extract_mode_param | |
) | |
class ToolkitNetworkMixin: | |
def __init__( | |
self: Network, | |
*args, | |
train_text_encoder: Optional[bool] = True, | |
train_unet: Optional[bool] = True, | |
is_sdxl=False, | |
is_v2=False, | |
is_ssd=False, | |
is_vega=False, | |
network_config: Optional[NetworkConfig] = None, | |
is_lorm=False, | |
**kwargs | |
): | |
self.train_text_encoder = train_text_encoder | |
self.train_unet = train_unet | |
self.is_checkpointing = False | |
self._multiplier: float = 1.0 | |
self.is_active: bool = False | |
self.is_sdxl = is_sdxl | |
self.is_ssd = is_ssd | |
self.is_vega = is_vega | |
self.is_v2 = is_v2 | |
self.is_v1 = not is_v2 and not is_sdxl and not is_ssd and not is_vega | |
self.is_merged_in = False | |
self.is_lorm = is_lorm | |
self.network_config: NetworkConfig = network_config | |
self.module_losses: List[torch.Tensor] = [] | |
self.lorm_train_mode: Literal['local', None] = None | |
self.can_merge_in = not is_lorm | |
def get_keymap(self: Network, force_weight_mapping=False): | |
use_weight_mapping = False | |
if self.is_ssd: | |
keymap_tail = 'ssd' | |
use_weight_mapping = True | |
elif self.is_vega: | |
keymap_tail = 'vega' | |
use_weight_mapping = True | |
elif self.is_sdxl: | |
keymap_tail = 'sdxl' | |
elif self.is_v2: | |
keymap_tail = 'sd2' | |
else: | |
keymap_tail = 'sd1' | |
# todo double check this | |
# use_weight_mapping = True | |
if force_weight_mapping: | |
use_weight_mapping = True | |
# load keymap | |
keymap_name = f"stable_diffusion_locon_{keymap_tail}.json" | |
if use_weight_mapping: | |
keymap_name = f"stable_diffusion_{keymap_tail}.json" | |
keymap_path = os.path.join(KEYMAPS_ROOT, keymap_name) | |
keymap = None | |
# check if file exists | |
if os.path.exists(keymap_path): | |
with open(keymap_path, 'r') as f: | |
keymap = json.load(f)['ldm_diffusers_keymap'] | |
if use_weight_mapping and keymap is not None: | |
# get keymap from weights | |
keymap = get_lora_keymap_from_model_keymap(keymap) | |
# upgrade keymaps for DoRA | |
if self.network_type.lower() == 'dora': | |
if keymap is not None: | |
new_keymap = {} | |
for ldm_key, diffusers_key in keymap.items(): | |
ldm_key = ldm_key.replace('.alpha', '.magnitude') | |
# ldm_key = ldm_key.replace('.lora_down.weight', '.lora_down') | |
# ldm_key = ldm_key.replace('.lora_up.weight', '.lora_up') | |
diffusers_key = diffusers_key.replace('.alpha', '.magnitude') | |
# diffusers_key = diffusers_key.replace('.lora_down.weight', '.lora_down') | |
# diffusers_key = diffusers_key.replace('.lora_up.weight', '.lora_up') | |
new_keymap[ldm_key] = diffusers_key | |
keymap = new_keymap | |
return keymap | |
def get_state_dict(self: Network, extra_state_dict=None, dtype=torch.float16): | |
keymap = self.get_keymap() | |
save_keymap = {} | |
if keymap is not None: | |
for ldm_key, diffusers_key in keymap.items(): | |
# invert them | |
save_keymap[diffusers_key] = ldm_key | |
state_dict = self.state_dict() | |
save_dict = OrderedDict() | |
for key in list(state_dict.keys()): | |
v = state_dict[key] | |
v = v.detach().clone().to("cpu").to(dtype) | |
save_key = save_keymap[key] if key in save_keymap else key | |
save_dict[save_key] = v | |
del state_dict[key] | |
if extra_state_dict is not None: | |
# add extra items to state dict | |
for key in list(extra_state_dict.keys()): | |
v = extra_state_dict[key] | |
v = v.detach().clone().to("cpu").to(dtype) | |
save_dict[key] = v | |
if self.peft_format: | |
# lora_down = lora_A | |
# lora_up = lora_B | |
# no alpha | |
new_save_dict = {} | |
for key, value in save_dict.items(): | |
if key.endswith('.alpha'): | |
continue | |
new_key = key | |
new_key = new_key.replace('lora_down', 'lora_A') | |
new_key = new_key.replace('lora_up', 'lora_B') | |
# replace all $$ with . | |
new_key = new_key.replace('$$', '.') | |
new_save_dict[new_key] = value | |
save_dict = new_save_dict | |
if self.network_type.lower() == "lokr": | |
new_save_dict = {} | |
for key, value in save_dict.items(): | |
# lora_transformer_transformer_blocks_7_attn_to_v.lokr_w1 to lycoris_transformer_blocks_7_attn_to_v.lokr_w1 | |
new_key = key | |
new_key = new_key.replace('lora_transformer_', 'lycoris_') | |
new_save_dict[new_key] = value | |
save_dict = new_save_dict | |
if self.base_model_ref is not None: | |
save_dict = self.base_model_ref().convert_lora_weights_before_save(save_dict) | |
return save_dict | |
def save_weights( | |
self: Network, | |
file, dtype=torch.float16, | |
metadata=None, | |
extra_state_dict: Optional[OrderedDict] = None | |
): | |
save_dict = self.get_state_dict(extra_state_dict=extra_state_dict, dtype=dtype) | |
if metadata is not None and len(metadata) == 0: | |
metadata = None | |
if metadata is None: | |
metadata = OrderedDict() | |
metadata = add_model_hash_to_meta(save_dict, metadata) | |
if os.path.splitext(file)[1] == ".safetensors": | |
from safetensors.torch import save_file | |
save_file(save_dict, file, metadata) | |
else: | |
torch.save(save_dict, file) | |
def load_weights(self: Network, file, force_weight_mapping=False): | |
# allows us to save and load to and from ldm weights | |
keymap = self.get_keymap(force_weight_mapping) | |
keymap = {} if keymap is None else keymap | |
if isinstance(file, str): | |
if os.path.splitext(file)[1] == ".safetensors": | |
from safetensors.torch import load_file | |
weights_sd = load_file(file) | |
else: | |
weights_sd = torch.load(file, map_location="cpu") | |
else: | |
# probably a state dict | |
weights_sd = file | |
if self.base_model_ref is not None: | |
weights_sd = self.base_model_ref().convert_lora_weights_before_load(weights_sd) | |
load_sd = OrderedDict() | |
for key, value in weights_sd.items(): | |
load_key = keymap[key] if key in keymap else key | |
# replace old double __ with single _ | |
if self.is_pixart: | |
load_key = load_key.replace('__', '_') | |
if self.peft_format: | |
# lora_down = lora_A | |
# lora_up = lora_B | |
# no alpha | |
if load_key.endswith('.alpha'): | |
continue | |
load_key = load_key.replace('lora_A', 'lora_down') | |
load_key = load_key.replace('lora_B', 'lora_up') | |
# replace all . with $$ | |
load_key = load_key.replace('.', '$$') | |
load_key = load_key.replace('$$lora_down$$', '.lora_down.') | |
load_key = load_key.replace('$$lora_up$$', '.lora_up.') | |
if self.network_type.lower() == "lokr": | |
# lora_transformer_transformer_blocks_7_attn_to_v.lokr_w1 to lycoris_transformer_blocks_7_attn_to_v.lokr_w1 | |
load_key = load_key.replace('lycoris_', 'lora_transformer_') | |
load_sd[load_key] = value | |
# extract extra items from state dict | |
current_state_dict = self.state_dict() | |
extra_dict = OrderedDict() | |
to_delete = [] | |
for key in list(load_sd.keys()): | |
if key not in current_state_dict: | |
extra_dict[key] = load_sd[key] | |
to_delete.append(key) | |
for key in to_delete: | |
del load_sd[key] | |
print(f"Missing keys: {to_delete}") | |
if len(to_delete) > 0 and self.is_v1 and not force_weight_mapping and not ( | |
len(to_delete) == 1 and 'emb_params' in to_delete): | |
print(" Attempting to load with forced keymap") | |
return self.load_weights(file, force_weight_mapping=True) | |
info = self.load_state_dict(load_sd, False) | |
if len(extra_dict.keys()) == 0: | |
extra_dict = None | |
return extra_dict | |
def _update_torch_multiplier(self: Network): | |
# builds a tensor for fast usage in the forward pass of the network modules | |
# without having to set it in every single module every time it changes | |
multiplier = self._multiplier | |
# get first module | |
try: | |
first_module = self.get_all_modules()[0] | |
except IndexError: | |
raise ValueError("There are not any lora modules in this network. Check your config and try again") | |
if hasattr(first_module, 'lora_down'): | |
device = first_module.lora_down.weight.device | |
dtype = first_module.lora_down.weight.dtype | |
elif hasattr(first_module, 'lokr_w1'): | |
device = first_module.lokr_w1.device | |
dtype = first_module.lokr_w1.dtype | |
elif hasattr(first_module, 'lokr_w1_a'): | |
device = first_module.lokr_w1_a.device | |
dtype = first_module.lokr_w1_a.dtype | |
else: | |
raise ValueError("Unknown module type") | |
with torch.no_grad(): | |
tensor_multiplier = None | |
if isinstance(multiplier, int) or isinstance(multiplier, float): | |
tensor_multiplier = torch.tensor((multiplier,)).to(device, dtype=dtype) | |
elif isinstance(multiplier, list): | |
tensor_multiplier = torch.tensor(multiplier).to(device, dtype=dtype) | |
elif isinstance(multiplier, torch.Tensor): | |
tensor_multiplier = multiplier.clone().detach().to(device, dtype=dtype) | |
self.torch_multiplier = tensor_multiplier.clone().detach() | |
def multiplier(self) -> Union[float, List[float], List[List[float]]]: | |
return self._multiplier | |
def multiplier(self, value: Union[float, List[float], List[List[float]]]): | |
# it takes time to update all the multipliers, so we only do it if the value has changed | |
if self._multiplier == value: | |
return | |
# if we are setting a single value but have a list, keep the list if every item is the same as value | |
self._multiplier = value | |
self._update_torch_multiplier() | |
# called when the context manager is entered | |
# ie: with network: | |
def __enter__(self: Network): | |
self.is_active = True | |
def __exit__(self: Network, exc_type, exc_value, tb): | |
self.is_active = False | |
def force_to(self: Network, device, dtype): | |
self.to(device, dtype) | |
loras = [] | |
if hasattr(self, 'unet_loras'): | |
loras += self.unet_loras | |
if hasattr(self, 'text_encoder_loras'): | |
loras += self.text_encoder_loras | |
for lora in loras: | |
lora.to(device, dtype) | |
def get_all_modules(self: Network) -> List[Module]: | |
loras = [] | |
if hasattr(self, 'unet_loras'): | |
loras += self.unet_loras | |
if hasattr(self, 'text_encoder_loras'): | |
loras += self.text_encoder_loras | |
return loras | |
def _update_checkpointing(self: Network): | |
for module in self.get_all_modules(): | |
if self.is_checkpointing: | |
module.enable_gradient_checkpointing() | |
else: | |
module.disable_gradient_checkpointing() | |
def enable_gradient_checkpointing(self: Network): | |
# not supported | |
self.is_checkpointing = True | |
self._update_checkpointing() | |
def disable_gradient_checkpointing(self: Network): | |
# not supported | |
self.is_checkpointing = False | |
self._update_checkpointing() | |
def merge_in(self, merge_weight=1.0): | |
if self.network_type.lower() == 'dora': | |
return | |
self.is_merged_in = True | |
for module in self.get_all_modules(): | |
module.merge_in(merge_weight) | |
def merge_out(self: Network, merge_weight=1.0): | |
if not self.is_merged_in: | |
return | |
self.is_merged_in = False | |
for module in self.get_all_modules(): | |
module.merge_out(merge_weight) | |
def extract_weight( | |
self: Network, | |
extract_mode: ExtractMode = "existing", | |
extract_mode_param: Union[int, float] = None, | |
): | |
if extract_mode_param is None: | |
raise ValueError("extract_mode_param must be set") | |
for module in tqdm(self.get_all_modules(), desc="Extracting weights"): | |
module.extract_weight( | |
extract_mode=extract_mode, | |
extract_mode_param=extract_mode_param | |
) | |
def setup_lorm(self: Network, state_dict: Optional[Dict[str, Any]] = None): | |
for module in tqdm(self.get_all_modules(), desc="Extracting LoRM"): | |
module.setup_lorm(state_dict=state_dict) | |
def calculate_lorem_parameter_reduction(self): | |
params_reduced = 0 | |
for module in self.get_all_modules(): | |
num_orig_module_params = count_parameters(module.org_module[0]) | |
num_lorem_params = count_parameters(module.lora_down) + count_parameters(module.lora_up) | |
params_reduced += (num_orig_module_params - num_lorem_params) | |
return params_reduced | |