|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
import seaborn as sns |
|
from matplotlib.ticker import ScalarFormatter |
|
from enum import Enum |
|
import io |
|
|
|
class AttentionType(Enum): |
|
LOCAL = 0 |
|
GLOBAL = 1 |
|
|
|
def gqa_kv_per_layer_per_token(n_kv_heads, d_head, kv_parameter_size): |
|
return 2 * kv_parameter_size * n_kv_heads * d_head |
|
|
|
def mla_kv_per_layer_per_token(d_compressed, kv_parameter_size): |
|
return kv_parameter_size * d_compressed |
|
|
|
def tokens_per_second(batch_size, bandwidth, total_kv_size, param_size): |
|
return (batch_size * bandwidth) / (batch_size * total_kv_size + param_size) |
|
|
|
def compute_tps(kv_per_layer_per_token, seq_len, batch_size, total_param_size, |
|
num_layers, swa_pattern, swa_size, bandwidth): |
|
tps_values = [] |
|
for ctx_len in seq_len: |
|
total_kv_size = 0 |
|
for l in range(num_layers): |
|
if swa_pattern[l % len(swa_pattern)] == AttentionType.LOCAL: |
|
total_kv_size += kv_per_layer_per_token * min(ctx_len, swa_size) |
|
else: |
|
total_kv_size += kv_per_layer_per_token * ctx_len |
|
tps = tokens_per_second(batch_size, bandwidth, total_kv_size, total_param_size) |
|
tps_values.append(tps) |
|
return tps_values |
|
|
|
def create_throughput_plot( |
|
model_name, |
|
memory_bandwidth, |
|
num_parameters, |
|
parameter_size, |
|
kv_parameter_size, |
|
num_layers, |
|
num_heads, |
|
d_model, |
|
ctx_length, |
|
local_layers, |
|
global_layers, |
|
swa_size, |
|
gqa_heads, |
|
mla_d_compressed, |
|
): |
|
memory_bandwidth = float(memory_bandwidth) * 1_000_000_000 |
|
num_parameters = float(num_parameters) * 1_000_000_000 |
|
|
|
d_head = d_model // num_heads |
|
total_param_size = num_parameters * (parameter_size / 8.0) |
|
|
|
swa_pattern = ([AttentionType.LOCAL] * local_layers + |
|
[AttentionType.GLOBAL] * global_layers) |
|
|
|
if len(swa_pattern) == 0: |
|
swa_pattern = [AttentionType.GLOBAL] |
|
|
|
sns.set_theme(style="whitegrid", context="paper") |
|
palette = sns.color_palette("viridis", len(gqa_heads) + len(mla_d_compressed)) |
|
plt.figure(figsize=(14, 8), dpi=300) |
|
|
|
seq_len = np.logspace(2, 5, 100).astype(int) |
|
batch_size = 1 |
|
|
|
tps_values = [] |
|
gqa_count = len(gqa_heads) |
|
for i, n_kv_head in enumerate(gqa_heads): |
|
n_kv_head = int(n_kv_head) |
|
kv_per_token = gqa_kv_per_layer_per_token(n_kv_head, d_head, kv_parameter_size) |
|
gqa_tps_values = compute_tps(kv_per_token, seq_len, batch_size, total_param_size, |
|
num_layers, swa_pattern, swa_size, memory_bandwidth) |
|
tps_values.extend(gqa_tps_values) |
|
plt.plot(seq_len, gqa_tps_values, label=f"GQA: {n_kv_head} heads", color=palette[i], |
|
linewidth=3.5, alpha=0.85) |
|
|
|
plt.axvline(x=ctx_length, color='red', linestyle='--', alpha=0.8, linewidth=2.5, |
|
label=f"Max Context Length ({ctx_length:,})") |
|
|
|
local_count = swa_pattern.count(AttentionType.LOCAL) |
|
global_count = swa_pattern.count(AttentionType.GLOBAL) |
|
if local_count > 0: |
|
plt.axvline(x=swa_size, color='blue', linestyle='--', alpha=0.8, linewidth=2.5, |
|
label=f"Sliding Window Limit ({swa_size:,})") |
|
|
|
for i, d_comp in enumerate(mla_d_compressed): |
|
d_comp = int(d_comp) |
|
kv_per_token = mla_kv_per_layer_per_token(d_comp, kv_parameter_size) |
|
mla_tps_values = compute_tps(kv_per_token, seq_len, batch_size, total_param_size, |
|
num_layers, swa_pattern, swa_size, memory_bandwidth) |
|
tps_values.extend(mla_tps_values) |
|
plt.plot(seq_len, mla_tps_values, label=f"MLA: dc = {d_comp}", |
|
color=palette[i + gqa_count], linewidth=3.5, alpha=0.85) |
|
|
|
plt.xscale('log') |
|
if all(np.isfinite(tps_values)): |
|
min_tps = min(tps_values) |
|
max_tps = max(tps_values) |
|
y_min = max(0, min_tps * 0.9) |
|
y_max = max_tps * 1.1 |
|
|
|
plt.ylim(y_min, y_max) |
|
else: |
|
plt.ylim(15, 40) |
|
|
|
plt.gca().xaxis.set_major_formatter(ScalarFormatter()) |
|
plt.gca().yaxis.set_major_formatter(ScalarFormatter()) |
|
|
|
ax = plt.gca() |
|
ax.spines['top'].set_visible(False) |
|
ax.spines['right'].set_visible(False) |
|
ax.spines['left'].set_linewidth(1.5) |
|
ax.spines['bottom'].set_linewidth(1.5) |
|
|
|
attn_label = "Global" if local_count == 0 else f"SWA {local_count}:{global_count}" |
|
device_name = model_name.split(':')[0] if ':' in model_name else model_name |
|
|
|
plt.annotate(f"{device_name}\nBandwidth: {memory_bandwidth/1e9:.1f} GB/s\nParameter Size: {parameter_size:.1f} bits\nAttention Kind: {attn_label}", |
|
xy=(0.8, 0.97), |
|
xycoords='axes fraction', |
|
bbox=dict(boxstyle="round,pad=0.4", facecolor="white", alpha=0.9, edgecolor='darkgray'), |
|
va='top', |
|
fontsize=11) |
|
|
|
plt.xlabel('Context Length (tokens)', fontsize=14, fontweight='bold') |
|
plt.ylabel('Tokens per Second', fontsize=14, fontweight='bold') |
|
|
|
plt.tick_params(axis='both', which='major', labelsize=12) |
|
|
|
model_title = model_name.split(':')[1] if ':' in model_name else model_name |
|
plt.title(f"{model_title}: Tokens Per Second vs. Sequence Length", fontsize=18, |
|
fontweight='bold', pad=20) |
|
|
|
plt.legend(title="Configuration", frameon=True, framealpha=0.95, fontsize=12, title_fontsize=14) |
|
|
|
plt.grid(True, alpha=0.5) |
|
|
|
buf = io.BytesIO() |
|
plt.savefig(buf, format='png') |
|
plt.close() |
|
buf.seek(0) |
|
from PIL import Image |
|
img = Image.open(buf) |
|
return img |
|
|