Spaces:
Paused
Paused
import os | |
from tqdm import tqdm | |
import argparse | |
from collections import OrderedDict | |
parser = argparse.ArgumentParser(description="Extract LoRA from Flex") | |
parser.add_argument("--base", type=str, default="ostris/Flex.1-alpha", help="Base model path") | |
parser.add_argument("--tuned", type=str, required=True, help="Tuned model path") | |
parser.add_argument("--output", type=str, required=True, help="Output path for lora") | |
parser.add_argument("--rank", type=int, default=32, help="LoRA rank for extraction") | |
parser.add_argument("--gpu", type=int, default=0, help="GPU to process extraction") | |
parser.add_argument("--full", action="store_true", help="Do a full transformer extraction, not just transformer blocks") | |
args = parser.parse_args() | |
if True: | |
# set cuda environment variable | |
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) | |
import torch | |
from safetensors.torch import load_file, save_file | |
from lycoris.utils import extract_linear, extract_conv, make_sparse | |
from diffusers import FluxTransformer2DModel | |
base = args.base | |
tuned = args.tuned | |
output_path = args.output | |
dim = args.rank | |
os.makedirs(os.path.dirname(output_path), exist_ok=True) | |
state_dict_base = {} | |
state_dict_tuned = {} | |
output_dict = {} | |
def extract_diff( | |
base_unet, | |
db_unet, | |
mode="fixed", | |
linear_mode_param=0, | |
conv_mode_param=0, | |
extract_device="cpu", | |
use_bias=False, | |
sparsity=0.98, | |
# small_conv=True, | |
small_conv=False, | |
): | |
UNET_TARGET_REPLACE_MODULE = [ | |
"Linear", | |
"Conv2d", | |
"LayerNorm", | |
"GroupNorm", | |
"GroupNorm32", | |
"LoRACompatibleLinear", | |
"LoRACompatibleConv" | |
] | |
LORA_PREFIX_UNET = "transformer" | |
def make_state_dict( | |
prefix, | |
root_module: torch.nn.Module, | |
target_module: torch.nn.Module, | |
target_replace_modules, | |
): | |
loras = {} | |
temp = {} | |
for name, module in root_module.named_modules(): | |
if module.__class__.__name__ in target_replace_modules: | |
temp[name] = module | |
for name, module in tqdm( | |
list((n, m) for n, m in target_module.named_modules() if n in temp) | |
): | |
weights = temp[name] | |
lora_name = prefix + "." + name | |
# lora_name = lora_name.replace(".", "_") | |
layer = module.__class__.__name__ | |
if 'transformer_blocks' not in lora_name and not args.full: | |
continue | |
if layer in { | |
"Linear", | |
"Conv2d", | |
"LayerNorm", | |
"GroupNorm", | |
"GroupNorm32", | |
"Embedding", | |
"LoRACompatibleLinear", | |
"LoRACompatibleConv" | |
}: | |
root_weight = module.weight | |
try: | |
if torch.allclose(root_weight, weights.weight): | |
continue | |
except: | |
continue | |
else: | |
continue | |
module = module.to(extract_device, torch.float32) | |
weights = weights.to(extract_device, torch.float32) | |
if mode == "full": | |
decompose_mode = "full" | |
elif layer == "Linear": | |
weight, decompose_mode = extract_linear( | |
(root_weight - weights.weight), | |
mode, | |
linear_mode_param, | |
device=extract_device, | |
) | |
if decompose_mode == "low rank": | |
extract_a, extract_b, diff = weight | |
elif layer == "Conv2d": | |
is_linear = root_weight.shape[2] == 1 and root_weight.shape[3] == 1 | |
weight, decompose_mode = extract_conv( | |
(root_weight - weights.weight), | |
mode, | |
linear_mode_param if is_linear else conv_mode_param, | |
device=extract_device, | |
) | |
if decompose_mode == "low rank": | |
extract_a, extract_b, diff = weight | |
if small_conv and not is_linear and decompose_mode == "low rank": | |
dim = extract_a.size(0) | |
(extract_c, extract_a, _), _ = extract_conv( | |
extract_a.transpose(0, 1), | |
"fixed", | |
dim, | |
extract_device, | |
True, | |
) | |
extract_a = extract_a.transpose(0, 1) | |
extract_c = extract_c.transpose(0, 1) | |
loras[f"{lora_name}.lora_mid.weight"] = ( | |
extract_c.detach().cpu().contiguous().half() | |
) | |
diff = ( | |
( | |
root_weight | |
- torch.einsum( | |
"i j k l, j r, p i -> p r k l", | |
extract_c, | |
extract_a.flatten(1, -1), | |
extract_b.flatten(1, -1), | |
) | |
) | |
.detach() | |
.cpu() | |
.contiguous() | |
) | |
del extract_c | |
else: | |
module = module.to("cpu") | |
weights = weights.to("cpu") | |
continue | |
if decompose_mode == "low rank": | |
loras[f"{lora_name}.lora_A.weight"] = ( | |
extract_a.detach().cpu().contiguous().half() | |
) | |
loras[f"{lora_name}.lora_B.weight"] = ( | |
extract_b.detach().cpu().contiguous().half() | |
) | |
# loras[f"{lora_name}.alpha"] = torch.Tensor([extract_a.shape[0]]).half() | |
if use_bias: | |
diff = diff.detach().cpu().reshape(extract_b.size(0), -1) | |
sparse_diff = make_sparse(diff, sparsity).to_sparse().coalesce() | |
indices = sparse_diff.indices().to(torch.int16) | |
values = sparse_diff.values().half() | |
loras[f"{lora_name}.bias_indices"] = indices | |
loras[f"{lora_name}.bias_values"] = values | |
loras[f"{lora_name}.bias_size"] = torch.tensor(diff.shape).to( | |
torch.int16 | |
) | |
del extract_a, extract_b, diff | |
elif decompose_mode == "full": | |
if "Norm" in layer: | |
w_key = "w_norm" | |
b_key = "b_norm" | |
else: | |
w_key = "diff" | |
b_key = "diff_b" | |
weight_diff = module.weight - weights.weight | |
loras[f"{lora_name}.{w_key}"] = ( | |
weight_diff.detach().cpu().contiguous().half() | |
) | |
if getattr(weights, "bias", None) is not None: | |
bias_diff = module.bias - weights.bias | |
loras[f"{lora_name}.{b_key}"] = ( | |
bias_diff.detach().cpu().contiguous().half() | |
) | |
else: | |
raise NotImplementedError | |
module = module.to("cpu", torch.bfloat16) | |
weights = weights.to("cpu", torch.bfloat16) | |
return loras | |
all_loras = {} | |
all_loras |= make_state_dict( | |
LORA_PREFIX_UNET, | |
base_unet, | |
db_unet, | |
UNET_TARGET_REPLACE_MODULE, | |
) | |
del base_unet, db_unet | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
all_lora_name = set() | |
for k in all_loras: | |
lora_name, weight = k.rsplit(".", 1) | |
all_lora_name.add(lora_name) | |
print(len(all_lora_name)) | |
return all_loras | |
# find all the .safetensors files and load them | |
print("Loading Base") | |
base_model = FluxTransformer2DModel.from_pretrained(base, subfolder="transformer", torch_dtype=torch.bfloat16) | |
print("Loading Tuned") | |
tuned_model = FluxTransformer2DModel.from_pretrained(tuned, subfolder="transformer", torch_dtype=torch.bfloat16) | |
output_dict = extract_diff( | |
base_model, | |
tuned_model, | |
mode="fixed", | |
linear_mode_param=dim, | |
conv_mode_param=dim, | |
extract_device="cuda", | |
use_bias=False, | |
sparsity=0.98, | |
small_conv=False, | |
) | |
meta = OrderedDict() | |
meta['format'] = 'pt' | |
save_file(output_dict, output_path, metadata=meta) | |
print("Done") | |