File size: 5,380 Bytes
37c80b4
 
 
 
 
e964e52
37c80b4
feba4eb
37c80b4
 
 
feba4eb
f328707
37c80b4
 
 
 
 
f328707
 
 
 
 
37c80b4
 
e964e52
 
 
 
 
 
 
 
 
 
 
 
 
37c80b4
e964e52
37c80b4
e964e52
 
 
 
 
 
 
 
 
 
 
37c80b4
e964e52
 
 
f328707
e964e52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37c80b4
e964e52
 
37c80b4
e964e52
 
 
 
 
 
 
 
 
 
37c80b4
e964e52
 
 
 
 
 
 
 
 
 
37c80b4
e964e52
37c80b4
e964e52
 
 
 
f328707
 
 
 
 
 
 
 
 
 
e964e52
 
37c80b4
 
 
f328707
 
 
 
37c80b4
f328707
e964e52
f328707
 
e964e52
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import numpy as np
import trimesh
import tempfile
import torch
from scipy.spatial import Delaunay
from sklearn.cluster import KMeans
from .build_pipe import *
import spaces

pipe = build_pipe()

@spaces.GPU
def generate_terrain(prompt, num_inference_steps, guidance_scale, seed, random_seed, prefix, crop_size=None):
    """Generates terrain data (RGB and elevation) from a text prompt."""
    if prefix and not prefix.endswith(' '):
        prefix += ' '  # Ensure prefix ends with a space

    full_prompt = prefix + prompt
    if random_seed:
        generator = torch.Generator("cuda")
    else:
        generator = torch.Generator("cuda").manual_seed(seed)
        
    image, dem = pipe(full_prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=generator)

    if crop_size is not None:
        # Center crop the image and dem
        h, w, c = image[0].shape
        start_h = (h - crop_size) // 2
        start_w = (w - crop_size) // 2
        end_h = start_h + crop_size
        end_w = start_w + crop_size
    
        cropped_image = image[0][start_h:end_h, start_w:end_w, :]
        cropped_dem = dem[0][start_h:end_h, start_w:end_w, :]
    else:
        cropped_image = image[0]
        cropped_dem = dem[0]

    return (255 * cropped_image).astype(np.uint8), cropped_dem.mean(-1)

def create_3d_mesh(rgb, elevation, n_clusters=1000):
    """Creates a 3D mesh from RGB and elevation data.
    If n_clusters is 0, generates the full mesh.
    Otherwise, generates a simplified mesh using KMeans clustering with distinct colors.
    """
    rows, cols = elevation.shape
    x, y = np.meshgrid(np.arange(cols), np.arange(rows))
    points_2d = np.stack([x.flatten(), y.flatten()], axis=-1)
    elevation_flat = elevation.flatten()
    points_3d = np.column_stack([points_2d, elevation_flat])
    original_colors = rgb.reshape(-1, 3)

    if n_clusters <= 0:
        # Generate full mesh without clustering
        vertices = points_3d
        
        try:
            tri = Delaunay(points_2d)
            faces = tri.simplices
            mesh = trimesh.Trimesh(vertices=vertices, faces=faces, vertex_colors=original_colors)
            return mesh
        except Exception as e:
            print(f"Error during Delaunay triangulation (full mesh): {e}")
            return None
    else:
        n_clusters = min(n_clusters, len(elevation_flat))
        # Apply KMeans clustering for simplification
        kmeans = KMeans(n_clusters=n_clusters, random_state=0, n_init='auto')
        kmeans.fit(points_3d)
        cluster_centers = kmeans.cluster_centers_
        cluster_labels = kmeans.labels_

        # Use the cluster centers as the simplified vertices
        simplified_vertices = cluster_centers

        # Perform Delaunay triangulation on the X and Y coordinates of the cluster centers
        simplified_points_2d = simplified_vertices[:, :2]
        try:
            tri = Delaunay(simplified_points_2d)
            faces = tri.simplices
            # Ensure the number of vertices in faces does not exceed the number of simplified vertices
            valid_faces = faces[np.all(faces < len(simplified_vertices), axis=1)]
        except Exception as e:
            print(f"Error during Delaunay triangulation (clustered mesh): {e}")
            return None

        # Assign a distinct color to each cluster
        unique_labels = np.unique(cluster_labels)
        cluster_colors = {}
        for label in unique_labels:
            cluster_indices = np.where(cluster_labels == label)[0]
            if len(cluster_indices) > 0:
                avg_color = np.mean(original_colors[cluster_indices], axis=0).astype(np.uint8)
                cluster_colors[label] = avg_color
            else:
                cluster_colors[label] = np.array([255, 0, 0], dtype=np.uint8) # Red

        vertex_colors = np.array([cluster_colors[i] for i in range(n_clusters)])

        # Create the trimesh object
        mesh = trimesh.Trimesh(vertices=simplified_vertices, faces=valid_faces, vertex_colors=vertex_colors)
        return mesh

def create_3d_point_cloud(rgb, elevation):
    height, width = elevation.shape
    x, y = np.meshgrid(np.arange(width), np.arange(height))
    points = np.stack([x.flatten(), y.flatten(), elevation.flatten()], axis=-1)
    colors = rgb.reshape(-1, 3)

    return trimesh.PointCloud(vertices=points, colors=colors)

def generate_3d_view_output(prompt, num_inference_steps, guidance_scale, seed, random_seed, crop_size, vertex_count, prefix):
    rgb, elevation = generate_terrain(prompt, num_inference_steps, guidance_scale, seed, random_seed, prefix, crop_size)
    
    mesh = create_3d_mesh(rgb, 500*elevation, n_clusters=vertex_count)
    with tempfile.NamedTemporaryFile(suffix=".obj", delete=False) as temp_file:
        mesh.export(temp_file.name)
        file_path = temp_file.name
    # pc = create_3d_point_cloud(rgb, 500*elevation)
    # with tempfile.NamedTemporaryFile(suffix=".ply", delete=False) as temp_file:
    #     pc.export(temp_file.name, file_type="ply")
    #     file_path = temp_file.name

    return rgb, elevation, file_path

def generate_2d_view_output(prompt, num_inference_steps, guidance_scale, seed, random_seed, prefix):
    rgb, elevation = generate_terrain(prompt, num_inference_steps, guidance_scale, seed, random_seed, prefix)

    return rgb, elevation