File size: 5,912 Bytes
6bf4672
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139

def memory_for_attention_layer(precession: int,
                               seq_len: int,
                               batch_size: int,
                               hidden_size: int,
                               num_heads: int):
    """
    head_dim = hidden_size // num_heads

    Model Parameters:
    q_proj: (hidden_size, num_heads * head_dim)
    k_proj: (hidden_size, num_key_value_heads * head_dim)
    v_proj: (hidden_size, num_key_value_heads * head_dim)
    o_proj: (hidden_size, hidden_size)

    Total parameters = 3 * hidden_size * num_heads * head_dim + hidden_size^2

    Memory required for model parameters = (3 * hidden_size * num_heads * head_dim + hidden_size^2)

    Gradients:
        Gradients have the same size as the model parameters.
        Memory required for gradients = (3 * hidden_size * num_heads * head_dim + hidden_size^2)

    Optimizer States:
        Assuming Adam optimizer with two states per parameter (momentum and variance).
        Memory required for optimizer states = 2 * (3 * hidden_size * num_heads * head_dim + hidden_size^2)

    Activations:
        query_states: (batch_size, num_heads, q_len, head_dim)
        key_states: (batch_size, num_key_value_heads, q_len, head_dim)
        value_states: (batch_size, num_key_value_heads, q_len, head_dim)
        attn_weights: (batch_size, num_heads, q_len, q_len)
        attn_output: (batch_size, q_len, hidden_size)
        Total activations = batch_size * (num_heads * q_len * head_dim + 2 * num_key_value_heads * q_len * head_dim + num_heads * q_len^2 + q_len * hidden_size)

        Memory required for activations = batch_size * (num_heads * q_len * head_dim + 2 * num_key_value_heads * q_len * head_dim + num_heads * q_len^2 + q_len * hidden_size)

    Temporary Memory:
        Additional temporary memory for intermediate computations and buffer storage.
        Assuming 20% of the total memory as temporary memory.

    total_memory = (model_parameters + gradients + optimizer_states + activations) * (1 + temporary_memory_factor)

     ((3 * hidden_size * num_heads * head_dim + hidden_size^2)  +
                    (3 * hidden_size * num_heads * head_dim + hidden_size^2) +
                    2 * (3 * hidden_size * num_heads * head_dim + hidden_size^2)  +
                    batch_size * (num_heads * q_len * head_dim + 2 * num_key_value_heads * q_len * head_dim + num_heads * q_len^2 + q_len * hidden_size)) * (1 + 0.2)

    """
    head_dim = hidden_size // num_heads
    # Model Memory (3 * hidden_size * num_heads * head_dim + hidden_size^2)
    model_memory = 3 * hidden_size * num_heads * head_dim + hidden_size ** 2

    # Gradients = model_memory
    gradients = model_memory

    # Optimizer
    optimizer = 2 * model_memory

    # Activation
    activation = batch_size * (3 * num_heads * seq_len * head_dim +
                               num_heads * seq_len ** 2 +
                               seq_len * hidden_size
                               )
    total_memory = (model_memory + gradients + optimizer + activation) * precession

    return total_memory


def memory_mlp_layer(precession: int,
                     seq_len: int,
                     batch_size: int,
                     hidden_size: int,
                     intermediate_size: int):
    """
    MLP model
    gate_proj (hidden_size, intermediate_size)
    up_proj (hidden_size, intermediate_size)
    down_proj (intermediate_size, hidden_size)

    Memory required for gate_proj weights = intermediate_size * hidden_size
    Memory required for up_proj weights = intermediate_size * hidden_size
    Memory required for down_proj weights = intermediate_size * hidden_size

    model memory = 3 * (hidden_size * intermediate_size)
    gradient = model_memory
    optimizer = 2 * model_memory
    activations = batch_size * seq_len * hidden_size + 2 * batch_size * seq_len * intermediate_size

    total_memory = 3 * (hidden_size * intermediate_size) + 3 * (hidden_size * intermediate_size) + 6 * (hidden_size * intermediate_size) + batch_size * (2 * intermediate_size + hidden_size)
    total_memory = (hidden_size * intermediate_size) * 12 + Batch_size * seq_len * (2 * intermediate_size + hidden_size)

    Args:
        hidden_size:
        intermediate_size:
        batch_size:
        seq_len:

    Returns:

    """
    model_memory = 3 * (hidden_size * intermediate_size)
    gradient = model_memory
    optimizer = 2 * model_memory
    activation = batch_size * seq_len * (2 * intermediate_size + hidden_size)
    total_memory = (model_memory + gradient + hidden_size + activation) * precession
    return total_memory


def memory_moe_mlp(precession: int,
                   seq_len: int,
                   batch_size: int,
                   hidden_size: int,
                   intermediate_size: int,
                   num_expert: int,
                   top_k: int):
    # model memory
    gat_memory = hidden_size * num_expert
    # The result in byte
    moe_mlp = memory_mlp_layer(precession, seq_len, batch_size, hidden_size, intermediate_size) * num_expert

    # total model memory The result in byte
    model_memory = gat_memory * precession + moe_mlp

    # optimizer and gradient as before.
    # activation
    max_memory_activation = (
            (batch_size * seq_len * num_expert * precession) +  # Router logits
            (batch_size * seq_len * top_k * precession) +  # Routing weights
            (batch_size * seq_len * top_k * precession) +  # Selected experts
            (batch_size * seq_len * hidden_size * precession) +  # Final hidden states
            (batch_size * seq_len * hidden_size * precession) +  # Current state (worst-case)
            (batch_size * seq_len * hidden_size * precession)  # Current hidden states (worst-case)
    )
    total_memory = model_memory + model_memory + 2 * model_memory + max_memory_activation

    return total_memory