File size: 7,319 Bytes
2b3faac
 
527ed9f
 
 
 
 
 
 
2b3faac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
527ed9f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202

import matplotlib.pyplot as plt
import trimesh
import numpy as np
from copy import deepcopy
from PIL import Image

from . import color_mappings


def plot_all_modalities(ds_entry, figsize=(8, 15)):
    modalities_to_plot = ['images', 'depth', 'gestalt', 'ade']
    modalities_in_entry = [k for k in ds_entry.keys() if k in modalities_to_plot and len(ds_entry[k]) > 0]
    number_of_columns = len(modalities_in_entry)
    number_of_images = len(ds_entry['image_ids'])
    number_of_rows = number_of_images
    fig, axes = plt.subplots(number_of_rows, number_of_columns, figsize=figsize)
    for i in range(len(ds_entry[modalities_in_entry[0]])):
        for j, modality in enumerate(modalities_in_entry):
            ax = axes[i, j]
            if modality == 'image':
                ax.imshow(ds_entry[modality][i])
            elif modality == 'depth':
                depth_image = np.array(ds_entry[modality][i])/1000.0
                ax.imshow(depth_image, cmap='rainbow')
            elif modality == 'gestalt':
                ax.imshow(ds_entry[modality][i])
            elif modality == 'ade':
                ax.imshow(ds_entry[modality][i])
            else:
                raise ValueError(f"Unknown modality: {modality}")
            if i == 0:
                ax.set_title(modality)
            ax.axis('off')
            if j == 0:
                ax.set_ylabel(f"Image {i}")
    fig.tight_layout()  
    fig.subplots_adjust(wspace=0.05, hspace=0.01)
    #plt.show()
    return fig, axes


def line(p1, p2, c=(255,0,0), resolution=10, radius=0.05):
    '''draws a 3d cylinder along the line (p1, p2)'''
    # check colors
    if len(c) == 1:
        c = [c[0]]*4
    elif len(c) == 3:
        c = [*c, 255]
    elif len(c) != 4:
        raise ValueError(f'{c} is not a valid color (must have 1,3, or 4 elements).')
        
    # compute length and direction of segment
    p1, p2 = np.asarray(p1), np.asarray(p2)
    l = np.linalg.norm(p2-p1)
    
    direction = (p2 - p1) / l
    
    # point z along direction of segment
    T = np.eye(4)
    T[:3, 2] = direction
    T[:3, 3] = (p1+p2)/2
    
    #reorthogonalize basis
    b0, b1 = T[:3, 0], T[:3, 1]
    if np.abs(np.dot(b0, direction)) < np.abs(np.dot(b1, direction)):
        T[:3, 1] = -np.cross(b0, direction)
    else:
        T[:3, 0] = np.cross(b1, direction)
    
    # generate and transform mesh
    mesh = trimesh.primitives.Cylinder(radius=radius, height=l, transform=T)
    
    # apply uniform color
    mesh.visual.vertex_colors = np.ones_like(mesh.visual.vertex_colors)*c
         
    return mesh

def show_wf(row, radius=10, show_vertices=False, vertex_color=(255,0,0, 255)):
    EDGE_CLASSES = ['eave',
                    'ridge',
                    'step_flashing',
                    'rake',
                    'flashing',
                    'post',
                    'valley',
                    'hip',
                    'transition_line']
    out_meshes = []
    if show_vertices:
        out_meshes.extend([trimesh.primitives.Sphere(radius=radius+5, center = center, color=vertex_color) for center in row['wf_vertices']])
        for m in out_meshes:
            m.visual.vertex_colors = np.ones_like(m.visual.vertex_colors)*vertex_color
    if 'edge_semantics' not in row:
        print ("Warning: edge semantics is not here, skipping")
        out_meshes.extend([line(a,b, radius=radius, c=(214, 251, 248)) for a,b in np.stack([*row['wf_vertices']])[np.stack(row['wf_edges'])]])
    elif len(np.stack(row['wf_edges'])) ==  len(row['edge_semantics']):
        out_meshes.extend([line(a,b, radius=radius, c=color_mappings.gestalt_color_mapping[EDGE_CLASSES[cls_id]]) for (a,b), cls_id in zip(np.stack([*row['wf_vertices']])[np.stack(row['wf_edges'])], row['edge_semantics'])])
    else:
        print ("Warning: edge semantics has different length compared to edges, skipping semantics")
        out_meshes.extend([line(a,b, radius=radius, c=(214, 251, 248)) for a,b in np.stack([*row['wf_vertices']])[np.stack(row['wf_edges'])]])
    return out_meshes
    # return [line(a,b, radius=radius, c=color_mappings.edge_colors[cls_id]) for (a,b), cls_id in zip(np.stack([*row['wf_vertices']])[np.stack(row['wf_edges'])], row['edge_semantics'])]


def show_grid(edges, meshes=None, row_length=5):
    '''
        edges: list of list of meshes
        meshes: optional corresponding list of meshes
        row_length: number of meshes per row
  
        returns trimesh.Scene()
    '''
    
    T = np.eye(4)
    out = []
    edges = [sum(e[1:], e[0]) for e in edges]
    row_height = 1.1 * max((e.extents for e in edges), key=lambda e: e[1])[1]
    col_width = 1.1 * max((e.extents for e in edges), key=lambda e: e[0])[0]
    # print(row_height, col_width)
    
    if meshes is None:
        meshes = [None]*len(edges)

    for i, (gt, mesh) in enumerate(zip(edges, meshes), start=0):
        mesh = deepcopy(mesh)
        gt = deepcopy(gt)

        if i%row_length != 0:
            T[0, 3] += col_width

        else:
            T[0, 3] = 0
            T[1, 3] += row_height

        # print(T[0,3]/col_width, T[2,3]/row_height)
        
        if mesh is not None:
            mesh.apply_transform(T)
            out.append(mesh)
                            
        gt.apply_transform(T)
        out.append(gt)
        
                            
        out.extend([mesh, gt])

            
    return trimesh.Scene(out)


def visualize_order_images(row_order):
    return create_image_grid(row_order['ade20k'] + row_order['gestalt'] + [visualize_depth(dm) for dm in row_order['depthcm']], num_per_row=len(row_order['ade20k']))

def create_image_grid(images, target_length=312, num_per_row=2):
    # Calculate the target size for the first image
    first_img = images[0]
    aspect_ratio = first_img.width / first_img.height
    new_width = int((target_length ** 2 * aspect_ratio) ** 0.5)
    new_height = int((target_length ** 2 / aspect_ratio) ** 0.5)
    
    # Resize the first image
    resized_images = [img.resize((new_width, new_height), Image.Resampling.LANCZOS) for img in images]
    
    # Calculate the grid size
    num_rows = (len(resized_images) + num_per_row - 1) // num_per_row
    grid_width = new_width * num_per_row
    grid_height = new_height * num_rows
    
    # Create a new image for the grid
    grid_img = Image.new('RGB', (grid_width, grid_height))
    
    # Paste the images into the grid
    for i, img in enumerate(resized_images):
        x_offset = (i % num_per_row) * new_width
        y_offset = (i // num_per_row) * new_height
        grid_img.paste(img, (x_offset, y_offset))
    
    return grid_img


def visualize_depth(depth, min_depth=None, max_depth=None, cmap='rainbow'):
    depth = np.array(depth)
    
    if min_depth is None:
        min_depth = np.min(depth)
    if max_depth is None:
        max_depth = np.max(depth)
    
    
    # Normalize the depth to be between 0 and 1
    depth = (depth - min_depth) / (max_depth - min_depth)
    depth = np.clip(depth, 0, 1)
    
    # Use the matplotlib colormap to convert the depth to an RGB image
    cmap = plt.get_cmap(cmap)
    depth_image = (cmap(depth) * 255).astype(np.uint8)
    
    # Convert the depth image to a PIL image
    depth_image = Image.fromarray(depth_image)
    
    return depth_image