Spaces:
Paused
Paused
File size: 2,255 Bytes
1c72248 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
def convert_to_diffusers(state_dict):
new_state_dict = {}
for key in state_dict:
new_key = key
# Base model name change
if key.startswith("diffusion_model."):
new_key = key.replace("diffusion_model.", "transformer.")
# Attention blocks conversion
if "self_attn" in new_key:
new_key = new_key.replace("self_attn", "attn1")
elif "cross_attn" in new_key:
new_key = new_key.replace("cross_attn", "attn2")
# Attention components conversion
parts = new_key.split(".")
for i, part in enumerate(parts):
if part in ["q", "k", "v"]:
parts[i] = f"to_{part}"
elif part == "o":
parts[i] = "to_out.0"
new_key = ".".join(parts)
# FFN conversion
if "ffn.0" in new_key:
new_key = new_key.replace("ffn.0", "ffn.net.0.proj")
elif "ffn.2" in new_key:
new_key = new_key.replace("ffn.2", "ffn.net.2")
new_state_dict[new_key] = state_dict[key]
return new_state_dict
def convert_to_original(state_dict):
new_state_dict = {}
for key in state_dict:
new_key = key
# Base model name change
if key.startswith("transformer."):
new_key = key.replace("transformer.", "diffusion_model.")
# Attention blocks conversion
if "attn1" in new_key:
new_key = new_key.replace("attn1", "self_attn")
elif "attn2" in new_key:
new_key = new_key.replace("attn2", "cross_attn")
# Attention components conversion
if "to_out.0" in new_key:
new_key = new_key.replace("to_out.0", "o")
elif "to_q" in new_key:
new_key = new_key.replace("to_q", "q")
elif "to_k" in new_key:
new_key = new_key.replace("to_k", "k")
elif "to_v" in new_key:
new_key = new_key.replace("to_v", "v")
# FFN conversion
if "ffn.net.0.proj" in new_key:
new_key = new_key.replace("ffn.net.0.proj", "ffn.0")
elif "ffn.net.2" in new_key:
new_key = new_key.replace("ffn.net.2", "ffn.2")
new_state_dict[new_key] = state_dict[key]
return new_state_dict
|