ameerazam08 commited on
Commit
fd832fc
·
verified ·
1 Parent(s): 5ccc1b8

Upload 7 files

Browse files
Files changed (4) hide show
  1. README.md +50 -3
  2. js/drag-drop.js +93 -388
  3. js/main.js +417 -26
  4. js/neural-network.js +53 -10
README.md CHANGED
@@ -1,9 +1,56 @@
1
  ---
2
  title: Neural Network Playground
3
- emoji: 🌖
4
  colorFrom: pink
5
  colorTo: blue
6
  sdk: static
7
- pinned: false
 
8
  ---
9
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  title: Neural Network Playground
3
+ emoji: 🧠
4
  colorFrom: pink
5
  colorTo: blue
6
  sdk: static
7
+ pinned: true
8
+ license: mit
9
  ---
10
+
11
+ # Neural Network Playground
12
+
13
+ ![Neural Network Playground](https://raw.githubusercontent.com/huggingface/hub-docs/main/static/logos/huggingface_logo-noborder.svg)
14
+
15
+ ## Introduction
16
+
17
+ Neural Network Playground is an interactive visualization tool that helps you understand how neural networks work. Built with plain HTML, CSS, and JavaScript, it allows you to:
18
+
19
+ - Create custom neural network architectures by dragging and dropping different types of layers
20
+ - Connect layers and see how data flows through the network
21
+ - View input and output shapes for each layer
22
+ - Visualize layer parameters and configurations
23
+
24
+ ## Features
25
+
26
+ - **Interactive Interface**: Drag and drop nodes to create neural networks
27
+ - **Shape Information**: See input and output shapes for each node
28
+ - **Detailed Parameters**: View kernel size, stride, and padding for applicable layers
29
+ - **Layer Types**:
30
+ - Input Layer
31
+ - Hidden Layer
32
+ - Output Layer
33
+ - Convolutional Layer
34
+ - Pooling Layer
35
+ - Linear Regression Layer
36
+
37
+ ## How to Use
38
+
39
+ 1. Drag components from the left panel onto the canvas
40
+ 2. Connect them by dragging from output (right) ports to input (left) ports
41
+ 3. Double-click on nodes to edit their properties
42
+ 4. Use the network settings to adjust learning rate and activation functions
43
+
44
+ ## Technical Details
45
+
46
+ The playground visualizes how neural networks process data and helps users understand concepts like:
47
+
48
+ - Shape transformations between layers
49
+ - Parameter calculations
50
+ - The effects of different layer configurations
51
+
52
+ This is an educational tool designed to make neural networks more accessible and understandable.
53
+
54
+ ## License
55
+
56
+ MIT
js/drag-drop.js CHANGED
@@ -16,6 +16,17 @@ function initializeDragAndDrop() {
16
  connections: []
17
  };
18
 
 
 
 
 
 
 
 
 
 
 
 
19
  // Add event listeners to draggable items
20
  nodeItems.forEach(item => {
21
  item.addEventListener('dragstart', handleDragStart);
@@ -111,7 +122,7 @@ function initializeDragAndDrop() {
111
  inputShape = 'Connect input';
112
  outputShape = 'Depends on input';
113
  // Create parameter string
114
- parameters = `In: ?, Out: ${nodeConfig.filters}\nKernel: ${nodeConfig.kernelSize.join('×')}\nStride: ${nodeConfig.strides.join('×')}\nPadding: ${nodeConfig.padding}`;
115
  break;
116
  case 'pool':
117
  const poolCount = document.querySelectorAll('.canvas-node[data-type="pool"]').length;
@@ -127,11 +138,6 @@ function initializeDragAndDrop() {
127
  parameters = 'N/A';
128
  }
129
 
130
- // Create node header
131
- const nodeHeader = document.createElement('div');
132
- nodeHeader.className = 'node-header';
133
- nodeHeader.textContent = nodeName;
134
-
135
  // Create node content
136
  const nodeContent = document.createElement('div');
137
  nodeContent.className = 'node-content';
@@ -147,25 +153,71 @@ function initializeDragAndDrop() {
147
  // Add parameters section
148
  const paramsSection = document.createElement('div');
149
  paramsSection.className = 'params-section';
150
- paramsSection.innerHTML = `<pre class="params-display">${parameters}</pre>`;
151
-
152
- // Add connection ports
153
- const inputPort = document.createElement('div');
154
- inputPort.className = 'port input-port';
155
- inputPort.setAttribute('data-port-type', 'input');
156
-
157
- const outputPort = document.createElement('div');
158
- outputPort.className = 'port output-port';
159
- outputPort.setAttribute('data-port-type', 'output');
160
 
161
- // Assemble the node
162
  nodeContent.appendChild(shapeInfo);
163
  nodeContent.appendChild(paramsSection);
164
 
165
- canvasNode.appendChild(nodeHeader);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
  canvasNode.appendChild(nodeContent);
167
- canvasNode.appendChild(inputPort);
168
- canvasNode.appendChild(outputPort);
 
 
 
 
169
 
170
  // Add node to the canvas
171
  canvas.appendChild(canvasNode);
@@ -175,10 +227,13 @@ function initializeDragAndDrop() {
175
 
176
  // Add event listeners for node manipulation
177
  canvasNode.addEventListener('mousedown', startDrag);
178
- inputPort.addEventListener('mousedown', (e) => {
 
 
179
  e.stopPropagation();
180
  });
181
- outputPort.addEventListener('mousedown', (e) => {
 
182
  e.stopPropagation();
183
  startConnection(canvasNode, e);
184
  });
@@ -198,7 +253,11 @@ function initializeDragAndDrop() {
198
  networkLayers.layers.push({
199
  id: layerId,
200
  type: nodeType,
201
- config: nodeConfig
 
 
 
 
202
  });
203
 
204
  // Notify about network changes
@@ -251,12 +310,19 @@ function initializeDragAndDrop() {
251
  let x = e.clientX - canvasRect.left - offsetX;
252
  let y = e.clientY - canvasRect.top - offsetY;
253
 
254
- // Constrain to canvas
255
- x = Math.max(0, Math.min(canvasRect.width - draggedNode.offsetWidth, x));
256
- y = Math.max(0, Math.min(canvasRect.height - draggedNode.offsetHeight, y));
 
 
 
 
257
 
 
 
258
  draggedNode.style.left = `${x}px`;
259
  draggedNode.style.top = `${y}px`;
 
260
 
261
  // Update node position in network layers
262
  const nodeId = draggedNode.getAttribute('data-id');
@@ -509,365 +575,4 @@ function initializeDragAndDrop() {
509
  const endY = targetPortRect.top + (targetPortRect.height / 2) - canvasRect.top;
510
 
511
  // Create the connection
512
- const pathId = `connection-${sourceId}-${targetId}`;
513
- const connectionPath = document.createElementNS('http://www.w3.org/2000/svg', 'path');
514
- connectionPath.setAttribute('id', pathId);
515
- connectionPath.setAttribute('class', 'connection-line');
516
-
517
- // Curved path (bezier)
518
- const dx = Math.abs(endX - startX) * 0.7;
519
- const path = `M ${startX} ${startY} C ${startX + dx} ${startY}, ${endX - dx} ${endY}, ${endX} ${endY}`;
520
- connectionPath.setAttribute('d', path);
521
-
522
- // Add connection to SVG container
523
- svgContainer.appendChild(connectionPath);
524
-
525
- // Add to connections
526
- networkLayers.connections.push({
527
- id: pathId,
528
- source: sourceId,
529
- target: targetId,
530
- sourceType: sourceType,
531
- targetType: targetType
532
- });
533
-
534
- // Update input and output shapes
535
- updateNodeShapes(sourceId, targetId);
536
-
537
- // Notify about connection
538
- document.dispatchEvent(new CustomEvent('networkUpdated', {
539
- detail: networkLayers
540
- }));
541
- }
542
-
543
- // Clean up
544
- removePortHighlights();
545
- if (connectionLine) {
546
- connectionLine.remove();
547
- connectionLine = null;
548
- }
549
- isConnecting = false;
550
- startNode = null;
551
- }
552
-
553
- // Update input and output shapes when connections are made
554
- function updateNodeShapes(sourceId, targetId) {
555
- const sourceNode = document.querySelector(`.canvas-node[data-id="${sourceId}"]`);
556
- const targetNode = document.querySelector(`.canvas-node[data-id="${targetId}"]`);
557
-
558
- if (sourceNode && targetNode) {
559
- const sourceConfig = sourceNode.layerConfig;
560
- const targetConfig = targetNode.layerConfig;
561
-
562
- // Update the target's input shape based on the source's output shape
563
- if (sourceConfig && targetConfig) {
564
- // Calculate output shape based on node type
565
- let outputShape;
566
- switch (sourceNode.getAttribute('data-type')) {
567
- case 'input':
568
- outputShape = sourceConfig.shape;
569
- break;
570
- case 'hidden':
571
- outputShape = [sourceConfig.units];
572
- break;
573
- case 'output':
574
- outputShape = [sourceConfig.units];
575
- break;
576
- case 'conv':
577
- // For Conv2D, the output shape depends on the input and parameters
578
- // This is a simplified calculation
579
- if (targetConfig.inputShape) {
580
- const h = targetConfig.inputShape[0];
581
- const w = targetConfig.inputShape[1];
582
- const kh = sourceConfig.kernelSize[0];
583
- const kw = sourceConfig.kernelSize[1];
584
- const sh = sourceConfig.strides[0];
585
- const sw = sourceConfig.strides[1];
586
- const padding = sourceConfig.padding;
587
-
588
- let outHeight, outWidth;
589
- if (padding === 'same') {
590
- outHeight = Math.ceil(h / sh);
591
- outWidth = Math.ceil(w / sw);
592
- } else { // 'valid'
593
- outHeight = Math.ceil((h - kh + 1) / sh);
594
- outWidth = Math.ceil((w - kw + 1) / sw);
595
- }
596
-
597
- outputShape = [outHeight, outWidth, sourceConfig.filters];
598
- } else {
599
- outputShape = ['?', '?', sourceConfig.filters];
600
- }
601
- break;
602
- case 'pool':
603
- // For pooling, also depends on the input and parameters
604
- if (targetConfig.inputShape) {
605
- const h = targetConfig.inputShape[0];
606
- const w = targetConfig.inputShape[1];
607
- const c = targetConfig.inputShape[2];
608
- const ph = sourceConfig.poolSize[0];
609
- const pw = sourceConfig.poolSize[1];
610
- const sh = sourceConfig.strides[0];
611
- const sw = sourceConfig.strides[1];
612
- const padding = sourceConfig.padding;
613
-
614
- let outHeight, outWidth;
615
- if (padding === 'same') {
616
- outHeight = Math.ceil(h / sh);
617
- outWidth = Math.ceil(w / sw);
618
- } else { // 'valid'
619
- outHeight = Math.ceil((h - ph + 1) / sh);
620
- outWidth = Math.ceil((w - pw + 1) / sw);
621
- }
622
-
623
- outputShape = [outHeight, outWidth, c];
624
- } else {
625
- outputShape = ['?', '?', '?'];
626
- }
627
- break;
628
- case 'linear':
629
- outputShape = [sourceConfig.outputFeatures];
630
- break;
631
- default:
632
- outputShape = ['?', '?', '?'];
633
- }
634
-
635
- // Update the target's input shape
636
- targetConfig.inputShape = outputShape;
637
-
638
- // Update UI
639
- updateNodeDisplayShapes(sourceNode, targetNode);
640
- }
641
- }
642
- }
643
-
644
- // Update the displayed shapes in the UI
645
- function updateNodeDisplayShapes(sourceNode, targetNode) {
646
- if (sourceNode && targetNode) {
647
- const sourceType = sourceNode.getAttribute('data-type');
648
- const targetType = targetNode.getAttribute('data-type');
649
- const sourceConfig = sourceNode.layerConfig;
650
- const targetConfig = targetNode.layerConfig;
651
-
652
- // Update source node output shape display
653
- const sourceOutputElem = sourceNode.querySelector('.output-shape');
654
- if (sourceOutputElem && sourceConfig) {
655
- let outputText;
656
- switch (sourceType) {
657
- case 'input':
658
- outputText = `[${sourceConfig.shape.join(' × ')}]`;
659
- break;
660
- case 'hidden':
661
- case 'output':
662
- outputText = `[${sourceConfig.units}]`;
663
- break;
664
- case 'conv':
665
- if (sourceConfig.outputShape) {
666
- outputText = `[${sourceConfig.outputShape.join(' × ')}]`;
667
- } else {
668
- outputText = `[? × ? × ${sourceConfig.filters}]`;
669
- }
670
- break;
671
- case 'pool':
672
- if (sourceConfig.outputShape) {
673
- outputText = `[${sourceConfig.outputShape.join(' × ')}]`;
674
- } else {
675
- outputText = 'Depends on input';
676
- }
677
- break;
678
- case 'linear':
679
- outputText = `[${sourceConfig.outputFeatures}]`;
680
- break;
681
- default:
682
- outputText = 'Unknown';
683
- }
684
- sourceOutputElem.textContent = outputText;
685
- }
686
-
687
- // Update target node input shape display
688
- const targetInputElem = targetNode.querySelector('.input-shape');
689
- if (targetInputElem && targetConfig && targetConfig.inputShape) {
690
- targetInputElem.textContent = `[${targetConfig.inputShape.join(' × ')}]`;
691
-
692
- // Update parameters section
693
- const targetParamsElem = targetNode.querySelector('.params-display');
694
- if (targetParamsElem) {
695
- // Calculate and display parameters
696
- let paramsText = '';
697
- switch (targetType) {
698
- case 'hidden':
699
- const inputUnits = Array.isArray(targetConfig.inputShape) ?
700
- targetConfig.inputShape.reduce((acc, val) => acc * val, 1) :
701
- targetConfig.inputShape;
702
-
703
- const biasParams = targetConfig.useBias ? targetConfig.units : 0;
704
- const totalParams = (inputUnits * targetConfig.units) + biasParams;
705
-
706
- paramsText = `In: ${inputUnits}, Out: ${targetConfig.units}\nParams: ${totalParams.toLocaleString()}\nDropout: ${targetConfig.dropoutRate}`;
707
- break;
708
- case 'output':
709
- const outInputUnits = Array.isArray(targetConfig.inputShape) ?
710
- targetConfig.inputShape.reduce((acc, val) => acc * val, 1) :
711
- targetConfig.inputShape;
712
-
713
- const outBiasParams = targetConfig.useBias ? targetConfig.units : 0;
714
- const outTotalParams = (outInputUnits * targetConfig.units) + outBiasParams;
715
-
716
- paramsText = `In: ${outInputUnits}, Out: ${targetConfig.units}\nParams: ${outTotalParams.toLocaleString()}\nActivation: ${targetConfig.activation}`;
717
- break;
718
- case 'conv':
719
- const channels = targetConfig.inputShape[2] || '?';
720
- const kernelH = targetConfig.kernelSize[0];
721
- const kernelW = targetConfig.kernelSize[1];
722
- const kernelParams = kernelH * kernelW * channels * targetConfig.filters;
723
- const convBiasParams = targetConfig.useBias ? targetConfig.filters : 0;
724
- const convTotalParams = kernelParams + convBiasParams;
725
-
726
- paramsText = `In: ${channels}, Out: ${targetConfig.filters}\nKernel: ${targetConfig.kernelSize.join('×')}\nStride: ${targetConfig.strides.join('×')}\nPadding: ${targetConfig.padding}\nParams: ${convTotalParams.toLocaleString()}`;
727
- break;
728
- case 'pool':
729
- paramsText = `Pool size: ${targetConfig.poolSize.join('×')}\nStride: ${targetConfig.strides.join('×')}\nPadding: ${targetConfig.padding}\nParams: 0`;
730
- break;
731
- case 'linear':
732
- const linearInputs = targetConfig.inputFeatures;
733
- const linearBiasParams = targetConfig.useBias ? targetConfig.outputFeatures : 0;
734
- const linearTotalParams = (linearInputs * targetConfig.outputFeatures) + linearBiasParams;
735
-
736
- paramsText = `In: ${linearInputs}, Out: ${targetConfig.outputFeatures}\nParams: ${linearTotalParams.toLocaleString()}\nLearning Rate: ${targetConfig.learningRate}\nLoss: ${targetConfig.lossFunction}`;
737
- break;
738
- }
739
-
740
- targetParamsElem.textContent = paramsText;
741
- }
742
- }
743
- }
744
- }
745
-
746
- // Delete a node and its connections
747
- function deleteNode(node) {
748
- if (!node) return;
749
-
750
- const nodeId = node.getAttribute('data-id');
751
-
752
- // Remove all connections to/from this node
753
- document.querySelectorAll(`.connection[data-source="${nodeId}"], .connection[data-target="${nodeId}"]`).forEach(conn => {
754
- conn.parentNode.removeChild(conn);
755
- });
756
-
757
- // Remove from network layers
758
- networkLayers.layers = networkLayers.layers.filter(layer => layer.id !== nodeId);
759
- networkLayers.connections = networkLayers.connections.filter(conn =>
760
- conn.source !== nodeId && conn.target !== nodeId
761
- );
762
-
763
- // Remove the node
764
- node.parentNode.removeChild(node);
765
-
766
- // Update layer connectivity
767
- updateLayerConnectivity();
768
- }
769
-
770
- // Open layer editor modal
771
- function openLayerEditor(node) {
772
- if (!node) return;
773
-
774
- const nodeId = node.getAttribute('data-id');
775
- const nodeType = node.getAttribute('data-type');
776
- const nodeName = node.getAttribute('data-name');
777
- const dimensions = node.getAttribute('data-dimensions');
778
-
779
- // Trigger custom event
780
- const event = new CustomEvent('openLayerEditor', {
781
- detail: { id: nodeId, type: nodeType, name: nodeName, dimensions: dimensions }
782
- });
783
- document.dispatchEvent(event);
784
- }
785
-
786
- // Update connections when nodes are moved
787
- function updateConnections() {
788
- const connections = document.querySelectorAll('.connection');
789
- connections.forEach(connection => {
790
- const sourceId = connection.getAttribute('data-source');
791
- const targetId = connection.getAttribute('data-target');
792
-
793
- const sourceNode = document.querySelector(`.canvas-node[data-id="${sourceId}"]`);
794
- const targetNode = document.querySelector(`.canvas-node[data-id="${targetId}"]`);
795
-
796
- if (sourceNode && targetNode) {
797
- const sourcePort = sourceNode.querySelector('.port-out');
798
- const targetPort = targetNode.querySelector('.port-in');
799
-
800
- if (sourcePort && targetPort) {
801
- const sourceRect = sourcePort.getBoundingClientRect();
802
- const targetRect = targetPort.getBoundingClientRect();
803
- const canvasRect = canvas.getBoundingClientRect();
804
-
805
- const startX = sourceRect.left + sourceRect.width / 2 - canvasRect.left;
806
- const startY = sourceRect.top + sourceRect.height / 2 - canvasRect.top;
807
- const endX = targetRect.left + targetRect.width / 2 - canvasRect.left;
808
- const endY = targetRect.top + targetRect.height / 2 - canvasRect.top;
809
-
810
- const length = Math.sqrt(Math.pow(endX - startX, 2) + Math.pow(endY - startY, 2));
811
- const angle = Math.atan2(endY - startY, endX - startX) * 180 / Math.PI;
812
-
813
- connection.style.left = `${startX}px`;
814
- connection.style.top = `${startY}px`;
815
- connection.style.width = `${length}px`;
816
- connection.style.transform = `rotate(${angle}deg)`;
817
- }
818
- } else {
819
- // If either node is missing, remove the connection
820
- if (connection.parentNode) {
821
- connection.parentNode.removeChild(connection);
822
-
823
- // Remove from the connections array
824
- const connIndex = networkLayers.connections.findIndex(conn =>
825
- conn.source === sourceId && conn.target === targetId
826
- );
827
- if (connIndex !== -1) {
828
- networkLayers.connections.splice(connIndex, 1);
829
- }
830
- }
831
- }
832
- });
833
- }
834
-
835
- // Get the current network architecture
836
- function getNetworkArchitecture() {
837
- return networkLayers;
838
- }
839
-
840
- // Clear all nodes from the canvas
841
- function clearAllNodes() {
842
- // Clear all nodes and connections
843
- document.querySelectorAll('.canvas-node, .connection').forEach(el => {
844
- el.parentNode.removeChild(el);
845
- });
846
-
847
- // Reset network layers
848
- networkLayers = {
849
- layers: [],
850
- connections: []
851
- };
852
-
853
- // Reset layer counter
854
- window.neuralNetwork.resetLayerCounter();
855
-
856
- // Show the canvas hint
857
- const canvasHint = document.querySelector('.canvas-hint');
858
- if (canvasHint) {
859
- canvasHint.style.display = 'block';
860
- }
861
-
862
- // Trigger network updated event
863
- const event = new CustomEvent('networkUpdated', { detail: networkLayers });
864
- document.dispatchEvent(event);
865
- }
866
-
867
- // Export functions
868
- window.dragDrop = {
869
- getNetworkArchitecture,
870
- clearAllNodes,
871
- updateConnections
872
- };
873
- }
 
16
  connections: []
17
  };
18
 
19
+ // Helper function to format numbers with K, M, B suffixes
20
+ function formatNumber(num) {
21
+ if (num === 0) return '0';
22
+ if (!num) return 'N/A';
23
+
24
+ if (num >= 1e9) return (num / 1e9).toFixed(2) + 'B';
25
+ if (num >= 1e6) return (num / 1e6).toFixed(2) + 'M';
26
+ if (num >= 1e3) return (num / 1e3).toFixed(2) + 'K';
27
+ return num.toString();
28
+ }
29
+
30
  // Add event listeners to draggable items
31
  nodeItems.forEach(item => {
32
  item.addEventListener('dragstart', handleDragStart);
 
122
  inputShape = 'Connect input';
123
  outputShape = 'Depends on input';
124
  // Create parameter string
125
+ parameters = `Kernel: ${nodeConfig.kernelSize.join('×')}\nStride: ${nodeConfig.strides.join('×')}\nPadding: ${nodeConfig.padding}`;
126
  break;
127
  case 'pool':
128
  const poolCount = document.querySelectorAll('.canvas-node[data-type="pool"]').length;
 
138
  parameters = 'N/A';
139
  }
140
 
 
 
 
 
 
141
  // Create node content
142
  const nodeContent = document.createElement('div');
143
  nodeContent.className = 'node-content';
 
153
  // Add parameters section
154
  const paramsSection = document.createElement('div');
155
  paramsSection.className = 'params-section';
156
+ paramsSection.innerHTML = `
157
+ <div class="params-details">${parameters}</div>
158
+ <div class="node-parameters">Params: ${nodeConfig.parameters !== undefined ? formatNumber(nodeConfig.parameters) : '?'}</div>
159
+ `;
 
 
 
 
 
 
160
 
161
+ // Assemble content
162
  nodeContent.appendChild(shapeInfo);
163
  nodeContent.appendChild(paramsSection);
164
 
165
+ // Add dimensions section to show shapes compactly
166
+ const dimensionsSection = document.createElement('div');
167
+ dimensionsSection.className = 'node-dimensions';
168
+
169
+ // Set dimensions text based on node type
170
+ let dimensionsText = '';
171
+ switch(nodeType) {
172
+ case 'input':
173
+ dimensionsText = nodeConfig.shape.join(' × ');
174
+ break;
175
+ case 'hidden':
176
+ case 'output':
177
+ dimensionsText = nodeConfig.units.toString();
178
+ break;
179
+ case 'conv':
180
+ if (nodeConfig.inputShape && nodeConfig.outputShape) {
181
+ dimensionsText = `${nodeConfig.inputShape.join('×')} → ${nodeConfig.outputShape.join('×')}`;
182
+ } else {
183
+ dimensionsText = `? → ${nodeConfig.filters} filters`;
184
+ }
185
+ break;
186
+ case 'pool':
187
+ if (nodeConfig.inputShape && nodeConfig.outputShape) {
188
+ dimensionsText = `${nodeConfig.inputShape.join('×')} → ${nodeConfig.outputShape.join('×')}`;
189
+ } else {
190
+ dimensionsText = `? → ?`;
191
+ }
192
+ break;
193
+ case 'linear':
194
+ dimensionsText = `${nodeConfig.inputFeatures} → ${nodeConfig.outputFeatures}`;
195
+ break;
196
+ }
197
+ dimensionsSection.textContent = dimensionsText;
198
+
199
+ // Add node title for clearer identification
200
+ const nodeTitle = document.createElement('div');
201
+ nodeTitle.className = 'node-title';
202
+ nodeTitle.textContent = nodeName;
203
+
204
+ // Add connection ports
205
+ const portIn = document.createElement('div');
206
+ portIn.className = 'node-port port-in';
207
+
208
+ const portOut = document.createElement('div');
209
+ portOut.className = 'node-port port-out';
210
+
211
+ // Assemble the node with the new structure
212
+ canvasNode.appendChild(nodeTitle);
213
+ canvasNode.appendChild(dimensionsSection);
214
  canvasNode.appendChild(nodeContent);
215
+ canvasNode.appendChild(portIn);
216
+ canvasNode.appendChild(portOut);
217
+
218
+ // Store node data attributes for easier access
219
+ canvasNode.setAttribute('data-name', nodeName);
220
+ canvasNode.setAttribute('data-dimensions', dimensionsText);
221
 
222
  // Add node to the canvas
223
  canvas.appendChild(canvasNode);
 
227
 
228
  // Add event listeners for node manipulation
229
  canvasNode.addEventListener('mousedown', startDrag);
230
+
231
+ // Update port event listeners for the new class names
232
+ portIn.addEventListener('mousedown', (e) => {
233
  e.stopPropagation();
234
  });
235
+
236
+ portOut.addEventListener('mousedown', (e) => {
237
  e.stopPropagation();
238
  startConnection(canvasNode, e);
239
  });
 
253
  networkLayers.layers.push({
254
  id: layerId,
255
  type: nodeType,
256
+ name: nodeName,
257
+ position: { x, y },
258
+ dimensions: dimensionsText,
259
+ config: nodeConfig,
260
+ parameters: nodeConfig.parameters || 0
261
  });
262
 
263
  // Notify about network changes
 
310
  let x = e.clientX - canvasRect.left - offsetX;
311
  let y = e.clientY - canvasRect.top - offsetY;
312
 
313
+ // Constrain to canvas with better boundary checks
314
+ const nodeWidth = draggedNode.offsetWidth || 150; // Default width if not set
315
+ const nodeHeight = draggedNode.offsetHeight || 100; // Default height if not set
316
+
317
+ // Ensure the node stays completely within the canvas
318
+ x = Math.max(0, Math.min(canvasRect.width - nodeWidth, x));
319
+ y = Math.max(0, Math.min(canvasRect.height - nodeHeight, y));
320
 
321
+ // Apply position with fixed sizing to prevent layout expansion
322
+ draggedNode.style.position = 'absolute';
323
  draggedNode.style.left = `${x}px`;
324
  draggedNode.style.top = `${y}px`;
325
+ draggedNode.style.width = `${nodeWidth}px`; // Maintain fixed width
326
 
327
  // Update node position in network layers
328
  const nodeId = draggedNode.getAttribute('data-id');
 
575
  const endY = targetPortRect.top + (targetPortRect.height / 2) - canvasRect.top;
576
 
577
  // Create the connection
578
+ const pathId = `connection-${sourceId}-${targetId}`
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
js/main.js CHANGED
@@ -405,7 +405,20 @@ document.addEventListener('DOMContentLoaded', () => {
405
 
406
  case 'conv':
407
  // Convolutional layer parameters
 
 
 
 
408
  layerForm.innerHTML += `
 
 
 
 
 
 
 
 
 
409
  <div class="form-group">
410
  <label>Filters:</label>
411
  <input type="number" id="conv-filters" min="1" value="${layerConfig.filters}" placeholder="Number of filters">
@@ -442,12 +455,92 @@ document.addEventListener('DOMContentLoaded', () => {
442
  <option value="leaky_relu" ${layerConfig.activation === 'leaky_relu' ? 'selected' : ''}>Leaky ReLU</option>
443
  </select>
444
  </div>
 
 
 
 
 
 
 
 
 
 
 
 
 
445
  `;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
446
  break;
447
 
448
  case 'pool':
449
  // Pooling layer parameters
 
 
 
 
450
  layerForm.innerHTML += `
 
 
 
 
 
 
 
 
 
451
  <div class="form-group">
452
  <label>Pool Size:</label>
453
  <div class="form-row">
@@ -476,7 +569,61 @@ document.addEventListener('DOMContentLoaded', () => {
476
  <option value="avg">Average Pooling</option>
477
  </select>
478
  </div>
 
 
 
 
 
 
 
479
  `;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
480
  break;
481
 
482
  case 'linear':
@@ -591,28 +738,169 @@ document.addEventListener('DOMContentLoaded', () => {
591
  const values = {};
592
  const inputs = form.querySelectorAll('input, select');
593
  inputs.forEach(input => {
594
- values[input.id] = input.value;
 
 
 
 
595
  });
596
 
597
  // Update node configuration
598
- node.layerConfig = {
599
- type: nodeType,
600
- shape: [
601
- parseInt(values['input-height']),
602
- parseInt(values['input-width']),
603
- parseInt(values['input-channels'])
604
- ],
605
- batchSize: parseInt(values['batch-size']),
606
- units: parseInt(values['hidden-units']),
607
- activation: values['hidden-activation'],
608
- dropoutRate: parseFloat(values['dropout-rate']),
609
- useBias: values['use-bias'] === 'true',
610
- learningRate: parseFloat(values['learning-rate-slider']),
611
- lossFunction: values['loss-function'],
612
- optimizer: values['optimizer'],
613
- inputFeatures: parseInt(values['input-features']),
614
- outputFeatures: parseInt(values['output-features'])
615
- };
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
616
 
617
  // Update node title
618
  const nodeTitle = node.querySelector('.node-title');
@@ -623,37 +911,55 @@ document.addEventListener('DOMContentLoaded', () => {
623
  // Update node data attribute
624
  node.setAttribute('data-name', nodeType.charAt(0).toUpperCase() + nodeType.slice(1));
625
 
626
- // Update dimensions based on layer type
627
  let dimensions = '';
628
  switch (nodeType) {
629
  case 'input':
630
- dimensions = values['input-height'] + ' × ' + values['input-width'] + ' × ' + values['input-channels'];
631
  break;
632
 
633
  case 'hidden':
634
  case 'output':
635
- dimensions = values['hidden-units'];
636
  break;
637
 
638
  case 'conv':
639
- dimensions = values['conv-filters'] + ' × ' + values['kernel-size-h'] + ' × ' + values['kernel-size-w'];
 
 
 
 
 
640
  break;
641
 
642
  case 'pool':
643
- dimensions = values['pool-size-h'] + ' × ' + values['pool-size-w'];
 
 
 
 
 
644
  break;
645
 
646
  case 'linear':
647
- dimensions = values['input-features'] + ' ' + values['output-features'];
648
  break;
649
  }
650
 
651
- // Update node dimensions
652
  const nodeDimensions = node.querySelector('.node-dimensions');
653
  if (nodeDimensions) {
654
  nodeDimensions.textContent = dimensions;
655
  }
656
 
 
 
 
 
 
 
 
 
657
  // Update node data attribute
658
  node.setAttribute('data-dimensions', dimensions);
659
 
@@ -664,11 +970,96 @@ document.addEventListener('DOMContentLoaded', () => {
664
  if (layerIndex !== -1) {
665
  networkLayers.layers[layerIndex].name = nodeType.charAt(0).toUpperCase() + nodeType.slice(1);
666
  networkLayers.layers[layerIndex].dimensions = dimensions;
 
 
 
 
667
  }
668
 
669
  // Trigger network updated event
670
  const event = new CustomEvent('networkUpdated', { detail: networkLayers });
671
  document.dispatchEvent(event);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
672
  }
673
 
674
  // Handle sample selection
 
405
 
406
  case 'conv':
407
  // Convolutional layer parameters
408
+ // Get input and output shapes - may be calculated or null at first
409
+ const inputShape = layerConfig.inputShape || ['?', '?', '?'];
410
+ const outputShape = layerConfig.outputShape || ['?', '?', layerConfig.filters];
411
+
412
  layerForm.innerHTML += `
413
+ <div class="form-group">
414
+ <label>Input Shape:</label>
415
+ <div class="form-row">
416
+ <input type="number" id="conv-input-h" min="1" value="${inputShape[0] === '?' ? 28 : inputShape[0]}" placeholder="Height">
417
+ <input type="number" id="conv-input-w" min="1" value="${inputShape[1] === '?' ? 28 : inputShape[1]}" placeholder="Width">
418
+ <input type="number" id="conv-input-c" min="1" value="${inputShape[2] === '?' ? 1 : inputShape[2]}" placeholder="Channels">
419
+ </div>
420
+ <small>Input dimensions: H × W × C</small>
421
+ </div>
422
  <div class="form-group">
423
  <label>Filters:</label>
424
  <input type="number" id="conv-filters" min="1" value="${layerConfig.filters}" placeholder="Number of filters">
 
455
  <option value="leaky_relu" ${layerConfig.activation === 'leaky_relu' ? 'selected' : ''}>Leaky ReLU</option>
456
  </select>
457
  </div>
458
+ <div class="form-group">
459
+ <label>Output Shape (calculated):</label>
460
+ <div class="output-shape-display" id="conv-output-shape">
461
+ [${outputShape.join(' × ')}]
462
+ </div>
463
+ <small>Output dimensions: H × W × Filters</small>
464
+ </div>
465
+ <div class="form-group">
466
+ <label>Parameters (calculated):</label>
467
+ <div class="parameters-display" id="conv-parameters">
468
+ Calculating...
469
+ </div>
470
+ </div>
471
  `;
472
+
473
+ // Add event listeners to calculate output shape and parameters in real-time
474
+ setTimeout(() => {
475
+ const inputH = document.getElementById('conv-input-h');
476
+ const inputW = document.getElementById('conv-input-w');
477
+ const inputC = document.getElementById('conv-input-c');
478
+ const filters = document.getElementById('conv-filters');
479
+ const kernelH = document.getElementById('kernel-size-h');
480
+ const kernelW = document.getElementById('kernel-size-w');
481
+ const strideH = document.getElementById('stride-h');
482
+ const strideW = document.getElementById('stride-w');
483
+ const paddingType = document.getElementById('padding-type');
484
+ const outputShapeDisplay = document.getElementById('conv-output-shape');
485
+ const parametersDisplay = document.getElementById('conv-parameters');
486
+
487
+ const updateOutputShape = () => {
488
+ const h = parseInt(inputH.value);
489
+ const w = parseInt(inputW.value);
490
+ const c = parseInt(inputC.value);
491
+ const f = parseInt(filters.value);
492
+ const kh = parseInt(kernelH.value);
493
+ const kw = parseInt(kernelW.value);
494
+ const sh = parseInt(strideH.value);
495
+ const sw = parseInt(strideW.value);
496
+ const padding = paddingType.value;
497
+
498
+ // Calculate output dimensions
499
+ const pH = padding === 'same' ? Math.floor(kh / 2) : 0;
500
+ const pW = padding === 'same' ? Math.floor(kw / 2) : 0;
501
+
502
+ const outH = Math.floor((h - kh + 2 * pH) / sh) + 1;
503
+ const outW = Math.floor((w - kw + 2 * pW) / sw) + 1;
504
+
505
+ // Update output shape display
506
+ outputShapeDisplay.textContent = `[${outH} × ${outW} × ${f}]`;
507
+
508
+ // Calculate parameters
509
+ const params = kh * kw * c * f + f; // weights + bias
510
+ parametersDisplay.textContent = formatNumber(params);
511
+
512
+ // Store for saving
513
+ layerConfig.inputShape = [h, w, c];
514
+ layerConfig.outputShape = [outH, outW, f];
515
+ layerConfig.parameters = params;
516
+ };
517
+
518
+ // Attach event listeners to all inputs
519
+ [inputH, inputW, inputC, filters, kernelH, kernelW, strideH, strideW, paddingType].forEach(
520
+ input => input.addEventListener('input', updateOutputShape)
521
+ );
522
+
523
+ // Initialize values
524
+ updateOutputShape();
525
+ }, 100);
526
  break;
527
 
528
  case 'pool':
529
  // Pooling layer parameters
530
+ // Get input and output shapes
531
+ const poolInputShape = layerConfig.inputShape || ['?', '?', '?'];
532
+ const poolOutputShape = layerConfig.outputShape || ['?', '?', '?'];
533
+
534
  layerForm.innerHTML += `
535
+ <div class="form-group">
536
+ <label>Input Shape:</label>
537
+ <div class="form-row">
538
+ <input type="number" id="pool-input-h" min="1" value="${poolInputShape[0] === '?' ? 28 : poolInputShape[0]}" placeholder="Height">
539
+ <input type="number" id="pool-input-w" min="1" value="${poolInputShape[1] === '?' ? 28 : poolInputShape[1]}" placeholder="Width">
540
+ <input type="number" id="pool-input-c" min="1" value="${poolInputShape[2] === '?' ? 1 : poolInputShape[2]}" placeholder="Channels">
541
+ </div>
542
+ <small>Input dimensions: H × W × C</small>
543
+ </div>
544
  <div class="form-group">
545
  <label>Pool Size:</label>
546
  <div class="form-row">
 
569
  <option value="avg">Average Pooling</option>
570
  </select>
571
  </div>
572
+ <div class="form-group">
573
+ <label>Output Shape (calculated):</label>
574
+ <div class="output-shape-display" id="pool-output-shape">
575
+ [${poolOutputShape.join(' × ')}]
576
+ </div>
577
+ <small>Output dimensions: H × W × C</small>
578
+ </div>
579
  `;
580
+
581
+ // Add event listeners to calculate output shape in real-time
582
+ setTimeout(() => {
583
+ const inputH = document.getElementById('pool-input-h');
584
+ const inputW = document.getElementById('pool-input-w');
585
+ const inputC = document.getElementById('pool-input-c');
586
+ const poolH = document.getElementById('pool-size-h');
587
+ const poolW = document.getElementById('pool-size-w');
588
+ const strideH = document.getElementById('pool-stride-h');
589
+ const strideW = document.getElementById('pool-stride-w');
590
+ const paddingType = document.getElementById('pool-padding');
591
+ const outputShapeDisplay = document.getElementById('pool-output-shape');
592
+
593
+ const updateOutputShape = () => {
594
+ const h = parseInt(inputH.value);
595
+ const w = parseInt(inputW.value);
596
+ const c = parseInt(inputC.value);
597
+ const ph = parseInt(poolH.value);
598
+ const pw = parseInt(poolW.value);
599
+ const sh = parseInt(strideH.value);
600
+ const sw = parseInt(strideW.value);
601
+ const padding = paddingType.value;
602
+
603
+ // Calculate output dimensions
604
+ const padH = padding === 'same' ? Math.floor(ph / 2) : 0;
605
+ const padW = padding === 'same' ? Math.floor(pw / 2) : 0;
606
+
607
+ const outH = Math.floor((h - ph + 2 * padH) / sh) + 1;
608
+ const outW = Math.floor((w - pw + 2 * padW) / sw) + 1;
609
+
610
+ // Update output shape display
611
+ outputShapeDisplay.textContent = `[${outH} × ${outW} × ${c}]`;
612
+
613
+ // Store for saving
614
+ layerConfig.inputShape = [h, w, c];
615
+ layerConfig.outputShape = [outH, outW, c];
616
+ layerConfig.parameters = 0; // Pooling has no parameters
617
+ };
618
+
619
+ // Attach event listeners to all inputs
620
+ [inputH, inputW, inputC, poolH, poolW, strideH, strideW, paddingType].forEach(
621
+ input => input.addEventListener('input', updateOutputShape)
622
+ );
623
+
624
+ // Initialize values
625
+ updateOutputShape();
626
+ }, 100);
627
  break;
628
 
629
  case 'linear':
 
738
  const values = {};
739
  const inputs = form.querySelectorAll('input, select');
740
  inputs.forEach(input => {
741
+ if (input.type === 'checkbox') {
742
+ values[input.id] = input.checked;
743
+ } else {
744
+ values[input.id] = input.value;
745
+ }
746
  });
747
 
748
  // Update node configuration
749
+ node.layerConfig = node.layerConfig || {};
750
+ const layerConfig = node.layerConfig;
751
+
752
+ switch (nodeType) {
753
+ case 'input':
754
+ layerConfig.shape = [
755
+ parseInt(values['input-height']) || 28,
756
+ parseInt(values['input-width']) || 28,
757
+ parseInt(values['input-channels']) || 1
758
+ ];
759
+ layerConfig.batchSize = parseInt(values['batch-size']) || 32;
760
+ layerConfig.outputShape = layerConfig.shape;
761
+ layerConfig.parameters = 0;
762
+ break;
763
+
764
+ case 'hidden':
765
+ layerConfig.units = parseInt(values['hidden-units']) || 128;
766
+ layerConfig.activation = values['hidden-activation'] || 'relu';
767
+ layerConfig.dropoutRate = parseFloat(values['dropout-rate']) || 0.2;
768
+ layerConfig.useBias = values['use-bias'] === true;
769
+ layerConfig.outputShape = [layerConfig.units];
770
+
771
+ // Calculate parameters if input shape is available
772
+ if (layerConfig.inputShape) {
773
+ const inputUnits = Array.isArray(layerConfig.inputShape) ?
774
+ layerConfig.inputShape.reduce((a, b) => a * b, 1) : layerConfig.inputShape;
775
+ layerConfig.parameters = (inputUnits * layerConfig.units) + (layerConfig.useBias ? layerConfig.units : 0);
776
+ }
777
+ break;
778
+
779
+ case 'output':
780
+ layerConfig.units = parseInt(values['output-units']) || 10;
781
+ layerConfig.activation = values['output-activation'] || 'softmax';
782
+ layerConfig.useBias = values['output-use-bias'] === true;
783
+ layerConfig.outputShape = [layerConfig.units];
784
+
785
+ // Calculate parameters if input shape is available
786
+ if (layerConfig.inputShape) {
787
+ const inputUnits = Array.isArray(layerConfig.inputShape) ?
788
+ layerConfig.inputShape.reduce((a, b) => a * b, 1) : layerConfig.inputShape;
789
+ layerConfig.parameters = (inputUnits * layerConfig.units) + (layerConfig.useBias ? layerConfig.units : 0);
790
+ }
791
+ break;
792
+
793
+ case 'conv':
794
+ // Process input shape if available in form
795
+ if (values['conv-input-h'] && values['conv-input-w'] && values['conv-input-c']) {
796
+ layerConfig.inputShape = [
797
+ parseInt(values['conv-input-h']) || 28,
798
+ parseInt(values['conv-input-w']) || 28,
799
+ parseInt(values['conv-input-c']) || 1
800
+ ];
801
+ }
802
+
803
+ // Process configuration
804
+ layerConfig.filters = parseInt(values['conv-filters']) || 32;
805
+ layerConfig.kernelSize = [
806
+ parseInt(values['kernel-size-h']) || 3,
807
+ parseInt(values['kernel-size-w']) || 3
808
+ ];
809
+ layerConfig.strides = [
810
+ parseInt(values['stride-h']) || 1,
811
+ parseInt(values['stride-w']) || 1
812
+ ];
813
+ layerConfig.padding = values['padding-type'] || 'valid';
814
+ layerConfig.activation = values['conv-activation'] || 'relu';
815
+ layerConfig.useBias = true; // Default to true for CNN
816
+
817
+ // Calculate output shape if input shape is available
818
+ if (layerConfig.inputShape) {
819
+ const padding = layerConfig.padding === 'same' ?
820
+ Math.floor(layerConfig.kernelSize[0] / 2) : 0;
821
+
822
+ const outH = Math.floor(
823
+ (layerConfig.inputShape[0] - layerConfig.kernelSize[0] + 2 * padding) /
824
+ layerConfig.strides[0]
825
+ ) + 1;
826
+
827
+ const outW = Math.floor(
828
+ (layerConfig.inputShape[1] - layerConfig.kernelSize[1] + 2 * padding) /
829
+ layerConfig.strides[1]
830
+ ) + 1;
831
+
832
+ layerConfig.outputShape = [outH, outW, layerConfig.filters];
833
+
834
+ // Calculate parameters
835
+ const kernelParams = layerConfig.kernelSize[0] * layerConfig.kernelSize[1] *
836
+ layerConfig.inputShape[2] * layerConfig.filters;
837
+ const biasParams = layerConfig.filters;
838
+ layerConfig.parameters = kernelParams + biasParams;
839
+ }
840
+ break;
841
+
842
+ case 'pool':
843
+ // Process input shape if available in form
844
+ if (values['pool-input-h'] && values['pool-input-w'] && values['pool-input-c']) {
845
+ layerConfig.inputShape = [
846
+ parseInt(values['pool-input-h']) || 28,
847
+ parseInt(values['pool-input-w']) || 28,
848
+ parseInt(values['pool-input-c']) || 1
849
+ ];
850
+ }
851
+
852
+ // Process configuration
853
+ layerConfig.poolSize = [
854
+ parseInt(values['pool-size-h']) || 2,
855
+ parseInt(values['pool-size-w']) || 2
856
+ ];
857
+ layerConfig.strides = [
858
+ parseInt(values['pool-stride-h']) || 2,
859
+ parseInt(values['pool-stride-w']) || 2
860
+ ];
861
+ layerConfig.padding = values['pool-padding'] || 'valid';
862
+ layerConfig.poolType = values['pool-type'] || 'max';
863
+
864
+ // Calculate output shape if input shape is available
865
+ if (layerConfig.inputShape) {
866
+ const poolPadding = layerConfig.padding === 'same' ?
867
+ Math.floor(layerConfig.poolSize[0] / 2) : 0;
868
+
869
+ const poolOutH = Math.floor(
870
+ (layerConfig.inputShape[0] - layerConfig.poolSize[0] + 2 * poolPadding) /
871
+ layerConfig.strides[0]
872
+ ) + 1;
873
+
874
+ const poolOutW = Math.floor(
875
+ (layerConfig.inputShape[1] - layerConfig.poolSize[1] + 2 * poolPadding) /
876
+ layerConfig.strides[1]
877
+ ) + 1;
878
+
879
+ layerConfig.outputShape = [poolOutH, poolOutW, layerConfig.inputShape[2]];
880
+ }
881
+
882
+ // Pooling has no parameters
883
+ layerConfig.parameters = 0;
884
+ break;
885
+
886
+ case 'linear':
887
+ layerConfig.inputFeatures = parseInt(values['input-features']) || 1;
888
+ layerConfig.outputFeatures = parseInt(values['output-features']) || 1;
889
+ layerConfig.useBias = values['linear-use-bias'] === true;
890
+ layerConfig.learningRate = parseFloat(values['learning-rate-slider']) || 0.01;
891
+ layerConfig.activation = values['linear-activation'] || 'linear';
892
+ layerConfig.optimizer = values['optimizer'] || 'sgd';
893
+ layerConfig.lossFunction = values['loss-function'] || 'mse';
894
+ layerConfig.inputShape = [layerConfig.inputFeatures];
895
+ layerConfig.outputShape = [layerConfig.outputFeatures];
896
+
897
+ // Calculate parameters
898
+ layerConfig.parameters = layerConfig.inputFeatures * layerConfig.outputFeatures;
899
+ if (layerConfig.useBias) {
900
+ layerConfig.parameters += layerConfig.outputFeatures;
901
+ }
902
+ break;
903
+ }
904
 
905
  // Update node title
906
  const nodeTitle = node.querySelector('.node-title');
 
911
  // Update node data attribute
912
  node.setAttribute('data-name', nodeType.charAt(0).toUpperCase() + nodeType.slice(1));
913
 
914
+ // Update dimensions and parameter display based on layer type
915
  let dimensions = '';
916
  switch (nodeType) {
917
  case 'input':
918
+ dimensions = layerConfig.shape.join(' × ');
919
  break;
920
 
921
  case 'hidden':
922
  case 'output':
923
+ dimensions = layerConfig.units.toString();
924
  break;
925
 
926
  case 'conv':
927
+ if (layerConfig.inputShape && layerConfig.outputShape) {
928
+ // Show input -> output shape transformation
929
+ dimensions = `${layerConfig.inputShape[0]}×${layerConfig.inputShape[1]}×${layerConfig.inputShape[2]} → ${layerConfig.outputShape[0]}×${layerConfig.outputShape[1]}×${layerConfig.outputShape[2]}`;
930
+ } else {
931
+ dimensions = `? → ${layerConfig.filters} filters`;
932
+ }
933
  break;
934
 
935
  case 'pool':
936
+ if (layerConfig.inputShape && layerConfig.outputShape) {
937
+ // Show input -> output shape transformation
938
+ dimensions = `${layerConfig.inputShape[0]}×${layerConfig.inputShape[1]}×${layerConfig.inputShape[2]} → ${layerConfig.outputShape[0]}×${layerConfig.outputShape[1]}×${layerConfig.outputShape[2]}`;
939
+ } else {
940
+ dimensions = `? → ?`;
941
+ }
942
  break;
943
 
944
  case 'linear':
945
+ dimensions = `${layerConfig.inputFeatures}${layerConfig.outputFeatures}`;
946
  break;
947
  }
948
 
949
+ // Update node dimensions display
950
  const nodeDimensions = node.querySelector('.node-dimensions');
951
  if (nodeDimensions) {
952
  nodeDimensions.textContent = dimensions;
953
  }
954
 
955
+ // Update parameters display if available
956
+ const nodeParameters = node.querySelector('.node-parameters');
957
+ if (nodeParameters && layerConfig.parameters !== undefined) {
958
+ nodeParameters.textContent = `Params: ${formatNumber(layerConfig.parameters)}`;
959
+ } else if (nodeParameters) {
960
+ nodeParameters.textContent = 'Params: ?';
961
+ }
962
+
963
  // Update node data attribute
964
  node.setAttribute('data-dimensions', dimensions);
965
 
 
970
  if (layerIndex !== -1) {
971
  networkLayers.layers[layerIndex].name = nodeType.charAt(0).toUpperCase() + nodeType.slice(1);
972
  networkLayers.layers[layerIndex].dimensions = dimensions;
973
+ networkLayers.layers[layerIndex].config = layerConfig;
974
+
975
+ // Add parameter count to the layer
976
+ networkLayers.layers[layerIndex].parameters = layerConfig.parameters;
977
  }
978
 
979
  // Trigger network updated event
980
  const event = new CustomEvent('networkUpdated', { detail: networkLayers });
981
  document.dispatchEvent(event);
982
+
983
+ // Update connected nodes to propagate shape changes
984
+ updateNodeConnections(node, layerId);
985
+ }
986
+
987
+ // Helper function to update connections between nodes when shapes change
988
+ function updateNodeConnections(sourceNode, sourceId) {
989
+ // Find all connections from this source node
990
+ const connections = document.querySelectorAll(`.connection[data-source="${sourceId}"]`);
991
+
992
+ connections.forEach(connection => {
993
+ const targetId = connection.getAttribute('data-target');
994
+ const targetNode = document.querySelector(`.canvas-node[data-id="${targetId}"]`);
995
+
996
+ if (targetNode && sourceNode.layerConfig && sourceNode.layerConfig.outputShape) {
997
+ // Update target node with source node's output shape as its input shape
998
+ if (!targetNode.layerConfig) {
999
+ targetNode.layerConfig = {};
1000
+ }
1001
+
1002
+ targetNode.layerConfig.inputShape = sourceNode.layerConfig.outputShape;
1003
+
1004
+ // Update parameter calculation
1005
+ window.neuralNetwork.calculateParameters(
1006
+ targetNode.getAttribute('data-type'),
1007
+ targetNode.layerConfig,
1008
+ sourceNode.layerConfig
1009
+ );
1010
+
1011
+ // Update display
1012
+ updateNodeDisplay(targetNode);
1013
+
1014
+ // Recursively update downstream nodes
1015
+ updateNodeConnections(targetNode, targetId);
1016
+ }
1017
+ });
1018
+ }
1019
+
1020
+ // Helper function to update a node's display
1021
+ function updateNodeDisplay(node) {
1022
+ if (!node || !node.layerConfig) return;
1023
+
1024
+ const nodeType = node.getAttribute('data-type');
1025
+ const layerConfig = node.layerConfig;
1026
+
1027
+ // Create dimensions string
1028
+ let dimensions = '';
1029
+ switch (nodeType) {
1030
+ case 'conv':
1031
+ case 'pool':
1032
+ if (layerConfig.inputShape && layerConfig.outputShape) {
1033
+ dimensions = `${layerConfig.inputShape[0]}×${layerConfig.inputShape[1]}×${layerConfig.inputShape[2]} → ${layerConfig.outputShape[0]}×${layerConfig.outputShape[1]}×${layerConfig.outputShape[2]}`;
1034
+ }
1035
+ break;
1036
+
1037
+ case 'hidden':
1038
+ case 'output':
1039
+ dimensions = layerConfig.units.toString();
1040
+ break;
1041
+
1042
+ case 'linear':
1043
+ dimensions = `${layerConfig.inputFeatures} → ${layerConfig.outputFeatures}`;
1044
+ break;
1045
+ }
1046
+
1047
+ // Update dimensions display
1048
+ if (dimensions) {
1049
+ const nodeDimensions = node.querySelector('.node-dimensions');
1050
+ if (nodeDimensions) {
1051
+ nodeDimensions.textContent = dimensions;
1052
+ node.setAttribute('data-dimensions', dimensions);
1053
+ }
1054
+ }
1055
+
1056
+ // Update parameters display
1057
+ if (layerConfig.parameters !== undefined) {
1058
+ const nodeParameters = node.querySelector('.node-parameters');
1059
+ if (nodeParameters) {
1060
+ nodeParameters.textContent = `Params: ${formatNumber(layerConfig.parameters)}`;
1061
+ }
1062
+ }
1063
  }
1064
 
1065
  // Handle sample selection
js/neural-network.js CHANGED
@@ -153,10 +153,10 @@
153
  }
154
 
155
  /**
156
- * Calculate the number of parameters for a layer
157
  * @param {string} layerType - The type of the layer
158
  * @param {Object} config - Layer configuration
159
- * @param {Object} prevLayerConfig - Previous layer configuration (for connections)
160
  * @returns {number} - Number of trainable parameters
161
  */
162
  function calculateParameters(layerType, config, prevLayerConfig = null) {
@@ -169,10 +169,17 @@
169
 
170
  case 'hidden':
171
  if (prevLayerConfig) {
172
- const inputUnits = prevLayerConfig.units ||
173
- (prevLayerConfig.shape ?
174
- prevLayerConfig.shape.reduce((a, b) => a * b, 1) :
175
- 784);
 
 
 
 
 
 
 
176
 
177
  // Weight parameters: input_units * output_units
178
  parameters = inputUnits * config.units;
@@ -186,7 +193,15 @@
186
 
187
  case 'output':
188
  if (prevLayerConfig) {
189
- const inputUnits = prevLayerConfig.units || 128;
 
 
 
 
 
 
 
 
190
 
191
  // Weight parameters: input_units * output_units
192
  parameters = inputUnits * config.units;
@@ -200,9 +215,17 @@
200
 
201
  case 'conv':
202
  if (prevLayerConfig) {
203
- const inputChannels = prevLayerConfig.shape ?
204
- prevLayerConfig.shape[2] || 1 :
205
- (prevLayerConfig.filters || 1);
 
 
 
 
 
 
 
 
206
 
207
  // Weight parameters: kernel_height * kernel_width * input_channels * filters
208
  const kernelSize = Array.isArray(config.kernelSize) ?
@@ -215,11 +238,31 @@
215
  if (config.useBias) {
216
  parameters += config.filters;
217
  }
 
 
 
 
 
 
 
 
 
 
218
  }
219
  break;
220
 
221
  case 'pool':
222
  parameters = 0; // Pooling layers have no trainable parameters
 
 
 
 
 
 
 
 
 
 
223
  break;
224
 
225
  default:
 
153
  }
154
 
155
  /**
156
+ * Calculate parameters for a layer
157
  * @param {string} layerType - The type of the layer
158
  * @param {Object} config - Layer configuration
159
+ * @param {Object} prevLayerConfig - Configuration of the previous connected layer
160
  * @returns {number} - Number of trainable parameters
161
  */
162
  function calculateParameters(layerType, config, prevLayerConfig = null) {
 
169
 
170
  case 'hidden':
171
  if (prevLayerConfig) {
172
+ // Calculate input units from previous layer shape or units
173
+ let inputUnits;
174
+ if (prevLayerConfig.outputShape && Array.isArray(prevLayerConfig.outputShape)) {
175
+ inputUnits = prevLayerConfig.outputShape.reduce((a, b) => a * b, 1);
176
+ } else if (prevLayerConfig.units) {
177
+ inputUnits = prevLayerConfig.units;
178
+ } else if (prevLayerConfig.shape) {
179
+ inputUnits = prevLayerConfig.shape.reduce((a, b) => a * b, 1);
180
+ } else {
181
+ inputUnits = 784; // Default fallback
182
+ }
183
 
184
  // Weight parameters: input_units * output_units
185
  parameters = inputUnits * config.units;
 
193
 
194
  case 'output':
195
  if (prevLayerConfig) {
196
+ // Calculate input units from previous layer
197
+ let inputUnits;
198
+ if (prevLayerConfig.outputShape && Array.isArray(prevLayerConfig.outputShape)) {
199
+ inputUnits = prevLayerConfig.outputShape.reduce((a, b) => a * b, 1);
200
+ } else if (prevLayerConfig.units) {
201
+ inputUnits = prevLayerConfig.units;
202
+ } else {
203
+ inputUnits = 128; // Default fallback
204
+ }
205
 
206
  // Weight parameters: input_units * output_units
207
  parameters = inputUnits * config.units;
 
215
 
216
  case 'conv':
217
  if (prevLayerConfig) {
218
+ // Get input channels from previous layer
219
+ let inputChannels;
220
+ if (prevLayerConfig.outputShape && prevLayerConfig.outputShape.length > 2) {
221
+ inputChannels = prevLayerConfig.outputShape[2];
222
+ } else if (prevLayerConfig.shape && prevLayerConfig.shape.length > 2) {
223
+ inputChannels = prevLayerConfig.shape[2];
224
+ } else if (prevLayerConfig.filters) {
225
+ inputChannels = prevLayerConfig.filters;
226
+ } else {
227
+ inputChannels = 1; // Default fallback
228
+ }
229
 
230
  // Weight parameters: kernel_height * kernel_width * input_channels * filters
231
  const kernelSize = Array.isArray(config.kernelSize) ?
 
238
  if (config.useBias) {
239
  parameters += config.filters;
240
  }
241
+
242
+ // Calculate and store output shape
243
+ if (prevLayerConfig.shape || prevLayerConfig.outputShape) {
244
+ const inputShape = prevLayerConfig.outputShape || prevLayerConfig.shape;
245
+ const padding = config.padding === 'same' ? Math.floor(config.kernelSize[0] / 2) : 0;
246
+ const outputHeight = Math.floor((inputShape[0] - config.kernelSize[0] + 2 * padding) / config.strides[0]) + 1;
247
+ const outputWidth = Math.floor((inputShape[1] - config.kernelSize[1] + 2 * padding) / config.strides[1]) + 1;
248
+
249
+ config.outputShape = [outputHeight, outputWidth, config.filters];
250
+ }
251
  }
252
  break;
253
 
254
  case 'pool':
255
  parameters = 0; // Pooling layers have no trainable parameters
256
+
257
+ // Calculate and store output shape
258
+ if (prevLayerConfig && (prevLayerConfig.shape || prevLayerConfig.outputShape)) {
259
+ const inputShape = prevLayerConfig.outputShape || prevLayerConfig.shape;
260
+ const padding = config.padding === 'same' ? Math.floor(config.poolSize[0] / 2) : 0;
261
+ const outputHeight = Math.floor((inputShape[0] - config.poolSize[0] + 2 * padding) / config.strides[0]) + 1;
262
+ const outputWidth = Math.floor((inputShape[1] - config.poolSize[1] + 2 * padding) / config.strides[1]) + 1;
263
+
264
+ config.outputShape = [outputHeight, outputWidth, inputShape[2]];
265
+ }
266
  break;
267
 
268
  default: