Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,380 Bytes
37c80b4 e964e52 37c80b4 feba4eb 37c80b4 feba4eb f328707 37c80b4 f328707 37c80b4 e964e52 37c80b4 e964e52 37c80b4 e964e52 37c80b4 e964e52 f328707 e964e52 37c80b4 e964e52 37c80b4 e964e52 37c80b4 e964e52 37c80b4 e964e52 37c80b4 e964e52 f328707 e964e52 37c80b4 f328707 37c80b4 f328707 e964e52 f328707 e964e52 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import numpy as np
import trimesh
import tempfile
import torch
from scipy.spatial import Delaunay
from sklearn.cluster import KMeans
from .build_pipe import *
import spaces
pipe = build_pipe()
@spaces.GPU
def generate_terrain(prompt, num_inference_steps, guidance_scale, seed, random_seed, prefix, crop_size=None):
"""Generates terrain data (RGB and elevation) from a text prompt."""
if prefix and not prefix.endswith(' '):
prefix += ' ' # Ensure prefix ends with a space
full_prompt = prefix + prompt
if random_seed:
generator = torch.Generator("cuda")
else:
generator = torch.Generator("cuda").manual_seed(seed)
image, dem = pipe(full_prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=generator)
if crop_size is not None:
# Center crop the image and dem
h, w, c = image[0].shape
start_h = (h - crop_size) // 2
start_w = (w - crop_size) // 2
end_h = start_h + crop_size
end_w = start_w + crop_size
cropped_image = image[0][start_h:end_h, start_w:end_w, :]
cropped_dem = dem[0][start_h:end_h, start_w:end_w, :]
else:
cropped_image = image[0]
cropped_dem = dem[0]
return (255 * cropped_image).astype(np.uint8), cropped_dem.mean(-1)
def create_3d_mesh(rgb, elevation, n_clusters=1000):
"""Creates a 3D mesh from RGB and elevation data.
If n_clusters is 0, generates the full mesh.
Otherwise, generates a simplified mesh using KMeans clustering with distinct colors.
"""
rows, cols = elevation.shape
x, y = np.meshgrid(np.arange(cols), np.arange(rows))
points_2d = np.stack([x.flatten(), y.flatten()], axis=-1)
elevation_flat = elevation.flatten()
points_3d = np.column_stack([points_2d, elevation_flat])
original_colors = rgb.reshape(-1, 3)
if n_clusters <= 0:
# Generate full mesh without clustering
vertices = points_3d
try:
tri = Delaunay(points_2d)
faces = tri.simplices
mesh = trimesh.Trimesh(vertices=vertices, faces=faces, vertex_colors=original_colors)
return mesh
except Exception as e:
print(f"Error during Delaunay triangulation (full mesh): {e}")
return None
else:
n_clusters = min(n_clusters, len(elevation_flat))
# Apply KMeans clustering for simplification
kmeans = KMeans(n_clusters=n_clusters, random_state=0, n_init='auto')
kmeans.fit(points_3d)
cluster_centers = kmeans.cluster_centers_
cluster_labels = kmeans.labels_
# Use the cluster centers as the simplified vertices
simplified_vertices = cluster_centers
# Perform Delaunay triangulation on the X and Y coordinates of the cluster centers
simplified_points_2d = simplified_vertices[:, :2]
try:
tri = Delaunay(simplified_points_2d)
faces = tri.simplices
# Ensure the number of vertices in faces does not exceed the number of simplified vertices
valid_faces = faces[np.all(faces < len(simplified_vertices), axis=1)]
except Exception as e:
print(f"Error during Delaunay triangulation (clustered mesh): {e}")
return None
# Assign a distinct color to each cluster
unique_labels = np.unique(cluster_labels)
cluster_colors = {}
for label in unique_labels:
cluster_indices = np.where(cluster_labels == label)[0]
if len(cluster_indices) > 0:
avg_color = np.mean(original_colors[cluster_indices], axis=0).astype(np.uint8)
cluster_colors[label] = avg_color
else:
cluster_colors[label] = np.array([255, 0, 0], dtype=np.uint8) # Red
vertex_colors = np.array([cluster_colors[i] for i in range(n_clusters)])
# Create the trimesh object
mesh = trimesh.Trimesh(vertices=simplified_vertices, faces=valid_faces, vertex_colors=vertex_colors)
return mesh
def create_3d_point_cloud(rgb, elevation):
height, width = elevation.shape
x, y = np.meshgrid(np.arange(width), np.arange(height))
points = np.stack([x.flatten(), y.flatten(), elevation.flatten()], axis=-1)
colors = rgb.reshape(-1, 3)
return trimesh.PointCloud(vertices=points, colors=colors)
def generate_3d_view_output(prompt, num_inference_steps, guidance_scale, seed, random_seed, crop_size, vertex_count, prefix):
rgb, elevation = generate_terrain(prompt, num_inference_steps, guidance_scale, seed, random_seed, prefix, crop_size)
mesh = create_3d_mesh(rgb, 500*elevation, n_clusters=vertex_count)
with tempfile.NamedTemporaryFile(suffix=".obj", delete=False) as temp_file:
mesh.export(temp_file.name)
file_path = temp_file.name
# pc = create_3d_point_cloud(rgb, 500*elevation)
# with tempfile.NamedTemporaryFile(suffix=".ply", delete=False) as temp_file:
# pc.export(temp_file.name, file_type="ply")
# file_path = temp_file.name
return rgb, elevation, file_path
def generate_2d_view_output(prompt, num_inference_steps, guidance_scale, seed, random_seed, prefix):
rgb, elevation = generate_terrain(prompt, num_inference_steps, guidance_scale, seed, random_seed, prefix)
return rgb, elevation |