|
import argparse |
|
from typing import Any, Dict |
|
|
|
import torch |
|
from huggingface_hub import hf_hub_download |
|
from safetensors.torch import load_file |
|
|
|
from diffusers import AutoencoderDC |
|
|
|
|
|
def remap_qkv_(key: str, state_dict: Dict[str, Any]): |
|
qkv = state_dict.pop(key) |
|
q, k, v = torch.chunk(qkv, 3, dim=0) |
|
parent_module, _, _ = key.rpartition(".qkv.conv.weight") |
|
state_dict[f"{parent_module}.to_q.weight"] = q.squeeze() |
|
state_dict[f"{parent_module}.to_k.weight"] = k.squeeze() |
|
state_dict[f"{parent_module}.to_v.weight"] = v.squeeze() |
|
|
|
|
|
def remap_proj_conv_(key: str, state_dict: Dict[str, Any]): |
|
parent_module, _, _ = key.rpartition(".proj.conv.weight") |
|
state_dict[f"{parent_module}.to_out.weight"] = state_dict.pop(key).squeeze() |
|
|
|
|
|
AE_KEYS_RENAME_DICT = { |
|
|
|
"main.": "", |
|
"op_list.": "", |
|
"context_module": "attn", |
|
"local_module": "conv_out", |
|
|
|
|
|
"aggreg.0.0": "to_qkv_multiscale.0.proj_in", |
|
"aggreg.0.1": "to_qkv_multiscale.0.proj_out", |
|
"depth_conv.conv": "conv_depth", |
|
"inverted_conv.conv": "conv_inverted", |
|
"point_conv.conv": "conv_point", |
|
"point_conv.norm": "norm", |
|
"conv.conv.": "conv.", |
|
"conv1.conv": "conv1", |
|
"conv2.conv": "conv2", |
|
"conv2.norm": "norm", |
|
"proj.norm": "norm_out", |
|
|
|
"encoder.project_in.conv": "encoder.conv_in", |
|
"encoder.project_out.0.conv": "encoder.conv_out", |
|
"encoder.stages": "encoder.down_blocks", |
|
|
|
"decoder.project_in.conv": "decoder.conv_in", |
|
"decoder.project_out.0": "decoder.norm_out", |
|
"decoder.project_out.2.conv": "decoder.conv_out", |
|
"decoder.stages": "decoder.up_blocks", |
|
} |
|
|
|
AE_F32C32_KEYS = { |
|
|
|
"encoder.project_in.conv": "encoder.conv_in.conv", |
|
|
|
"decoder.project_out.2.conv": "decoder.conv_out.conv", |
|
} |
|
|
|
AE_F64C128_KEYS = { |
|
|
|
"encoder.project_in.conv": "encoder.conv_in.conv", |
|
|
|
"decoder.project_out.2.conv": "decoder.conv_out.conv", |
|
} |
|
|
|
AE_F128C512_KEYS = { |
|
|
|
"encoder.project_in.conv": "encoder.conv_in.conv", |
|
|
|
"decoder.project_out.2.conv": "decoder.conv_out.conv", |
|
} |
|
|
|
AE_SPECIAL_KEYS_REMAP = { |
|
"qkv.conv.weight": remap_qkv_, |
|
"proj.conv.weight": remap_proj_conv_, |
|
} |
|
|
|
|
|
def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]: |
|
state_dict = saved_dict |
|
if "model" in saved_dict.keys(): |
|
state_dict = state_dict["model"] |
|
if "module" in saved_dict.keys(): |
|
state_dict = state_dict["module"] |
|
if "state_dict" in saved_dict.keys(): |
|
state_dict = state_dict["state_dict"] |
|
return state_dict |
|
|
|
|
|
def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]: |
|
state_dict[new_key] = state_dict.pop(old_key) |
|
|
|
|
|
def convert_ae(config_name: str, dtype: torch.dtype): |
|
config = get_ae_config(config_name) |
|
hub_id = f"mit-han-lab/{config_name}" |
|
ckpt_path = hf_hub_download(hub_id, "model.safetensors") |
|
original_state_dict = get_state_dict(load_file(ckpt_path)) |
|
|
|
ae = AutoencoderDC(**config).to(dtype=dtype) |
|
|
|
for key in list(original_state_dict.keys()): |
|
new_key = key[:] |
|
for replace_key, rename_key in AE_KEYS_RENAME_DICT.items(): |
|
new_key = new_key.replace(replace_key, rename_key) |
|
update_state_dict_(original_state_dict, key, new_key) |
|
|
|
for key in list(original_state_dict.keys()): |
|
for special_key, handler_fn_inplace in AE_SPECIAL_KEYS_REMAP.items(): |
|
if special_key not in key: |
|
continue |
|
handler_fn_inplace(key, original_state_dict) |
|
|
|
ae.load_state_dict(original_state_dict, strict=True) |
|
return ae |
|
|
|
|
|
def get_ae_config(name: str): |
|
if name in ["dc-ae-f32c32-sana-1.0"]: |
|
config = { |
|
"latent_channels": 32, |
|
"encoder_block_types": ( |
|
"ResBlock", |
|
"ResBlock", |
|
"ResBlock", |
|
"EfficientViTBlock", |
|
"EfficientViTBlock", |
|
"EfficientViTBlock", |
|
), |
|
"decoder_block_types": ( |
|
"ResBlock", |
|
"ResBlock", |
|
"ResBlock", |
|
"EfficientViTBlock", |
|
"EfficientViTBlock", |
|
"EfficientViTBlock", |
|
), |
|
"encoder_block_out_channels": (128, 256, 512, 512, 1024, 1024), |
|
"decoder_block_out_channels": (128, 256, 512, 512, 1024, 1024), |
|
"encoder_qkv_multiscales": ((), (), (), (5,), (5,), (5,)), |
|
"decoder_qkv_multiscales": ((), (), (), (5,), (5,), (5,)), |
|
"encoder_layers_per_block": (2, 2, 2, 3, 3, 3), |
|
"decoder_layers_per_block": [3, 3, 3, 3, 3, 3], |
|
"downsample_block_type": "conv", |
|
"upsample_block_type": "interpolate", |
|
"decoder_norm_types": "rms_norm", |
|
"decoder_act_fns": "silu", |
|
"scaling_factor": 0.41407, |
|
} |
|
elif name in ["dc-ae-f32c32-in-1.0", "dc-ae-f32c32-mix-1.0"]: |
|
AE_KEYS_RENAME_DICT.update(AE_F32C32_KEYS) |
|
config = { |
|
"latent_channels": 32, |
|
"encoder_block_types": [ |
|
"ResBlock", |
|
"ResBlock", |
|
"ResBlock", |
|
"EfficientViTBlock", |
|
"EfficientViTBlock", |
|
"EfficientViTBlock", |
|
], |
|
"decoder_block_types": [ |
|
"ResBlock", |
|
"ResBlock", |
|
"ResBlock", |
|
"EfficientViTBlock", |
|
"EfficientViTBlock", |
|
"EfficientViTBlock", |
|
], |
|
"encoder_block_out_channels": [128, 256, 512, 512, 1024, 1024], |
|
"decoder_block_out_channels": [128, 256, 512, 512, 1024, 1024], |
|
"encoder_layers_per_block": [0, 4, 8, 2, 2, 2], |
|
"decoder_layers_per_block": [0, 5, 10, 2, 2, 2], |
|
"encoder_qkv_multiscales": ((), (), (), (), (), ()), |
|
"decoder_qkv_multiscales": ((), (), (), (), (), ()), |
|
"decoder_norm_types": ["batch_norm", "batch_norm", "batch_norm", "rms_norm", "rms_norm", "rms_norm"], |
|
"decoder_act_fns": ["relu", "relu", "relu", "silu", "silu", "silu"], |
|
} |
|
if name == "dc-ae-f32c32-in-1.0": |
|
config["scaling_factor"] = 0.3189 |
|
elif name == "dc-ae-f32c32-mix-1.0": |
|
config["scaling_factor"] = 0.4552 |
|
elif name in ["dc-ae-f64c128-in-1.0", "dc-ae-f64c128-mix-1.0"]: |
|
AE_KEYS_RENAME_DICT.update(AE_F64C128_KEYS) |
|
config = { |
|
"latent_channels": 128, |
|
"encoder_block_types": [ |
|
"ResBlock", |
|
"ResBlock", |
|
"ResBlock", |
|
"EfficientViTBlock", |
|
"EfficientViTBlock", |
|
"EfficientViTBlock", |
|
"EfficientViTBlock", |
|
], |
|
"decoder_block_types": [ |
|
"ResBlock", |
|
"ResBlock", |
|
"ResBlock", |
|
"EfficientViTBlock", |
|
"EfficientViTBlock", |
|
"EfficientViTBlock", |
|
"EfficientViTBlock", |
|
], |
|
"encoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048], |
|
"decoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048], |
|
"encoder_layers_per_block": [0, 4, 8, 2, 2, 2, 2], |
|
"decoder_layers_per_block": [0, 5, 10, 2, 2, 2, 2], |
|
"encoder_qkv_multiscales": ((), (), (), (), (), (), ()), |
|
"decoder_qkv_multiscales": ((), (), (), (), (), (), ()), |
|
"decoder_norm_types": [ |
|
"batch_norm", |
|
"batch_norm", |
|
"batch_norm", |
|
"rms_norm", |
|
"rms_norm", |
|
"rms_norm", |
|
"rms_norm", |
|
], |
|
"decoder_act_fns": ["relu", "relu", "relu", "silu", "silu", "silu", "silu"], |
|
} |
|
if name == "dc-ae-f64c128-in-1.0": |
|
config["scaling_factor"] = 0.2889 |
|
elif name == "dc-ae-f64c128-mix-1.0": |
|
config["scaling_factor"] = 0.4538 |
|
elif name in ["dc-ae-f128c512-in-1.0", "dc-ae-f128c512-mix-1.0"]: |
|
AE_KEYS_RENAME_DICT.update(AE_F128C512_KEYS) |
|
config = { |
|
"latent_channels": 512, |
|
"encoder_block_types": [ |
|
"ResBlock", |
|
"ResBlock", |
|
"ResBlock", |
|
"EfficientViTBlock", |
|
"EfficientViTBlock", |
|
"EfficientViTBlock", |
|
"EfficientViTBlock", |
|
"EfficientViTBlock", |
|
], |
|
"decoder_block_types": [ |
|
"ResBlock", |
|
"ResBlock", |
|
"ResBlock", |
|
"EfficientViTBlock", |
|
"EfficientViTBlock", |
|
"EfficientViTBlock", |
|
"EfficientViTBlock", |
|
"EfficientViTBlock", |
|
], |
|
"encoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048, 2048], |
|
"decoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048, 2048], |
|
"encoder_layers_per_block": [0, 4, 8, 2, 2, 2, 2, 2], |
|
"decoder_layers_per_block": [0, 5, 10, 2, 2, 2, 2, 2], |
|
"encoder_qkv_multiscales": ((), (), (), (), (), (), (), ()), |
|
"decoder_qkv_multiscales": ((), (), (), (), (), (), (), ()), |
|
"decoder_norm_types": [ |
|
"batch_norm", |
|
"batch_norm", |
|
"batch_norm", |
|
"rms_norm", |
|
"rms_norm", |
|
"rms_norm", |
|
"rms_norm", |
|
"rms_norm", |
|
], |
|
"decoder_act_fns": ["relu", "relu", "relu", "silu", "silu", "silu", "silu", "silu"], |
|
} |
|
if name == "dc-ae-f128c512-in-1.0": |
|
config["scaling_factor"] = 0.4883 |
|
elif name == "dc-ae-f128c512-mix-1.0": |
|
config["scaling_factor"] = 0.3620 |
|
else: |
|
raise ValueError("Invalid config name provided.") |
|
|
|
return config |
|
|
|
|
|
def get_args(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"--config_name", |
|
type=str, |
|
default="dc-ae-f32c32-sana-1.0", |
|
choices=[ |
|
"dc-ae-f32c32-sana-1.0", |
|
"dc-ae-f32c32-in-1.0", |
|
"dc-ae-f32c32-mix-1.0", |
|
"dc-ae-f64c128-in-1.0", |
|
"dc-ae-f64c128-mix-1.0", |
|
"dc-ae-f128c512-in-1.0", |
|
"dc-ae-f128c512-mix-1.0", |
|
], |
|
help="The DCAE checkpoint to convert", |
|
) |
|
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved") |
|
parser.add_argument("--dtype", default="fp32", help="Torch dtype to save the model in.") |
|
return parser.parse_args() |
|
|
|
|
|
DTYPE_MAPPING = { |
|
"fp32": torch.float32, |
|
"fp16": torch.float16, |
|
"bf16": torch.bfloat16, |
|
} |
|
|
|
VARIANT_MAPPING = { |
|
"fp32": None, |
|
"fp16": "fp16", |
|
"bf16": "bf16", |
|
} |
|
|
|
|
|
if __name__ == "__main__": |
|
args = get_args() |
|
|
|
dtype = DTYPE_MAPPING[args.dtype] |
|
variant = VARIANT_MAPPING[args.dtype] |
|
|
|
ae = convert_ae(args.config_name, dtype) |
|
ae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", variant=variant) |
|
|