// Complete drag and drop fix for neural network playground
// This handles both initial node creation and moving existing nodes
(function() {
console.log('Loading complete drag and drop fix...');
document.addEventListener('DOMContentLoaded', function() {
// Wait a bit to ensure other scripts have loaded
setTimeout(initializeCompleteDragFix, 1000);
});
function initializeCompleteDragFix() {
console.log('Initializing complete drag and drop fix');
// Get necessary elements
const canvas = document.getElementById('network-canvas');
const nodeItems = document.querySelectorAll('.node-item');
if (!canvas) {
console.error('Canvas element not found!');
return;
}
// Track state for moving existing nodes
let activeNode = null;
let offsetX = 0;
let offsetY = 0;
let isDragging = false;
// Track node counts for naming
const nodeCounter = {};
// Anti-duplication system for new nodes
const recentlyCreated = {
nodeIds: new Set(),
timestamp: 0
};
// Network model structure (reused from original code)
let networkLayers = {
layers: [],
connections: []
};
// Helper function for formatting numbers
function formatNumber(num) {
if (num === 0) return '0';
if (!num) return 'N/A';
if (num >= 1e9) return (num / 1e9).toFixed(2) + 'B';
if (num >= 1e6) return (num / 1e6).toFixed(2) + 'M';
if (num >= 1e3) return (num / 1e3).toFixed(2) + 'K';
return num.toString();
}
// Add debug button for Conv2D parameters
addConv2DFixButton();
// 1. DRAGGING NEW NODES FROM PANEL TO CANVAS
// Setup draggable items
nodeItems.forEach(item => {
// Override existing dragstart handler for reliability
item.addEventListener('dragstart', function(e) {
const nodeType = this.getAttribute('data-type');
console.log(`Starting drag for new ${nodeType} node`);
// Ensure data is properly set for transfer
e.dataTransfer.setData('text/plain', nodeType);
e.dataTransfer.effectAllowed = 'copy';
// Create ghost image
const ghost = this.cloneNode(true);
ghost.style.opacity = '0.5';
document.body.appendChild(ghost);
e.dataTransfer.setDragImage(ghost, 0, 0);
// Remove ghost image after dragstart completes
setTimeout(() => {
document.body.removeChild(ghost);
}, 0);
});
});
// Add canvas event handlers for dropping new nodes
function handleDragOver(e) {
e.preventDefault();
e.dataTransfer.dropEffect = 'copy';
}
// Remove old handlers first to prevent duplicates
canvas.removeEventListener('dragover', handleDragOver);
canvas.addEventListener('dragover', handleDragOver);
// Create drop handler for new nodes
function handleDrop(e) {
e.preventDefault();
console.log('Drop event triggered');
// Debounce: prevent multiple drops in quick succession
const now = Date.now();
if (now - recentlyCreated.timestamp < 500) {
console.log('Debouncing drop event');
return;
}
recentlyCreated.timestamp = now;
// Get node type from dataTransfer
const nodeType = e.dataTransfer.getData('text/plain');
if (!nodeType) {
console.error('No node type found in drop data');
return;
}
console.log(`Creating new ${nodeType} node`);
// Calculate position for new node
const canvasRect = canvas.getBoundingClientRect();
const x = e.clientX - canvasRect.left - 75;
const y = e.clientY - canvasRect.top - 30;
// Ensure position is within canvas
const posX = Math.max(0, Math.min(canvasRect.width - 150, x));
const posY = Math.max(0, Math.min(canvasRect.height - 100, y));
// Generate unique ID
const layerId = `${nodeType}-${Date.now()}-${Math.floor(Math.random() * 1000)}`;
// Create the new node
createNode(nodeType, layerId, posX, posY);
}
// Remove old handler first to prevent duplicates
canvas.removeEventListener('drop', handleDrop);
canvas.addEventListener('drop', handleDrop);
// Function to create a new node
function createNode(nodeType, layerId, posX, posY) {
// Increment counter for this node type
nodeCounter[nodeType] = (nodeCounter[nodeType] || 0) + 1;
// Get default configuration from neural network module or use our own defaults
let nodeConfig;
if (window.neuralNetwork && window.neuralNetwork.createNodeConfig) {
nodeConfig = window.neuralNetwork.createNodeConfig(nodeType);
} else {
// Fallback default configs if neural network module is not available
nodeConfig = {};
switch (nodeType) {
case 'input':
nodeConfig = {
shape: [28, 28, 1],
outputShape: [28, 28, 1],
parameters: 0
};
break;
case 'hidden':
nodeConfig = {
units: 128,
activation: 'relu',
outputShape: [128],
parameters: 0
};
break;
case 'output':
nodeConfig = {
units: 10,
activation: 'softmax',
outputShape: [10],
parameters: 0
};
break;
case 'conv':
nodeConfig = {
filters: 32,
kernelSize: [3, 3],
strides: [1, 1],
padding: 'same',
activation: 'relu',
outputShape: ['?', '?', 32],
parameters: 0
};
break;
case 'pool':
nodeConfig = {
poolSize: [2, 2],
strides: [2, 2],
padding: 'valid',
poolType: 'max',
outputShape: ['?', '?', '?'],
parameters: 0
};
break;
case 'lstm':
nodeConfig = {
units: 64,
returnSequences: true,
activation: 'tanh',
recurrentActivation: 'sigmoid',
useBias: true,
outputShape: ['?', 64],
parameters: 0
};
break;
case 'rnn':
nodeConfig = {
units: 32,
returnSequences: true,
activation: 'tanh',
useBias: true,
outputShape: ['?', 32],
parameters: 0
};
break;
case 'gru':
nodeConfig = {
units: 48,
returnSequences: true,
activation: 'tanh',
recurrentActivation: 'sigmoid',
useBias: true,
outputShape: ['?', 48],
parameters: 0
};
break;
}
}
// Ensure Conv2D has properly formatted array values
if (nodeType === 'conv') {
if (!nodeConfig.kernelSize || typeof nodeConfig.kernelSize === 'string') {
nodeConfig.kernelSize = [3, 3];
}
if (!nodeConfig.strides || typeof nodeConfig.strides === 'string') {
nodeConfig.strides = [1, 1];
}
if (!nodeConfig.filters || isNaN(nodeConfig.filters)) {
nodeConfig.filters = 32;
}
nodeConfig.padding = nodeConfig.padding || 'same';
nodeConfig.activation = nodeConfig.activation || 'relu';
}
// Create node element
const canvasNode = document.createElement('div');
canvasNode.className = `canvas-node ${nodeType}-node`;
canvasNode.setAttribute('data-type', nodeType);
canvasNode.setAttribute('data-id', layerId);
canvasNode.style.position = 'absolute';
canvasNode.style.left = `${posX}px`;
canvasNode.style.top = `${posY}px`;
// Set up node content (input/output shape, parameters)
let nodeName, inputShape, outputShape, parameters;
switch(nodeType) {
case 'input':
nodeName = 'Input Layer';
inputShape = 'N/A';
outputShape = '[' + nodeConfig.shape.join(' × ') + ']';
parameters = nodeConfig.parameters;
break;
case 'hidden':
nodeConfig.units = nodeCounter[nodeType] === 1 ? 128 : 64;
nodeName = `Hidden Layer ${nodeCounter[nodeType]}`;
inputShape = 'Connect input';
outputShape = `[${nodeConfig.units}]`;
parameters = 'Connect input to calculate';
break;
case 'output':
nodeName = 'Output Layer';
inputShape = 'Connect input';
outputShape = `[${nodeConfig.units}]`;
parameters = 'Connect input to calculate';
break;
case 'conv':
nodeConfig.filters = 32 * nodeCounter[nodeType];
nodeName = `Conv2D ${nodeCounter[nodeType]}`;
inputShape = 'Connect input';
outputShape = 'Depends on input';
parameters = `Kernel: ${nodeConfig.kernelSize.join('×')}\nStride: ${nodeConfig.strides.join('×')}\nPadding: ${nodeConfig.padding}`;
break;
case 'pool':
nodeName = `Pooling ${nodeCounter[nodeType]}`;
inputShape = 'Connect input';
outputShape = 'Depends on input';
parameters = `Pool size: ${nodeConfig.poolSize.join('×')}\nStride: ${nodeConfig.strides.join('×')}\nPadding: ${nodeConfig.padding}`;
break;
case 'lstm':
nodeName = `LSTM ${nodeCounter[nodeType]}`;
inputShape = 'Connect input';
outputShape = `[?, ${nodeConfig.units}]`;
parameters = `Units: ${nodeConfig.units}\nReturn Sequences: ${nodeConfig.returnSequences ? 'Yes' : 'No'}\nGates: 4`;
break;
case 'rnn':
nodeName = `RNN ${nodeCounter[nodeType]}`;
inputShape = 'Connect input';
outputShape = `[?, ${nodeConfig.units}]`;
parameters = `Units: ${nodeConfig.units}\nReturn Sequences: ${nodeConfig.returnSequences ? 'Yes' : 'No'}`;
break;
case 'gru':
nodeName = `GRU ${nodeCounter[nodeType]}`;
inputShape = 'Connect input';
outputShape = `[?, ${nodeConfig.units}]`;
parameters = `Units: ${nodeConfig.units}\nReturn Sequences: ${nodeConfig.returnSequences ? 'Yes' : 'No'}\nGates: 3`;
break;
default:
nodeName = 'Unknown Layer';
inputShape = 'N/A';
outputShape = 'N/A';
parameters = 'N/A';
}
// Create node content
const nodeContent = document.createElement('div');
nodeContent.className = 'node-content';
// Add shape information
const shapeInfo = document.createElement('div');
shapeInfo.className = 'shape-info';
shapeInfo.innerHTML = `
Input: ${inputShape}
Output: ${outputShape}
`;
// Add parameters section
const paramsSection = document.createElement('div');
paramsSection.className = 'params-section';
paramsSection.innerHTML = `
${parameters}
Params: ${nodeConfig.parameters !== undefined ? formatNumber(nodeConfig.parameters) : '?'}
`;
// Assemble content
nodeContent.appendChild(shapeInfo);
nodeContent.appendChild(paramsSection);
// Add dimensions section
const dimensionsSection = document.createElement('div');
dimensionsSection.className = 'node-dimensions';
// Set dimensions text based on node type
let dimensionsText = '';
switch(nodeType) {
case 'input':
dimensionsText = nodeConfig.shape.join(' × ');
break;
case 'hidden':
case 'output':
dimensionsText = nodeConfig.units.toString();
break;
case 'conv':
if (nodeConfig.inputShape && nodeConfig.outputShape) {
dimensionsText = `${nodeConfig.inputShape.join('×')} → ${nodeConfig.outputShape.join('×')}`;
} else {
dimensionsText = `? → ${nodeConfig.filters} filters`;
}
break;
case 'pool':
if (nodeConfig.inputShape && nodeConfig.outputShape) {
dimensionsText = `${nodeConfig.inputShape.join('×')} → ${nodeConfig.outputShape.join('×')}`;
} else {
dimensionsText = `? → ?`;
}
break;
}
dimensionsSection.textContent = dimensionsText;
// Create node title
const nodeTitle = document.createElement('div');
nodeTitle.className = 'node-title';
nodeTitle.textContent = nodeName;
// Add node controls (edit and delete buttons)
const nodeControls = document.createElement('div');
nodeControls.className = 'node-controls';
const editButton = document.createElement('button');
editButton.className = 'node-edit-btn';
editButton.innerHTML = '✎';
editButton.title = 'Edit Layer';
const deleteButton = document.createElement('button');
deleteButton.className = 'node-delete-btn';
deleteButton.innerHTML = '×';
deleteButton.title = 'Delete Layer';
nodeControls.appendChild(editButton);
nodeControls.appendChild(deleteButton);
// Add connection ports
const portIn = document.createElement('div');
portIn.className = 'node-port port-in';
const portOut = document.createElement('div');
portOut.className = 'node-port port-out';
// Assemble the node
canvasNode.appendChild(nodeTitle);
canvasNode.appendChild(nodeControls);
canvasNode.appendChild(dimensionsSection);
canvasNode.appendChild(nodeContent);
canvasNode.appendChild(portIn);
canvasNode.appendChild(portOut);
// Store metadata
canvasNode.setAttribute('data-name', nodeName);
canvasNode.setAttribute('data-dimensions', dimensionsText);
canvasNode.layerConfig = nodeConfig;
// Add node to canvas
canvas.appendChild(canvasNode);
// Add to network model
networkLayers.layers.push({
id: layerId,
type: nodeType,
name: nodeName,
position: { x: posX, y: posY },
dimensions: dimensionsText,
config: nodeConfig,
parameters: nodeConfig.parameters || 0
});
// Set up event handlers (edit, delete, connections)
setupNodeEventHandlers(canvasNode);
// Hide canvas hint
const canvasHint = document.querySelector('.canvas-hint');
if (canvasHint) {
canvasHint.style.display = 'none';
}
// Notify model update
document.dispatchEvent(new CustomEvent('networkUpdated', {
detail: networkLayers
}));
console.log(`Node created: ${nodeType} (${layerId})`);
return canvasNode;
}
// 2. MOVING EXISTING NODES ON CANVAS
// Setup event handlers for node actions
function setupNodeEventHandlers(node) {
// Setup direct mouse handlers for dragging
node.addEventListener('mousedown', function(e) {
// Skip if clicking on controls or ports
if (e.target.closest('.node-controls') || e.target.closest('.node-port')) {
return;
}
console.log(`Mouse down on node: ${node.getAttribute('data-id')}`);
// Initialize drag
activeNode = node;
const rect = node.getBoundingClientRect();
offsetX = e.clientX - rect.left;
offsetY = e.clientY - rect.top;
isDragging = true;
// Visual indication
node.classList.add('dragging');
document.body.classList.add('node-dragging');
node.style.zIndex = '1000';
e.preventDefault();
});
// Edit button click
const editButton = node.querySelector('.node-edit-btn');
if (editButton) {
editButton.addEventListener('click', function(e) {
e.stopPropagation();
openLayerEditor(node);
});
}
// Delete button click
const deleteButton = node.querySelector('.node-delete-btn');
if (deleteButton) {
deleteButton.addEventListener('click', function(e) {
e.stopPropagation();
deleteNode(node);
});
}
// Double-click to edit
node.addEventListener('dblclick', function() {
openLayerEditor(node);
});
// Right-click to delete
node.addEventListener('contextmenu', function(e) {
e.preventDefault();
deleteNode(node);
});
// Connection port events
const portOut = node.querySelector('.port-out');
if (portOut) {
portOut.addEventListener('mousedown', function(e) {
e.stopPropagation();
// Use our own connection handler instead of relying on window.startConnection
startConnectionHandler(node, e);
});
}
}
// Global mouse handlers for dragging
document.addEventListener('mousemove', function(e) {
if (!isDragging || !activeNode) return;
// Log occasionally for debugging
if (Math.random() < 0.05) {
console.log('Node is being dragged...');
}
const canvasRect = canvas.getBoundingClientRect();
let x = e.clientX - canvasRect.left - offsetX;
let y = e.clientY - canvasRect.top - offsetY;
// Keep within canvas
const nodeWidth = activeNode.offsetWidth || 180;
const nodeHeight = activeNode.offsetHeight || 120;
x = Math.max(0, Math.min(canvasRect.width - nodeWidth, x));
y = Math.max(0, Math.min(canvasRect.height - nodeHeight, y));
// Move node
activeNode.style.left = `${x}px`;
activeNode.style.top = `${y}px`;
// Update model
const nodeId = activeNode.getAttribute('data-id');
const layerIndex = networkLayers.layers.findIndex(layer => layer.id === nodeId);
if (layerIndex !== -1) {
networkLayers.layers[layerIndex].position = { x, y };
}
// Update connections
updateConnections(nodeId);
});
document.addEventListener('mouseup', function() {
if (!isDragging || !activeNode) return;
console.log('Node drag complete');
// Visual cleanup
activeNode.classList.remove('dragging');
document.body.classList.remove('node-dragging');
activeNode.style.zIndex = '10';
// Final connection update
updateConnections();
// Cleanup
isDragging = false;
activeNode = null;
// Notify model update
document.dispatchEvent(new CustomEvent('networkUpdated', {
detail: networkLayers
}));
});
// Add handlers to existing nodes (for page refresh cases)
document.querySelectorAll('.canvas-node').forEach(setupNodeEventHandlers);
// 3. SUPPORTING FUNCTIONS
// Delete a node
function deleteNode(node) {
if (!node) return;
const nodeId = node.getAttribute('data-id');
console.log(`Deleting node: ${nodeId}`);
// Remove connections
const connections = document.querySelectorAll(`.connection[data-source="${nodeId}"], .connection[data-target="${nodeId}"]`);
connections.forEach(conn => {
if (conn.parentNode) {
conn.parentNode.removeChild(conn);
}
});
// Update model
networkLayers.connections = networkLayers.connections.filter(conn =>
conn.source !== nodeId && conn.target !== nodeId
);
const layerIndex = networkLayers.layers.findIndex(layer => layer.id === nodeId);
if (layerIndex !== -1) {
networkLayers.layers.splice(layerIndex, 1);
}
// Remove from DOM
if (node.parentNode) {
node.parentNode.removeChild(node);
}
// Show hint if no nodes left
if (document.querySelectorAll('.canvas-node').length === 0) {
const canvasHint = document.querySelector('.canvas-hint');
if (canvasHint) {
canvasHint.style.display = 'block';
}
}
// Notify model update
document.dispatchEvent(new CustomEvent('networkUpdated', {
detail: networkLayers
}));
}
// Open layer editor
function openLayerEditor(node) {
if (!node) return;
const nodeId = node.getAttribute('data-id');
const nodeType = node.getAttribute('data-type');
const nodeName = node.getAttribute('data-name');
const dimensions = node.getAttribute('data-dimensions');
console.log(`Opening editor for node: ${nodeId}`);
// Trigger editor event
document.dispatchEvent(new CustomEvent('openLayerEditor', {
detail: {
id: nodeId,
type: nodeType,
name: nodeName,
dimensions: dimensions,
node: node
}
}));
}
// Update connections
function updateConnections(specificNodeId = null) {
// Get connections to update
let connections;
if (specificNodeId) {
connections = document.querySelectorAll(`.connection[data-source="${specificNodeId}"], .connection[data-target="${specificNodeId}"]`);
} else {
connections = document.querySelectorAll('.connection:not(.temp-connection)');
}
connections.forEach(connection => {
const sourceId = connection.getAttribute('data-source');
const targetId = connection.getAttribute('data-target');
const sourceNode = document.querySelector(`.canvas-node[data-id="${sourceId}"]`);
const targetNode = document.querySelector(`.canvas-node[data-id="${targetId}"]`);
if (sourceNode && targetNode) {
const sourcePort = sourceNode.querySelector('.port-out');
const targetPort = targetNode.querySelector('.port-in');
if (sourcePort && targetPort) {
const canvasRect = canvas.getBoundingClientRect();
const sourceRect = sourcePort.getBoundingClientRect();
const targetRect = targetPort.getBoundingClientRect();
const startX = sourceRect.left + sourceRect.width / 2 - canvasRect.left;
const startY = sourceRect.top + sourceRect.height / 2 - canvasRect.top;
const endX = targetRect.left + targetRect.width / 2 - canvasRect.left;
const endY = targetRect.top + targetRect.height / 2 - canvasRect.top;
const length = Math.sqrt(Math.pow(endX - startX, 2) + Math.pow(endY - startY, 2));
const angle = Math.atan2(endY - startY, endX - startX) * 180 / Math.PI;
connection.style.left = `${startX}px`;
connection.style.top = `${startY}px`;
connection.style.width = `${length}px`;
connection.style.transform = `rotate(${angle}deg)`;
}
} else {
// Remove orphaned connection
if (connection.parentNode) {
connection.parentNode.removeChild(connection);
}
}
});
}
// 5. CONNECTION HANDLING
// Connection state tracking
let tempConnection = null;
let connectionSource = null;
// Start creating a connection
function startConnectionHandler(sourceNode, event) {
console.log('Starting connection from node:', sourceNode.getAttribute('data-id'));
// Cancel any existing connection attempt
if (tempConnection && tempConnection.parentNode) {
tempConnection.parentNode.removeChild(tempConnection);
}
// Create a temporary connection element
tempConnection = document.createElement('div');
tempConnection.className = 'connection temp-connection';
canvas.appendChild(tempConnection);
// Store the source node
connectionSource = sourceNode;
// Get initial positions
const sourceId = sourceNode.getAttribute('data-id');
const sourcePort = sourceNode.querySelector('.port-out');
const canvasRect = canvas.getBoundingClientRect();
const sourceRect = sourcePort.getBoundingClientRect();
const startX = sourceRect.left + sourceRect.width / 2 - canvasRect.left;
const startY = sourceRect.top + sourceRect.height / 2 - canvasRect.top;
// Set initial position
tempConnection.style.left = `${startX}px`;
tempConnection.style.top = `${startY}px`;
tempConnection.setAttribute('data-source', sourceId);
// Add event listeners for moving and completing the connection
document.addEventListener('mousemove', moveConnectionHandler);
document.addEventListener('mouseup', endConnectionHandler);
event.preventDefault();
event.stopPropagation();
}
// Update the temporary connection during drag
function moveConnectionHandler(event) {
if (!tempConnection || !connectionSource) return;
const canvasRect = canvas.getBoundingClientRect();
const sourcePort = connectionSource.querySelector('.port-out');
const sourceRect = sourcePort.getBoundingClientRect();
const startX = sourceRect.left + sourceRect.width / 2 - canvasRect.left;
const startY = sourceRect.top + sourceRect.height / 2 - canvasRect.top;
const endX = event.clientX - canvasRect.left;
const endY = event.clientY - canvasRect.top;
const length = Math.sqrt(Math.pow(endX - startX, 2) + Math.pow(endY - startY, 2));
const angle = Math.atan2(endY - startY, endX - startX) * 180 / Math.PI;
tempConnection.style.left = `${startX}px`;
tempConnection.style.top = `${startY}px`;
tempConnection.style.width = `${length}px`;
tempConnection.style.transform = `rotate(${angle}deg)`;
}
// Complete or cancel the connection
function endConnectionHandler(event) {
// Remove the event listeners
document.removeEventListener('mousemove', moveConnectionHandler);
document.removeEventListener('mouseup', endConnectionHandler);
if (!tempConnection || !connectionSource) return;
// Check if we're over a valid target (port-in)
const targetPort = document.elementFromPoint(event.clientX, event.clientY);
let targetNode = null;
if (targetPort && targetPort.classList.contains('port-in')) {
targetNode = targetPort.closest('.canvas-node');
}
if (targetNode) {
const sourceId = connectionSource.getAttribute('data-id');
const targetId = targetNode.getAttribute('data-id');
// Prevent self-connections
if (sourceId === targetId) {
console.log('Cannot connect a node to itself');
if (tempConnection.parentNode) {
tempConnection.parentNode.removeChild(tempConnection);
}
tempConnection = null;
connectionSource = null;
return;
}
// Check if connection already exists
const existingConnection = document.querySelector(`.connection[data-source="${sourceId}"][data-target="${targetId}"]`);
if (existingConnection) {
console.log('Connection already exists');
if (tempConnection.parentNode) {
tempConnection.parentNode.removeChild(tempConnection);
}
tempConnection = null;
connectionSource = null;
return;
}
console.log(`Creating connection: ${sourceId} → ${targetId}`);
// Create the permanent connection
tempConnection.classList.remove('temp-connection');
tempConnection.setAttribute('data-target', targetId);
// Add to network model
networkLayers.connections.push({
source: sourceId,
target: targetId
});
// Update connection display
updateConnections();
// Update parameters based on the new connection
updateParametersAfterConnection(sourceId, targetId);
// Notify model update
document.dispatchEvent(new CustomEvent('networkUpdated', {
detail: networkLayers
}));
} else {
// No valid target, remove the temp connection
if (tempConnection.parentNode) {
tempConnection.parentNode.removeChild(tempConnection);
}
}
// Reset state
tempConnection = null;
connectionSource = null;
}
// Update parameters after a connection is made
function updateParametersAfterConnection(sourceId, targetId) {
const sourceNode = document.querySelector(`.canvas-node[data-id="${sourceId}"]`);
const targetNode = document.querySelector(`.canvas-node[data-id="${targetId}"]`);
if (!sourceNode || !targetNode) return;
const sourceType = sourceNode.getAttribute('data-type');
const targetType = targetNode.getAttribute('data-type');
const sourceConfig = sourceNode.layerConfig || {};
const targetConfig = targetNode.layerConfig || {};
console.log(`Updating parameters: ${sourceType} → ${targetType}`);
// Check if target has a manual output shape (user set)
const hasManualOutputShape = targetConfig.outputShape &&
Array.isArray(targetConfig.outputShape) &&
targetConfig.outputShape.length > 0 &&
targetConfig.outputShape.some(dim => dim !== '?' && dim !== '');
console.log(`Target has manual output shape: ${hasManualOutputShape}`, targetConfig.outputShape);
// Set input shape of target based on output shape of source
if (sourceConfig.outputShape) {
targetConfig.inputShape = [...sourceConfig.outputShape];
// Update the display
const inputShapeDisplay = targetNode.querySelector('.input-shape');
if (inputShapeDisplay) {
inputShapeDisplay.textContent = `[${sourceConfig.outputShape.join(' × ')}]`;
}
}
// If target has a manual output shape, don't recalculate the output shape
if (hasManualOutputShape) {
console.log('Preserving manual output shape:', targetConfig.outputShape);
} else {
// Calculate output shape and parameters based on node type
if (window.neuralNetwork && window.neuralNetwork.calculateOutputShape) {
// Use neural network module if available
const outputShape = window.neuralNetwork.calculateOutputShape(targetConfig, targetType);
const parameters = window.neuralNetwork.calculateParameters(targetConfig, targetType);
if (outputShape) {
targetConfig.outputShape = outputShape;
// Update output shape display
const outputShapeDisplay = targetNode.querySelector('.output-shape');
if (outputShapeDisplay) {
outputShapeDisplay.textContent = `[${outputShape.join(' × ')}]`;
}
}
if (parameters !== undefined) {
targetConfig.parameters = parameters;
// Update parameters display
const paramsDisplay = targetNode.querySelector('.node-parameters');
if (paramsDisplay) {
paramsDisplay.textContent = `Params: ${formatNumber(parameters)}`;
}
}
} else {
// Fallback calculations if neural network module is not available
let outputShape, parameters;
switch (targetType) {
case 'hidden':
outputShape = [targetConfig.units || 64];
if (sourceConfig.outputShape) {
const inputSize = sourceConfig.outputShape.reduce((a, b) => a * b, 1);
parameters = inputSize * targetConfig.units + targetConfig.units; // weights + biases
}
break;
case 'output':
outputShape = [targetConfig.units || 10];
if (sourceConfig.outputShape) {
const inputSize = sourceConfig.outputShape.reduce((a, b) => a * b, 1);
parameters = inputSize * targetConfig.units + targetConfig.units; // weights + biases
}
break;
case 'rnn':
// Get units and check if returning sequences
const rnnUnits = parseInt(targetConfig.units) || 32;
const rnnReturnSequences = targetConfig.returnSequences === 'true' || targetConfig.returnSequences === true;
// Set output shape based on return_sequences setting
if (rnnReturnSequences && sourceConfig.outputShape && sourceConfig.outputShape.length > 0) {
// If return_sequences is true, output is [sequence_length, units]
outputShape = [sourceConfig.outputShape[0], rnnUnits];
} else {
// If return_sequences is false, output is just [units]
outputShape = [rnnUnits];
}
// Calculate parameters if we have input shape
if (sourceConfig.outputShape && sourceConfig.outputShape.length > 0) {
// Get the last dimension of the input as input_features
const inputFeatures = sourceConfig.outputShape[sourceConfig.outputShape.length - 1];
const useBias = targetConfig.useBias !== 'false' && targetConfig.useBias !== false;
// Formula: input_features * units + units * units + units (bias)
const inputParams = inputFeatures * rnnUnits;
const recurrentParams = rnnUnits * rnnUnits;
const biasParams = useBias ? rnnUnits : 0;
parameters = inputParams + recurrentParams + biasParams;
console.log(`RNN parameter calculation:
Input features: ${inputFeatures}
Units: ${rnnUnits}
Input weights: ${inputParams}
Recurrent weights: ${recurrentParams}
Bias: ${biasParams}
Total: ${parameters}`);
}
break;
case 'lstm':
// Get units and check if returning sequences
const lstmUnits = parseInt(targetConfig.units) || 64;
const lstmReturnSequences = targetConfig.returnSequences === 'true' || targetConfig.returnSequences === true;
// Set output shape based on return_sequences setting
if (lstmReturnSequences && sourceConfig.outputShape && sourceConfig.outputShape.length > 0) {
outputShape = [sourceConfig.outputShape[0], lstmUnits];
} else {
outputShape = [lstmUnits];
}
// Calculate parameters if we have input shape
if (sourceConfig.outputShape && sourceConfig.outputShape.length > 0) {
// LSTM has 4 gates, each with its own weights and biases
const inputFeatures = sourceConfig.outputShape[sourceConfig.outputShape.length - 1];
const useBias = targetConfig.useBias !== 'false' && targetConfig.useBias !== false;
// Formula: 4 * (input_features * units + units * units + units (bias))
const inputParams = 4 * (inputFeatures * lstmUnits);
const recurrentParams = 4 * (lstmUnits * lstmUnits);
const biasParams = useBias ? 4 * lstmUnits : 0;
parameters = inputParams + recurrentParams + biasParams;
console.log(`LSTM parameter calculation:
Input features: ${inputFeatures}
Units: ${lstmUnits}
Gates: 4 (input, forget, cell, output)
Input weights: ${inputParams}
Recurrent weights: ${recurrentParams}
Bias: ${biasParams}
Total: ${parameters}`);
}
break;
case 'gru':
// Get units and check if returning sequences
const gruUnits = parseInt(targetConfig.units) || 48;
const gruReturnSequences = targetConfig.returnSequences === 'true' || targetConfig.returnSequences === true;
// Set output shape based on return_sequences setting
if (gruReturnSequences && sourceConfig.outputShape && sourceConfig.outputShape.length > 0) {
outputShape = [sourceConfig.outputShape[0], gruUnits];
} else {
outputShape = [gruUnits];
}
// Calculate parameters if we have input shape
if (sourceConfig.outputShape && sourceConfig.outputShape.length > 0) {
// GRU has 3 gates, each with its own weights and biases
const inputFeatures = sourceConfig.outputShape[sourceConfig.outputShape.length - 1];
const useBias = targetConfig.useBias !== 'false' && targetConfig.useBias !== false;
// Formula: 3 * (input_features * units + units * units + units (bias))
const inputParams = 3 * (inputFeatures * gruUnits);
const recurrentParams = 3 * (gruUnits * gruUnits);
const biasParams = useBias ? 3 * gruUnits : 0;
parameters = inputParams + recurrentParams + biasParams;
console.log(`GRU parameter calculation:
Input features: ${inputFeatures}
Units: ${gruUnits}
Gates: 3 (update, reset, new)
Input weights: ${inputParams}
Recurrent weights: ${recurrentParams}
Bias: ${biasParams}
Total: ${parameters}`);
}
break;
case 'conv':
if (sourceConfig.outputShape && sourceConfig.outputShape.length >= 3) {
// Very explicit type conversion - ensure all values are numbers
const height = Math.max(1, parseInt(sourceConfig.outputShape[0]) || 1); // Ensure at least 1
const width = Math.max(1, parseInt(sourceConfig.outputShape[1]) || 1); // Ensure at least 1
const channels = Math.max(1, parseInt(sourceConfig.outputShape[2]) || 1); // Ensure at least 1
console.log(`Conv2D CONNECTION INPUT SHAPE: [${height}, ${width}, ${channels}]`,
{original: sourceConfig.outputShape, parsed: [height, width, channels]});
// Ensure filters is a positive number
const filters = Math.max(1, parseInt(targetConfig.filters) || 32);
// Explicit processing of kernelSize with safety checks
let kernelSize = [3, 3]; // Default fallback
if (targetConfig.kernelSize) {
if (typeof targetConfig.kernelSize === 'string') {
kernelSize = targetConfig.kernelSize.split(',')
.map(v => Math.max(1, parseInt(v.trim()) || 1)); // Ensure at least 1
} else if (Array.isArray(targetConfig.kernelSize)) {
kernelSize = targetConfig.kernelSize
.map(v => Math.max(1, parseInt(v) || 1)); // Ensure at least 1
}
}
// Explicit processing of strides with safety checks
let strides = [1, 1]; // Default fallback
if (targetConfig.strides) {
if (typeof targetConfig.strides === 'string') {
strides = targetConfig.strides.split(',')
.map(v => Math.max(1, parseInt(v.trim()) || 1)); // Ensure at least 1
} else if (Array.isArray(targetConfig.strides)) {
strides = targetConfig.strides
.map(v => Math.max(1, parseInt(v) || 1)); // Ensure at least 1
}
}
// Ensure we have at least 2 elements for kernelSize and strides and all values are at least 1
kernelSize = kernelSize.length >= 2 ?
[Math.max(1, kernelSize[0]), Math.max(1, kernelSize[1])] :
[Math.max(1, kernelSize[0] || 3), Math.max(1, kernelSize[0] || 3)];
strides = strides.length >= 2 ?
[Math.max(1, strides[0]), Math.max(1, strides[1])] :
[Math.max(1, strides[0] || 1), Math.max(1, strides[0] || 1)];
console.log(`Conv2D CONNECTION CONFIG:`, {
filters: filters,
kernelSize: kernelSize,
strides: strides
});
// Store cleaned values back in config
targetConfig.filters = filters;
targetConfig.kernelSize = kernelSize;
targetConfig.strides = strides;
const padding = targetConfig.padding || 'same';
// Calculate output dimensions based on padding
let outHeight, outWidth;
if (padding === 'same') {
outHeight = Math.ceil(height / strides[0]);
outWidth = Math.ceil(width / strides[1]);
} else { // 'valid' padding
outHeight = Math.ceil((height - kernelSize[0] + 1) / strides[0]);
outWidth = Math.ceil((width - kernelSize[1] + 1) / strides[1]);
}
// Ensure output dimensions are at least 1
outHeight = Math.max(1, outHeight);
outWidth = Math.max(1, outWidth);
// Final output shape with proper validation
outputShape = [outHeight, outWidth, filters];
// Calculate parameters step by step to avoid any overflow or multiplication errors
const kh = Number(kernelSize[0]);
const kw = Number(kernelSize[1]);
const c = Number(channels);
const f = Number(filters);
// Check for any zeros or negative values that would make the calculation invalid
if (kh <= 0 || kw <= 0 || c <= 0 || f <= 0) {
console.error(`Invalid Conv2D connection parameter values: kh=${kh}, kw=${kw}, c=${c}, f=${f}`);
parameters = 0;
} else {
// Calculate with explicit steps to avoid any overflow
const kernelParams = kh * kw * c * f;
const biasParams = f;
parameters = kernelParams + biasParams;
console.log(`Conv2D CONNECTION CALCULATION STEPS:
Kernel height (kh) = ${kh}
Kernel width (kw) = ${kw}
Input channels (c) = ${c}
Filters (f) = ${f}
Kernel params = ${kh} × ${kw} × ${c} × ${f} = ${kernelParams}
Bias params = ${biasParams}
Total params = ${kernelParams} + ${biasParams} = ${parameters}
`);
}
console.log(`Conv2D connection output shape: ${outHeight}×${outWidth}×${filters}`);
} else {
console.log('Cannot calculate Conv2D connection parameters - invalid input shape:', sourceConfig.outputShape);
const filters = parseInt(targetConfig.filters) || 32;
outputShape = ['?', '?', filters];
parameters = 0; // Set to 0 instead of '?' to avoid display issues
}
break;
case 'pool':
if (sourceConfig.outputShape && sourceConfig.outputShape.length >= 3) {
const [height, width, channels] = sourceConfig.outputShape;
const poolSize = targetConfig.poolSize || [2, 2];
const stride = targetConfig.strides || poolSize;
const padding = targetConfig.padding || 'valid';
// Calculate output dimensions
let outHeight, outWidth;
if (padding === 'same') {
outHeight = Math.ceil(height / stride[0]);
outWidth = Math.ceil(width / stride[1]);
} else { // 'valid' padding
outHeight = Math.ceil((height - poolSize[0] + 1) / stride[0]);
outWidth = Math.ceil((width - poolSize[1] + 1) / stride[1]);
}
outputShape = [outHeight, outWidth, channels];
parameters = 0; // Pooling layers have no parameters
}
break;
}
// Update target config and display only for automatically calculated shapes
if (outputShape) {
targetConfig.outputShape = outputShape;
// Update output shape display
const outputShapeDisplay = targetNode.querySelector('.output-shape');
if (outputShapeDisplay) {
outputShapeDisplay.textContent = `[${outputShape.join(' × ')}]`;
}
}
if (parameters !== undefined) {
targetConfig.parameters = parameters;
// Update parameters display
const paramsDisplay = targetNode.querySelector('.node-parameters');
if (paramsDisplay) {
paramsDisplay.textContent = `Params: ${formatNumber(parameters)}`;
}
}
}
}
// Store updated config back to the node
targetNode.layerConfig = targetConfig;
// Update model
const layerIndex = networkLayers.layers.findIndex(layer => layer.id === targetId);
if (layerIndex !== -1) {
networkLayers.layers[layerIndex].config = targetConfig;
if (targetConfig.parameters) {
networkLayers.layers[layerIndex].parameters = targetConfig.parameters;
}
}
// Force re-render the node to show updated info
const dimensions = targetNode.querySelector('.node-dimensions');
if (dimensions && targetConfig.outputShape) {
let dimensionsText = '';
if (targetType === 'hidden' || targetType === 'output') {
dimensionsText = targetConfig.units || '';
} else if (targetType === 'conv' || targetType === 'pool') {
dimensionsText = targetConfig.outputShape.join('×');
}
dimensions.textContent = dimensionsText;
}
}
// 4. EXPORT GLOBAL FUNCTIONS
// Expose functions to window for compatibility
window.dragDrop = {
getNetworkArchitecture: function() {
return networkLayers;
},
clearAllNodes: function() {
// Clear all nodes
document.querySelectorAll('.canvas-node, .connection').forEach(el => {
if (el.parentNode) {
el.parentNode.removeChild(el);
}
});
// Reset model
networkLayers = {
layers: [],
connections: []
};
// Reset counters
for (let key in nodeCounter) {
nodeCounter[key] = 0;
}
// Show hint
const canvasHint = document.querySelector('.canvas-hint');
if (canvasHint) {
canvasHint.style.display = 'block';
}
// Reset layer counter in neural network module
if (window.neuralNetwork && window.neuralNetwork.resetLayerCounter) {
window.neuralNetwork.resetLayerCounter();
}
// Notify update
document.dispatchEvent(new CustomEvent('networkUpdated', {
detail: networkLayers
}));
},
updateConnections: updateConnections,
// Force update all node parameters in the network
forceUpdateNetworkParameters: function() {
console.log('Force updating all network parameters');
// Get all root nodes (nodes with no incoming connections)
const rootNodes = [];
const allNodeIds = networkLayers.layers.map(layer => layer.id);
const targetNodeIds = networkLayers.connections.map(conn => conn.target);
allNodeIds.forEach(nodeId => {
if (!targetNodeIds.includes(nodeId)) {
rootNodes.push(nodeId);
}
});
console.log('Root nodes for parameter propagation:', rootNodes);
// Start update from root nodes
rootNodes.forEach(nodeId => {
updateDownstreamNodes(nodeId);
});
// Recursive function to update downstream nodes
function updateDownstreamNodes(nodeId) {
console.log(`Updating downstream from node: ${nodeId}`);
// Find all connections from this node
const outgoingConnections = networkLayers.connections.filter(conn => conn.source === nodeId);
// If no outgoing connections, we're done with this branch
if (outgoingConnections.length === 0) {
console.log(`Node ${nodeId} has no outgoing connections`);
return;
}
// Get source node and its config
const sourceNode = document.querySelector(`.canvas-node[data-id="${nodeId}"]`);
if (!sourceNode || !sourceNode.layerConfig) {
console.warn(`Source node ${nodeId} not found or has no config`);
return;
}
const sourceConfig = sourceNode.layerConfig;
const sourceType = sourceNode.getAttribute('data-type');
// Double check source outputShape is valid
if (!sourceConfig.outputShape || !Array.isArray(sourceConfig.outputShape)) {
console.warn(`Source node ${nodeId} (${sourceType}) has invalid output shape:`, sourceConfig.outputShape);
// Try to fix based on node type
if (sourceType === 'input' && Array.isArray(sourceConfig.shape)) {
sourceConfig.outputShape = [...sourceConfig.shape];
console.log(`Fixed input node output shape to:`, sourceConfig.outputShape);
}
}
console.log(`Source node ${nodeId} (${sourceType}) output shape:`, sourceConfig.outputShape);
// For each outgoing connection, update the target node
outgoingConnections.forEach(conn => {
const targetId = conn.target;
const targetNode = document.querySelector(`.canvas-node[data-id="${targetId}"]`);
if (!targetNode) {
console.warn(`Target node ${targetId} not found`);
return;
}
// Update target node
const targetType = targetNode.getAttribute('data-type');
const targetConfig = targetNode.layerConfig || {};
console.log(`Updating connection: ${sourceType}(${nodeId}) → ${targetType}(${targetId})`);
// Check if target has manually set output shape
const hasManualOutputShape = targetConfig.outputShape &&
Array.isArray(targetConfig.outputShape) &&
targetConfig.outputShape.length > 0 &&
targetConfig.outputShape.some(dim => dim !== '?' && dim !== '');
console.log(`Target node ${targetId} has manual output shape: ${hasManualOutputShape}`,
targetConfig.outputShape);
// Set input shape of target based on output shape of source
if (sourceConfig.outputShape) {
// Make a deep copy to avoid reference issues
targetConfig.inputShape = JSON.parse(JSON.stringify(sourceConfig.outputShape));
console.log(`Set target node ${targetId} input shape to:`, targetConfig.inputShape);
// Update the input shape display
const inputShapeDisplay = targetNode.querySelector('.input-shape');
if (inputShapeDisplay) {
inputShapeDisplay.textContent = `[${sourceConfig.outputShape.join(' × ')}]`;
}
// Only update output shape if not manually set
if (!hasManualOutputShape) {
// Special handling for Conv2D
if (targetType === 'conv') {
console.log(`Special handling for Conv2D target node ${targetId}`);
// Force update the parameters
if (window.updateParametersAfterConnection) {
try {
window.updateParametersAfterConnection(nodeId, targetId);
console.log(`Updated Conv2D node ${targetId} parameters through connection handler`);
} catch (error) {
console.error(`Error updating Conv2D parameters:`, error);
}
} else {
console.warn('updateParametersAfterConnection not available');
}
} else {
// Use standard update for other node types
if (window.updateParametersAfterConnection) {
window.updateParametersAfterConnection(nodeId, targetId);
} else {
// Otherwise, manually update the target node
updateNodeDisplay(targetNode, targetConfig);
}
}
} else {
console.log(`Preserving manual output shape for node ${targetId}:`, targetConfig.outputShape);
// Still update parameters even if output shape is manual
if (window.neuralNetwork && window.neuralNetwork.calculateParameters) {
try {
const parameters = window.neuralNetwork.calculateParameters(targetConfig, targetType);
if (parameters !== undefined) {
targetConfig.parameters = parameters;
// Update parameters display
const paramsDisplay = targetNode.querySelector('.node-parameters');
if (paramsDisplay) {
paramsDisplay.textContent = `Params: ${formatNumber(parameters)}`;
}
}
} catch (error) {
console.error(`Error calculating parameters with manual shape:`, error);
}
}
}
// Store updated config back to the node
targetNode.layerConfig = targetConfig;
// Continue propagation down the network
updateDownstreamNodes(targetId);
} else {
console.warn(`Source node ${nodeId} has no output shape, cannot update target ${targetId}`);
}
});
}
// Update node's display without trigger events that would cause loops
function updateNodeDisplay(node, config) {
if (!node) return;
const nodeType = node.getAttribute('data-type');
node.layerConfig = config;
// Update input shape display
const inputShapeDisplay = node.querySelector('.input-shape');
if (inputShapeDisplay && config.inputShape) {
inputShapeDisplay.textContent = `[${config.inputShape.join(' × ')}]`;
}
// Other updates would depend on neural network module
// This is just a basic update without recalculating everything
}
// Update all connections visually
updateConnections();
// Notify that network has been updated
document.dispatchEvent(new CustomEvent('networkUpdated', {
detail: networkLayers
}));
console.log('Finished force updating network parameters');
}
};
// Add global connection handlers for compatibility with existing code
window.startConnection = startConnectionHandler;
window.updateParametersAfterConnection = updateParametersAfterConnection;
// Debugging help
console.log('Complete drag and drop fix initialized');
// Add a button to manually fix Conv2D parameters
function addConv2DFixButton() {
// Check if button already exists
if (document.getElementById('fix-conv2d-button')) {
return;
}
// Create the button
const fixButton = document.createElement('button');
fixButton.id = 'fix-conv2d-button';
fixButton.textContent = 'Fix Conv2D Params';
fixButton.title = 'Manually recalculate parameters for Conv2D nodes';
// Style the button
Object.assign(fixButton.style, {
position: 'absolute',
right: '10px',
top: '10px',
zIndex: '9999',
padding: '5px 10px',
backgroundColor: '#4285f4',
color: 'white',
border: 'none',
borderRadius: '4px',
cursor: 'pointer',
fontSize: '12px',
fontWeight: 'bold',
boxShadow: '0 2px 5px rgba(0,0,0,0.2)'
});
// Add hover effect
fixButton.onmouseover = function() {
this.style.backgroundColor = '#3367d6';
};
fixButton.onmouseout = function() {
this.style.backgroundColor = '#4285f4';
};
// Add click handler
fixButton.addEventListener('click', function() {
console.log('Manually fixing Conv2D parameters...');
// Check if our helper function exists
if (window.forceRecalculateConv2DParameters) {
window.forceRecalculateConv2DParameters();
fixButton.textContent = 'Conv2D Fixed!';
setTimeout(() => {
fixButton.textContent = 'Fix Conv2D Params';
}, 2000);
} else {
console.error('Conv2D helper function not found');
alert('Conv2D helper function not found! Please refresh the page and try again.');
}
});
// Add to body
document.body.appendChild(fixButton);
console.log('Added Conv2D fix button');
}
}
})();