File size: 4,033 Bytes
123719b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
#
# Copyright (C) 2023, Inria
# GRAPHDECO research group, https://team.inria.fr/graphdeco
# All rights reserved.
#
# This software is free for non-commercial, research and evaluation use 
# under the terms of the LICENSE.md file.
#
# For inquiries contact  [email protected]
#

import torch

def mse(img1, img2):
    return (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)

def psnr(img1, img2):
    mse = (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)
    return 20 * torch.log10(1.0 / torch.sqrt(mse))

def masked_psnr(img1, img2, mask):
    mse = ((((img1 - img2)) ** 2) * mask).sum() / (3. * mask.sum())
    return 20 * torch.log10(1.0 / torch.sqrt(mse))


def accuracy_torch(gt_points, rec_points, gt_normals=None, rec_normals=None, batch_size=5000):
    n_points = rec_points.shape[0]
    all_distances = []
    all_indices = []
    
    for i in range(0, n_points, batch_size):
        end_idx = min(i + batch_size, n_points)
        batch_points = rec_points[i:end_idx]
        
        distances = torch.cdist(batch_points, gt_points)  # (batch_size, M)
        batch_distances, batch_indices = torch.min(distances, dim=1)  # (batch_size,)
        
        all_distances.append(batch_distances)
        all_indices.append(batch_indices)
    
    distances = torch.cat(all_distances)
    indices = torch.cat(all_indices)
    
    acc = torch.mean(distances)
    acc_median = torch.median(distances)

    if gt_normals is not None and rec_normals is not None:
        normal_dot = torch.sum(gt_normals[indices] * rec_normals, dim=-1)
        normal_dot = torch.abs(normal_dot)
        return acc, acc_median, torch.mean(normal_dot), torch.median(normal_dot)

    return acc, acc_median

def completion_torch(gt_points, rec_points, gt_normals=None, rec_normals=None, batch_size=5000):

    n_points = gt_points.shape[0]
    all_distances = []
    all_indices = []
    
    for i in range(0, n_points, batch_size):
        end_idx = min(i + batch_size, n_points)
        batch_points = gt_points[i:end_idx]
        
        distances = torch.cdist(batch_points, rec_points)  # (batch_size, M)
        batch_distances, batch_indices = torch.min(distances, dim=1)  # (batch_size,)
        
        all_distances.append(batch_distances)
        all_indices.append(batch_indices)
    
    distances = torch.cat(all_distances)
    indices = torch.cat(all_indices)
    
    comp = torch.mean(distances)
    comp_median = torch.median(distances)

    if gt_normals is not None and rec_normals is not None:
        normal_dot = torch.sum(gt_normals * rec_normals[indices], dim=-1)
        normal_dot = torch.abs(normal_dot)
        return comp, comp_median, torch.mean(normal_dot), torch.median(normal_dot)
    
    return comp, comp_median

def accuracy_per_point(gt_points, rec_points, batch_size=5000):
    n_points = rec_points.shape[0]
    all_distances = []
    all_indices = []
    
    for i in range(0, n_points, batch_size):
        end_idx = min(i + batch_size, n_points)
        batch_points = rec_points[i:end_idx]
        
        distances = torch.cdist(batch_points, gt_points)  # (batch_size, M)
        batch_distances, batch_indices = torch.min(distances, dim=1)  # (batch_size,)
        
        all_distances.append(batch_distances)
        all_indices.append(batch_indices)
    
    distances = torch.cat(all_distances)
    return distances

def completion_per_point(gt_points, rec_points, batch_size=5000):

    n_points = gt_points.shape[0]
    all_distances = []
    all_indices = []
    
    for i in range(0, n_points, batch_size):
        end_idx = min(i + batch_size, n_points)
        batch_points = gt_points[i:end_idx]
        
        distances = torch.cdist(batch_points, rec_points)  # (batch_size, M)
        batch_distances, batch_indices = torch.min(distances, dim=1)  # (batch_size,)
        
        all_distances.append(batch_distances)
        all_indices.append(batch_indices)
    
    distances = torch.cat(all_distances)
    return distances