|
<!DOCTYPE html> |
|
<html lang="en"> |
|
<head> |
|
<meta charset="UTF-8"> |
|
<meta name="viewport" content="width=device-width, initial-scale=1.0"> |
|
<title>FlashInfer: Attention States & Recursive Merge</title> |
|
<script src="https://cdn.tailwindcss.com"></script> |
|
<link href="https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&display=swap" rel="stylesheet"> |
|
<style> |
|
body { |
|
font-family: 'Inter', sans-serif; |
|
overflow-x: hidden; |
|
} |
|
.math { |
|
font-family: 'Times New Roman', serif; |
|
font-style: italic; |
|
} |
|
.node { |
|
transition: all 0.5s ease-in-out; |
|
} |
|
.arrow { |
|
stroke-dasharray: 10; |
|
animation: dash 1s linear infinite; |
|
} |
|
@keyframes dash { |
|
to { |
|
stroke-dashoffset: -20; |
|
} |
|
} |
|
.highlight { |
|
animation: pulse 2s infinite; |
|
} |
|
@keyframes pulse { |
|
0% { transform: scale(1); opacity: 1; } |
|
50% { transform: scale(1.05); opacity: 0.9; } |
|
100% { transform: scale(1); opacity: 1; } |
|
} |
|
.fade-in { |
|
opacity: 0; |
|
animation: fadeIn 1s forwards; |
|
} |
|
@keyframes fadeIn { |
|
to { opacity: 1; } |
|
} |
|
.equation { |
|
transition: all 0.5s ease; |
|
} |
|
.panel { |
|
transition: transform 0.5s ease-out; |
|
} |
|
</style> |
|
</head> |
|
<body class="bg-gray-100 text-gray-900"> |
|
<div class="container mx-auto p-4 max-w-6xl"> |
|
<header class="text-center my-8"> |
|
<h1 class="text-4xl font-bold text-blue-800 mb-2">FlashInfer: Attention States & Recursive Merge</h1> |
|
<p class="text-xl text-gray-600">Visualizing how FlashInfer accelerates LLM inference</p> |
|
</header> |
|
|
|
<div class="bg-white rounded-xl shadow-lg p-6 mb-8"> |
|
<h2 class="text-2xl font-semibold mb-4 text-blue-700">Key Innovation: Attention States</h2> |
|
<p class="mb-4">FlashInfer introduces the concept of <strong>attention states</strong>, which fully characterize the attention between a query and a set of key/value pairs. Each attention state consists of two components:</p> |
|
|
|
<div class="flex flex-col md:flex-row gap-4 mb-6"> |
|
<div class="flex-1 bg-blue-50 p-4 rounded-lg"> |
|
<h3 class="font-medium text-blue-800 mb-2">Generalized Score (s)</h3> |
|
<div class="flex justify-center"> |
|
<div class="math text-xl"> |
|
s(I) = log(∑<sub>i∈I</sub> exp(s<sub>i</sub>)) |
|
</div> |
|
</div> |
|
<p class="mt-2 text-sm text-gray-600">The log-sum-exp (LSE) of pre-softmax attention scores</p> |
|
</div> |
|
|
|
<div class="flex-1 bg-blue-50 p-4 rounded-lg"> |
|
<h3 class="font-medium text-blue-800 mb-2">Generalized Value (v)</h3> |
|
<div class="flex justify-center"> |
|
<div class="math text-xl"> |
|
v(I) = ∑<sub>i∈I</sub> softmax(s<sub>i</sub>)v<sub>i</sub> |
|
</div> |
|
</div> |
|
<p class="mt-2 text-sm text-gray-600">The weighted sum of value vectors using the softmax of scores</p> |
|
</div> |
|
</div> |
|
</div> |
|
|
|
<div class="bg-white rounded-xl shadow-lg p-6 mb-8"> |
|
<h2 class="text-2xl font-semibold mb-4 text-blue-700">Recursive Merge Operator</h2> |
|
<p class="mb-4">The key insight of FlashInfer is that attention states can be <strong>merged</strong> efficiently. Given two attention states corresponding to different subsets of KV pairs, we can compute the attention state for their union:</p> |
|
|
|
<div class="bg-blue-50 p-4 rounded-lg mb-6"> |
|
<div class="flex justify-center"> |
|
<div class="math text-xl"> |
|
[v(I∪J), s(I∪J)] = [v(I), s(I)] ⊕ [v(J), s(J)] |
|
</div> |
|
</div> |
|
</div> |
|
|
|
<p class="mb-6">This merge operator (⊕) is <strong>commutative</strong> and <strong>associative</strong>, allowing flexible computation strategies.</p> |
|
</div> |
|
|
|
|
|
<div class="bg-white rounded-xl shadow-lg p-6 mb-8"> |
|
<h2 class="text-2xl font-semibold mb-4 text-blue-700">Interactive Visualization</h2> |
|
<p class="mb-6">This animation shows how FlashInfer computes attention for a query over 4 KV pairs by partitioning the work and merging results.</p> |
|
|
|
<div class="flex justify-center"> |
|
<div class="relative" style="width: 700px; height: 650px;" id="visualization"> |
|
|
|
</div> |
|
</div> |
|
|
|
<div class="flex justify-center mt-4"> |
|
<button id="resetBtn" class="bg-blue-600 hover:bg-blue-700 text-white px-4 py-2 rounded-lg mr-4"> |
|
Reset Animation |
|
</button> |
|
<button id="playBtn" class="bg-green-600 hover:bg-green-700 text-white px-4 py-2 rounded-lg"> |
|
Play Animation |
|
</button> |
|
</div> |
|
</div> |
|
|
|
<div class="bg-white rounded-xl shadow-lg p-6 mb-8"> |
|
<h2 class="text-2xl font-semibold mb-4 text-blue-700">Applications</h2> |
|
|
|
<div class="grid grid-cols-1 md:grid-cols-2 gap-6"> |
|
<div class="bg-blue-50 p-4 rounded-lg"> |
|
<h3 class="font-medium text-blue-800 mb-2">Shared-Prefix Batch Decoding</h3> |
|
<p>When multiple sequences share a common prefix (e.g., same prompt), compute the attention state for the shared part once, then merge with each sequence's unique suffix.</p> |
|
<p class="mt-2 text-sm font-medium text-green-600">Up to 30x speedup in long-context scenarios</p> |
|
</div> |
|
|
|
<div class="bg-blue-50 p-4 rounded-lg"> |
|
<h3 class="font-medium text-blue-800 mb-2">KV Sequence Parallelism</h3> |
|
<p>Partition long KV sequences across multiple processing units, compute partial attention states in parallel, then merge the results.</p> |
|
<p class="mt-2 text-sm font-medium text-green-600">Improves GPU utilization for memory-constrained scenarios</p> |
|
</div> |
|
</div> |
|
</div> |
|
</div> |
|
|
|
<script> |
|
// SVG for the visualization |
|
const svg = ` |
|
<svg width="100%" height="100%" viewBox="0 0 700 650"> |
|
|
|
<g class="fade-in" style="animation-delay: 0.5s"> |
|
<circle id="queryNode" cx="350" cy="50" r="30" fill="#4299e1" /> |
|
<text x="350" y="55" text-anchor="middle" fill="white" font-weight="bold">Q</text> |
|
</g> |
|
|
|
|
|
<g class="fade-in" style="animation-delay: 1s"> |
|
<circle id="kv1" cx="175" cy="150" r="25" fill="#9ae6b4" /> |
|
<text x="175" y="155" text-anchor="middle" fill="#2f855a" font-weight="bold">KV₁</text> |
|
|
|
<circle id="kv2" cx="275" cy="150" r="25" fill="#9ae6b4" /> |
|
<text x="275" y="155" text-anchor="middle" fill="#2f855a" font-weight="bold">KV₂</text> |
|
|
|
<circle id="kv3" cx="425" cy="150" r="25" fill="#9ae6b4" /> |
|
<text x="425" y="155" text-anchor="middle" fill="#2f855a" font-weight="bold">KV₃</text> |
|
|
|
<circle id="kv4" cx="525" cy="150" r="25" fill="#9ae6b4" /> |
|
<text x="525" y="155" text-anchor="middle" fill="#2f855a" font-weight="bold">KV₄</text> |
|
</g> |
|
|
|
|
|
<g class="fade-in" style="animation-delay: 1.5s"> |
|
<path id="line1" d="M 350 80 L 175 125" stroke="#4299e1" stroke-width="2" fill="none" /> |
|
<path id="line2" d="M 350 80 L 275 125" stroke="#4299e1" stroke-width="2" fill="none" /> |
|
<path id="line3" d="M 350 80 L 425 125" stroke="#4299e1" stroke-width="2" fill="none" /> |
|
<path id="line4" d="M 350 80 L 525 125" stroke="#4299e1" stroke-width="2" fill="none" /> |
|
</g> |
|
|
|
|
|
<g id="stateNodes"> |
|
|
|
<g id="state1" opacity="0"> |
|
<circle cx="175" cy="250" r="25" fill="#feb2b2" /> |
|
<text x="175" y="255" text-anchor="middle" fill="#742a2a" font-weight="bold">s₁,v₁</text> |
|
</g> |
|
|
|
<g id="state2" opacity="0"> |
|
<circle cx="275" cy="250" r="25" fill="#feb2b2" /> |
|
<text x="275" y="255" text-anchor="middle" fill="#742a2a" font-weight="bold">s₂,v₂</text> |
|
</g> |
|
|
|
<g id="state3" opacity="0"> |
|
<circle cx="425" cy="250" r="25" fill="#feb2b2" /> |
|
<text x="425" y="255" text-anchor="middle" fill="#742a2a" font-weight="bold">s₃,v₃</text> |
|
</g> |
|
|
|
<g id="state4" opacity="0"> |
|
<circle cx="525" cy="250" r="25" fill="#feb2b2" /> |
|
<text x="525" y="255" text-anchor="middle" fill="#742a2a" font-weight="bold">s₄,v₄</text> |
|
</g> |
|
|
|
|
|
<g id="merge1" opacity="0"> |
|
<circle cx="225" cy="350" r="30" fill="#fbd38d" /> |
|
<text x="225" y="355" text-anchor="middle" fill="#7b341e" font-weight="bold">s₁₂,v₁₂</text> |
|
</g> |
|
|
|
<g id="merge2" opacity="0"> |
|
<circle cx="475" cy="350" r="30" fill="#fbd38d" /> |
|
<text x="475" y="355" text-anchor="middle" fill="#7b341e" font-weight="bold">s₃₄,v₃₄</text> |
|
</g> |
|
|
|
|
|
<g id="finalMerge" opacity="0"> |
|
<circle cx="350" cy="450" r="35" fill="#d6bcfa" /> |
|
<text x="350" y="455" text-anchor="middle" fill="#553c9a" font-weight="bold">s₁₂₃₄,v₁₂₃₄</text> |
|
</g> |
|
</g> |
|
|
|
|
|
<g id="mergeArrows"> |
|
|
|
<path id="arrow1" opacity="0" d="M 175 275 L 210 325" stroke="#f6ad55" stroke-width="3" fill="none" class="arrow" marker-end="url(#arrowhead)" /> |
|
<path id="arrow2" opacity="0" d="M 275 275 L 240 325" stroke="#f6ad55" stroke-width="3" fill="none" class="arrow" marker-end="url(#arrowhead)" /> |
|
<path id="arrow3" opacity="0" d="M 425 275 L 460 325" stroke="#f6ad55" stroke-width="3" fill="none" class="arrow" marker-end="url(#arrowhead)" /> |
|
<path id="arrow4" opacity="0" d="M 525 275 L 490 325" stroke="#f6ad55" stroke-width="3" fill="none" class="arrow" marker-end="url(#arrowhead)" /> |
|
|
|
|
|
<path id="arrow5" opacity="0" d="M 225 380 L 320 425" stroke="#b794f4" stroke-width="3" fill="none" class="arrow" marker-end="url(#arrowhead)" /> |
|
<path id="arrow6" opacity="0" d="M 475 380 L 380 425" stroke="#b794f4" stroke-width="3" fill="none" class="arrow" marker-end="url(#arrowhead)" /> |
|
</g> |
|
|
|
|
|
<g id="equationPanel" opacity="0" transform="translate(350, 580)"> |
|
<rect x="-330" y="-48" width="660" height="72" rx="10" fill="#ebf8ff" stroke="#4299e1" stroke-width="2" /> |
|
<text id="equationText" x="0" y="0" text-anchor="middle" font-family="monospace" font-size="14"> |
|
Attention States: Computing... |
|
</text> |
|
</g> |
|
|
|
|
|
<defs> |
|
<marker id="arrowhead" markerWidth="10" markerHeight="7" refX="9" refY="3.5" orient="auto"> |
|
<polygon points="0 0, 10 3.5, 0 7" fill="#000" /> |
|
</marker> |
|
</defs> |
|
</svg>`; |
|
|
|
// Insert the SVG |
|
document.getElementById('visualization').innerHTML = svg; |
|
|
|
// Animation sequence |
|
let animationStep = 0; |
|
let animationInterval; |
|
const steps = [ |
|
// Step 0: Already loaded query and KVs |
|
function() { |
|
document.getElementById('equationText').textContent = "Computing individual attention states for each KV pair"; |
|
document.getElementById('equationPanel').setAttribute('opacity', '1'); |
|
}, |
|
// Step 1: Show state1 |
|
function() { |
|
document.getElementById('state1').setAttribute('opacity', '1'); |
|
document.getElementById('equationText').textContent = "s₁ = q·k₁ᵀ, v₁ = softmax(s₁)·v₁"; |
|
}, |
|
// Step 2: Show state2 |
|
function() { |
|
document.getElementById('state2').setAttribute('opacity', '1'); |
|
document.getElementById('equationText').textContent = "s₂ = q·k₂ᵀ, v₂ = softmax(s₂)·v₂"; |
|
}, |
|
// Step 3: Show state3 |
|
function() { |
|
document.getElementById('state3').setAttribute('opacity', '1'); |
|
document.getElementById('equationText').textContent = "s₃ = q·k₃ᵀ, v₃ = softmax(s₃)·v₃"; |
|
}, |
|
// Step 4: Show state4 |
|
function() { |
|
document.getElementById('state4').setAttribute('opacity', '1'); |
|
document.getElementById('equationText').textContent = "s₄ = q·k₄ᵀ, v₄ = softmax(s₄)·v₄"; |
|
}, |
|
// Step 5: First merge arrows |
|
function() { |
|
document.getElementById('arrow1').setAttribute('opacity', '1'); |
|
document.getElementById('arrow2').setAttribute('opacity', '1'); |
|
document.getElementById('equationText').textContent = "Merging s₁,v₁ and s₂,v₂ using the ⊕ operator"; |
|
}, |
|
// Step 6: First merge result |
|
function() { |
|
document.getElementById('merge1').setAttribute('opacity', '1'); |
|
document.getElementById('equationText').textContent = "s₁₂ = log(e^s₁ + e^s₂), v₁₂ = (e^s₁·v₁ + e^s₂·v₂)/(e^s₁ + e^s₂)"; |
|
}, |
|
// Step 7: Second merge arrows |
|
function() { |
|
document.getElementById('arrow3').setAttribute('opacity', '1'); |
|
document.getElementById('arrow4').setAttribute('opacity', '1'); |
|
document.getElementById('equationText').textContent = "Merging s₃,v₃ and s₄,v₄ using the ⊕ operator"; |
|
}, |
|
// Step 8: Second merge result |
|
function() { |
|
document.getElementById('merge2').setAttribute('opacity', '1'); |
|
document.getElementById('equationText').textContent = "s₃₄ = log(e^s₃ + e^s₄), v₃₄ = (e^s₃·v₃ + e^s₄·v₄)/(e^s₃ + e^s₄)"; |
|
}, |
|
// Step 9: Final merge arrows |
|
function() { |
|
document.getElementById('arrow5').setAttribute('opacity', '1'); |
|
document.getElementById('arrow6').setAttribute('opacity', '1'); |
|
document.getElementById('equationText').textContent = "Final merge: Combining s₁₂,v₁₂ and s₃₄,v₃₄"; |
|
}, |
|
// Step 10: Final result |
|
function() { |
|
document.getElementById('finalMerge').setAttribute('opacity', '1'); |
|
document.getElementById('finalMerge').classList.add('highlight'); |
|
document.getElementById('equationText').textContent = "s₁₂₃₄ = log(e^s₁₂ + e^s₃₄), v₁₂₃₄ = (e^s₁₂·v₁₂ + e^s₃₄·v₃₄)/(e^s₁₂ + e^s₃₄)"; |
|
}, |
|
// Step 11: Show final explanation |
|
function() { |
|
document.getElementById('equationText').innerHTML = |
|
"Complete! This is equivalent to standard attention but computed in parallel."; |
|
clearInterval(animationInterval); |
|
animationStep = 0; // Reset for next time |
|
} |
|
]; |
|
|
|
function resetAnimation() { |
|
// Reset all animated elements |
|
clearInterval(animationInterval); |
|
|
|
const stateNodes = document.querySelectorAll('#stateNodes > g'); |
|
stateNodes.forEach(node => { |
|
node.setAttribute('opacity', '0'); |
|
node.classList.remove('highlight'); |
|
}); |
|
|
|
const arrows = document.querySelectorAll('#mergeArrows > path'); |
|
arrows.forEach(arrow => { |
|
arrow.setAttribute('opacity', '0'); |
|
}); |
|
|
|
document.getElementById('equationPanel').setAttribute('opacity', '0'); |
|
|
|
animationStep = 0; |
|
} |
|
|
|
function playAnimation() { |
|
resetAnimation(); |
|
|
|
// Start the animation sequence |
|
steps[animationStep](); |
|
animationStep++; |
|
|
|
animationInterval = setInterval(() => { |
|
if (animationStep < steps.length) { |
|
steps[animationStep](); |
|
animationStep++; |
|
} else { |
|
clearInterval(animationInterval); |
|
} |
|
}, 1500); // Animation step timing |
|
} |
|
|
|
// Button handlers |
|
document.getElementById('resetBtn').addEventListener('click', resetAnimation); |
|
document.getElementById('playBtn').addEventListener('click', playAnimation); |
|
|
|
// Auto-start animation when page loads |
|
window.addEventListener('load', () => { |
|
setTimeout(playAnimation, 1000); |
|
}); |
|
</script> |
|
</body> |
|
</html> |