|
|
|
function initializeDragAndDrop() { |
|
const nodeItems = document.querySelectorAll('.node-item'); |
|
const canvas = document.getElementById('network-canvas'); |
|
let draggedNode = null; |
|
let offsetX, offsetY; |
|
let isDragging = false; |
|
let isConnecting = false; |
|
let startNode = null; |
|
let connectionLine = null; |
|
let nodeCounter = {}; |
|
|
|
|
|
let networkLayers = { |
|
layers: [], |
|
connections: [] |
|
}; |
|
|
|
|
|
nodeItems.forEach(item => { |
|
item.addEventListener('dragstart', handleDragStart); |
|
}); |
|
|
|
|
|
canvas.addEventListener('dragover', handleDragOver); |
|
canvas.addEventListener('drop', handleDrop); |
|
|
|
|
|
function handleDragStart(e) { |
|
draggedNode = this; |
|
e.dataTransfer.setData('text/plain', this.getAttribute('data-type')); |
|
|
|
|
|
const ghost = this.cloneNode(true); |
|
ghost.style.opacity = '0.5'; |
|
document.body.appendChild(ghost); |
|
e.dataTransfer.setDragImage(ghost, 0, 0); |
|
setTimeout(() => { |
|
document.body.removeChild(ghost); |
|
}, 0); |
|
} |
|
|
|
|
|
function handleDragOver(e) { |
|
e.preventDefault(); |
|
e.dataTransfer.dropEffect = 'copy'; |
|
} |
|
|
|
|
|
function handleDrop(e) { |
|
e.preventDefault(); |
|
|
|
|
|
const canvasHint = document.querySelector('.canvas-hint'); |
|
if (canvasHint) { |
|
canvasHint.style.display = 'none'; |
|
} |
|
|
|
const nodeType = e.dataTransfer.getData('text/plain'); |
|
|
|
if (nodeType) { |
|
|
|
const layerId = window.neuralNetwork.getNextLayerId(nodeType); |
|
|
|
|
|
const canvasNode = document.createElement('div'); |
|
canvasNode.className = `canvas-node ${nodeType}-node`; |
|
canvasNode.setAttribute('data-type', nodeType); |
|
canvasNode.setAttribute('data-id', layerId); |
|
|
|
|
|
const rect = canvas.getBoundingClientRect(); |
|
const x = e.clientX - rect.left; |
|
const y = e.clientY - rect.top; |
|
|
|
canvasNode.style.left = `${x}px`; |
|
canvasNode.style.top = `${y}px`; |
|
|
|
|
|
const nodeConfig = window.neuralNetwork.createNodeConfig(nodeType); |
|
|
|
|
|
let nodeName, inputShape, outputShape, parameters; |
|
|
|
switch(nodeType) { |
|
case 'input': |
|
nodeName = 'Input Layer'; |
|
inputShape = 'N/A'; |
|
outputShape = '[' + nodeConfig.shape.join(' × ') + ']'; |
|
parameters = nodeConfig.parameters; |
|
break; |
|
case 'hidden': |
|
const hiddenCount = document.querySelectorAll('.canvas-node[data-type="hidden"]').length; |
|
nodeConfig.units = hiddenCount === 0 ? 128 : 64; |
|
nodeName = `Hidden Layer ${hiddenCount + 1}`; |
|
|
|
inputShape = 'Connect input'; |
|
outputShape = `[${nodeConfig.units}]`; |
|
parameters = 'Connect input to calculate'; |
|
break; |
|
case 'output': |
|
nodeName = 'Output Layer'; |
|
inputShape = 'Connect input'; |
|
outputShape = `[${nodeConfig.units}]`; |
|
parameters = 'Connect input to calculate'; |
|
break; |
|
case 'conv': |
|
const convCount = document.querySelectorAll('.canvas-node[data-type="conv"]').length; |
|
nodeConfig.filters = 32 * (convCount + 1); |
|
nodeName = `Conv2D ${convCount + 1}`; |
|
inputShape = 'Connect input'; |
|
outputShape = 'Depends on input'; |
|
|
|
parameters = `In: ?, Out: ${nodeConfig.filters}\nKernel: ${nodeConfig.kernelSize.join('×')}\nStride: ${nodeConfig.strides.join('×')}\nPadding: ${nodeConfig.padding}`; |
|
break; |
|
case 'pool': |
|
const poolCount = document.querySelectorAll('.canvas-node[data-type="pool"]').length; |
|
nodeName = `Pooling ${poolCount + 1}`; |
|
inputShape = 'Connect input'; |
|
outputShape = 'Depends on input'; |
|
parameters = `Pool size: ${nodeConfig.poolSize.join('×')}\nStride: ${nodeConfig.strides.join('×')}\nPadding: ${nodeConfig.padding}`; |
|
break; |
|
default: |
|
nodeName = 'Unknown Layer'; |
|
inputShape = 'N/A'; |
|
outputShape = 'N/A'; |
|
parameters = 'N/A'; |
|
} |
|
|
|
|
|
const nodeHeader = document.createElement('div'); |
|
nodeHeader.className = 'node-header'; |
|
nodeHeader.textContent = nodeName; |
|
|
|
|
|
const nodeContent = document.createElement('div'); |
|
nodeContent.className = 'node-content'; |
|
|
|
|
|
const shapeInfo = document.createElement('div'); |
|
shapeInfo.className = 'shape-info'; |
|
shapeInfo.innerHTML = ` |
|
<div class="shape-row"><span class="shape-label">Input:</span> <span class="input-shape">${inputShape}</span></div> |
|
<div class="shape-row"><span class="shape-label">Output:</span> <span class="output-shape">${outputShape}</span></div> |
|
`; |
|
|
|
|
|
const paramsSection = document.createElement('div'); |
|
paramsSection.className = 'params-section'; |
|
paramsSection.innerHTML = `<pre class="params-display">${parameters}</pre>`; |
|
|
|
|
|
const inputPort = document.createElement('div'); |
|
inputPort.className = 'port input-port'; |
|
inputPort.setAttribute('data-port-type', 'input'); |
|
|
|
const outputPort = document.createElement('div'); |
|
outputPort.className = 'port output-port'; |
|
outputPort.setAttribute('data-port-type', 'output'); |
|
|
|
|
|
nodeContent.appendChild(shapeInfo); |
|
nodeContent.appendChild(paramsSection); |
|
|
|
canvasNode.appendChild(nodeHeader); |
|
canvasNode.appendChild(nodeContent); |
|
canvasNode.appendChild(inputPort); |
|
canvasNode.appendChild(outputPort); |
|
|
|
|
|
canvas.appendChild(canvasNode); |
|
|
|
|
|
canvasNode.layerConfig = nodeConfig; |
|
|
|
|
|
canvasNode.addEventListener('mousedown', startDrag); |
|
inputPort.addEventListener('mousedown', (e) => { |
|
e.stopPropagation(); |
|
}); |
|
outputPort.addEventListener('mousedown', (e) => { |
|
e.stopPropagation(); |
|
startConnection(canvasNode, e); |
|
}); |
|
|
|
|
|
canvasNode.addEventListener('dblclick', () => { |
|
openLayerEditor(canvasNode); |
|
}); |
|
|
|
|
|
canvasNode.addEventListener('contextmenu', (e) => { |
|
e.preventDefault(); |
|
deleteNode(canvasNode); |
|
}); |
|
|
|
|
|
networkLayers.layers.push({ |
|
id: layerId, |
|
type: nodeType, |
|
config: nodeConfig |
|
}); |
|
|
|
|
|
document.dispatchEvent(new CustomEvent('networkUpdated', { |
|
detail: networkLayers |
|
})); |
|
|
|
updateConnections(); |
|
} |
|
} |
|
|
|
|
|
function startDrag(e) { |
|
if (isConnecting) return; |
|
|
|
|
|
if (e.target.closest('.node-controls') || e.target.closest('.node-port')) { |
|
return; |
|
} |
|
|
|
isDragging = true; |
|
const target = e.target.closest('.canvas-node'); |
|
const rect = target.getBoundingClientRect(); |
|
|
|
|
|
offsetX = e.clientX - rect.left; |
|
offsetY = e.clientY - rect.top; |
|
|
|
document.addEventListener('mousemove', dragNode); |
|
document.addEventListener('mouseup', stopDrag); |
|
|
|
|
|
draggedNode = target; |
|
|
|
|
|
draggedNode.style.zIndex = "100"; |
|
|
|
|
|
draggedNode.classList.add('dragging'); |
|
|
|
|
|
e.preventDefault(); |
|
} |
|
|
|
|
|
function dragNode(e) { |
|
if (!isDragging) return; |
|
|
|
const canvasRect = canvas.getBoundingClientRect(); |
|
let x = e.clientX - canvasRect.left - offsetX; |
|
let y = e.clientY - canvasRect.top - offsetY; |
|
|
|
|
|
x = Math.max(0, Math.min(canvasRect.width - draggedNode.offsetWidth, x)); |
|
y = Math.max(0, Math.min(canvasRect.height - draggedNode.offsetHeight, y)); |
|
|
|
draggedNode.style.left = `${x}px`; |
|
draggedNode.style.top = `${y}px`; |
|
|
|
|
|
const nodeId = draggedNode.getAttribute('data-id'); |
|
const layerIndex = networkLayers.layers.findIndex(layer => layer.id === nodeId); |
|
if (layerIndex !== -1) { |
|
networkLayers.layers[layerIndex].position = { x, y }; |
|
} |
|
|
|
|
|
updateConnections(); |
|
} |
|
|
|
|
|
function stopDrag() { |
|
if (!isDragging) return; |
|
|
|
isDragging = false; |
|
document.removeEventListener('mousemove', dragNode); |
|
document.removeEventListener('mouseup', stopDrag); |
|
|
|
|
|
if (draggedNode) { |
|
draggedNode.style.zIndex = "10"; |
|
draggedNode.classList.remove('dragging'); |
|
|
|
|
|
updateConnections(); |
|
} |
|
} |
|
|
|
|
|
function startConnection(node, e) { |
|
isConnecting = true; |
|
startNode = node; |
|
|
|
|
|
connectionLine = document.createElement('div'); |
|
connectionLine.className = 'connection temp-connection'; |
|
|
|
|
|
const portOut = node.querySelector('.port-out'); |
|
const portRect = portOut.getBoundingClientRect(); |
|
const canvasRect = canvas.getBoundingClientRect(); |
|
|
|
const startX = portRect.left + portRect.width / 2 - canvasRect.left; |
|
const startY = portRect.top + portRect.height / 2 - canvasRect.top; |
|
|
|
|
|
connectionLine.style.left = `${startX}px`; |
|
connectionLine.style.top = `${startY}px`; |
|
connectionLine.style.width = '0px'; |
|
connectionLine.style.transform = 'rotate(0deg)'; |
|
|
|
|
|
portOut.classList.add('active-port'); |
|
|
|
|
|
highlightValidConnectionTargets(node); |
|
|
|
canvas.appendChild(connectionLine); |
|
|
|
|
|
document.addEventListener('mousemove', drawConnection); |
|
document.addEventListener('mouseup', cancelConnection); |
|
|
|
e.preventDefault(); |
|
} |
|
|
|
|
|
function highlightValidConnectionTargets(sourceNode) { |
|
const sourceType = sourceNode.getAttribute('data-type'); |
|
const sourceId = sourceNode.getAttribute('data-id'); |
|
|
|
document.querySelectorAll('.canvas-node').forEach(node => { |
|
if (node !== sourceNode) { |
|
const nodeType = node.getAttribute('data-type'); |
|
const nodeId = node.getAttribute('data-id'); |
|
const isValidTarget = isValidConnection(sourceType, nodeType, sourceId, nodeId); |
|
|
|
const portIn = node.querySelector('.port-in'); |
|
if (isValidTarget) { |
|
portIn.classList.add('valid-target'); |
|
} else { |
|
portIn.classList.add('invalid-target'); |
|
} |
|
} |
|
}); |
|
} |
|
|
|
|
|
function removePortHighlights() { |
|
document.querySelectorAll('.port-in, .port-out').forEach(port => { |
|
port.classList.remove('active-port', 'valid-target', 'invalid-target'); |
|
}); |
|
} |
|
|
|
|
|
function isValidConnection(sourceType, targetType, sourceId, targetId) { |
|
|
|
if (sourceType === 'output' || targetType === 'input') { |
|
return false; |
|
} |
|
|
|
|
|
const existingConnection = networkLayers.connections.find( |
|
conn => conn.target === sourceId && conn.source === targetId |
|
); |
|
if (existingConnection) { |
|
return false; |
|
} |
|
|
|
|
|
switch(sourceType) { |
|
case 'input': |
|
return ['hidden', 'conv'].includes(targetType); |
|
case 'conv': |
|
return ['conv', 'pool', 'hidden'].includes(targetType); |
|
case 'pool': |
|
return ['conv', 'hidden'].includes(targetType); |
|
case 'hidden': |
|
return ['hidden', 'output'].includes(targetType); |
|
default: |
|
return false; |
|
} |
|
} |
|
|
|
|
|
function drawConnection(e) { |
|
if (!isConnecting || !connectionLine) return; |
|
|
|
const canvasRect = canvas.getBoundingClientRect(); |
|
const portOut = startNode.querySelector('.port-out'); |
|
const portRect = portOut.getBoundingClientRect(); |
|
|
|
|
|
const startX = portRect.left + portRect.width / 2 - canvasRect.left; |
|
const startY = portRect.top + portRect.height / 2 - canvasRect.top; |
|
const endX = e.clientX - canvasRect.left; |
|
const endY = e.clientY - canvasRect.top; |
|
|
|
|
|
const length = Math.sqrt(Math.pow(endX - startX, 2) + Math.pow(endY - startY, 2)); |
|
const angle = Math.atan2(endY - startY, endX - startX) * 180 / Math.PI; |
|
|
|
|
|
connectionLine.style.width = `${length}px`; |
|
connectionLine.style.transform = `rotate(${angle}deg)`; |
|
|
|
|
|
document.querySelectorAll('.canvas-node').forEach(node => { |
|
if (node !== startNode) { |
|
const nodeRect = node.getBoundingClientRect(); |
|
const portIn = node.querySelector('.port-in'); |
|
const portInRect = portIn.getBoundingClientRect(); |
|
|
|
|
|
if (e.clientX >= portInRect.left && e.clientX <= portInRect.right && |
|
e.clientY >= portInRect.top && e.clientY <= portInRect.bottom) { |
|
portIn.classList.add('port-hover'); |
|
} else { |
|
portIn.classList.remove('port-hover'); |
|
} |
|
} |
|
}); |
|
} |
|
|
|
|
|
function cancelConnection(e) { |
|
if (!isConnecting) return; |
|
|
|
|
|
let targetNode = null; |
|
document.querySelectorAll('.canvas-node').forEach(node => { |
|
if (node !== startNode) { |
|
const portIn = node.querySelector('.port-in'); |
|
const portRect = portIn.getBoundingClientRect(); |
|
|
|
if (e.clientX >= portRect.left && e.clientX <= portRect.right && |
|
e.clientY >= portRect.top && e.clientY <= portRect.bottom) { |
|
|
|
|
|
const sourceType = startNode.getAttribute('data-type'); |
|
const targetType = node.getAttribute('data-type'); |
|
const sourceId = startNode.getAttribute('data-id'); |
|
const targetId = node.getAttribute('data-id'); |
|
|
|
if (isValidConnection(sourceType, targetType, sourceId, targetId)) { |
|
targetNode = node; |
|
} |
|
} |
|
} |
|
}); |
|
|
|
|
|
if (targetNode) { |
|
endConnection(targetNode); |
|
} else { |
|
|
|
if (connectionLine && connectionLine.parentNode) { |
|
connectionLine.parentNode.removeChild(connectionLine); |
|
} |
|
} |
|
|
|
|
|
removePortHighlights(); |
|
document.querySelectorAll('.port-hover').forEach(port => { |
|
port.classList.remove('port-hover'); |
|
}); |
|
|
|
|
|
isConnecting = false; |
|
startNode = null; |
|
connectionLine = null; |
|
|
|
|
|
document.removeEventListener('mousemove', drawConnection); |
|
document.removeEventListener('mouseup', cancelConnection); |
|
} |
|
|
|
|
|
function endConnection(targetNode) { |
|
if (!isConnecting || !connectionLine || !startNode) return; |
|
|
|
const sourceType = startNode.getAttribute('data-type'); |
|
const targetType = targetNode.getAttribute('data-type'); |
|
const sourceId = startNode.getAttribute('data-id'); |
|
const targetId = targetNode.getAttribute('data-id'); |
|
|
|
|
|
if (isValidConnection(sourceType, targetType, sourceId, targetId)) { |
|
|
|
const canvas = document.getElementById('network-canvas'); |
|
const svgContainer = document.querySelector('#network-canvas .svg-container') || createSVGContainer(); |
|
|
|
|
|
const sourceRect = startNode.getBoundingClientRect(); |
|
const targetRect = targetNode.getBoundingClientRect(); |
|
const canvasRect = canvas.getBoundingClientRect(); |
|
|
|
|
|
const sourcePort = startNode.querySelector('.output-port'); |
|
const targetPort = targetNode.querySelector('.input-port'); |
|
|
|
const sourcePortRect = sourcePort.getBoundingClientRect(); |
|
const targetPortRect = targetPort.getBoundingClientRect(); |
|
|
|
const startX = sourcePortRect.left + (sourcePortRect.width / 2) - canvasRect.left; |
|
const startY = sourcePortRect.top + (sourcePortRect.height / 2) - canvasRect.top; |
|
const endX = targetPortRect.left + (targetPortRect.width / 2) - canvasRect.left; |
|
const endY = targetPortRect.top + (targetPortRect.height / 2) - canvasRect.top; |
|
|
|
|
|
const pathId = `connection-${sourceId}-${targetId}`; |
|
const connectionPath = document.createElementNS('http://www.w3.org/2000/svg', 'path'); |
|
connectionPath.setAttribute('id', pathId); |
|
connectionPath.setAttribute('class', 'connection-line'); |
|
|
|
|
|
const dx = Math.abs(endX - startX) * 0.7; |
|
const path = `M ${startX} ${startY} C ${startX + dx} ${startY}, ${endX - dx} ${endY}, ${endX} ${endY}`; |
|
connectionPath.setAttribute('d', path); |
|
|
|
|
|
svgContainer.appendChild(connectionPath); |
|
|
|
|
|
networkLayers.connections.push({ |
|
id: pathId, |
|
source: sourceId, |
|
target: targetId, |
|
sourceType: sourceType, |
|
targetType: targetType |
|
}); |
|
|
|
|
|
updateNodeShapes(sourceId, targetId); |
|
|
|
|
|
document.dispatchEvent(new CustomEvent('networkUpdated', { |
|
detail: networkLayers |
|
})); |
|
} |
|
|
|
|
|
removePortHighlights(); |
|
if (connectionLine) { |
|
connectionLine.remove(); |
|
connectionLine = null; |
|
} |
|
isConnecting = false; |
|
startNode = null; |
|
} |
|
|
|
|
|
function updateNodeShapes(sourceId, targetId) { |
|
const sourceNode = document.querySelector(`.canvas-node[data-id="${sourceId}"]`); |
|
const targetNode = document.querySelector(`.canvas-node[data-id="${targetId}"]`); |
|
|
|
if (sourceNode && targetNode) { |
|
const sourceConfig = sourceNode.layerConfig; |
|
const targetConfig = targetNode.layerConfig; |
|
|
|
|
|
if (sourceConfig && targetConfig) { |
|
|
|
let outputShape; |
|
switch (sourceNode.getAttribute('data-type')) { |
|
case 'input': |
|
outputShape = sourceConfig.shape; |
|
break; |
|
case 'hidden': |
|
outputShape = [sourceConfig.units]; |
|
break; |
|
case 'output': |
|
outputShape = [sourceConfig.units]; |
|
break; |
|
case 'conv': |
|
|
|
|
|
if (targetConfig.inputShape) { |
|
const h = targetConfig.inputShape[0]; |
|
const w = targetConfig.inputShape[1]; |
|
const kh = sourceConfig.kernelSize[0]; |
|
const kw = sourceConfig.kernelSize[1]; |
|
const sh = sourceConfig.strides[0]; |
|
const sw = sourceConfig.strides[1]; |
|
const padding = sourceConfig.padding; |
|
|
|
let outHeight, outWidth; |
|
if (padding === 'same') { |
|
outHeight = Math.ceil(h / sh); |
|
outWidth = Math.ceil(w / sw); |
|
} else { |
|
outHeight = Math.ceil((h - kh + 1) / sh); |
|
outWidth = Math.ceil((w - kw + 1) / sw); |
|
} |
|
|
|
outputShape = [outHeight, outWidth, sourceConfig.filters]; |
|
} else { |
|
outputShape = ['?', '?', sourceConfig.filters]; |
|
} |
|
break; |
|
case 'pool': |
|
|
|
if (targetConfig.inputShape) { |
|
const h = targetConfig.inputShape[0]; |
|
const w = targetConfig.inputShape[1]; |
|
const c = targetConfig.inputShape[2]; |
|
const ph = sourceConfig.poolSize[0]; |
|
const pw = sourceConfig.poolSize[1]; |
|
const sh = sourceConfig.strides[0]; |
|
const sw = sourceConfig.strides[1]; |
|
const padding = sourceConfig.padding; |
|
|
|
let outHeight, outWidth; |
|
if (padding === 'same') { |
|
outHeight = Math.ceil(h / sh); |
|
outWidth = Math.ceil(w / sw); |
|
} else { |
|
outHeight = Math.ceil((h - ph + 1) / sh); |
|
outWidth = Math.ceil((w - pw + 1) / sw); |
|
} |
|
|
|
outputShape = [outHeight, outWidth, c]; |
|
} else { |
|
outputShape = ['?', '?', '?']; |
|
} |
|
break; |
|
case 'linear': |
|
outputShape = [sourceConfig.outputFeatures]; |
|
break; |
|
default: |
|
outputShape = ['?', '?', '?']; |
|
} |
|
|
|
|
|
targetConfig.inputShape = outputShape; |
|
|
|
|
|
updateNodeDisplayShapes(sourceNode, targetNode); |
|
} |
|
} |
|
} |
|
|
|
|
|
function updateNodeDisplayShapes(sourceNode, targetNode) { |
|
if (sourceNode && targetNode) { |
|
const sourceType = sourceNode.getAttribute('data-type'); |
|
const targetType = targetNode.getAttribute('data-type'); |
|
const sourceConfig = sourceNode.layerConfig; |
|
const targetConfig = targetNode.layerConfig; |
|
|
|
|
|
const sourceOutputElem = sourceNode.querySelector('.output-shape'); |
|
if (sourceOutputElem && sourceConfig) { |
|
let outputText; |
|
switch (sourceType) { |
|
case 'input': |
|
outputText = `[${sourceConfig.shape.join(' × ')}]`; |
|
break; |
|
case 'hidden': |
|
case 'output': |
|
outputText = `[${sourceConfig.units}]`; |
|
break; |
|
case 'conv': |
|
if (sourceConfig.outputShape) { |
|
outputText = `[${sourceConfig.outputShape.join(' × ')}]`; |
|
} else { |
|
outputText = `[? × ? × ${sourceConfig.filters}]`; |
|
} |
|
break; |
|
case 'pool': |
|
if (sourceConfig.outputShape) { |
|
outputText = `[${sourceConfig.outputShape.join(' × ')}]`; |
|
} else { |
|
outputText = 'Depends on input'; |
|
} |
|
break; |
|
case 'linear': |
|
outputText = `[${sourceConfig.outputFeatures}]`; |
|
break; |
|
default: |
|
outputText = 'Unknown'; |
|
} |
|
sourceOutputElem.textContent = outputText; |
|
} |
|
|
|
|
|
const targetInputElem = targetNode.querySelector('.input-shape'); |
|
if (targetInputElem && targetConfig && targetConfig.inputShape) { |
|
targetInputElem.textContent = `[${targetConfig.inputShape.join(' × ')}]`; |
|
|
|
|
|
const targetParamsElem = targetNode.querySelector('.params-display'); |
|
if (targetParamsElem) { |
|
|
|
let paramsText = ''; |
|
switch (targetType) { |
|
case 'hidden': |
|
const inputUnits = Array.isArray(targetConfig.inputShape) ? |
|
targetConfig.inputShape.reduce((acc, val) => acc * val, 1) : |
|
targetConfig.inputShape; |
|
|
|
const biasParams = targetConfig.useBias ? targetConfig.units : 0; |
|
const totalParams = (inputUnits * targetConfig.units) + biasParams; |
|
|
|
paramsText = `In: ${inputUnits}, Out: ${targetConfig.units}\nParams: ${totalParams.toLocaleString()}\nDropout: ${targetConfig.dropoutRate}`; |
|
break; |
|
case 'output': |
|
const outInputUnits = Array.isArray(targetConfig.inputShape) ? |
|
targetConfig.inputShape.reduce((acc, val) => acc * val, 1) : |
|
targetConfig.inputShape; |
|
|
|
const outBiasParams = targetConfig.useBias ? targetConfig.units : 0; |
|
const outTotalParams = (outInputUnits * targetConfig.units) + outBiasParams; |
|
|
|
paramsText = `In: ${outInputUnits}, Out: ${targetConfig.units}\nParams: ${outTotalParams.toLocaleString()}\nActivation: ${targetConfig.activation}`; |
|
break; |
|
case 'conv': |
|
const channels = targetConfig.inputShape[2] || '?'; |
|
const kernelH = targetConfig.kernelSize[0]; |
|
const kernelW = targetConfig.kernelSize[1]; |
|
const kernelParams = kernelH * kernelW * channels * targetConfig.filters; |
|
const convBiasParams = targetConfig.useBias ? targetConfig.filters : 0; |
|
const convTotalParams = kernelParams + convBiasParams; |
|
|
|
paramsText = `In: ${channels}, Out: ${targetConfig.filters}\nKernel: ${targetConfig.kernelSize.join('×')}\nStride: ${targetConfig.strides.join('×')}\nPadding: ${targetConfig.padding}\nParams: ${convTotalParams.toLocaleString()}`; |
|
break; |
|
case 'pool': |
|
paramsText = `Pool size: ${targetConfig.poolSize.join('×')}\nStride: ${targetConfig.strides.join('×')}\nPadding: ${targetConfig.padding}\nParams: 0`; |
|
break; |
|
case 'linear': |
|
const linearInputs = targetConfig.inputFeatures; |
|
const linearBiasParams = targetConfig.useBias ? targetConfig.outputFeatures : 0; |
|
const linearTotalParams = (linearInputs * targetConfig.outputFeatures) + linearBiasParams; |
|
|
|
paramsText = `In: ${linearInputs}, Out: ${targetConfig.outputFeatures}\nParams: ${linearTotalParams.toLocaleString()}\nLearning Rate: ${targetConfig.learningRate}\nLoss: ${targetConfig.lossFunction}`; |
|
break; |
|
} |
|
|
|
targetParamsElem.textContent = paramsText; |
|
} |
|
} |
|
} |
|
} |
|
|
|
|
|
function deleteNode(node) { |
|
if (!node) return; |
|
|
|
const nodeId = node.getAttribute('data-id'); |
|
|
|
|
|
document.querySelectorAll(`.connection[data-source="${nodeId}"], .connection[data-target="${nodeId}"]`).forEach(conn => { |
|
conn.parentNode.removeChild(conn); |
|
}); |
|
|
|
|
|
networkLayers.layers = networkLayers.layers.filter(layer => layer.id !== nodeId); |
|
networkLayers.connections = networkLayers.connections.filter(conn => |
|
conn.source !== nodeId && conn.target !== nodeId |
|
); |
|
|
|
|
|
node.parentNode.removeChild(node); |
|
|
|
|
|
updateLayerConnectivity(); |
|
} |
|
|
|
|
|
function openLayerEditor(node) { |
|
if (!node) return; |
|
|
|
const nodeId = node.getAttribute('data-id'); |
|
const nodeType = node.getAttribute('data-type'); |
|
const nodeName = node.getAttribute('data-name'); |
|
const dimensions = node.getAttribute('data-dimensions'); |
|
|
|
|
|
const event = new CustomEvent('openLayerEditor', { |
|
detail: { id: nodeId, type: nodeType, name: nodeName, dimensions: dimensions } |
|
}); |
|
document.dispatchEvent(event); |
|
} |
|
|
|
|
|
function updateConnections() { |
|
const connections = document.querySelectorAll('.connection'); |
|
connections.forEach(connection => { |
|
const sourceId = connection.getAttribute('data-source'); |
|
const targetId = connection.getAttribute('data-target'); |
|
|
|
const sourceNode = document.querySelector(`.canvas-node[data-id="${sourceId}"]`); |
|
const targetNode = document.querySelector(`.canvas-node[data-id="${targetId}"]`); |
|
|
|
if (sourceNode && targetNode) { |
|
const sourcePort = sourceNode.querySelector('.port-out'); |
|
const targetPort = targetNode.querySelector('.port-in'); |
|
|
|
if (sourcePort && targetPort) { |
|
const sourceRect = sourcePort.getBoundingClientRect(); |
|
const targetRect = targetPort.getBoundingClientRect(); |
|
const canvasRect = canvas.getBoundingClientRect(); |
|
|
|
const startX = sourceRect.left + sourceRect.width / 2 - canvasRect.left; |
|
const startY = sourceRect.top + sourceRect.height / 2 - canvasRect.top; |
|
const endX = targetRect.left + targetRect.width / 2 - canvasRect.left; |
|
const endY = targetRect.top + targetRect.height / 2 - canvasRect.top; |
|
|
|
const length = Math.sqrt(Math.pow(endX - startX, 2) + Math.pow(endY - startY, 2)); |
|
const angle = Math.atan2(endY - startY, endX - startX) * 180 / Math.PI; |
|
|
|
connection.style.left = `${startX}px`; |
|
connection.style.top = `${startY}px`; |
|
connection.style.width = `${length}px`; |
|
connection.style.transform = `rotate(${angle}deg)`; |
|
} |
|
} else { |
|
|
|
if (connection.parentNode) { |
|
connection.parentNode.removeChild(connection); |
|
|
|
|
|
const connIndex = networkLayers.connections.findIndex(conn => |
|
conn.source === sourceId && conn.target === targetId |
|
); |
|
if (connIndex !== -1) { |
|
networkLayers.connections.splice(connIndex, 1); |
|
} |
|
} |
|
} |
|
}); |
|
} |
|
|
|
|
|
function getNetworkArchitecture() { |
|
return networkLayers; |
|
} |
|
|
|
|
|
function clearAllNodes() { |
|
|
|
document.querySelectorAll('.canvas-node, .connection').forEach(el => { |
|
el.parentNode.removeChild(el); |
|
}); |
|
|
|
|
|
networkLayers = { |
|
layers: [], |
|
connections: [] |
|
}; |
|
|
|
|
|
window.neuralNetwork.resetLayerCounter(); |
|
|
|
|
|
const canvasHint = document.querySelector('.canvas-hint'); |
|
if (canvasHint) { |
|
canvasHint.style.display = 'block'; |
|
} |
|
|
|
|
|
const event = new CustomEvent('networkUpdated', { detail: networkLayers }); |
|
document.dispatchEvent(event); |
|
} |
|
|
|
|
|
window.dragDrop = { |
|
getNetworkArchitecture, |
|
clearAllNodes, |
|
updateConnections |
|
}; |
|
} |