Spaces:
Sleeping
Sleeping
import numpy as np | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
from matplotlib.ticker import ScalarFormatter | |
from enum import Enum | |
import io | |
class AttentionType(Enum): | |
LOCAL = 0 | |
GLOBAL = 1 | |
def gqa_kv_per_layer_per_token(n_kv_heads, d_head, kv_parameter_size): | |
return 2 * kv_parameter_size * n_kv_heads * d_head | |
def mla_kv_per_layer_per_token(d_compressed, kv_parameter_size): | |
return kv_parameter_size * d_compressed | |
def tokens_per_second(batch_size, bandwidth, total_kv_size, param_size): | |
return (batch_size * bandwidth) / (batch_size * total_kv_size + param_size) | |
def compute_tps(kv_per_layer_per_token, seq_len, batch_size, total_param_size, | |
num_layers, swa_pattern, swa_size, bandwidth): | |
tps_values = [] | |
for ctx_len in seq_len: | |
total_kv_size = 0 | |
for l in range(num_layers): | |
if swa_pattern[l % len(swa_pattern)] == AttentionType.LOCAL: | |
total_kv_size += kv_per_layer_per_token * min(ctx_len, swa_size) | |
else: | |
total_kv_size += kv_per_layer_per_token * ctx_len | |
tps = tokens_per_second(batch_size, bandwidth, total_kv_size, total_param_size) | |
tps_values.append(tps) | |
return tps_values | |
def create_throughput_plot( | |
model_name, | |
memory_bandwidth, | |
num_parameters, | |
parameter_size, | |
kv_parameter_size, | |
num_layers, | |
num_heads, | |
d_model, | |
ctx_length, | |
local_layers, | |
global_layers, | |
swa_size, | |
gqa_heads, | |
mla_d_compressed, | |
): | |
memory_bandwidth = float(memory_bandwidth) * 1_000_000_000 | |
num_parameters = float(num_parameters) * 1_000_000_000 | |
d_head = d_model // num_heads | |
total_param_size = num_parameters * (parameter_size / 8.0) | |
swa_pattern = ([AttentionType.LOCAL] * local_layers + | |
[AttentionType.GLOBAL] * global_layers) | |
if len(swa_pattern) == 0: | |
swa_pattern = [AttentionType.GLOBAL] | |
sns.set_theme(style="whitegrid", context="paper") | |
palette = sns.color_palette("viridis", len(gqa_heads) + len(mla_d_compressed)) | |
plt.figure(figsize=(14, 8), dpi=300) | |
seq_len = np.logspace(2, 5, 100).astype(int) | |
batch_size = 1 | |
tps_values = [] | |
gqa_count = len(gqa_heads) | |
for i, n_kv_head in enumerate(gqa_heads): | |
n_kv_head = int(n_kv_head) | |
kv_per_token = gqa_kv_per_layer_per_token(n_kv_head, d_head, kv_parameter_size) | |
gqa_tps_values = compute_tps(kv_per_token, seq_len, batch_size, total_param_size, | |
num_layers, swa_pattern, swa_size, memory_bandwidth) | |
tps_values.extend(gqa_tps_values) | |
plt.plot(seq_len, gqa_tps_values, label=f"GQA: {n_kv_head} heads", color=palette[i], | |
linewidth=3.5, alpha=0.85) | |
plt.axvline(x=ctx_length, color='red', linestyle='--', alpha=0.8, linewidth=2.5, | |
label=f"Max Context Length ({ctx_length:,})") | |
local_count = swa_pattern.count(AttentionType.LOCAL) | |
global_count = swa_pattern.count(AttentionType.GLOBAL) | |
if local_count > 0: | |
plt.axvline(x=swa_size, color='blue', linestyle='--', alpha=0.8, linewidth=2.5, | |
label=f"Sliding Window Limit ({swa_size:,})") | |
for i, d_comp in enumerate(mla_d_compressed): | |
d_comp = int(d_comp) | |
kv_per_token = mla_kv_per_layer_per_token(d_comp, kv_parameter_size) | |
mla_tps_values = compute_tps(kv_per_token, seq_len, batch_size, total_param_size, | |
num_layers, swa_pattern, swa_size, memory_bandwidth) | |
tps_values.extend(mla_tps_values) | |
plt.plot(seq_len, mla_tps_values, label=f"MLA: dc = {d_comp}", | |
color=palette[i + gqa_count], linewidth=3.5, alpha=0.85) | |
plt.xscale('log') | |
if all(np.isfinite(tps_values)): | |
min_tps = min(tps_values) | |
max_tps = max(tps_values) | |
y_min = max(0, min_tps * 0.9) | |
y_max = max_tps * 1.1 | |
plt.ylim(y_min, y_max) | |
else: | |
plt.ylim(15, 40) | |
plt.gca().xaxis.set_major_formatter(ScalarFormatter()) | |
plt.gca().yaxis.set_major_formatter(ScalarFormatter()) | |
ax = plt.gca() | |
ax.spines['top'].set_visible(False) | |
ax.spines['right'].set_visible(False) | |
ax.spines['left'].set_linewidth(1.5) | |
ax.spines['bottom'].set_linewidth(1.5) | |
attn_label = "Global" if local_count == 0 else f"SWA {local_count}:{global_count}" | |
device_name = model_name.split(':')[0] if ':' in model_name else model_name | |
plt.annotate(f"{device_name}\nBandwidth: {memory_bandwidth/1e9:.1f} GB/s\nParameter Size: {parameter_size:.1f} bits\nAttention Kind: {attn_label}", | |
xy=(0.8, 0.97), | |
xycoords='axes fraction', | |
bbox=dict(boxstyle="round,pad=0.4", facecolor="white", alpha=0.9, edgecolor='darkgray'), | |
va='top', | |
fontsize=11) | |
plt.xlabel('Context Length (tokens)', fontsize=14, fontweight='bold') | |
plt.ylabel('Tokens per Second', fontsize=14, fontweight='bold') | |
plt.tick_params(axis='both', which='major', labelsize=12) | |
model_title = model_name.split(':')[1] if ':' in model_name else model_name | |
plt.title(f"{model_title}: Tokens Per Second vs. Sequence Length", fontsize=18, | |
fontweight='bold', pad=20) | |
plt.legend(title="Configuration", frameon=True, framealpha=0.95, fontsize=12, title_fontsize=14) | |
plt.grid(True, alpha=0.5) | |
buf = io.BytesIO() | |
plt.savefig(buf, format='png') | |
plt.close() | |
buf.seek(0) | |
from PIL import Image | |
img = Image.open(buf) | |
return img | |