ameerazam08's picture
Upload 7 files
fd832fc verified
raw
history blame
23.8 kB
// Initialize drag and drop functionality
function initializeDragAndDrop() {
const nodeItems = document.querySelectorAll('.node-item');
const canvas = document.getElementById('network-canvas');
let draggedNode = null;
let offsetX, offsetY;
let isDragging = false;
let isConnecting = false;
let startNode = null;
let connectionLine = null;
let nodeCounter = {};
// Track layers for proper architecture building
let networkLayers = {
layers: [],
connections: []
};
// Helper function to format numbers with K, M, B suffixes
function formatNumber(num) {
if (num === 0) return '0';
if (!num) return 'N/A';
if (num >= 1e9) return (num / 1e9).toFixed(2) + 'B';
if (num >= 1e6) return (num / 1e6).toFixed(2) + 'M';
if (num >= 1e3) return (num / 1e3).toFixed(2) + 'K';
return num.toString();
}
// Add event listeners to draggable items
nodeItems.forEach(item => {
item.addEventListener('dragstart', handleDragStart);
});
// Canvas events for dropping nodes
canvas.addEventListener('dragover', handleDragOver);
canvas.addEventListener('drop', handleDrop);
// Handle drag start event
function handleDragStart(e) {
draggedNode = this;
e.dataTransfer.setData('text/plain', this.getAttribute('data-type'));
// Set a ghost image for drag (optional)
const ghost = this.cloneNode(true);
ghost.style.opacity = '0.5';
document.body.appendChild(ghost);
e.dataTransfer.setDragImage(ghost, 0, 0);
setTimeout(() => {
document.body.removeChild(ghost);
}, 0);
}
// Handle drag over event
function handleDragOver(e) {
e.preventDefault();
e.dataTransfer.dropEffect = 'copy';
}
// Handle drop event to create new nodes on the canvas
function handleDrop(e) {
e.preventDefault();
// Hide the canvas hint when nodes are added
const canvasHint = document.querySelector('.canvas-hint');
if (canvasHint) {
canvasHint.style.display = 'none';
}
const nodeType = e.dataTransfer.getData('text/plain');
if (nodeType) {
// Generate unique layer ID
const layerId = window.neuralNetwork.getNextLayerId(nodeType);
// Create a new node on the canvas
const canvasNode = document.createElement('div');
canvasNode.className = `canvas-node ${nodeType}-node`;
canvasNode.setAttribute('data-type', nodeType);
canvasNode.setAttribute('data-id', layerId);
// Set node position
const rect = canvas.getBoundingClientRect();
const x = e.clientX - rect.left;
const y = e.clientY - rect.top;
canvasNode.style.left = `${x}px`;
canvasNode.style.top = `${y}px`;
// Get default config for this node type
const nodeConfig = window.neuralNetwork.createNodeConfig(nodeType);
// Create node content with input and output shape information
let nodeName, inputShape, outputShape, parameters;
switch(nodeType) {
case 'input':
nodeName = 'Input Layer';
inputShape = 'N/A';
outputShape = '[' + nodeConfig.shape.join(' × ') + ']';
parameters = nodeConfig.parameters;
break;
case 'hidden':
const hiddenCount = document.querySelectorAll('.canvas-node[data-type="hidden"]').length;
nodeConfig.units = hiddenCount === 0 ? 128 : 64;
nodeName = `Hidden Layer ${hiddenCount + 1}`;
// Input shape will be updated when connections are made
inputShape = 'Connect input';
outputShape = `[${nodeConfig.units}]`;
parameters = 'Connect input to calculate';
break;
case 'output':
nodeName = 'Output Layer';
inputShape = 'Connect input';
outputShape = `[${nodeConfig.units}]`;
parameters = 'Connect input to calculate';
break;
case 'conv':
const convCount = document.querySelectorAll('.canvas-node[data-type="conv"]').length;
nodeConfig.filters = 32 * (convCount + 1);
nodeName = `Conv2D ${convCount + 1}`;
inputShape = 'Connect input';
outputShape = 'Depends on input';
// Create parameter string
parameters = `Kernel: ${nodeConfig.kernelSize.join('×')}\nStride: ${nodeConfig.strides.join('×')}\nPadding: ${nodeConfig.padding}`;
break;
case 'pool':
const poolCount = document.querySelectorAll('.canvas-node[data-type="pool"]').length;
nodeName = `Pooling ${poolCount + 1}`;
inputShape = 'Connect input';
outputShape = 'Depends on input';
parameters = `Pool size: ${nodeConfig.poolSize.join('×')}\nStride: ${nodeConfig.strides.join('×')}\nPadding: ${nodeConfig.padding}`;
break;
default:
nodeName = 'Unknown Layer';
inputShape = 'N/A';
outputShape = 'N/A';
parameters = 'N/A';
}
// Create node content
const nodeContent = document.createElement('div');
nodeContent.className = 'node-content';
// Add shape information in a structured way
const shapeInfo = document.createElement('div');
shapeInfo.className = 'shape-info';
shapeInfo.innerHTML = `
<div class="shape-row"><span class="shape-label">Input:</span> <span class="input-shape">${inputShape}</span></div>
<div class="shape-row"><span class="shape-label">Output:</span> <span class="output-shape">${outputShape}</span></div>
`;
// Add parameters section
const paramsSection = document.createElement('div');
paramsSection.className = 'params-section';
paramsSection.innerHTML = `
<div class="params-details">${parameters}</div>
<div class="node-parameters">Params: ${nodeConfig.parameters !== undefined ? formatNumber(nodeConfig.parameters) : '?'}</div>
`;
// Assemble content
nodeContent.appendChild(shapeInfo);
nodeContent.appendChild(paramsSection);
// Add dimensions section to show shapes compactly
const dimensionsSection = document.createElement('div');
dimensionsSection.className = 'node-dimensions';
// Set dimensions text based on node type
let dimensionsText = '';
switch(nodeType) {
case 'input':
dimensionsText = nodeConfig.shape.join(' × ');
break;
case 'hidden':
case 'output':
dimensionsText = nodeConfig.units.toString();
break;
case 'conv':
if (nodeConfig.inputShape && nodeConfig.outputShape) {
dimensionsText = `${nodeConfig.inputShape.join('×')}${nodeConfig.outputShape.join('×')}`;
} else {
dimensionsText = `? → ${nodeConfig.filters} filters`;
}
break;
case 'pool':
if (nodeConfig.inputShape && nodeConfig.outputShape) {
dimensionsText = `${nodeConfig.inputShape.join('×')}${nodeConfig.outputShape.join('×')}`;
} else {
dimensionsText = `? → ?`;
}
break;
case 'linear':
dimensionsText = `${nodeConfig.inputFeatures}${nodeConfig.outputFeatures}`;
break;
}
dimensionsSection.textContent = dimensionsText;
// Add node title for clearer identification
const nodeTitle = document.createElement('div');
nodeTitle.className = 'node-title';
nodeTitle.textContent = nodeName;
// Add connection ports
const portIn = document.createElement('div');
portIn.className = 'node-port port-in';
const portOut = document.createElement('div');
portOut.className = 'node-port port-out';
// Assemble the node with the new structure
canvasNode.appendChild(nodeTitle);
canvasNode.appendChild(dimensionsSection);
canvasNode.appendChild(nodeContent);
canvasNode.appendChild(portIn);
canvasNode.appendChild(portOut);
// Store node data attributes for easier access
canvasNode.setAttribute('data-name', nodeName);
canvasNode.setAttribute('data-dimensions', dimensionsText);
// Add node to the canvas
canvas.appendChild(canvasNode);
// Store node configuration
canvasNode.layerConfig = nodeConfig;
// Add event listeners for node manipulation
canvasNode.addEventListener('mousedown', startDrag);
// Update port event listeners for the new class names
portIn.addEventListener('mousedown', (e) => {
e.stopPropagation();
});
portOut.addEventListener('mousedown', (e) => {
e.stopPropagation();
startConnection(canvasNode, e);
});
// Double-click to edit node properties
canvasNode.addEventListener('dblclick', () => {
openLayerEditor(canvasNode);
});
// Right-click to delete
canvasNode.addEventListener('contextmenu', (e) => {
e.preventDefault();
deleteNode(canvasNode);
});
// Add to network layers for architecture building
networkLayers.layers.push({
id: layerId,
type: nodeType,
name: nodeName,
position: { x, y },
dimensions: dimensionsText,
config: nodeConfig,
parameters: nodeConfig.parameters || 0
});
// Notify about network changes
document.dispatchEvent(new CustomEvent('networkUpdated', {
detail: networkLayers
}));
updateConnections();
}
}
// Start dragging an existing node on the canvas
function startDrag(e) {
if (isConnecting) return;
// Only start drag if not clicking on buttons or ports
if (e.target.closest('.node-controls') || e.target.closest('.node-port')) {
return;
}
isDragging = true;
const target = e.target.closest('.canvas-node');
const rect = target.getBoundingClientRect();
// Calculate offset
offsetX = e.clientX - rect.left;
offsetY = e.clientY - rect.top;
document.addEventListener('mousemove', dragNode);
document.addEventListener('mouseup', stopDrag);
// Reference to the dragged node
draggedNode = target;
// Make the dragged node appear on top
draggedNode.style.zIndex = "100";
// Add dragging class for visual feedback
draggedNode.classList.add('dragging');
// Prevent default behavior
e.preventDefault();
}
// Drag node on the canvas
function dragNode(e) {
if (!isDragging) return;
const canvasRect = canvas.getBoundingClientRect();
let x = e.clientX - canvasRect.left - offsetX;
let y = e.clientY - canvasRect.top - offsetY;
// Constrain to canvas with better boundary checks
const nodeWidth = draggedNode.offsetWidth || 150; // Default width if not set
const nodeHeight = draggedNode.offsetHeight || 100; // Default height if not set
// Ensure the node stays completely within the canvas
x = Math.max(0, Math.min(canvasRect.width - nodeWidth, x));
y = Math.max(0, Math.min(canvasRect.height - nodeHeight, y));
// Apply position with fixed sizing to prevent layout expansion
draggedNode.style.position = 'absolute';
draggedNode.style.left = `${x}px`;
draggedNode.style.top = `${y}px`;
draggedNode.style.width = `${nodeWidth}px`; // Maintain fixed width
// Update node position in network layers
const nodeId = draggedNode.getAttribute('data-id');
const layerIndex = networkLayers.layers.findIndex(layer => layer.id === nodeId);
if (layerIndex !== -1) {
networkLayers.layers[layerIndex].position = { x, y };
}
// Update connected lines if any
updateConnections();
}
// Stop dragging
function stopDrag() {
if (!isDragging) return;
isDragging = false;
document.removeEventListener('mousemove', dragNode);
document.removeEventListener('mouseup', stopDrag);
// Reset z-index and remove dragging class
if (draggedNode) {
draggedNode.style.zIndex = "10";
draggedNode.classList.remove('dragging');
// Trigger connections update one more time
updateConnections();
}
}
// Start creating a connection between nodes
function startConnection(node, e) {
isConnecting = true;
startNode = node;
// Create a temporary line
connectionLine = document.createElement('div');
connectionLine.className = 'connection temp-connection';
// Get start position (center of the port)
const portOut = node.querySelector('.port-out');
const portRect = portOut.getBoundingClientRect();
const canvasRect = canvas.getBoundingClientRect();
const startX = portRect.left + portRect.width / 2 - canvasRect.left;
const startY = portRect.top + portRect.height / 2 - canvasRect.top;
// Position the line
connectionLine.style.left = `${startX}px`;
connectionLine.style.top = `${startY}px`;
connectionLine.style.width = '0px';
connectionLine.style.transform = 'rotate(0deg)';
// Add active class to the starting port
portOut.classList.add('active-port');
// Highlight valid target ports
highlightValidConnectionTargets(node);
canvas.appendChild(connectionLine);
// Add event listeners for drawing the line
document.addEventListener('mousemove', drawConnection);
document.addEventListener('mouseup', cancelConnection);
e.preventDefault();
}
// Highlight valid targets for connection
function highlightValidConnectionTargets(sourceNode) {
const sourceType = sourceNode.getAttribute('data-type');
const sourceId = sourceNode.getAttribute('data-id');
document.querySelectorAll('.canvas-node').forEach(node => {
if (node !== sourceNode) {
const nodeType = node.getAttribute('data-type');
const nodeId = node.getAttribute('data-id');
const isValidTarget = isValidConnection(sourceType, nodeType, sourceId, nodeId);
const portIn = node.querySelector('.port-in');
if (isValidTarget) {
portIn.classList.add('valid-target');
} else {
portIn.classList.add('invalid-target');
}
}
});
}
// Remove highlights from all ports
function removePortHighlights() {
document.querySelectorAll('.port-in, .port-out').forEach(port => {
port.classList.remove('active-port', 'valid-target', 'invalid-target');
});
}
// Check if a connection between two node types is valid
function isValidConnection(sourceType, targetType, sourceId, targetId) {
// Basic hierarchy validation
if (sourceType === 'output' || targetType === 'input') {
return false; // Output can't have outgoing connections, Input can't have incoming
}
// Prevent cycles
const existingConnection = networkLayers.connections.find(
conn => conn.target === sourceId && conn.source === targetId
);
if (existingConnection) {
return false;
}
// Specific connection rules
switch(sourceType) {
case 'input':
return ['hidden', 'conv'].includes(targetType);
case 'conv':
return ['conv', 'pool', 'hidden'].includes(targetType);
case 'pool':
return ['conv', 'hidden'].includes(targetType);
case 'hidden':
return ['hidden', 'output'].includes(targetType);
default:
return false;
}
}
// Draw the connection line as mouse moves
function drawConnection(e) {
if (!isConnecting || !connectionLine) return;
const canvasRect = canvas.getBoundingClientRect();
const portOut = startNode.querySelector('.port-out');
const portRect = portOut.getBoundingClientRect();
// Calculate start and end points
const startX = portRect.left + portRect.width / 2 - canvasRect.left;
const startY = portRect.top + portRect.height / 2 - canvasRect.top;
const endX = e.clientX - canvasRect.left;
const endY = e.clientY - canvasRect.top;
// Calculate length and angle
const length = Math.sqrt(Math.pow(endX - startX, 2) + Math.pow(endY - startY, 2));
const angle = Math.atan2(endY - startY, endX - startX) * 180 / Math.PI;
// Update line
connectionLine.style.width = `${length}px`;
connectionLine.style.transform = `rotate(${angle}deg)`;
// Highlight the port under cursor
document.querySelectorAll('.canvas-node').forEach(node => {
if (node !== startNode) {
const nodeRect = node.getBoundingClientRect();
const portIn = node.querySelector('.port-in');
const portInRect = portIn.getBoundingClientRect();
// Check if mouse is over the input port
if (e.clientX >= portInRect.left && e.clientX <= portInRect.right &&
e.clientY >= portInRect.top && e.clientY <= portInRect.bottom) {
portIn.classList.add('port-hover');
} else {
portIn.classList.remove('port-hover');
}
}
});
}
// Cancel connection creation
function cancelConnection(e) {
if (!isConnecting) return;
// Find if we're over a valid input port
let targetNode = null;
document.querySelectorAll('.canvas-node').forEach(node => {
if (node !== startNode) {
const portIn = node.querySelector('.port-in');
const portRect = portIn.getBoundingClientRect();
if (e.clientX >= portRect.left && e.clientX <= portRect.right &&
e.clientY >= portRect.top && e.clientY <= portRect.bottom) {
// Check if this would be a valid connection
const sourceType = startNode.getAttribute('data-type');
const targetType = node.getAttribute('data-type');
const sourceId = startNode.getAttribute('data-id');
const targetId = node.getAttribute('data-id');
if (isValidConnection(sourceType, targetType, sourceId, targetId)) {
targetNode = node;
}
}
}
});
// If we found a valid target, create the connection
if (targetNode) {
endConnection(targetNode);
} else {
// Otherwise, remove the temporary line
if (connectionLine && connectionLine.parentNode) {
connectionLine.parentNode.removeChild(connectionLine);
}
}
// Remove all port highlights
removePortHighlights();
document.querySelectorAll('.port-hover').forEach(port => {
port.classList.remove('port-hover');
});
// Reset variables
isConnecting = false;
startNode = null;
connectionLine = null;
// Remove event listeners
document.removeEventListener('mousemove', drawConnection);
document.removeEventListener('mouseup', cancelConnection);
}
// End creating a connection
function endConnection(targetNode) {
if (!isConnecting || !connectionLine || !startNode) return;
const sourceType = startNode.getAttribute('data-type');
const targetType = targetNode.getAttribute('data-type');
const sourceId = startNode.getAttribute('data-id');
const targetId = targetNode.getAttribute('data-id');
// Check if this is a valid connection
if (isValidConnection(sourceType, targetType, sourceId, targetId)) {
// Create a permanent SVG connection
const canvas = document.getElementById('network-canvas');
const svgContainer = document.querySelector('#network-canvas .svg-container') || createSVGContainer();
// Get positions for source and target nodes
const sourceRect = startNode.getBoundingClientRect();
const targetRect = targetNode.getBoundingClientRect();
const canvasRect = canvas.getBoundingClientRect();
// Calculate port positions
const sourcePort = startNode.querySelector('.output-port');
const targetPort = targetNode.querySelector('.input-port');
const sourcePortRect = sourcePort.getBoundingClientRect();
const targetPortRect = targetPort.getBoundingClientRect();
const startX = sourcePortRect.left + (sourcePortRect.width / 2) - canvasRect.left;
const startY = sourcePortRect.top + (sourcePortRect.height / 2) - canvasRect.top;
const endX = targetPortRect.left + (targetPortRect.width / 2) - canvasRect.left;
const endY = targetPortRect.top + (targetPortRect.height / 2) - canvasRect.top;
// Create the connection
const pathId = `connection-${sourceId}-${targetId}`