Spaces:
Sleeping
Sleeping
File size: 5,634 Bytes
dc80200 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.ticker import ScalarFormatter
from enum import Enum
import io
class AttentionType(Enum):
LOCAL = 0
GLOBAL = 1
def gqa_kv_per_layer_per_token(n_kv_heads, d_head, kv_parameter_size):
return 2 * kv_parameter_size * n_kv_heads * d_head
def mla_kv_per_layer_per_token(d_compressed, kv_parameter_size):
return kv_parameter_size * d_compressed
def tokens_per_second(batch_size, bandwidth, total_kv_size, param_size):
return (batch_size * bandwidth) / (batch_size * total_kv_size + param_size)
def compute_tps(kv_per_layer_per_token, seq_len, batch_size, total_param_size,
num_layers, swa_pattern, swa_size, bandwidth):
tps_values = []
for ctx_len in seq_len:
total_kv_size = 0
for l in range(num_layers):
if swa_pattern[l % len(swa_pattern)] == AttentionType.LOCAL:
total_kv_size += kv_per_layer_per_token * min(ctx_len, swa_size)
else:
total_kv_size += kv_per_layer_per_token * ctx_len
tps = tokens_per_second(batch_size, bandwidth, total_kv_size, total_param_size)
tps_values.append(tps)
return tps_values
def create_throughput_plot(
model_name,
memory_bandwidth,
num_parameters,
parameter_size,
kv_parameter_size,
num_layers,
num_heads,
d_model,
ctx_length,
local_layers,
global_layers,
swa_size,
gqa_heads,
mla_d_compressed,
):
memory_bandwidth = float(memory_bandwidth) * 1_000_000_000
num_parameters = float(num_parameters) * 1_000_000_000
d_head = d_model // num_heads
total_param_size = num_parameters * (parameter_size / 8.0)
swa_pattern = ([AttentionType.LOCAL] * local_layers +
[AttentionType.GLOBAL] * global_layers)
if len(swa_pattern) == 0:
swa_pattern = [AttentionType.GLOBAL]
sns.set_theme(style="whitegrid", context="paper")
palette = sns.color_palette("viridis", len(gqa_heads) + len(mla_d_compressed))
plt.figure(figsize=(14, 8), dpi=300)
seq_len = np.logspace(2, 5, 100).astype(int)
batch_size = 1
tps_values = []
gqa_count = len(gqa_heads)
for i, n_kv_head in enumerate(gqa_heads):
n_kv_head = int(n_kv_head)
kv_per_token = gqa_kv_per_layer_per_token(n_kv_head, d_head, kv_parameter_size)
gqa_tps_values = compute_tps(kv_per_token, seq_len, batch_size, total_param_size,
num_layers, swa_pattern, swa_size, memory_bandwidth)
tps_values.extend(gqa_tps_values)
plt.plot(seq_len, gqa_tps_values, label=f"GQA: {n_kv_head} heads", color=palette[i],
linewidth=3.5, alpha=0.85)
plt.axvline(x=ctx_length, color='red', linestyle='--', alpha=0.8, linewidth=2.5,
label=f"Max Context Length ({ctx_length:,})")
local_count = swa_pattern.count(AttentionType.LOCAL)
global_count = swa_pattern.count(AttentionType.GLOBAL)
if local_count > 0:
plt.axvline(x=swa_size, color='blue', linestyle='--', alpha=0.8, linewidth=2.5,
label=f"Sliding Window Limit ({swa_size:,})")
for i, d_comp in enumerate(mla_d_compressed):
d_comp = int(d_comp)
kv_per_token = mla_kv_per_layer_per_token(d_comp, kv_parameter_size)
mla_tps_values = compute_tps(kv_per_token, seq_len, batch_size, total_param_size,
num_layers, swa_pattern, swa_size, memory_bandwidth)
tps_values.extend(mla_tps_values)
plt.plot(seq_len, mla_tps_values, label=f"MLA: dc = {d_comp}",
color=palette[i + gqa_count], linewidth=3.5, alpha=0.85)
plt.xscale('log')
if all(np.isfinite(tps_values)):
min_tps = min(tps_values)
max_tps = max(tps_values)
y_min = max(0, min_tps * 0.9)
y_max = max_tps * 1.1
plt.ylim(y_min, y_max)
else:
plt.ylim(15, 40)
plt.gca().xaxis.set_major_formatter(ScalarFormatter())
plt.gca().yaxis.set_major_formatter(ScalarFormatter())
ax = plt.gca()
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['left'].set_linewidth(1.5)
ax.spines['bottom'].set_linewidth(1.5)
attn_label = "Global" if local_count == 0 else f"SWA {local_count}:{global_count}"
device_name = model_name.split(':')[0] if ':' in model_name else model_name
plt.annotate(f"{device_name}\nBandwidth: {memory_bandwidth/1e9:.1f} GB/s\nParameter Size: {parameter_size:.1f} bits\nAttention Kind: {attn_label}",
xy=(0.8, 0.97),
xycoords='axes fraction',
bbox=dict(boxstyle="round,pad=0.4", facecolor="white", alpha=0.9, edgecolor='darkgray'),
va='top',
fontsize=11)
plt.xlabel('Context Length (tokens)', fontsize=14, fontweight='bold')
plt.ylabel('Tokens per Second', fontsize=14, fontweight='bold')
plt.tick_params(axis='both', which='major', labelsize=12)
model_title = model_name.split(':')[1] if ':' in model_name else model_name
plt.title(f"{model_title}: Tokens Per Second vs. Sequence Length", fontsize=18,
fontweight='bold', pad=20)
plt.legend(title="Configuration", frameon=True, framealpha=0.95, fontsize=12, title_fontsize=14)
plt.grid(True, alpha=0.5)
buf = io.BytesIO()
plt.savefig(buf, format='png')
plt.close()
buf.seek(0)
from PIL import Image
img = Image.open(buf)
return img
|