|
import gradio as gr |
|
from enum import Enum |
|
from throughput_utils import create_throughput_plot |
|
|
|
class AttentionType(Enum): |
|
LOCAL = 0 |
|
GLOBAL = 1 |
|
|
|
class PhoneBandwidth(Enum): |
|
Sixteen = 60 |
|
Fifteen = 51.2 |
|
Fourteen = 34.1 |
|
|
|
custom_css = """ |
|
#plot-container { |
|
border-radius: 10px; |
|
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1), 0 1px 3px rgba(0, 0, 0, 0.08); |
|
padding: 1rem; |
|
background-color: white; |
|
height: 100%; |
|
margin-bottom: 1.5rem; |
|
} |
|
|
|
#generate-button { |
|
background-color: #2563eb; |
|
color: white; |
|
border-radius: 8px; |
|
font-weight: bold; |
|
padding: 10px 20px; |
|
box-shadow: 0 4px 6px rgba(37, 99, 235, 0.1); |
|
transition: all 0.2s ease; |
|
width: 100%; |
|
max-width: 400px; |
|
margin: 0 auto; |
|
font-size: 16px; |
|
} |
|
|
|
#generate-button:hover { |
|
background-color: #1d4ed8; |
|
box-shadow: 0 6px 8px rgba(37, 99, 235, 0.2); |
|
transform: translateY(-2px); |
|
} |
|
|
|
.gradio-container { |
|
background-color: #f5f7fa; |
|
} |
|
|
|
/* Custom styles for sliders containers */ |
|
.sliders-container { |
|
border: 1px solid rgba(0, 0, 0, 0.1); |
|
border-radius: 8px; |
|
padding: 1rem; |
|
margin-top: 0.5rem; |
|
background-color: rgba(255, 255, 255, 0.8); |
|
} |
|
|
|
#error-status { |
|
color: #b91c1c; |
|
background-color: #fee2e2; |
|
border-radius: 8px; |
|
padding: 0.75rem; |
|
margin-top: 0.5rem; |
|
border: 1px solid #f87171; |
|
font-weight: 500; |
|
} |
|
""" |
|
|
|
with gr.Blocks(css=custom_css) as demo: |
|
gqa_sliders = [] |
|
mla_sliders = [] |
|
|
|
with gr.Column(): |
|
gr.Markdown( |
|
"""# ๐ On-Device LLM Throughput Calculator |
|
|
|
This tool estimates the throughput (tokens per second) of Large Language Models on devices with memory bandwidth constraints. |
|
It visualizes how different attention mechanisms (GQA, MLA) and context lengths affect throughput. |
|
""" |
|
) |
|
|
|
with gr.Row(): |
|
plot_output = gr.Image(label="Throughput Plot", type="pil", elem_id="plot-container") |
|
|
|
|
|
status_output = gr.Markdown(visible=False, elem_id="error-status") |
|
|
|
with gr.Row(): |
|
plot_button = gr.Button("Generate Throughput Plot", size="lg", elem_id="generate-button", variant="primary") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
with gr.Group(): |
|
gr.Markdown("### Device Configuration") |
|
model_name = gr.Textbox(label="Model Name", value="TinyLLM") |
|
iphone_model = gr.Dropdown( |
|
label="iPhone Model", |
|
choices=[e.name for e in PhoneBandwidth], |
|
value=PhoneBandwidth.Sixteen.name, |
|
interactive=True |
|
) |
|
|
|
with gr.Group(): |
|
gr.Markdown("### Attention Configurations to Plot") |
|
|
|
gr.Markdown("#### GQA Head Configurations") |
|
gr.Markdown("*Note: GQA head count must be less than or equal to the total number of heads*") |
|
|
|
with gr.Column(elem_classes="sliders-container"): |
|
gqa_slider1 = gr.Slider(minimum=1, maximum=32, step=2, value=4, |
|
label="GQA Head Count #1") |
|
gqa_slider2 = gr.Slider(minimum=1, maximum=32, step=2, value=8, |
|
label="GQA Head Count #2") |
|
gqa_sliders.extend([gqa_slider1, gqa_slider2]) |
|
|
|
gr.Markdown("#### MLA Compressed Dimensions") |
|
gr.Markdown("*Note: MLA dimension must be less than or equal to d_model*") |
|
|
|
with gr.Column(elem_classes="sliders-container"): |
|
mla_slider1 = gr.Slider(minimum=64, maximum=1024, step=64, value=256, |
|
label="MLA Dimension #1") |
|
mla_slider2 = gr.Slider(minimum=64, maximum=1024, step=64, value=512, |
|
label="MLA Dimension #2") |
|
mla_sliders.extend([mla_slider1, mla_slider2]) |
|
|
|
with gr.Column(scale=1): |
|
with gr.Group(): |
|
gr.Markdown("### Model Configuration") |
|
num_parameters = gr.Number(label="Parameters (Billions)", value=3) |
|
parameter_size = gr.Slider(minimum=1, maximum=16.0, step=1.0, label="Parameter Size (bits per param)", value=5) |
|
kv_parameter_size = gr.Slider(minimum=0.25, maximum=4.0, step=0.25, |
|
label="KV Cache Size (bytes per value)", value=2.0) |
|
num_layers = gr.Number(label="Number of Layers", value=36) |
|
num_heads = gr.Number(label="Number of Heads", value=16, |
|
info="GQA head counts must be less than or equal to this value") |
|
d_model = gr.Number(label="D Model", value=2048, |
|
info="MLA dimensions must be less than or equal to this value") |
|
|
|
with gr.Group(): |
|
gr.Markdown("### Context Configuration") |
|
ctx_length = gr.Slider(minimum=1024, maximum=131072, step=1024, |
|
label="Max Context Length", value=65536) |
|
local_layers = gr.Number(label="Local Attention Layers", value=0) |
|
global_layers = gr.Number(label="Global Attention Layers", value=1) |
|
swa_size = gr.Slider(minimum=1024, maximum=32768, step=1024, |
|
label="Sliding Window Size", value=4096) |
|
|
|
gr.Markdown( |
|
""" |
|
For more information, see [JAX ML Scaling Book](https://jax-ml.github.io/scaling-book/inference/#theoretical-estimates-for-llm-latency-and-throughput). |
|
""" |
|
) |
|
|
|
def generate_throughput_plot( |
|
model_name, iphone_model, num_parameters, parameter_size, |
|
kv_parameter_size, num_layers, num_heads, d_model, ctx_length, |
|
local_layers, global_layers, swa_size, gqa_1, gqa_2, mla_1, mla_2 |
|
): |
|
memory_bandwidth = PhoneBandwidth[iphone_model].value |
|
|
|
if "iPhone" not in model_name: |
|
model_name = f"iPhone {iphone_model}: {model_name}" |
|
|
|
try: |
|
|
|
for gqa_heads, label in [(gqa_1, "GQA Head Count #1"), (gqa_2, "GQA Head Count #2")]: |
|
if gqa_heads > num_heads: |
|
raise ValueError(f"{label} ({gqa_heads}) cannot be greater than the total number of attention heads ({num_heads})") |
|
|
|
|
|
for mla_dim, label in [(mla_1, "MLA Dimension #1"), (mla_2, "MLA Dimension #2")]: |
|
if mla_dim > d_model: |
|
raise ValueError(f"{label} ({mla_dim}) cannot be greater than the model dimension (d_model = {d_model})") |
|
|
|
plot_img = create_throughput_plot( |
|
model_name, |
|
memory_bandwidth, |
|
num_parameters, |
|
parameter_size, |
|
kv_parameter_size, |
|
num_layers, |
|
num_heads, |
|
d_model, |
|
ctx_length, |
|
local_layers, |
|
global_layers, |
|
swa_size, |
|
[gqa_1, gqa_2], |
|
[mla_1, mla_2], |
|
) |
|
|
|
|
|
return [ |
|
gr.update(value=plot_img), |
|
gr.update(visible=False, value="") |
|
] |
|
except Exception as e: |
|
err_string = f"Error generating plot: {str(e)}" |
|
print(err_string) |
|
|
|
return [ |
|
gr.update(value=None), |
|
gr.update(visible=True, value=f"โ ๏ธ {err_string}") |
|
] |
|
|
|
|
|
def update_gqa_sliders(heads_value): |
|
if not heads_value or heads_value < 1: |
|
heads_value = 1 |
|
return [gr.update(maximum=heads_value, value=min(slider.value, heads_value)) for slider in gqa_sliders] |
|
|
|
|
|
def update_mla_sliders(d_model_value): |
|
if not d_model_value or d_model_value < 64: |
|
d_model_value = 64 |
|
return [gr.update(maximum=d_model_value, value=min(slider.value, d_model_value)) for slider in mla_sliders] |
|
|
|
|
|
num_heads.change( |
|
update_gqa_sliders, |
|
inputs=[num_heads], |
|
outputs=gqa_sliders |
|
) |
|
|
|
d_model.change( |
|
update_mla_sliders, |
|
inputs=[d_model], |
|
outputs=mla_sliders |
|
) |
|
|
|
plot_button.click( |
|
generate_throughput_plot, |
|
inputs=[ |
|
model_name, |
|
iphone_model, |
|
num_parameters, |
|
parameter_size, |
|
kv_parameter_size, |
|
num_layers, |
|
num_heads, |
|
d_model, |
|
ctx_length, |
|
local_layers, |
|
global_layers, |
|
swa_size, |
|
*gqa_sliders, |
|
*mla_sliders, |
|
], |
|
outputs=[plot_output, status_output] |
|
) |
|
|
|
demo.launch() |
|
|