3v324v23 commited on
Commit
f02caca
·
1 Parent(s): 9c44d64

Initial commit of Neural Network Playground with improved node display and Linear Regression support

Browse files
Files changed (6) hide show
  1. README.md +46 -9
  2. css/styles.css +119 -11
  3. index.html +3 -0
  4. js/drag-drop.js +348 -136
  5. js/main.js +240 -105
  6. js/neural-network.js +31 -6
README.md CHANGED
@@ -1,9 +1,46 @@
1
- ---
2
- title: Neural Network Playground
3
- emoji: 🌖
4
- colorFrom: pink
5
- colorTo: blue
6
- sdk: static
7
- pinned: false
8
- ---
9
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Neural Network Playground
2
+
3
+ An interactive web-based application for visualizing and experimenting with neural network architectures.
4
+
5
+ ## Features
6
+
7
+ - **Drag-and-Drop Interface**: Easily create neural network architectures by dragging and dropping different layer types
8
+ - **Multiple Layer Types**: Support for Input, Hidden, Output, Convolutional, and Pooling layers
9
+ - **Dynamic Connections**: Create connections between layers to define your network topology
10
+ - **Visual Styling**: Beautiful gradient-based styling for different layer types with animations
11
+ - **Layer Properties**: View and edit detailed properties for each layer
12
+ - **Network Validation**: Automatic validation of network architectures
13
+ - **Training Simulation**: Visual simulation of the training process
14
+ - **Responsive Design**: Works on desktop and mobile devices
15
+
16
+ ## Getting Started
17
+
18
+ 1. Clone this repository
19
+ 2. Open `index.html` in your browser or use a local server:
20
+ ```
21
+ python -m http.server
22
+ ```
23
+ 3. Visit `http://localhost:8000` in your browser
24
+
25
+ ## How to Use
26
+
27
+ 1. Drag layer components from the left panel onto the canvas
28
+ 2. Connect layers by dragging from output ports (right side) to input ports (left side)
29
+ 3. Click on a layer to view its properties
30
+ 4. Edit layer properties by clicking the edit button
31
+ 5. Click "Run Network" to simulate training
32
+
33
+ ## Technologies Used
34
+
35
+ - HTML5
36
+ - CSS3 (with animations and gradients)
37
+ - JavaScript (vanilla)
38
+ - No external libraries required!
39
+
40
+ ## License
41
+
42
+ MIT
43
+
44
+ ## Contributing
45
+
46
+ Contributions, issues, and feature requests are welcome!
css/styles.css CHANGED
@@ -32,6 +32,8 @@
32
  --pool-node-color-1: #e74c3c;
33
  --pool-node-color-2: #c0392b;
34
  --node-glow: 0 0 15px rgba(255, 255, 255, 0.8);
 
 
35
  }
36
 
37
  body {
@@ -157,6 +159,12 @@ header h1 {
157
  color: white;
158
  }
159
 
 
 
 
 
 
 
160
  .node-icon {
161
  width: 24px;
162
  height: 24px;
@@ -251,6 +259,12 @@ header h1 {
251
  color: white;
252
  }
253
 
 
 
 
 
 
 
254
  .controls {
255
  margin-top: 2rem;
256
  }
@@ -385,21 +399,66 @@ footer p {
385
  .canvas-node {
386
  position: absolute;
387
  width: 180px;
388
- padding: 0.8rem;
389
- border-radius: var(--border-radius);
390
- color: white;
391
  box-shadow: var(--shadow-md);
392
- z-index: 10;
393
- transition: all 0.3s ease;
394
  cursor: move;
395
- background-size: 300% 300%;
396
- animation: gradientShift 8s ease infinite;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
397
  }
398
 
399
- @keyframes gradientShift {
400
- 0% { background-position: 0% 50%; }
401
- 50% { background-position: 100% 50%; }
402
- 100% { background-position: 0% 50%; }
 
 
 
 
 
 
 
 
 
 
 
 
403
  }
404
 
405
  .canvas-node.dragging {
@@ -434,6 +493,11 @@ footer p {
434
  border: 2px solid var(--pool-node-color-1);
435
  }
436
 
 
 
 
 
 
437
  .canvas-node .node-title {
438
  font-weight: 600;
439
  font-size: 0.9rem;
@@ -1222,4 +1286,48 @@ select {
1222
  100% {
1223
  box-shadow: 0 0 0 0 rgba(231, 76, 60, 0);
1224
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1225
  }
 
32
  --pool-node-color-1: #e74c3c;
33
  --pool-node-color-2: #c0392b;
34
  --node-glow: 0 0 15px rgba(255, 255, 255, 0.8);
35
+ --linear-node-color-1: #1abc9c;
36
+ --linear-node-color-2: #16a085;
37
  }
38
 
39
  body {
 
159
  color: white;
160
  }
161
 
162
+ .node-item[data-type="linear"] {
163
+ background-image: linear-gradient(135deg, var(--linear-node-color-1), var(--linear-node-color-2));
164
+ border-bottom: 3px solid var(--linear-node-color-2);
165
+ box-shadow: var(--shadow-sm);
166
+ }
167
+
168
  .node-icon {
169
  width: 24px;
170
  height: 24px;
 
259
  color: white;
260
  }
261
 
262
+ .linear-node {
263
+ background-image: linear-gradient(135deg, var(--linear-node-color-1), var(--linear-node-color-2));
264
+ position: relative;
265
+ border-radius: var(--border-radius);
266
+ }
267
+
268
  .controls {
269
  margin-top: 2rem;
270
  }
 
399
  .canvas-node {
400
  position: absolute;
401
  width: 180px;
402
+ min-height: 120px;
403
+ border-radius: 8px;
 
404
  box-shadow: var(--shadow-md);
405
+ padding: 0;
 
406
  cursor: move;
407
+ transition: all 0.2s ease-in-out;
408
+ z-index: 2;
409
+ display: flex;
410
+ flex-direction: column;
411
+ }
412
+
413
+ .canvas-node .node-header {
414
+ padding: 8px 12px;
415
+ font-weight: bold;
416
+ border-radius: 8px 8px 0 0;
417
+ color: white;
418
+ background-color: rgba(0, 0, 0, 0.2);
419
+ text-align: center;
420
+ font-size: 14px;
421
+ border-bottom: 1px solid rgba(255, 255, 255, 0.2);
422
+ }
423
+
424
+ .canvas-node .node-content {
425
+ padding: 8px;
426
+ flex-grow: 1;
427
+ display: flex;
428
+ flex-direction: column;
429
+ gap: 8px;
430
+ }
431
+
432
+ .shape-info {
433
+ background-color: rgba(255, 255, 255, 0.15);
434
+ border-radius: 4px;
435
+ padding: 6px;
436
+ }
437
+
438
+ .shape-row {
439
+ display: flex;
440
+ justify-content: space-between;
441
+ font-size: 12px;
442
+ color: white;
443
+ padding: 2px 0;
444
  }
445
 
446
+ .shape-label {
447
+ font-weight: bold;
448
+ }
449
+
450
+ .params-section {
451
+ background-color: rgba(255, 255, 255, 0.15);
452
+ border-radius: 4px;
453
+ padding: 6px;
454
+ }
455
+
456
+ .params-display {
457
+ font-family: monospace;
458
+ font-size: 11px;
459
+ color: white;
460
+ margin: 0;
461
+ white-space: pre-wrap;
462
  }
463
 
464
  .canvas-node.dragging {
 
493
  border: 2px solid var(--pool-node-color-1);
494
  }
495
 
496
+ .canvas-node[data-type="linear"] {
497
+ background-image: linear-gradient(135deg, var(--linear-node-color-1), var(--linear-node-color-2));
498
+ border: 2px solid var(--linear-node-color-2);
499
+ }
500
+
501
  .canvas-node .node-title {
502
  font-weight: 600;
503
  font-size: 0.9rem;
 
1286
  100% {
1287
  box-shadow: 0 0 0 0 rgba(231, 76, 60, 0);
1288
  }
1289
+ }
1290
+
1291
+ /* Style for connection ports */
1292
+ .port {
1293
+ position: absolute;
1294
+ width: 12px;
1295
+ height: 12px;
1296
+ border-radius: 50%;
1297
+ background-color: white;
1298
+ border: 2px solid rgba(0, 0, 0, 0.3);
1299
+ z-index: 3;
1300
+ transition: all 0.2s ease;
1301
+ }
1302
+
1303
+ .port:hover {
1304
+ transform: scale(1.5);
1305
+ box-shadow: 0 0 5px rgba(255, 255, 255, 0.8);
1306
+ }
1307
+
1308
+ .input-port {
1309
+ top: 50%;
1310
+ left: -6px;
1311
+ transform: translateY(-50%);
1312
+ }
1313
+
1314
+ .output-port {
1315
+ top: 50%;
1316
+ right: -6px;
1317
+ transform: translateY(-50%);
1318
+ }
1319
+
1320
+ /* Connection line styles */
1321
+ .connection-line {
1322
+ stroke: rgba(255, 255, 255, 0.7);
1323
+ stroke-width: 2;
1324
+ fill: none;
1325
+ pointer-events: none;
1326
+ }
1327
+
1328
+ .connection-line-temp {
1329
+ stroke: rgba(255, 255, 255, 0.5);
1330
+ stroke-dasharray: 5, 5;
1331
+ stroke-width: 2;
1332
+ fill: none;
1333
  }
index.html CHANGED
@@ -39,6 +39,9 @@
39
  <div class="node-item" draggable="true" data-type="pool">
40
  <div class="node pool-node">Pooling</div>
41
  </div>
 
 
 
42
  </div>
43
 
44
  <h3 class="section-title">Sample Data</h3>
 
39
  <div class="node-item" draggable="true" data-type="pool">
40
  <div class="node pool-node">Pooling</div>
41
  </div>
42
+ <div class="node-item" draggable="true" data-type="linear">
43
+ <div class="node linear-node">Linear Regression</div>
44
+ </div>
45
  </div>
46
 
47
  <h3 class="section-title">Sample Data</h3>
js/drag-drop.js CHANGED
@@ -76,107 +76,137 @@ function initializeDragAndDrop() {
76
  canvasNode.style.left = `${x}px`;
77
  canvasNode.style.top = `${y}px`;
78
 
79
- // Set node content based on type
80
- let nodeName, dimensions, units;
 
 
 
81
 
82
  switch(nodeType) {
83
  case 'input':
84
  nodeName = 'Input Layer';
85
- dimensions = '1 × 28 × 28';
 
 
86
  break;
87
  case 'hidden':
88
- // Customize if it's the first hidden layer
89
  const hiddenCount = document.querySelectorAll('.canvas-node[data-type="hidden"]').length;
90
- units = hiddenCount === 0 ? 128 : 64;
91
  nodeName = `Hidden Layer ${hiddenCount + 1}`;
92
- dimensions = `${units}`;
 
 
 
93
  break;
94
  case 'output':
95
  nodeName = 'Output Layer';
96
- dimensions = '10';
 
 
97
  break;
98
  case 'conv':
99
  const convCount = document.querySelectorAll('.canvas-node[data-type="conv"]').length;
100
- const filters = 32 * (convCount + 1);
101
  nodeName = `Conv2D ${convCount + 1}`;
102
- dimensions = `${filters} × 26 × 26`;
 
 
 
103
  break;
104
  case 'pool':
105
  const poolCount = document.querySelectorAll('.canvas-node[data-type="pool"]').length;
106
- nodeName = `MaxPool ${poolCount + 1}`;
107
- dimensions = '32 × 13 × 13';
 
 
108
  break;
109
  default:
110
- nodeName = 'Neural Node';
111
- dimensions = '64';
 
 
112
  }
113
 
114
- canvasNode.innerHTML = `
115
- <div class="node-title">${nodeName}</div>
116
- <div class="node-id">${layerId}</div>
117
- <div class="node-dimensions">${dimensions}</div>
118
- <div class="node-port port-in"></div>
119
- <div class="node-port port-out"></div>
120
- <div class="node-controls">
121
- <button class="node-edit-btn" title="Edit layer parameters"><i class="icon">⚙️</i></button>
122
- <button class="node-delete-btn" title="Delete layer"><i class="icon">🗑️</i></button>
123
- </div>
 
 
 
 
 
124
  `;
125
 
126
- // Store dimensions for hover display
127
- canvasNode.setAttribute('data-dimensions', dimensions);
128
- canvasNode.setAttribute('data-name', nodeName);
 
129
 
130
- // Add to network layers
131
- const layerInfo = {
132
- id: layerId,
133
- type: nodeType,
134
- name: nodeName,
135
- dimensions: dimensions,
136
- position: { x, y }
137
- };
138
 
139
- networkLayers.layers.push(layerInfo);
 
 
140
 
141
- // Add to canvas
142
- canvas.appendChild(canvasNode);
 
143
 
144
- // Add events for moving nodes on the canvas
145
- canvasNode.addEventListener('mousedown', startDrag);
 
 
146
 
147
- // Connection handling
148
- const portIn = canvasNode.querySelector('.port-in');
149
- const portOut = canvasNode.querySelector('.port-out');
150
 
151
- portOut.addEventListener('mousedown', (e) => {
 
 
 
 
 
 
 
 
152
  e.stopPropagation();
153
  startConnection(canvasNode, e);
154
  });
155
 
156
- portIn.addEventListener('mouseup', (e) => {
157
- e.stopPropagation();
158
- endConnection(canvasNode);
159
  });
160
 
161
- // Button event listeners
162
- const editBtn = canvasNode.querySelector('.node-edit-btn');
163
- if (editBtn) {
164
- editBtn.addEventListener('click', (e) => {
165
- e.stopPropagation();
166
- openLayerEditor(canvasNode);
167
- });
168
- }
169
 
170
- const deleteBtn = canvasNode.querySelector('.node-delete-btn');
171
- if (deleteBtn) {
172
- deleteBtn.addEventListener('click', (e) => {
173
- e.stopPropagation();
174
- deleteNode(canvasNode);
175
- });
176
- }
177
 
178
- // Update node parameters (for sequential model validation)
179
- updateLayerConnectivity();
 
 
 
 
180
  }
181
  }
182
 
@@ -448,87 +478,269 @@ function initializeDragAndDrop() {
448
 
449
  // End creating a connection
450
  function endConnection(targetNode) {
451
- if (!isConnecting) return;
452
-
453
- // Check if a valid node port was targeted
454
- if (targetNode && targetNode.classList && targetNode.classList.contains('canvas-node')) {
455
- // Get node IDs for the connection
456
- const sourceId = startNode.getAttribute('data-id');
457
- const targetId = targetNode.getAttribute('data-id');
458
-
459
- // Check if connection already exists
460
- const exists = networkLayers.connections.some(conn =>
461
- conn.source === sourceId && conn.target === targetId
462
- );
463
-
464
- if (!exists) {
465
- // Create permanent connection
466
- const connection = connectionLine.cloneNode(true);
467
- connection.classList.remove('temp-connection');
468
- connection.setAttribute('data-source', sourceId);
469
- connection.setAttribute('data-target', targetId);
470
- canvas.appendChild(connection);
471
-
472
- // Add to connections array
473
- networkLayers.connections.push({
474
- source: sourceId,
475
- target: targetId,
476
- sourceType: startNode.getAttribute('data-type'),
477
- targetType: targetNode.getAttribute('data-type')
478
- });
479
-
480
- // Update parameters for model consistency
481
- updateLayerConnectivity();
482
-
483
- console.log(`Connected ${sourceId} to ${targetId}`);
484
- }
485
- }
486
-
487
- // Remove temporary line
488
- if (connectionLine && connectionLine.parentNode) {
489
- connectionLine.parentNode.removeChild(connectionLine);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
490
  }
491
 
492
- // Remove port highlights
493
  removePortHighlights();
494
-
495
- // Reset variables
 
 
496
  isConnecting = false;
497
  startNode = null;
498
- connectionLine = null;
499
-
500
- // Remove event listeners
501
- document.removeEventListener('mousemove', drawConnection);
502
- document.removeEventListener('mouseup', cancelConnection);
503
  }
504
 
505
- // Update layer connectivity to ensure model consistency
506
- function updateLayerConnectivity() {
507
- // This is where we'd propagate input/output shapes between connected layers
508
- // For now we'll just highlight connected nodes
509
-
510
- // Reset all nodes
511
- document.querySelectorAll('.canvas-node').forEach(node => {
512
- node.classList.remove('connected-node');
513
- });
514
-
515
- // Mark all nodes that have connections
516
- const connectedNodeIds = new Set();
517
- networkLayers.connections.forEach(conn => {
518
- connectedNodeIds.add(conn.source);
519
- connectedNodeIds.add(conn.target);
520
- });
521
-
522
- connectedNodeIds.forEach(id => {
523
- const node = document.querySelector(`.canvas-node[data-id="${id}"]`);
524
- if (node) {
525
- node.classList.add('connected-node');
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
526
  }
527
- });
528
-
529
- // Trigger a custom event that the main script can listen for
530
- const event = new CustomEvent('networkUpdated', { detail: networkLayers });
531
- document.dispatchEvent(event);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
532
  }
533
 
534
  // Delete a node and its connections
 
76
  canvasNode.style.left = `${x}px`;
77
  canvasNode.style.top = `${y}px`;
78
 
79
+ // Get default config for this node type
80
+ const nodeConfig = window.neuralNetwork.createNodeConfig(nodeType);
81
+
82
+ // Create node content with input and output shape information
83
+ let nodeName, inputShape, outputShape, parameters;
84
 
85
  switch(nodeType) {
86
  case 'input':
87
  nodeName = 'Input Layer';
88
+ inputShape = 'N/A';
89
+ outputShape = '[' + nodeConfig.shape.join(' × ') + ']';
90
+ parameters = nodeConfig.parameters;
91
  break;
92
  case 'hidden':
 
93
  const hiddenCount = document.querySelectorAll('.canvas-node[data-type="hidden"]').length;
94
+ nodeConfig.units = hiddenCount === 0 ? 128 : 64;
95
  nodeName = `Hidden Layer ${hiddenCount + 1}`;
96
+ // Input shape will be updated when connections are made
97
+ inputShape = 'Connect input';
98
+ outputShape = `[${nodeConfig.units}]`;
99
+ parameters = 'Connect input to calculate';
100
  break;
101
  case 'output':
102
  nodeName = 'Output Layer';
103
+ inputShape = 'Connect input';
104
+ outputShape = `[${nodeConfig.units}]`;
105
+ parameters = 'Connect input to calculate';
106
  break;
107
  case 'conv':
108
  const convCount = document.querySelectorAll('.canvas-node[data-type="conv"]').length;
109
+ nodeConfig.filters = 32 * (convCount + 1);
110
  nodeName = `Conv2D ${convCount + 1}`;
111
+ inputShape = 'Connect input';
112
+ outputShape = 'Depends on input';
113
+ // Create parameter string
114
+ parameters = `In: ?, Out: ${nodeConfig.filters}\nKernel: ${nodeConfig.kernelSize.join('×')}\nStride: ${nodeConfig.strides.join('×')}\nPadding: ${nodeConfig.padding}`;
115
  break;
116
  case 'pool':
117
  const poolCount = document.querySelectorAll('.canvas-node[data-type="pool"]').length;
118
+ nodeName = `Pooling ${poolCount + 1}`;
119
+ inputShape = 'Connect input';
120
+ outputShape = 'Depends on input';
121
+ parameters = `Pool size: ${nodeConfig.poolSize.join('×')}\nStride: ${nodeConfig.strides.join('×')}\nPadding: ${nodeConfig.padding}`;
122
  break;
123
  default:
124
+ nodeName = 'Unknown Layer';
125
+ inputShape = 'N/A';
126
+ outputShape = 'N/A';
127
+ parameters = 'N/A';
128
  }
129
 
130
+ // Create node header
131
+ const nodeHeader = document.createElement('div');
132
+ nodeHeader.className = 'node-header';
133
+ nodeHeader.textContent = nodeName;
134
+
135
+ // Create node content
136
+ const nodeContent = document.createElement('div');
137
+ nodeContent.className = 'node-content';
138
+
139
+ // Add shape information in a structured way
140
+ const shapeInfo = document.createElement('div');
141
+ shapeInfo.className = 'shape-info';
142
+ shapeInfo.innerHTML = `
143
+ <div class="shape-row"><span class="shape-label">Input:</span> <span class="input-shape">${inputShape}</span></div>
144
+ <div class="shape-row"><span class="shape-label">Output:</span> <span class="output-shape">${outputShape}</span></div>
145
  `;
146
 
147
+ // Add parameters section
148
+ const paramsSection = document.createElement('div');
149
+ paramsSection.className = 'params-section';
150
+ paramsSection.innerHTML = `<pre class="params-display">${parameters}</pre>`;
151
 
152
+ // Add connection ports
153
+ const inputPort = document.createElement('div');
154
+ inputPort.className = 'port input-port';
155
+ inputPort.setAttribute('data-port-type', 'input');
 
 
 
 
156
 
157
+ const outputPort = document.createElement('div');
158
+ outputPort.className = 'port output-port';
159
+ outputPort.setAttribute('data-port-type', 'output');
160
 
161
+ // Assemble the node
162
+ nodeContent.appendChild(shapeInfo);
163
+ nodeContent.appendChild(paramsSection);
164
 
165
+ canvasNode.appendChild(nodeHeader);
166
+ canvasNode.appendChild(nodeContent);
167
+ canvasNode.appendChild(inputPort);
168
+ canvasNode.appendChild(outputPort);
169
 
170
+ // Add node to the canvas
171
+ canvas.appendChild(canvasNode);
 
172
 
173
+ // Store node configuration
174
+ canvasNode.layerConfig = nodeConfig;
175
+
176
+ // Add event listeners for node manipulation
177
+ canvasNode.addEventListener('mousedown', startDrag);
178
+ inputPort.addEventListener('mousedown', (e) => {
179
+ e.stopPropagation();
180
+ });
181
+ outputPort.addEventListener('mousedown', (e) => {
182
  e.stopPropagation();
183
  startConnection(canvasNode, e);
184
  });
185
 
186
+ // Double-click to edit node properties
187
+ canvasNode.addEventListener('dblclick', () => {
188
+ openLayerEditor(canvasNode);
189
  });
190
 
191
+ // Right-click to delete
192
+ canvasNode.addEventListener('contextmenu', (e) => {
193
+ e.preventDefault();
194
+ deleteNode(canvasNode);
195
+ });
 
 
 
196
 
197
+ // Add to network layers for architecture building
198
+ networkLayers.layers.push({
199
+ id: layerId,
200
+ type: nodeType,
201
+ config: nodeConfig
202
+ });
 
203
 
204
+ // Notify about network changes
205
+ document.dispatchEvent(new CustomEvent('networkUpdated', {
206
+ detail: networkLayers
207
+ }));
208
+
209
+ updateConnections();
210
  }
211
  }
212
 
 
478
 
479
  // End creating a connection
480
  function endConnection(targetNode) {
481
+ if (!isConnecting || !connectionLine || !startNode) return;
482
+
483
+ const sourceType = startNode.getAttribute('data-type');
484
+ const targetType = targetNode.getAttribute('data-type');
485
+ const sourceId = startNode.getAttribute('data-id');
486
+ const targetId = targetNode.getAttribute('data-id');
487
+
488
+ // Check if this is a valid connection
489
+ if (isValidConnection(sourceType, targetType, sourceId, targetId)) {
490
+ // Create a permanent SVG connection
491
+ const canvas = document.getElementById('network-canvas');
492
+ const svgContainer = document.querySelector('#network-canvas .svg-container') || createSVGContainer();
493
+
494
+ // Get positions for source and target nodes
495
+ const sourceRect = startNode.getBoundingClientRect();
496
+ const targetRect = targetNode.getBoundingClientRect();
497
+ const canvasRect = canvas.getBoundingClientRect();
498
+
499
+ // Calculate port positions
500
+ const sourcePort = startNode.querySelector('.output-port');
501
+ const targetPort = targetNode.querySelector('.input-port');
502
+
503
+ const sourcePortRect = sourcePort.getBoundingClientRect();
504
+ const targetPortRect = targetPort.getBoundingClientRect();
505
+
506
+ const startX = sourcePortRect.left + (sourcePortRect.width / 2) - canvasRect.left;
507
+ const startY = sourcePortRect.top + (sourcePortRect.height / 2) - canvasRect.top;
508
+ const endX = targetPortRect.left + (targetPortRect.width / 2) - canvasRect.left;
509
+ const endY = targetPortRect.top + (targetPortRect.height / 2) - canvasRect.top;
510
+
511
+ // Create the connection
512
+ const pathId = `connection-${sourceId}-${targetId}`;
513
+ const connectionPath = document.createElementNS('http://www.w3.org/2000/svg', 'path');
514
+ connectionPath.setAttribute('id', pathId);
515
+ connectionPath.setAttribute('class', 'connection-line');
516
+
517
+ // Curved path (bezier)
518
+ const dx = Math.abs(endX - startX) * 0.7;
519
+ const path = `M ${startX} ${startY} C ${startX + dx} ${startY}, ${endX - dx} ${endY}, ${endX} ${endY}`;
520
+ connectionPath.setAttribute('d', path);
521
+
522
+ // Add connection to SVG container
523
+ svgContainer.appendChild(connectionPath);
524
+
525
+ // Add to connections
526
+ networkLayers.connections.push({
527
+ id: pathId,
528
+ source: sourceId,
529
+ target: targetId,
530
+ sourceType: sourceType,
531
+ targetType: targetType
532
+ });
533
+
534
+ // Update input and output shapes
535
+ updateNodeShapes(sourceId, targetId);
536
+
537
+ // Notify about connection
538
+ document.dispatchEvent(new CustomEvent('networkUpdated', {
539
+ detail: networkLayers
540
+ }));
541
  }
542
 
543
+ // Clean up
544
  removePortHighlights();
545
+ if (connectionLine) {
546
+ connectionLine.remove();
547
+ connectionLine = null;
548
+ }
549
  isConnecting = false;
550
  startNode = null;
 
 
 
 
 
551
  }
552
 
553
+ // Update input and output shapes when connections are made
554
+ function updateNodeShapes(sourceId, targetId) {
555
+ const sourceNode = document.querySelector(`.canvas-node[data-id="${sourceId}"]`);
556
+ const targetNode = document.querySelector(`.canvas-node[data-id="${targetId}"]`);
557
+
558
+ if (sourceNode && targetNode) {
559
+ const sourceConfig = sourceNode.layerConfig;
560
+ const targetConfig = targetNode.layerConfig;
561
+
562
+ // Update the target's input shape based on the source's output shape
563
+ if (sourceConfig && targetConfig) {
564
+ // Calculate output shape based on node type
565
+ let outputShape;
566
+ switch (sourceNode.getAttribute('data-type')) {
567
+ case 'input':
568
+ outputShape = sourceConfig.shape;
569
+ break;
570
+ case 'hidden':
571
+ outputShape = [sourceConfig.units];
572
+ break;
573
+ case 'output':
574
+ outputShape = [sourceConfig.units];
575
+ break;
576
+ case 'conv':
577
+ // For Conv2D, the output shape depends on the input and parameters
578
+ // This is a simplified calculation
579
+ if (targetConfig.inputShape) {
580
+ const h = targetConfig.inputShape[0];
581
+ const w = targetConfig.inputShape[1];
582
+ const kh = sourceConfig.kernelSize[0];
583
+ const kw = sourceConfig.kernelSize[1];
584
+ const sh = sourceConfig.strides[0];
585
+ const sw = sourceConfig.strides[1];
586
+ const padding = sourceConfig.padding;
587
+
588
+ let outHeight, outWidth;
589
+ if (padding === 'same') {
590
+ outHeight = Math.ceil(h / sh);
591
+ outWidth = Math.ceil(w / sw);
592
+ } else { // 'valid'
593
+ outHeight = Math.ceil((h - kh + 1) / sh);
594
+ outWidth = Math.ceil((w - kw + 1) / sw);
595
+ }
596
+
597
+ outputShape = [outHeight, outWidth, sourceConfig.filters];
598
+ } else {
599
+ outputShape = ['?', '?', sourceConfig.filters];
600
+ }
601
+ break;
602
+ case 'pool':
603
+ // For pooling, also depends on the input and parameters
604
+ if (targetConfig.inputShape) {
605
+ const h = targetConfig.inputShape[0];
606
+ const w = targetConfig.inputShape[1];
607
+ const c = targetConfig.inputShape[2];
608
+ const ph = sourceConfig.poolSize[0];
609
+ const pw = sourceConfig.poolSize[1];
610
+ const sh = sourceConfig.strides[0];
611
+ const sw = sourceConfig.strides[1];
612
+ const padding = sourceConfig.padding;
613
+
614
+ let outHeight, outWidth;
615
+ if (padding === 'same') {
616
+ outHeight = Math.ceil(h / sh);
617
+ outWidth = Math.ceil(w / sw);
618
+ } else { // 'valid'
619
+ outHeight = Math.ceil((h - ph + 1) / sh);
620
+ outWidth = Math.ceil((w - pw + 1) / sw);
621
+ }
622
+
623
+ outputShape = [outHeight, outWidth, c];
624
+ } else {
625
+ outputShape = ['?', '?', '?'];
626
+ }
627
+ break;
628
+ case 'linear':
629
+ outputShape = [sourceConfig.outputFeatures];
630
+ break;
631
+ default:
632
+ outputShape = ['?', '?', '?'];
633
+ }
634
+
635
+ // Update the target's input shape
636
+ targetConfig.inputShape = outputShape;
637
+
638
+ // Update UI
639
+ updateNodeDisplayShapes(sourceNode, targetNode);
640
  }
641
+ }
642
+ }
643
+
644
+ // Update the displayed shapes in the UI
645
+ function updateNodeDisplayShapes(sourceNode, targetNode) {
646
+ if (sourceNode && targetNode) {
647
+ const sourceType = sourceNode.getAttribute('data-type');
648
+ const targetType = targetNode.getAttribute('data-type');
649
+ const sourceConfig = sourceNode.layerConfig;
650
+ const targetConfig = targetNode.layerConfig;
651
+
652
+ // Update source node output shape display
653
+ const sourceOutputElem = sourceNode.querySelector('.output-shape');
654
+ if (sourceOutputElem && sourceConfig) {
655
+ let outputText;
656
+ switch (sourceType) {
657
+ case 'input':
658
+ outputText = `[${sourceConfig.shape.join(' × ')}]`;
659
+ break;
660
+ case 'hidden':
661
+ case 'output':
662
+ outputText = `[${sourceConfig.units}]`;
663
+ break;
664
+ case 'conv':
665
+ if (sourceConfig.outputShape) {
666
+ outputText = `[${sourceConfig.outputShape.join(' × ')}]`;
667
+ } else {
668
+ outputText = `[? × ? × ${sourceConfig.filters}]`;
669
+ }
670
+ break;
671
+ case 'pool':
672
+ if (sourceConfig.outputShape) {
673
+ outputText = `[${sourceConfig.outputShape.join(' × ')}]`;
674
+ } else {
675
+ outputText = 'Depends on input';
676
+ }
677
+ break;
678
+ case 'linear':
679
+ outputText = `[${sourceConfig.outputFeatures}]`;
680
+ break;
681
+ default:
682
+ outputText = 'Unknown';
683
+ }
684
+ sourceOutputElem.textContent = outputText;
685
+ }
686
+
687
+ // Update target node input shape display
688
+ const targetInputElem = targetNode.querySelector('.input-shape');
689
+ if (targetInputElem && targetConfig && targetConfig.inputShape) {
690
+ targetInputElem.textContent = `[${targetConfig.inputShape.join(' × ')}]`;
691
+
692
+ // Update parameters section
693
+ const targetParamsElem = targetNode.querySelector('.params-display');
694
+ if (targetParamsElem) {
695
+ // Calculate and display parameters
696
+ let paramsText = '';
697
+ switch (targetType) {
698
+ case 'hidden':
699
+ const inputUnits = Array.isArray(targetConfig.inputShape) ?
700
+ targetConfig.inputShape.reduce((acc, val) => acc * val, 1) :
701
+ targetConfig.inputShape;
702
+
703
+ const biasParams = targetConfig.useBias ? targetConfig.units : 0;
704
+ const totalParams = (inputUnits * targetConfig.units) + biasParams;
705
+
706
+ paramsText = `In: ${inputUnits}, Out: ${targetConfig.units}\nParams: ${totalParams.toLocaleString()}\nDropout: ${targetConfig.dropoutRate}`;
707
+ break;
708
+ case 'output':
709
+ const outInputUnits = Array.isArray(targetConfig.inputShape) ?
710
+ targetConfig.inputShape.reduce((acc, val) => acc * val, 1) :
711
+ targetConfig.inputShape;
712
+
713
+ const outBiasParams = targetConfig.useBias ? targetConfig.units : 0;
714
+ const outTotalParams = (outInputUnits * targetConfig.units) + outBiasParams;
715
+
716
+ paramsText = `In: ${outInputUnits}, Out: ${targetConfig.units}\nParams: ${outTotalParams.toLocaleString()}\nActivation: ${targetConfig.activation}`;
717
+ break;
718
+ case 'conv':
719
+ const channels = targetConfig.inputShape[2] || '?';
720
+ const kernelH = targetConfig.kernelSize[0];
721
+ const kernelW = targetConfig.kernelSize[1];
722
+ const kernelParams = kernelH * kernelW * channels * targetConfig.filters;
723
+ const convBiasParams = targetConfig.useBias ? targetConfig.filters : 0;
724
+ const convTotalParams = kernelParams + convBiasParams;
725
+
726
+ paramsText = `In: ${channels}, Out: ${targetConfig.filters}\nKernel: ${targetConfig.kernelSize.join('×')}\nStride: ${targetConfig.strides.join('×')}\nPadding: ${targetConfig.padding}\nParams: ${convTotalParams.toLocaleString()}`;
727
+ break;
728
+ case 'pool':
729
+ paramsText = `Pool size: ${targetConfig.poolSize.join('×')}\nStride: ${targetConfig.strides.join('×')}\nPadding: ${targetConfig.padding}\nParams: 0`;
730
+ break;
731
+ case 'linear':
732
+ const linearInputs = targetConfig.inputFeatures;
733
+ const linearBiasParams = targetConfig.useBias ? targetConfig.outputFeatures : 0;
734
+ const linearTotalParams = (linearInputs * targetConfig.outputFeatures) + linearBiasParams;
735
+
736
+ paramsText = `In: ${linearInputs}, Out: ${targetConfig.outputFeatures}\nParams: ${linearTotalParams.toLocaleString()}\nLearning Rate: ${targetConfig.learningRate}\nLoss: ${targetConfig.lossFunction}`;
737
+ break;
738
+ }
739
+
740
+ targetParamsElem.textContent = paramsText;
741
+ }
742
+ }
743
+ }
744
  }
745
 
746
  // Delete a node and its connections
js/main.js CHANGED
@@ -299,214 +299,352 @@ document.addEventListener('DOMContentLoaded', () => {
299
 
300
  // Handle opening the layer editor
301
  function handleOpenLayerEditor(e) {
302
- const layerDetails = e.detail;
303
- console.log('Opening layer editor for:', layerDetails);
 
304
 
305
- const layerEditorModal = document.getElementById('layer-editor-modal');
306
- if (!layerEditorModal) return;
307
-
308
- // Get the form and populate it
309
- const layerForm = layerEditorModal.querySelector('.layer-form');
310
- if (!layerForm) return;
311
 
312
- // Set the layer ID in a data attribute for retrieval when saving
313
- layerForm.setAttribute('data-layer-id', layerDetails.id);
314
- layerForm.setAttribute('data-layer-type', layerDetails.type);
315
-
316
- // Set modal title
317
- const modalTitle = layerEditorModal.querySelector('.modal-title');
318
  if (modalTitle) {
319
- modalTitle.textContent = `Edit ${layerDetails.name}`;
320
  }
321
 
322
- // Get layer config template
323
- const layerConfig = window.neuralNetwork.nodeConfigTemplates[layerDetails.type];
 
324
 
325
- // Generate form fields based on layer type
326
  layerForm.innerHTML = '';
327
 
328
- // Add common fields
329
- layerForm.innerHTML += `
330
- <div class="form-group">
331
- <label for="layer-name">Layer Name</label>
332
- <input type="text" id="layer-name" value="${layerDetails.name}">
333
- </div>
334
- `;
335
-
336
- // Add type-specific fields
337
- switch (layerDetails.type) {
338
  case 'input':
 
339
  layerForm.innerHTML += `
340
  <div class="form-group">
341
- <label for="input-shape">Input Shape</label>
342
- <input type="text" id="input-shape" value="${layerConfig.shape.join(' × ')}">
 
 
 
 
 
343
  </div>
344
  <div class="form-group">
345
- <label for="batch-size">Batch Size</label>
346
- <input type="number" id="batch-size" value="${layerConfig.batchSize}">
347
  </div>
348
  `;
349
  break;
350
 
351
  case 'hidden':
 
352
  layerForm.innerHTML += `
353
  <div class="form-group">
354
- <label for="units">Units</label>
355
- <input type="number" id="units" value="${layerConfig.units}">
 
356
  </div>
357
  <div class="form-group">
358
- <label for="activation">Activation</label>
359
- <select id="activation">
360
  <option value="relu" ${layerConfig.activation === 'relu' ? 'selected' : ''}>ReLU</option>
361
  <option value="sigmoid" ${layerConfig.activation === 'sigmoid' ? 'selected' : ''}>Sigmoid</option>
362
  <option value="tanh" ${layerConfig.activation === 'tanh' ? 'selected' : ''}>Tanh</option>
363
- <option value="linear" ${layerConfig.activation === 'linear' ? 'selected' : ''}>Linear</option>
364
  </select>
365
  </div>
366
  <div class="form-group">
367
- <label for="use-bias">Use Bias</label>
368
- <select id="use-bias">
369
- <option value="true" ${layerConfig.useBias ? 'selected' : ''}>Yes</option>
370
- <option value="false" ${!layerConfig.useBias ? 'selected' : ''}>No</option>
371
- </select>
372
  </div>
373
  <div class="form-group">
374
- <label for="dropout-rate">Dropout Rate</label>
375
- <input type="number" id="dropout-rate" min="0" max="0.9" step="0.1" value="${layerConfig.dropoutRate}">
376
  </div>
377
  `;
 
 
 
 
 
 
 
 
 
 
 
378
  break;
379
 
380
  case 'output':
 
381
  layerForm.innerHTML += `
382
  <div class="form-group">
383
- <label for="units">Units</label>
384
- <input type="number" id="units" value="${layerConfig.units}">
 
385
  </div>
386
  <div class="form-group">
387
- <label for="activation">Activation</label>
388
- <select id="activation">
389
- <option value="softmax" ${layerConfig.activation === 'softmax' ? 'selected' : ''}>Softmax</option>
390
- <option value="sigmoid" ${layerConfig.activation === 'sigmoid' ? 'selected' : ''}>Sigmoid</option>
391
- <option value="linear" ${layerConfig.activation === 'linear' ? 'selected' : ''}>Linear</option>
392
  </select>
393
  </div>
 
 
 
 
394
  `;
395
  break;
396
 
397
  case 'conv':
 
398
  layerForm.innerHTML += `
399
  <div class="form-group">
400
- <label for="filters">Filters</label>
401
- <input type="number" id="filters" value="${layerConfig.filters}">
 
402
  </div>
403
  <div class="form-group">
404
- <label for="kernel-size">Kernel Size</label>
405
- <input type="text" id="kernel-size" value="${layerConfig.kernelSize.join(' × ')}">
 
 
 
 
406
  </div>
407
  <div class="form-group">
408
- <label for="strides">Strides</label>
409
- <input type="text" id="strides" value="${layerConfig.strides.join(' × ')}">
 
 
 
410
  </div>
411
  <div class="form-group">
412
- <label for="padding">Padding</label>
413
- <select id="padding">
414
- <option value="valid" ${layerConfig.padding === 'valid' ? 'selected' : ''}>Valid</option>
415
- <option value="same" ${layerConfig.padding === 'same' ? 'selected' : ''}>Same</option>
416
  </select>
417
  </div>
418
  <div class="form-group">
419
- <label for="activation">Activation</label>
420
- <select id="activation">
421
  <option value="relu" ${layerConfig.activation === 'relu' ? 'selected' : ''}>ReLU</option>
422
  <option value="sigmoid" ${layerConfig.activation === 'sigmoid' ? 'selected' : ''}>Sigmoid</option>
423
  <option value="tanh" ${layerConfig.activation === 'tanh' ? 'selected' : ''}>Tanh</option>
424
- <option value="linear" ${layerConfig.activation === 'linear' ? 'selected' : ''}>Linear</option>
425
  </select>
426
  </div>
427
  `;
428
  break;
429
 
430
  case 'pool':
 
431
  layerForm.innerHTML += `
432
  <div class="form-group">
433
- <label for="pool-size">Pool Size</label>
434
- <input type="text" id="pool-size" value="${layerConfig.poolSize.join(' × ')}">
 
 
 
 
 
 
 
 
 
 
435
  </div>
436
  <div class="form-group">
437
- <label for="strides">Strides</label>
438
- <input type="text" id="strides" value="${layerConfig.strides.join(' × ')}">
 
 
 
439
  </div>
440
  <div class="form-group">
441
- <label for="padding">Padding</label>
442
- <select id="padding">
443
- <option value="valid" ${layerConfig.padding === 'valid' ? 'selected' : ''}>Valid</option>
444
- <option value="same" ${layerConfig.padding === 'same' ? 'selected' : ''}>Same</option>
445
  </select>
446
  </div>
447
  `;
448
  break;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
449
  }
450
 
451
- // Add save button
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
452
  layerForm.innerHTML += `
453
- <div class="form-group form-grid-full">
454
- <button type="button" class="btn btn-primary save-layer-btn">Save Changes</button>
 
455
  </div>
456
  `;
457
 
458
- // Show the modal
459
- openModal(layerEditorModal);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
460
  }
461
 
462
  // Save layer configuration
463
- function saveLayerConfig() {
464
- const layerEditorModal = document.getElementById('layer-editor-modal');
465
- if (!layerEditorModal) return;
466
-
467
- const layerForm = layerEditorModal.querySelector('.layer-form');
468
- if (!layerForm) return;
469
-
470
- const layerId = layerForm.getAttribute('data-layer-id');
471
- const layerType = layerForm.getAttribute('data-layer-type');
472
 
473
- // Get node on canvas
474
- const node = document.querySelector(`.canvas-node[data-id="${layerId}"]`);
475
- if (!node) return;
 
 
476
 
477
- // Get form values
478
- const name = document.getElementById('layer-name').value;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
479
 
480
  // Update node title
481
  const nodeTitle = node.querySelector('.node-title');
482
  if (nodeTitle) {
483
- nodeTitle.textContent = name;
484
  }
485
 
486
  // Update node data attribute
487
- node.setAttribute('data-name', name);
488
 
489
  // Update dimensions based on layer type
490
  let dimensions = '';
491
- switch (layerType) {
492
  case 'input':
493
- const inputShape = document.getElementById('input-shape').value;
494
- dimensions = inputShape;
495
  break;
496
 
497
  case 'hidden':
498
  case 'output':
499
- const units = document.getElementById('units').value;
500
- dimensions = units;
501
  break;
502
 
503
  case 'conv':
504
- const filters = document.getElementById('filters').value;
505
- dimensions = `${filters} × 26 × 26`; // Simplified
506
  break;
507
 
508
  case 'pool':
509
- dimensions = '32 × 13 × 13'; // Simplified
 
 
 
 
510
  break;
511
  }
512
 
@@ -524,16 +662,13 @@ document.addEventListener('DOMContentLoaded', () => {
524
  const layerIndex = networkLayers.layers.findIndex(layer => layer.id === layerId);
525
 
526
  if (layerIndex !== -1) {
527
- networkLayers.layers[layerIndex].name = name;
528
  networkLayers.layers[layerIndex].dimensions = dimensions;
529
  }
530
 
531
  // Trigger network updated event
532
  const event = new CustomEvent('networkUpdated', { detail: networkLayers });
533
  document.dispatchEvent(event);
534
-
535
- // Close the modal
536
- closeModal(layerEditorModal);
537
  }
538
 
539
  // Handle sample selection
 
299
 
300
  // Handle opening the layer editor
301
  function handleOpenLayerEditor(e) {
302
+ const node = e.detail.node;
303
+ const nodeType = node.getAttribute('data-type');
304
+ const layerId = node.getAttribute('data-id');
305
 
306
+ // Get current configuration
307
+ const layerConfig = node.layerConfig || window.neuralNetwork.createNodeConfig(nodeType);
 
 
 
 
308
 
309
+ // Update modal title
310
+ const modalTitle = document.querySelector('.layer-editor-modal .modal-title');
 
 
 
 
311
  if (modalTitle) {
312
+ modalTitle.textContent = `Edit ${nodeType.charAt(0).toUpperCase() + nodeType.slice(1)} Layer`;
313
  }
314
 
315
+ // Get layer form
316
+ const layerForm = document.querySelector('.layer-form');
317
+ if (!layerForm) return;
318
 
319
+ // Clear previous form fields
320
  layerForm.innerHTML = '';
321
 
322
+ // Create form fields based on layer type
323
+ switch (nodeType) {
 
 
 
 
 
 
 
 
324
  case 'input':
325
+ // Input shape fields
326
  layerForm.innerHTML += `
327
  <div class="form-group">
328
+ <label>Input Dimensions:</label>
329
+ <div class="form-row">
330
+ <input type="number" id="input-height" min="1" value="${layerConfig.shape[0]}" placeholder="Height">
331
+ <input type="number" id="input-width" min="1" value="${layerConfig.shape[1]}" placeholder="Width">
332
+ <input type="number" id="input-channels" min="1" value="${layerConfig.shape[2]}" placeholder="Channels">
333
+ </div>
334
+ <small>Input shape: [${layerConfig.shape.join(' × ')}]</small>
335
  </div>
336
  <div class="form-group">
337
+ <label>Batch Size:</label>
338
+ <input type="number" id="batch-size" min="1" value="${layerConfig.batchSize}" placeholder="Batch Size">
339
  </div>
340
  `;
341
  break;
342
 
343
  case 'hidden':
344
+ // Units and activation function
345
  layerForm.innerHTML += `
346
  <div class="form-group">
347
+ <label>Units:</label>
348
+ <input type="number" id="hidden-units" min="1" value="${layerConfig.units}" placeholder="Number of units">
349
+ <small>Output shape: [${layerConfig.units}]</small>
350
  </div>
351
  <div class="form-group">
352
+ <label>Activation Function:</label>
353
+ <select id="hidden-activation">
354
  <option value="relu" ${layerConfig.activation === 'relu' ? 'selected' : ''}>ReLU</option>
355
  <option value="sigmoid" ${layerConfig.activation === 'sigmoid' ? 'selected' : ''}>Sigmoid</option>
356
  <option value="tanh" ${layerConfig.activation === 'tanh' ? 'selected' : ''}>Tanh</option>
357
+ <option value="leaky_relu" ${layerConfig.activation === 'leaky_relu' ? 'selected' : ''}>Leaky ReLU</option>
358
  </select>
359
  </div>
360
  <div class="form-group">
361
+ <label>Dropout Rate:</label>
362
+ <input type="range" id="dropout-rate" min="0" max="0.9" step="0.1" value="${layerConfig.dropoutRate}">
363
+ <span id="dropout-value">${layerConfig.dropoutRate}</span>
 
 
364
  </div>
365
  <div class="form-group">
366
+ <label>Use Bias:</label>
367
+ <input type="checkbox" id="use-bias" ${layerConfig.useBias ? 'checked' : ''}>
368
  </div>
369
  `;
370
+
371
+ // Add listener for dropout rate slider
372
+ setTimeout(() => {
373
+ const dropoutSlider = document.getElementById('dropout-rate');
374
+ const dropoutValue = document.getElementById('dropout-value');
375
+ if (dropoutSlider && dropoutValue) {
376
+ dropoutSlider.addEventListener('input', (e) => {
377
+ dropoutValue.textContent = e.target.value;
378
+ });
379
+ }
380
+ }, 100);
381
  break;
382
 
383
  case 'output':
384
+ // Output units and activation
385
  layerForm.innerHTML += `
386
  <div class="form-group">
387
+ <label>Units:</label>
388
+ <input type="number" id="output-units" min="1" value="${layerConfig.units}" placeholder="Number of output units">
389
+ <small>Output shape: [${layerConfig.units}]</small>
390
  </div>
391
  <div class="form-group">
392
+ <label>Activation Function:</label>
393
+ <select id="output-activation">
394
+ <option value="softmax" ${layerConfig.activation === 'softmax' ? 'selected' : ''}>Softmax (Classification)</option>
395
+ <option value="sigmoid" ${layerConfig.activation === 'sigmoid' ? 'selected' : ''}>Sigmoid (Binary Classification)</option>
396
+ <option value="linear" ${layerConfig.activation === 'linear' ? 'selected' : ''}>Linear (Regression)</option>
397
  </select>
398
  </div>
399
+ <div class="form-group">
400
+ <label>Use Bias:</label>
401
+ <input type="checkbox" id="output-use-bias" ${layerConfig.useBias ? 'checked' : ''}>
402
+ </div>
403
  `;
404
  break;
405
 
406
  case 'conv':
407
+ // Convolutional layer parameters
408
  layerForm.innerHTML += `
409
  <div class="form-group">
410
+ <label>Filters:</label>
411
+ <input type="number" id="conv-filters" min="1" value="${layerConfig.filters}" placeholder="Number of filters">
412
+ <small>Output channels</small>
413
  </div>
414
  <div class="form-group">
415
+ <label>Kernel Size:</label>
416
+ <div class="form-row">
417
+ <input type="number" id="kernel-size-h" min="1" max="7" value="${layerConfig.kernelSize[0]}" placeholder="Height">
418
+ <input type="number" id="kernel-size-w" min="1" max="7" value="${layerConfig.kernelSize[1]}" placeholder="Width">
419
+ </div>
420
+ <small>Filter dimensions: ${layerConfig.kernelSize.join(' × ')}</small>
421
  </div>
422
  <div class="form-group">
423
+ <label>Strides:</label>
424
+ <div class="form-row">
425
+ <input type="number" id="stride-h" min="1" max="4" value="${layerConfig.strides[0]}" placeholder="Height">
426
+ <input type="number" id="stride-w" min="1" max="4" value="${layerConfig.strides[1]}" placeholder="Width">
427
+ </div>
428
  </div>
429
  <div class="form-group">
430
+ <label>Padding:</label>
431
+ <select id="padding-type">
432
+ <option value="valid" ${layerConfig.padding === 'valid' ? 'selected' : ''}>Valid (No Padding)</option>
433
+ <option value="same" ${layerConfig.padding === 'same' ? 'selected' : ''}>Same (Preserve Dimensions)</option>
434
  </select>
435
  </div>
436
  <div class="form-group">
437
+ <label>Activation Function:</label>
438
+ <select id="conv-activation">
439
  <option value="relu" ${layerConfig.activation === 'relu' ? 'selected' : ''}>ReLU</option>
440
  <option value="sigmoid" ${layerConfig.activation === 'sigmoid' ? 'selected' : ''}>Sigmoid</option>
441
  <option value="tanh" ${layerConfig.activation === 'tanh' ? 'selected' : ''}>Tanh</option>
442
+ <option value="leaky_relu" ${layerConfig.activation === 'leaky_relu' ? 'selected' : ''}>Leaky ReLU</option>
443
  </select>
444
  </div>
445
  `;
446
  break;
447
 
448
  case 'pool':
449
+ // Pooling layer parameters
450
  layerForm.innerHTML += `
451
  <div class="form-group">
452
+ <label>Pool Size:</label>
453
+ <div class="form-row">
454
+ <input type="number" id="pool-size-h" min="1" max="4" value="${layerConfig.poolSize[0]}" placeholder="Height">
455
+ <input type="number" id="pool-size-w" min="1" max="4" value="${layerConfig.poolSize[1]}" placeholder="Width">
456
+ </div>
457
+ </div>
458
+ <div class="form-group">
459
+ <label>Strides:</label>
460
+ <div class="form-row">
461
+ <input type="number" id="pool-stride-h" min="1" max="4" value="${layerConfig.strides[0]}" placeholder="Height">
462
+ <input type="number" id="pool-stride-w" min="1" max="4" value="${layerConfig.strides[1]}" placeholder="Width">
463
+ </div>
464
  </div>
465
  <div class="form-group">
466
+ <label>Padding:</label>
467
+ <select id="pool-padding">
468
+ <option value="valid" ${layerConfig.padding === 'valid' ? 'selected' : ''}>Valid (No Padding)</option>
469
+ <option value="same" ${layerConfig.padding === 'same' ? 'selected' : ''}>Same (Preserve Dimensions)</option>
470
+ </select>
471
  </div>
472
  <div class="form-group">
473
+ <label>Pool Type:</label>
474
+ <select id="pool-type">
475
+ <option value="max" selected>Max Pooling</option>
476
+ <option value="avg">Average Pooling</option>
477
  </select>
478
  </div>
479
  `;
480
  break;
481
+
482
+ case 'linear':
483
+ // Linear regression layer parameters
484
+ layerForm.innerHTML += `
485
+ <div class="form-group">
486
+ <label>Input Features:</label>
487
+ <input type="number" id="input-features" min="1" value="${layerConfig.inputFeatures}" placeholder="Number of input features">
488
+ <small>Input shape: [${layerConfig.inputFeatures}]</small>
489
+ </div>
490
+ <div class="form-group">
491
+ <label>Output Features:</label>
492
+ <input type="number" id="output-features" min="1" value="${layerConfig.outputFeatures}" placeholder="Number of output features">
493
+ <small>Output shape: [${layerConfig.outputFeatures}]</small>
494
+ </div>
495
+ <div class="form-group">
496
+ <label>Use Bias:</label>
497
+ <input type="checkbox" id="linear-use-bias" ${layerConfig.useBias ? 'checked' : ''}>
498
+ </div>
499
+ <div class="form-group">
500
+ <label>Learning Rate:</label>
501
+ <input type="range" id="learning-rate-slider" min="0.001" max="0.1" step="0.001" value="${layerConfig.learningRate}">
502
+ <span id="learning-rate-value">${layerConfig.learningRate}</span>
503
+ </div>
504
+ <div class="form-group">
505
+ <label>Loss Function:</label>
506
+ <select id="loss-function">
507
+ <option value="mse" ${layerConfig.lossFunction === 'mse' ? 'selected' : ''}>Mean Squared Error</option>
508
+ <option value="mae" ${layerConfig.lossFunction === 'mae' ? 'selected' : ''}>Mean Absolute Error</option>
509
+ <option value="huber" ${layerConfig.lossFunction === 'huber' ? 'selected' : ''}>Huber Loss</option>
510
+ </select>
511
+ </div>
512
+ <div class="form-group">
513
+ <label>Optimizer:</label>
514
+ <select id="optimizer">
515
+ <option value="sgd" ${layerConfig.optimizer === 'sgd' ? 'selected' : ''}>Stochastic Gradient Descent</option>
516
+ <option value="adam" ${layerConfig.optimizer === 'adam' ? 'selected' : ''}>Adam</option>
517
+ <option value="rmsprop" ${layerConfig.optimizer === 'rmsprop' ? 'selected' : ''}>RMSprop</option>
518
+ </select>
519
+ </div>
520
+ `;
521
+
522
+ // Add listener for learning rate slider
523
+ setTimeout(() => {
524
+ const learningRateSlider = document.getElementById('learning-rate-slider');
525
+ const learningRateValue = document.getElementById('learning-rate-value');
526
+ if (learningRateSlider && learningRateValue) {
527
+ learningRateSlider.addEventListener('input', (e) => {
528
+ learningRateValue.textContent = parseFloat(e.target.value).toFixed(3);
529
+ });
530
+ }
531
+ }, 100);
532
+ break;
533
+
534
+ default:
535
+ layerForm.innerHTML = '<p>No editable properties for this layer type.</p>';
536
  }
537
 
538
+ // Add a preview of calculated parameters if available
539
+ if (nodeType !== 'input') {
540
+ const parameterCount = window.neuralNetwork.calculateParameters(nodeType, layerConfig);
541
+ if (parameterCount) {
542
+ layerForm.innerHTML += `
543
+ <div class="form-group">
544
+ <label>Parameter Summary:</label>
545
+ <div class="parameters-summary">
546
+ <p>Total parameters: <strong>${formatNumber(parameterCount)}</strong></p>
547
+ <p>Memory usage (32-bit): ~${formatMemorySize(parameterCount * 4)}</p>
548
+ </div>
549
+ </div>
550
+ `;
551
+ }
552
+ }
553
+
554
+ // Add save and cancel buttons
555
  layerForm.innerHTML += `
556
+ <div class="form-buttons">
557
+ <button type="button" id="save-layer-config" class="btn-primary">Save Changes</button>
558
+ <button type="button" id="cancel-layer-edit" class="btn-secondary">Cancel</button>
559
  </div>
560
  `;
561
 
562
+ // Open the modal
563
+ const modal = document.getElementById('layer-editor-modal');
564
+ if (modal) {
565
+ openModal(modal);
566
+
567
+ // Add event listeners for buttons
568
+ const saveButton = document.getElementById('save-layer-config');
569
+ if (saveButton) {
570
+ saveButton.addEventListener('click', () => {
571
+ saveLayerConfig(node, nodeType, layerId);
572
+ closeModal(modal);
573
+ });
574
+ }
575
+
576
+ const cancelButton = document.getElementById('cancel-layer-edit');
577
+ if (cancelButton) {
578
+ cancelButton.addEventListener('click', () => {
579
+ closeModal(modal);
580
+ });
581
+ }
582
+ }
583
  }
584
 
585
  // Save layer configuration
586
+ function saveLayerConfig(node, nodeType, layerId) {
587
+ // Get form values
588
+ const form = document.querySelector('.layer-form');
589
+ if (!form) return;
 
 
 
 
 
590
 
591
+ const values = {};
592
+ const inputs = form.querySelectorAll('input, select');
593
+ inputs.forEach(input => {
594
+ values[input.id] = input.value;
595
+ });
596
 
597
+ // Update node configuration
598
+ node.layerConfig = {
599
+ type: nodeType,
600
+ shape: [
601
+ parseInt(values['input-height']),
602
+ parseInt(values['input-width']),
603
+ parseInt(values['input-channels'])
604
+ ],
605
+ batchSize: parseInt(values['batch-size']),
606
+ units: parseInt(values['hidden-units']),
607
+ activation: values['hidden-activation'],
608
+ dropoutRate: parseFloat(values['dropout-rate']),
609
+ useBias: values['use-bias'] === 'true',
610
+ learningRate: parseFloat(values['learning-rate-slider']),
611
+ lossFunction: values['loss-function'],
612
+ optimizer: values['optimizer'],
613
+ inputFeatures: parseInt(values['input-features']),
614
+ outputFeatures: parseInt(values['output-features'])
615
+ };
616
 
617
  // Update node title
618
  const nodeTitle = node.querySelector('.node-title');
619
  if (nodeTitle) {
620
+ nodeTitle.textContent = nodeType.charAt(0).toUpperCase() + nodeType.slice(1);
621
  }
622
 
623
  // Update node data attribute
624
+ node.setAttribute('data-name', nodeType.charAt(0).toUpperCase() + nodeType.slice(1));
625
 
626
  // Update dimensions based on layer type
627
  let dimensions = '';
628
+ switch (nodeType) {
629
  case 'input':
630
+ dimensions = values['input-height'] + ' × ' + values['input-width'] + ' × ' + values['input-channels'];
 
631
  break;
632
 
633
  case 'hidden':
634
  case 'output':
635
+ dimensions = values['hidden-units'];
 
636
  break;
637
 
638
  case 'conv':
639
+ dimensions = values['conv-filters'] + ' × ' + values['kernel-size-h'] + ' × ' + values['kernel-size-w'];
 
640
  break;
641
 
642
  case 'pool':
643
+ dimensions = values['pool-size-h'] + ' × ' + values['pool-size-w'];
644
+ break;
645
+
646
+ case 'linear':
647
+ dimensions = values['input-features'] + ' → ' + values['output-features'];
648
  break;
649
  }
650
 
 
662
  const layerIndex = networkLayers.layers.findIndex(layer => layer.id === layerId);
663
 
664
  if (layerIndex !== -1) {
665
+ networkLayers.layers[layerIndex].name = nodeType.charAt(0).toUpperCase() + nodeType.slice(1);
666
  networkLayers.layers[layerIndex].dimensions = dimensions;
667
  }
668
 
669
  // Trigger network updated event
670
  const event = new CustomEvent('networkUpdated', { detail: networkLayers });
671
  document.dispatchEvent(event);
 
 
 
672
  }
673
 
674
  // Handle sample selection
js/neural-network.js CHANGED
@@ -11,7 +11,8 @@
11
  'hidden': 0,
12
  'output': 0,
13
  'conv': 0,
14
- 'pool': 0
 
15
  };
16
 
17
  // Default configuration templates for different layer types
@@ -21,7 +22,9 @@
21
  shape: [28, 28, 1],
22
  batchSize: 32,
23
  description: 'Input layer for raw data',
24
- parameters: 0
 
 
25
  },
26
  'hidden': {
27
  units: 128,
@@ -30,7 +33,9 @@
30
  kernelInitializer: 'glorotUniform',
31
  biasInitializer: 'zeros',
32
  dropoutRate: 0.2,
33
- description: 'Dense hidden layer with ReLU activation'
 
 
34
  },
35
  'output': {
36
  units: 10,
@@ -38,7 +43,9 @@
38
  useBias: true,
39
  kernelInitializer: 'glorotUniform',
40
  biasInitializer: 'zeros',
41
- description: 'Output layer with Softmax activation for classification'
 
 
42
  },
43
  'conv': {
44
  filters: 32,
@@ -49,13 +56,31 @@
49
  useBias: true,
50
  kernelInitializer: 'glorotUniform',
51
  biasInitializer: 'zeros',
52
- description: 'Convolutional layer for feature extraction'
 
 
53
  },
54
  'pool': {
55
  poolSize: [2, 2],
56
  strides: [2, 2],
57
  padding: 'valid',
58
- description: 'Max pooling layer for spatial downsampling'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  }
60
  };
61
 
 
11
  'hidden': 0,
12
  'output': 0,
13
  'conv': 0,
14
+ 'pool': 0,
15
+ 'linear': 0
16
  };
17
 
18
  // Default configuration templates for different layer types
 
22
  shape: [28, 28, 1],
23
  batchSize: 32,
24
  description: 'Input layer for raw data',
25
+ parameters: 0,
26
+ inputShape: null,
27
+ outputShape: [784]
28
  },
29
  'hidden': {
30
  units: 128,
 
33
  kernelInitializer: 'glorotUniform',
34
  biasInitializer: 'zeros',
35
  dropoutRate: 0.2,
36
+ description: 'Dense hidden layer with ReLU activation',
37
+ inputShape: null,
38
+ outputShape: null
39
  },
40
  'output': {
41
  units: 10,
 
43
  useBias: true,
44
  kernelInitializer: 'glorotUniform',
45
  biasInitializer: 'zeros',
46
+ description: 'Output layer with Softmax activation for classification',
47
+ inputShape: null,
48
+ outputShape: [10]
49
  },
50
  'conv': {
51
  filters: 32,
 
56
  useBias: true,
57
  kernelInitializer: 'glorotUniform',
58
  biasInitializer: 'zeros',
59
+ description: 'Convolutional layer for feature extraction',
60
+ inputShape: null,
61
+ outputShape: null
62
  },
63
  'pool': {
64
  poolSize: [2, 2],
65
  strides: [2, 2],
66
  padding: 'valid',
67
+ description: 'Max pooling layer for spatial downsampling',
68
+ inputShape: null,
69
+ outputShape: null
70
+ },
71
+ 'linear': {
72
+ inputFeatures: 1,
73
+ outputFeatures: 1,
74
+ useBias: true,
75
+ activation: 'linear',
76
+ learningRate: 0.01,
77
+ optimizer: 'sgd',
78
+ lossFunction: 'mse',
79
+ biasInitializer: 'zeros',
80
+ kernelInitializer: 'glorotUniform',
81
+ description: 'Linear regression layer for numerical prediction',
82
+ inputShape: [1],
83
+ outputShape: [1]
84
  }
85
  };
86