import torch.nn as nn import torch.nn.functional as F import torch from ..utils import box_utils class MultiboxLoss(nn.Module): def __init__(self, priors, iou_threshold, neg_pos_ratio, center_variance, size_variance, device): """Implement SSD Multibox Loss. Basically, Multibox loss combines classification loss and Smooth L1 regression loss. """ super(MultiboxLoss, self).__init__() self.iou_threshold = iou_threshold self.neg_pos_ratio = neg_pos_ratio self.center_variance = center_variance self.size_variance = size_variance self.priors = priors self.priors.to(device) def forward(self, confidence, predicted_locations, labels, gt_locations): """Compute classification loss and smooth l1 loss. Args: confidence (batch_size, num_priors, num_classes): class predictions. locations (batch_size, num_priors, 4): predicted locations. labels (batch_size, num_priors): real labels of all the priors. boxes (batch_size, num_priors, 4): real boxes corresponding all the priors. """ num_classes = confidence.size(2) with torch.no_grad(): # derived from cross_entropy=sum(log(p)) loss = -F.log_softmax(confidence, dim=2)[:, :, 0] mask = box_utils.hard_negative_mining(loss, labels, self.neg_pos_ratio) confidence = confidence[mask, :] classification_loss = F.cross_entropy(confidence.reshape(-1, num_classes), labels[mask], size_average=False) pos_mask = labels > 0 predicted_locations = predicted_locations[pos_mask, :].reshape(-1, 4) gt_locations = gt_locations[pos_mask, :].reshape(-1, 4) smooth_l1_loss = F.smooth_l1_loss(predicted_locations, gt_locations, size_average=False) num_pos = gt_locations.size(0) return smooth_l1_loss/num_pos, classification_loss/num_pos