MESA / src /utils.py
mikonvergence
minor update (random seed etc)
f328707
raw
history blame
5.38 kB
import numpy as np
import trimesh
import tempfile
import torch
from scipy.spatial import Delaunay
from sklearn.cluster import KMeans
from .build_pipe import *
import spaces
pipe = build_pipe()
@spaces.GPU
def generate_terrain(prompt, num_inference_steps, guidance_scale, seed, random_seed, prefix, crop_size=None):
"""Generates terrain data (RGB and elevation) from a text prompt."""
if prefix and not prefix.endswith(' '):
prefix += ' ' # Ensure prefix ends with a space
full_prompt = prefix + prompt
if random_seed:
generator = torch.Generator("cuda")
else:
generator = torch.Generator("cuda").manual_seed(seed)
image, dem = pipe(full_prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=generator)
if crop_size is not None:
# Center crop the image and dem
h, w, c = image[0].shape
start_h = (h - crop_size) // 2
start_w = (w - crop_size) // 2
end_h = start_h + crop_size
end_w = start_w + crop_size
cropped_image = image[0][start_h:end_h, start_w:end_w, :]
cropped_dem = dem[0][start_h:end_h, start_w:end_w, :]
else:
cropped_image = image[0]
cropped_dem = dem[0]
return (255 * cropped_image).astype(np.uint8), cropped_dem.mean(-1)
def create_3d_mesh(rgb, elevation, n_clusters=1000):
"""Creates a 3D mesh from RGB and elevation data.
If n_clusters is 0, generates the full mesh.
Otherwise, generates a simplified mesh using KMeans clustering with distinct colors.
"""
rows, cols = elevation.shape
x, y = np.meshgrid(np.arange(cols), np.arange(rows))
points_2d = np.stack([x.flatten(), y.flatten()], axis=-1)
elevation_flat = elevation.flatten()
points_3d = np.column_stack([points_2d, elevation_flat])
original_colors = rgb.reshape(-1, 3)
if n_clusters <= 0:
# Generate full mesh without clustering
vertices = points_3d
try:
tri = Delaunay(points_2d)
faces = tri.simplices
mesh = trimesh.Trimesh(vertices=vertices, faces=faces, vertex_colors=original_colors)
return mesh
except Exception as e:
print(f"Error during Delaunay triangulation (full mesh): {e}")
return None
else:
n_clusters = min(n_clusters, len(elevation_flat))
# Apply KMeans clustering for simplification
kmeans = KMeans(n_clusters=n_clusters, random_state=0, n_init='auto')
kmeans.fit(points_3d)
cluster_centers = kmeans.cluster_centers_
cluster_labels = kmeans.labels_
# Use the cluster centers as the simplified vertices
simplified_vertices = cluster_centers
# Perform Delaunay triangulation on the X and Y coordinates of the cluster centers
simplified_points_2d = simplified_vertices[:, :2]
try:
tri = Delaunay(simplified_points_2d)
faces = tri.simplices
# Ensure the number of vertices in faces does not exceed the number of simplified vertices
valid_faces = faces[np.all(faces < len(simplified_vertices), axis=1)]
except Exception as e:
print(f"Error during Delaunay triangulation (clustered mesh): {e}")
return None
# Assign a distinct color to each cluster
unique_labels = np.unique(cluster_labels)
cluster_colors = {}
for label in unique_labels:
cluster_indices = np.where(cluster_labels == label)[0]
if len(cluster_indices) > 0:
avg_color = np.mean(original_colors[cluster_indices], axis=0).astype(np.uint8)
cluster_colors[label] = avg_color
else:
cluster_colors[label] = np.array([255, 0, 0], dtype=np.uint8) # Red
vertex_colors = np.array([cluster_colors[i] for i in range(n_clusters)])
# Create the trimesh object
mesh = trimesh.Trimesh(vertices=simplified_vertices, faces=valid_faces, vertex_colors=vertex_colors)
return mesh
def create_3d_point_cloud(rgb, elevation):
height, width = elevation.shape
x, y = np.meshgrid(np.arange(width), np.arange(height))
points = np.stack([x.flatten(), y.flatten(), elevation.flatten()], axis=-1)
colors = rgb.reshape(-1, 3)
return trimesh.PointCloud(vertices=points, colors=colors)
def generate_3d_view_output(prompt, num_inference_steps, guidance_scale, seed, random_seed, crop_size, vertex_count, prefix):
rgb, elevation = generate_terrain(prompt, num_inference_steps, guidance_scale, seed, random_seed, prefix, crop_size)
mesh = create_3d_mesh(rgb, 500*elevation, n_clusters=vertex_count)
with tempfile.NamedTemporaryFile(suffix=".obj", delete=False) as temp_file:
mesh.export(temp_file.name)
file_path = temp_file.name
# pc = create_3d_point_cloud(rgb, 500*elevation)
# with tempfile.NamedTemporaryFile(suffix=".ply", delete=False) as temp_file:
# pc.export(temp_file.name, file_type="ply")
# file_path = temp_file.name
return rgb, elevation, file_path
def generate_2d_view_output(prompt, num_inference_steps, guidance_scale, seed, random_seed, prefix):
rgb, elevation = generate_terrain(prompt, num_inference_steps, guidance_scale, seed, random_seed, prefix)
return rgb, elevation