import numpy as np import matplotlib.pyplot as plt import seaborn as sns from matplotlib.ticker import ScalarFormatter from enum import Enum import io class AttentionType(Enum): LOCAL = 0 GLOBAL = 1 def gqa_kv_per_layer_per_token(n_kv_heads, d_head, kv_parameter_size): return 2 * kv_parameter_size * n_kv_heads * d_head def mla_kv_per_layer_per_token(d_compressed, kv_parameter_size): return kv_parameter_size * d_compressed def tokens_per_second(batch_size, bandwidth, total_kv_size, param_size): return (batch_size * bandwidth) / (batch_size * total_kv_size + param_size) def compute_tps(kv_per_layer_per_token, seq_len, batch_size, total_param_size, num_layers, swa_pattern, swa_size, bandwidth): tps_values = [] for ctx_len in seq_len: total_kv_size = 0 for l in range(num_layers): if swa_pattern[l % len(swa_pattern)] == AttentionType.LOCAL: total_kv_size += kv_per_layer_per_token * min(ctx_len, swa_size) else: total_kv_size += kv_per_layer_per_token * ctx_len tps = tokens_per_second(batch_size, bandwidth, total_kv_size, total_param_size) tps_values.append(tps) return tps_values def create_throughput_plot( model_name, memory_bandwidth, num_parameters, parameter_size, kv_parameter_size, num_layers, num_heads, d_model, ctx_length, local_layers, global_layers, swa_size, gqa_heads, mla_d_compressed, ): memory_bandwidth = float(memory_bandwidth) * 1_000_000_000 num_parameters = float(num_parameters) * 1_000_000_000 d_head = d_model // num_heads total_param_size = num_parameters * (parameter_size / 8.0) swa_pattern = ([AttentionType.LOCAL] * local_layers + [AttentionType.GLOBAL] * global_layers) if len(swa_pattern) == 0: swa_pattern = [AttentionType.GLOBAL] sns.set_theme(style="whitegrid", context="paper") palette = sns.color_palette("viridis", len(gqa_heads) + len(mla_d_compressed)) plt.figure(figsize=(14, 8), dpi=300) seq_len = np.logspace(2, 5, 100).astype(int) batch_size = 1 tps_values = [] gqa_count = len(gqa_heads) for i, n_kv_head in enumerate(gqa_heads): n_kv_head = int(n_kv_head) kv_per_token = gqa_kv_per_layer_per_token(n_kv_head, d_head, kv_parameter_size) gqa_tps_values = compute_tps(kv_per_token, seq_len, batch_size, total_param_size, num_layers, swa_pattern, swa_size, memory_bandwidth) tps_values.extend(gqa_tps_values) plt.plot(seq_len, gqa_tps_values, label=f"GQA: {n_kv_head} heads", color=palette[i], linewidth=3.5, alpha=0.85) plt.axvline(x=ctx_length, color='red', linestyle='--', alpha=0.8, linewidth=2.5, label=f"Max Context Length ({ctx_length:,})") local_count = swa_pattern.count(AttentionType.LOCAL) global_count = swa_pattern.count(AttentionType.GLOBAL) if local_count > 0: plt.axvline(x=swa_size, color='blue', linestyle='--', alpha=0.8, linewidth=2.5, label=f"Sliding Window Limit ({swa_size:,})") for i, d_comp in enumerate(mla_d_compressed): d_comp = int(d_comp) kv_per_token = mla_kv_per_layer_per_token(d_comp, kv_parameter_size) mla_tps_values = compute_tps(kv_per_token, seq_len, batch_size, total_param_size, num_layers, swa_pattern, swa_size, memory_bandwidth) tps_values.extend(mla_tps_values) plt.plot(seq_len, mla_tps_values, label=f"MLA: dc = {d_comp}", color=palette[i + gqa_count], linewidth=3.5, alpha=0.85) plt.xscale('log') if all(np.isfinite(tps_values)): min_tps = min(tps_values) max_tps = max(tps_values) y_min = max(0, min_tps * 0.9) y_max = max_tps * 1.1 plt.ylim(y_min, y_max) else: plt.ylim(15, 40) plt.gca().xaxis.set_major_formatter(ScalarFormatter()) plt.gca().yaxis.set_major_formatter(ScalarFormatter()) ax = plt.gca() ax.spines['top'].set_visible(False) ax.spines['right'].set_visible(False) ax.spines['left'].set_linewidth(1.5) ax.spines['bottom'].set_linewidth(1.5) attn_label = "Global" if local_count == 0 else f"SWA {local_count}:{global_count}" device_name = model_name.split(':')[0] if ':' in model_name else model_name plt.annotate(f"{device_name}\nBandwidth: {memory_bandwidth/1e9:.1f} GB/s\nParameter Size: {parameter_size:.1f} bits\nAttention Kind: {attn_label}", xy=(0.8, 0.97), xycoords='axes fraction', bbox=dict(boxstyle="round,pad=0.4", facecolor="white", alpha=0.9, edgecolor='darkgray'), va='top', fontsize=11) plt.xlabel('Context Length (tokens)', fontsize=14, fontweight='bold') plt.ylabel('Tokens per Second', fontsize=14, fontweight='bold') plt.tick_params(axis='both', which='major', labelsize=12) model_title = model_name.split(':')[1] if ':' in model_name else model_name plt.title(f"{model_title}: Tokens Per Second vs. Sequence Length", fontsize=18, fontweight='bold', pad=20) plt.legend(title="Configuration", frameon=True, framealpha=0.95, fontsize=12, title_fontsize=14) plt.grid(True, alpha=0.5) buf = io.BytesIO() plt.savefig(buf, format='png') plt.close() buf.seek(0) from PIL import Image img = Image.open(buf) return img