Peijie commited on
Commit
5ac407d
·
1 Parent(s): c898769

fix bug and tested locally.

Browse files
Files changed (2) hide show
  1. utils/model.py +211 -262
  2. utils/predict.py +1 -1
utils/model.py CHANGED
@@ -157,7 +157,7 @@ class OwlViTForClassification(nn.Module):
157
  config_class = OwlViTConfig
158
 
159
  def __init__(self, owlvit_det_model, num_classes, weight_dict, device, freeze_box_heads=False, train_box_heads_only=False, network_type=None, logits_from_teacher=False, finetuned: bool = False, custom_box_head: bool = False):
160
- super(OwlViTForClassification, self).__init__()
161
 
162
  self.config = owlvit_det_model.config
163
  self.num_classes = num_classes
@@ -202,12 +202,12 @@ class OwlViTForClassification(nn.Module):
202
  losses += ["boxes"] if weight_dict["loss_bbox"] > 0 else []
203
  losses += ["labels"] if weight_dict["loss_ce"] > 0 else []
204
 
205
- self.criterion = DetrLoss(
206
- matcher=None,
207
- num_parts=self.num_parts,
208
- eos_coef=0.1, # Following facebook/detr-resnet-50
209
- losses=losses,
210
- )
211
 
212
  self.freeze_parameters(freeze_box_heads, train_box_heads_only)
213
  del owlvit_det_model
@@ -417,22 +417,7 @@ class OwlViTForClassification(nn.Module):
417
  topk_scores, topk_idxs = torch.topk(teacher_boxes_logits, k=1, dim=1)
418
 
419
  else:
420
- #DEUBUG:
421
- print(f"text_inputs_parts - input_ids: {text_inputs_parts['input_ids'].shape}. attention_mask : {text_inputs_parts['attention_mask'].shape}")
422
- seq_length = text_inputs_parts['input_ids'].shape[-1]
423
- position_ids = self.owlvit.text_model.embeddings.position_ids[:, :seq_length]
424
- txt_embeds = self.owlvit.text_model.embeddings.token_embedding(text_inputs_parts['input_ids'])
425
- print(f"position_embedding: {self.owlvit.text_model.embeddings.position_embedding(position_ids).shape}")
426
- print(f"text_embeds: {txt_embeds.shape}")
427
-
428
- device_ = txt_embeds.device
429
- position_ids = position_ids.to(device_)
430
- txt_embeds_size_0 = text_embeds.size(0)
431
- position_embedding = position_ids.cpu().repeat(txt_embeds_size_0, 1, 1)
432
- text_inputs_parts["position_ids"] = position_ids
433
- print(f"position_embedding : {position_embedding.shape}")
434
- print(f"pos + emb: {(txt_embeds.cpu() + position_embedding).shape}")
435
- text_embeds_parts = self.owlvit.text_model.get_text_features(**text_inputs_parts)
436
 
437
  # # Embed images and text queries
438
  query_mask, text_embeds_parts = self._get_text_query_mask(text_inputs_parts, text_embeds_parts, batch_size)
@@ -460,46 +445,10 @@ class OwlViTForClassification(nn.Module):
460
  outputs_loss["logits"] = pred_logits_parts
461
  outputs_loss["pred_boxes"] = pred_boxes
462
 
463
- # Compute box + class losses
464
- loss_dict = self.criterion(outputs_loss, targets, mapping_indices)
465
-
466
- # Compute symmetric loss to get rid of the teacher model
467
- logits_per_image = torch.softmax(pred_logits_parts, dim=1)
468
- logits_per_text = torch.softmax(pred_logits_parts, dim=-1)
469
-
470
- # For getting rid of the teacher model
471
- if self.weight_dict["loss_sym_box_label"] > 0:
472
- sym_loss_box_label = self.loss_symmetric(logits_per_image, logits_per_text, teacher_boxes_logits)
473
- loss_dict["loss_sym_box_label"] = sym_loss_box_label
474
- # ----------------------------------------------------------------------------------------
475
-
476
- #DEBUG:
477
- print(f"im_features size: {image_feats.shape}, text_embeds size: {text_embeds.shape}")
478
- print(f"im_features sum: {image_feats.sum().item()}, text_embeds sum: {text_embeds.sum().item()}")
479
  # Predict image-level classes (batch_size, num_patches, num_queries)
480
  image_text_logits, pred_logits, part_logits = self.cls_head(image_feats, text_embeds, topk_idxs)
481
- print(f"topk_idxs: {topk_idxs}")
482
- print(f"image_text_logits size: {image_text_logits.shape}")
483
- print(f"image_text_logits sum: {image_text_logits.sum().item()}")
484
-
485
- if self.weight_dict["loss_xclip"] > 0:
486
- targets_cls = torch.tensor([target["targets_cls"] for target in targets]).unsqueeze(1).to(self.device)
487
- if self.network_type == "classification":
488
- one_hot = torch.zeros_like(pred_logits).scatter(1, targets_cls, 1).to(self.device)
489
- cls_loss = self.ce_loss(pred_logits, one_hot)
490
- loss_dict["loss_xclip"] = cls_loss
491
- else:
492
- # TODO: Need a linear classifier for this approach
493
- # Compute symmetric loss for part-descriptor contrastive learning
494
- logits_per_image = torch.softmax(image_text_logits, dim=0)
495
- logits_per_text = torch.softmax(image_text_logits, dim=-1)
496
- sym_loss = self.loss_symmetric(logits_per_image, logits_per_text, targets_cls)
497
- loss_dict["loss_xclip"] = sym_loss
498
-
499
- #DEBUG:
500
- print(f"pred_logits size: {part_logits.shape}, pred_logits size: {part_logits.shape}")
501
- print(f"part_logits sum: {pred_logits.sum().item()}, part_logits sum: {pred_logits.sum().item()}")
502
- return pred_logits, part_logits, loss_dict
503
 
504
  def loss_symmetric(self, text_logits: torch.Tensor, image_logits: torch.Tensor, targets: torch.Tensor, box_labels: torch.Tensor = None) -> torch.Tensor:
505
  # text/image logits (batch_size*num_boxes, num_classes*num_descs): The logits that softmax over text descriptors or boxes
@@ -537,204 +486,204 @@ class OwlViTForClassification(nn.Module):
537
 
538
  return sym_loss
539
 
540
- class DetrLoss(nn.Module):
541
- """
542
- This class computes the losses for DetrForObjectDetection/DetrForSegmentation. The process happens in two steps: 1)
543
- we compute hungarian assignment between ground truth boxes and the outputs of the model 2) we supervise each pair
544
- of matched ground-truth / prediction (supervise class and box).
545
-
546
- A note on the `num_classes` argument (copied from original repo in detr.py): "the naming of the `num_classes`
547
- parameter of the criterion is somewhat misleading. It indeed corresponds to `max_obj_id` + 1, where `max_obj_id` is
548
- the maximum id for a class in your dataset. For example, COCO has a `max_obj_id` of 90, so we pass `num_classes` to
549
- be 91. As another example, for a dataset that has a single class with `id` 1, you should pass `num_classes` to be 2
550
- (`max_obj_id` + 1). For more details on this, check the following discussion
551
- https://github.com/facebookresearch/detr/issues/108#issuecomment-650269223"
552
-
553
-
554
- Args:
555
- matcher (`DetrHungarianMatcher`):
556
- Module able to compute a matching between targets and proposals.
557
- num_parts (`int`):
558
- Number of object categories, omitting the special no-object category.
559
- eos_coef (`float`):
560
- Relative classification weight applied to the no-object category.
561
- losses (`List[str]`):
562
- List of all the losses to be applied. See `get_loss` for a list of all available losses.
563
- """
564
-
565
- def __init__(self, matcher, num_parts, eos_coef, losses):
566
- super().__init__()
567
- self.matcher = matcher
568
- self.num_parts = num_parts
569
- self.eos_coef = eos_coef
570
- self.losses = losses
571
-
572
- # empty_weight = torch.ones(self.num_parts + 1)
573
- empty_weight = torch.ones(self.num_parts)
574
- empty_weight[-1] = self.eos_coef
575
- self.register_buffer("empty_weight", empty_weight)
576
-
577
- # removed logging parameter, which was part of the original implementation
578
- def loss_labels(self, outputs, targets, indices, num_boxes):
579
- """
580
- Classification loss (NLL) targets dicts must contain the key "class_labels" containing a tensor of dim
581
- [nb_target_boxes]
582
- """
583
- if "logits" not in outputs:
584
- raise KeyError("No logits were found in the outputs")
585
- source_logits = outputs["logits"]
586
-
587
- idx = self._get_source_permutation_idx(indices)
588
- # target_classes_o = torch.cat([t["class_labels"][J] for t, (_, J) in zip(targets, indices)])
589
- # target_classes = torch.full(source_logits.shape[:2], self.num_parts, dtype=torch.int64, device=source_logits.device)
590
- # target_classes[idx] = target_classes_o
591
-
592
- source_logits = source_logits[idx].view(len(indices), -1, self.num_parts)
593
- target_classes = torch.stack([t["class_labels"][J] for t, (_, J) in zip(targets, indices)], dim=0)
594
-
595
- loss_ce = nn.functional.cross_entropy(source_logits.transpose(1, 2), target_classes, self.empty_weight)
596
- losses = {"loss_ce": loss_ce}
597
-
598
- return losses
599
-
600
- @torch.no_grad()
601
- def loss_cardinality(self, outputs, targets, indices, num_boxes):
602
- """
603
- Compute the cardinality error, i.e. the absolute error in the number of predicted non-empty boxes.
604
-
605
- This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients.
606
- """
607
- logits = outputs["logits"]
608
- device = logits.device
609
- target_lengths = torch.as_tensor([len(v["class_labels"]) for v in targets], device=device)
610
- # Count the number of predictions that are NOT "no-object" (which is the last class)
611
- card_pred = (logits.argmax(-1) != logits.shape[-1] - 1).sum(1)
612
- card_err = nn.functional.l1_loss(card_pred.float(), target_lengths.float())
613
- losses = {"cardinality_error": card_err}
614
- return losses
615
-
616
- def loss_boxes(self, outputs, targets, indices, num_boxes):
617
- """
618
- Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss.
619
-
620
- Targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]. The target boxes
621
- are expected in format (center_x, center_y, w, h), normalized by the image size.
622
- """
623
- if "pred_boxes" not in outputs:
624
- raise KeyError("No predicted boxes found in outputs")
625
-
626
- idx = self._get_source_permutation_idx(indices)
627
- source_boxes = outputs["pred_boxes"][idx]
628
- target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0)
629
-
630
- losses = {}
631
-
632
- loss_bbox = nn.functional.l1_loss(source_boxes, target_boxes, reduction="none")
633
- losses["loss_bbox"] = loss_bbox.sum() / num_boxes
634
-
635
- loss_giou = 1 - torch.diag(generalized_box_iou(center_to_corners_format(source_boxes), center_to_corners_format(target_boxes)))
636
- losses["loss_giou"] = loss_giou.sum() / num_boxes
637
-
638
- return losses
639
-
640
- def loss_masks(self, outputs, targets, indices, num_boxes):
641
- """
642
- Compute the losses related to the masks: the focal loss and the dice loss.
643
-
644
- Targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w].
645
- """
646
- if "pred_masks" not in outputs:
647
- raise KeyError("No predicted masks found in outputs")
648
-
649
- source_idx = self._get_source_permutation_idx(indices)
650
- target_idx = self._get_target_permutation_idx(indices)
651
- source_masks = outputs["pred_masks"]
652
- source_masks = source_masks[source_idx]
653
- masks = [t["masks"] for t in targets]
654
-
655
- # TODO use valid to mask invalid areas due to padding in loss
656
- target_masks, valid = nested_tensor_from_tensor_list(masks).decompose()
657
- target_masks = target_masks.to(source_masks)
658
- target_masks = target_masks[target_idx]
659
-
660
- # upsample predictions to the target size
661
- source_masks = nn.functional.interpolate(
662
- source_masks[:, None], size=target_masks.shape[-2:], mode="bilinear", align_corners=False
663
- )
664
- source_masks = source_masks[:, 0].flatten(1)
665
-
666
- target_masks = target_masks.flatten(1)
667
- target_masks = target_masks.view(source_masks.shape)
668
- losses = {
669
- "loss_mask": sigmoid_focal_loss(source_masks, target_masks, num_boxes),
670
- "loss_dice": dice_loss(source_masks, target_masks, num_boxes),
671
- }
672
- return losses
673
-
674
- def _get_source_permutation_idx(self, indices):
675
- # permute predictions following indices
676
- batch_idx = torch.cat([torch.full_like(source, i) for i, (source, _) in enumerate(indices)])
677
- source_idx = torch.cat([source for (source, _) in indices])
678
- return batch_idx, source_idx
679
-
680
- def _get_target_permutation_idx(self, indices):
681
- # permute targets following indices
682
- batch_idx = torch.cat([torch.full_like(target, i) for i, (_, target) in enumerate(indices)])
683
- target_idx = torch.cat([target for (_, target) in indices])
684
- return batch_idx, target_idx
685
-
686
- def get_loss(self, loss, outputs, targets, indices, num_boxes):
687
- loss_map = {
688
- "labels": self.loss_labels,
689
- "cardinality": self.loss_cardinality,
690
- "boxes": self.loss_boxes,
691
- "masks": self.loss_masks,
692
- }
693
- if loss not in loss_map:
694
- raise ValueError(f"Loss {loss} not supported")
695
- return loss_map[loss](outputs, targets, indices, num_boxes)
696
-
697
- def forward(self, outputs, targets, indices):
698
- """
699
- This performs the loss computation.
700
-
701
- Args:
702
- outputs (`dict`, *optional*):
703
- Dictionary of tensors, see the output specification of the model for the format.
704
- targets (`List[dict]`, *optional*):
705
- List of dicts, such that `len(targets) == batch_size`. The expected keys in each dict depends on the
706
- losses applied, see each loss' doc.
707
- """
708
- outputs_without_aux = {k: v for k, v in outputs.items() if k != "auxiliary_outputs"}
709
-
710
- # ThangPM: Do NOT use bipartite matching --> Use the boxes selected by argmax for computing symmetric loss
711
- # Retrieve the matching between the outputs of the last layer and the targets
712
- # indices = self.matcher(outputs_without_aux, targets)
713
-
714
- # Compute the average number of target boxes across all nodes, for normalization purposes
715
- num_boxes = sum(len(t["class_labels"]) for t in targets)
716
- num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
717
- # (Niels): comment out function below, distributed training to be added
718
- # if is_dist_avail_and_initialized():
719
- # torch.distributed.all_reduce(num_boxes)
720
- # (Niels) in original implementation, num_boxes is divided by get_world_size()
721
- num_boxes = torch.clamp(num_boxes, min=1).item()
722
-
723
- # Compute all the requested losses
724
- losses = {}
725
- for loss in self.losses:
726
- losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes))
727
-
728
- # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
729
- if "auxiliary_outputs" in outputs:
730
- for i, auxiliary_outputs in enumerate(outputs["auxiliary_outputs"]):
731
- # indices = self.matcher(auxiliary_outputs, targets)
732
- for loss in self.losses:
733
- if loss == "masks":
734
- # Intermediate masks losses are too costly to compute, we ignore them.
735
- continue
736
- l_dict = self.get_loss(loss, auxiliary_outputs, targets, indices, num_boxes)
737
- l_dict = {k + f"_{i}": v for k, v in l_dict.items()}
738
- losses.update(l_dict)
739
-
740
- return losses
 
157
  config_class = OwlViTConfig
158
 
159
  def __init__(self, owlvit_det_model, num_classes, weight_dict, device, freeze_box_heads=False, train_box_heads_only=False, network_type=None, logits_from_teacher=False, finetuned: bool = False, custom_box_head: bool = False):
160
+ super().__init__()
161
 
162
  self.config = owlvit_det_model.config
163
  self.num_classes = num_classes
 
202
  losses += ["boxes"] if weight_dict["loss_bbox"] > 0 else []
203
  losses += ["labels"] if weight_dict["loss_ce"] > 0 else []
204
 
205
+ # self.criterion = DetrLoss(
206
+ # matcher=None,
207
+ # num_parts=self.num_parts,
208
+ # eos_coef=0.1, # Following facebook/detr-resnet-50
209
+ # losses=losses,
210
+ # )
211
 
212
  self.freeze_parameters(freeze_box_heads, train_box_heads_only)
213
  del owlvit_det_model
 
417
  topk_scores, topk_idxs = torch.topk(teacher_boxes_logits, k=1, dim=1)
418
 
419
  else:
420
+ text_embeds_parts = self.owlvit.get_text_features(**text_inputs_parts)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
421
 
422
  # # Embed images and text queries
423
  query_mask, text_embeds_parts = self._get_text_query_mask(text_inputs_parts, text_embeds_parts, batch_size)
 
445
  outputs_loss["logits"] = pred_logits_parts
446
  outputs_loss["pred_boxes"] = pred_boxes
447
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
448
  # Predict image-level classes (batch_size, num_patches, num_queries)
449
  image_text_logits, pred_logits, part_logits = self.cls_head(image_feats, text_embeds, topk_idxs)
450
+
451
+ return pred_logits, part_logits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
452
 
453
  def loss_symmetric(self, text_logits: torch.Tensor, image_logits: torch.Tensor, targets: torch.Tensor, box_labels: torch.Tensor = None) -> torch.Tensor:
454
  # text/image logits (batch_size*num_boxes, num_classes*num_descs): The logits that softmax over text descriptors or boxes
 
486
 
487
  return sym_loss
488
 
489
+ # class DetrLoss(nn.Module):
490
+ # """
491
+ # This class computes the losses for DetrForObjectDetection/DetrForSegmentation. The process happens in two steps: 1)
492
+ # we compute hungarian assignment between ground truth boxes and the outputs of the model 2) we supervise each pair
493
+ # of matched ground-truth / prediction (supervise class and box).
494
+
495
+ # A note on the `num_classes` argument (copied from original repo in detr.py): "the naming of the `num_classes`
496
+ # parameter of the criterion is somewhat misleading. It indeed corresponds to `max_obj_id` + 1, where `max_obj_id` is
497
+ # the maximum id for a class in your dataset. For example, COCO has a `max_obj_id` of 90, so we pass `num_classes` to
498
+ # be 91. As another example, for a dataset that has a single class with `id` 1, you should pass `num_classes` to be 2
499
+ # (`max_obj_id` + 1). For more details on this, check the following discussion
500
+ # https://github.com/facebookresearch/detr/issues/108#issuecomment-650269223"
501
+
502
+
503
+ # Args:
504
+ # matcher (`DetrHungarianMatcher`):
505
+ # Module able to compute a matching between targets and proposals.
506
+ # num_parts (`int`):
507
+ # Number of object categories, omitting the special no-object category.
508
+ # eos_coef (`float`):
509
+ # Relative classification weight applied to the no-object category.
510
+ # losses (`List[str]`):
511
+ # List of all the losses to be applied. See `get_loss` for a list of all available losses.
512
+ # """
513
+
514
+ # def __init__(self, matcher, num_parts, eos_coef, losses):
515
+ # super().__init__()
516
+ # self.matcher = matcher
517
+ # self.num_parts = num_parts
518
+ # self.eos_coef = eos_coef
519
+ # self.losses = losses
520
+
521
+ # # empty_weight = torch.ones(self.num_parts + 1)
522
+ # empty_weight = torch.ones(self.num_parts)
523
+ # empty_weight[-1] = self.eos_coef
524
+ # self.register_buffer("empty_weight", empty_weight)
525
+
526
+ # # removed logging parameter, which was part of the original implementation
527
+ # def loss_labels(self, outputs, targets, indices, num_boxes):
528
+ # """
529
+ # Classification loss (NLL) targets dicts must contain the key "class_labels" containing a tensor of dim
530
+ # [nb_target_boxes]
531
+ # """
532
+ # if "logits" not in outputs:
533
+ # raise KeyError("No logits were found in the outputs")
534
+ # source_logits = outputs["logits"]
535
+
536
+ # idx = self._get_source_permutation_idx(indices)
537
+ # # target_classes_o = torch.cat([t["class_labels"][J] for t, (_, J) in zip(targets, indices)])
538
+ # # target_classes = torch.full(source_logits.shape[:2], self.num_parts, dtype=torch.int64, device=source_logits.device)
539
+ # # target_classes[idx] = target_classes_o
540
+
541
+ # source_logits = source_logits[idx].view(len(indices), -1, self.num_parts)
542
+ # target_classes = torch.stack([t["class_labels"][J] for t, (_, J) in zip(targets, indices)], dim=0)
543
+
544
+ # loss_ce = nn.functional.cross_entropy(source_logits.transpose(1, 2), target_classes, self.empty_weight)
545
+ # losses = {"loss_ce": loss_ce}
546
+
547
+ # return losses
548
+
549
+ # @torch.no_grad()
550
+ # def loss_cardinality(self, outputs, targets, indices, num_boxes):
551
+ # """
552
+ # Compute the cardinality error, i.e. the absolute error in the number of predicted non-empty boxes.
553
+
554
+ # This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients.
555
+ # """
556
+ # logits = outputs["logits"]
557
+ # device = logits.device
558
+ # target_lengths = torch.as_tensor([len(v["class_labels"]) for v in targets], device=device)
559
+ # # Count the number of predictions that are NOT "no-object" (which is the last class)
560
+ # card_pred = (logits.argmax(-1) != logits.shape[-1] - 1).sum(1)
561
+ # card_err = nn.functional.l1_loss(card_pred.float(), target_lengths.float())
562
+ # losses = {"cardinality_error": card_err}
563
+ # return losses
564
+
565
+ # def loss_boxes(self, outputs, targets, indices, num_boxes):
566
+ # """
567
+ # Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss.
568
+
569
+ # Targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]. The target boxes
570
+ # are expected in format (center_x, center_y, w, h), normalized by the image size.
571
+ # """
572
+ # if "pred_boxes" not in outputs:
573
+ # raise KeyError("No predicted boxes found in outputs")
574
+
575
+ # idx = self._get_source_permutation_idx(indices)
576
+ # source_boxes = outputs["pred_boxes"][idx]
577
+ # target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0)
578
+
579
+ # losses = {}
580
+
581
+ # loss_bbox = nn.functional.l1_loss(source_boxes, target_boxes, reduction="none")
582
+ # losses["loss_bbox"] = loss_bbox.sum() / num_boxes
583
+
584
+ # loss_giou = 1 - torch.diag(generalized_box_iou(center_to_corners_format(source_boxes), center_to_corners_format(target_boxes)))
585
+ # losses["loss_giou"] = loss_giou.sum() / num_boxes
586
+
587
+ # return losses
588
+
589
+ # def loss_masks(self, outputs, targets, indices, num_boxes):
590
+ # """
591
+ # Compute the losses related to the masks: the focal loss and the dice loss.
592
+
593
+ # Targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w].
594
+ # """
595
+ # if "pred_masks" not in outputs:
596
+ # raise KeyError("No predicted masks found in outputs")
597
+
598
+ # source_idx = self._get_source_permutation_idx(indices)
599
+ # target_idx = self._get_target_permutation_idx(indices)
600
+ # source_masks = outputs["pred_masks"]
601
+ # source_masks = source_masks[source_idx]
602
+ # masks = [t["masks"] for t in targets]
603
+
604
+ # # TODO use valid to mask invalid areas due to padding in loss
605
+ # target_masks, valid = nested_tensor_from_tensor_list(masks).decompose()
606
+ # target_masks = target_masks.to(source_masks)
607
+ # target_masks = target_masks[target_idx]
608
+
609
+ # # upsample predictions to the target size
610
+ # source_masks = nn.functional.interpolate(
611
+ # source_masks[:, None], size=target_masks.shape[-2:], mode="bilinear", align_corners=False
612
+ # )
613
+ # source_masks = source_masks[:, 0].flatten(1)
614
+
615
+ # target_masks = target_masks.flatten(1)
616
+ # target_masks = target_masks.view(source_masks.shape)
617
+ # losses = {
618
+ # "loss_mask": sigmoid_focal_loss(source_masks, target_masks, num_boxes),
619
+ # "loss_dice": dice_loss(source_masks, target_masks, num_boxes),
620
+ # }
621
+ # return losses
622
+
623
+ # def _get_source_permutation_idx(self, indices):
624
+ # # permute predictions following indices
625
+ # batch_idx = torch.cat([torch.full_like(source, i) for i, (source, _) in enumerate(indices)])
626
+ # source_idx = torch.cat([source for (source, _) in indices])
627
+ # return batch_idx, source_idx
628
+
629
+ # def _get_target_permutation_idx(self, indices):
630
+ # # permute targets following indices
631
+ # batch_idx = torch.cat([torch.full_like(target, i) for i, (_, target) in enumerate(indices)])
632
+ # target_idx = torch.cat([target for (_, target) in indices])
633
+ # return batch_idx, target_idx
634
+
635
+ # def get_loss(self, loss, outputs, targets, indices, num_boxes):
636
+ # loss_map = {
637
+ # "labels": self.loss_labels,
638
+ # "cardinality": self.loss_cardinality,
639
+ # "boxes": self.loss_boxes,
640
+ # "masks": self.loss_masks,
641
+ # }
642
+ # if loss not in loss_map:
643
+ # raise ValueError(f"Loss {loss} not supported")
644
+ # return loss_map[loss](outputs, targets, indices, num_boxes)
645
+
646
+ # def forward(self, outputs, targets, indices):
647
+ # """
648
+ # This performs the loss computation.
649
+
650
+ # Args:
651
+ # outputs (`dict`, *optional*):
652
+ # Dictionary of tensors, see the output specification of the model for the format.
653
+ # targets (`List[dict]`, *optional*):
654
+ # List of dicts, such that `len(targets) == batch_size`. The expected keys in each dict depends on the
655
+ # losses applied, see each loss' doc.
656
+ # """
657
+ # outputs_without_aux = {k: v for k, v in outputs.items() if k != "auxiliary_outputs"}
658
+
659
+ # # ThangPM: Do NOT use bipartite matching --> Use the boxes selected by argmax for computing symmetric loss
660
+ # # Retrieve the matching between the outputs of the last layer and the targets
661
+ # # indices = self.matcher(outputs_without_aux, targets)
662
+
663
+ # # Compute the average number of target boxes across all nodes, for normalization purposes
664
+ # num_boxes = sum(len(t["class_labels"]) for t in targets)
665
+ # num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
666
+ # # (Niels): comment out function below, distributed training to be added
667
+ # # if is_dist_avail_and_initialized():
668
+ # # torch.distributed.all_reduce(num_boxes)
669
+ # # (Niels) in original implementation, num_boxes is divided by get_world_size()
670
+ # num_boxes = torch.clamp(num_boxes, min=1).item()
671
+
672
+ # # Compute all the requested losses
673
+ # losses = {}
674
+ # for loss in self.losses:
675
+ # losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes))
676
+
677
+ # # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
678
+ # if "auxiliary_outputs" in outputs:
679
+ # for i, auxiliary_outputs in enumerate(outputs["auxiliary_outputs"]):
680
+ # # indices = self.matcher(auxiliary_outputs, targets)
681
+ # for loss in self.losses:
682
+ # if loss == "masks":
683
+ # # Intermediate masks losses are too costly to compute, we ignore them.
684
+ # continue
685
+ # l_dict = self.get_loss(loss, auxiliary_outputs, targets, indices, num_boxes)
686
+ # l_dict = {k + f"_{i}": v for k, v in l_dict.items()}
687
+ # losses.update(l_dict)
688
+
689
+ # return losses
utils/predict.py CHANGED
@@ -112,7 +112,7 @@ def xclip_pred(new_desc: dict,
112
  image_input = owlvit_processor(images=image, return_tensors='pt').to(device)
113
  image_embeds, _ = model.image_embedder(pixel_values = image_input['pixel_values'])
114
 
115
- pred_logits, part_logits, output_dict = model(image_embeds, part_embeds, query_embeds, None)
116
 
117
  b, c, n = part_logits.shape
118
  mask = torch.tensor(desc_mask, dtype=float).unsqueeze(0).unsqueeze(0).repeat(b, c, 1).to(device)
 
112
  image_input = owlvit_processor(images=image, return_tensors='pt').to(device)
113
  image_embeds, _ = model.image_embedder(pixel_values = image_input['pixel_values'])
114
 
115
+ pred_logits, part_logits = model(image_embeds, part_embeds, query_embeds, None)
116
 
117
  b, c, n = part_logits.shape
118
  mask = torch.tensor(desc_mask, dtype=float).unsqueeze(0).unsqueeze(0).repeat(b, c, 1).to(device)