throughput-calculator / src /throughput_utils.py
FL33TW00D
chore: init
dc80200 unverified
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.ticker import ScalarFormatter
from enum import Enum
import io
class AttentionType(Enum):
LOCAL = 0
GLOBAL = 1
def gqa_kv_per_layer_per_token(n_kv_heads, d_head, kv_parameter_size):
return 2 * kv_parameter_size * n_kv_heads * d_head
def mla_kv_per_layer_per_token(d_compressed, kv_parameter_size):
return kv_parameter_size * d_compressed
def tokens_per_second(batch_size, bandwidth, total_kv_size, param_size):
return (batch_size * bandwidth) / (batch_size * total_kv_size + param_size)
def compute_tps(kv_per_layer_per_token, seq_len, batch_size, total_param_size,
num_layers, swa_pattern, swa_size, bandwidth):
tps_values = []
for ctx_len in seq_len:
total_kv_size = 0
for l in range(num_layers):
if swa_pattern[l % len(swa_pattern)] == AttentionType.LOCAL:
total_kv_size += kv_per_layer_per_token * min(ctx_len, swa_size)
else:
total_kv_size += kv_per_layer_per_token * ctx_len
tps = tokens_per_second(batch_size, bandwidth, total_kv_size, total_param_size)
tps_values.append(tps)
return tps_values
def create_throughput_plot(
model_name,
memory_bandwidth,
num_parameters,
parameter_size,
kv_parameter_size,
num_layers,
num_heads,
d_model,
ctx_length,
local_layers,
global_layers,
swa_size,
gqa_heads,
mla_d_compressed,
):
memory_bandwidth = float(memory_bandwidth) * 1_000_000_000
num_parameters = float(num_parameters) * 1_000_000_000
d_head = d_model // num_heads
total_param_size = num_parameters * (parameter_size / 8.0)
swa_pattern = ([AttentionType.LOCAL] * local_layers +
[AttentionType.GLOBAL] * global_layers)
if len(swa_pattern) == 0:
swa_pattern = [AttentionType.GLOBAL]
sns.set_theme(style="whitegrid", context="paper")
palette = sns.color_palette("viridis", len(gqa_heads) + len(mla_d_compressed))
plt.figure(figsize=(14, 8), dpi=300)
seq_len = np.logspace(2, 5, 100).astype(int)
batch_size = 1
tps_values = []
gqa_count = len(gqa_heads)
for i, n_kv_head in enumerate(gqa_heads):
n_kv_head = int(n_kv_head)
kv_per_token = gqa_kv_per_layer_per_token(n_kv_head, d_head, kv_parameter_size)
gqa_tps_values = compute_tps(kv_per_token, seq_len, batch_size, total_param_size,
num_layers, swa_pattern, swa_size, memory_bandwidth)
tps_values.extend(gqa_tps_values)
plt.plot(seq_len, gqa_tps_values, label=f"GQA: {n_kv_head} heads", color=palette[i],
linewidth=3.5, alpha=0.85)
plt.axvline(x=ctx_length, color='red', linestyle='--', alpha=0.8, linewidth=2.5,
label=f"Max Context Length ({ctx_length:,})")
local_count = swa_pattern.count(AttentionType.LOCAL)
global_count = swa_pattern.count(AttentionType.GLOBAL)
if local_count > 0:
plt.axvline(x=swa_size, color='blue', linestyle='--', alpha=0.8, linewidth=2.5,
label=f"Sliding Window Limit ({swa_size:,})")
for i, d_comp in enumerate(mla_d_compressed):
d_comp = int(d_comp)
kv_per_token = mla_kv_per_layer_per_token(d_comp, kv_parameter_size)
mla_tps_values = compute_tps(kv_per_token, seq_len, batch_size, total_param_size,
num_layers, swa_pattern, swa_size, memory_bandwidth)
tps_values.extend(mla_tps_values)
plt.plot(seq_len, mla_tps_values, label=f"MLA: dc = {d_comp}",
color=palette[i + gqa_count], linewidth=3.5, alpha=0.85)
plt.xscale('log')
if all(np.isfinite(tps_values)):
min_tps = min(tps_values)
max_tps = max(tps_values)
y_min = max(0, min_tps * 0.9)
y_max = max_tps * 1.1
plt.ylim(y_min, y_max)
else:
plt.ylim(15, 40)
plt.gca().xaxis.set_major_formatter(ScalarFormatter())
plt.gca().yaxis.set_major_formatter(ScalarFormatter())
ax = plt.gca()
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['left'].set_linewidth(1.5)
ax.spines['bottom'].set_linewidth(1.5)
attn_label = "Global" if local_count == 0 else f"SWA {local_count}:{global_count}"
device_name = model_name.split(':')[0] if ':' in model_name else model_name
plt.annotate(f"{device_name}\nBandwidth: {memory_bandwidth/1e9:.1f} GB/s\nParameter Size: {parameter_size:.1f} bits\nAttention Kind: {attn_label}",
xy=(0.8, 0.97),
xycoords='axes fraction',
bbox=dict(boxstyle="round,pad=0.4", facecolor="white", alpha=0.9, edgecolor='darkgray'),
va='top',
fontsize=11)
plt.xlabel('Context Length (tokens)', fontsize=14, fontweight='bold')
plt.ylabel('Tokens per Second', fontsize=14, fontweight='bold')
plt.tick_params(axis='both', which='major', labelsize=12)
model_title = model_name.split(':')[1] if ':' in model_name else model_name
plt.title(f"{model_title}: Tokens Per Second vs. Sequence Length", fontsize=18,
fontweight='bold', pad=20)
plt.legend(title="Configuration", frameon=True, framealpha=0.95, fontsize=12, title_fontsize=14)
plt.grid(True, alpha=0.5)
buf = io.BytesIO()
plt.savefig(buf, format='png')
plt.close()
buf.seek(0)
from PIL import Image
img = Image.open(buf)
return img