Spaces:
Paused
Paused
from typing import Union, Tuple, Literal, Optional | |
import torch | |
import torch.nn as nn | |
from diffusers import UNet2DConditionModel | |
from torch import Tensor | |
from tqdm import tqdm | |
from toolkit.config_modules import LoRMConfig | |
conv = nn.Conv2d | |
lin = nn.Linear | |
_size_2_t = Union[int, Tuple[int, int]] | |
ExtractMode = Union[ | |
'fixed', | |
'threshold', | |
'ratio', | |
'quantile', | |
'percentage' | |
] | |
LINEAR_MODULES = [ | |
'Linear', | |
'LoRACompatibleLinear' | |
] | |
CONV_MODULES = [ | |
# 'Conv2d', | |
# 'LoRACompatibleConv' | |
] | |
UNET_TARGET_REPLACE_MODULE = [ | |
"Transformer2DModel", | |
# "ResnetBlock2D", | |
"Downsample2D", | |
"Upsample2D", | |
] | |
LORM_TARGET_REPLACE_MODULE = UNET_TARGET_REPLACE_MODULE | |
UNET_TARGET_REPLACE_NAME = [ | |
"conv_in", | |
"conv_out", | |
"time_embedding.linear_1", | |
"time_embedding.linear_2", | |
] | |
UNET_MODULES_TO_AVOID = [ | |
] | |
# Low Rank Convolution | |
class LoRMCon2d(nn.Module): | |
def __init__( | |
self, | |
in_channels: int, | |
lorm_channels: int, | |
out_channels: int, | |
kernel_size: _size_2_t, | |
stride: _size_2_t = 1, | |
padding: Union[str, _size_2_t] = 'same', | |
dilation: _size_2_t = 1, | |
groups: int = 1, | |
bias: bool = True, | |
padding_mode: str = 'zeros', | |
device=None, | |
dtype=None | |
) -> None: | |
super().__init__() | |
self.in_channels = in_channels | |
self.lorm_channels = lorm_channels | |
self.out_channels = out_channels | |
self.kernel_size = kernel_size | |
self.stride = stride | |
self.padding = padding | |
self.dilation = dilation | |
self.groups = groups | |
self.padding_mode = padding_mode | |
self.down = nn.Conv2d( | |
in_channels=in_channels, | |
out_channels=lorm_channels, | |
kernel_size=kernel_size, | |
stride=stride, | |
padding=padding, | |
dilation=dilation, | |
groups=groups, | |
bias=False, | |
padding_mode=padding_mode, | |
device=device, | |
dtype=dtype | |
) | |
# Kernel size on the up is always 1x1. | |
# I don't think you could calculate a dual 3x3, or I can't at least | |
self.up = nn.Conv2d( | |
in_channels=lorm_channels, | |
out_channels=out_channels, | |
kernel_size=(1, 1), | |
stride=1, | |
padding='same', | |
dilation=1, | |
groups=1, | |
bias=bias, | |
padding_mode='zeros', | |
device=device, | |
dtype=dtype | |
) | |
def forward(self, input: Tensor, *args, **kwargs) -> Tensor: | |
x = input | |
x = self.down(x) | |
x = self.up(x) | |
return x | |
class LoRMLinear(nn.Module): | |
def __init__( | |
self, | |
in_features: int, | |
lorm_features: int, | |
out_features: int, | |
bias: bool = True, | |
device=None, | |
dtype=None | |
) -> None: | |
super().__init__() | |
self.in_features = in_features | |
self.lorm_features = lorm_features | |
self.out_features = out_features | |
self.down = nn.Linear( | |
in_features=in_features, | |
out_features=lorm_features, | |
bias=False, | |
device=device, | |
dtype=dtype | |
) | |
self.up = nn.Linear( | |
in_features=lorm_features, | |
out_features=out_features, | |
bias=bias, | |
# bias=True, | |
device=device, | |
dtype=dtype | |
) | |
def forward(self, input: Tensor, *args, **kwargs) -> Tensor: | |
x = input | |
x = self.down(x) | |
x = self.up(x) | |
return x | |
def extract_conv( | |
weight: Union[torch.Tensor, nn.Parameter], | |
mode='fixed', | |
mode_param=0, | |
device='cpu' | |
) -> Tuple[Tensor, Tensor, int, Tensor]: | |
weight = weight.to(device) | |
out_ch, in_ch, kernel_size, _ = weight.shape | |
U, S, Vh = torch.linalg.svd(weight.reshape(out_ch, -1)) | |
if mode == 'percentage': | |
assert 0 <= mode_param <= 1 # Ensure it's a valid percentage. | |
original_params = out_ch * in_ch * kernel_size * kernel_size | |
desired_params = mode_param * original_params | |
# Solve for lora_rank from the equation | |
lora_rank = int(desired_params / (in_ch * kernel_size * kernel_size + out_ch)) | |
elif mode == 'fixed': | |
lora_rank = mode_param | |
elif mode == 'threshold': | |
assert mode_param >= 0 | |
lora_rank = torch.sum(S > mode_param).item() | |
elif mode == 'ratio': | |
assert 1 >= mode_param >= 0 | |
min_s = torch.max(S) * mode_param | |
lora_rank = torch.sum(S > min_s).item() | |
elif mode == 'quantile' or mode == 'percentile': | |
assert 1 >= mode_param >= 0 | |
s_cum = torch.cumsum(S, dim=0) | |
min_cum_sum = mode_param * torch.sum(S) | |
lora_rank = torch.sum(s_cum < min_cum_sum).item() | |
else: | |
raise NotImplementedError('Extract mode should be "fixed", "threshold", "ratio" or "quantile"') | |
lora_rank = max(1, lora_rank) | |
lora_rank = min(out_ch, in_ch, lora_rank) | |
if lora_rank >= out_ch / 2: | |
lora_rank = int(out_ch / 2) | |
print(f"rank is higher than it should be") | |
# print(f"Skipping layer as determined rank is too high") | |
# return None, None, None, None | |
# return weight, 'full' | |
U = U[:, :lora_rank] | |
S = S[:lora_rank] | |
U = U @ torch.diag(S) | |
Vh = Vh[:lora_rank, :] | |
diff = (weight - (U @ Vh).reshape(out_ch, in_ch, kernel_size, kernel_size)).detach() | |
extract_weight_A = Vh.reshape(lora_rank, in_ch, kernel_size, kernel_size).detach() | |
extract_weight_B = U.reshape(out_ch, lora_rank, 1, 1).detach() | |
del U, S, Vh, weight | |
return extract_weight_A, extract_weight_B, lora_rank, diff | |
def extract_linear( | |
weight: Union[torch.Tensor, nn.Parameter], | |
mode='fixed', | |
mode_param=0, | |
device='cpu', | |
) -> Tuple[Tensor, Tensor, int, Tensor]: | |
weight = weight.to(device) | |
out_ch, in_ch = weight.shape | |
U, S, Vh = torch.linalg.svd(weight) | |
if mode == 'percentage': | |
assert 0 <= mode_param <= 1 # Ensure it's a valid percentage. | |
desired_params = mode_param * out_ch * in_ch | |
# Solve for lora_rank from the equation | |
lora_rank = int(desired_params / (in_ch + out_ch)) | |
elif mode == 'fixed': | |
lora_rank = mode_param | |
elif mode == 'threshold': | |
assert mode_param >= 0 | |
lora_rank = torch.sum(S > mode_param).item() | |
elif mode == 'ratio': | |
assert 1 >= mode_param >= 0 | |
min_s = torch.max(S) * mode_param | |
lora_rank = torch.sum(S > min_s).item() | |
elif mode == 'quantile': | |
assert 1 >= mode_param >= 0 | |
s_cum = torch.cumsum(S, dim=0) | |
min_cum_sum = mode_param * torch.sum(S) | |
lora_rank = torch.sum(s_cum < min_cum_sum).item() | |
else: | |
raise NotImplementedError('Extract mode should be "fixed", "threshold", "ratio" or "quantile"') | |
lora_rank = max(1, lora_rank) | |
lora_rank = min(out_ch, in_ch, lora_rank) | |
if lora_rank >= out_ch / 2: | |
# print(f"rank is higher than it should be") | |
lora_rank = int(out_ch / 2) | |
# return weight, 'full' | |
# print(f"Skipping layer as determined rank is too high") | |
# return None, None, None, None | |
U = U[:, :lora_rank] | |
S = S[:lora_rank] | |
U = U @ torch.diag(S) | |
Vh = Vh[:lora_rank, :] | |
diff = (weight - U @ Vh).detach() | |
extract_weight_A = Vh.reshape(lora_rank, in_ch).detach() | |
extract_weight_B = U.reshape(out_ch, lora_rank).detach() | |
del U, S, Vh, weight | |
return extract_weight_A, extract_weight_B, lora_rank, diff | |
def replace_module_by_path(network, name, module): | |
"""Replace a module in a network by its name.""" | |
name_parts = name.split('.') | |
current_module = network | |
for part in name_parts[:-1]: | |
current_module = getattr(current_module, part) | |
try: | |
setattr(current_module, name_parts[-1], module) | |
except Exception as e: | |
print(e) | |
def count_parameters(module): | |
return sum(p.numel() for p in module.parameters()) | |
def compute_optimal_bias(original_module, linear_down, linear_up, X): | |
Y_original = original_module(X) | |
Y_approx = linear_up(linear_down(X)) | |
E = Y_original - Y_approx | |
optimal_bias = E.mean(dim=0) | |
return optimal_bias | |
def format_with_commas(n): | |
return f"{n:,}" | |
def print_lorm_extract_details( | |
start_num_params: int, | |
end_num_params: int, | |
num_replaced: int, | |
): | |
start_formatted = format_with_commas(start_num_params) | |
end_formatted = format_with_commas(end_num_params) | |
num_replaced_formatted = format_with_commas(num_replaced) | |
width = max(len(start_formatted), len(end_formatted), len(num_replaced_formatted)) | |
print(f"Convert UNet result:") | |
print(f" - converted: {num_replaced:>{width},} modules") | |
print(f" - start: {start_num_params:>{width},} params") | |
print(f" - end: {end_num_params:>{width},} params") | |
lorm_ignore_if_contains = [ | |
'proj_out', 'proj_in', | |
] | |
lorm_parameter_threshold = 1000000 | |
def convert_diffusers_unet_to_lorm( | |
unet: UNet2DConditionModel, | |
config: LoRMConfig, | |
): | |
print('Converting UNet to LoRM UNet') | |
start_num_params = count_parameters(unet) | |
named_modules = list(unet.named_modules()) | |
num_replaced = 0 | |
pbar = tqdm(total=len(named_modules), desc="UNet -> LoRM UNet") | |
layer_names_replaced = [] | |
converted_modules = [] | |
ignore_if_contains = [ | |
'proj_out', 'proj_in', | |
] | |
for name, module in named_modules: | |
module_name = module.__class__.__name__ | |
if module_name in UNET_TARGET_REPLACE_MODULE: | |
for child_name, child_module in module.named_modules(): | |
new_module: Union[LoRMCon2d, LoRMLinear, None] = None | |
# if child name includes attn, skip it | |
combined_name = combined_name = f"{name}.{child_name}" | |
# if child_module.__class__.__name__ in LINEAR_MODULES and child_module.bias is None: | |
# pass | |
lorm_config = config.get_config_for_module(combined_name) | |
extract_mode = lorm_config.extract_mode | |
extract_mode_param = lorm_config.extract_mode_param | |
parameter_threshold = lorm_config.parameter_threshold | |
if any([word in child_name for word in ignore_if_contains]): | |
pass | |
elif child_module.__class__.__name__ in LINEAR_MODULES: | |
if count_parameters(child_module) > parameter_threshold: | |
# dtype = child_module.weight.dtype | |
dtype = torch.float32 | |
# extract and convert | |
down_weight, up_weight, lora_dim, diff = extract_linear( | |
weight=child_module.weight.clone().detach().float(), | |
mode=extract_mode, | |
mode_param=extract_mode_param, | |
device=child_module.weight.device, | |
) | |
if down_weight is None: | |
continue | |
down_weight = down_weight.to(dtype=dtype) | |
up_weight = up_weight.to(dtype=dtype) | |
bias_weight = None | |
if child_module.bias is not None: | |
bias_weight = child_module.bias.data.clone().detach().to(dtype=dtype) | |
# linear layer weights = (out_features, in_features) | |
new_module = LoRMLinear( | |
in_features=down_weight.shape[1], | |
lorm_features=lora_dim, | |
out_features=up_weight.shape[0], | |
bias=bias_weight is not None, | |
device=down_weight.device, | |
dtype=down_weight.dtype | |
) | |
# replace the weights | |
new_module.down.weight.data = down_weight | |
new_module.up.weight.data = up_weight | |
if bias_weight is not None: | |
new_module.up.bias.data = bias_weight | |
# else: | |
# new_module.up.bias.data = torch.zeros_like(new_module.up.bias.data) | |
# bias_correction = compute_optimal_bias( | |
# child_module, | |
# new_module.down, | |
# new_module.up, | |
# torch.randn((1000, down_weight.shape[1])).to(device=down_weight.device, dtype=dtype) | |
# ) | |
# new_module.up.bias.data += bias_correction | |
elif child_module.__class__.__name__ in CONV_MODULES: | |
if count_parameters(child_module) > parameter_threshold: | |
dtype = child_module.weight.dtype | |
down_weight, up_weight, lora_dim, diff = extract_conv( | |
weight=child_module.weight.clone().detach().float(), | |
mode=extract_mode, | |
mode_param=extract_mode_param, | |
device=child_module.weight.device, | |
) | |
if down_weight is None: | |
continue | |
down_weight = down_weight.to(dtype=dtype) | |
up_weight = up_weight.to(dtype=dtype) | |
bias_weight = None | |
if child_module.bias is not None: | |
bias_weight = child_module.bias.data.clone().detach().to(dtype=dtype) | |
new_module = LoRMCon2d( | |
in_channels=down_weight.shape[1], | |
lorm_channels=lora_dim, | |
out_channels=up_weight.shape[0], | |
kernel_size=child_module.kernel_size, | |
dilation=child_module.dilation, | |
padding=child_module.padding, | |
padding_mode=child_module.padding_mode, | |
stride=child_module.stride, | |
bias=bias_weight is not None, | |
device=down_weight.device, | |
dtype=down_weight.dtype | |
) | |
# replace the weights | |
new_module.down.weight.data = down_weight | |
new_module.up.weight.data = up_weight | |
if bias_weight is not None: | |
new_module.up.bias.data = bias_weight | |
if new_module: | |
combined_name = f"{name}.{child_name}" | |
replace_module_by_path(unet, combined_name, new_module) | |
converted_modules.append(new_module) | |
num_replaced += 1 | |
layer_names_replaced.append( | |
f"{combined_name} - {format_with_commas(count_parameters(child_module))}") | |
pbar.update(1) | |
pbar.close() | |
end_num_params = count_parameters(unet) | |
def sorting_key(s): | |
# Extract the number part, remove commas, and convert to integer | |
return int(s.split("-")[1].strip().replace(",", "")) | |
sorted_layer_names_replaced = sorted(layer_names_replaced, key=sorting_key, reverse=True) | |
for layer_name in sorted_layer_names_replaced: | |
print(layer_name) | |
print_lorm_extract_details( | |
start_num_params=start_num_params, | |
end_num_params=end_num_params, | |
num_replaced=num_replaced, | |
) | |
return converted_modules | |