Spaces:
Running
Running
File size: 20,454 Bytes
ce8e9f4 85f61a7 94a5ae8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 |
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>ModernBERT Features Visualization</title>
<script src="https://cdn.tailwindcss.com"></script>
<style>
/* Custom styles for better visualization */
body {
font-family: 'Inter', sans-serif; /* Using Inter font as standard */
}
canvas {
display: block;
background-color: #f3f4f6; /* Light gray background */
border-radius: 0.5rem; /* Rounded corners */
border: 1px solid #d1d5db; /* Gray border */
margin: 1rem auto; /* Center canvas with margin */
}
.token {
display: inline-block; /* Align tokens nicely */
padding: 0.3rem 0.6rem;
margin: 0.2rem;
border-radius: 0.375rem; /* Rounded corners for tokens */
font-size: 0.875rem; /* Smaller font size */
font-weight: 500;
min-width: 30px; /* Ensure minimum width */
text-align: center;
}
.token-real { background-color: #60a5fa; color: white; } /* Blue for real tokens */
.token-pad { background-color: #e5e7eb; color: #6b7280; } /* Gray for padding */
.token-attend { border: 2px solid #f87171; } /* Red border for attending token */
.token-attended { background-color: #fbbf24; color: white; } /* Amber for attended tokens */
.token-local { background-color: #a78bfa; color: white; } /* Violet for local attention window */
/* Ensure canvas is responsive */
#animationCanvas {
width: 100%;
max-width: 800px; /* Limit max width for larger screens */
height: auto; /* Adjust height automatically */
aspect-ratio: 16 / 9; /* Maintain aspect ratio, adjust as needed */
}
/* Style buttons */
button {
transition: all 0.2s ease-in-out; /* Smooth transitions */
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1); /* Subtle shadow */
}
button:hover {
transform: translateY(-1px); /* Slight lift on hover */
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.15); /* Enhanced shadow on hover */
}
button:active {
transform: translateY(0px); /* Press effect */
box-shadow: 0 1px 2px rgba(0, 0, 0, 0.1);
}
</style>
<link href="https://fonts.googleapis.com/css2?family=Inter:wght@400;500;700&display=swap" rel="stylesheet">
</head>
<body class="bg-gray-100 p-4 md:p-8">
<div class="max-w-4xl mx-auto bg-white p-6 rounded-lg shadow-md">
<h1 class="text-2xl md:text-3xl font-bold text-center text-gray-800 mb-6">Visualizing ModernBERT Efficiency Features</h1>
<canvas id="animationCanvas"></canvas>
<div class="flex flex-wrap justify-center gap-3 mt-6 mb-4">
<button id="btnPadding" class="bg-blue-500 hover:bg-blue-600 text-white font-semibold py-2 px-4 rounded-lg shadow">Padding</button>
<button id="btnUnpadding" class="bg-green-500 hover:bg-green-600 text-white font-semibold py-2 px-4 rounded-lg shadow">Unpadding</button>
<button id="btnPacking" class="bg-purple-500 hover:bg-purple-600 text-white font-semibold py-2 px-4 rounded-lg shadow">Sequence Packing</button>
<button id="btnLocalAttn" class="bg-indigo-500 hover:bg-indigo-600 text-white font-semibold py-2 px-4 rounded-lg shadow">Local Attention</button>
<button id="btnGlobalAttn" class="bg-red-500 hover:bg-red-600 text-white font-semibold py-2 px-4 rounded-lg shadow">Global Attention</button>
<button id="btnReset" class="bg-gray-500 hover:bg-gray-600 text-white font-semibold py-2 px-4 rounded-lg shadow">Reset</button>
</div>
<div id="explanation" class="mt-4 p-4 bg-gray-50 border border-gray-200 rounded-lg text-gray-700 min-h-[100px]">
Click a button above to see the animation and explanation.
</div>
<div class="mt-6 p-4 bg-gray-50 border border-gray-200 rounded-lg">
<h3 class="font-semibold mb-2 text-gray-800">Legend:</h3>
<div class="flex flex-wrap gap-2 items-center">
<span class="token token-real">Token</span>
<span class="token token-pad">Pad</span>
<span class="token token-attend">Attending</span>
<span class="token token-attended">Attended</span>
<span class="token token-local">Local Window</span>
</div>
</div>
</div>
<script>
const canvas = document.getElementById('animationCanvas');
const ctx = canvas.getContext('2d');
const explanationDiv = document.getElementById('explanation');
// --- Configuration ---
const tokenWidth = 45;
const tokenHeight = 30;
const padding = 10; // Padding around tokens and between rows
const animationSpeed = 2; // Higher is faster
let sequences = [
['T1', 'T2', 'T3', 'T4', 'T5'],
['T1', 'T2', 'T3'],
['T1', 'T2', 'T3', 'T4', 'T5', 'T6', 'T7'],
['T1', 'T2']
];
const maxSeqLen = 8; // Max length for padding example
const packMaxLen = 10; // Max length for packing example
const localWindowSize = 3; // e.g., attend to self +/- 1
let animationFrameId = null; // To control animation loop
// --- Drawing Functions ---
function drawToken(x, y, text, type = 'real', highlight = 'none') {
ctx.font = '12px Inter';
ctx.textAlign = 'center';
ctx.textBaseline = 'middle';
let bgColor = '#60a5fa'; // token-real (blue)
let textColor = 'white';
let borderColor = null;
if (type === 'pad') {
bgColor = '#e5e7eb'; // token-pad (gray)
textColor = '#6b7280';
} else if (type === 'attended') {
bgColor = '#fbbf24'; // token-attended (amber)
} else if (type === 'local') {
bgColor = '#a78bfa'; // token-local (violet)
}
if (highlight === 'attending') {
borderColor = '#f87171'; // token-attend (red)
}
// Draw background rectangle
ctx.fillStyle = bgColor;
ctx.beginPath();
ctx.roundRect(x, y, tokenWidth, tokenHeight, 5); // Use roundRect for rounded corners
ctx.fill();
// Draw border if needed
if (borderColor) {
ctx.strokeStyle = borderColor;
ctx.lineWidth = 2;
ctx.stroke();
}
// Draw text
ctx.fillStyle = textColor;
ctx.fillText(text, x + tokenWidth / 2, y + tokenHeight / 2);
}
function drawSequence(x, y, sequence, highlightMap = {}) {
sequence.forEach((token, index) => {
const tokenX = x + index * (tokenWidth + padding);
const tokenType = (typeof token === 'string' && token.startsWith('T')) ? 'real' : 'pad';
const highlight = highlightMap[index] || 'none';
drawToken(tokenX, y, token, tokenType, highlight);
});
}
function clearCanvas() {
ctx.clearRect(0, 0, canvas.width, canvas.height);
}
// --- Animation Logic ---
// 1. Padding Animation
function animatePadding() {
cancelAnimationFrame(animationFrameId); // Stop previous animation
explanationDiv.innerHTML = `
<h3 class="font-semibold mb-1">Padding</h3>
Traditional processing requires all sequences in a batch to have the same length. Shorter sequences are "padded" with special (PAD) tokens to match the longest sequence. This wastes computation on meaningless tokens.
`;
let currentLengths = sequences.map(() => 0);
const targetMaxLength = Math.max(...sequences.map(s => s.length));
const paddedSequences = sequences.map(seq => {
const padsNeeded = targetMaxLength - seq.length;
return [...seq, ...Array(padsNeeded).fill('Pad')];
});
let progress = 0; // Represents how many tokens are shown
const totalSteps = targetMaxLength;
function step() {
clearCanvas();
const startY = padding * 3; // Start drawing lower
paddedSequences.forEach((seq, i) => {
const y = startY + i * (tokenHeight + padding * 2); // Increased vertical spacing
const displayLength = Math.min(seq.length, Math.ceil(progress));
drawSequence(padding * 2, y, seq.slice(0, displayLength));
});
if (progress < totalSteps) {
progress += animationSpeed / 10; // Adjust speed
animationFrameId = requestAnimationFrame(step);
}
}
step();
}
// 2. Unpadding Animation
function animateUnpadding() {
cancelAnimationFrame(animationFrameId);
explanationDiv.innerHTML = `
<h3 class="font-semibold mb-1">Unpadding</h3>
Unpadding removes the PAD tokens before processing. Sequences are treated individually (conceptually), avoiding wasted computation. ModernBERT concatenates these unpadded sequences.
`;
const targetMaxLength = Math.max(...sequences.map(s => s.length));
const paddedSequences = sequences.map(seq => {
const padsNeeded = targetMaxLength - seq.length;
return [...seq, ...Array(padsNeeded).fill('Pad')];
});
let fadeAmount = 1; // 1 = fully visible, 0 = invisible
function step() {
clearCanvas();
const startY = padding * 3;
ctx.globalAlpha = fadeAmount; // Apply fade effect
paddedSequences.forEach((seq, i) => {
const y = startY + i * (tokenHeight + padding * 2);
seq.forEach((token, index) => {
const tokenX = padding * 2 + index * (tokenWidth + padding);
const isPad = token === 'Pad';
// Only fade out padding tokens
ctx.globalAlpha = isPad ? fadeAmount : 1;
drawToken(tokenX, y, token, isPad ? 'pad' : 'real');
});
});
ctx.globalAlpha = 1; // Reset alpha
if (fadeAmount > 0) {
fadeAmount -= 0.02 * animationSpeed; // Fade out speed
animationFrameId = requestAnimationFrame(step);
} else {
// After fading, show only original tokens
clearCanvas();
sequences.forEach((seq, i) => {
const y = startY + i * (tokenHeight + padding * 2);
drawSequence(padding * 2, y, seq);
});
}
}
step();
}
// 3. Sequence Packing Animation
function animateSequencePacking() {
cancelAnimationFrame(animationFrameId);
explanationDiv.innerHTML = `
<h3 class="font-semibold mb-1">Sequence Packing</h3>
After unpadding, sequences are concatenated (packed) together into longer sequences, up to the model's maximum length (e.g., ${packMaxLen} here). This maximizes GPU utilization by processing more real tokens per batch. Careful masking ensures tokens only attend within their original sequence.
`;
let packedSequences = [];
let currentPack = [];
let currentLen = 0;
sequences.forEach(seq => {
if (currentLen + seq.length <= packMaxLen) {
currentPack.push(...seq);
currentLen += seq.length;
} else {
packedSequences.push([...currentPack]);
currentPack = [...seq];
currentLen = seq.length;
}
});
if (currentPack.length > 0) {
packedSequences.push(currentPack);
}
let progress = 0; // How many sequences are shown packed
const totalSequences = packedSequences.length;
function step() {
clearCanvas();
const startY = padding * 3;
let currentY = startY;
// Draw original sequences first (fading out)
const fade = 1 - (progress / totalSequences);
ctx.globalAlpha = Math.max(0, fade);
sequences.forEach((seq, i) => {
const y = startY + i * (tokenHeight + padding * 2);
drawSequence(padding * 2, y, seq);
});
ctx.globalAlpha = 1.0;
// Draw packed sequences (fading in)
const fadeIn = progress / totalSequences;
ctx.globalAlpha = Math.min(1, fadeIn * 2); // Faster fade in
packedSequences.slice(0, Math.ceil(progress)).forEach((pack, i) => {
const y = startY + i * (tokenHeight + padding * 2); // Draw packed below originals initially
drawSequence(padding * 2, y, pack);
currentY = y + tokenHeight + padding * 2;
});
ctx.globalAlpha = 1.0;
if (progress < totalSequences) {
progress += animationSpeed / 20; // Adjust speed
animationFrameId = requestAnimationFrame(step);
} else {
// Final state: only packed sequences
clearCanvas();
packedSequences.forEach((pack, i) => {
const y = startY + i * (tokenHeight + padding * 2);
drawSequence(padding * 2, y, pack);
});
}
}
step();
}
// 4. Attention Animations (Local/Global)
function animateAttention(isGlobal) {
cancelAnimationFrame(animationFrameId);
const seq = sequences[2]; // Use the longest sequence for demo
const midIndex = Math.floor(seq.length / 2); // Token that will 'attend'
explanationDiv.innerHTML = `
<h3 class="font-semibold mb-1">${isGlobal ? 'Global' : 'Local'} Attention</h3>
Attention allows tokens to "look" at other tokens to understand context.
<b>${isGlobal ? 'Global Attention:' : 'Local Attention:'}</b>
${isGlobal
? 'Every token attends to every other token in the sequence. Powerful but computationally expensive, especially for long sequences.'
: `Each token attends only to a fixed-size window of nearby tokens (e.g., +/- ${Math.floor(localWindowSize / 2)} here). Much faster for long sequences.`
} ModernBERT alternates between these layers.
`;
let highlightProgress = 0; // 0 to 1
function step() {
clearCanvas();
const startY = padding * 3;
const startX = padding * 2;
let highlightMap = {};
highlightMap[midIndex] = 'attending'; // The token doing the attending
const currentHighlight = Math.min(1, highlightProgress);
if (isGlobal) {
// Highlight all tokens based on progress
for (let i = 0; i < seq.length; i++) {
if (i !== midIndex) {
// Simple linear fade-in for attended tokens
if (Math.random() < currentHighlight) { // Randomly highlight based on progress for effect
highlightMap[i] = 'attended';
}
}
}
} else { // Local Attention
const windowStart = Math.max(0, midIndex - Math.floor(localWindowSize / 2));
const windowEnd = Math.min(seq.length - 1, midIndex + Math.floor(localWindowSize / 2));
for (let i = windowStart; i <= windowEnd; i++) {
// Highlight based on progress within the window
if (i !== midIndex) {
const dist = Math.abs(i - midIndex);
const requiredProgress = dist / (localWindowSize / 2); // Closer tokens highlight sooner
if (currentHighlight >= requiredProgress) {
highlightMap[i] = 'local'; // Use 'local' type for window visualization
}
}
}
}
drawSequence(startX, startY, seq, highlightMap);
if (highlightProgress < 1) {
highlightProgress += 0.01 * animationSpeed;
animationFrameId = requestAnimationFrame(step);
} else {
// Ensure final state is fully highlighted
clearCanvas();
highlightMap = {};
highlightMap[midIndex] = 'attending';
if (isGlobal) {
for (let i = 0; i < seq.length; i++) if (i !== midIndex) highlightMap[i] = 'attended';
} else {
const windowStart = Math.max(0, midIndex - Math.floor(localWindowSize / 2));
const windowEnd = Math.min(seq.length - 1, midIndex + Math.floor(localWindowSize / 2));
for (let i = windowStart; i <= windowEnd; i++) if (i !== midIndex) highlightMap[i] = 'local';
}
drawSequence(startX, startY, seq, highlightMap);
}
}
step();
}
// --- Reset Function ---
function resetVisualization() {
cancelAnimationFrame(animationFrameId);
clearCanvas();
explanationDiv.innerHTML = 'Click a button above to see the animation and explanation.';
// Optionally redraw initial state if needed
// drawInitialState(); // Implement if you want a default view
}
// --- Event Listeners ---
document.getElementById('btnPadding').addEventListener('click', animatePadding);
document.getElementById('btnUnpadding').addEventListener('click', animateUnpadding);
document.getElementById('btnPacking').addEventListener('click', animateSequencePacking);
document.getElementById('btnLocalAttn').addEventListener('click', () => animateAttention(false));
document.getElementById('btnGlobalAttn').addEventListener('click', () => animateAttention(true));
document.getElementById('btnReset').addEventListener('click', resetVisualization);
// --- Initial Setup & Resize ---
function resizeCanvas() {
// Make canvas resolution match its display size
const displayWidth = canvas.clientWidth;
const displayHeight = canvas.clientHeight; // Use clientHeight for aspect ratio consistency
if (canvas.width !== displayWidth || canvas.height !== displayHeight) {
canvas.width = displayWidth;
canvas.height = displayHeight;
// Redraw current state if an animation was running? Or just reset?
// For simplicity, we'll reset on resize.
resetVisualization();
}
}
// Initial resize and setup listener
window.addEventListener('resize', resizeCanvas);
// Ensure initial sizing is correct after elements are laid out
window.addEventListener('load', () => {
resizeCanvas();
resetVisualization(); // Start clean
});
</script>
</body>
</html>
|