Initial commit of Neural Network Playground with improved node display and Linear Regression support
Browse files- README.md +46 -9
- css/styles.css +119 -11
- index.html +3 -0
- js/drag-drop.js +348 -136
- js/main.js +240 -105
- js/neural-network.js +31 -6
README.md
CHANGED
@@ -1,9 +1,46 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Neural Network Playground
|
2 |
+
|
3 |
+
An interactive web-based application for visualizing and experimenting with neural network architectures.
|
4 |
+
|
5 |
+
## Features
|
6 |
+
|
7 |
+
- **Drag-and-Drop Interface**: Easily create neural network architectures by dragging and dropping different layer types
|
8 |
+
- **Multiple Layer Types**: Support for Input, Hidden, Output, Convolutional, and Pooling layers
|
9 |
+
- **Dynamic Connections**: Create connections between layers to define your network topology
|
10 |
+
- **Visual Styling**: Beautiful gradient-based styling for different layer types with animations
|
11 |
+
- **Layer Properties**: View and edit detailed properties for each layer
|
12 |
+
- **Network Validation**: Automatic validation of network architectures
|
13 |
+
- **Training Simulation**: Visual simulation of the training process
|
14 |
+
- **Responsive Design**: Works on desktop and mobile devices
|
15 |
+
|
16 |
+
## Getting Started
|
17 |
+
|
18 |
+
1. Clone this repository
|
19 |
+
2. Open `index.html` in your browser or use a local server:
|
20 |
+
```
|
21 |
+
python -m http.server
|
22 |
+
```
|
23 |
+
3. Visit `http://localhost:8000` in your browser
|
24 |
+
|
25 |
+
## How to Use
|
26 |
+
|
27 |
+
1. Drag layer components from the left panel onto the canvas
|
28 |
+
2. Connect layers by dragging from output ports (right side) to input ports (left side)
|
29 |
+
3. Click on a layer to view its properties
|
30 |
+
4. Edit layer properties by clicking the edit button
|
31 |
+
5. Click "Run Network" to simulate training
|
32 |
+
|
33 |
+
## Technologies Used
|
34 |
+
|
35 |
+
- HTML5
|
36 |
+
- CSS3 (with animations and gradients)
|
37 |
+
- JavaScript (vanilla)
|
38 |
+
- No external libraries required!
|
39 |
+
|
40 |
+
## License
|
41 |
+
|
42 |
+
MIT
|
43 |
+
|
44 |
+
## Contributing
|
45 |
+
|
46 |
+
Contributions, issues, and feature requests are welcome!
|
css/styles.css
CHANGED
@@ -32,6 +32,8 @@
|
|
32 |
--pool-node-color-1: #e74c3c;
|
33 |
--pool-node-color-2: #c0392b;
|
34 |
--node-glow: 0 0 15px rgba(255, 255, 255, 0.8);
|
|
|
|
|
35 |
}
|
36 |
|
37 |
body {
|
@@ -157,6 +159,12 @@ header h1 {
|
|
157 |
color: white;
|
158 |
}
|
159 |
|
|
|
|
|
|
|
|
|
|
|
|
|
160 |
.node-icon {
|
161 |
width: 24px;
|
162 |
height: 24px;
|
@@ -251,6 +259,12 @@ header h1 {
|
|
251 |
color: white;
|
252 |
}
|
253 |
|
|
|
|
|
|
|
|
|
|
|
|
|
254 |
.controls {
|
255 |
margin-top: 2rem;
|
256 |
}
|
@@ -385,21 +399,66 @@ footer p {
|
|
385 |
.canvas-node {
|
386 |
position: absolute;
|
387 |
width: 180px;
|
388 |
-
|
389 |
-
border-radius:
|
390 |
-
color: white;
|
391 |
box-shadow: var(--shadow-md);
|
392 |
-
|
393 |
-
transition: all 0.3s ease;
|
394 |
cursor: move;
|
395 |
-
|
396 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
397 |
}
|
398 |
|
399 |
-
|
400 |
-
|
401 |
-
|
402 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
403 |
}
|
404 |
|
405 |
.canvas-node.dragging {
|
@@ -434,6 +493,11 @@ footer p {
|
|
434 |
border: 2px solid var(--pool-node-color-1);
|
435 |
}
|
436 |
|
|
|
|
|
|
|
|
|
|
|
437 |
.canvas-node .node-title {
|
438 |
font-weight: 600;
|
439 |
font-size: 0.9rem;
|
@@ -1222,4 +1286,48 @@ select {
|
|
1222 |
100% {
|
1223 |
box-shadow: 0 0 0 0 rgba(231, 76, 60, 0);
|
1224 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1225 |
}
|
|
|
32 |
--pool-node-color-1: #e74c3c;
|
33 |
--pool-node-color-2: #c0392b;
|
34 |
--node-glow: 0 0 15px rgba(255, 255, 255, 0.8);
|
35 |
+
--linear-node-color-1: #1abc9c;
|
36 |
+
--linear-node-color-2: #16a085;
|
37 |
}
|
38 |
|
39 |
body {
|
|
|
159 |
color: white;
|
160 |
}
|
161 |
|
162 |
+
.node-item[data-type="linear"] {
|
163 |
+
background-image: linear-gradient(135deg, var(--linear-node-color-1), var(--linear-node-color-2));
|
164 |
+
border-bottom: 3px solid var(--linear-node-color-2);
|
165 |
+
box-shadow: var(--shadow-sm);
|
166 |
+
}
|
167 |
+
|
168 |
.node-icon {
|
169 |
width: 24px;
|
170 |
height: 24px;
|
|
|
259 |
color: white;
|
260 |
}
|
261 |
|
262 |
+
.linear-node {
|
263 |
+
background-image: linear-gradient(135deg, var(--linear-node-color-1), var(--linear-node-color-2));
|
264 |
+
position: relative;
|
265 |
+
border-radius: var(--border-radius);
|
266 |
+
}
|
267 |
+
|
268 |
.controls {
|
269 |
margin-top: 2rem;
|
270 |
}
|
|
|
399 |
.canvas-node {
|
400 |
position: absolute;
|
401 |
width: 180px;
|
402 |
+
min-height: 120px;
|
403 |
+
border-radius: 8px;
|
|
|
404 |
box-shadow: var(--shadow-md);
|
405 |
+
padding: 0;
|
|
|
406 |
cursor: move;
|
407 |
+
transition: all 0.2s ease-in-out;
|
408 |
+
z-index: 2;
|
409 |
+
display: flex;
|
410 |
+
flex-direction: column;
|
411 |
+
}
|
412 |
+
|
413 |
+
.canvas-node .node-header {
|
414 |
+
padding: 8px 12px;
|
415 |
+
font-weight: bold;
|
416 |
+
border-radius: 8px 8px 0 0;
|
417 |
+
color: white;
|
418 |
+
background-color: rgba(0, 0, 0, 0.2);
|
419 |
+
text-align: center;
|
420 |
+
font-size: 14px;
|
421 |
+
border-bottom: 1px solid rgba(255, 255, 255, 0.2);
|
422 |
+
}
|
423 |
+
|
424 |
+
.canvas-node .node-content {
|
425 |
+
padding: 8px;
|
426 |
+
flex-grow: 1;
|
427 |
+
display: flex;
|
428 |
+
flex-direction: column;
|
429 |
+
gap: 8px;
|
430 |
+
}
|
431 |
+
|
432 |
+
.shape-info {
|
433 |
+
background-color: rgba(255, 255, 255, 0.15);
|
434 |
+
border-radius: 4px;
|
435 |
+
padding: 6px;
|
436 |
+
}
|
437 |
+
|
438 |
+
.shape-row {
|
439 |
+
display: flex;
|
440 |
+
justify-content: space-between;
|
441 |
+
font-size: 12px;
|
442 |
+
color: white;
|
443 |
+
padding: 2px 0;
|
444 |
}
|
445 |
|
446 |
+
.shape-label {
|
447 |
+
font-weight: bold;
|
448 |
+
}
|
449 |
+
|
450 |
+
.params-section {
|
451 |
+
background-color: rgba(255, 255, 255, 0.15);
|
452 |
+
border-radius: 4px;
|
453 |
+
padding: 6px;
|
454 |
+
}
|
455 |
+
|
456 |
+
.params-display {
|
457 |
+
font-family: monospace;
|
458 |
+
font-size: 11px;
|
459 |
+
color: white;
|
460 |
+
margin: 0;
|
461 |
+
white-space: pre-wrap;
|
462 |
}
|
463 |
|
464 |
.canvas-node.dragging {
|
|
|
493 |
border: 2px solid var(--pool-node-color-1);
|
494 |
}
|
495 |
|
496 |
+
.canvas-node[data-type="linear"] {
|
497 |
+
background-image: linear-gradient(135deg, var(--linear-node-color-1), var(--linear-node-color-2));
|
498 |
+
border: 2px solid var(--linear-node-color-2);
|
499 |
+
}
|
500 |
+
|
501 |
.canvas-node .node-title {
|
502 |
font-weight: 600;
|
503 |
font-size: 0.9rem;
|
|
|
1286 |
100% {
|
1287 |
box-shadow: 0 0 0 0 rgba(231, 76, 60, 0);
|
1288 |
}
|
1289 |
+
}
|
1290 |
+
|
1291 |
+
/* Style for connection ports */
|
1292 |
+
.port {
|
1293 |
+
position: absolute;
|
1294 |
+
width: 12px;
|
1295 |
+
height: 12px;
|
1296 |
+
border-radius: 50%;
|
1297 |
+
background-color: white;
|
1298 |
+
border: 2px solid rgba(0, 0, 0, 0.3);
|
1299 |
+
z-index: 3;
|
1300 |
+
transition: all 0.2s ease;
|
1301 |
+
}
|
1302 |
+
|
1303 |
+
.port:hover {
|
1304 |
+
transform: scale(1.5);
|
1305 |
+
box-shadow: 0 0 5px rgba(255, 255, 255, 0.8);
|
1306 |
+
}
|
1307 |
+
|
1308 |
+
.input-port {
|
1309 |
+
top: 50%;
|
1310 |
+
left: -6px;
|
1311 |
+
transform: translateY(-50%);
|
1312 |
+
}
|
1313 |
+
|
1314 |
+
.output-port {
|
1315 |
+
top: 50%;
|
1316 |
+
right: -6px;
|
1317 |
+
transform: translateY(-50%);
|
1318 |
+
}
|
1319 |
+
|
1320 |
+
/* Connection line styles */
|
1321 |
+
.connection-line {
|
1322 |
+
stroke: rgba(255, 255, 255, 0.7);
|
1323 |
+
stroke-width: 2;
|
1324 |
+
fill: none;
|
1325 |
+
pointer-events: none;
|
1326 |
+
}
|
1327 |
+
|
1328 |
+
.connection-line-temp {
|
1329 |
+
stroke: rgba(255, 255, 255, 0.5);
|
1330 |
+
stroke-dasharray: 5, 5;
|
1331 |
+
stroke-width: 2;
|
1332 |
+
fill: none;
|
1333 |
}
|
index.html
CHANGED
@@ -39,6 +39,9 @@
|
|
39 |
<div class="node-item" draggable="true" data-type="pool">
|
40 |
<div class="node pool-node">Pooling</div>
|
41 |
</div>
|
|
|
|
|
|
|
42 |
</div>
|
43 |
|
44 |
<h3 class="section-title">Sample Data</h3>
|
|
|
39 |
<div class="node-item" draggable="true" data-type="pool">
|
40 |
<div class="node pool-node">Pooling</div>
|
41 |
</div>
|
42 |
+
<div class="node-item" draggable="true" data-type="linear">
|
43 |
+
<div class="node linear-node">Linear Regression</div>
|
44 |
+
</div>
|
45 |
</div>
|
46 |
|
47 |
<h3 class="section-title">Sample Data</h3>
|
js/drag-drop.js
CHANGED
@@ -76,107 +76,137 @@ function initializeDragAndDrop() {
|
|
76 |
canvasNode.style.left = `${x}px`;
|
77 |
canvasNode.style.top = `${y}px`;
|
78 |
|
79 |
-
//
|
80 |
-
|
|
|
|
|
|
|
81 |
|
82 |
switch(nodeType) {
|
83 |
case 'input':
|
84 |
nodeName = 'Input Layer';
|
85 |
-
|
|
|
|
|
86 |
break;
|
87 |
case 'hidden':
|
88 |
-
// Customize if it's the first hidden layer
|
89 |
const hiddenCount = document.querySelectorAll('.canvas-node[data-type="hidden"]').length;
|
90 |
-
units = hiddenCount === 0 ? 128 : 64;
|
91 |
nodeName = `Hidden Layer ${hiddenCount + 1}`;
|
92 |
-
|
|
|
|
|
|
|
93 |
break;
|
94 |
case 'output':
|
95 |
nodeName = 'Output Layer';
|
96 |
-
|
|
|
|
|
97 |
break;
|
98 |
case 'conv':
|
99 |
const convCount = document.querySelectorAll('.canvas-node[data-type="conv"]').length;
|
100 |
-
|
101 |
nodeName = `Conv2D ${convCount + 1}`;
|
102 |
-
|
|
|
|
|
|
|
103 |
break;
|
104 |
case 'pool':
|
105 |
const poolCount = document.querySelectorAll('.canvas-node[data-type="pool"]').length;
|
106 |
-
nodeName = `
|
107 |
-
|
|
|
|
|
108 |
break;
|
109 |
default:
|
110 |
-
nodeName = '
|
111 |
-
|
|
|
|
|
112 |
}
|
113 |
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
|
|
|
|
|
|
|
|
|
|
124 |
`;
|
125 |
|
126 |
-
//
|
127 |
-
|
128 |
-
|
|
|
129 |
|
130 |
-
// Add
|
131 |
-
const
|
132 |
-
|
133 |
-
|
134 |
-
name: nodeName,
|
135 |
-
dimensions: dimensions,
|
136 |
-
position: { x, y }
|
137 |
-
};
|
138 |
|
139 |
-
|
|
|
|
|
140 |
|
141 |
-
//
|
142 |
-
|
|
|
143 |
|
144 |
-
|
145 |
-
canvasNode.
|
|
|
|
|
146 |
|
147 |
-
//
|
148 |
-
|
149 |
-
const portOut = canvasNode.querySelector('.port-out');
|
150 |
|
151 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
152 |
e.stopPropagation();
|
153 |
startConnection(canvasNode, e);
|
154 |
});
|
155 |
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
});
|
160 |
|
161 |
-
//
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
openLayerEditor(canvasNode);
|
167 |
-
});
|
168 |
-
}
|
169 |
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
}
|
177 |
|
178 |
-
//
|
179 |
-
|
|
|
|
|
|
|
|
|
180 |
}
|
181 |
}
|
182 |
|
@@ -448,87 +478,269 @@ function initializeDragAndDrop() {
|
|
448 |
|
449 |
// End creating a connection
|
450 |
function endConnection(targetNode) {
|
451 |
-
if (!isConnecting) return;
|
452 |
-
|
453 |
-
|
454 |
-
|
455 |
-
|
456 |
-
|
457 |
-
|
458 |
-
|
459 |
-
|
460 |
-
|
461 |
-
|
462 |
-
);
|
463 |
-
|
464 |
-
|
465 |
-
|
466 |
-
|
467 |
-
|
468 |
-
|
469 |
-
|
470 |
-
|
471 |
-
|
472 |
-
|
473 |
-
|
474 |
-
|
475 |
-
|
476 |
-
|
477 |
-
|
478 |
-
|
479 |
-
|
480 |
-
|
481 |
-
|
482 |
-
|
483 |
-
|
484 |
-
|
485 |
-
|
486 |
-
|
487 |
-
|
488 |
-
|
489 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
490 |
}
|
491 |
|
492 |
-
//
|
493 |
removePortHighlights();
|
494 |
-
|
495 |
-
|
|
|
|
|
496 |
isConnecting = false;
|
497 |
startNode = null;
|
498 |
-
connectionLine = null;
|
499 |
-
|
500 |
-
// Remove event listeners
|
501 |
-
document.removeEventListener('mousemove', drawConnection);
|
502 |
-
document.removeEventListener('mouseup', cancelConnection);
|
503 |
}
|
504 |
|
505 |
-
// Update
|
506 |
-
function
|
507 |
-
|
508 |
-
|
509 |
-
|
510 |
-
|
511 |
-
|
512 |
-
|
513 |
-
|
514 |
-
|
515 |
-
|
516 |
-
|
517 |
-
|
518 |
-
|
519 |
-
|
520 |
-
|
521 |
-
|
522 |
-
|
523 |
-
|
524 |
-
|
525 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
526 |
}
|
527 |
-
}
|
528 |
-
|
529 |
-
|
530 |
-
|
531 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
532 |
}
|
533 |
|
534 |
// Delete a node and its connections
|
|
|
76 |
canvasNode.style.left = `${x}px`;
|
77 |
canvasNode.style.top = `${y}px`;
|
78 |
|
79 |
+
// Get default config for this node type
|
80 |
+
const nodeConfig = window.neuralNetwork.createNodeConfig(nodeType);
|
81 |
+
|
82 |
+
// Create node content with input and output shape information
|
83 |
+
let nodeName, inputShape, outputShape, parameters;
|
84 |
|
85 |
switch(nodeType) {
|
86 |
case 'input':
|
87 |
nodeName = 'Input Layer';
|
88 |
+
inputShape = 'N/A';
|
89 |
+
outputShape = '[' + nodeConfig.shape.join(' × ') + ']';
|
90 |
+
parameters = nodeConfig.parameters;
|
91 |
break;
|
92 |
case 'hidden':
|
|
|
93 |
const hiddenCount = document.querySelectorAll('.canvas-node[data-type="hidden"]').length;
|
94 |
+
nodeConfig.units = hiddenCount === 0 ? 128 : 64;
|
95 |
nodeName = `Hidden Layer ${hiddenCount + 1}`;
|
96 |
+
// Input shape will be updated when connections are made
|
97 |
+
inputShape = 'Connect input';
|
98 |
+
outputShape = `[${nodeConfig.units}]`;
|
99 |
+
parameters = 'Connect input to calculate';
|
100 |
break;
|
101 |
case 'output':
|
102 |
nodeName = 'Output Layer';
|
103 |
+
inputShape = 'Connect input';
|
104 |
+
outputShape = `[${nodeConfig.units}]`;
|
105 |
+
parameters = 'Connect input to calculate';
|
106 |
break;
|
107 |
case 'conv':
|
108 |
const convCount = document.querySelectorAll('.canvas-node[data-type="conv"]').length;
|
109 |
+
nodeConfig.filters = 32 * (convCount + 1);
|
110 |
nodeName = `Conv2D ${convCount + 1}`;
|
111 |
+
inputShape = 'Connect input';
|
112 |
+
outputShape = 'Depends on input';
|
113 |
+
// Create parameter string
|
114 |
+
parameters = `In: ?, Out: ${nodeConfig.filters}\nKernel: ${nodeConfig.kernelSize.join('×')}\nStride: ${nodeConfig.strides.join('×')}\nPadding: ${nodeConfig.padding}`;
|
115 |
break;
|
116 |
case 'pool':
|
117 |
const poolCount = document.querySelectorAll('.canvas-node[data-type="pool"]').length;
|
118 |
+
nodeName = `Pooling ${poolCount + 1}`;
|
119 |
+
inputShape = 'Connect input';
|
120 |
+
outputShape = 'Depends on input';
|
121 |
+
parameters = `Pool size: ${nodeConfig.poolSize.join('×')}\nStride: ${nodeConfig.strides.join('×')}\nPadding: ${nodeConfig.padding}`;
|
122 |
break;
|
123 |
default:
|
124 |
+
nodeName = 'Unknown Layer';
|
125 |
+
inputShape = 'N/A';
|
126 |
+
outputShape = 'N/A';
|
127 |
+
parameters = 'N/A';
|
128 |
}
|
129 |
|
130 |
+
// Create node header
|
131 |
+
const nodeHeader = document.createElement('div');
|
132 |
+
nodeHeader.className = 'node-header';
|
133 |
+
nodeHeader.textContent = nodeName;
|
134 |
+
|
135 |
+
// Create node content
|
136 |
+
const nodeContent = document.createElement('div');
|
137 |
+
nodeContent.className = 'node-content';
|
138 |
+
|
139 |
+
// Add shape information in a structured way
|
140 |
+
const shapeInfo = document.createElement('div');
|
141 |
+
shapeInfo.className = 'shape-info';
|
142 |
+
shapeInfo.innerHTML = `
|
143 |
+
<div class="shape-row"><span class="shape-label">Input:</span> <span class="input-shape">${inputShape}</span></div>
|
144 |
+
<div class="shape-row"><span class="shape-label">Output:</span> <span class="output-shape">${outputShape}</span></div>
|
145 |
`;
|
146 |
|
147 |
+
// Add parameters section
|
148 |
+
const paramsSection = document.createElement('div');
|
149 |
+
paramsSection.className = 'params-section';
|
150 |
+
paramsSection.innerHTML = `<pre class="params-display">${parameters}</pre>`;
|
151 |
|
152 |
+
// Add connection ports
|
153 |
+
const inputPort = document.createElement('div');
|
154 |
+
inputPort.className = 'port input-port';
|
155 |
+
inputPort.setAttribute('data-port-type', 'input');
|
|
|
|
|
|
|
|
|
156 |
|
157 |
+
const outputPort = document.createElement('div');
|
158 |
+
outputPort.className = 'port output-port';
|
159 |
+
outputPort.setAttribute('data-port-type', 'output');
|
160 |
|
161 |
+
// Assemble the node
|
162 |
+
nodeContent.appendChild(shapeInfo);
|
163 |
+
nodeContent.appendChild(paramsSection);
|
164 |
|
165 |
+
canvasNode.appendChild(nodeHeader);
|
166 |
+
canvasNode.appendChild(nodeContent);
|
167 |
+
canvasNode.appendChild(inputPort);
|
168 |
+
canvasNode.appendChild(outputPort);
|
169 |
|
170 |
+
// Add node to the canvas
|
171 |
+
canvas.appendChild(canvasNode);
|
|
|
172 |
|
173 |
+
// Store node configuration
|
174 |
+
canvasNode.layerConfig = nodeConfig;
|
175 |
+
|
176 |
+
// Add event listeners for node manipulation
|
177 |
+
canvasNode.addEventListener('mousedown', startDrag);
|
178 |
+
inputPort.addEventListener('mousedown', (e) => {
|
179 |
+
e.stopPropagation();
|
180 |
+
});
|
181 |
+
outputPort.addEventListener('mousedown', (e) => {
|
182 |
e.stopPropagation();
|
183 |
startConnection(canvasNode, e);
|
184 |
});
|
185 |
|
186 |
+
// Double-click to edit node properties
|
187 |
+
canvasNode.addEventListener('dblclick', () => {
|
188 |
+
openLayerEditor(canvasNode);
|
189 |
});
|
190 |
|
191 |
+
// Right-click to delete
|
192 |
+
canvasNode.addEventListener('contextmenu', (e) => {
|
193 |
+
e.preventDefault();
|
194 |
+
deleteNode(canvasNode);
|
195 |
+
});
|
|
|
|
|
|
|
196 |
|
197 |
+
// Add to network layers for architecture building
|
198 |
+
networkLayers.layers.push({
|
199 |
+
id: layerId,
|
200 |
+
type: nodeType,
|
201 |
+
config: nodeConfig
|
202 |
+
});
|
|
|
203 |
|
204 |
+
// Notify about network changes
|
205 |
+
document.dispatchEvent(new CustomEvent('networkUpdated', {
|
206 |
+
detail: networkLayers
|
207 |
+
}));
|
208 |
+
|
209 |
+
updateConnections();
|
210 |
}
|
211 |
}
|
212 |
|
|
|
478 |
|
479 |
// End creating a connection
|
480 |
function endConnection(targetNode) {
|
481 |
+
if (!isConnecting || !connectionLine || !startNode) return;
|
482 |
+
|
483 |
+
const sourceType = startNode.getAttribute('data-type');
|
484 |
+
const targetType = targetNode.getAttribute('data-type');
|
485 |
+
const sourceId = startNode.getAttribute('data-id');
|
486 |
+
const targetId = targetNode.getAttribute('data-id');
|
487 |
+
|
488 |
+
// Check if this is a valid connection
|
489 |
+
if (isValidConnection(sourceType, targetType, sourceId, targetId)) {
|
490 |
+
// Create a permanent SVG connection
|
491 |
+
const canvas = document.getElementById('network-canvas');
|
492 |
+
const svgContainer = document.querySelector('#network-canvas .svg-container') || createSVGContainer();
|
493 |
+
|
494 |
+
// Get positions for source and target nodes
|
495 |
+
const sourceRect = startNode.getBoundingClientRect();
|
496 |
+
const targetRect = targetNode.getBoundingClientRect();
|
497 |
+
const canvasRect = canvas.getBoundingClientRect();
|
498 |
+
|
499 |
+
// Calculate port positions
|
500 |
+
const sourcePort = startNode.querySelector('.output-port');
|
501 |
+
const targetPort = targetNode.querySelector('.input-port');
|
502 |
+
|
503 |
+
const sourcePortRect = sourcePort.getBoundingClientRect();
|
504 |
+
const targetPortRect = targetPort.getBoundingClientRect();
|
505 |
+
|
506 |
+
const startX = sourcePortRect.left + (sourcePortRect.width / 2) - canvasRect.left;
|
507 |
+
const startY = sourcePortRect.top + (sourcePortRect.height / 2) - canvasRect.top;
|
508 |
+
const endX = targetPortRect.left + (targetPortRect.width / 2) - canvasRect.left;
|
509 |
+
const endY = targetPortRect.top + (targetPortRect.height / 2) - canvasRect.top;
|
510 |
+
|
511 |
+
// Create the connection
|
512 |
+
const pathId = `connection-${sourceId}-${targetId}`;
|
513 |
+
const connectionPath = document.createElementNS('http://www.w3.org/2000/svg', 'path');
|
514 |
+
connectionPath.setAttribute('id', pathId);
|
515 |
+
connectionPath.setAttribute('class', 'connection-line');
|
516 |
+
|
517 |
+
// Curved path (bezier)
|
518 |
+
const dx = Math.abs(endX - startX) * 0.7;
|
519 |
+
const path = `M ${startX} ${startY} C ${startX + dx} ${startY}, ${endX - dx} ${endY}, ${endX} ${endY}`;
|
520 |
+
connectionPath.setAttribute('d', path);
|
521 |
+
|
522 |
+
// Add connection to SVG container
|
523 |
+
svgContainer.appendChild(connectionPath);
|
524 |
+
|
525 |
+
// Add to connections
|
526 |
+
networkLayers.connections.push({
|
527 |
+
id: pathId,
|
528 |
+
source: sourceId,
|
529 |
+
target: targetId,
|
530 |
+
sourceType: sourceType,
|
531 |
+
targetType: targetType
|
532 |
+
});
|
533 |
+
|
534 |
+
// Update input and output shapes
|
535 |
+
updateNodeShapes(sourceId, targetId);
|
536 |
+
|
537 |
+
// Notify about connection
|
538 |
+
document.dispatchEvent(new CustomEvent('networkUpdated', {
|
539 |
+
detail: networkLayers
|
540 |
+
}));
|
541 |
}
|
542 |
|
543 |
+
// Clean up
|
544 |
removePortHighlights();
|
545 |
+
if (connectionLine) {
|
546 |
+
connectionLine.remove();
|
547 |
+
connectionLine = null;
|
548 |
+
}
|
549 |
isConnecting = false;
|
550 |
startNode = null;
|
|
|
|
|
|
|
|
|
|
|
551 |
}
|
552 |
|
553 |
+
// Update input and output shapes when connections are made
|
554 |
+
function updateNodeShapes(sourceId, targetId) {
|
555 |
+
const sourceNode = document.querySelector(`.canvas-node[data-id="${sourceId}"]`);
|
556 |
+
const targetNode = document.querySelector(`.canvas-node[data-id="${targetId}"]`);
|
557 |
+
|
558 |
+
if (sourceNode && targetNode) {
|
559 |
+
const sourceConfig = sourceNode.layerConfig;
|
560 |
+
const targetConfig = targetNode.layerConfig;
|
561 |
+
|
562 |
+
// Update the target's input shape based on the source's output shape
|
563 |
+
if (sourceConfig && targetConfig) {
|
564 |
+
// Calculate output shape based on node type
|
565 |
+
let outputShape;
|
566 |
+
switch (sourceNode.getAttribute('data-type')) {
|
567 |
+
case 'input':
|
568 |
+
outputShape = sourceConfig.shape;
|
569 |
+
break;
|
570 |
+
case 'hidden':
|
571 |
+
outputShape = [sourceConfig.units];
|
572 |
+
break;
|
573 |
+
case 'output':
|
574 |
+
outputShape = [sourceConfig.units];
|
575 |
+
break;
|
576 |
+
case 'conv':
|
577 |
+
// For Conv2D, the output shape depends on the input and parameters
|
578 |
+
// This is a simplified calculation
|
579 |
+
if (targetConfig.inputShape) {
|
580 |
+
const h = targetConfig.inputShape[0];
|
581 |
+
const w = targetConfig.inputShape[1];
|
582 |
+
const kh = sourceConfig.kernelSize[0];
|
583 |
+
const kw = sourceConfig.kernelSize[1];
|
584 |
+
const sh = sourceConfig.strides[0];
|
585 |
+
const sw = sourceConfig.strides[1];
|
586 |
+
const padding = sourceConfig.padding;
|
587 |
+
|
588 |
+
let outHeight, outWidth;
|
589 |
+
if (padding === 'same') {
|
590 |
+
outHeight = Math.ceil(h / sh);
|
591 |
+
outWidth = Math.ceil(w / sw);
|
592 |
+
} else { // 'valid'
|
593 |
+
outHeight = Math.ceil((h - kh + 1) / sh);
|
594 |
+
outWidth = Math.ceil((w - kw + 1) / sw);
|
595 |
+
}
|
596 |
+
|
597 |
+
outputShape = [outHeight, outWidth, sourceConfig.filters];
|
598 |
+
} else {
|
599 |
+
outputShape = ['?', '?', sourceConfig.filters];
|
600 |
+
}
|
601 |
+
break;
|
602 |
+
case 'pool':
|
603 |
+
// For pooling, also depends on the input and parameters
|
604 |
+
if (targetConfig.inputShape) {
|
605 |
+
const h = targetConfig.inputShape[0];
|
606 |
+
const w = targetConfig.inputShape[1];
|
607 |
+
const c = targetConfig.inputShape[2];
|
608 |
+
const ph = sourceConfig.poolSize[0];
|
609 |
+
const pw = sourceConfig.poolSize[1];
|
610 |
+
const sh = sourceConfig.strides[0];
|
611 |
+
const sw = sourceConfig.strides[1];
|
612 |
+
const padding = sourceConfig.padding;
|
613 |
+
|
614 |
+
let outHeight, outWidth;
|
615 |
+
if (padding === 'same') {
|
616 |
+
outHeight = Math.ceil(h / sh);
|
617 |
+
outWidth = Math.ceil(w / sw);
|
618 |
+
} else { // 'valid'
|
619 |
+
outHeight = Math.ceil((h - ph + 1) / sh);
|
620 |
+
outWidth = Math.ceil((w - pw + 1) / sw);
|
621 |
+
}
|
622 |
+
|
623 |
+
outputShape = [outHeight, outWidth, c];
|
624 |
+
} else {
|
625 |
+
outputShape = ['?', '?', '?'];
|
626 |
+
}
|
627 |
+
break;
|
628 |
+
case 'linear':
|
629 |
+
outputShape = [sourceConfig.outputFeatures];
|
630 |
+
break;
|
631 |
+
default:
|
632 |
+
outputShape = ['?', '?', '?'];
|
633 |
+
}
|
634 |
+
|
635 |
+
// Update the target's input shape
|
636 |
+
targetConfig.inputShape = outputShape;
|
637 |
+
|
638 |
+
// Update UI
|
639 |
+
updateNodeDisplayShapes(sourceNode, targetNode);
|
640 |
}
|
641 |
+
}
|
642 |
+
}
|
643 |
+
|
644 |
+
// Update the displayed shapes in the UI
|
645 |
+
function updateNodeDisplayShapes(sourceNode, targetNode) {
|
646 |
+
if (sourceNode && targetNode) {
|
647 |
+
const sourceType = sourceNode.getAttribute('data-type');
|
648 |
+
const targetType = targetNode.getAttribute('data-type');
|
649 |
+
const sourceConfig = sourceNode.layerConfig;
|
650 |
+
const targetConfig = targetNode.layerConfig;
|
651 |
+
|
652 |
+
// Update source node output shape display
|
653 |
+
const sourceOutputElem = sourceNode.querySelector('.output-shape');
|
654 |
+
if (sourceOutputElem && sourceConfig) {
|
655 |
+
let outputText;
|
656 |
+
switch (sourceType) {
|
657 |
+
case 'input':
|
658 |
+
outputText = `[${sourceConfig.shape.join(' × ')}]`;
|
659 |
+
break;
|
660 |
+
case 'hidden':
|
661 |
+
case 'output':
|
662 |
+
outputText = `[${sourceConfig.units}]`;
|
663 |
+
break;
|
664 |
+
case 'conv':
|
665 |
+
if (sourceConfig.outputShape) {
|
666 |
+
outputText = `[${sourceConfig.outputShape.join(' × ')}]`;
|
667 |
+
} else {
|
668 |
+
outputText = `[? × ? × ${sourceConfig.filters}]`;
|
669 |
+
}
|
670 |
+
break;
|
671 |
+
case 'pool':
|
672 |
+
if (sourceConfig.outputShape) {
|
673 |
+
outputText = `[${sourceConfig.outputShape.join(' × ')}]`;
|
674 |
+
} else {
|
675 |
+
outputText = 'Depends on input';
|
676 |
+
}
|
677 |
+
break;
|
678 |
+
case 'linear':
|
679 |
+
outputText = `[${sourceConfig.outputFeatures}]`;
|
680 |
+
break;
|
681 |
+
default:
|
682 |
+
outputText = 'Unknown';
|
683 |
+
}
|
684 |
+
sourceOutputElem.textContent = outputText;
|
685 |
+
}
|
686 |
+
|
687 |
+
// Update target node input shape display
|
688 |
+
const targetInputElem = targetNode.querySelector('.input-shape');
|
689 |
+
if (targetInputElem && targetConfig && targetConfig.inputShape) {
|
690 |
+
targetInputElem.textContent = `[${targetConfig.inputShape.join(' × ')}]`;
|
691 |
+
|
692 |
+
// Update parameters section
|
693 |
+
const targetParamsElem = targetNode.querySelector('.params-display');
|
694 |
+
if (targetParamsElem) {
|
695 |
+
// Calculate and display parameters
|
696 |
+
let paramsText = '';
|
697 |
+
switch (targetType) {
|
698 |
+
case 'hidden':
|
699 |
+
const inputUnits = Array.isArray(targetConfig.inputShape) ?
|
700 |
+
targetConfig.inputShape.reduce((acc, val) => acc * val, 1) :
|
701 |
+
targetConfig.inputShape;
|
702 |
+
|
703 |
+
const biasParams = targetConfig.useBias ? targetConfig.units : 0;
|
704 |
+
const totalParams = (inputUnits * targetConfig.units) + biasParams;
|
705 |
+
|
706 |
+
paramsText = `In: ${inputUnits}, Out: ${targetConfig.units}\nParams: ${totalParams.toLocaleString()}\nDropout: ${targetConfig.dropoutRate}`;
|
707 |
+
break;
|
708 |
+
case 'output':
|
709 |
+
const outInputUnits = Array.isArray(targetConfig.inputShape) ?
|
710 |
+
targetConfig.inputShape.reduce((acc, val) => acc * val, 1) :
|
711 |
+
targetConfig.inputShape;
|
712 |
+
|
713 |
+
const outBiasParams = targetConfig.useBias ? targetConfig.units : 0;
|
714 |
+
const outTotalParams = (outInputUnits * targetConfig.units) + outBiasParams;
|
715 |
+
|
716 |
+
paramsText = `In: ${outInputUnits}, Out: ${targetConfig.units}\nParams: ${outTotalParams.toLocaleString()}\nActivation: ${targetConfig.activation}`;
|
717 |
+
break;
|
718 |
+
case 'conv':
|
719 |
+
const channels = targetConfig.inputShape[2] || '?';
|
720 |
+
const kernelH = targetConfig.kernelSize[0];
|
721 |
+
const kernelW = targetConfig.kernelSize[1];
|
722 |
+
const kernelParams = kernelH * kernelW * channels * targetConfig.filters;
|
723 |
+
const convBiasParams = targetConfig.useBias ? targetConfig.filters : 0;
|
724 |
+
const convTotalParams = kernelParams + convBiasParams;
|
725 |
+
|
726 |
+
paramsText = `In: ${channels}, Out: ${targetConfig.filters}\nKernel: ${targetConfig.kernelSize.join('×')}\nStride: ${targetConfig.strides.join('×')}\nPadding: ${targetConfig.padding}\nParams: ${convTotalParams.toLocaleString()}`;
|
727 |
+
break;
|
728 |
+
case 'pool':
|
729 |
+
paramsText = `Pool size: ${targetConfig.poolSize.join('×')}\nStride: ${targetConfig.strides.join('×')}\nPadding: ${targetConfig.padding}\nParams: 0`;
|
730 |
+
break;
|
731 |
+
case 'linear':
|
732 |
+
const linearInputs = targetConfig.inputFeatures;
|
733 |
+
const linearBiasParams = targetConfig.useBias ? targetConfig.outputFeatures : 0;
|
734 |
+
const linearTotalParams = (linearInputs * targetConfig.outputFeatures) + linearBiasParams;
|
735 |
+
|
736 |
+
paramsText = `In: ${linearInputs}, Out: ${targetConfig.outputFeatures}\nParams: ${linearTotalParams.toLocaleString()}\nLearning Rate: ${targetConfig.learningRate}\nLoss: ${targetConfig.lossFunction}`;
|
737 |
+
break;
|
738 |
+
}
|
739 |
+
|
740 |
+
targetParamsElem.textContent = paramsText;
|
741 |
+
}
|
742 |
+
}
|
743 |
+
}
|
744 |
}
|
745 |
|
746 |
// Delete a node and its connections
|
js/main.js
CHANGED
@@ -299,214 +299,352 @@ document.addEventListener('DOMContentLoaded', () => {
|
|
299 |
|
300 |
// Handle opening the layer editor
|
301 |
function handleOpenLayerEditor(e) {
|
302 |
-
const
|
303 |
-
|
|
|
304 |
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
// Get the form and populate it
|
309 |
-
const layerForm = layerEditorModal.querySelector('.layer-form');
|
310 |
-
if (!layerForm) return;
|
311 |
|
312 |
-
//
|
313 |
-
|
314 |
-
layerForm.setAttribute('data-layer-type', layerDetails.type);
|
315 |
-
|
316 |
-
// Set modal title
|
317 |
-
const modalTitle = layerEditorModal.querySelector('.modal-title');
|
318 |
if (modalTitle) {
|
319 |
-
modalTitle.textContent = `Edit ${
|
320 |
}
|
321 |
|
322 |
-
// Get layer
|
323 |
-
const
|
|
|
324 |
|
325 |
-
//
|
326 |
layerForm.innerHTML = '';
|
327 |
|
328 |
-
//
|
329 |
-
|
330 |
-
<div class="form-group">
|
331 |
-
<label for="layer-name">Layer Name</label>
|
332 |
-
<input type="text" id="layer-name" value="${layerDetails.name}">
|
333 |
-
</div>
|
334 |
-
`;
|
335 |
-
|
336 |
-
// Add type-specific fields
|
337 |
-
switch (layerDetails.type) {
|
338 |
case 'input':
|
|
|
339 |
layerForm.innerHTML += `
|
340 |
<div class="form-group">
|
341 |
-
<label
|
342 |
-
<
|
|
|
|
|
|
|
|
|
|
|
343 |
</div>
|
344 |
<div class="form-group">
|
345 |
-
<label
|
346 |
-
<input type="number" id="batch-size" value="${layerConfig.batchSize}">
|
347 |
</div>
|
348 |
`;
|
349 |
break;
|
350 |
|
351 |
case 'hidden':
|
|
|
352 |
layerForm.innerHTML += `
|
353 |
<div class="form-group">
|
354 |
-
<label
|
355 |
-
<input type="number" id="units" value="${layerConfig.units}">
|
|
|
356 |
</div>
|
357 |
<div class="form-group">
|
358 |
-
<label
|
359 |
-
<select id="activation">
|
360 |
<option value="relu" ${layerConfig.activation === 'relu' ? 'selected' : ''}>ReLU</option>
|
361 |
<option value="sigmoid" ${layerConfig.activation === 'sigmoid' ? 'selected' : ''}>Sigmoid</option>
|
362 |
<option value="tanh" ${layerConfig.activation === 'tanh' ? 'selected' : ''}>Tanh</option>
|
363 |
-
<option value="
|
364 |
</select>
|
365 |
</div>
|
366 |
<div class="form-group">
|
367 |
-
<label
|
368 |
-
<
|
369 |
-
|
370 |
-
<option value="false" ${!layerConfig.useBias ? 'selected' : ''}>No</option>
|
371 |
-
</select>
|
372 |
</div>
|
373 |
<div class="form-group">
|
374 |
-
<label
|
375 |
-
<input type="
|
376 |
</div>
|
377 |
`;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
378 |
break;
|
379 |
|
380 |
case 'output':
|
|
|
381 |
layerForm.innerHTML += `
|
382 |
<div class="form-group">
|
383 |
-
<label
|
384 |
-
<input type="number" id="units" value="${layerConfig.units}">
|
|
|
385 |
</div>
|
386 |
<div class="form-group">
|
387 |
-
<label
|
388 |
-
<select id="activation">
|
389 |
-
<option value="softmax" ${layerConfig.activation === 'softmax' ? 'selected' : ''}>Softmax</option>
|
390 |
-
<option value="sigmoid" ${layerConfig.activation === 'sigmoid' ? 'selected' : ''}>Sigmoid</option>
|
391 |
-
<option value="linear" ${layerConfig.activation === 'linear' ? 'selected' : ''}>Linear</option>
|
392 |
</select>
|
393 |
</div>
|
|
|
|
|
|
|
|
|
394 |
`;
|
395 |
break;
|
396 |
|
397 |
case 'conv':
|
|
|
398 |
layerForm.innerHTML += `
|
399 |
<div class="form-group">
|
400 |
-
<label
|
401 |
-
<input type="number" id="filters" value="${layerConfig.filters}">
|
|
|
402 |
</div>
|
403 |
<div class="form-group">
|
404 |
-
<label
|
405 |
-
<
|
|
|
|
|
|
|
|
|
406 |
</div>
|
407 |
<div class="form-group">
|
408 |
-
<label
|
409 |
-
<
|
|
|
|
|
|
|
410 |
</div>
|
411 |
<div class="form-group">
|
412 |
-
<label
|
413 |
-
<select id="padding">
|
414 |
-
<option value="valid" ${layerConfig.padding === 'valid' ? 'selected' : ''}>Valid</option>
|
415 |
-
<option value="same" ${layerConfig.padding === 'same' ? 'selected' : ''}>Same</option>
|
416 |
</select>
|
417 |
</div>
|
418 |
<div class="form-group">
|
419 |
-
<label
|
420 |
-
<select id="activation">
|
421 |
<option value="relu" ${layerConfig.activation === 'relu' ? 'selected' : ''}>ReLU</option>
|
422 |
<option value="sigmoid" ${layerConfig.activation === 'sigmoid' ? 'selected' : ''}>Sigmoid</option>
|
423 |
<option value="tanh" ${layerConfig.activation === 'tanh' ? 'selected' : ''}>Tanh</option>
|
424 |
-
<option value="
|
425 |
</select>
|
426 |
</div>
|
427 |
`;
|
428 |
break;
|
429 |
|
430 |
case 'pool':
|
|
|
431 |
layerForm.innerHTML += `
|
432 |
<div class="form-group">
|
433 |
-
<label
|
434 |
-
<
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
435 |
</div>
|
436 |
<div class="form-group">
|
437 |
-
<label
|
438 |
-
<
|
|
|
|
|
|
|
439 |
</div>
|
440 |
<div class="form-group">
|
441 |
-
<label
|
442 |
-
<select id="
|
443 |
-
<option value="
|
444 |
-
<option value="
|
445 |
</select>
|
446 |
</div>
|
447 |
`;
|
448 |
break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
449 |
}
|
450 |
|
451 |
-
// Add
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
452 |
layerForm.innerHTML += `
|
453 |
-
<div class="form-
|
454 |
-
<button type="button"
|
|
|
455 |
</div>
|
456 |
`;
|
457 |
|
458 |
-
//
|
459 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
460 |
}
|
461 |
|
462 |
// Save layer configuration
|
463 |
-
function saveLayerConfig() {
|
464 |
-
|
465 |
-
|
466 |
-
|
467 |
-
const layerForm = layerEditorModal.querySelector('.layer-form');
|
468 |
-
if (!layerForm) return;
|
469 |
-
|
470 |
-
const layerId = layerForm.getAttribute('data-layer-id');
|
471 |
-
const layerType = layerForm.getAttribute('data-layer-type');
|
472 |
|
473 |
-
|
474 |
-
const
|
475 |
-
|
|
|
|
|
476 |
|
477 |
-
//
|
478 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
479 |
|
480 |
// Update node title
|
481 |
const nodeTitle = node.querySelector('.node-title');
|
482 |
if (nodeTitle) {
|
483 |
-
nodeTitle.textContent =
|
484 |
}
|
485 |
|
486 |
// Update node data attribute
|
487 |
-
node.setAttribute('data-name',
|
488 |
|
489 |
// Update dimensions based on layer type
|
490 |
let dimensions = '';
|
491 |
-
switch (
|
492 |
case 'input':
|
493 |
-
|
494 |
-
dimensions = inputShape;
|
495 |
break;
|
496 |
|
497 |
case 'hidden':
|
498 |
case 'output':
|
499 |
-
|
500 |
-
dimensions = units;
|
501 |
break;
|
502 |
|
503 |
case 'conv':
|
504 |
-
|
505 |
-
dimensions = `${filters} × 26 × 26`; // Simplified
|
506 |
break;
|
507 |
|
508 |
case 'pool':
|
509 |
-
dimensions = '
|
|
|
|
|
|
|
|
|
510 |
break;
|
511 |
}
|
512 |
|
@@ -524,16 +662,13 @@ document.addEventListener('DOMContentLoaded', () => {
|
|
524 |
const layerIndex = networkLayers.layers.findIndex(layer => layer.id === layerId);
|
525 |
|
526 |
if (layerIndex !== -1) {
|
527 |
-
networkLayers.layers[layerIndex].name =
|
528 |
networkLayers.layers[layerIndex].dimensions = dimensions;
|
529 |
}
|
530 |
|
531 |
// Trigger network updated event
|
532 |
const event = new CustomEvent('networkUpdated', { detail: networkLayers });
|
533 |
document.dispatchEvent(event);
|
534 |
-
|
535 |
-
// Close the modal
|
536 |
-
closeModal(layerEditorModal);
|
537 |
}
|
538 |
|
539 |
// Handle sample selection
|
|
|
299 |
|
300 |
// Handle opening the layer editor
|
301 |
function handleOpenLayerEditor(e) {
|
302 |
+
const node = e.detail.node;
|
303 |
+
const nodeType = node.getAttribute('data-type');
|
304 |
+
const layerId = node.getAttribute('data-id');
|
305 |
|
306 |
+
// Get current configuration
|
307 |
+
const layerConfig = node.layerConfig || window.neuralNetwork.createNodeConfig(nodeType);
|
|
|
|
|
|
|
|
|
308 |
|
309 |
+
// Update modal title
|
310 |
+
const modalTitle = document.querySelector('.layer-editor-modal .modal-title');
|
|
|
|
|
|
|
|
|
311 |
if (modalTitle) {
|
312 |
+
modalTitle.textContent = `Edit ${nodeType.charAt(0).toUpperCase() + nodeType.slice(1)} Layer`;
|
313 |
}
|
314 |
|
315 |
+
// Get layer form
|
316 |
+
const layerForm = document.querySelector('.layer-form');
|
317 |
+
if (!layerForm) return;
|
318 |
|
319 |
+
// Clear previous form fields
|
320 |
layerForm.innerHTML = '';
|
321 |
|
322 |
+
// Create form fields based on layer type
|
323 |
+
switch (nodeType) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
324 |
case 'input':
|
325 |
+
// Input shape fields
|
326 |
layerForm.innerHTML += `
|
327 |
<div class="form-group">
|
328 |
+
<label>Input Dimensions:</label>
|
329 |
+
<div class="form-row">
|
330 |
+
<input type="number" id="input-height" min="1" value="${layerConfig.shape[0]}" placeholder="Height">
|
331 |
+
<input type="number" id="input-width" min="1" value="${layerConfig.shape[1]}" placeholder="Width">
|
332 |
+
<input type="number" id="input-channels" min="1" value="${layerConfig.shape[2]}" placeholder="Channels">
|
333 |
+
</div>
|
334 |
+
<small>Input shape: [${layerConfig.shape.join(' × ')}]</small>
|
335 |
</div>
|
336 |
<div class="form-group">
|
337 |
+
<label>Batch Size:</label>
|
338 |
+
<input type="number" id="batch-size" min="1" value="${layerConfig.batchSize}" placeholder="Batch Size">
|
339 |
</div>
|
340 |
`;
|
341 |
break;
|
342 |
|
343 |
case 'hidden':
|
344 |
+
// Units and activation function
|
345 |
layerForm.innerHTML += `
|
346 |
<div class="form-group">
|
347 |
+
<label>Units:</label>
|
348 |
+
<input type="number" id="hidden-units" min="1" value="${layerConfig.units}" placeholder="Number of units">
|
349 |
+
<small>Output shape: [${layerConfig.units}]</small>
|
350 |
</div>
|
351 |
<div class="form-group">
|
352 |
+
<label>Activation Function:</label>
|
353 |
+
<select id="hidden-activation">
|
354 |
<option value="relu" ${layerConfig.activation === 'relu' ? 'selected' : ''}>ReLU</option>
|
355 |
<option value="sigmoid" ${layerConfig.activation === 'sigmoid' ? 'selected' : ''}>Sigmoid</option>
|
356 |
<option value="tanh" ${layerConfig.activation === 'tanh' ? 'selected' : ''}>Tanh</option>
|
357 |
+
<option value="leaky_relu" ${layerConfig.activation === 'leaky_relu' ? 'selected' : ''}>Leaky ReLU</option>
|
358 |
</select>
|
359 |
</div>
|
360 |
<div class="form-group">
|
361 |
+
<label>Dropout Rate:</label>
|
362 |
+
<input type="range" id="dropout-rate" min="0" max="0.9" step="0.1" value="${layerConfig.dropoutRate}">
|
363 |
+
<span id="dropout-value">${layerConfig.dropoutRate}</span>
|
|
|
|
|
364 |
</div>
|
365 |
<div class="form-group">
|
366 |
+
<label>Use Bias:</label>
|
367 |
+
<input type="checkbox" id="use-bias" ${layerConfig.useBias ? 'checked' : ''}>
|
368 |
</div>
|
369 |
`;
|
370 |
+
|
371 |
+
// Add listener for dropout rate slider
|
372 |
+
setTimeout(() => {
|
373 |
+
const dropoutSlider = document.getElementById('dropout-rate');
|
374 |
+
const dropoutValue = document.getElementById('dropout-value');
|
375 |
+
if (dropoutSlider && dropoutValue) {
|
376 |
+
dropoutSlider.addEventListener('input', (e) => {
|
377 |
+
dropoutValue.textContent = e.target.value;
|
378 |
+
});
|
379 |
+
}
|
380 |
+
}, 100);
|
381 |
break;
|
382 |
|
383 |
case 'output':
|
384 |
+
// Output units and activation
|
385 |
layerForm.innerHTML += `
|
386 |
<div class="form-group">
|
387 |
+
<label>Units:</label>
|
388 |
+
<input type="number" id="output-units" min="1" value="${layerConfig.units}" placeholder="Number of output units">
|
389 |
+
<small>Output shape: [${layerConfig.units}]</small>
|
390 |
</div>
|
391 |
<div class="form-group">
|
392 |
+
<label>Activation Function:</label>
|
393 |
+
<select id="output-activation">
|
394 |
+
<option value="softmax" ${layerConfig.activation === 'softmax' ? 'selected' : ''}>Softmax (Classification)</option>
|
395 |
+
<option value="sigmoid" ${layerConfig.activation === 'sigmoid' ? 'selected' : ''}>Sigmoid (Binary Classification)</option>
|
396 |
+
<option value="linear" ${layerConfig.activation === 'linear' ? 'selected' : ''}>Linear (Regression)</option>
|
397 |
</select>
|
398 |
</div>
|
399 |
+
<div class="form-group">
|
400 |
+
<label>Use Bias:</label>
|
401 |
+
<input type="checkbox" id="output-use-bias" ${layerConfig.useBias ? 'checked' : ''}>
|
402 |
+
</div>
|
403 |
`;
|
404 |
break;
|
405 |
|
406 |
case 'conv':
|
407 |
+
// Convolutional layer parameters
|
408 |
layerForm.innerHTML += `
|
409 |
<div class="form-group">
|
410 |
+
<label>Filters:</label>
|
411 |
+
<input type="number" id="conv-filters" min="1" value="${layerConfig.filters}" placeholder="Number of filters">
|
412 |
+
<small>Output channels</small>
|
413 |
</div>
|
414 |
<div class="form-group">
|
415 |
+
<label>Kernel Size:</label>
|
416 |
+
<div class="form-row">
|
417 |
+
<input type="number" id="kernel-size-h" min="1" max="7" value="${layerConfig.kernelSize[0]}" placeholder="Height">
|
418 |
+
<input type="number" id="kernel-size-w" min="1" max="7" value="${layerConfig.kernelSize[1]}" placeholder="Width">
|
419 |
+
</div>
|
420 |
+
<small>Filter dimensions: ${layerConfig.kernelSize.join(' × ')}</small>
|
421 |
</div>
|
422 |
<div class="form-group">
|
423 |
+
<label>Strides:</label>
|
424 |
+
<div class="form-row">
|
425 |
+
<input type="number" id="stride-h" min="1" max="4" value="${layerConfig.strides[0]}" placeholder="Height">
|
426 |
+
<input type="number" id="stride-w" min="1" max="4" value="${layerConfig.strides[1]}" placeholder="Width">
|
427 |
+
</div>
|
428 |
</div>
|
429 |
<div class="form-group">
|
430 |
+
<label>Padding:</label>
|
431 |
+
<select id="padding-type">
|
432 |
+
<option value="valid" ${layerConfig.padding === 'valid' ? 'selected' : ''}>Valid (No Padding)</option>
|
433 |
+
<option value="same" ${layerConfig.padding === 'same' ? 'selected' : ''}>Same (Preserve Dimensions)</option>
|
434 |
</select>
|
435 |
</div>
|
436 |
<div class="form-group">
|
437 |
+
<label>Activation Function:</label>
|
438 |
+
<select id="conv-activation">
|
439 |
<option value="relu" ${layerConfig.activation === 'relu' ? 'selected' : ''}>ReLU</option>
|
440 |
<option value="sigmoid" ${layerConfig.activation === 'sigmoid' ? 'selected' : ''}>Sigmoid</option>
|
441 |
<option value="tanh" ${layerConfig.activation === 'tanh' ? 'selected' : ''}>Tanh</option>
|
442 |
+
<option value="leaky_relu" ${layerConfig.activation === 'leaky_relu' ? 'selected' : ''}>Leaky ReLU</option>
|
443 |
</select>
|
444 |
</div>
|
445 |
`;
|
446 |
break;
|
447 |
|
448 |
case 'pool':
|
449 |
+
// Pooling layer parameters
|
450 |
layerForm.innerHTML += `
|
451 |
<div class="form-group">
|
452 |
+
<label>Pool Size:</label>
|
453 |
+
<div class="form-row">
|
454 |
+
<input type="number" id="pool-size-h" min="1" max="4" value="${layerConfig.poolSize[0]}" placeholder="Height">
|
455 |
+
<input type="number" id="pool-size-w" min="1" max="4" value="${layerConfig.poolSize[1]}" placeholder="Width">
|
456 |
+
</div>
|
457 |
+
</div>
|
458 |
+
<div class="form-group">
|
459 |
+
<label>Strides:</label>
|
460 |
+
<div class="form-row">
|
461 |
+
<input type="number" id="pool-stride-h" min="1" max="4" value="${layerConfig.strides[0]}" placeholder="Height">
|
462 |
+
<input type="number" id="pool-stride-w" min="1" max="4" value="${layerConfig.strides[1]}" placeholder="Width">
|
463 |
+
</div>
|
464 |
</div>
|
465 |
<div class="form-group">
|
466 |
+
<label>Padding:</label>
|
467 |
+
<select id="pool-padding">
|
468 |
+
<option value="valid" ${layerConfig.padding === 'valid' ? 'selected' : ''}>Valid (No Padding)</option>
|
469 |
+
<option value="same" ${layerConfig.padding === 'same' ? 'selected' : ''}>Same (Preserve Dimensions)</option>
|
470 |
+
</select>
|
471 |
</div>
|
472 |
<div class="form-group">
|
473 |
+
<label>Pool Type:</label>
|
474 |
+
<select id="pool-type">
|
475 |
+
<option value="max" selected>Max Pooling</option>
|
476 |
+
<option value="avg">Average Pooling</option>
|
477 |
</select>
|
478 |
</div>
|
479 |
`;
|
480 |
break;
|
481 |
+
|
482 |
+
case 'linear':
|
483 |
+
// Linear regression layer parameters
|
484 |
+
layerForm.innerHTML += `
|
485 |
+
<div class="form-group">
|
486 |
+
<label>Input Features:</label>
|
487 |
+
<input type="number" id="input-features" min="1" value="${layerConfig.inputFeatures}" placeholder="Number of input features">
|
488 |
+
<small>Input shape: [${layerConfig.inputFeatures}]</small>
|
489 |
+
</div>
|
490 |
+
<div class="form-group">
|
491 |
+
<label>Output Features:</label>
|
492 |
+
<input type="number" id="output-features" min="1" value="${layerConfig.outputFeatures}" placeholder="Number of output features">
|
493 |
+
<small>Output shape: [${layerConfig.outputFeatures}]</small>
|
494 |
+
</div>
|
495 |
+
<div class="form-group">
|
496 |
+
<label>Use Bias:</label>
|
497 |
+
<input type="checkbox" id="linear-use-bias" ${layerConfig.useBias ? 'checked' : ''}>
|
498 |
+
</div>
|
499 |
+
<div class="form-group">
|
500 |
+
<label>Learning Rate:</label>
|
501 |
+
<input type="range" id="learning-rate-slider" min="0.001" max="0.1" step="0.001" value="${layerConfig.learningRate}">
|
502 |
+
<span id="learning-rate-value">${layerConfig.learningRate}</span>
|
503 |
+
</div>
|
504 |
+
<div class="form-group">
|
505 |
+
<label>Loss Function:</label>
|
506 |
+
<select id="loss-function">
|
507 |
+
<option value="mse" ${layerConfig.lossFunction === 'mse' ? 'selected' : ''}>Mean Squared Error</option>
|
508 |
+
<option value="mae" ${layerConfig.lossFunction === 'mae' ? 'selected' : ''}>Mean Absolute Error</option>
|
509 |
+
<option value="huber" ${layerConfig.lossFunction === 'huber' ? 'selected' : ''}>Huber Loss</option>
|
510 |
+
</select>
|
511 |
+
</div>
|
512 |
+
<div class="form-group">
|
513 |
+
<label>Optimizer:</label>
|
514 |
+
<select id="optimizer">
|
515 |
+
<option value="sgd" ${layerConfig.optimizer === 'sgd' ? 'selected' : ''}>Stochastic Gradient Descent</option>
|
516 |
+
<option value="adam" ${layerConfig.optimizer === 'adam' ? 'selected' : ''}>Adam</option>
|
517 |
+
<option value="rmsprop" ${layerConfig.optimizer === 'rmsprop' ? 'selected' : ''}>RMSprop</option>
|
518 |
+
</select>
|
519 |
+
</div>
|
520 |
+
`;
|
521 |
+
|
522 |
+
// Add listener for learning rate slider
|
523 |
+
setTimeout(() => {
|
524 |
+
const learningRateSlider = document.getElementById('learning-rate-slider');
|
525 |
+
const learningRateValue = document.getElementById('learning-rate-value');
|
526 |
+
if (learningRateSlider && learningRateValue) {
|
527 |
+
learningRateSlider.addEventListener('input', (e) => {
|
528 |
+
learningRateValue.textContent = parseFloat(e.target.value).toFixed(3);
|
529 |
+
});
|
530 |
+
}
|
531 |
+
}, 100);
|
532 |
+
break;
|
533 |
+
|
534 |
+
default:
|
535 |
+
layerForm.innerHTML = '<p>No editable properties for this layer type.</p>';
|
536 |
}
|
537 |
|
538 |
+
// Add a preview of calculated parameters if available
|
539 |
+
if (nodeType !== 'input') {
|
540 |
+
const parameterCount = window.neuralNetwork.calculateParameters(nodeType, layerConfig);
|
541 |
+
if (parameterCount) {
|
542 |
+
layerForm.innerHTML += `
|
543 |
+
<div class="form-group">
|
544 |
+
<label>Parameter Summary:</label>
|
545 |
+
<div class="parameters-summary">
|
546 |
+
<p>Total parameters: <strong>${formatNumber(parameterCount)}</strong></p>
|
547 |
+
<p>Memory usage (32-bit): ~${formatMemorySize(parameterCount * 4)}</p>
|
548 |
+
</div>
|
549 |
+
</div>
|
550 |
+
`;
|
551 |
+
}
|
552 |
+
}
|
553 |
+
|
554 |
+
// Add save and cancel buttons
|
555 |
layerForm.innerHTML += `
|
556 |
+
<div class="form-buttons">
|
557 |
+
<button type="button" id="save-layer-config" class="btn-primary">Save Changes</button>
|
558 |
+
<button type="button" id="cancel-layer-edit" class="btn-secondary">Cancel</button>
|
559 |
</div>
|
560 |
`;
|
561 |
|
562 |
+
// Open the modal
|
563 |
+
const modal = document.getElementById('layer-editor-modal');
|
564 |
+
if (modal) {
|
565 |
+
openModal(modal);
|
566 |
+
|
567 |
+
// Add event listeners for buttons
|
568 |
+
const saveButton = document.getElementById('save-layer-config');
|
569 |
+
if (saveButton) {
|
570 |
+
saveButton.addEventListener('click', () => {
|
571 |
+
saveLayerConfig(node, nodeType, layerId);
|
572 |
+
closeModal(modal);
|
573 |
+
});
|
574 |
+
}
|
575 |
+
|
576 |
+
const cancelButton = document.getElementById('cancel-layer-edit');
|
577 |
+
if (cancelButton) {
|
578 |
+
cancelButton.addEventListener('click', () => {
|
579 |
+
closeModal(modal);
|
580 |
+
});
|
581 |
+
}
|
582 |
+
}
|
583 |
}
|
584 |
|
585 |
// Save layer configuration
|
586 |
+
function saveLayerConfig(node, nodeType, layerId) {
|
587 |
+
// Get form values
|
588 |
+
const form = document.querySelector('.layer-form');
|
589 |
+
if (!form) return;
|
|
|
|
|
|
|
|
|
|
|
590 |
|
591 |
+
const values = {};
|
592 |
+
const inputs = form.querySelectorAll('input, select');
|
593 |
+
inputs.forEach(input => {
|
594 |
+
values[input.id] = input.value;
|
595 |
+
});
|
596 |
|
597 |
+
// Update node configuration
|
598 |
+
node.layerConfig = {
|
599 |
+
type: nodeType,
|
600 |
+
shape: [
|
601 |
+
parseInt(values['input-height']),
|
602 |
+
parseInt(values['input-width']),
|
603 |
+
parseInt(values['input-channels'])
|
604 |
+
],
|
605 |
+
batchSize: parseInt(values['batch-size']),
|
606 |
+
units: parseInt(values['hidden-units']),
|
607 |
+
activation: values['hidden-activation'],
|
608 |
+
dropoutRate: parseFloat(values['dropout-rate']),
|
609 |
+
useBias: values['use-bias'] === 'true',
|
610 |
+
learningRate: parseFloat(values['learning-rate-slider']),
|
611 |
+
lossFunction: values['loss-function'],
|
612 |
+
optimizer: values['optimizer'],
|
613 |
+
inputFeatures: parseInt(values['input-features']),
|
614 |
+
outputFeatures: parseInt(values['output-features'])
|
615 |
+
};
|
616 |
|
617 |
// Update node title
|
618 |
const nodeTitle = node.querySelector('.node-title');
|
619 |
if (nodeTitle) {
|
620 |
+
nodeTitle.textContent = nodeType.charAt(0).toUpperCase() + nodeType.slice(1);
|
621 |
}
|
622 |
|
623 |
// Update node data attribute
|
624 |
+
node.setAttribute('data-name', nodeType.charAt(0).toUpperCase() + nodeType.slice(1));
|
625 |
|
626 |
// Update dimensions based on layer type
|
627 |
let dimensions = '';
|
628 |
+
switch (nodeType) {
|
629 |
case 'input':
|
630 |
+
dimensions = values['input-height'] + ' × ' + values['input-width'] + ' × ' + values['input-channels'];
|
|
|
631 |
break;
|
632 |
|
633 |
case 'hidden':
|
634 |
case 'output':
|
635 |
+
dimensions = values['hidden-units'];
|
|
|
636 |
break;
|
637 |
|
638 |
case 'conv':
|
639 |
+
dimensions = values['conv-filters'] + ' × ' + values['kernel-size-h'] + ' × ' + values['kernel-size-w'];
|
|
|
640 |
break;
|
641 |
|
642 |
case 'pool':
|
643 |
+
dimensions = values['pool-size-h'] + ' × ' + values['pool-size-w'];
|
644 |
+
break;
|
645 |
+
|
646 |
+
case 'linear':
|
647 |
+
dimensions = values['input-features'] + ' → ' + values['output-features'];
|
648 |
break;
|
649 |
}
|
650 |
|
|
|
662 |
const layerIndex = networkLayers.layers.findIndex(layer => layer.id === layerId);
|
663 |
|
664 |
if (layerIndex !== -1) {
|
665 |
+
networkLayers.layers[layerIndex].name = nodeType.charAt(0).toUpperCase() + nodeType.slice(1);
|
666 |
networkLayers.layers[layerIndex].dimensions = dimensions;
|
667 |
}
|
668 |
|
669 |
// Trigger network updated event
|
670 |
const event = new CustomEvent('networkUpdated', { detail: networkLayers });
|
671 |
document.dispatchEvent(event);
|
|
|
|
|
|
|
672 |
}
|
673 |
|
674 |
// Handle sample selection
|
js/neural-network.js
CHANGED
@@ -11,7 +11,8 @@
|
|
11 |
'hidden': 0,
|
12 |
'output': 0,
|
13 |
'conv': 0,
|
14 |
-
'pool': 0
|
|
|
15 |
};
|
16 |
|
17 |
// Default configuration templates for different layer types
|
@@ -21,7 +22,9 @@
|
|
21 |
shape: [28, 28, 1],
|
22 |
batchSize: 32,
|
23 |
description: 'Input layer for raw data',
|
24 |
-
parameters: 0
|
|
|
|
|
25 |
},
|
26 |
'hidden': {
|
27 |
units: 128,
|
@@ -30,7 +33,9 @@
|
|
30 |
kernelInitializer: 'glorotUniform',
|
31 |
biasInitializer: 'zeros',
|
32 |
dropoutRate: 0.2,
|
33 |
-
description: 'Dense hidden layer with ReLU activation'
|
|
|
|
|
34 |
},
|
35 |
'output': {
|
36 |
units: 10,
|
@@ -38,7 +43,9 @@
|
|
38 |
useBias: true,
|
39 |
kernelInitializer: 'glorotUniform',
|
40 |
biasInitializer: 'zeros',
|
41 |
-
description: 'Output layer with Softmax activation for classification'
|
|
|
|
|
42 |
},
|
43 |
'conv': {
|
44 |
filters: 32,
|
@@ -49,13 +56,31 @@
|
|
49 |
useBias: true,
|
50 |
kernelInitializer: 'glorotUniform',
|
51 |
biasInitializer: 'zeros',
|
52 |
-
description: 'Convolutional layer for feature extraction'
|
|
|
|
|
53 |
},
|
54 |
'pool': {
|
55 |
poolSize: [2, 2],
|
56 |
strides: [2, 2],
|
57 |
padding: 'valid',
|
58 |
-
description: 'Max pooling layer for spatial downsampling'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
}
|
60 |
};
|
61 |
|
|
|
11 |
'hidden': 0,
|
12 |
'output': 0,
|
13 |
'conv': 0,
|
14 |
+
'pool': 0,
|
15 |
+
'linear': 0
|
16 |
};
|
17 |
|
18 |
// Default configuration templates for different layer types
|
|
|
22 |
shape: [28, 28, 1],
|
23 |
batchSize: 32,
|
24 |
description: 'Input layer for raw data',
|
25 |
+
parameters: 0,
|
26 |
+
inputShape: null,
|
27 |
+
outputShape: [784]
|
28 |
},
|
29 |
'hidden': {
|
30 |
units: 128,
|
|
|
33 |
kernelInitializer: 'glorotUniform',
|
34 |
biasInitializer: 'zeros',
|
35 |
dropoutRate: 0.2,
|
36 |
+
description: 'Dense hidden layer with ReLU activation',
|
37 |
+
inputShape: null,
|
38 |
+
outputShape: null
|
39 |
},
|
40 |
'output': {
|
41 |
units: 10,
|
|
|
43 |
useBias: true,
|
44 |
kernelInitializer: 'glorotUniform',
|
45 |
biasInitializer: 'zeros',
|
46 |
+
description: 'Output layer with Softmax activation for classification',
|
47 |
+
inputShape: null,
|
48 |
+
outputShape: [10]
|
49 |
},
|
50 |
'conv': {
|
51 |
filters: 32,
|
|
|
56 |
useBias: true,
|
57 |
kernelInitializer: 'glorotUniform',
|
58 |
biasInitializer: 'zeros',
|
59 |
+
description: 'Convolutional layer for feature extraction',
|
60 |
+
inputShape: null,
|
61 |
+
outputShape: null
|
62 |
},
|
63 |
'pool': {
|
64 |
poolSize: [2, 2],
|
65 |
strides: [2, 2],
|
66 |
padding: 'valid',
|
67 |
+
description: 'Max pooling layer for spatial downsampling',
|
68 |
+
inputShape: null,
|
69 |
+
outputShape: null
|
70 |
+
},
|
71 |
+
'linear': {
|
72 |
+
inputFeatures: 1,
|
73 |
+
outputFeatures: 1,
|
74 |
+
useBias: true,
|
75 |
+
activation: 'linear',
|
76 |
+
learningRate: 0.01,
|
77 |
+
optimizer: 'sgd',
|
78 |
+
lossFunction: 'mse',
|
79 |
+
biasInitializer: 'zeros',
|
80 |
+
kernelInitializer: 'glorotUniform',
|
81 |
+
description: 'Linear regression layer for numerical prediction',
|
82 |
+
inputShape: [1],
|
83 |
+
outputShape: [1]
|
84 |
}
|
85 |
};
|
86 |
|