Spaces:
Running
on
Zero
Running
on
Zero
fix bug and tested locally.
Browse files- utils/model.py +211 -262
- utils/predict.py +1 -1
utils/model.py
CHANGED
@@ -157,7 +157,7 @@ class OwlViTForClassification(nn.Module):
|
|
157 |
config_class = OwlViTConfig
|
158 |
|
159 |
def __init__(self, owlvit_det_model, num_classes, weight_dict, device, freeze_box_heads=False, train_box_heads_only=False, network_type=None, logits_from_teacher=False, finetuned: bool = False, custom_box_head: bool = False):
|
160 |
-
super(
|
161 |
|
162 |
self.config = owlvit_det_model.config
|
163 |
self.num_classes = num_classes
|
@@ -202,12 +202,12 @@ class OwlViTForClassification(nn.Module):
|
|
202 |
losses += ["boxes"] if weight_dict["loss_bbox"] > 0 else []
|
203 |
losses += ["labels"] if weight_dict["loss_ce"] > 0 else []
|
204 |
|
205 |
-
self.criterion = DetrLoss(
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
)
|
211 |
|
212 |
self.freeze_parameters(freeze_box_heads, train_box_heads_only)
|
213 |
del owlvit_det_model
|
@@ -417,22 +417,7 @@ class OwlViTForClassification(nn.Module):
|
|
417 |
topk_scores, topk_idxs = torch.topk(teacher_boxes_logits, k=1, dim=1)
|
418 |
|
419 |
else:
|
420 |
-
|
421 |
-
print(f"text_inputs_parts - input_ids: {text_inputs_parts['input_ids'].shape}. attention_mask : {text_inputs_parts['attention_mask'].shape}")
|
422 |
-
seq_length = text_inputs_parts['input_ids'].shape[-1]
|
423 |
-
position_ids = self.owlvit.text_model.embeddings.position_ids[:, :seq_length]
|
424 |
-
txt_embeds = self.owlvit.text_model.embeddings.token_embedding(text_inputs_parts['input_ids'])
|
425 |
-
print(f"position_embedding: {self.owlvit.text_model.embeddings.position_embedding(position_ids).shape}")
|
426 |
-
print(f"text_embeds: {txt_embeds.shape}")
|
427 |
-
|
428 |
-
device_ = txt_embeds.device
|
429 |
-
position_ids = position_ids.to(device_)
|
430 |
-
txt_embeds_size_0 = text_embeds.size(0)
|
431 |
-
position_embedding = position_ids.cpu().repeat(txt_embeds_size_0, 1, 1)
|
432 |
-
text_inputs_parts["position_ids"] = position_ids
|
433 |
-
print(f"position_embedding : {position_embedding.shape}")
|
434 |
-
print(f"pos + emb: {(txt_embeds.cpu() + position_embedding).shape}")
|
435 |
-
text_embeds_parts = self.owlvit.text_model.get_text_features(**text_inputs_parts)
|
436 |
|
437 |
# # Embed images and text queries
|
438 |
query_mask, text_embeds_parts = self._get_text_query_mask(text_inputs_parts, text_embeds_parts, batch_size)
|
@@ -460,46 +445,10 @@ class OwlViTForClassification(nn.Module):
|
|
460 |
outputs_loss["logits"] = pred_logits_parts
|
461 |
outputs_loss["pred_boxes"] = pred_boxes
|
462 |
|
463 |
-
# Compute box + class losses
|
464 |
-
loss_dict = self.criterion(outputs_loss, targets, mapping_indices)
|
465 |
-
|
466 |
-
# Compute symmetric loss to get rid of the teacher model
|
467 |
-
logits_per_image = torch.softmax(pred_logits_parts, dim=1)
|
468 |
-
logits_per_text = torch.softmax(pred_logits_parts, dim=-1)
|
469 |
-
|
470 |
-
# For getting rid of the teacher model
|
471 |
-
if self.weight_dict["loss_sym_box_label"] > 0:
|
472 |
-
sym_loss_box_label = self.loss_symmetric(logits_per_image, logits_per_text, teacher_boxes_logits)
|
473 |
-
loss_dict["loss_sym_box_label"] = sym_loss_box_label
|
474 |
-
# ----------------------------------------------------------------------------------------
|
475 |
-
|
476 |
-
#DEBUG:
|
477 |
-
print(f"im_features size: {image_feats.shape}, text_embeds size: {text_embeds.shape}")
|
478 |
-
print(f"im_features sum: {image_feats.sum().item()}, text_embeds sum: {text_embeds.sum().item()}")
|
479 |
# Predict image-level classes (batch_size, num_patches, num_queries)
|
480 |
image_text_logits, pred_logits, part_logits = self.cls_head(image_feats, text_embeds, topk_idxs)
|
481 |
-
|
482 |
-
|
483 |
-
print(f"image_text_logits sum: {image_text_logits.sum().item()}")
|
484 |
-
|
485 |
-
if self.weight_dict["loss_xclip"] > 0:
|
486 |
-
targets_cls = torch.tensor([target["targets_cls"] for target in targets]).unsqueeze(1).to(self.device)
|
487 |
-
if self.network_type == "classification":
|
488 |
-
one_hot = torch.zeros_like(pred_logits).scatter(1, targets_cls, 1).to(self.device)
|
489 |
-
cls_loss = self.ce_loss(pred_logits, one_hot)
|
490 |
-
loss_dict["loss_xclip"] = cls_loss
|
491 |
-
else:
|
492 |
-
# TODO: Need a linear classifier for this approach
|
493 |
-
# Compute symmetric loss for part-descriptor contrastive learning
|
494 |
-
logits_per_image = torch.softmax(image_text_logits, dim=0)
|
495 |
-
logits_per_text = torch.softmax(image_text_logits, dim=-1)
|
496 |
-
sym_loss = self.loss_symmetric(logits_per_image, logits_per_text, targets_cls)
|
497 |
-
loss_dict["loss_xclip"] = sym_loss
|
498 |
-
|
499 |
-
#DEBUG:
|
500 |
-
print(f"pred_logits size: {part_logits.shape}, pred_logits size: {part_logits.shape}")
|
501 |
-
print(f"part_logits sum: {pred_logits.sum().item()}, part_logits sum: {pred_logits.sum().item()}")
|
502 |
-
return pred_logits, part_logits, loss_dict
|
503 |
|
504 |
def loss_symmetric(self, text_logits: torch.Tensor, image_logits: torch.Tensor, targets: torch.Tensor, box_labels: torch.Tensor = None) -> torch.Tensor:
|
505 |
# text/image logits (batch_size*num_boxes, num_classes*num_descs): The logits that softmax over text descriptors or boxes
|
@@ -537,204 +486,204 @@ class OwlViTForClassification(nn.Module):
|
|
537 |
|
538 |
return sym_loss
|
539 |
|
540 |
-
class DetrLoss(nn.Module):
|
541 |
-
|
542 |
-
|
543 |
-
|
544 |
-
|
545 |
-
|
546 |
-
|
547 |
-
|
548 |
-
|
549 |
-
|
550 |
-
|
551 |
-
|
552 |
-
|
553 |
-
|
554 |
-
|
555 |
-
|
556 |
-
|
557 |
-
|
558 |
-
|
559 |
-
|
560 |
-
|
561 |
-
|
562 |
-
|
563 |
-
|
564 |
-
|
565 |
-
|
566 |
-
|
567 |
-
|
568 |
-
|
569 |
-
|
570 |
-
|
571 |
-
|
572 |
-
|
573 |
-
|
574 |
-
|
575 |
-
|
576 |
-
|
577 |
-
|
578 |
-
|
579 |
-
|
580 |
-
|
581 |
-
|
582 |
-
|
583 |
-
|
584 |
-
|
585 |
-
|
586 |
-
|
587 |
-
|
588 |
-
|
589 |
-
|
590 |
-
|
591 |
-
|
592 |
-
|
593 |
-
|
594 |
-
|
595 |
-
|
596 |
-
|
597 |
-
|
598 |
-
|
599 |
-
|
600 |
-
|
601 |
-
|
602 |
-
|
603 |
-
|
604 |
-
|
605 |
-
|
606 |
-
|
607 |
-
|
608 |
-
|
609 |
-
|
610 |
-
|
611 |
-
|
612 |
-
|
613 |
-
|
614 |
-
|
615 |
-
|
616 |
-
|
617 |
-
|
618 |
-
|
619 |
-
|
620 |
-
|
621 |
-
|
622 |
-
|
623 |
-
|
624 |
-
|
625 |
-
|
626 |
-
|
627 |
-
|
628 |
-
|
629 |
-
|
630 |
-
|
631 |
-
|
632 |
-
|
633 |
-
|
634 |
-
|
635 |
-
|
636 |
-
|
637 |
-
|
638 |
-
|
639 |
-
|
640 |
-
|
641 |
-
|
642 |
-
|
643 |
-
|
644 |
-
|
645 |
-
|
646 |
-
|
647 |
-
|
648 |
-
|
649 |
-
|
650 |
-
|
651 |
-
|
652 |
-
|
653 |
-
|
654 |
-
|
655 |
-
|
656 |
-
|
657 |
-
|
658 |
-
|
659 |
-
|
660 |
-
|
661 |
-
|
662 |
-
|
663 |
-
|
664 |
-
|
665 |
-
|
666 |
-
|
667 |
-
|
668 |
-
|
669 |
-
|
670 |
-
|
671 |
-
|
672 |
-
|
673 |
-
|
674 |
-
|
675 |
-
|
676 |
-
|
677 |
-
|
678 |
-
|
679 |
-
|
680 |
-
|
681 |
-
|
682 |
-
|
683 |
-
|
684 |
-
|
685 |
-
|
686 |
-
|
687 |
-
|
688 |
-
|
689 |
-
|
690 |
-
|
691 |
-
|
692 |
-
|
693 |
-
|
694 |
-
|
695 |
-
|
696 |
-
|
697 |
-
|
698 |
-
|
699 |
-
|
700 |
-
|
701 |
-
|
702 |
-
|
703 |
-
|
704 |
-
|
705 |
-
|
706 |
-
|
707 |
-
|
708 |
-
|
709 |
-
|
710 |
-
|
711 |
-
|
712 |
-
|
713 |
-
|
714 |
-
|
715 |
-
|
716 |
-
|
717 |
-
|
718 |
-
|
719 |
-
|
720 |
-
|
721 |
-
|
722 |
-
|
723 |
-
|
724 |
-
|
725 |
-
|
726 |
-
|
727 |
-
|
728 |
-
|
729 |
-
|
730 |
-
|
731 |
-
|
732 |
-
|
733 |
-
|
734 |
-
|
735 |
-
|
736 |
-
|
737 |
-
|
738 |
-
|
739 |
-
|
740 |
-
|
|
|
157 |
config_class = OwlViTConfig
|
158 |
|
159 |
def __init__(self, owlvit_det_model, num_classes, weight_dict, device, freeze_box_heads=False, train_box_heads_only=False, network_type=None, logits_from_teacher=False, finetuned: bool = False, custom_box_head: bool = False):
|
160 |
+
super().__init__()
|
161 |
|
162 |
self.config = owlvit_det_model.config
|
163 |
self.num_classes = num_classes
|
|
|
202 |
losses += ["boxes"] if weight_dict["loss_bbox"] > 0 else []
|
203 |
losses += ["labels"] if weight_dict["loss_ce"] > 0 else []
|
204 |
|
205 |
+
# self.criterion = DetrLoss(
|
206 |
+
# matcher=None,
|
207 |
+
# num_parts=self.num_parts,
|
208 |
+
# eos_coef=0.1, # Following facebook/detr-resnet-50
|
209 |
+
# losses=losses,
|
210 |
+
# )
|
211 |
|
212 |
self.freeze_parameters(freeze_box_heads, train_box_heads_only)
|
213 |
del owlvit_det_model
|
|
|
417 |
topk_scores, topk_idxs = torch.topk(teacher_boxes_logits, k=1, dim=1)
|
418 |
|
419 |
else:
|
420 |
+
text_embeds_parts = self.owlvit.get_text_features(**text_inputs_parts)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
421 |
|
422 |
# # Embed images and text queries
|
423 |
query_mask, text_embeds_parts = self._get_text_query_mask(text_inputs_parts, text_embeds_parts, batch_size)
|
|
|
445 |
outputs_loss["logits"] = pred_logits_parts
|
446 |
outputs_loss["pred_boxes"] = pred_boxes
|
447 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
448 |
# Predict image-level classes (batch_size, num_patches, num_queries)
|
449 |
image_text_logits, pred_logits, part_logits = self.cls_head(image_feats, text_embeds, topk_idxs)
|
450 |
+
|
451 |
+
return pred_logits, part_logits
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
452 |
|
453 |
def loss_symmetric(self, text_logits: torch.Tensor, image_logits: torch.Tensor, targets: torch.Tensor, box_labels: torch.Tensor = None) -> torch.Tensor:
|
454 |
# text/image logits (batch_size*num_boxes, num_classes*num_descs): The logits that softmax over text descriptors or boxes
|
|
|
486 |
|
487 |
return sym_loss
|
488 |
|
489 |
+
# class DetrLoss(nn.Module):
|
490 |
+
# """
|
491 |
+
# This class computes the losses for DetrForObjectDetection/DetrForSegmentation. The process happens in two steps: 1)
|
492 |
+
# we compute hungarian assignment between ground truth boxes and the outputs of the model 2) we supervise each pair
|
493 |
+
# of matched ground-truth / prediction (supervise class and box).
|
494 |
+
|
495 |
+
# A note on the `num_classes` argument (copied from original repo in detr.py): "the naming of the `num_classes`
|
496 |
+
# parameter of the criterion is somewhat misleading. It indeed corresponds to `max_obj_id` + 1, where `max_obj_id` is
|
497 |
+
# the maximum id for a class in your dataset. For example, COCO has a `max_obj_id` of 90, so we pass `num_classes` to
|
498 |
+
# be 91. As another example, for a dataset that has a single class with `id` 1, you should pass `num_classes` to be 2
|
499 |
+
# (`max_obj_id` + 1). For more details on this, check the following discussion
|
500 |
+
# https://github.com/facebookresearch/detr/issues/108#issuecomment-650269223"
|
501 |
+
|
502 |
+
|
503 |
+
# Args:
|
504 |
+
# matcher (`DetrHungarianMatcher`):
|
505 |
+
# Module able to compute a matching between targets and proposals.
|
506 |
+
# num_parts (`int`):
|
507 |
+
# Number of object categories, omitting the special no-object category.
|
508 |
+
# eos_coef (`float`):
|
509 |
+
# Relative classification weight applied to the no-object category.
|
510 |
+
# losses (`List[str]`):
|
511 |
+
# List of all the losses to be applied. See `get_loss` for a list of all available losses.
|
512 |
+
# """
|
513 |
+
|
514 |
+
# def __init__(self, matcher, num_parts, eos_coef, losses):
|
515 |
+
# super().__init__()
|
516 |
+
# self.matcher = matcher
|
517 |
+
# self.num_parts = num_parts
|
518 |
+
# self.eos_coef = eos_coef
|
519 |
+
# self.losses = losses
|
520 |
+
|
521 |
+
# # empty_weight = torch.ones(self.num_parts + 1)
|
522 |
+
# empty_weight = torch.ones(self.num_parts)
|
523 |
+
# empty_weight[-1] = self.eos_coef
|
524 |
+
# self.register_buffer("empty_weight", empty_weight)
|
525 |
+
|
526 |
+
# # removed logging parameter, which was part of the original implementation
|
527 |
+
# def loss_labels(self, outputs, targets, indices, num_boxes):
|
528 |
+
# """
|
529 |
+
# Classification loss (NLL) targets dicts must contain the key "class_labels" containing a tensor of dim
|
530 |
+
# [nb_target_boxes]
|
531 |
+
# """
|
532 |
+
# if "logits" not in outputs:
|
533 |
+
# raise KeyError("No logits were found in the outputs")
|
534 |
+
# source_logits = outputs["logits"]
|
535 |
+
|
536 |
+
# idx = self._get_source_permutation_idx(indices)
|
537 |
+
# # target_classes_o = torch.cat([t["class_labels"][J] for t, (_, J) in zip(targets, indices)])
|
538 |
+
# # target_classes = torch.full(source_logits.shape[:2], self.num_parts, dtype=torch.int64, device=source_logits.device)
|
539 |
+
# # target_classes[idx] = target_classes_o
|
540 |
+
|
541 |
+
# source_logits = source_logits[idx].view(len(indices), -1, self.num_parts)
|
542 |
+
# target_classes = torch.stack([t["class_labels"][J] for t, (_, J) in zip(targets, indices)], dim=0)
|
543 |
+
|
544 |
+
# loss_ce = nn.functional.cross_entropy(source_logits.transpose(1, 2), target_classes, self.empty_weight)
|
545 |
+
# losses = {"loss_ce": loss_ce}
|
546 |
+
|
547 |
+
# return losses
|
548 |
+
|
549 |
+
# @torch.no_grad()
|
550 |
+
# def loss_cardinality(self, outputs, targets, indices, num_boxes):
|
551 |
+
# """
|
552 |
+
# Compute the cardinality error, i.e. the absolute error in the number of predicted non-empty boxes.
|
553 |
+
|
554 |
+
# This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients.
|
555 |
+
# """
|
556 |
+
# logits = outputs["logits"]
|
557 |
+
# device = logits.device
|
558 |
+
# target_lengths = torch.as_tensor([len(v["class_labels"]) for v in targets], device=device)
|
559 |
+
# # Count the number of predictions that are NOT "no-object" (which is the last class)
|
560 |
+
# card_pred = (logits.argmax(-1) != logits.shape[-1] - 1).sum(1)
|
561 |
+
# card_err = nn.functional.l1_loss(card_pred.float(), target_lengths.float())
|
562 |
+
# losses = {"cardinality_error": card_err}
|
563 |
+
# return losses
|
564 |
+
|
565 |
+
# def loss_boxes(self, outputs, targets, indices, num_boxes):
|
566 |
+
# """
|
567 |
+
# Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss.
|
568 |
+
|
569 |
+
# Targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]. The target boxes
|
570 |
+
# are expected in format (center_x, center_y, w, h), normalized by the image size.
|
571 |
+
# """
|
572 |
+
# if "pred_boxes" not in outputs:
|
573 |
+
# raise KeyError("No predicted boxes found in outputs")
|
574 |
+
|
575 |
+
# idx = self._get_source_permutation_idx(indices)
|
576 |
+
# source_boxes = outputs["pred_boxes"][idx]
|
577 |
+
# target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0)
|
578 |
+
|
579 |
+
# losses = {}
|
580 |
+
|
581 |
+
# loss_bbox = nn.functional.l1_loss(source_boxes, target_boxes, reduction="none")
|
582 |
+
# losses["loss_bbox"] = loss_bbox.sum() / num_boxes
|
583 |
+
|
584 |
+
# loss_giou = 1 - torch.diag(generalized_box_iou(center_to_corners_format(source_boxes), center_to_corners_format(target_boxes)))
|
585 |
+
# losses["loss_giou"] = loss_giou.sum() / num_boxes
|
586 |
+
|
587 |
+
# return losses
|
588 |
+
|
589 |
+
# def loss_masks(self, outputs, targets, indices, num_boxes):
|
590 |
+
# """
|
591 |
+
# Compute the losses related to the masks: the focal loss and the dice loss.
|
592 |
+
|
593 |
+
# Targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w].
|
594 |
+
# """
|
595 |
+
# if "pred_masks" not in outputs:
|
596 |
+
# raise KeyError("No predicted masks found in outputs")
|
597 |
+
|
598 |
+
# source_idx = self._get_source_permutation_idx(indices)
|
599 |
+
# target_idx = self._get_target_permutation_idx(indices)
|
600 |
+
# source_masks = outputs["pred_masks"]
|
601 |
+
# source_masks = source_masks[source_idx]
|
602 |
+
# masks = [t["masks"] for t in targets]
|
603 |
+
|
604 |
+
# # TODO use valid to mask invalid areas due to padding in loss
|
605 |
+
# target_masks, valid = nested_tensor_from_tensor_list(masks).decompose()
|
606 |
+
# target_masks = target_masks.to(source_masks)
|
607 |
+
# target_masks = target_masks[target_idx]
|
608 |
+
|
609 |
+
# # upsample predictions to the target size
|
610 |
+
# source_masks = nn.functional.interpolate(
|
611 |
+
# source_masks[:, None], size=target_masks.shape[-2:], mode="bilinear", align_corners=False
|
612 |
+
# )
|
613 |
+
# source_masks = source_masks[:, 0].flatten(1)
|
614 |
+
|
615 |
+
# target_masks = target_masks.flatten(1)
|
616 |
+
# target_masks = target_masks.view(source_masks.shape)
|
617 |
+
# losses = {
|
618 |
+
# "loss_mask": sigmoid_focal_loss(source_masks, target_masks, num_boxes),
|
619 |
+
# "loss_dice": dice_loss(source_masks, target_masks, num_boxes),
|
620 |
+
# }
|
621 |
+
# return losses
|
622 |
+
|
623 |
+
# def _get_source_permutation_idx(self, indices):
|
624 |
+
# # permute predictions following indices
|
625 |
+
# batch_idx = torch.cat([torch.full_like(source, i) for i, (source, _) in enumerate(indices)])
|
626 |
+
# source_idx = torch.cat([source for (source, _) in indices])
|
627 |
+
# return batch_idx, source_idx
|
628 |
+
|
629 |
+
# def _get_target_permutation_idx(self, indices):
|
630 |
+
# # permute targets following indices
|
631 |
+
# batch_idx = torch.cat([torch.full_like(target, i) for i, (_, target) in enumerate(indices)])
|
632 |
+
# target_idx = torch.cat([target for (_, target) in indices])
|
633 |
+
# return batch_idx, target_idx
|
634 |
+
|
635 |
+
# def get_loss(self, loss, outputs, targets, indices, num_boxes):
|
636 |
+
# loss_map = {
|
637 |
+
# "labels": self.loss_labels,
|
638 |
+
# "cardinality": self.loss_cardinality,
|
639 |
+
# "boxes": self.loss_boxes,
|
640 |
+
# "masks": self.loss_masks,
|
641 |
+
# }
|
642 |
+
# if loss not in loss_map:
|
643 |
+
# raise ValueError(f"Loss {loss} not supported")
|
644 |
+
# return loss_map[loss](outputs, targets, indices, num_boxes)
|
645 |
+
|
646 |
+
# def forward(self, outputs, targets, indices):
|
647 |
+
# """
|
648 |
+
# This performs the loss computation.
|
649 |
+
|
650 |
+
# Args:
|
651 |
+
# outputs (`dict`, *optional*):
|
652 |
+
# Dictionary of tensors, see the output specification of the model for the format.
|
653 |
+
# targets (`List[dict]`, *optional*):
|
654 |
+
# List of dicts, such that `len(targets) == batch_size`. The expected keys in each dict depends on the
|
655 |
+
# losses applied, see each loss' doc.
|
656 |
+
# """
|
657 |
+
# outputs_without_aux = {k: v for k, v in outputs.items() if k != "auxiliary_outputs"}
|
658 |
+
|
659 |
+
# # ThangPM: Do NOT use bipartite matching --> Use the boxes selected by argmax for computing symmetric loss
|
660 |
+
# # Retrieve the matching between the outputs of the last layer and the targets
|
661 |
+
# # indices = self.matcher(outputs_without_aux, targets)
|
662 |
+
|
663 |
+
# # Compute the average number of target boxes across all nodes, for normalization purposes
|
664 |
+
# num_boxes = sum(len(t["class_labels"]) for t in targets)
|
665 |
+
# num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
|
666 |
+
# # (Niels): comment out function below, distributed training to be added
|
667 |
+
# # if is_dist_avail_and_initialized():
|
668 |
+
# # torch.distributed.all_reduce(num_boxes)
|
669 |
+
# # (Niels) in original implementation, num_boxes is divided by get_world_size()
|
670 |
+
# num_boxes = torch.clamp(num_boxes, min=1).item()
|
671 |
+
|
672 |
+
# # Compute all the requested losses
|
673 |
+
# losses = {}
|
674 |
+
# for loss in self.losses:
|
675 |
+
# losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes))
|
676 |
+
|
677 |
+
# # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
|
678 |
+
# if "auxiliary_outputs" in outputs:
|
679 |
+
# for i, auxiliary_outputs in enumerate(outputs["auxiliary_outputs"]):
|
680 |
+
# # indices = self.matcher(auxiliary_outputs, targets)
|
681 |
+
# for loss in self.losses:
|
682 |
+
# if loss == "masks":
|
683 |
+
# # Intermediate masks losses are too costly to compute, we ignore them.
|
684 |
+
# continue
|
685 |
+
# l_dict = self.get_loss(loss, auxiliary_outputs, targets, indices, num_boxes)
|
686 |
+
# l_dict = {k + f"_{i}": v for k, v in l_dict.items()}
|
687 |
+
# losses.update(l_dict)
|
688 |
+
|
689 |
+
# return losses
|
utils/predict.py
CHANGED
@@ -112,7 +112,7 @@ def xclip_pred(new_desc: dict,
|
|
112 |
image_input = owlvit_processor(images=image, return_tensors='pt').to(device)
|
113 |
image_embeds, _ = model.image_embedder(pixel_values = image_input['pixel_values'])
|
114 |
|
115 |
-
pred_logits, part_logits
|
116 |
|
117 |
b, c, n = part_logits.shape
|
118 |
mask = torch.tensor(desc_mask, dtype=float).unsqueeze(0).unsqueeze(0).repeat(b, c, 1).to(device)
|
|
|
112 |
image_input = owlvit_processor(images=image, return_tensors='pt').to(device)
|
113 |
image_embeds, _ = model.image_embedder(pixel_values = image_input['pixel_values'])
|
114 |
|
115 |
+
pred_logits, part_logits = model(image_embeds, part_embeds, query_embeds, None)
|
116 |
|
117 |
b, c, n = part_logits.shape
|
118 |
mask = torch.tensor(desc_mask, dtype=float).unsqueeze(0).unsqueeze(0).repeat(b, c, 1).to(device)
|