File size: 5,634 Bytes
dc80200
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.ticker import ScalarFormatter
from enum import Enum
import io

class AttentionType(Enum):
    LOCAL = 0
    GLOBAL = 1

def gqa_kv_per_layer_per_token(n_kv_heads, d_head, kv_parameter_size):
    return  2 * kv_parameter_size * n_kv_heads * d_head

def mla_kv_per_layer_per_token(d_compressed, kv_parameter_size):
    return kv_parameter_size * d_compressed

def tokens_per_second(batch_size, bandwidth, total_kv_size, param_size):
    return (batch_size * bandwidth) / (batch_size * total_kv_size + param_size)

def compute_tps(kv_per_layer_per_token, seq_len, batch_size, total_param_size, 
                num_layers, swa_pattern, swa_size, bandwidth):
    tps_values = []
    for ctx_len in seq_len:
        total_kv_size = 0
        for l in range(num_layers):
            if swa_pattern[l % len(swa_pattern)] == AttentionType.LOCAL:
                total_kv_size += kv_per_layer_per_token * min(ctx_len, swa_size)
            else:
                total_kv_size += kv_per_layer_per_token * ctx_len
        tps = tokens_per_second(batch_size, bandwidth, total_kv_size, total_param_size)
        tps_values.append(tps)
    return tps_values

def create_throughput_plot(
    model_name,
    memory_bandwidth,
    num_parameters,
    parameter_size,
    kv_parameter_size,
    num_layers,
    num_heads,
    d_model,
    ctx_length,
    local_layers,
    global_layers,
    swa_size,
    gqa_heads,
    mla_d_compressed,
):
    memory_bandwidth = float(memory_bandwidth) * 1_000_000_000 
    num_parameters = float(num_parameters) * 1_000_000_000 
    
    d_head = d_model // num_heads
    total_param_size = num_parameters * (parameter_size / 8.0)
    
    swa_pattern = ([AttentionType.LOCAL] * local_layers + 
                  [AttentionType.GLOBAL] * global_layers)
    
    if len(swa_pattern) == 0:
        swa_pattern = [AttentionType.GLOBAL]
    
    sns.set_theme(style="whitegrid", context="paper")
    palette = sns.color_palette("viridis", len(gqa_heads) + len(mla_d_compressed))
    plt.figure(figsize=(14, 8), dpi=300)
    
    seq_len = np.logspace(2, 5, 100).astype(int)
    batch_size = 1
    
    tps_values = []
    gqa_count = len(gqa_heads) 
    for i, n_kv_head in enumerate(gqa_heads):
        n_kv_head = int(n_kv_head)
        kv_per_token = gqa_kv_per_layer_per_token(n_kv_head, d_head, kv_parameter_size)
        gqa_tps_values = compute_tps(kv_per_token, seq_len, batch_size, total_param_size, 
                                num_layers, swa_pattern, swa_size, memory_bandwidth)
        tps_values.extend(gqa_tps_values)
        plt.plot(seq_len, gqa_tps_values, label=f"GQA: {n_kv_head} heads", color=palette[i], 
                linewidth=3.5, alpha=0.85)
    
    plt.axvline(x=ctx_length, color='red', linestyle='--', alpha=0.8, linewidth=2.5,
                label=f"Max Context Length ({ctx_length:,})")
    
    local_count = swa_pattern.count(AttentionType.LOCAL)
    global_count = swa_pattern.count(AttentionType.GLOBAL)
    if local_count > 0:
        plt.axvline(x=swa_size, color='blue', linestyle='--', alpha=0.8, linewidth=2.5,
                    label=f"Sliding Window Limit ({swa_size:,})")
    
    for i, d_comp in enumerate(mla_d_compressed):
        d_comp = int(d_comp)
        kv_per_token = mla_kv_per_layer_per_token(d_comp, kv_parameter_size)
        mla_tps_values = compute_tps(kv_per_token, seq_len, batch_size, total_param_size, 
                                num_layers, swa_pattern, swa_size, memory_bandwidth)
        tps_values.extend(mla_tps_values)
        plt.plot(seq_len, mla_tps_values, label=f"MLA: dc = {d_comp}", 
                color=palette[i + gqa_count], linewidth=3.5, alpha=0.85)
    
    plt.xscale('log')
    if all(np.isfinite(tps_values)):
        min_tps = min(tps_values)
        max_tps = max(tps_values)
        y_min = max(0, min_tps * 0.9)
        y_max = max_tps * 1.1
        
        plt.ylim(y_min, y_max)
    else:
        plt.ylim(15, 40)

    plt.gca().xaxis.set_major_formatter(ScalarFormatter())
    plt.gca().yaxis.set_major_formatter(ScalarFormatter())
    
    ax = plt.gca()
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['left'].set_linewidth(1.5)
    ax.spines['bottom'].set_linewidth(1.5)
    
    attn_label = "Global" if local_count == 0 else f"SWA {local_count}:{global_count}"
    device_name = model_name.split(':')[0] if ':' in model_name else model_name
    
    plt.annotate(f"{device_name}\nBandwidth: {memory_bandwidth/1e9:.1f} GB/s\nParameter Size: {parameter_size:.1f} bits\nAttention Kind: {attn_label}",
                 xy=(0.8, 0.97), 
                 xycoords='axes fraction',
                 bbox=dict(boxstyle="round,pad=0.4", facecolor="white", alpha=0.9, edgecolor='darkgray'),
                 va='top',
                 fontsize=11)
    
    plt.xlabel('Context Length (tokens)', fontsize=14, fontweight='bold')
    plt.ylabel('Tokens per Second', fontsize=14, fontweight='bold')
    
    plt.tick_params(axis='both', which='major', labelsize=12)
    
    model_title = model_name.split(':')[1] if ':' in model_name else model_name
    plt.title(f"{model_title}: Tokens Per Second vs. Sequence Length", fontsize=18,
              fontweight='bold', pad=20)
    
    plt.legend(title="Configuration", frameon=True, framealpha=0.95, fontsize=12, title_fontsize=14)
    
    plt.grid(True, alpha=0.5)
    
    buf = io.BytesIO()
    plt.savefig(buf, format='png')
    plt.close()
    buf.seek(0)
    from PIL import Image
    img = Image.open(buf)
    return img