""" Implementation of Yolo Loss Function similar to the one in Yolov3 paper, the difference from what I can tell is I use CrossEntropy for the classes instead of BinaryCrossEntropy. """ import random import torch import torch.nn as nn import pytorch_lightning as pl from utils import intersection_over_union import config as cfg class YoloLoss(pl.LightningModule): def __init__(self): super().__init__() self.mse = nn.MSELoss() self.bce = nn.BCEWithLogitsLoss() self.entropy = nn.CrossEntropyLoss() self.sigmoid = nn.Sigmoid() # Constants signifying how much to pay for each respective part of the loss self.lambda_class = 1 self.lambda_noobj = 10 self.lambda_obj = 1 self.lambda_box = 10 self.scaled_anchors = ( torch.tensor(cfg.ANCHORS) * torch.tensor(cfg.S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2) ) def forward(self, predictions_list, target_list, **kwargs): anchors_list = kwargs.get('anchors_list', None) if not anchors_list: anchors_list = self.scaled_anchors anchors_list = anchors_list.to(cfg.DEVICE) box_loss = 0.0 object_loss = 0.0 no_object_loss = 0.0 class_loss = 0.0 for i in range(3): target = target_list[i] predictions = predictions_list[i] anchors = anchors_list[i] # Check where obj and noobj (we ignore if target == -1) obj = target[..., 0] == 1 # in paper this is Iobj_i noobj = target[..., 0] == 0 # in paper this is Inoobj_i # ======================= # # FOR NO OBJECT LOSS # # ======================= # no_object_loss += self.bce( (predictions[..., 0:1][noobj]), (target[..., 0:1][noobj]), ) # ==================== # # FOR OBJECT LOSS # # ==================== # anchors = anchors.reshape(1, 3, 1, 1, 2) box_preds = torch.cat([self.sigmoid(predictions[..., 1:3]), torch.exp(predictions[..., 3:5]) * anchors], dim=-1) ious = intersection_over_union(box_preds[obj], target[..., 1:5][obj]).detach() object_loss += self.mse(self.sigmoid(predictions[..., 0:1][obj]), ious * target[..., 0:1][obj]) # ======================== # # FOR BOX COORDINATES # # ======================== # predictions[..., 1:3] = self.sigmoid(predictions[..., 1:3]) # x,y coordinates target[..., 3:5] = torch.log( (1e-16 + target[..., 3:5] / anchors) ) # width, height coordinates box_loss += self.mse(predictions[..., 1:5][obj], target[..., 1:5][obj]) # ================== # # FOR CLASS LOSS # # ================== # class_loss += self.entropy( (predictions[..., 5:][obj]), (target[..., 5][obj].long()), ) #print("__________________________________") #print(self.lambda_box * box_loss) #print(self.lambda_obj * object_loss) #print(self.lambda_noobj * no_object_loss) #print(self.lambda_class * class_loss) #print("\n") total_loss = ( self.lambda_box * box_loss + self.lambda_obj * object_loss + self.lambda_noobj * no_object_loss + self.lambda_class * class_loss ) if kwargs.get('loss_dict'): return dict(class_loss=self.lambda_class * class_loss, no_object_loss=self.lambda_noobj * no_object_loss, object_loss=self.lambda_obj * object_loss, box_loss=self.lambda_box * box_loss, total_loss=total_loss ) else: return total_loss def check_class_accuracy(self, predictions, target, threshold): tot_class_preds, correct_class = 0, 0 tot_noobj, correct_noobj = 0, 0 tot_obj, correct_obj = 0, 0 y = target out = predictions for i in range(3): obj = y[i][..., 0] == 1 # in paper this is Iobj_i noobj = y[i][..., 0] == 0 # in paper this is Iobj_i correct_class += torch.sum( torch.argmax(out[i][..., 5:][obj], dim=-1) == y[i][..., 5][obj] ) tot_class_preds += torch.sum(obj) obj_preds = torch.sigmoid(out[i][..., 0]) > threshold correct_obj += torch.sum(obj_preds[obj] == y[i][..., 0][obj]) tot_obj += torch.sum(obj) correct_noobj += torch.sum(obj_preds[noobj] == y[i][..., 0][noobj]) tot_noobj += torch.sum(noobj) return dict( correct_class=correct_class, correct_noobj=correct_noobj, correct_obj=correct_obj, total_class_preds=tot_class_preds, total_noobj=tot_noobj, total_obj=tot_obj ) '''print(f"Class accuracy is: {(correct_class/(tot_class_preds+1e-16))*100:2f}%") print(f"No obj accuracy is: {(correct_noobj/(tot_noobj+1e-16))*100:2f}%") print(f"Obj accuracy is: {(correct_obj/(tot_obj+1e-16))*100:2f}%")'''