File size: 6,594 Bytes
2b3faac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import numpy as np
from scipy.spatial.distance import cdist
from scipy.optimize import linear_sum_assignment
import torch
import trimesh
from time import time

MAX_SCORE = 1.0

def get_one_primitive(p1, p2, c=(255, 0, 0), radius=25, primitive_type='cylinder', sections=6):
    if len(c) == 1:
        c = [c[0]] * 4
    elif len(c) == 3:
        c = [*c, 255]
    elif len(c) != 4:
        raise ValueError(f'{c} is not a valid color (must have 1,3, or 4 elements).')

    p1, p2 = np.asarray(p1), np.asarray(p2)
    l = np.linalg.norm(p2 - p1)
    
    # Add check for zero-length edges
    if l < 1e-6:
        return None
        
    direction = (p2 - p1) / l

    T = np.eye(4)
    T[:3, 2] = direction
    T[:3, 3] = (p1 + p2) / 2

    b0, b1 = T[:3, 0], T[:3, 1]
    if np.abs(np.dot(b0, direction)) < np.abs(np.dot(b1, direction)):
        T[:3, 1] = -np.cross(b0, direction)
    else:
        T[:3, 0] = np.cross(b1, direction)

    if primitive_type == 'capsule':
        mesh = trimesh.primitives.Capsule(radius=radius, height=l, transform=T, sections=sections)
    elif primitive_type == 'cylinder':
        mesh = trimesh.primitives.Cylinder(radius=radius, height=l, transform=T, sections=sections)
    else:
        raise ValueError("Unknown primitive!")

    # Add vertex color initialization check
    if not hasattr(mesh.visual, 'vertex_colors') or mesh.visual.vertex_colors is None:
        mesh.visual.vertex_colors = np.ones((len(mesh.vertices), 4)) * 255
        
    mesh.visual.vertex_colors = np.ones_like(mesh.visual.vertex_colors) * c
    return mesh

def get_primitives(vertices, edges, radius=25, c=[255, 0, 0]):
    # Convert vertices to a NumPy array
    if isinstance(vertices, torch.Tensor):
        vertices = vertices.detach().cpu().numpy()
    else:
        vertices = np.asarray(vertices)
    
    # Convert edges to a NumPy array of integers
    if isinstance(edges, torch.Tensor):
        edges = edges.detach().cpu().numpy().astype(np.int64)
    else:
        edges = np.asarray(edges, dtype=np.int64)
    
    primitives = []
    for e in edges:
        # Add edge validation
        if e[0] >= len(vertices) or e[1] >= len(vertices):
            continue
        primitive = get_one_primitive(vertices[e[0]], vertices[e[1]], radius=radius, c=c)
        if primitive is not None:
            primitives.append(primitive)
    return primitives



def compute_mesh_iou_VOLUME(pd_vertices, pd_edges, gt_vertices, gt_edges, radius=20, engine='manifold'):
    # check empty 
    if len(pd_edges) == 0 or len(gt_edges) == 0:
        return 0.0

    pd_vertices = pd_vertices.detach().cpu() if isinstance(pd_vertices, torch.Tensor) else pd_vertices
    pd_edges = pd_edges.detach().cpu() if isinstance(pd_edges, torch.Tensor) else pd_edges
    gt_vertices = gt_vertices.detach().cpu() if isinstance(gt_vertices, torch.Tensor) else gt_vertices
    gt_edges = gt_edges.detach().cpu() if isinstance(gt_edges, torch.Tensor) else gt_edges

    pd_primitives = get_primitives(pd_vertices, pd_edges, radius=radius, c=[0, 255, 0])
    gt_primitives = get_primitives(gt_vertices, gt_edges, radius=radius, c=[255, 0, 0])
    # check for empty primitives
    if not pd_primitives or not gt_primitives:
        return 0.0

    # Add bounding box check to detect non-overlapping cases quickly
    pd_bounds = np.array([p.bounds for p in pd_primitives])
    gt_bounds = np.array([p.bounds for p in gt_primitives])
    
    pd_min, pd_max = np.min(pd_bounds[:, 0], axis=0), np.max(pd_bounds[:, 1], axis=0)
    gt_min, gt_max = np.min(gt_bounds[:, 0], axis=0), np.max(gt_bounds[:, 1], axis=0)
    
    # If bounding boxes don't overlap, return 0
    if np.any(pd_max < gt_min) or np.any(pd_min > gt_max):
        return 0.0
    t=time()
    mesh_pred = trimesh.boolean.union(pd_primitives, engine=engine)
    #print(f"mesh_pred union: {time() - t} {mesh_pred.is_volume}")
    t=time()
    mesh_gt= trimesh.boolean.union(gt_primitives, engine=engine)
    #print(f"mesh_gt union: {time() - t} {mesh_gt.is_volume}")

    if mesh_pred.is_volume and mesh_gt.is_volume:
        t=time()
        inter_volume = trimesh.boolean.intersection([mesh_pred, mesh_gt], engine=engine).volume
        #print(f"inter_volume: {time() - t}")
    else:
        all_inter = []
        t=time()
        for pd_prim in pd_primitives:
            pd_min, pd_max = pd_prim.bounds
            for gt_prim in gt_primitives:
                # Skip intersection calculation if bounding boxes don't overlap
                gt_min, gt_max = gt_prim.bounds
                if np.any(pd_max < gt_min) or np.any(pd_min > gt_max):
                    continue
                inter = trimesh.boolean.intersection([pd_prim, gt_prim], engine=engine)
                if inter.is_volume and inter.volume > 0:
                    all_inter.append(inter)
        inter_volume = trimesh.boolean.union(all_inter, engine=engine).volume if all_inter else 0
        #print(f"all_inter: {time() - t}")
    union_volume = mesh_pred.volume + mesh_gt.volume - inter_volume
    
    return inter_volume / union_volume if union_volume > 0 else 0.0


# ----------------- Corner F1 -----------------
def compute_ap_metrics(pd_vertices, gt_vertices, thresh=25):
    if len(pd_vertices) == 0 or len(gt_vertices) == 0:
        return 0.0

    dists = cdist(pd_vertices, gt_vertices)
    row_ind, col_ind = linear_sum_assignment(dists)

    tp = (dists[row_ind, col_ind] <= thresh).sum()
    precision = tp / len(pd_vertices) if len(pd_vertices) > 0 else 0
    recall = tp / len(gt_vertices) if len(gt_vertices) > 0 else 0
    denom = precision + recall
    f1 = (2 * precision * recall / denom) if denom > 0 else 0.0
    return f1

def batch_corner_f1(X, Y, distance_thresh=25):
    results = []
    for (pd_v, _), (gt_v, _) in zip(X, Y):
        results.append(compute_ap_metrics(pd_v, gt_v, thresh=distance_thresh))
    return np.array(results)

# ----------------- HSS Metric -----------------
from collections import namedtuple
HSSReturnType = namedtuple('HSSReturnType', ['hss', 'f1', 'iou'])
def hss(y_hat_v, y_hat_e, y_v, y_e, vert_thresh=0.5, edge_thresh=0.5):
    X = [(y_hat_v, y_hat_e)]
    Y = [(y_v, y_e)]
    t=time()
    f1 = np.clip(batch_corner_f1(X, Y, distance_thresh=vert_thresh)[0], 0, 1)
    #print(f"f1 {f1}: in {time() - t:.2f} sec")
    t=time()
    IoU = np.clip(compute_mesh_iou_VOLUME(y_hat_v, y_hat_e, y_v, y_e, radius=edge_thresh), 0, 1)
    #print(f"IoU: {IoU} in {time() - t:.2f} sec")
    score = 2 * f1 * IoU / (f1 + IoU) if (f1 + IoU) > 0 else 0.0
    return HSSReturnType(hss=score, f1=f1, iou=IoU)