File size: 4,664 Bytes
37c80b4
 
 
 
 
e964e52
37c80b4
feba4eb
37c80b4
 
 
feba4eb
e964e52
37c80b4
 
 
 
 
 
 
 
e964e52
 
 
 
 
 
 
 
 
 
 
 
 
37c80b4
e964e52
37c80b4
e964e52
 
 
 
 
 
 
 
 
 
 
37c80b4
e964e52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37c80b4
e964e52
 
37c80b4
e964e52
 
 
 
 
 
 
 
 
 
37c80b4
e964e52
 
 
 
 
 
 
 
 
 
37c80b4
e964e52
37c80b4
e964e52
 
 
 
 
 
 
 
37c80b4
 
 
 
 
feba4eb
e964e52
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import numpy as np
import trimesh
import tempfile
import torch
from scipy.spatial import Delaunay
from sklearn.cluster import KMeans
from .build_pipe import *
import spaces

pipe = build_pipe()

@spaces.GPU
def generate_terrain(prompt, num_inference_steps, guidance_scale, seed, prefix, crop_size=None):
    """Generates terrain data (RGB and elevation) from a text prompt."""
    if prefix and not prefix.endswith(' '):
        prefix += ' '  # Ensure prefix ends with a space

    full_prompt = prefix + prompt
    generator = torch.Generator("cuda").manual_seed(seed)
    image, dem = pipe(full_prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=generator)

    if crop_size is not None:
        # Center crop the image and dem
        h, w, c = image[0].shape
        start_h = (h - crop_size) // 2
        start_w = (w - crop_size) // 2
        end_h = start_h + crop_size
        end_w = start_w + crop_size
    
        cropped_image = image[0][start_h:end_h, start_w:end_w, :]
        cropped_dem = dem[0][start_h:end_h, start_w:end_w, :]
    else:
        cropped_image = image[0]
        cropped_dem = dem[0]

    return (255 * cropped_image).astype(np.uint8), cropped_dem.mean(-1)

def create_3d_mesh(rgb, elevation, n_clusters=1000):
    """Creates a 3D mesh from RGB and elevation data.
    If n_clusters is 0, generates the full mesh.
    Otherwise, generates a simplified mesh using KMeans clustering with distinct colors.
    """
    rows, cols = elevation.shape
    x, y = np.meshgrid(np.arange(cols), np.arange(rows))
    points_2d = np.stack([x.flatten(), y.flatten()], axis=-1)
    elevation_flat = elevation.flatten()
    points_3d = np.column_stack([points_2d, elevation_flat])
    original_colors = rgb.reshape(-1, 3)

    if n_clusters <= 0:
        # Generate full mesh without clustering
        vertices = points_3d
        try:
            tri = Delaunay(points_2d)
            faces = tri.simplices
            mesh = trimesh.Trimesh(vertices=vertices, faces=faces, vertex_colors=original_colors)
            return mesh
        except Exception as e:
            print(f"Error during Delaunay triangulation (full mesh): {e}")
            return None
    else:
        n_clusters = min(n_clusters, len(elevation_flat))
        # Apply KMeans clustering for simplification
        kmeans = KMeans(n_clusters=n_clusters, random_state=0, n_init='auto')
        kmeans.fit(points_3d)
        cluster_centers = kmeans.cluster_centers_
        cluster_labels = kmeans.labels_

        # Use the cluster centers as the simplified vertices
        simplified_vertices = cluster_centers

        # Perform Delaunay triangulation on the X and Y coordinates of the cluster centers
        simplified_points_2d = simplified_vertices[:, :2]
        try:
            tri = Delaunay(simplified_points_2d)
            faces = tri.simplices
            # Ensure the number of vertices in faces does not exceed the number of simplified vertices
            valid_faces = faces[np.all(faces < len(simplified_vertices), axis=1)]
        except Exception as e:
            print(f"Error during Delaunay triangulation (clustered mesh): {e}")
            return None

        # Assign a distinct color to each cluster
        unique_labels = np.unique(cluster_labels)
        cluster_colors = {}
        for label in unique_labels:
            cluster_indices = np.where(cluster_labels == label)[0]
            if len(cluster_indices) > 0:
                avg_color = np.mean(original_colors[cluster_indices], axis=0).astype(np.uint8)
                cluster_colors[label] = avg_color
            else:
                cluster_colors[label] = np.array([255, 0, 0], dtype=np.uint8) # Red

        vertex_colors = np.array([cluster_colors[i] for i in range(n_clusters)])

        # Create the trimesh object
        mesh = trimesh.Trimesh(vertices=simplified_vertices, faces=valid_faces, vertex_colors=vertex_colors)
        return mesh

def generate_3d_view_output(prompt, num_inference_steps, guidance_scale, seed, crop_size, vertex_count, prefix):
    rgb, elevation = generate_terrain(prompt, num_inference_steps, guidance_scale, seed, prefix, crop_size)
    
    mesh = create_3d_mesh(rgb, 500*elevation, n_clusters=vertex_count)

    with tempfile.NamedTemporaryFile(suffix=".obj", delete=False) as temp_file:
        mesh.export(temp_file.name)
        file_path = temp_file.name

    return file_path

def generate_2d_view_output(prompt, num_inference_steps, guidance_scale, seed, prefix):
    rgb, elevation = generate_terrain(prompt, num_inference_steps, guidance_scale, seed, prefix)

    return rgb, elevation