File size: 20,454 Bytes
ce8e9f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85f61a7
94a5ae8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <title>ModernBERT Features Visualization</title>
    <script src="https://cdn.tailwindcss.com"></script>
    <style>
        /* Custom styles for better visualization */
        body {
            font-family: 'Inter', sans-serif; /* Using Inter font as standard */
        }
        canvas {
            display: block;
            background-color: #f3f4f6; /* Light gray background */
            border-radius: 0.5rem; /* Rounded corners */
            border: 1px solid #d1d5db; /* Gray border */
            margin: 1rem auto; /* Center canvas with margin */
        }
        .token {
            display: inline-block; /* Align tokens nicely */
            padding: 0.3rem 0.6rem;
            margin: 0.2rem;
            border-radius: 0.375rem; /* Rounded corners for tokens */
            font-size: 0.875rem; /* Smaller font size */
            font-weight: 500;
            min-width: 30px; /* Ensure minimum width */
            text-align: center;
        }
        .token-real { background-color: #60a5fa; color: white; } /* Blue for real tokens */
        .token-pad { background-color: #e5e7eb; color: #6b7280; } /* Gray for padding */
        .token-attend { border: 2px solid #f87171; } /* Red border for attending token */
        .token-attended { background-color: #fbbf24; color: white; } /* Amber for attended tokens */
        .token-local { background-color: #a78bfa; color: white; } /* Violet for local attention window */

        /* Ensure canvas is responsive */
        #animationCanvas {
            width: 100%;
            max-width: 800px; /* Limit max width for larger screens */
            height: auto; /* Adjust height automatically */
            aspect-ratio: 16 / 9; /* Maintain aspect ratio, adjust as needed */
        }

        /* Style buttons */
        button {
            transition: all 0.2s ease-in-out; /* Smooth transitions */
            box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1); /* Subtle shadow */
        }
        button:hover {
            transform: translateY(-1px); /* Slight lift on hover */
            box-shadow: 0 4px 8px rgba(0, 0, 0, 0.15); /* Enhanced shadow on hover */
        }
         button:active {
            transform: translateY(0px); /* Press effect */
            box-shadow: 0 1px 2px rgba(0, 0, 0, 0.1);
        }
    </style>
    <link href="https://fonts.googleapis.com/css2?family=Inter:wght@400;500;700&display=swap" rel="stylesheet">
</head>
<body class="bg-gray-100 p-4 md:p-8">

    <div class="max-w-4xl mx-auto bg-white p-6 rounded-lg shadow-md">
        <h1 class="text-2xl md:text-3xl font-bold text-center text-gray-800 mb-6">Visualizing ModernBERT Efficiency Features</h1>

        <canvas id="animationCanvas"></canvas>

        <div class="flex flex-wrap justify-center gap-3 mt-6 mb-4">
            <button id="btnPadding" class="bg-blue-500 hover:bg-blue-600 text-white font-semibold py-2 px-4 rounded-lg shadow">Padding</button>
            <button id="btnUnpadding" class="bg-green-500 hover:bg-green-600 text-white font-semibold py-2 px-4 rounded-lg shadow">Unpadding</button>
            <button id="btnPacking" class="bg-purple-500 hover:bg-purple-600 text-white font-semibold py-2 px-4 rounded-lg shadow">Sequence Packing</button>
            <button id="btnLocalAttn" class="bg-indigo-500 hover:bg-indigo-600 text-white font-semibold py-2 px-4 rounded-lg shadow">Local Attention</button>
            <button id="btnGlobalAttn" class="bg-red-500 hover:bg-red-600 text-white font-semibold py-2 px-4 rounded-lg shadow">Global Attention</button>
             <button id="btnReset" class="bg-gray-500 hover:bg-gray-600 text-white font-semibold py-2 px-4 rounded-lg shadow">Reset</button>
        </div>

        <div id="explanation" class="mt-4 p-4 bg-gray-50 border border-gray-200 rounded-lg text-gray-700 min-h-[100px]">
            Click a button above to see the animation and explanation.
        </div>

        <div class="mt-6 p-4 bg-gray-50 border border-gray-200 rounded-lg">
            <h3 class="font-semibold mb-2 text-gray-800">Legend:</h3>
            <div class="flex flex-wrap gap-2 items-center">
                <span class="token token-real">Token</span>
                <span class="token token-pad">Pad</span>
                <span class="token token-attend">Attending</span>
                <span class="token token-attended">Attended</span>
                 <span class="token token-local">Local Window</span>
            </div>
        </div>
    </div>

    <script>
        const canvas = document.getElementById('animationCanvas');
        const ctx = canvas.getContext('2d');
        const explanationDiv = document.getElementById('explanation');

        // --- Configuration ---
        const tokenWidth = 45;
        const tokenHeight = 30;
        const padding = 10; // Padding around tokens and between rows
        const animationSpeed = 2; // Higher is faster

        let sequences = [
            ['T1', 'T2', 'T3', 'T4', 'T5'],
            ['T1', 'T2', 'T3'],
            ['T1', 'T2', 'T3', 'T4', 'T5', 'T6', 'T7'],
            ['T1', 'T2']
        ];
        const maxSeqLen = 8; // Max length for padding example
        const packMaxLen = 10; // Max length for packing example
        const localWindowSize = 3; // e.g., attend to self +/- 1

        let animationFrameId = null; // To control animation loop

        // --- Drawing Functions ---
        function drawToken(x, y, text, type = 'real', highlight = 'none') {
            ctx.font = '12px Inter';
            ctx.textAlign = 'center';
            ctx.textBaseline = 'middle';

            let bgColor = '#60a5fa'; // token-real (blue)
            let textColor = 'white';
            let borderColor = null;

            if (type === 'pad') {
                bgColor = '#e5e7eb'; // token-pad (gray)
                textColor = '#6b7280';
            } else if (type === 'attended') {
                 bgColor = '#fbbf24'; // token-attended (amber)
            } else if (type === 'local') {
                bgColor = '#a78bfa'; // token-local (violet)
            }

            if (highlight === 'attending') {
                borderColor = '#f87171'; // token-attend (red)
            }

            // Draw background rectangle
            ctx.fillStyle = bgColor;
            ctx.beginPath();
            ctx.roundRect(x, y, tokenWidth, tokenHeight, 5); // Use roundRect for rounded corners
            ctx.fill();

            // Draw border if needed
            if (borderColor) {
                ctx.strokeStyle = borderColor;
                ctx.lineWidth = 2;
                ctx.stroke();
            }

            // Draw text
            ctx.fillStyle = textColor;
            ctx.fillText(text, x + tokenWidth / 2, y + tokenHeight / 2);
        }

        function drawSequence(x, y, sequence, highlightMap = {}) {
            sequence.forEach((token, index) => {
                const tokenX = x + index * (tokenWidth + padding);
                const tokenType = (typeof token === 'string' && token.startsWith('T')) ? 'real' : 'pad';
                const highlight = highlightMap[index] || 'none';
                drawToken(tokenX, y, token, tokenType, highlight);
            });
        }

        function clearCanvas() {
            ctx.clearRect(0, 0, canvas.width, canvas.height);
        }

        // --- Animation Logic ---

        // 1. Padding Animation
        function animatePadding() {
            cancelAnimationFrame(animationFrameId); // Stop previous animation
            explanationDiv.innerHTML = `
                <h3 class="font-semibold mb-1">Padding</h3>
                Traditional processing requires all sequences in a batch to have the same length. Shorter sequences are "padded" with special (PAD) tokens to match the longest sequence. This wastes computation on meaningless tokens.
            `;
            let currentLengths = sequences.map(() => 0);
            const targetMaxLength = Math.max(...sequences.map(s => s.length));
            const paddedSequences = sequences.map(seq => {
                const padsNeeded = targetMaxLength - seq.length;
                return [...seq, ...Array(padsNeeded).fill('Pad')];
            });

            let progress = 0; // Represents how many tokens are shown
            const totalSteps = targetMaxLength;

            function step() {
                clearCanvas();
                const startY = padding * 3; // Start drawing lower

                paddedSequences.forEach((seq, i) => {
                    const y = startY + i * (tokenHeight + padding * 2); // Increased vertical spacing
                    const displayLength = Math.min(seq.length, Math.ceil(progress));
                    drawSequence(padding * 2, y, seq.slice(0, displayLength));
                });

                if (progress < totalSteps) {
                    progress += animationSpeed / 10; // Adjust speed
                    animationFrameId = requestAnimationFrame(step);
                }
            }
            step();
        }

        // 2. Unpadding Animation
        function animateUnpadding() {
            cancelAnimationFrame(animationFrameId);
             explanationDiv.innerHTML = `
                <h3 class="font-semibold mb-1">Unpadding</h3>
                Unpadding removes the PAD tokens before processing. Sequences are treated individually (conceptually), avoiding wasted computation. ModernBERT concatenates these unpadded sequences.
            `;
            const targetMaxLength = Math.max(...sequences.map(s => s.length));
            const paddedSequences = sequences.map(seq => {
                const padsNeeded = targetMaxLength - seq.length;
                return [...seq, ...Array(padsNeeded).fill('Pad')];
            });

            let fadeAmount = 1; // 1 = fully visible, 0 = invisible

            function step() {
                clearCanvas();
                const startY = padding * 3;

                ctx.globalAlpha = fadeAmount; // Apply fade effect

                paddedSequences.forEach((seq, i) => {
                    const y = startY + i * (tokenHeight + padding * 2);
                    seq.forEach((token, index) => {
                         const tokenX = padding * 2 + index * (tokenWidth + padding);
                         const isPad = token === 'Pad';
                         // Only fade out padding tokens
                         ctx.globalAlpha = isPad ? fadeAmount : 1;
                         drawToken(tokenX, y, token, isPad ? 'pad' : 'real');
                    });
                });

                 ctx.globalAlpha = 1; // Reset alpha

                if (fadeAmount > 0) {
                    fadeAmount -= 0.02 * animationSpeed; // Fade out speed
                    animationFrameId = requestAnimationFrame(step);
                } else {
                    // After fading, show only original tokens
                    clearCanvas();
                    sequences.forEach((seq, i) => {
                        const y = startY + i * (tokenHeight + padding * 2);
                        drawSequence(padding * 2, y, seq);
                    });
                }
            }
            step();
        }

        // 3. Sequence Packing Animation
        function animateSequencePacking() {
            cancelAnimationFrame(animationFrameId);
             explanationDiv.innerHTML = `
                <h3 class="font-semibold mb-1">Sequence Packing</h3>
                After unpadding, sequences are concatenated (packed) together into longer sequences, up to the model's maximum length (e.g., ${packMaxLen} here). This maximizes GPU utilization by processing more real tokens per batch. Careful masking ensures tokens only attend within their original sequence.
            `;

            let packedSequences = [];
            let currentPack = [];
            let currentLen = 0;

            sequences.forEach(seq => {
                if (currentLen + seq.length <= packMaxLen) {
                    currentPack.push(...seq);
                    currentLen += seq.length;
                } else {
                    packedSequences.push([...currentPack]);
                    currentPack = [...seq];
                    currentLen = seq.length;
                }
            });
            if (currentPack.length > 0) {
                packedSequences.push(currentPack);
            }

            let progress = 0; // How many sequences are shown packed
            const totalSequences = packedSequences.length;

             function step() {
                clearCanvas();
                const startY = padding * 3;
                let currentY = startY;

                // Draw original sequences first (fading out)
                const fade = 1 - (progress / totalSequences);
                ctx.globalAlpha = Math.max(0, fade);
                 sequences.forEach((seq, i) => {
                    const y = startY + i * (tokenHeight + padding * 2);
                    drawSequence(padding * 2, y, seq);
                 });
                 ctx.globalAlpha = 1.0;


                // Draw packed sequences (fading in)
                const fadeIn = progress / totalSequences;
                ctx.globalAlpha = Math.min(1, fadeIn * 2); // Faster fade in
                 packedSequences.slice(0, Math.ceil(progress)).forEach((pack, i) => {
                     const y = startY + i * (tokenHeight + padding * 2); // Draw packed below originals initially
                     drawSequence(padding * 2, y, pack);
                     currentY = y + tokenHeight + padding * 2;
                 });
                 ctx.globalAlpha = 1.0;


                if (progress < totalSequences) {
                    progress += animationSpeed / 20; // Adjust speed
                    animationFrameId = requestAnimationFrame(step);
                } else {
                     // Final state: only packed sequences
                     clearCanvas();
                     packedSequences.forEach((pack, i) => {
                         const y = startY + i * (tokenHeight + padding * 2);
                         drawSequence(padding * 2, y, pack);
                     });
                }
            }
            step();
        }

        // 4. Attention Animations (Local/Global)
        function animateAttention(isGlobal) {
             cancelAnimationFrame(animationFrameId);
             const seq = sequences[2]; // Use the longest sequence for demo
             const midIndex = Math.floor(seq.length / 2); // Token that will 'attend'
             explanationDiv.innerHTML = `
                <h3 class="font-semibold mb-1">${isGlobal ? 'Global' : 'Local'} Attention</h3>
                Attention allows tokens to "look" at other tokens to understand context.
                <b>${isGlobal ? 'Global Attention:' : 'Local Attention:'}</b>
                ${isGlobal
                    ? 'Every token attends to every other token in the sequence. Powerful but computationally expensive, especially for long sequences.'
                    : `Each token attends only to a fixed-size window of nearby tokens (e.g., +/- ${Math.floor(localWindowSize / 2)} here). Much faster for long sequences.`
                } ModernBERT alternates between these layers.
            `;

            let highlightProgress = 0; // 0 to 1

            function step() {
                clearCanvas();
                const startY = padding * 3;
                const startX = padding * 2;

                let highlightMap = {};
                highlightMap[midIndex] = 'attending'; // The token doing the attending

                const currentHighlight = Math.min(1, highlightProgress);

                if (isGlobal) {
                    // Highlight all tokens based on progress
                    for (let i = 0; i < seq.length; i++) {
                        if (i !== midIndex) {
                             // Simple linear fade-in for attended tokens
                             if (Math.random() < currentHighlight) { // Randomly highlight based on progress for effect
                                highlightMap[i] = 'attended';
                             }
                        }
                    }
                } else { // Local Attention
                    const windowStart = Math.max(0, midIndex - Math.floor(localWindowSize / 2));
                    const windowEnd = Math.min(seq.length - 1, midIndex + Math.floor(localWindowSize / 2));

                    for (let i = windowStart; i <= windowEnd; i++) {
                         // Highlight based on progress within the window
                         if (i !== midIndex) {
                             const dist = Math.abs(i - midIndex);
                             const requiredProgress = dist / (localWindowSize / 2); // Closer tokens highlight sooner
                             if (currentHighlight >= requiredProgress) {
                                highlightMap[i] = 'local'; // Use 'local' type for window visualization
                             }
                         }
                    }
                }

                drawSequence(startX, startY, seq, highlightMap);

                if (highlightProgress < 1) {
                    highlightProgress += 0.01 * animationSpeed;
                    animationFrameId = requestAnimationFrame(step);
                } else {
                    // Ensure final state is fully highlighted
                     clearCanvas();
                     highlightMap = {};
                     highlightMap[midIndex] = 'attending';
                     if (isGlobal) {
                         for (let i = 0; i < seq.length; i++) if (i !== midIndex) highlightMap[i] = 'attended';
                     } else {
                         const windowStart = Math.max(0, midIndex - Math.floor(localWindowSize / 2));
                         const windowEnd = Math.min(seq.length - 1, midIndex + Math.floor(localWindowSize / 2));
                         for (let i = windowStart; i <= windowEnd; i++) if (i !== midIndex) highlightMap[i] = 'local';
                     }
                     drawSequence(startX, startY, seq, highlightMap);
                }
            }
            step();
        }

        // --- Reset Function ---
        function resetVisualization() {
            cancelAnimationFrame(animationFrameId);
            clearCanvas();
            explanationDiv.innerHTML = 'Click a button above to see the animation and explanation.';
            // Optionally redraw initial state if needed
             // drawInitialState(); // Implement if you want a default view
        }

        // --- Event Listeners ---
        document.getElementById('btnPadding').addEventListener('click', animatePadding);
        document.getElementById('btnUnpadding').addEventListener('click', animateUnpadding);
        document.getElementById('btnPacking').addEventListener('click', animateSequencePacking);
        document.getElementById('btnLocalAttn').addEventListener('click', () => animateAttention(false));
        document.getElementById('btnGlobalAttn').addEventListener('click', () => animateAttention(true));
        document.getElementById('btnReset').addEventListener('click', resetVisualization);


        // --- Initial Setup & Resize ---
        function resizeCanvas() {
            // Make canvas resolution match its display size
            const displayWidth = canvas.clientWidth;
            const displayHeight = canvas.clientHeight; // Use clientHeight for aspect ratio consistency

            if (canvas.width !== displayWidth || canvas.height !== displayHeight) {
                canvas.width = displayWidth;
                canvas.height = displayHeight;
                // Redraw current state if an animation was running? Or just reset?
                // For simplicity, we'll reset on resize.
                resetVisualization();
            }
        }

        // Initial resize and setup listener
        window.addEventListener('resize', resizeCanvas);
        // Ensure initial sizing is correct after elements are laid out
        window.addEventListener('load', () => {
             resizeCanvas();
             resetVisualization(); // Start clean
        });

    </script>

</body>
</html>