Spaces:
Paused
Paused
# based heavily on https://github.com/KohakuBlueleaf/LyCORIS/blob/eb460098187f752a5d66406d3affade6f0a07ece/lycoris/modules/lokr.py | |
import math | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from toolkit.network_mixins import ToolkitModuleMixin | |
from typing import TYPE_CHECKING, Union, List | |
from optimum.quanto import QBytesTensor, QTensor | |
if TYPE_CHECKING: | |
from toolkit.lora_special import LoRASpecialNetwork | |
def factorization(dimension: int, factor: int = -1) -> tuple[int, int]: | |
''' | |
return a tuple of two value of input dimension decomposed by the number closest to factor | |
second value is higher or equal than first value. | |
In LoRA with Kroneckor Product, first value is a value for weight scale. | |
secon value is a value for weight. | |
Becuase of non-commutative property, A⊗B ≠ B⊗A. Meaning of two matrices is slightly different. | |
examples) | |
factor | |
-1 2 4 8 16 ... | |
127 -> 127, 1 127 -> 127, 1 127 -> 127, 1 127 -> 127, 1 127 -> 127, 1 | |
128 -> 16, 8 128 -> 64, 2 128 -> 32, 4 128 -> 16, 8 128 -> 16, 8 | |
250 -> 125, 2 250 -> 125, 2 250 -> 125, 2 250 -> 125, 2 250 -> 125, 2 | |
360 -> 45, 8 360 -> 180, 2 360 -> 90, 4 360 -> 45, 8 360 -> 45, 8 | |
512 -> 32, 16 512 -> 256, 2 512 -> 128, 4 512 -> 64, 8 512 -> 32, 16 | |
1024 -> 32, 32 1024 -> 512, 2 1024 -> 256, 4 1024 -> 128, 8 1024 -> 64, 16 | |
''' | |
if factor > 0 and (dimension % factor) == 0: | |
m = factor | |
n = dimension // factor | |
return m, n | |
if factor == -1: | |
factor = dimension | |
m, n = 1, dimension | |
length = m + n | |
while m < n: | |
new_m = m + 1 | |
while dimension % new_m != 0: | |
new_m += 1 | |
new_n = dimension // new_m | |
if new_m + new_n > length or new_m > factor: | |
break | |
else: | |
m, n = new_m, new_n | |
if m > n: | |
n, m = m, n | |
return m, n | |
def make_weight_cp(t, wa, wb): | |
rebuild2 = torch.einsum('i j k l, i p, j r -> p r k l', | |
t, wa, wb) # [c, d, k1, k2] | |
return rebuild2 | |
def make_kron(w1, w2, scale): | |
if len(w2.shape) == 4: | |
w1 = w1.unsqueeze(2).unsqueeze(2) | |
w2 = w2.contiguous() | |
rebuild = torch.kron(w1, w2) | |
return rebuild*scale | |
class LokrModule(ToolkitModuleMixin, nn.Module): | |
def __init__( | |
self, | |
lora_name, | |
org_module: nn.Module, | |
multiplier=1.0, | |
lora_dim=4, | |
alpha=1, | |
dropout=0., | |
rank_dropout=0., | |
module_dropout=0., | |
use_cp=False, | |
decompose_both=False, | |
network: 'LoRASpecialNetwork' = None, | |
factor: int = -1, # factorization factor | |
**kwargs, | |
): | |
""" if alpha == 0 or None, alpha is rank (no scaling). """ | |
ToolkitModuleMixin.__init__(self, network=network) | |
torch.nn.Module.__init__(self) | |
factor = int(factor) | |
self.lora_name = lora_name | |
self.lora_dim = lora_dim | |
self.cp = False | |
self.use_w1 = False | |
self.use_w2 = False | |
self.can_merge_in = True | |
self.shape = org_module.weight.shape | |
if org_module.__class__.__name__ == 'Conv2d': | |
in_dim = org_module.in_channels | |
k_size = org_module.kernel_size | |
out_dim = org_module.out_channels | |
in_m, in_n = factorization(in_dim, factor) | |
out_l, out_k = factorization(out_dim, factor) | |
# ((a, b), (c, d), *k_size) | |
shape = ((out_l, out_k), (in_m, in_n), *k_size) | |
self.cp = use_cp and k_size != (1, 1) | |
if decompose_both and lora_dim < max(shape[0][0], shape[1][0])/2: | |
self.lokr_w1_a = nn.Parameter( | |
torch.empty(shape[0][0], lora_dim)) | |
self.lokr_w1_b = nn.Parameter( | |
torch.empty(lora_dim, shape[1][0])) | |
else: | |
self.use_w1 = True | |
self.lokr_w1 = nn.Parameter(torch.empty( | |
shape[0][0], shape[1][0])) # a*c, 1-mode | |
if lora_dim >= max(shape[0][1], shape[1][1])/2: | |
self.use_w2 = True | |
self.lokr_w2 = nn.Parameter(torch.empty( | |
shape[0][1], shape[1][1], *k_size)) | |
elif self.cp: | |
self.lokr_t2 = nn.Parameter(torch.empty( | |
lora_dim, lora_dim, shape[2], shape[3])) | |
self.lokr_w2_a = nn.Parameter( | |
torch.empty(lora_dim, shape[0][1])) # b, 1-mode | |
self.lokr_w2_b = nn.Parameter( | |
torch.empty(lora_dim, shape[1][1])) # d, 2-mode | |
else: # Conv2d not cp | |
# bigger part. weight and LoRA. [b, dim] x [dim, d*k1*k2] | |
self.lokr_w2_a = nn.Parameter( | |
torch.empty(shape[0][1], lora_dim)) | |
self.lokr_w2_b = nn.Parameter(torch.empty( | |
lora_dim, shape[1][1]*shape[2]*shape[3])) | |
# w1 ⊗ (w2_a x w2_b) = (a, b)⊗((c, dim)x(dim, d*k1*k2)) = (a, b)⊗(c, d*k1*k2) = (ac, bd*k1*k2) | |
self.op = F.conv2d | |
self.extra_args = { | |
"stride": org_module.stride, | |
"padding": org_module.padding, | |
"dilation": org_module.dilation, | |
"groups": org_module.groups | |
} | |
else: # Linear | |
in_dim = org_module.in_features | |
out_dim = org_module.out_features | |
in_m, in_n = factorization(in_dim, factor) | |
out_l, out_k = factorization(out_dim, factor) | |
# ((a, b), (c, d)), out_dim = a*c, in_dim = b*d | |
shape = ((out_l, out_k), (in_m, in_n)) | |
# smaller part. weight scale | |
if decompose_both and lora_dim < max(shape[0][0], shape[1][0])/2: | |
self.lokr_w1_a = nn.Parameter( | |
torch.empty(shape[0][0], lora_dim)) | |
self.lokr_w1_b = nn.Parameter( | |
torch.empty(lora_dim, shape[1][0])) | |
else: | |
self.use_w1 = True | |
self.lokr_w1 = nn.Parameter(torch.empty( | |
shape[0][0], shape[1][0])) # a*c, 1-mode | |
if lora_dim < max(shape[0][1], shape[1][1])/2: | |
# bigger part. weight and LoRA. [b, dim] x [dim, d] | |
self.lokr_w2_a = nn.Parameter( | |
torch.empty(shape[0][1], lora_dim)) | |
self.lokr_w2_b = nn.Parameter( | |
torch.empty(lora_dim, shape[1][1])) | |
# w1 ⊗ (w2_a x w2_b) = (a, b)⊗((c, dim)x(dim, d)) = (a, b)⊗(c, d) = (ac, bd) | |
else: | |
self.use_w2 = True | |
self.lokr_w2 = nn.Parameter( | |
torch.empty(shape[0][1], shape[1][1])) | |
self.op = F.linear | |
self.extra_args = {} | |
self.dropout = dropout | |
if dropout: | |
print("[WARN]LoKr haven't implemented normal dropout yet.") | |
self.rank_dropout = rank_dropout | |
self.module_dropout = module_dropout | |
if isinstance(alpha, torch.Tensor): | |
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error | |
alpha = lora_dim if alpha is None or alpha == 0 else alpha | |
if self.use_w2 and self.use_w1: | |
# use scale = 1 | |
alpha = lora_dim | |
self.scale = alpha / self.lora_dim | |
self.register_buffer('alpha', torch.tensor(alpha)) # treat as constant | |
if self.use_w2: | |
torch.nn.init.constant_(self.lokr_w2, 0) | |
else: | |
if self.cp: | |
torch.nn.init.kaiming_uniform_(self.lokr_t2, a=math.sqrt(5)) | |
torch.nn.init.kaiming_uniform_(self.lokr_w2_a, a=math.sqrt(5)) | |
torch.nn.init.constant_(self.lokr_w2_b, 0) | |
if self.use_w1: | |
torch.nn.init.kaiming_uniform_(self.lokr_w1, a=math.sqrt(5)) | |
else: | |
torch.nn.init.kaiming_uniform_(self.lokr_w1_a, a=math.sqrt(5)) | |
torch.nn.init.kaiming_uniform_(self.lokr_w1_b, a=math.sqrt(5)) | |
self.multiplier = multiplier | |
self.org_module = [org_module] | |
weight = make_kron( | |
self.lokr_w1 if self.use_w1 else [email protected]_w1_b, | |
(self.lokr_w2 if self.use_w2 | |
else make_weight_cp(self.lokr_t2, self.lokr_w2_a, self.lokr_w2_b) if self.cp | |
else [email protected]_w2_b), | |
torch.tensor(self.multiplier * self.scale) | |
) | |
assert torch.sum(torch.isnan(weight)) == 0, "weight is nan" | |
# Same as locon.py | |
def apply_to(self): | |
self.org_forward = self.org_module[0].forward | |
self.org_module[0].forward = self.forward | |
def get_weight(self, orig_weight=None): | |
weight = make_kron( | |
self.lokr_w1 if self.use_w1 else [email protected]_w1_b, | |
(self.lokr_w2 if self.use_w2 | |
else make_weight_cp(self.lokr_t2, self.lokr_w2_a, self.lokr_w2_b) if self.cp | |
else [email protected]_w2_b), | |
torch.tensor(self.scale) | |
) | |
if orig_weight is not None: | |
weight = weight.reshape(orig_weight.shape) | |
if self.training and self.rank_dropout: | |
drop = torch.rand(weight.size(0)) < self.rank_dropout | |
weight *= drop.view(-1, [1] * | |
len(weight.shape[1:])).to(weight.device) | |
return weight | |
def merge_in(self, merge_weight=1.0): | |
if not self.can_merge_in: | |
return | |
# extract weight from org_module | |
org_sd = self.org_module[0].state_dict() | |
# todo find a way to merge in weights when doing quantized model | |
if 'weight._data' in org_sd: | |
# quantized weight | |
return | |
weight_key = "weight" | |
if 'weight._data' in org_sd: | |
# quantized weight | |
weight_key = "weight._data" | |
orig_dtype = org_sd[weight_key].dtype | |
weight = org_sd[weight_key].float() | |
scale = self.scale | |
# handle trainable scaler method locon does | |
if hasattr(self, 'scalar'): | |
scale = scale * self.scalar | |
lokr_weight = self.get_weight(weight) | |
merged_weight = ( | |
weight | |
+ (lokr_weight * merge_weight).to(weight.device, dtype=weight.dtype) | |
) | |
# set weight to org_module | |
org_sd[weight_key] = merged_weight.to(orig_dtype) | |
self.org_module[0].load_state_dict(org_sd) | |
def get_orig_weight(self): | |
weight = self.org_module[0].weight | |
if isinstance(weight, QTensor) or isinstance(weight, QBytesTensor): | |
return weight.dequantize().data.detach() | |
else: | |
return weight.data.detach() | |
def get_orig_bias(self): | |
if hasattr(self.org_module[0], 'bias') and self.org_module[0].bias is not None: | |
if isinstance(self.org_module[0].bias, QTensor) or isinstance(self.org_module[0].bias, QBytesTensor): | |
return self.org_module[0].bias.dequantize().data.detach() | |
else: | |
return self.org_module[0].bias.data.detach() | |
return None | |
def _call_forward(self, x): | |
if isinstance(x, QTensor) or isinstance(x, QBytesTensor): | |
x = x.dequantize() | |
orig_dtype = x.dtype | |
orig_weight = self.get_orig_weight() | |
lokr_weight = self.get_weight(orig_weight).to(dtype=orig_weight.dtype) | |
multiplier = self.network_ref().torch_multiplier | |
if x.dtype != orig_weight.dtype: | |
x = x.to(dtype=orig_weight.dtype) | |
# we do not currently support split batch multipliers for lokr. Just do a mean | |
multiplier = torch.mean(multiplier) | |
weight = ( | |
orig_weight | |
+ lokr_weight * multiplier | |
) | |
bias = self.get_orig_bias() | |
if bias is not None: | |
bias = bias.to(weight.device, dtype=weight.dtype) | |
output = self.op( | |
x, | |
weight.view(self.shape), | |
bias, | |
**self.extra_args | |
) | |
return output.to(orig_dtype) | |