Spaces:
Paused
Paused
def convert_to_diffusers(state_dict): | |
new_state_dict = {} | |
for key in state_dict: | |
new_key = key | |
# Base model name change | |
if key.startswith("diffusion_model."): | |
new_key = key.replace("diffusion_model.", "transformer.") | |
# Attention blocks conversion | |
if "self_attn" in new_key: | |
new_key = new_key.replace("self_attn", "attn1") | |
elif "cross_attn" in new_key: | |
new_key = new_key.replace("cross_attn", "attn2") | |
# Attention components conversion | |
parts = new_key.split(".") | |
for i, part in enumerate(parts): | |
if part in ["q", "k", "v"]: | |
parts[i] = f"to_{part}" | |
elif part == "o": | |
parts[i] = "to_out.0" | |
new_key = ".".join(parts) | |
# FFN conversion | |
if "ffn.0" in new_key: | |
new_key = new_key.replace("ffn.0", "ffn.net.0.proj") | |
elif "ffn.2" in new_key: | |
new_key = new_key.replace("ffn.2", "ffn.net.2") | |
new_state_dict[new_key] = state_dict[key] | |
return new_state_dict | |
def convert_to_original(state_dict): | |
new_state_dict = {} | |
for key in state_dict: | |
new_key = key | |
# Base model name change | |
if key.startswith("transformer."): | |
new_key = key.replace("transformer.", "diffusion_model.") | |
# Attention blocks conversion | |
if "attn1" in new_key: | |
new_key = new_key.replace("attn1", "self_attn") | |
elif "attn2" in new_key: | |
new_key = new_key.replace("attn2", "cross_attn") | |
# Attention components conversion | |
if "to_out.0" in new_key: | |
new_key = new_key.replace("to_out.0", "o") | |
elif "to_q" in new_key: | |
new_key = new_key.replace("to_q", "q") | |
elif "to_k" in new_key: | |
new_key = new_key.replace("to_k", "k") | |
elif "to_v" in new_key: | |
new_key = new_key.replace("to_v", "v") | |
# FFN conversion | |
if "ffn.net.0.proj" in new_key: | |
new_key = new_key.replace("ffn.net.0.proj", "ffn.0") | |
elif "ffn.net.2" in new_key: | |
new_key = new_key.replace("ffn.net.2", "ffn.2") | |
new_state_dict[new_key] = state_dict[key] | |
return new_state_dict | |