Spaces:
Running
Running
File size: 5,912 Bytes
6bf4672 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
def memory_for_attention_layer(precession: int,
seq_len: int,
batch_size: int,
hidden_size: int,
num_heads: int):
"""
head_dim = hidden_size // num_heads
Model Parameters:
q_proj: (hidden_size, num_heads * head_dim)
k_proj: (hidden_size, num_key_value_heads * head_dim)
v_proj: (hidden_size, num_key_value_heads * head_dim)
o_proj: (hidden_size, hidden_size)
Total parameters = 3 * hidden_size * num_heads * head_dim + hidden_size^2
Memory required for model parameters = (3 * hidden_size * num_heads * head_dim + hidden_size^2)
Gradients:
Gradients have the same size as the model parameters.
Memory required for gradients = (3 * hidden_size * num_heads * head_dim + hidden_size^2)
Optimizer States:
Assuming Adam optimizer with two states per parameter (momentum and variance).
Memory required for optimizer states = 2 * (3 * hidden_size * num_heads * head_dim + hidden_size^2)
Activations:
query_states: (batch_size, num_heads, q_len, head_dim)
key_states: (batch_size, num_key_value_heads, q_len, head_dim)
value_states: (batch_size, num_key_value_heads, q_len, head_dim)
attn_weights: (batch_size, num_heads, q_len, q_len)
attn_output: (batch_size, q_len, hidden_size)
Total activations = batch_size * (num_heads * q_len * head_dim + 2 * num_key_value_heads * q_len * head_dim + num_heads * q_len^2 + q_len * hidden_size)
Memory required for activations = batch_size * (num_heads * q_len * head_dim + 2 * num_key_value_heads * q_len * head_dim + num_heads * q_len^2 + q_len * hidden_size)
Temporary Memory:
Additional temporary memory for intermediate computations and buffer storage.
Assuming 20% of the total memory as temporary memory.
total_memory = (model_parameters + gradients + optimizer_states + activations) * (1 + temporary_memory_factor)
((3 * hidden_size * num_heads * head_dim + hidden_size^2) +
(3 * hidden_size * num_heads * head_dim + hidden_size^2) +
2 * (3 * hidden_size * num_heads * head_dim + hidden_size^2) +
batch_size * (num_heads * q_len * head_dim + 2 * num_key_value_heads * q_len * head_dim + num_heads * q_len^2 + q_len * hidden_size)) * (1 + 0.2)
"""
head_dim = hidden_size // num_heads
# Model Memory (3 * hidden_size * num_heads * head_dim + hidden_size^2)
model_memory = 3 * hidden_size * num_heads * head_dim + hidden_size ** 2
# Gradients = model_memory
gradients = model_memory
# Optimizer
optimizer = 2 * model_memory
# Activation
activation = batch_size * (3 * num_heads * seq_len * head_dim +
num_heads * seq_len ** 2 +
seq_len * hidden_size
)
total_memory = (model_memory + gradients + optimizer + activation) * precession
return total_memory
def memory_mlp_layer(precession: int,
seq_len: int,
batch_size: int,
hidden_size: int,
intermediate_size: int):
"""
MLP model
gate_proj (hidden_size, intermediate_size)
up_proj (hidden_size, intermediate_size)
down_proj (intermediate_size, hidden_size)
Memory required for gate_proj weights = intermediate_size * hidden_size
Memory required for up_proj weights = intermediate_size * hidden_size
Memory required for down_proj weights = intermediate_size * hidden_size
model memory = 3 * (hidden_size * intermediate_size)
gradient = model_memory
optimizer = 2 * model_memory
activations = batch_size * seq_len * hidden_size + 2 * batch_size * seq_len * intermediate_size
total_memory = 3 * (hidden_size * intermediate_size) + 3 * (hidden_size * intermediate_size) + 6 * (hidden_size * intermediate_size) + batch_size * (2 * intermediate_size + hidden_size)
total_memory = (hidden_size * intermediate_size) * 12 + Batch_size * seq_len * (2 * intermediate_size + hidden_size)
Args:
hidden_size:
intermediate_size:
batch_size:
seq_len:
Returns:
"""
model_memory = 3 * (hidden_size * intermediate_size)
gradient = model_memory
optimizer = 2 * model_memory
activation = batch_size * seq_len * (2 * intermediate_size + hidden_size)
total_memory = (model_memory + gradient + hidden_size + activation) * precession
return total_memory
def memory_moe_mlp(precession: int,
seq_len: int,
batch_size: int,
hidden_size: int,
intermediate_size: int,
num_expert: int,
top_k: int):
# model memory
gat_memory = hidden_size * num_expert
# The result in byte
moe_mlp = memory_mlp_layer(precession, seq_len, batch_size, hidden_size, intermediate_size) * num_expert
# total model memory The result in byte
model_memory = gat_memory * precession + moe_mlp
# optimizer and gradient as before.
# activation
max_memory_activation = (
(batch_size * seq_len * num_expert * precession) + # Router logits
(batch_size * seq_len * top_k * precession) + # Routing weights
(batch_size * seq_len * top_k * precession) + # Selected experts
(batch_size * seq_len * hidden_size * precession) + # Final hidden states
(batch_size * seq_len * hidden_size * precession) + # Current state (worst-case)
(batch_size * seq_len * hidden_size * precession) # Current hidden states (worst-case)
)
total_memory = model_memory + model_memory + 2 * model_memory + max_memory_activation
return total_memory
|