Spaces:
Paused
Paused
import math | |
import os | |
from typing import Optional, Union, List, Type | |
import torch | |
from lycoris.kohya import LycorisNetwork, LoConModule | |
from lycoris.modules.glora import GLoRAModule | |
from torch import nn | |
from transformers import CLIPTextModel | |
from torch.nn import functional as F | |
from toolkit.network_mixins import ToolkitNetworkMixin, ToolkitModuleMixin, ExtractableModuleMixin | |
# diffusers specific stuff | |
LINEAR_MODULES = [ | |
'Linear', | |
'LoRACompatibleLinear' | |
] | |
CONV_MODULES = [ | |
'Conv2d', | |
'LoRACompatibleConv' | |
] | |
class LoConSpecialModule(ToolkitModuleMixin, LoConModule, ExtractableModuleMixin): | |
def __init__( | |
self, | |
lora_name, org_module: nn.Module, | |
multiplier=1.0, | |
lora_dim=4, alpha=1, | |
dropout=0., rank_dropout=0., module_dropout=0., | |
use_cp=False, | |
network: 'LycorisSpecialNetwork' = None, | |
use_bias=False, | |
**kwargs, | |
): | |
""" if alpha == 0 or None, alpha is rank (no scaling). """ | |
# call super of super | |
ToolkitModuleMixin.__init__(self, network=network) | |
torch.nn.Module.__init__(self) | |
self.lora_name = lora_name | |
self.lora_dim = lora_dim | |
self.cp = False | |
# check if parent has bias. if not force use_bias to False | |
if org_module.bias is None: | |
use_bias = False | |
self.scalar = nn.Parameter(torch.tensor(0.0)) | |
orig_module_name = org_module.__class__.__name__ | |
if orig_module_name in CONV_MODULES: | |
self.isconv = True | |
# For general LoCon | |
in_dim = org_module.in_channels | |
k_size = org_module.kernel_size | |
stride = org_module.stride | |
padding = org_module.padding | |
out_dim = org_module.out_channels | |
self.down_op = F.conv2d | |
self.up_op = F.conv2d | |
if use_cp and k_size != (1, 1): | |
self.lora_down = nn.Conv2d(in_dim, lora_dim, (1, 1), bias=False) | |
self.lora_mid = nn.Conv2d(lora_dim, lora_dim, k_size, stride, padding, bias=False) | |
self.cp = True | |
else: | |
self.lora_down = nn.Conv2d(in_dim, lora_dim, k_size, stride, padding, bias=False) | |
self.lora_up = nn.Conv2d(lora_dim, out_dim, (1, 1), bias=use_bias) | |
elif orig_module_name in LINEAR_MODULES: | |
self.isconv = False | |
self.down_op = F.linear | |
self.up_op = F.linear | |
if orig_module_name == 'GroupNorm': | |
# RuntimeError: mat1 and mat2 shapes cannot be multiplied (56320x120 and 320x32) | |
in_dim = org_module.num_channels | |
out_dim = org_module.num_channels | |
else: | |
in_dim = org_module.in_features | |
out_dim = org_module.out_features | |
self.lora_down = nn.Linear(in_dim, lora_dim, bias=False) | |
self.lora_up = nn.Linear(lora_dim, out_dim, bias=use_bias) | |
else: | |
raise NotImplementedError | |
self.shape = org_module.weight.shape | |
if dropout: | |
self.dropout = nn.Dropout(dropout) | |
else: | |
self.dropout = nn.Identity() | |
self.rank_dropout = rank_dropout | |
self.module_dropout = module_dropout | |
if type(alpha) == torch.Tensor: | |
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error | |
alpha = lora_dim if alpha is None or alpha == 0 else alpha | |
self.scale = alpha / self.lora_dim | |
self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える | |
# same as microsoft's | |
torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) | |
torch.nn.init.kaiming_uniform_(self.lora_up.weight) | |
if self.cp: | |
torch.nn.init.kaiming_uniform_(self.lora_mid.weight, a=math.sqrt(5)) | |
self.multiplier = multiplier | |
self.org_module = [org_module] | |
self.register_load_state_dict_post_hook(self.load_weight_hook) | |
def load_weight_hook(self, *args, **kwargs): | |
self.scalar = nn.Parameter(torch.ones_like(self.scalar)) | |
class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork): | |
UNET_TARGET_REPLACE_MODULE = [ | |
"Transformer2DModel", | |
"ResnetBlock2D", | |
"Downsample2D", | |
"Upsample2D", | |
# 'UNet2DConditionModel', | |
# 'Conv2d', | |
# 'Timesteps', | |
# 'TimestepEmbedding', | |
# 'Linear', | |
# 'SiLU', | |
# 'ModuleList', | |
# 'DownBlock2D', | |
# 'ResnetBlock2D', # need | |
# 'GroupNorm', | |
# 'LoRACompatibleConv', | |
# 'LoRACompatibleLinear', | |
# 'Dropout', | |
# 'CrossAttnDownBlock2D', # needed | |
# 'Transformer2DModel', # maybe not, has duplicates | |
# 'BasicTransformerBlock', # duplicates | |
# 'LayerNorm', | |
# 'Attention', | |
# 'FeedForward', | |
# 'GEGLU', | |
# 'UpBlock2D', | |
# 'UNetMidBlock2DCrossAttn' | |
] | |
UNET_TARGET_REPLACE_NAME = [ | |
"conv_in", | |
"conv_out", | |
"time_embedding.linear_1", | |
"time_embedding.linear_2", | |
] | |
def __init__( | |
self, | |
text_encoder: Union[List[CLIPTextModel], CLIPTextModel], | |
unet, | |
multiplier: float = 1.0, | |
lora_dim: int = 4, | |
alpha: float = 1, | |
dropout: Optional[float] = None, | |
rank_dropout: Optional[float] = None, | |
module_dropout: Optional[float] = None, | |
conv_lora_dim: Optional[int] = None, | |
conv_alpha: Optional[float] = None, | |
use_cp: Optional[bool] = False, | |
network_module: Type[object] = LoConSpecialModule, | |
train_unet: bool = True, | |
train_text_encoder: bool = True, | |
use_text_encoder_1: bool = True, | |
use_text_encoder_2: bool = True, | |
use_bias: bool = False, | |
is_lorm: bool = False, | |
**kwargs, | |
) -> None: | |
# call ToolkitNetworkMixin super | |
ToolkitNetworkMixin.__init__( | |
self, | |
train_text_encoder=train_text_encoder, | |
train_unet=train_unet, | |
is_lorm=is_lorm, | |
**kwargs | |
) | |
# call the parent of the parent LycorisNetwork | |
torch.nn.Module.__init__(self) | |
# LyCORIS unique stuff | |
if dropout is None: | |
dropout = 0 | |
if rank_dropout is None: | |
rank_dropout = 0 | |
if module_dropout is None: | |
module_dropout = 0 | |
self.train_unet = train_unet | |
self.train_text_encoder = train_text_encoder | |
self.torch_multiplier = None | |
# triggers a tensor update | |
self.multiplier = multiplier | |
self.lora_dim = lora_dim | |
if not self.ENABLE_CONV or conv_lora_dim is None: | |
conv_lora_dim = 0 | |
conv_alpha = 0 | |
self.conv_lora_dim = int(conv_lora_dim) | |
if self.conv_lora_dim and self.conv_lora_dim != self.lora_dim: | |
print('Apply different lora dim for conv layer') | |
print(f'Conv Dim: {conv_lora_dim}, Linear Dim: {lora_dim}') | |
elif self.conv_lora_dim == 0: | |
print('Disable conv layer') | |
self.alpha = alpha | |
self.conv_alpha = float(conv_alpha) | |
if self.conv_lora_dim and self.alpha != self.conv_alpha: | |
print('Apply different alpha value for conv layer') | |
print(f'Conv alpha: {conv_alpha}, Linear alpha: {alpha}') | |
if 1 >= dropout >= 0: | |
print(f'Use Dropout value: {dropout}') | |
self.dropout = dropout | |
self.rank_dropout = rank_dropout | |
self.module_dropout = module_dropout | |
# create module instances | |
def create_modules( | |
prefix, | |
root_module: torch.nn.Module, | |
target_replace_modules, | |
target_replace_names=[] | |
) -> List[network_module]: | |
print('Create LyCORIS Module') | |
loras = [] | |
# remove this | |
named_modules = root_module.named_modules() | |
# add a few to tthe generator | |
for name, module in named_modules: | |
module_name = module.__class__.__name__ | |
if module_name in target_replace_modules: | |
if module_name in self.MODULE_ALGO_MAP: | |
algo = self.MODULE_ALGO_MAP[module_name] | |
else: | |
algo = network_module | |
for child_name, child_module in module.named_modules(): | |
lora_name = prefix + '.' + name + '.' + child_name | |
lora_name = lora_name.replace('.', '_') | |
if lora_name.startswith('lora_unet_input_blocks_1_0_emb_layers_1'): | |
print(f"{lora_name}") | |
if child_module.__class__.__name__ in LINEAR_MODULES and lora_dim > 0: | |
lora = algo( | |
lora_name, child_module, self.multiplier, | |
self.lora_dim, self.alpha, | |
self.dropout, self.rank_dropout, self.module_dropout, | |
use_cp, | |
network=self, | |
parent=module, | |
use_bias=use_bias, | |
**kwargs | |
) | |
elif child_module.__class__.__name__ in CONV_MODULES: | |
k_size, *_ = child_module.kernel_size | |
if k_size == 1 and lora_dim > 0: | |
lora = algo( | |
lora_name, child_module, self.multiplier, | |
self.lora_dim, self.alpha, | |
self.dropout, self.rank_dropout, self.module_dropout, | |
use_cp, | |
network=self, | |
parent=module, | |
use_bias=use_bias, | |
**kwargs | |
) | |
elif conv_lora_dim > 0: | |
lora = algo( | |
lora_name, child_module, self.multiplier, | |
self.conv_lora_dim, self.conv_alpha, | |
self.dropout, self.rank_dropout, self.module_dropout, | |
use_cp, | |
network=self, | |
parent=module, | |
use_bias=use_bias, | |
**kwargs | |
) | |
else: | |
continue | |
else: | |
continue | |
loras.append(lora) | |
elif name in target_replace_names: | |
if name in self.NAME_ALGO_MAP: | |
algo = self.NAME_ALGO_MAP[name] | |
else: | |
algo = network_module | |
lora_name = prefix + '.' + name | |
lora_name = lora_name.replace('.', '_') | |
if module.__class__.__name__ == 'Linear' and lora_dim > 0: | |
lora = algo( | |
lora_name, module, self.multiplier, | |
self.lora_dim, self.alpha, | |
self.dropout, self.rank_dropout, self.module_dropout, | |
use_cp, | |
parent=module, | |
network=self, | |
use_bias=use_bias, | |
**kwargs | |
) | |
elif module.__class__.__name__ == 'Conv2d': | |
k_size, *_ = module.kernel_size | |
if k_size == 1 and lora_dim > 0: | |
lora = algo( | |
lora_name, module, self.multiplier, | |
self.lora_dim, self.alpha, | |
self.dropout, self.rank_dropout, self.module_dropout, | |
use_cp, | |
network=self, | |
parent=module, | |
use_bias=use_bias, | |
**kwargs | |
) | |
elif conv_lora_dim > 0: | |
lora = algo( | |
lora_name, module, self.multiplier, | |
self.conv_lora_dim, self.conv_alpha, | |
self.dropout, self.rank_dropout, self.module_dropout, | |
use_cp, | |
network=self, | |
parent=module, | |
use_bias=use_bias, | |
**kwargs | |
) | |
else: | |
continue | |
else: | |
continue | |
loras.append(lora) | |
return loras | |
if network_module == GLoRAModule: | |
print('GLoRA enabled, only train transformer') | |
# only train transformer (for GLoRA) | |
LycorisSpecialNetwork.UNET_TARGET_REPLACE_MODULE = [ | |
"Transformer2DModel", | |
"Attention", | |
] | |
LycorisSpecialNetwork.UNET_TARGET_REPLACE_NAME = [] | |
if isinstance(text_encoder, list): | |
text_encoders = text_encoder | |
use_index = True | |
else: | |
text_encoders = [text_encoder] | |
use_index = False | |
self.text_encoder_loras = [] | |
if self.train_text_encoder: | |
for i, te in enumerate(text_encoders): | |
if not use_text_encoder_1 and i == 0: | |
continue | |
if not use_text_encoder_2 and i == 1: | |
continue | |
self.text_encoder_loras.extend(create_modules( | |
LycorisSpecialNetwork.LORA_PREFIX_TEXT_ENCODER + (f'{i + 1}' if use_index else ''), | |
te, | |
LycorisSpecialNetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE | |
)) | |
print(f"create LyCORIS for Text Encoder: {len(self.text_encoder_loras)} modules.") | |
if self.train_unet: | |
self.unet_loras = create_modules(LycorisSpecialNetwork.LORA_PREFIX_UNET, unet, | |
LycorisSpecialNetwork.UNET_TARGET_REPLACE_MODULE) | |
else: | |
self.unet_loras = [] | |
print(f"create LyCORIS for U-Net: {len(self.unet_loras)} modules.") | |
self.weights_sd = None | |
# assertion | |
names = set() | |
for lora in self.text_encoder_loras + self.unet_loras: | |
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}" | |
names.add(lora.lora_name) | |