File size: 2,255 Bytes
1c72248
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
def convert_to_diffusers(state_dict):
    new_state_dict = {}
    for key in state_dict:
        new_key = key
        # Base model name change
        if key.startswith("diffusion_model."):
            new_key = key.replace("diffusion_model.", "transformer.")

        # Attention blocks conversion
        if "self_attn" in new_key:
            new_key = new_key.replace("self_attn", "attn1")
        elif "cross_attn" in new_key:
            new_key = new_key.replace("cross_attn", "attn2")

        # Attention components conversion
        parts = new_key.split(".")
        for i, part in enumerate(parts):
            if part in ["q", "k", "v"]:
                parts[i] = f"to_{part}"
            elif part == "o":
                parts[i] = "to_out.0"
        new_key = ".".join(parts)

        # FFN conversion
        if "ffn.0" in new_key:
            new_key = new_key.replace("ffn.0", "ffn.net.0.proj")
        elif "ffn.2" in new_key:
            new_key = new_key.replace("ffn.2", "ffn.net.2")

        new_state_dict[new_key] = state_dict[key]
    return new_state_dict


def convert_to_original(state_dict):
    new_state_dict = {}
    for key in state_dict:
        new_key = key
        # Base model name change
        if key.startswith("transformer."):
            new_key = key.replace("transformer.", "diffusion_model.")

        # Attention blocks conversion
        if "attn1" in new_key:
            new_key = new_key.replace("attn1", "self_attn")
        elif "attn2" in new_key:
            new_key = new_key.replace("attn2", "cross_attn")

        # Attention components conversion
        if "to_out.0" in new_key:
            new_key = new_key.replace("to_out.0", "o")
        elif "to_q" in new_key:
            new_key = new_key.replace("to_q", "q")
        elif "to_k" in new_key:
            new_key = new_key.replace("to_k", "k")
        elif "to_v" in new_key:
            new_key = new_key.replace("to_v", "v")

        # FFN conversion
        if "ffn.net.0.proj" in new_key:
            new_key = new_key.replace("ffn.net.0.proj", "ffn.0")
        elif "ffn.net.2" in new_key:
            new_key = new_key.replace("ffn.net.2", "ffn.2")

        new_state_dict[new_key] = state_dict[key]
    return new_state_dict