|
|
|
function initializeDragAndDrop() { |
|
const nodeItems = document.querySelectorAll('.node-item'); |
|
const canvas = document.getElementById('network-canvas'); |
|
let draggedNode = null; |
|
let offsetX, offsetY; |
|
let isDragging = false; |
|
let isConnecting = false; |
|
let startNode = null; |
|
let connectionLine = null; |
|
let nodeCounter = {}; |
|
|
|
|
|
let networkLayers = { |
|
layers: [], |
|
connections: [] |
|
}; |
|
|
|
|
|
function formatNumber(num) { |
|
if (num === 0) return '0'; |
|
if (!num) return 'N/A'; |
|
|
|
if (num >= 1e9) return (num / 1e9).toFixed(2) + 'B'; |
|
if (num >= 1e6) return (num / 1e6).toFixed(2) + 'M'; |
|
if (num >= 1e3) return (num / 1e3).toFixed(2) + 'K'; |
|
return num.toString(); |
|
} |
|
|
|
|
|
nodeItems.forEach(item => { |
|
item.addEventListener('dragstart', handleDragStart); |
|
}); |
|
|
|
|
|
canvas.addEventListener('dragover', handleDragOver); |
|
canvas.addEventListener('drop', handleDrop); |
|
|
|
|
|
function handleDragStart(e) { |
|
draggedNode = this; |
|
e.dataTransfer.setData('text/plain', this.getAttribute('data-type')); |
|
|
|
|
|
const ghost = this.cloneNode(true); |
|
ghost.style.opacity = '0.5'; |
|
document.body.appendChild(ghost); |
|
e.dataTransfer.setDragImage(ghost, 0, 0); |
|
setTimeout(() => { |
|
document.body.removeChild(ghost); |
|
}, 0); |
|
} |
|
|
|
|
|
function handleDragOver(e) { |
|
e.preventDefault(); |
|
e.dataTransfer.dropEffect = 'copy'; |
|
} |
|
|
|
|
|
function handleDrop(e) { |
|
e.preventDefault(); |
|
|
|
|
|
const canvasHint = document.querySelector('.canvas-hint'); |
|
if (canvasHint) { |
|
canvasHint.style.display = 'none'; |
|
} |
|
|
|
const nodeType = e.dataTransfer.getData('text/plain'); |
|
|
|
if (nodeType) { |
|
|
|
const layerId = window.neuralNetwork.getNextLayerId(nodeType); |
|
|
|
|
|
const canvasNode = document.createElement('div'); |
|
canvasNode.className = `canvas-node ${nodeType}-node`; |
|
canvasNode.setAttribute('data-type', nodeType); |
|
canvasNode.setAttribute('data-id', layerId); |
|
|
|
|
|
const rect = canvas.getBoundingClientRect(); |
|
const x = e.clientX - rect.left; |
|
const y = e.clientY - rect.top; |
|
|
|
canvasNode.style.left = `${x}px`; |
|
canvasNode.style.top = `${y}px`; |
|
|
|
|
|
const nodeConfig = window.neuralNetwork.createNodeConfig(nodeType); |
|
|
|
|
|
let nodeName, inputShape, outputShape, parameters; |
|
|
|
switch(nodeType) { |
|
case 'input': |
|
nodeName = 'Input Layer'; |
|
inputShape = 'N/A'; |
|
outputShape = '[' + nodeConfig.shape.join(' × ') + ']'; |
|
parameters = nodeConfig.parameters; |
|
break; |
|
case 'hidden': |
|
const hiddenCount = document.querySelectorAll('.canvas-node[data-type="hidden"]').length; |
|
nodeConfig.units = hiddenCount === 0 ? 128 : 64; |
|
nodeName = `Hidden Layer ${hiddenCount + 1}`; |
|
|
|
inputShape = 'Connect input'; |
|
outputShape = `[${nodeConfig.units}]`; |
|
parameters = 'Connect input to calculate'; |
|
break; |
|
case 'output': |
|
nodeName = 'Output Layer'; |
|
inputShape = 'Connect input'; |
|
outputShape = `[${nodeConfig.units}]`; |
|
parameters = 'Connect input to calculate'; |
|
break; |
|
case 'conv': |
|
const convCount = document.querySelectorAll('.canvas-node[data-type="conv"]').length; |
|
nodeConfig.filters = 32 * (convCount + 1); |
|
nodeName = `Conv2D ${convCount + 1}`; |
|
inputShape = 'Connect input'; |
|
outputShape = 'Depends on input'; |
|
|
|
parameters = `Kernel: ${nodeConfig.kernelSize.join('×')}\nStride: ${nodeConfig.strides.join('×')}\nPadding: ${nodeConfig.padding}`; |
|
break; |
|
case 'pool': |
|
const poolCount = document.querySelectorAll('.canvas-node[data-type="pool"]').length; |
|
nodeName = `Pooling ${poolCount + 1}`; |
|
inputShape = 'Connect input'; |
|
outputShape = 'Depends on input'; |
|
parameters = `Pool size: ${nodeConfig.poolSize.join('×')}\nStride: ${nodeConfig.strides.join('×')}\nPadding: ${nodeConfig.padding}`; |
|
break; |
|
default: |
|
nodeName = 'Unknown Layer'; |
|
inputShape = 'N/A'; |
|
outputShape = 'N/A'; |
|
parameters = 'N/A'; |
|
} |
|
|
|
|
|
const nodeContent = document.createElement('div'); |
|
nodeContent.className = 'node-content'; |
|
|
|
|
|
const shapeInfo = document.createElement('div'); |
|
shapeInfo.className = 'shape-info'; |
|
shapeInfo.innerHTML = ` |
|
<div class="shape-row"><span class="shape-label">Input:</span> <span class="input-shape">${inputShape}</span></div> |
|
<div class="shape-row"><span class="shape-label">Output:</span> <span class="output-shape">${outputShape}</span></div> |
|
`; |
|
|
|
|
|
const paramsSection = document.createElement('div'); |
|
paramsSection.className = 'params-section'; |
|
paramsSection.innerHTML = ` |
|
<div class="params-details">${parameters}</div> |
|
<div class="node-parameters">Params: ${nodeConfig.parameters !== undefined ? formatNumber(nodeConfig.parameters) : '?'}</div> |
|
`; |
|
|
|
|
|
nodeContent.appendChild(shapeInfo); |
|
nodeContent.appendChild(paramsSection); |
|
|
|
|
|
const dimensionsSection = document.createElement('div'); |
|
dimensionsSection.className = 'node-dimensions'; |
|
|
|
|
|
let dimensionsText = ''; |
|
switch(nodeType) { |
|
case 'input': |
|
dimensionsText = nodeConfig.shape.join(' × '); |
|
break; |
|
case 'hidden': |
|
case 'output': |
|
dimensionsText = nodeConfig.units.toString(); |
|
break; |
|
case 'conv': |
|
if (nodeConfig.inputShape && nodeConfig.outputShape) { |
|
dimensionsText = `${nodeConfig.inputShape.join('×')} → ${nodeConfig.outputShape.join('×')}`; |
|
} else { |
|
dimensionsText = `? → ${nodeConfig.filters} filters`; |
|
} |
|
break; |
|
case 'pool': |
|
if (nodeConfig.inputShape && nodeConfig.outputShape) { |
|
dimensionsText = `${nodeConfig.inputShape.join('×')} → ${nodeConfig.outputShape.join('×')}`; |
|
} else { |
|
dimensionsText = `? → ?`; |
|
} |
|
break; |
|
case 'linear': |
|
dimensionsText = `${nodeConfig.inputFeatures} → ${nodeConfig.outputFeatures}`; |
|
break; |
|
} |
|
dimensionsSection.textContent = dimensionsText; |
|
|
|
|
|
const nodeTitle = document.createElement('div'); |
|
nodeTitle.className = 'node-title'; |
|
nodeTitle.textContent = nodeName; |
|
|
|
|
|
const portIn = document.createElement('div'); |
|
portIn.className = 'node-port port-in'; |
|
|
|
const portOut = document.createElement('div'); |
|
portOut.className = 'node-port port-out'; |
|
|
|
|
|
canvasNode.appendChild(nodeTitle); |
|
canvasNode.appendChild(dimensionsSection); |
|
canvasNode.appendChild(nodeContent); |
|
canvasNode.appendChild(portIn); |
|
canvasNode.appendChild(portOut); |
|
|
|
|
|
canvasNode.setAttribute('data-name', nodeName); |
|
canvasNode.setAttribute('data-dimensions', dimensionsText); |
|
|
|
|
|
canvas.appendChild(canvasNode); |
|
|
|
|
|
canvasNode.layerConfig = nodeConfig; |
|
|
|
|
|
canvasNode.addEventListener('mousedown', startDrag); |
|
|
|
|
|
portIn.addEventListener('mousedown', (e) => { |
|
e.stopPropagation(); |
|
}); |
|
|
|
portOut.addEventListener('mousedown', (e) => { |
|
e.stopPropagation(); |
|
startConnection(canvasNode, e); |
|
}); |
|
|
|
|
|
canvasNode.addEventListener('dblclick', () => { |
|
openLayerEditor(canvasNode); |
|
}); |
|
|
|
|
|
canvasNode.addEventListener('contextmenu', (e) => { |
|
e.preventDefault(); |
|
deleteNode(canvasNode); |
|
}); |
|
|
|
|
|
networkLayers.layers.push({ |
|
id: layerId, |
|
type: nodeType, |
|
name: nodeName, |
|
position: { x, y }, |
|
dimensions: dimensionsText, |
|
config: nodeConfig, |
|
parameters: nodeConfig.parameters || 0 |
|
}); |
|
|
|
|
|
document.dispatchEvent(new CustomEvent('networkUpdated', { |
|
detail: networkLayers |
|
})); |
|
|
|
updateConnections(); |
|
} |
|
} |
|
|
|
|
|
function startDrag(e) { |
|
if (isConnecting) return; |
|
|
|
|
|
if (e.target.closest('.node-controls') || e.target.closest('.node-port')) { |
|
return; |
|
} |
|
|
|
isDragging = true; |
|
const target = e.target.closest('.canvas-node'); |
|
const rect = target.getBoundingClientRect(); |
|
|
|
|
|
offsetX = e.clientX - rect.left; |
|
offsetY = e.clientY - rect.top; |
|
|
|
document.addEventListener('mousemove', dragNode); |
|
document.addEventListener('mouseup', stopDrag); |
|
|
|
|
|
draggedNode = target; |
|
|
|
|
|
draggedNode.style.zIndex = "100"; |
|
|
|
|
|
draggedNode.classList.add('dragging'); |
|
|
|
|
|
e.preventDefault(); |
|
} |
|
|
|
|
|
function dragNode(e) { |
|
if (!isDragging) return; |
|
|
|
const canvasRect = canvas.getBoundingClientRect(); |
|
let x = e.clientX - canvasRect.left - offsetX; |
|
let y = e.clientY - canvasRect.top - offsetY; |
|
|
|
|
|
const nodeWidth = draggedNode.offsetWidth || 150; |
|
const nodeHeight = draggedNode.offsetHeight || 100; |
|
|
|
|
|
x = Math.max(0, Math.min(canvasRect.width - nodeWidth, x)); |
|
y = Math.max(0, Math.min(canvasRect.height - nodeHeight, y)); |
|
|
|
|
|
draggedNode.style.position = 'absolute'; |
|
draggedNode.style.left = `${x}px`; |
|
draggedNode.style.top = `${y}px`; |
|
draggedNode.style.width = `${nodeWidth}px`; |
|
|
|
|
|
const nodeId = draggedNode.getAttribute('data-id'); |
|
const layerIndex = networkLayers.layers.findIndex(layer => layer.id === nodeId); |
|
if (layerIndex !== -1) { |
|
networkLayers.layers[layerIndex].position = { x, y }; |
|
} |
|
|
|
|
|
updateConnections(); |
|
} |
|
|
|
|
|
function stopDrag() { |
|
if (!isDragging) return; |
|
|
|
isDragging = false; |
|
document.removeEventListener('mousemove', dragNode); |
|
document.removeEventListener('mouseup', stopDrag); |
|
|
|
|
|
if (draggedNode) { |
|
draggedNode.style.zIndex = "10"; |
|
draggedNode.classList.remove('dragging'); |
|
|
|
|
|
updateConnections(); |
|
} |
|
} |
|
|
|
|
|
function startConnection(node, e) { |
|
isConnecting = true; |
|
startNode = node; |
|
|
|
|
|
connectionLine = document.createElement('div'); |
|
connectionLine.className = 'connection temp-connection'; |
|
|
|
|
|
const portOut = node.querySelector('.port-out'); |
|
const portRect = portOut.getBoundingClientRect(); |
|
const canvasRect = canvas.getBoundingClientRect(); |
|
|
|
const startX = portRect.left + portRect.width / 2 - canvasRect.left; |
|
const startY = portRect.top + portRect.height / 2 - canvasRect.top; |
|
|
|
|
|
connectionLine.style.left = `${startX}px`; |
|
connectionLine.style.top = `${startY}px`; |
|
connectionLine.style.width = '0px'; |
|
connectionLine.style.transform = 'rotate(0deg)'; |
|
|
|
|
|
portOut.classList.add('active-port'); |
|
|
|
|
|
highlightValidConnectionTargets(node); |
|
|
|
canvas.appendChild(connectionLine); |
|
|
|
|
|
document.addEventListener('mousemove', drawConnection); |
|
document.addEventListener('mouseup', cancelConnection); |
|
|
|
e.preventDefault(); |
|
} |
|
|
|
|
|
function highlightValidConnectionTargets(sourceNode) { |
|
const sourceType = sourceNode.getAttribute('data-type'); |
|
const sourceId = sourceNode.getAttribute('data-id'); |
|
|
|
document.querySelectorAll('.canvas-node').forEach(node => { |
|
if (node !== sourceNode) { |
|
const nodeType = node.getAttribute('data-type'); |
|
const nodeId = node.getAttribute('data-id'); |
|
const isValidTarget = isValidConnection(sourceType, nodeType, sourceId, nodeId); |
|
|
|
const portIn = node.querySelector('.port-in'); |
|
if (isValidTarget) { |
|
portIn.classList.add('valid-target'); |
|
} else { |
|
portIn.classList.add('invalid-target'); |
|
} |
|
} |
|
}); |
|
} |
|
|
|
|
|
function removePortHighlights() { |
|
document.querySelectorAll('.port-in, .port-out').forEach(port => { |
|
port.classList.remove('active-port', 'valid-target', 'invalid-target'); |
|
}); |
|
} |
|
|
|
|
|
function isValidConnection(sourceType, targetType, sourceId, targetId) { |
|
|
|
if (sourceType === 'output' || targetType === 'input') { |
|
return false; |
|
} |
|
|
|
|
|
const existingConnection = networkLayers.connections.find( |
|
conn => conn.target === sourceId && conn.source === targetId |
|
); |
|
if (existingConnection) { |
|
return false; |
|
} |
|
|
|
|
|
switch(sourceType) { |
|
case 'input': |
|
return ['hidden', 'conv'].includes(targetType); |
|
case 'conv': |
|
return ['conv', 'pool', 'hidden'].includes(targetType); |
|
case 'pool': |
|
return ['conv', 'hidden'].includes(targetType); |
|
case 'hidden': |
|
return ['hidden', 'output'].includes(targetType); |
|
default: |
|
return false; |
|
} |
|
} |
|
|
|
|
|
function drawConnection(e) { |
|
if (!isConnecting || !connectionLine) return; |
|
|
|
const canvasRect = canvas.getBoundingClientRect(); |
|
const portOut = startNode.querySelector('.port-out'); |
|
const portRect = portOut.getBoundingClientRect(); |
|
|
|
|
|
const startX = portRect.left + portRect.width / 2 - canvasRect.left; |
|
const startY = portRect.top + portRect.height / 2 - canvasRect.top; |
|
const endX = e.clientX - canvasRect.left; |
|
const endY = e.clientY - canvasRect.top; |
|
|
|
|
|
const length = Math.sqrt(Math.pow(endX - startX, 2) + Math.pow(endY - startY, 2)); |
|
const angle = Math.atan2(endY - startY, endX - startX) * 180 / Math.PI; |
|
|
|
|
|
connectionLine.style.width = `${length}px`; |
|
connectionLine.style.transform = `rotate(${angle}deg)`; |
|
|
|
|
|
document.querySelectorAll('.canvas-node').forEach(node => { |
|
if (node !== startNode) { |
|
const nodeRect = node.getBoundingClientRect(); |
|
const portIn = node.querySelector('.port-in'); |
|
const portInRect = portIn.getBoundingClientRect(); |
|
|
|
|
|
if (e.clientX >= portInRect.left && e.clientX <= portInRect.right && |
|
e.clientY >= portInRect.top && e.clientY <= portInRect.bottom) { |
|
portIn.classList.add('port-hover'); |
|
} else { |
|
portIn.classList.remove('port-hover'); |
|
} |
|
} |
|
}); |
|
} |
|
|
|
|
|
function cancelConnection(e) { |
|
if (!isConnecting) return; |
|
|
|
|
|
let targetNode = null; |
|
document.querySelectorAll('.canvas-node').forEach(node => { |
|
if (node !== startNode) { |
|
const portIn = node.querySelector('.port-in'); |
|
const portRect = portIn.getBoundingClientRect(); |
|
|
|
if (e.clientX >= portRect.left && e.clientX <= portRect.right && |
|
e.clientY >= portRect.top && e.clientY <= portRect.bottom) { |
|
|
|
|
|
const sourceType = startNode.getAttribute('data-type'); |
|
const targetType = node.getAttribute('data-type'); |
|
const sourceId = startNode.getAttribute('data-id'); |
|
const targetId = node.getAttribute('data-id'); |
|
|
|
if (isValidConnection(sourceType, targetType, sourceId, targetId)) { |
|
targetNode = node; |
|
} |
|
} |
|
} |
|
}); |
|
|
|
|
|
if (targetNode) { |
|
endConnection(targetNode); |
|
} else { |
|
|
|
if (connectionLine && connectionLine.parentNode) { |
|
connectionLine.parentNode.removeChild(connectionLine); |
|
} |
|
} |
|
|
|
|
|
removePortHighlights(); |
|
document.querySelectorAll('.port-hover').forEach(port => { |
|
port.classList.remove('port-hover'); |
|
}); |
|
|
|
|
|
isConnecting = false; |
|
startNode = null; |
|
connectionLine = null; |
|
|
|
|
|
document.removeEventListener('mousemove', drawConnection); |
|
document.removeEventListener('mouseup', cancelConnection); |
|
} |
|
|
|
|
|
function endConnection(targetNode) { |
|
if (!isConnecting || !connectionLine || !startNode) return; |
|
|
|
const sourceType = startNode.getAttribute('data-type'); |
|
const targetType = targetNode.getAttribute('data-type'); |
|
const sourceId = startNode.getAttribute('data-id'); |
|
const targetId = targetNode.getAttribute('data-id'); |
|
|
|
|
|
if (isValidConnection(sourceType, targetType, sourceId, targetId)) { |
|
|
|
const canvas = document.getElementById('network-canvas'); |
|
const svgContainer = document.querySelector('#network-canvas .svg-container') || createSVGContainer(); |
|
|
|
|
|
const sourceRect = startNode.getBoundingClientRect(); |
|
const targetRect = targetNode.getBoundingClientRect(); |
|
const canvasRect = canvas.getBoundingClientRect(); |
|
|
|
|
|
const sourcePort = startNode.querySelector('.output-port'); |
|
const targetPort = targetNode.querySelector('.input-port'); |
|
|
|
const sourcePortRect = sourcePort.getBoundingClientRect(); |
|
const targetPortRect = targetPort.getBoundingClientRect(); |
|
|
|
const startX = sourcePortRect.left + (sourcePortRect.width / 2) - canvasRect.left; |
|
const startY = sourcePortRect.top + (sourcePortRect.height / 2) - canvasRect.top; |
|
const endX = targetPortRect.left + (targetPortRect.width / 2) - canvasRect.left; |
|
const endY = targetPortRect.top + (targetPortRect.height / 2) - canvasRect.top; |
|
|
|
|
|
const pathId = `connection-${sourceId}-${targetId}` |