aldigobbler's picture
Update index.html
0d3897e verified
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Embeddings Visualizer</title>
<style>
body {
margin: 0;
overflow: hidden;
font-family: Arial, sans-serif;
}
canvas {
display: block;
}
#info {
position: absolute;
top: 10px;
left: 10px;
background-color: rgba(0, 0, 0, 0.7);
color: white;
padding: 10px;
border-radius: 5px;
display: none;
max-width: 300px;
max-height: 200px;
overflow: auto;
}
#loading {
position: absolute;
top: 50%;
left: 50%;
transform: translate(-50%, -50%);
background-color: rgba(0, 0, 0, 0.7);
color: white;
padding: 20px;
border-radius: 5px;
font-size: 18px;
}
#legend {
position: absolute;
top: 10px;
right: 10px;
background-color: rgba(0, 0, 0, 0.7);
color: white;
padding: 10px;
border-radius: 5px;
max-width: 200px;
}
.color-box {
display: inline-block;
width: 15px;
height: 15px;
margin-right: 8px;
border-radius: 3px;
}
.legend-item {
margin: 5px 0;
display: flex;
align-items: center;
}
#controls {
position: absolute;
bottom: 10px;
left: 10px;
background-color: rgba(0, 0, 0, 0.7);
color: white;
padding: 10px;
border-radius: 5px;
}
select, button {
margin: 5px 0;
padding: 5px;
border-radius: 3px;
border: none;
}
</style>
</head>
<body>
<div id="loading">Loading embeddings, please wait...</div>
<div id="info"></div>
<div id="legend"></div>
<div id="controls">
<div>
<label for="category-filter">Filter by category:</label>
<select id="category-filter">
<option value="all">All Categories</option>
</select>
</div>
<div>
<label for="point-size">Point Size:</label>
<input type="range" id="point-size" min="0.05" max="0.3" step="0.01" value="0.1">
</div>
<button id="reset-view">Reset View</button>
</div>
<script src="https://cdn.jsdelivr.net/npm/[email protected]/build/three.min.js"></script>
<script src="https://cdn.jsdelivr.net/npm/[email protected]/examples/js/controls/OrbitControls.js"></script>
<script>
// Set up the scene
const scene = new THREE.Scene();
scene.background = new THREE.Color(0x111111);
// Set up the camera
const camera = new THREE.PerspectiveCamera(75, window.innerWidth / window.innerHeight, 0.1, 1000);
camera.position.z = 5;
// Set up the renderer
const renderer = new THREE.WebGLRenderer({ antialias: true });
renderer.setSize(window.innerWidth, window.innerHeight);
document.body.appendChild(renderer.domElement);
// Add controls for rotating, panning, and zooming
const controls = new THREE.OrbitControls(camera, renderer.domElement);
controls.enableDamping = true;
controls.dampingFactor = 0.05;
// Add ambient light
const ambientLight = new THREE.AmbientLight(0xffffff, 0.6);
scene.add(ambientLight);
// Add directional light
const directionalLight = new THREE.DirectionalLight(0xffffff, 0.8);
directionalLight.position.set(1, 1, 1);
scene.add(directionalLight);
// Variables for raycasting (detecting when mouse hovers over points)
const raycaster = new THREE.Raycaster();
const mouse = new THREE.Vector2();
let hoveredPoint = null;
let points = [];
let pointsData = [];
let pointCloud = null;
let pointSize = 0.1;
let categoryColors = {};
let categories = [];
let originalPositions = null;
let visibleCategories = new Set();
// Create a color scheme for categories
function generateColorMap(categories) {
const colorMap = {};
const hueStep = 1 / categories.length;
categories.forEach((category, index) => {
const color = new THREE.Color();
color.setHSL(hueStep * index, 0.7, 0.5);
colorMap[category] = color;
});
return colorMap;
}
// Create the legend
function createLegend(colorMap) {
const legendEl = document.getElementById('legend');
legendEl.innerHTML = '<h3>Categories</h3>';
Object.entries(colorMap).forEach(([category, color]) => {
const item = document.createElement('div');
item.className = 'legend-item';
const colorBox = document.createElement('span');
colorBox.className = 'color-box';
colorBox.style.backgroundColor = `#${color.getHexString()}`;
const label = document.createElement('span');
label.textContent = category;
item.appendChild(colorBox);
item.appendChild(label);
legendEl.appendChild(item);
});
}
// Update category filter dropdown
function updateCategoryFilter(categories) {
const filterEl = document.getElementById('category-filter');
// Clear existing options except the first one
while (filterEl.options.length > 1) {
filterEl.remove(1);
}
// Add categories
categories.forEach(category => {
const option = document.createElement('option');
option.value = category;
option.textContent = category;
filterEl.appendChild(option);
});
}
// Filter points by category
function filterByCategory(category) {
if (!pointCloud || !originalPositions) return;
const positions = pointCloud.geometry.attributes.position.array;
const colors = pointCloud.geometry.attributes.color.array;
const visible = new Float32Array(positions.length);
if (category === 'all') {
// Show all points
for (let i = 0; i < positions.length; i++) {
positions[i] = originalPositions[i];
}
visibleCategories = new Set(categories);
} else {
// Show only points of the selected category
visibleCategories = new Set([category]);
for (let i = 0; i < pointsData.length; i++) {
const idx = i * 3;
if (pointsData[i].category === category) {
positions[idx] = originalPositions[idx];
positions[idx + 1] = originalPositions[idx + 1];
positions[idx + 2] = originalPositions[idx + 2];
} else {
// Move points of other categories far away (effectively hiding them)
positions[idx] = 10000;
positions[idx + 1] = 10000;
positions[idx + 2] = 10000;
}
}
}
pointCloud.geometry.attributes.position.needsUpdate = true;
}
// Reset camera and controls to initial view
function resetView() {
if (!pointCloud) return;
// Calculate center of visible points
const visiblePoints = pointsData.filter(p => visibleCategories.has(p.category));
if (visiblePoints.length === 0) return;
const center = new THREE.Vector3(0, 0, 0);
const boundingSphere = new THREE.Sphere(center, 5);
const distance = boundingSphere.radius / Math.sin(camera.fov * Math.PI / 360);
camera.position.set(0, 0, distance * 1.2);
camera.lookAt(center);
controls.target.copy(center);
controls.update();
}
// Load the embeddings data
Promise.all([
fetch('/static/embeddings.json').then(res => res.json()),
fetch('/static/categories.json').then(res => res.json()).catch(() => ({ categories: [] }))
]).then(([data, categoryData]) => {
// Hide loading message
document.getElementById('loading').style.display = 'none';
// Store data for tooltip
pointsData = data;
// Extract categories from points if not provided separately
if (categoryData.categories && categoryData.categories.length > 0) {
categories = categoryData.categories;
} else {
const categorySet = new Set();
data.forEach(point => {
if (point.category) categorySet.add(point.category);
});
categories = Array.from(categorySet);
}
// Generate colors for categories
categoryColors = generateColorMap(categories);
createLegend(categoryColors);
updateCategoryFilter(categories);
visibleCategories = new Set(categories);
// Normalize the positions to keep them within reasonable bounds
let xValues = data.map(p => p.x);
let yValues = data.map(p => p.y);
let zValues = data.map(p => p.z);
let xMin = Math.min(...xValues), xMax = Math.max(...xValues);
let yMin = Math.min(...yValues), yMax = Math.max(...yValues);
let zMin = Math.min(...zValues), zMax = Math.max(...zValues);
// Create a geometry for all points
const geometry = new THREE.BufferGeometry();
const positions = new Float32Array(data.length * 3);
const colors = new Float32Array(data.length * 3);
for (let i = 0; i < data.length; i++) {
// Normalize to a range of approximately -5 to 5
const x = ((data[i].x - xMin) / (xMax - xMin) * 10) - 5;
const y = ((data[i].y - yMin) / (yMax - yMin) * 10) - 5;
const z = ((data[i].z - zMin) / (zMax - zMin) * 10) - 5;
positions[i * 3] = x;
positions[i * 3 + 1] = y;
positions[i * 3 + 2] = z;
// Use category color
let color;
if (data[i].category && categoryColors[data[i].category]) {
color = categoryColors[data[i].category];
} else {
// Fallback to a random color if no category
color = new THREE.Color();
color.setHSL(Math.random(), 0.7, 0.5);
}
colors[i * 3] = color.r;
colors[i * 3 + 1] = color.g;
colors[i * 3 + 2] = color.b;
}
// Store original positions for filtering
originalPositions = positions.slice();
geometry.setAttribute('position', new THREE.BufferAttribute(positions, 3));
geometry.setAttribute('color', new THREE.BufferAttribute(colors, 3));
// Create the point material
const material = new THREE.PointsMaterial({
size: pointSize,
vertexColors: true,
sizeAttenuation: true
});
// Create the point cloud
pointCloud = new THREE.Points(geometry, material);
scene.add(pointCloud);
// Store individual points for raycasting
for (let i = 0; i < data.length; i++) {
points.push({
position: new THREE.Vector3(
positions[i * 3],
positions[i * 3 + 1],
positions[i * 3 + 2]
),
index: i
});
}
// Set camera position to view the entire scene
resetView();
// Setup event listeners
document.getElementById('point-size').addEventListener('input', (e) => {
pointSize = parseFloat(e.target.value);
if (pointCloud) {
pointCloud.material.size = pointSize;
}
});
document.getElementById('category-filter').addEventListener('change', (e) => {
filterByCategory(e.target.value);
});
document.getElementById('reset-view').addEventListener('click', resetView);
})
.catch(error => {
console.error('Error loading embeddings:', error);
document.getElementById('loading').textContent = 'Error loading embeddings. Check console for details.';
});
// Handle mouse movement for hover effects
function onMouseMove(event) {
// Calculate mouse position in normalized device coordinates (-1 to +1)
mouse.x = (event.clientX / window.innerWidth) * 2 - 1;
mouse.y = - (event.clientY / window.innerHeight) * 2 + 1;
}
// Handle window resize
function onWindowResize() {
camera.aspect = window.innerWidth / window.innerHeight;
camera.updateProjectionMatrix();
renderer.setSize(window.innerWidth, window.innerHeight);
}
// Check if mouse is hovering over a point
function checkIntersection() {
if (!pointCloud) return;
raycaster.setFromCamera(mouse, camera);
let intersects = [];
for (let i = 0; i < points.length; i++) {
const point = points[i];
const pointData = pointsData[i];
// Skip points that aren't in visible categories
if (!visibleCategories.has(pointData.category)) continue;
const distance = raycaster.ray.distanceToPoint(point.position);
// If mouse is close enough to a point (adjust the threshold as needed)
if (distance < 0.1) {
intersects.push({ distance, index: point.index });
}
}
// Sort by distance to get the closest point
intersects.sort((a, b) => a.distance - b.distance);
// Reset previous hover
if (hoveredPoint !== null) {
document.getElementById('info').style.display = 'none';
hoveredPoint = null;
}
// Show info for the new hovered point
if (intersects.length > 0) {
const index = intersects[0].index;
const data = pointsData[index];
hoveredPoint = index;
const infoElement = document.getElementById('info');
infoElement.innerHTML = `
<h3>${data.title}</h3>
<p><strong>Category:</strong> ${data.category || 'Uncategorized'}</p>
<p>${data.text}</p>
`;
infoElement.style.display = 'block';
}
}
// Animation loop
function animate() {
requestAnimationFrame(animate);
controls.update();
checkIntersection();
renderer.render(scene, camera);
}
// Add event listeners
window.addEventListener('mousemove', onMouseMove, false);
window.addEventListener('resize', onWindowResize, false);
// Start animation
animate();
</script>
</body>
</html>