ucalyptus commited on
Commit
4bf0bc7
·
verified ·
1 Parent(s): 68707fe

Update index.html

Browse files
Files changed (1) hide show
  1. index.html +336 -404
index.html CHANGED
@@ -3,433 +3,365 @@
3
  <head>
4
  <meta charset="UTF-8">
5
  <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
- <title>FlashInfer Attention State Visualization</title>
7
  <script src="https://cdn.tailwindcss.com"></script>
8
- <link href="https://fonts.googleapis.com/css2?family=Inter:wght@400;600&display=swap" rel="stylesheet">
9
  <style>
10
  body {
11
  font-family: 'Inter', sans-serif;
12
- background-color: #f3f4f6; /* Tailwind gray-100 */
13
  }
14
- canvas {
15
- background-color: #ffffff; /* White */
16
- border-radius: 0.5rem; /* rounded-lg */
17
- box-shadow: 0 4px 6px -1px rgb(0 0 0 / 0.1), 0 2px 4px -2px rgb(0 0 0 / 0.1); /* shadow-md */
18
  }
19
- .btn {
20
- padding: 0.5rem 1rem; /* py-2 px-4 */
21
- border-radius: 0.375rem; /* rounded-md */
22
- font-weight: 600; /* font-semibold */
23
- color: white;
24
- background-color: #4f46e5; /* indigo-600 */
25
- transition: background-color 0.2s;
26
- cursor: pointer;
27
- margin: 0 0.5rem; /* mx-2 */
28
- box-shadow: 0 1px 2px 0 rgb(0 0 0 / 0.05); /* shadow-sm */
29
  }
30
- .btn:hover {
31
- background-color: #4338ca; /* indigo-700 */
 
32
  }
33
- .btn:disabled {
34
- background-color: #a5b4fc; /* indigo-300 */
35
- cursor: not-allowed;
 
 
 
 
 
 
 
 
 
36
  }
37
- .info-text {
38
- color: #4b5563; /* gray-600 */
39
- font-size: 0.875rem; /* text-sm */
40
- margin-top: 0.5rem; /* mt-2 */
41
  }
42
- .state-box {
43
- border: 2px solid;
44
- border-radius: 0.375rem; /* rounded-md */
45
- padding: 5px;
46
- text-align: center;
47
- font-size: 0.8rem;
48
- background-color: rgba(255, 255, 255, 0.8);
49
  }
50
- .s-value {
51
- font-weight: bold;
52
- color: #1d4ed8; /* blue-700 */
53
  }
54
- .v-value {
55
- display: inline-block;
56
- width: 15px;
57
- height: 15px;
58
- border-radius: 50%;
59
- margin-left: 5px;
60
- vertical-align: middle;
61
  }
62
  </style>
63
  </head>
64
- <body class="flex flex-col items-center justify-center min-h-screen p-4">
65
-
66
- <h1 class="text-2xl font-semibold text-gray-800 mb-4">FlashInfer: Attention States & Recursive Merge</h1>
67
- <p class="text-center text-gray-600 mb-6 max-w-2xl">
68
- This visualization demonstrates how FlashInfer computes attention by:
69
- <br>1. Calculating partial "Attention States" (s, v) for subsets of Key-Value pairs.
70
- <br>2. Recursively merging these states using the ⊕ operator to get the final result.
71
- </p>
72
-
73
- <canvas id="attentionCanvas" width="800" height="400"></canvas>
74
- <p id="statusText" class="info-text h-6"></p> <div class="mt-6 flex justify-center">
75
- <button id="resetBtn" class="btn">Reset & Initialize</button>
76
- <button id="computeStatesBtn" class="btn" disabled>Compute Partial States</button>
77
- <button id="mergeStatesBtn" class="btn" disabled>Merge States</button>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  </div>
79
 
80
  <script>
81
- const canvas = document.getElementById('attentionCanvas');
82
- const ctx = canvas.getContext('2d');
83
- const statusText = document.getElementById('statusText');
84
- const resetBtn = document.getElementById('resetBtn');
85
- const computeStatesBtn = document.getElementById('computeStatesBtn');
86
- const mergeStatesBtn = document.getElementById('mergeStatesBtn');
87
-
88
- let animationFrameId = null;
89
-
90
- // --- Configuration ---
91
- const config = {
92
- queryPos: { x: 50, y: 200 },
93
- kvStartY: 50,
94
- kvSpacingY: 60,
95
- kvCount: 6, // Ensure this is even for easy partitioning
96
- kvPartitionSize: 3, // Size of each partition
97
- stateBoxWidth: 100,
98
- stateBoxHeight: 50,
99
- mergePointX: 600,
100
- finalStateX: 700,
101
- colors: {
102
- query: '#ef4444', // red-500
103
- kv: '#3b82f6', // blue-500
104
- partition1: '#10b981', // emerald-500
105
- partition2: '#f97316', // orange-500
106
- merged: '#8b5cf6', // violet-500
107
- arrow: '#6b7280', // gray-500
108
- text: '#1f2937', // gray-800
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  },
110
- arrowHeadSize: 8,
111
- };
112
-
113
- // --- State Variables ---
114
- let query = {};
115
- let kvPairs = [];
116
- let states = {
117
- partition1: null,
118
- partition2: null,
119
- final: null
120
- };
121
- let currentStep = 0; // 0: initial, 1: kvs shown, 2: states computed, 3: merged
122
-
123
- // --- Drawing Functions ---
124
- function drawArrow(startX, startY, endX, endY, color = config.colors.arrow, progress = 1) {
125
- const dx = endX - startX;
126
- const dy = endY - startY;
127
- const length = Math.sqrt(dx * dx + dy * dy);
128
- const angle = Math.atan2(dy, dx);
129
-
130
- const currentX = startX + dx * progress;
131
- const currentY = startY + dy * progress;
132
-
133
- ctx.beginPath();
134
- ctx.moveTo(startX, startY);
135
- ctx.lineTo(currentX, currentY);
136
- ctx.strokeStyle = color;
137
- ctx.lineWidth = 1.5;
138
- ctx.stroke();
139
-
140
- if (progress >= 1) {
141
- // Draw arrowhead
142
- ctx.beginPath();
143
- ctx.moveTo(currentX, currentY);
144
- ctx.lineTo(currentX - config.arrowHeadSize * Math.cos(angle - Math.PI / 6), currentY - config.arrowHeadSize * Math.sin(angle - Math.PI / 6));
145
- ctx.lineTo(currentX - config.arrowHeadSize * Math.cos(angle + Math.PI / 6), currentY - config.arrowHeadSize * Math.sin(angle + Math.PI / 6));
146
- ctx.closePath();
147
- ctx.fillStyle = color;
148
- ctx.fill();
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  }
151
-
152
- function drawQuery(q) {
153
- ctx.fillStyle = config.colors.query;
154
- ctx.fillRect(q.x - 15, q.y - 15, 30, 30);
155
- ctx.fillStyle = config.colors.text;
156
- ctx.font = 'bold 14px Inter';
157
- ctx.textAlign = 'center';
158
- ctx.fillText('Q', q.x, q.y + 5);
159
- }
160
-
161
- function drawKVPair(kv, index) {
162
- const kvColor = config.colors.kv;
163
- // Draw K
164
- ctx.fillStyle = kvColor;
165
- ctx.beginPath();
166
- ctx.moveTo(kv.x - 10, kv.y - 10);
167
- ctx.lineTo(kv.x, kv.y);
168
- ctx.lineTo(kv.x - 10, kv.y + 10);
169
- ctx.closePath();
170
- ctx.fill();
171
- // Draw V
172
- ctx.beginPath();
173
- ctx.arc(kv.x + 10, kv.y, 10, 0, Math.PI * 2);
174
- ctx.fill();
175
-
176
- ctx.fillStyle = config.colors.text;
177
- ctx.font = '12px Inter';
178
- ctx.textAlign = 'left';
179
- ctx.fillText(`KV ${index + 1}`, kv.x + 25, kv.y + 4);
180
- }
181
-
182
- function drawAttentionState(state, color) {
183
- if (!state) return;
184
- ctx.strokeStyle = color;
185
- ctx.lineWidth = 2;
186
- ctx.fillStyle = 'rgba(255, 255, 255, 0.9)'; // Semi-transparent white background
187
- ctx.fillRect(state.x, state.y, config.stateBoxWidth, config.stateBoxHeight);
188
- ctx.strokeRect(state.x, state.y, config.stateBoxWidth, config.stateBoxHeight);
189
-
190
- ctx.fillStyle = config.colors.text;
191
- ctx.font = 'bold 12px Inter';
192
- ctx.textAlign = 'center';
193
- ctx.fillText(state.label, state.x + config.stateBoxWidth / 2, state.y + 15);
194
-
195
- ctx.font = '11px Inter';
196
- // Draw 's' value (LSE)
197
- ctx.fillStyle = config.colors.text; // Use standard text color for label
198
- ctx.fillText('s:', state.x + 25, state.y + 35);
199
- ctx.fillStyle = config.colors.text; // Use standard text color for value
200
- ctx.textAlign = 'left';
201
- ctx.fillText(state.s.toFixed(2), state.x + 35, state.y + 35); // Display LSE value
202
-
203
- // Draw 'v' representation
204
- ctx.fillStyle = config.colors.text; // Use standard text color for label
205
- ctx.textAlign = 'center';
206
- ctx.fillText('v:', state.x + 70, state.y + 35);
207
- ctx.fillStyle = state.vColor; // Use the state's specific color for the 'v' circle
208
- ctx.beginPath();
209
- ctx.arc(state.x + 85, state.y + 30, 7, 0, Math.PI * 2); // Draw circle representing 'v'
210
- ctx.fill();
211
- }
212
-
213
- // --- Animation Loop ---
214
- let progress = 0;
215
- const animationSpeed = 0.02;
216
-
217
- function animateStep(stepFunction, nextStep) {
218
- cancelAnimationFrame(animationFrameId); // Cancel previous animation if any
219
- progress = 0;
220
-
221
- function loop() {
222
- progress += animationSpeed;
223
- if (progress >= 1) {
224
- progress = 1;
225
- stepFunction(progress); // Draw final frame
226
- currentStep = nextStep;
227
- updateButtons();
228
- setStatus(''); // Clear status after animation
229
- return;
230
  }
231
-
232
- stepFunction(progress); // Draw intermediate frame
233
- animationFrameId = requestAnimationFrame(loop);
234
- }
235
- loop();
236
  }
237
-
238
-
239
- // --- Visualization Steps ---
240
- function initialize() {
241
- cancelAnimationFrame(animationFrameId); // Stop any ongoing animation
242
- ctx.clearRect(0, 0, canvas.width, canvas.height);
243
- currentStep = 0;
244
- states = { partition1: null, partition2: null, final: null };
245
-
246
- // Define Query
247
- query = { ...config.queryPos };
248
-
249
- // Define KV Pairs
250
- kvPairs = [];
251
- const kvStartX = config.queryPos.x + 100;
252
- for (let i = 0; i < config.kvCount; i++) {
253
- kvPairs.push({
254
- x: kvStartX,
255
- y: config.kvStartY + i * config.kvSpacingY,
256
- id: i
257
- });
258
- }
259
-
260
- // Initial Draw
261
- drawQuery(query);
262
- kvPairs.forEach((kv, i) => drawKVPair(kv, i));
263
- currentStep = 1;
264
- updateButtons();
265
- setStatus('Initialized Query and KV Pairs.');
266
- }
267
-
268
- function computePartialStatesStep(p) {
269
- ctx.clearRect(0, 0, canvas.width, canvas.height);
270
- drawQuery(query);
271
- kvPairs.forEach((kv, i) => drawKVPair(kv, i));
272
-
273
- // --- Partition 1 ---
274
- const state1X = config.queryPos.x + 250;
275
- const state1Y = config.kvStartY + (config.kvPartitionSize / 2 - 0.5) * config.kvSpacingY - config.stateBoxHeight / 2;
276
- if (!states.partition1) {
277
- // Simulate LSE and create a representative color for v
278
- states.partition1 = { x: state1X, y: state1Y, s: Math.random() * 5 + 5, vColor: config.colors.partition1, label: `State 1..${config.kvPartitionSize}` };
279
- }
280
-
281
- // Draw arrows from Q to KV (Partition 1)
282
- for (let i = 0; i < config.kvPartitionSize; i++) {
283
- drawArrow(query.x + 15, query.y, kvPairs[i].x - 15, kvPairs[i].y, config.colors.partition1, p);
284
- }
285
- // Draw arrows from KV to State (Partition 1)
286
- if (p >= 0.5) { // Start drawing state arrows halfway through
287
- const stateProgress = (p - 0.5) * 2;
288
- for (let i = 0; i < config.kvPartitionSize; i++) {
289
- drawArrow(kvPairs[i].x + 15, kvPairs[i].y, states.partition1.x, states.partition1.y + config.stateBoxHeight / 2, config.colors.partition1, stateProgress);
290
- }
291
- // Draw state box with fade-in effect (using alpha)
292
- ctx.globalAlpha = stateProgress;
293
- drawAttentionState(states.partition1, config.colors.partition1);
294
- ctx.globalAlpha = 1.0; // Reset alpha
295
- }
296
-
297
-
298
- // --- Partition 2 ---
299
- const state2X = state1X; // Align horizontally
300
- const state2Y = config.kvStartY + (config.kvPartitionSize + config.kvPartitionSize / 2 - 0.5) * config.kvSpacingY - config.stateBoxHeight / 2;
301
- if (!states.partition2) {
302
- states.partition2 = { x: state2X, y: state2Y, s: Math.random() * 5 + 5, vColor: config.colors.partition2, label: `State ${config.kvPartitionSize+1}..${config.kvCount}` };
303
- }
304
-
305
- // Draw arrows from Q to KV (Partition 2)
306
- for (let i = config.kvPartitionSize; i < config.kvCount; i++) {
307
- drawArrow(query.x + 15, query.y, kvPairs[i].x - 15, kvPairs[i].y, config.colors.partition2, p);
308
- }
309
- // Draw arrows from KV to State (Partition 2)
310
- if (p >= 0.5) {
311
- const stateProgress = (p - 0.5) * 2;
312
- for (let i = config.kvPartitionSize; i < config.kvCount; i++) {
313
- drawArrow(kvPairs[i].x + 15, kvPairs[i].y, states.partition2.x, states.partition2.y + config.stateBoxHeight / 2, config.colors.partition2, stateProgress);
314
- }
315
- // Draw state box with fade-in
316
- ctx.globalAlpha = stateProgress;
317
- drawAttentionState(states.partition2, config.colors.partition2);
318
- ctx.globalAlpha = 1.0;
319
- }
320
- }
321
-
322
- function mergeStatesStep(p) {
323
- // Redraw previous step completely first
324
- computePartialStatesStep(1);
325
-
326
- // --- Merge Operation ---
327
- const mergePointY = canvas.height / 2;
328
- const finalStateY = mergePointY - config.stateBoxHeight / 2;
329
-
330
- if (!states.final) {
331
- // Simulate merged state calculation: s = log(e^s1 + e^s2), v is weighted average
332
- const s_final = Math.log(Math.exp(states.partition1.s) + Math.exp(states.partition2.s));
333
- // Simple color mixing for v visualization
334
- const finalVColor = averageHexColors(states.partition1.vColor, states.partition2.vColor);
335
- states.final = { x: config.finalStateX, y: finalStateY, s: s_final, vColor: finalVColor, label: `Final State (1..${config.kvCount})` };
336
- }
337
-
338
- // Draw arrows from partial states to merge point
339
- const mergeArrowEndX = config.mergePointX - 10; // End slightly before the symbol
340
- drawArrow(states.partition1.x + config.stateBoxWidth, states.partition1.y + config.stateBoxHeight / 2, mergeArrowEndX, mergePointY, config.colors.partition1, p);
341
- drawArrow(states.partition2.x + config.stateBoxWidth, states.partition2.y + config.stateBoxHeight / 2, mergeArrowEndX, mergePointY, config.colors.partition2, p);
342
-
343
- // Draw merge symbol (⊕) - appears halfway
344
- if (p >= 0.5) {
345
- const symbolProgress = (p - 0.5) * 2;
346
- ctx.globalAlpha = symbolProgress;
347
- ctx.font = 'bold 30px Inter';
348
- ctx.fillStyle = config.colors.merged;
349
- ctx.textAlign = 'center';
350
- ctx.fillText('⊕', config.mergePointX, mergePointY + 10); // Adjust Y for vertical centering
351
-
352
- // Draw arrow from merge symbol to final state
353
- const finalArrowStartX = config.mergePointX + 15; // Start after the symbol
354
- drawArrow(finalArrowStartX, mergePointY, states.final.x, states.final.y + config.stateBoxHeight / 2, config.colors.merged, symbolProgress);
355
-
356
- // Draw final state box with fade-in
357
- drawAttentionState(states.final, config.colors.merged);
358
- ctx.globalAlpha = 1.0; // Reset alpha
359
- }
360
- }
361
-
362
- // --- Helper Functions ---
363
- function setStatus(text) {
364
- statusText.textContent = text;
365
- }
366
-
367
- function updateButtons() {
368
- resetBtn.disabled = false;
369
- computeStatesBtn.disabled = currentStep < 1 || currentStep >= 2;
370
- mergeStatesBtn.disabled = currentStep < 2 || currentStep >= 3;
371
- }
372
-
373
- // Simple hex color averaging for visualization
374
- function averageHexColors(color1, color2) {
375
- const c1 = parseInt(color1.substring(1), 16);
376
- const c2 = parseInt(color2.substring(1), 16);
377
-
378
- const r1 = (c1 >> 16) & 255;
379
- const g1 = (c1 >> 8) & 255;
380
- const b1 = c1 & 255;
381
-
382
- const r2 = (c2 >> 16) & 255;
383
- const g2 = (c2 >> 8) & 255;
384
- const b2 = c2 & 255;
385
-
386
- const rAvg = Math.round((r1 + r2) / 2);
387
- const gAvg = Math.round((g1 + g2) / 2);
388
- const bAvg = Math.round((b1 + b2) / 2);
389
-
390
- return `#${(1 << 24 | rAvg << 16 | gAvg << 8 | bAvg).toString(16).slice(1).padStart(6, '0')}`;
391
- }
392
-
393
-
394
- // --- Event Listeners ---
395
- resetBtn.addEventListener('click', () => {
396
- initialize();
397
- });
398
-
399
- computeStatesBtn.addEventListener('click', () => {
400
- if (currentStep === 1) {
401
- setStatus('Computing partial attention states...');
402
- animateStep(computePartialStatesStep, 2);
403
- }
404
- });
405
-
406
- mergeStatesBtn.addEventListener('click', () => {
407
- if (currentStep === 2) {
408
- setStatus('Merging attention states...');
409
- animateStep(mergeStatesStep, 3);
410
- }
411
  });
412
-
413
- // --- Initial Setup ---
414
- window.onload = () => {
415
- // Adjust canvas size slightly for high DPI if needed, but keep logical size
416
- const dpr = window.devicePixelRatio || 1;
417
- const rect = canvas.getBoundingClientRect();
418
- // canvas.width = rect.width * dpr; // Keep logical size for layout
419
- // canvas.height = rect.height * dpr;
420
- // ctx.scale(dpr, dpr); // Scale context instead
421
-
422
- initialize(); // Draw initial state on load
423
- };
424
-
425
- // Optional: Redraw on resize
426
- window.addEventListener('resize', () => {
427
- // Basic redraw based on current step - could be more sophisticated
428
- if (currentStep === 1) initialize();
429
- else if (currentStep === 2) computePartialStatesStep(1); // Draw completed step
430
- else if (currentStep === 3) mergeStatesStep(1); // Draw completed step
431
- });
432
-
433
  </script>
434
  </body>
435
- </html>
 
3
  <head>
4
  <meta charset="UTF-8">
5
  <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>FlashInfer: Attention States & Recursive Merge</title>
7
  <script src="https://cdn.tailwindcss.com"></script>
8
+ <link href="https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&display=swap" rel="stylesheet">
9
  <style>
10
  body {
11
  font-family: 'Inter', sans-serif;
12
+ overflow-x: hidden;
13
  }
14
+ .math {
15
+ font-family: 'Times New Roman', serif;
16
+ font-style: italic;
 
17
  }
18
+ .node {
19
+ transition: all 0.5s ease-in-out;
 
 
 
 
 
 
 
 
20
  }
21
+ .arrow {
22
+ stroke-dasharray: 10;
23
+ animation: dash 1s linear infinite;
24
  }
25
+ @keyframes dash {
26
+ to {
27
+ stroke-dashoffset: -20;
28
+ }
29
+ }
30
+ .highlight {
31
+ animation: pulse 2s infinite;
32
+ }
33
+ @keyframes pulse {
34
+ 0% { transform: scale(1); opacity: 1; }
35
+ 50% { transform: scale(1.05); opacity: 0.9; }
36
+ 100% { transform: scale(1); opacity: 1; }
37
  }
38
+ .fade-in {
39
+ opacity: 0;
40
+ animation: fadeIn 1s forwards;
 
41
  }
42
+ @keyframes fadeIn {
43
+ to { opacity: 1; }
 
 
 
 
 
44
  }
45
+ .equation {
46
+ transition: all 0.5s ease;
 
47
  }
48
+ .panel {
49
+ transition: transform 0.5s ease-out;
 
 
 
 
 
50
  }
51
  </style>
52
  </head>
53
+ <body class="bg-gray-100 text-gray-900">
54
+ <div class="container mx-auto p-4 max-w-6xl">
55
+ <header class="text-center my-8">
56
+ <h1 class="text-4xl font-bold text-blue-800 mb-2">FlashInfer: Attention States & Recursive Merge</h1>
57
+ <p class="text-xl text-gray-600">Visualizing how FlashInfer accelerates LLM inference</p>
58
+ </header>
59
+
60
+ <div class="bg-white rounded-xl shadow-lg p-6 mb-8">
61
+ <h2 class="text-2xl font-semibold mb-4 text-blue-700">Key Innovation: Attention States</h2>
62
+ <p class="mb-4">FlashInfer introduces the concept of <strong>attention states</strong>, which fully characterize the attention between a query and a set of key/value pairs. Each attention state consists of two components:</p>
63
+
64
+ <div class="flex flex-col md:flex-row gap-4 mb-6">
65
+ <div class="flex-1 bg-blue-50 p-4 rounded-lg">
66
+ <h3 class="font-medium text-blue-800 mb-2">Generalized Score (s)</h3>
67
+ <div class="flex justify-center">
68
+ <div class="math text-xl">
69
+ s(I) = log(∑<sub>i∈I</sub> exp(s<sub>i</sub>))
70
+ </div>
71
+ </div>
72
+ <p class="mt-2 text-sm text-gray-600">The log-sum-exp (LSE) of pre-softmax attention scores</p>
73
+ </div>
74
+
75
+ <div class="flex-1 bg-blue-50 p-4 rounded-lg">
76
+ <h3 class="font-medium text-blue-800 mb-2">Generalized Value (v)</h3>
77
+ <div class="flex justify-center">
78
+ <div class="math text-xl">
79
+ v(I) = ∑<sub>i∈I</sub> softmax(s<sub>i</sub>)v<sub>i</sub>
80
+ </div>
81
+ </div>
82
+ <p class="mt-2 text-sm text-gray-600">The weighted sum of value vectors using the softmax of scores</p>
83
+ </div>
84
+ </div>
85
+ </div>
86
+
87
+ <div class="bg-white rounded-xl shadow-lg p-6 mb-8">
88
+ <h2 class="text-2xl font-semibold mb-4 text-blue-700">Recursive Merge Operator</h2>
89
+ <p class="mb-4">The key insight of FlashInfer is that attention states can be <strong>merged</strong> efficiently. Given two attention states corresponding to different subsets of KV pairs, we can compute the attention state for their union:</p>
90
+
91
+ <div class="bg-blue-50 p-4 rounded-lg mb-6">
92
+ <div class="flex justify-center">
93
+ <div class="math text-xl">
94
+ [v(I∪J), s(I∪J)] = [v(I), s(I)] ⊕ [v(J), s(J)]
95
+ </div>
96
+ </div>
97
+ </div>
98
+
99
+ <p class="mb-6">This merge operator (⊕) is <strong>commutative</strong> and <strong>associative</strong>, allowing flexible computation strategies.</p>
100
+ </div>
101
+
102
+ <!-- Interactive Visualization -->
103
+ <div class="bg-white rounded-xl shadow-lg p-6 mb-8">
104
+ <h2 class="text-2xl font-semibold mb-4 text-blue-700">Interactive Visualization</h2>
105
+ <p class="mb-6">This animation shows how FlashInfer computes attention for a query over 4 KV pairs by partitioning the work and merging results.</p>
106
+
107
+ <div class="flex justify-center">
108
+ <div class="relative" style="width: 700px; height: 650px;" id="visualization">
109
+ <!-- SVG will be inserted here -->
110
+ </div>
111
+ </div>
112
+
113
+ <div class="flex justify-center mt-4">
114
+ <button id="resetBtn" class="bg-blue-600 hover:bg-blue-700 text-white px-4 py-2 rounded-lg mr-4">
115
+ Reset Animation
116
+ </button>
117
+ <button id="playBtn" class="bg-green-600 hover:bg-green-700 text-white px-4 py-2 rounded-lg">
118
+ Play Animation
119
+ </button>
120
+ </div>
121
+ </div>
122
+
123
+ <div class="bg-white rounded-xl shadow-lg p-6 mb-8">
124
+ <h2 class="text-2xl font-semibold mb-4 text-blue-700">Applications</h2>
125
+
126
+ <div class="grid grid-cols-1 md:grid-cols-2 gap-6">
127
+ <div class="bg-blue-50 p-4 rounded-lg">
128
+ <h3 class="font-medium text-blue-800 mb-2">Shared-Prefix Batch Decoding</h3>
129
+ <p>When multiple sequences share a common prefix (e.g., same prompt), compute the attention state for the shared part once, then merge with each sequence's unique suffix.</p>
130
+ <p class="mt-2 text-sm font-medium text-green-600">Up to 30x speedup in long-context scenarios</p>
131
+ </div>
132
+
133
+ <div class="bg-blue-50 p-4 rounded-lg">
134
+ <h3 class="font-medium text-blue-800 mb-2">KV Sequence Parallelism</h3>
135
+ <p>Partition long KV sequences across multiple processing units, compute partial attention states in parallel, then merge the results.</p>
136
+ <p class="mt-2 text-sm font-medium text-green-600">Improves GPU utilization for memory-constrained scenarios</p>
137
+ </div>
138
+ </div>
139
+ </div>
140
  </div>
141
 
142
  <script>
143
+ // SVG for the visualization
144
+ const svg = `
145
+ <svg width="100%" height="100%" viewBox="0 0 700 650">
146
+ <!-- Query Node -->
147
+ <g class="fade-in" style="animation-delay: 0.5s">
148
+ <circle id="queryNode" cx="350" cy="50" r="30" fill="#4299e1" />
149
+ <text x="350" y="55" text-anchor="middle" fill="white" font-weight="bold">Q</text>
150
+ </g>
151
+
152
+ <!-- KV Nodes -->
153
+ <g class="fade-in" style="animation-delay: 1s">
154
+ <circle id="kv1" cx="175" cy="150" r="25" fill="#9ae6b4" />
155
+ <text x="175" y="155" text-anchor="middle" fill="#2f855a" font-weight="bold">KV₁</text>
156
+
157
+ <circle id="kv2" cx="275" cy="150" r="25" fill="#9ae6b4" />
158
+ <text x="275" y="155" text-anchor="middle" fill="#2f855a" font-weight="bold">KV₂</text>
159
+
160
+ <circle id="kv3" cx="425" cy="150" r="25" fill="#9ae6b4" />
161
+ <text x="425" y="155" text-anchor="middle" fill="#2f855a" font-weight="bold">KV₃</text>
162
+
163
+ <circle id="kv4" cx="525" cy="150" r="25" fill="#9ae6b4" />
164
+ <text x="525" y="155" text-anchor="middle" fill="#2f855a" font-weight="bold">KV₄</text>
165
+ </g>
166
+
167
+ <!-- Lines connecting Query to KVs -->
168
+ <g class="fade-in" style="animation-delay: 1.5s">
169
+ <path id="line1" d="M 350 80 L 175 125" stroke="#4299e1" stroke-width="2" fill="none" />
170
+ <path id="line2" d="M 350 80 L 275 125" stroke="#4299e1" stroke-width="2" fill="none" />
171
+ <path id="line3" d="M 350 80 L 425 125" stroke="#4299e1" stroke-width="2" fill="none" />
172
+ <path id="line4" d="M 350 80 L 525 125" stroke="#4299e1" stroke-width="2" fill="none" />
173
+ </g>
174
+
175
+ <!-- Attention State Nodes -->
176
+ <g id="stateNodes">
177
+ <!-- These will be animated in via JS -->
178
+ <g id="state1" opacity="0">
179
+ <circle cx="175" cy="250" r="25" fill="#feb2b2" />
180
+ <text x="175" y="255" text-anchor="middle" fill="#742a2a" font-weight="bold">s₁,v₁</text>
181
+ </g>
182
+
183
+ <g id="state2" opacity="0">
184
+ <circle cx="275" cy="250" r="25" fill="#feb2b2" />
185
+ <text x="275" y="255" text-anchor="middle" fill="#742a2a" font-weight="bold">s₂,v₂</text>
186
+ </g>
187
+
188
+ <g id="state3" opacity="0">
189
+ <circle cx="425" cy="250" r="25" fill="#feb2b2" />
190
+ <text x="425" y="255" text-anchor="middle" fill="#742a2a" font-weight="bold">s₃,v₃</text>
191
+ </g>
192
+
193
+ <g id="state4" opacity="0">
194
+ <circle cx="525" cy="250" r="25" fill="#feb2b2" />
195
+ <text x="525" y="255" text-anchor="middle" fill="#742a2a" font-weight="bold">s₄,v₄</text>
196
+ </g>
197
+
198
+ <!-- Merge Level 1 -->
199
+ <g id="merge1" opacity="0">
200
+ <circle cx="225" cy="350" r="30" fill="#fbd38d" />
201
+ <text x="225" y="355" text-anchor="middle" fill="#7b341e" font-weight="bold">s₁₂,v₁₂</text>
202
+ </g>
203
+
204
+ <g id="merge2" opacity="0">
205
+ <circle cx="475" cy="350" r="30" fill="#fbd38d" />
206
+ <text x="475" y="355" text-anchor="middle" fill="#7b341e" font-weight="bold">s₃₄,v₃₄</text>
207
+ </g>
208
+
209
+ <!-- Final Merge -->
210
+ <g id="finalMerge" opacity="0">
211
+ <circle cx="350" cy="450" r="35" fill="#d6bcfa" />
212
+ <text x="350" y="455" text-anchor="middle" fill="#553c9a" font-weight="bold">s₁₂₃₄,v₁₂₃₄</text>
213
+ </g>
214
+ </g>
215
+
216
+ <!-- Merge Arrows - will be animated in -->
217
+ <g id="mergeArrows">
218
+ <!-- State to Merge Level 1 -->
219
+ <path id="arrow1" opacity="0" d="M 175 275 L 210 325" stroke="#f6ad55" stroke-width="3" fill="none" class="arrow" marker-end="url(#arrowhead)" />
220
+ <path id="arrow2" opacity="0" d="M 275 275 L 240 325" stroke="#f6ad55" stroke-width="3" fill="none" class="arrow" marker-end="url(#arrowhead)" />
221
+ <path id="arrow3" opacity="0" d="M 425 275 L 460 325" stroke="#f6ad55" stroke-width="3" fill="none" class="arrow" marker-end="url(#arrowhead)" />
222
+ <path id="arrow4" opacity="0" d="M 525 275 L 490 325" stroke="#f6ad55" stroke-width="3" fill="none" class="arrow" marker-end="url(#arrowhead)" />
223
+
224
+ <!-- Merge Level 1 to Final -->
225
+ <path id="arrow5" opacity="0" d="M 225 380 L 320 425" stroke="#b794f4" stroke-width="3" fill="none" class="arrow" marker-end="url(#arrowhead)" />
226
+ <path id="arrow6" opacity="0" d="M 475 380 L 380 425" stroke="#b794f4" stroke-width="3" fill="none" class="arrow" marker-end="url(#arrowhead)" />
227
+ </g>
228
+
229
+ <!-- Equation Panel -->
230
+ <g id="equationPanel" opacity="0" transform="translate(350, 580)">
231
+ <rect x="-330" y="-48" width="660" height="72" rx="10" fill="#ebf8ff" stroke="#4299e1" stroke-width="2" />
232
+ <text id="equationText" x="0" y="0" text-anchor="middle" font-family="monospace" font-size="14">
233
+ Attention States: Computing...
234
+ </text>
235
+ </g>
236
+
237
+ <!-- Arrow marker definition -->
238
+ <defs>
239
+ <marker id="arrowhead" markerWidth="10" markerHeight="7" refX="9" refY="3.5" orient="auto">
240
+ <polygon points="0 0, 10 3.5, 0 7" fill="#000" />
241
+ </marker>
242
+ </defs>
243
+ </svg>`;
244
+
245
+ // Insert the SVG
246
+ document.getElementById('visualization').innerHTML = svg;
247
+
248
+ // Animation sequence
249
+ let animationStep = 0;
250
+ let animationInterval;
251
+ const steps = [
252
+ // Step 0: Already loaded query and KVs
253
+ function() {
254
+ document.getElementById('equationText').textContent = "Computing individual attention states for each KV pair";
255
+ document.getElementById('equationPanel').setAttribute('opacity', '1');
256
  },
257
+ // Step 1: Show state1
258
+ function() {
259
+ document.getElementById('state1').setAttribute('opacity', '1');
260
+ document.getElementById('equationText').textContent = "s₁ = q·k₁ᵀ, v₁ = softmax(s₁)·v₁";
261
+ },
262
+ // Step 2: Show state2
263
+ function() {
264
+ document.getElementById('state2').setAttribute('opacity', '1');
265
+ document.getElementById('equationText').textContent = "s₂ = q·k₂ᵀ, v₂ = softmax(s₂)·v₂";
266
+ },
267
+ // Step 3: Show state3
268
+ function() {
269
+ document.getElementById('state3').setAttribute('opacity', '1');
270
+ document.getElementById('equationText').textContent = "s₃ = q·k₃ᵀ, v₃ = softmax(s₃)·v₃";
271
+ },
272
+ // Step 4: Show state4
273
+ function() {
274
+ document.getElementById('state4').setAttribute('opacity', '1');
275
+ document.getElementById('equationText').textContent = "s₄ = q·k₄ᵀ, v₄ = softmax(s₄)·v₄";
276
+ },
277
+ // Step 5: First merge arrows
278
+ function() {
279
+ document.getElementById('arrow1').setAttribute('opacity', '1');
280
+ document.getElementById('arrow2').setAttribute('opacity', '1');
281
+ document.getElementById('equationText').textContent = "Merging s₁,v₁ and s₂,v�� using the ⊕ operator";
282
+ },
283
+ // Step 6: First merge result
284
+ function() {
285
+ document.getElementById('merge1').setAttribute('opacity', '1');
286
+ document.getElementById('equationText').textContent = "s₁₂ = log(e^s₁ + e^s₂), v₁₂ = (e^s₁·v₁ + e^s₂·v₂)/(e^s₁ + e^s₂)";
287
+ },
288
+ // Step 7: Second merge arrows
289
+ function() {
290
+ document.getElementById('arrow3').setAttribute('opacity', '1');
291
+ document.getElementById('arrow4').setAttribute('opacity', '1');
292
+ document.getElementById('equationText').textContent = "Merging s₃,v₃ and s₄,v₄ using the operator";
293
+ },
294
+ // Step 8: Second merge result
295
+ function() {
296
+ document.getElementById('merge2').setAttribute('opacity', '1');
297
+ document.getElementById('equationText').textContent = "s₃₄ = log(e^s₃ + e^s₄), v₃₄ = (e^s₃·v₃ + e^s₄·v₄)/(e^s₃ + e^s₄)";
298
+ },
299
+ // Step 9: Final merge arrows
300
+ function() {
301
+ document.getElementById('arrow5').setAttribute('opacity', '1');
302
+ document.getElementById('arrow6').setAttribute('opacity', '1');
303
+ document.getElementById('equationText').textContent = "Final merge: Combining s₁₂,v₁₂ and s₃₄,v₃₄";
304
+ },
305
+ // Step 10: Final result
306
+ function() {
307
+ document.getElementById('finalMerge').setAttribute('opacity', '1');
308
+ document.getElementById('finalMerge').classList.add('highlight');
309
+ document.getElementById('equationText').textContent = "s₁₂₃₄ = log(e^s₁₂ + e^s₃₄), v₁₂₃₄ = (e^s₁₂·v₁₂ + e^s₃₄·v₃₄)/(e^s₁₂ + e^s₃₄)";
310
+ },
311
+ // Step 11: Show final explanation
312
+ function() {
313
+ document.getElementById('equationText').innerHTML =
314
+ "Complete! This is equivalent to standard attention but computed in parallel.";
315
+ clearInterval(animationInterval);
316
+ animationStep = 0; // Reset for next time
317
  }
318
+ ];
319
+
320
+ function resetAnimation() {
321
+ // Reset all animated elements
322
+ clearInterval(animationInterval);
323
+
324
+ const stateNodes = document.querySelectorAll('#stateNodes > g');
325
+ stateNodes.forEach(node => {
326
+ node.setAttribute('opacity', '0');
327
+ node.classList.remove('highlight');
328
+ });
329
+
330
+ const arrows = document.querySelectorAll('#mergeArrows > path');
331
+ arrows.forEach(arrow => {
332
+ arrow.setAttribute('opacity', '0');
333
+ });
334
+
335
+ document.getElementById('equationPanel').setAttribute('opacity', '0');
336
+
337
+ animationStep = 0;
338
  }
339
+
340
+ function playAnimation() {
341
+ resetAnimation();
342
+
343
+ // Start the animation sequence
344
+ steps[animationStep]();
345
+ animationStep++;
346
+
347
+ animationInterval = setInterval(() => {
348
+ if (animationStep < steps.length) {
349
+ steps[animationStep]();
350
+ animationStep++;
351
+ } else {
352
+ clearInterval(animationInterval);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
353
  }
354
+ }, 1500); // Animation step timing
 
 
 
 
355
  }
356
+
357
+ // Button handlers
358
+ document.getElementById('resetBtn').addEventListener('click', resetAnimation);
359
+ document.getElementById('playBtn').addEventListener('click', playAnimation);
360
+
361
+ // Auto-start animation when page loads
362
+ window.addEventListener('load', () => {
363
+ setTimeout(playAnimation, 1000);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
364
  });
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
365
  </script>
366
  </body>
367
+ </html>