Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
def edgeLoss(preds_edges, edges): | |
""" | |
Args: | |
preds_edges: with shape [b, c, h , w] | |
edges: with shape [b, c, h, w] | |
Returns: Edge losses | |
""" | |
mask = (edges > 0.5).float() | |
b, c, h, w = mask.shape | |
num_pos = torch.sum(mask, dim=[1, 2, 3]).float() | |
num_neg = c * h * w - num_pos | |
neg_weights = (num_neg / (num_pos + num_neg)).unsqueeze(1).unsqueeze(2).unsqueeze(3) | |
pos_weights = (num_pos / (num_pos + num_neg)).unsqueeze(1).unsqueeze(2).unsqueeze(3) | |
weight = neg_weights * mask + pos_weights * (1 - mask) # weight for debug | |
losses = F.binary_cross_entropy_with_logits(preds_edges.float(), edges.float(), weight=weight, reduction='none') | |
loss = torch.mean(losses) | |
return loss | |
class EdgeAcc(nn.Module): | |
""" | |
Measure the accuracy of the edge map | |
""" | |
def __init__(self, threshold=0.5): | |
super(EdgeAcc, self).__init__() | |
self.threshold = threshold | |
def __call__(self, pred_edge, gt_edge): | |
""" | |
Args: | |
pred_edge: Predicted edges, with shape [b, c, h, w] | |
gt_edge: GT edges, with shape [b, c, h, w] | |
Returns: The prediction accuracy and the recall of the edges | |
""" | |
labels = gt_edge > self.threshold | |
preds = pred_edge > self.threshold | |
relevant = torch.sum(labels.float()) | |
selected = torch.sum(preds.float()) | |
if relevant == 0 and selected == 0: | |
return torch.tensor(1), torch.tensor(1) | |
true_positive = ((preds == labels) * labels).float() | |
recall = torch.sum(true_positive) / (relevant + 1e-8) | |
precision = torch.sum(true_positive) / (selected + 1e-8) | |
return precision, recall | |
if __name__ == '__main__': | |
edge = torch.zeros([2, 1, 10, 10]) # [b, 1, h, w] -> the extracted edges | |
edge[0, :, 2:8, 2:8] = 1 | |
edge[1, :, 3:7, 3:7] = 1 | |
mask = (edge > 0.5).float() | |
b, c, h, w = mask.shape | |
num_pos = torch.sum(mask, dim=[1, 2, 3]).float() | |
num_neg = c * h * w - num_pos | |
print(num_pos, num_neg) | |
n = num_neg / (num_pos + num_neg) | |
p = num_pos / (num_pos + num_neg) | |
n = n.unsqueeze(1).unsqueeze(2).unsqueeze(3) | |
p = p.unsqueeze(1).unsqueeze(2).unsqueeze(3) | |
print(n * mask + p * (1 - mask)) | |
# weight = num_neg / (num_pos + num_neg) * mask + num_pos / (num_pos + num_neg) * (1 - mask) | |
# print(weight) | |