import os from tqdm import tqdm import argparse from collections import OrderedDict parser = argparse.ArgumentParser(description="Extract LoRA from Flex") parser.add_argument("--base", type=str, default="ostris/Flex.1-alpha", help="Base model path") parser.add_argument("--tuned", type=str, required=True, help="Tuned model path") parser.add_argument("--output", type=str, required=True, help="Output path for lora") parser.add_argument("--rank", type=int, default=32, help="LoRA rank for extraction") parser.add_argument("--gpu", type=int, default=0, help="GPU to process extraction") parser.add_argument("--full", action="store_true", help="Do a full transformer extraction, not just transformer blocks") args = parser.parse_args() if True: # set cuda environment variable os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) import torch from safetensors.torch import load_file, save_file from lycoris.utils import extract_linear, extract_conv, make_sparse from diffusers import FluxTransformer2DModel base = args.base tuned = args.tuned output_path = args.output dim = args.rank os.makedirs(os.path.dirname(output_path), exist_ok=True) state_dict_base = {} state_dict_tuned = {} output_dict = {} @torch.no_grad() def extract_diff( base_unet, db_unet, mode="fixed", linear_mode_param=0, conv_mode_param=0, extract_device="cpu", use_bias=False, sparsity=0.98, # small_conv=True, small_conv=False, ): UNET_TARGET_REPLACE_MODULE = [ "Linear", "Conv2d", "LayerNorm", "GroupNorm", "GroupNorm32", "LoRACompatibleLinear", "LoRACompatibleConv" ] LORA_PREFIX_UNET = "transformer" def make_state_dict( prefix, root_module: torch.nn.Module, target_module: torch.nn.Module, target_replace_modules, ): loras = {} temp = {} for name, module in root_module.named_modules(): if module.__class__.__name__ in target_replace_modules: temp[name] = module for name, module in tqdm( list((n, m) for n, m in target_module.named_modules() if n in temp) ): weights = temp[name] lora_name = prefix + "." + name # lora_name = lora_name.replace(".", "_") layer = module.__class__.__name__ if 'transformer_blocks' not in lora_name and not args.full: continue if layer in { "Linear", "Conv2d", "LayerNorm", "GroupNorm", "GroupNorm32", "Embedding", "LoRACompatibleLinear", "LoRACompatibleConv" }: root_weight = module.weight try: if torch.allclose(root_weight, weights.weight): continue except: continue else: continue module = module.to(extract_device, torch.float32) weights = weights.to(extract_device, torch.float32) if mode == "full": decompose_mode = "full" elif layer == "Linear": weight, decompose_mode = extract_linear( (root_weight - weights.weight), mode, linear_mode_param, device=extract_device, ) if decompose_mode == "low rank": extract_a, extract_b, diff = weight elif layer == "Conv2d": is_linear = root_weight.shape[2] == 1 and root_weight.shape[3] == 1 weight, decompose_mode = extract_conv( (root_weight - weights.weight), mode, linear_mode_param if is_linear else conv_mode_param, device=extract_device, ) if decompose_mode == "low rank": extract_a, extract_b, diff = weight if small_conv and not is_linear and decompose_mode == "low rank": dim = extract_a.size(0) (extract_c, extract_a, _), _ = extract_conv( extract_a.transpose(0, 1), "fixed", dim, extract_device, True, ) extract_a = extract_a.transpose(0, 1) extract_c = extract_c.transpose(0, 1) loras[f"{lora_name}.lora_mid.weight"] = ( extract_c.detach().cpu().contiguous().half() ) diff = ( ( root_weight - torch.einsum( "i j k l, j r, p i -> p r k l", extract_c, extract_a.flatten(1, -1), extract_b.flatten(1, -1), ) ) .detach() .cpu() .contiguous() ) del extract_c else: module = module.to("cpu") weights = weights.to("cpu") continue if decompose_mode == "low rank": loras[f"{lora_name}.lora_A.weight"] = ( extract_a.detach().cpu().contiguous().half() ) loras[f"{lora_name}.lora_B.weight"] = ( extract_b.detach().cpu().contiguous().half() ) # loras[f"{lora_name}.alpha"] = torch.Tensor([extract_a.shape[0]]).half() if use_bias: diff = diff.detach().cpu().reshape(extract_b.size(0), -1) sparse_diff = make_sparse(diff, sparsity).to_sparse().coalesce() indices = sparse_diff.indices().to(torch.int16) values = sparse_diff.values().half() loras[f"{lora_name}.bias_indices"] = indices loras[f"{lora_name}.bias_values"] = values loras[f"{lora_name}.bias_size"] = torch.tensor(diff.shape).to( torch.int16 ) del extract_a, extract_b, diff elif decompose_mode == "full": if "Norm" in layer: w_key = "w_norm" b_key = "b_norm" else: w_key = "diff" b_key = "diff_b" weight_diff = module.weight - weights.weight loras[f"{lora_name}.{w_key}"] = ( weight_diff.detach().cpu().contiguous().half() ) if getattr(weights, "bias", None) is not None: bias_diff = module.bias - weights.bias loras[f"{lora_name}.{b_key}"] = ( bias_diff.detach().cpu().contiguous().half() ) else: raise NotImplementedError module = module.to("cpu", torch.bfloat16) weights = weights.to("cpu", torch.bfloat16) return loras all_loras = {} all_loras |= make_state_dict( LORA_PREFIX_UNET, base_unet, db_unet, UNET_TARGET_REPLACE_MODULE, ) del base_unet, db_unet if torch.cuda.is_available(): torch.cuda.empty_cache() all_lora_name = set() for k in all_loras: lora_name, weight = k.rsplit(".", 1) all_lora_name.add(lora_name) print(len(all_lora_name)) return all_loras # find all the .safetensors files and load them print("Loading Base") base_model = FluxTransformer2DModel.from_pretrained(base, subfolder="transformer", torch_dtype=torch.bfloat16) print("Loading Tuned") tuned_model = FluxTransformer2DModel.from_pretrained(tuned, subfolder="transformer", torch_dtype=torch.bfloat16) output_dict = extract_diff( base_model, tuned_model, mode="fixed", linear_mode_param=dim, conv_mode_param=dim, extract_device="cuda", use_bias=False, sparsity=0.98, small_conv=False, ) meta = OrderedDict() meta['format'] = 'pt' save_file(output_dict, output_path, metadata=meta) print("Done")