ucalyptus's picture
Update index.html
4bf0bc7 verified
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>FlashInfer: Attention States & Recursive Merge</title>
<script src="https://cdn.tailwindcss.com"></script>
<link href="https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&display=swap" rel="stylesheet">
<style>
body {
font-family: 'Inter', sans-serif;
overflow-x: hidden;
}
.math {
font-family: 'Times New Roman', serif;
font-style: italic;
}
.node {
transition: all 0.5s ease-in-out;
}
.arrow {
stroke-dasharray: 10;
animation: dash 1s linear infinite;
}
@keyframes dash {
to {
stroke-dashoffset: -20;
}
}
.highlight {
animation: pulse 2s infinite;
}
@keyframes pulse {
0% { transform: scale(1); opacity: 1; }
50% { transform: scale(1.05); opacity: 0.9; }
100% { transform: scale(1); opacity: 1; }
}
.fade-in {
opacity: 0;
animation: fadeIn 1s forwards;
}
@keyframes fadeIn {
to { opacity: 1; }
}
.equation {
transition: all 0.5s ease;
}
.panel {
transition: transform 0.5s ease-out;
}
</style>
</head>
<body class="bg-gray-100 text-gray-900">
<div class="container mx-auto p-4 max-w-6xl">
<header class="text-center my-8">
<h1 class="text-4xl font-bold text-blue-800 mb-2">FlashInfer: Attention States & Recursive Merge</h1>
<p class="text-xl text-gray-600">Visualizing how FlashInfer accelerates LLM inference</p>
</header>
<div class="bg-white rounded-xl shadow-lg p-6 mb-8">
<h2 class="text-2xl font-semibold mb-4 text-blue-700">Key Innovation: Attention States</h2>
<p class="mb-4">FlashInfer introduces the concept of <strong>attention states</strong>, which fully characterize the attention between a query and a set of key/value pairs. Each attention state consists of two components:</p>
<div class="flex flex-col md:flex-row gap-4 mb-6">
<div class="flex-1 bg-blue-50 p-4 rounded-lg">
<h3 class="font-medium text-blue-800 mb-2">Generalized Score (s)</h3>
<div class="flex justify-center">
<div class="math text-xl">
s(I) = log(∑<sub>i∈I</sub> exp(s<sub>i</sub>))
</div>
</div>
<p class="mt-2 text-sm text-gray-600">The log-sum-exp (LSE) of pre-softmax attention scores</p>
</div>
<div class="flex-1 bg-blue-50 p-4 rounded-lg">
<h3 class="font-medium text-blue-800 mb-2">Generalized Value (v)</h3>
<div class="flex justify-center">
<div class="math text-xl">
v(I) = ∑<sub>i∈I</sub> softmax(s<sub>i</sub>)v<sub>i</sub>
</div>
</div>
<p class="mt-2 text-sm text-gray-600">The weighted sum of value vectors using the softmax of scores</p>
</div>
</div>
</div>
<div class="bg-white rounded-xl shadow-lg p-6 mb-8">
<h2 class="text-2xl font-semibold mb-4 text-blue-700">Recursive Merge Operator</h2>
<p class="mb-4">The key insight of FlashInfer is that attention states can be <strong>merged</strong> efficiently. Given two attention states corresponding to different subsets of KV pairs, we can compute the attention state for their union:</p>
<div class="bg-blue-50 p-4 rounded-lg mb-6">
<div class="flex justify-center">
<div class="math text-xl">
[v(I∪J), s(I∪J)] = [v(I), s(I)] ⊕ [v(J), s(J)]
</div>
</div>
</div>
<p class="mb-6">This merge operator (⊕) is <strong>commutative</strong> and <strong>associative</strong>, allowing flexible computation strategies.</p>
</div>
<!-- Interactive Visualization -->
<div class="bg-white rounded-xl shadow-lg p-6 mb-8">
<h2 class="text-2xl font-semibold mb-4 text-blue-700">Interactive Visualization</h2>
<p class="mb-6">This animation shows how FlashInfer computes attention for a query over 4 KV pairs by partitioning the work and merging results.</p>
<div class="flex justify-center">
<div class="relative" style="width: 700px; height: 650px;" id="visualization">
<!-- SVG will be inserted here -->
</div>
</div>
<div class="flex justify-center mt-4">
<button id="resetBtn" class="bg-blue-600 hover:bg-blue-700 text-white px-4 py-2 rounded-lg mr-4">
Reset Animation
</button>
<button id="playBtn" class="bg-green-600 hover:bg-green-700 text-white px-4 py-2 rounded-lg">
Play Animation
</button>
</div>
</div>
<div class="bg-white rounded-xl shadow-lg p-6 mb-8">
<h2 class="text-2xl font-semibold mb-4 text-blue-700">Applications</h2>
<div class="grid grid-cols-1 md:grid-cols-2 gap-6">
<div class="bg-blue-50 p-4 rounded-lg">
<h3 class="font-medium text-blue-800 mb-2">Shared-Prefix Batch Decoding</h3>
<p>When multiple sequences share a common prefix (e.g., same prompt), compute the attention state for the shared part once, then merge with each sequence's unique suffix.</p>
<p class="mt-2 text-sm font-medium text-green-600">Up to 30x speedup in long-context scenarios</p>
</div>
<div class="bg-blue-50 p-4 rounded-lg">
<h3 class="font-medium text-blue-800 mb-2">KV Sequence Parallelism</h3>
<p>Partition long KV sequences across multiple processing units, compute partial attention states in parallel, then merge the results.</p>
<p class="mt-2 text-sm font-medium text-green-600">Improves GPU utilization for memory-constrained scenarios</p>
</div>
</div>
</div>
</div>
<script>
// SVG for the visualization
const svg = `
<svg width="100%" height="100%" viewBox="0 0 700 650">
<!-- Query Node -->
<g class="fade-in" style="animation-delay: 0.5s">
<circle id="queryNode" cx="350" cy="50" r="30" fill="#4299e1" />
<text x="350" y="55" text-anchor="middle" fill="white" font-weight="bold">Q</text>
</g>
<!-- KV Nodes -->
<g class="fade-in" style="animation-delay: 1s">
<circle id="kv1" cx="175" cy="150" r="25" fill="#9ae6b4" />
<text x="175" y="155" text-anchor="middle" fill="#2f855a" font-weight="bold">KV₁</text>
<circle id="kv2" cx="275" cy="150" r="25" fill="#9ae6b4" />
<text x="275" y="155" text-anchor="middle" fill="#2f855a" font-weight="bold">KV₂</text>
<circle id="kv3" cx="425" cy="150" r="25" fill="#9ae6b4" />
<text x="425" y="155" text-anchor="middle" fill="#2f855a" font-weight="bold">KV₃</text>
<circle id="kv4" cx="525" cy="150" r="25" fill="#9ae6b4" />
<text x="525" y="155" text-anchor="middle" fill="#2f855a" font-weight="bold">KV₄</text>
</g>
<!-- Lines connecting Query to KVs -->
<g class="fade-in" style="animation-delay: 1.5s">
<path id="line1" d="M 350 80 L 175 125" stroke="#4299e1" stroke-width="2" fill="none" />
<path id="line2" d="M 350 80 L 275 125" stroke="#4299e1" stroke-width="2" fill="none" />
<path id="line3" d="M 350 80 L 425 125" stroke="#4299e1" stroke-width="2" fill="none" />
<path id="line4" d="M 350 80 L 525 125" stroke="#4299e1" stroke-width="2" fill="none" />
</g>
<!-- Attention State Nodes -->
<g id="stateNodes">
<!-- These will be animated in via JS -->
<g id="state1" opacity="0">
<circle cx="175" cy="250" r="25" fill="#feb2b2" />
<text x="175" y="255" text-anchor="middle" fill="#742a2a" font-weight="bold">s₁,v₁</text>
</g>
<g id="state2" opacity="0">
<circle cx="275" cy="250" r="25" fill="#feb2b2" />
<text x="275" y="255" text-anchor="middle" fill="#742a2a" font-weight="bold">s₂,v₂</text>
</g>
<g id="state3" opacity="0">
<circle cx="425" cy="250" r="25" fill="#feb2b2" />
<text x="425" y="255" text-anchor="middle" fill="#742a2a" font-weight="bold">s₃,v₃</text>
</g>
<g id="state4" opacity="0">
<circle cx="525" cy="250" r="25" fill="#feb2b2" />
<text x="525" y="255" text-anchor="middle" fill="#742a2a" font-weight="bold">s₄,v₄</text>
</g>
<!-- Merge Level 1 -->
<g id="merge1" opacity="0">
<circle cx="225" cy="350" r="30" fill="#fbd38d" />
<text x="225" y="355" text-anchor="middle" fill="#7b341e" font-weight="bold">s₁₂,v₁₂</text>
</g>
<g id="merge2" opacity="0">
<circle cx="475" cy="350" r="30" fill="#fbd38d" />
<text x="475" y="355" text-anchor="middle" fill="#7b341e" font-weight="bold">s₃₄,v₃₄</text>
</g>
<!-- Final Merge -->
<g id="finalMerge" opacity="0">
<circle cx="350" cy="450" r="35" fill="#d6bcfa" />
<text x="350" y="455" text-anchor="middle" fill="#553c9a" font-weight="bold">s₁₂₃₄,v₁₂₃₄</text>
</g>
</g>
<!-- Merge Arrows - will be animated in -->
<g id="mergeArrows">
<!-- State to Merge Level 1 -->
<path id="arrow1" opacity="0" d="M 175 275 L 210 325" stroke="#f6ad55" stroke-width="3" fill="none" class="arrow" marker-end="url(#arrowhead)" />
<path id="arrow2" opacity="0" d="M 275 275 L 240 325" stroke="#f6ad55" stroke-width="3" fill="none" class="arrow" marker-end="url(#arrowhead)" />
<path id="arrow3" opacity="0" d="M 425 275 L 460 325" stroke="#f6ad55" stroke-width="3" fill="none" class="arrow" marker-end="url(#arrowhead)" />
<path id="arrow4" opacity="0" d="M 525 275 L 490 325" stroke="#f6ad55" stroke-width="3" fill="none" class="arrow" marker-end="url(#arrowhead)" />
<!-- Merge Level 1 to Final -->
<path id="arrow5" opacity="0" d="M 225 380 L 320 425" stroke="#b794f4" stroke-width="3" fill="none" class="arrow" marker-end="url(#arrowhead)" />
<path id="arrow6" opacity="0" d="M 475 380 L 380 425" stroke="#b794f4" stroke-width="3" fill="none" class="arrow" marker-end="url(#arrowhead)" />
</g>
<!-- Equation Panel -->
<g id="equationPanel" opacity="0" transform="translate(350, 580)">
<rect x="-330" y="-48" width="660" height="72" rx="10" fill="#ebf8ff" stroke="#4299e1" stroke-width="2" />
<text id="equationText" x="0" y="0" text-anchor="middle" font-family="monospace" font-size="14">
Attention States: Computing...
</text>
</g>
<!-- Arrow marker definition -->
<defs>
<marker id="arrowhead" markerWidth="10" markerHeight="7" refX="9" refY="3.5" orient="auto">
<polygon points="0 0, 10 3.5, 0 7" fill="#000" />
</marker>
</defs>
</svg>`;
// Insert the SVG
document.getElementById('visualization').innerHTML = svg;
// Animation sequence
let animationStep = 0;
let animationInterval;
const steps = [
// Step 0: Already loaded query and KVs
function() {
document.getElementById('equationText').textContent = "Computing individual attention states for each KV pair";
document.getElementById('equationPanel').setAttribute('opacity', '1');
},
// Step 1: Show state1
function() {
document.getElementById('state1').setAttribute('opacity', '1');
document.getElementById('equationText').textContent = "s₁ = q·k₁ᵀ, v₁ = softmax(s₁)·v₁";
},
// Step 2: Show state2
function() {
document.getElementById('state2').setAttribute('opacity', '1');
document.getElementById('equationText').textContent = "s₂ = q·k₂ᵀ, v₂ = softmax(s₂)·v₂";
},
// Step 3: Show state3
function() {
document.getElementById('state3').setAttribute('opacity', '1');
document.getElementById('equationText').textContent = "s₃ = q·k₃ᵀ, v₃ = softmax(s₃)·v₃";
},
// Step 4: Show state4
function() {
document.getElementById('state4').setAttribute('opacity', '1');
document.getElementById('equationText').textContent = "s₄ = q·k₄ᵀ, v₄ = softmax(s₄)·v₄";
},
// Step 5: First merge arrows
function() {
document.getElementById('arrow1').setAttribute('opacity', '1');
document.getElementById('arrow2').setAttribute('opacity', '1');
document.getElementById('equationText').textContent = "Merging s₁,v₁ and s₂,v₂ using the ⊕ operator";
},
// Step 6: First merge result
function() {
document.getElementById('merge1').setAttribute('opacity', '1');
document.getElementById('equationText').textContent = "s₁₂ = log(e^s₁ + e^s₂), v₁₂ = (e^s₁·v₁ + e^s₂·v₂)/(e^s₁ + e^s₂)";
},
// Step 7: Second merge arrows
function() {
document.getElementById('arrow3').setAttribute('opacity', '1');
document.getElementById('arrow4').setAttribute('opacity', '1');
document.getElementById('equationText').textContent = "Merging s₃,v₃ and s₄,v₄ using the ⊕ operator";
},
// Step 8: Second merge result
function() {
document.getElementById('merge2').setAttribute('opacity', '1');
document.getElementById('equationText').textContent = "s₃₄ = log(e^s₃ + e^s₄), v₃₄ = (e^s₃·v₃ + e^s₄·v₄)/(e^s₃ + e^s₄)";
},
// Step 9: Final merge arrows
function() {
document.getElementById('arrow5').setAttribute('opacity', '1');
document.getElementById('arrow6').setAttribute('opacity', '1');
document.getElementById('equationText').textContent = "Final merge: Combining s₁₂,v₁₂ and s₃₄,v₃₄";
},
// Step 10: Final result
function() {
document.getElementById('finalMerge').setAttribute('opacity', '1');
document.getElementById('finalMerge').classList.add('highlight');
document.getElementById('equationText').textContent = "s₁₂₃₄ = log(e^s₁₂ + e^s₃₄), v₁₂₃₄ = (e^s₁₂·v₁₂ + e^s₃₄·v₃₄)/(e^s₁₂ + e^s₃₄)";
},
// Step 11: Show final explanation
function() {
document.getElementById('equationText').innerHTML =
"Complete! This is equivalent to standard attention but computed in parallel.";
clearInterval(animationInterval);
animationStep = 0; // Reset for next time
}
];
function resetAnimation() {
// Reset all animated elements
clearInterval(animationInterval);
const stateNodes = document.querySelectorAll('#stateNodes > g');
stateNodes.forEach(node => {
node.setAttribute('opacity', '0');
node.classList.remove('highlight');
});
const arrows = document.querySelectorAll('#mergeArrows > path');
arrows.forEach(arrow => {
arrow.setAttribute('opacity', '0');
});
document.getElementById('equationPanel').setAttribute('opacity', '0');
animationStep = 0;
}
function playAnimation() {
resetAnimation();
// Start the animation sequence
steps[animationStep]();
animationStep++;
animationInterval = setInterval(() => {
if (animationStep < steps.length) {
steps[animationStep]();
animationStep++;
} else {
clearInterval(animationInterval);
}
}, 1500); // Animation step timing
}
// Button handlers
document.getElementById('resetBtn').addEventListener('click', resetAnimation);
document.getElementById('playBtn').addEventListener('click', playAnimation);
// Auto-start animation when page loads
window.addEventListener('load', () => {
setTimeout(playAnimation, 1000);
});
</script>
</body>
</html>