Spaces:
Sleeping
Sleeping
import gradio as gr | |
from enum import Enum | |
from throughput_utils import create_throughput_plot | |
class AttentionType(Enum): | |
LOCAL = 0 | |
GLOBAL = 1 | |
class PhoneBandwidth(Enum): | |
Sixteen = 60 | |
Fifteen = 51.2 | |
Fourteen = 34.1 | |
custom_css = """ | |
#plot-container { | |
border-radius: 10px; | |
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1), 0 1px 3px rgba(0, 0, 0, 0.08); | |
padding: 1rem; | |
background-color: white; | |
height: 100%; | |
margin-bottom: 1.5rem; | |
} | |
#generate-button { | |
background-color: #2563eb; | |
color: white; | |
border-radius: 8px; | |
font-weight: bold; | |
padding: 10px 20px; | |
box-shadow: 0 4px 6px rgba(37, 99, 235, 0.1); | |
transition: all 0.2s ease; | |
width: 100%; | |
max-width: 400px; | |
margin: 0 auto; | |
font-size: 16px; | |
} | |
#generate-button:hover { | |
background-color: #1d4ed8; | |
box-shadow: 0 6px 8px rgba(37, 99, 235, 0.2); | |
transform: translateY(-2px); | |
} | |
.gradio-container { | |
background-color: #f5f7fa; | |
} | |
/* Custom styles for sliders containers */ | |
.sliders-container { | |
border: 1px solid rgba(0, 0, 0, 0.1); | |
border-radius: 8px; | |
padding: 1rem; | |
margin-top: 0.5rem; | |
background-color: rgba(255, 255, 255, 0.8); | |
} | |
#error-status { | |
color: #b91c1c; | |
background-color: #fee2e2; | |
border-radius: 8px; | |
padding: 0.75rem; | |
margin-top: 0.5rem; | |
border: 1px solid #f87171; | |
font-weight: 500; | |
} | |
""" | |
with gr.Blocks(css=custom_css) as demo: | |
gqa_sliders = [] | |
mla_sliders = [] | |
with gr.Column(): | |
gr.Markdown( | |
"""# 📊 On-Device LLM Throughput Calculator | |
This tool estimates the throughput (tokens per second) of Large Language Models on devices with memory bandwidth constraints. | |
It visualizes how different attention mechanisms (GQA, MLA) and context lengths affect throughput. | |
""" | |
) | |
with gr.Row(): | |
plot_output = gr.Image(label="Throughput Plot", type="pil", elem_id="plot-container") | |
# Add status element to display validation errors | |
status_output = gr.Markdown(visible=False, elem_id="error-status") | |
with gr.Row(): | |
plot_button = gr.Button("Generate Throughput Plot", size="lg", elem_id="generate-button", variant="primary") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
with gr.Group(): | |
gr.Markdown("### Device Configuration") | |
model_name = gr.Textbox(label="Model Name", value="TinyLLM") | |
iphone_model = gr.Dropdown( | |
label="iPhone Model", | |
choices=[e.name for e in PhoneBandwidth], | |
value=PhoneBandwidth.Sixteen.name, | |
interactive=True | |
) | |
with gr.Group(): | |
gr.Markdown("### Attention Configurations to Plot") | |
gr.Markdown("#### GQA Head Configurations") | |
gr.Markdown("*Note: GQA head count must be less than or equal to the total number of heads*") | |
with gr.Column(elem_classes="sliders-container"): | |
gqa_slider1 = gr.Slider(minimum=1, maximum=32, step=2, value=4, | |
label="GQA Head Count #1") | |
gqa_slider2 = gr.Slider(minimum=1, maximum=32, step=2, value=8, | |
label="GQA Head Count #2") | |
gqa_sliders.extend([gqa_slider1, gqa_slider2]) | |
gr.Markdown("#### MLA Compressed Dimensions") | |
gr.Markdown("*Note: MLA dimension must be less than or equal to d_model*") | |
with gr.Column(elem_classes="sliders-container"): | |
mla_slider1 = gr.Slider(minimum=64, maximum=1024, step=64, value=256, | |
label="MLA Dimension #1") | |
mla_slider2 = gr.Slider(minimum=64, maximum=1024, step=64, value=512, | |
label="MLA Dimension #2") | |
mla_sliders.extend([mla_slider1, mla_slider2]) | |
with gr.Column(scale=1): | |
with gr.Group(): | |
gr.Markdown("### Model Configuration") | |
num_parameters = gr.Number(label="Parameters (Billions)", value=3) | |
parameter_size = gr.Slider(minimum=1, maximum=16.0, step=1.0, label="Parameter Size (bits per param)", value=5) | |
kv_parameter_size = gr.Slider(minimum=0.25, maximum=4.0, step=0.25, | |
label="KV Cache Size (bytes per value)", value=2.0) | |
num_layers = gr.Number(label="Number of Layers", value=36) | |
num_heads = gr.Number(label="Number of Heads", value=16, | |
info="GQA head counts must be less than or equal to this value") | |
d_model = gr.Number(label="D Model", value=2048, | |
info="MLA dimensions must be less than or equal to this value") | |
with gr.Group(): | |
gr.Markdown("### Context Configuration") | |
ctx_length = gr.Slider(minimum=1024, maximum=131072, step=1024, | |
label="Max Context Length", value=65536) | |
local_layers = gr.Number(label="Local Attention Layers", value=0) | |
global_layers = gr.Number(label="Global Attention Layers", value=1) | |
swa_size = gr.Slider(minimum=1024, maximum=32768, step=1024, | |
label="Sliding Window Size", value=4096) | |
gr.Markdown( | |
""" | |
For more information, see [JAX ML Scaling Book](https://jax-ml.github.io/scaling-book/inference/#theoretical-estimates-for-llm-latency-and-throughput). | |
""" | |
) | |
def generate_throughput_plot( | |
model_name, iphone_model, num_parameters, parameter_size, | |
kv_parameter_size, num_layers, num_heads, d_model, ctx_length, | |
local_layers, global_layers, swa_size, gqa_1, gqa_2, mla_1, mla_2 | |
): | |
memory_bandwidth = PhoneBandwidth[iphone_model].value | |
if "iPhone" not in model_name: | |
model_name = f"iPhone {iphone_model}: {model_name}" | |
try: | |
# Validate GQA head counts must be less than total attention heads | |
for gqa_heads, label in [(gqa_1, "GQA Head Count #1"), (gqa_2, "GQA Head Count #2")]: | |
if gqa_heads > num_heads: | |
raise ValueError(f"{label} ({gqa_heads}) cannot be greater than the total number of attention heads ({num_heads})") | |
# Validate MLA compressed dimensions must be less than d_model | |
for mla_dim, label in [(mla_1, "MLA Dimension #1"), (mla_2, "MLA Dimension #2")]: | |
if mla_dim > d_model: | |
raise ValueError(f"{label} ({mla_dim}) cannot be greater than the model dimension (d_model = {d_model})") | |
plot_img = create_throughput_plot( | |
model_name, | |
memory_bandwidth, | |
num_parameters, | |
parameter_size, | |
kv_parameter_size, | |
num_layers, | |
num_heads, | |
d_model, | |
ctx_length, | |
local_layers, | |
global_layers, | |
swa_size, | |
[gqa_1, gqa_2], | |
[mla_1, mla_2], | |
) | |
# Hide error message, show plot | |
return [ | |
gr.update(value=plot_img), | |
gr.update(visible=False, value="") | |
] | |
except Exception as e: | |
err_string = f"Error generating plot: {str(e)}" | |
print(err_string) | |
# Show error message, clear plot | |
return [ | |
gr.update(value=None), | |
gr.update(visible=True, value=f"⚠️ {err_string}") | |
] | |
# Function to update GQA sliders based on number of heads | |
def update_gqa_sliders(heads_value): | |
if not heads_value or heads_value < 1: | |
heads_value = 1 | |
return [gr.update(maximum=heads_value, value=min(slider.value, heads_value)) for slider in gqa_sliders] | |
# Function to update MLA sliders based on d_model | |
def update_mla_sliders(d_model_value): | |
if not d_model_value or d_model_value < 64: | |
d_model_value = 64 | |
return [gr.update(maximum=d_model_value, value=min(slider.value, d_model_value)) for slider in mla_sliders] | |
# Add event handlers to update sliders when model configuration changes | |
num_heads.change( | |
update_gqa_sliders, | |
inputs=[num_heads], | |
outputs=gqa_sliders | |
) | |
d_model.change( | |
update_mla_sliders, | |
inputs=[d_model], | |
outputs=mla_sliders | |
) | |
plot_button.click( | |
generate_throughput_plot, | |
inputs=[ | |
model_name, | |
iphone_model, | |
num_parameters, | |
parameter_size, | |
kv_parameter_size, | |
num_layers, | |
num_heads, | |
d_model, | |
ctx_length, | |
local_layers, | |
global_layers, | |
swa_size, | |
*gqa_sliders, | |
*mla_sliders, | |
], | |
outputs=[plot_output, status_output] | |
) | |
demo.launch() | |