PEEB / utils /model.py
Peijie's picture
fix bug and tested locally.
5ac407d
import copy
from typing import Optional, Tuple
import torch
import torch.nn.functional as F
import numpy as np
from torch import nn
from transformers import OwlViTConfig
# from transformers.models.owlvit.modeling_owlvit import OwlViTVisionTransformer
class OwlViTBoxPredictionHead(nn.Module):
def __init__(self, config: OwlViTConfig):
super().__init__()
width = config.vision_config.hidden_size
self.dense0 = nn.Linear(width, width)
self.dense1 = nn.Linear(width, width)
self.dense2 = nn.Linear(width, width)
self.dense3 = nn.Linear(width, width)
self.gelu = nn.GELU()
self.dense4 = nn.Linear(width, 4)
def forward(self, image_features: torch.Tensor) -> torch.FloatTensor:
output = self.dense0(image_features)
output = self.gelu(output)
output = self.dense1(output)
output = self.gelu(output)
output = self.dense2(output)
output = self.gelu(output)
output = self.dense3(output)
output = self.gelu(output)
output = self.dense4(output)
output = self.gelu(output)
return output
class OwlViTClassPredictionHead(nn.Module):
def __init__(self, config: OwlViTConfig):
super().__init__()
out_dim = config.text_config.hidden_size
self.query_dim = config.vision_config.hidden_size
self.dense0 = nn.Linear(self.query_dim, out_dim)
self.logit_shift = nn.Linear(self.query_dim, 1)
self.logit_scale = nn.Linear(self.query_dim, 1)
self.elu = nn.ELU()
def forward(
self,
image_embeds: torch.FloatTensor,
query_embeds: Optional[torch.FloatTensor],
query_mask: Optional[torch.Tensor],
) -> Tuple[torch.FloatTensor]:
image_class_embeds = self.dense0(image_embeds)
if query_embeds is None:
device = image_class_embeds.device
batch_size, num_patches = image_class_embeds.shape[:2]
pred_logits = torch.zeros((batch_size, num_patches, self.query_dim)).to(device)
return (pred_logits, image_class_embeds)
# Normalize image and text features
image_class_embeds = F.normalize(image_class_embeds, dim=-1) + 1e-6
query_embeds = F.normalize(query_embeds, dim=-1) + 1e-6
# Get class predictions
pred_logits = torch.einsum("...pd,...qd->...pq", image_class_embeds, query_embeds)
# Apply a learnable shift and scale to logits
logit_shift = self.logit_shift(image_embeds)
logit_scale = self.logit_scale(image_embeds)
logit_scale = self.elu(logit_scale) + 1
pred_logits = (pred_logits + logit_shift) * logit_scale
if query_mask is not None:
if query_mask.ndim > 1:
query_mask = torch.unsqueeze(query_mask, dim=-2)
pred_logits = pred_logits.to(torch.float64)
pred_logits = torch.where(query_mask == 0, -1e6, pred_logits)
pred_logits = pred_logits.to(torch.float32)
return (pred_logits, image_class_embeds)
class OwlViTPredictionHead(nn.Module):
def __init__(self, config: OwlViTConfig, num_classes: int, finetuned: bool):
super().__init__()
out_dim = config.text_config.hidden_size
self.query_dim = config.vision_config.hidden_size
self.finetuned = finetuned
self.num_classes = num_classes
self.mlp_image = nn.Sequential(
nn.Flatten(),
nn.Linear(in_features=self.query_dim, out_features=self.query_dim),
nn.GELU(),
nn.Linear(in_features=self.query_dim, out_features=self.query_dim),
nn.GELU(),
nn.Linear(in_features=self.query_dim, out_features=out_dim),
nn.GELU(),
)
# if self.finetuned:
# self.cls_head = nn.Sequential(
# nn.GELU(),
# nn.Linear(in_features=out_dim, out_features=out_dim),
# nn.GELU()
# )
def forward(self,
image_embeds: torch.FloatTensor,
query_embeds: torch.FloatTensor,
topk_idxs: torch.FloatTensor,
) -> Tuple[torch.FloatTensor]:
# Get class predictions: topk_idxs (batch_size, n_parts, 1), one_hot (batch_size, n_parts, n_patches*n_patches)
topk_idxs = torch.swapaxes(topk_idxs, 1, 2)
one_hot = torch.zeros(topk_idxs.shape[0], topk_idxs.shape[1], image_embeds.shape[1]).to(image_embeds.device).scatter_(2, topk_idxs, 1)
batch_size, n_parts = one_hot.shape[0], one_hot.shape[1]
# (batch_size, n_parts, 3600, 1) * (batch_size, 1, 3600, 1024) = (batch_size, n_parts, 3600, 1024).sum(dim=-2)
image_embeds = (one_hot.unsqueeze(-1) * image_embeds.unsqueeze(1)).sum(dim=-2)
# image_embeds = self.dense0(image_embeds) # (batch_size, n_patches, 1024) --> (.., .., 768)
image_embeds = self.mlp_image(image_embeds.view(-1, image_embeds.shape[-1])).view(batch_size, n_parts, -1)
query_embeds = query_embeds.view(batch_size, -1, query_embeds.shape[-1])
# if self.finetuned:
# image_embeds = self.cls_head(image_embeds)
# query_embeds = query_embeds.view(batch_size, -1, query_embeds.shape[-1])
# Normalize image and text features
image_embeds = F.normalize(image_embeds, dim=-1) + 1e-6 # (batch_size, n_parts, 768)
query_embeds = F.normalize(query_embeds, dim=-1) + 1e-6 # (batch_size, num_classes * n_parts, 768)
# Shape: torch.Size([bs, num_boxes, num_classes * num_parts])
image_text_logits = torch.einsum('bnd, bid -> bni', image_embeds, query_embeds)
image_text_logits_reshaped = image_text_logits.view(-1, image_text_logits.shape[-1])
# Shape: (bs, num_classes * num_parts, num_boxes) --> (bs, num_classes, num_parts, num_boxes)
pred_logits = image_text_logits.swapaxes(axis0=1, axis1=2).view(batch_size, self.num_classes, n_parts, -1)
pred_logits = torch.diagonal(pred_logits, dim1=-2, dim2=-1) # --> torch.Size([bs, num_classes, 12])
#DEBUG: try add sigmoid here to see if it helps. PEIJIE: It does not help.
# pred_logits = pred_logits.sigmoid()
# pred_logits = abs(pred_logits) # for debugging
final_pred_logits = torch.sum(pred_logits, dim=-1)
return (image_text_logits_reshaped, final_pred_logits, pred_logits)
class OwlViTForClassification(nn.Module):
config_class = OwlViTConfig
def __init__(self, owlvit_det_model, num_classes, weight_dict, device, freeze_box_heads=False, train_box_heads_only=False, network_type=None, logits_from_teacher=False, finetuned: bool = False, custom_box_head: bool = False):
super().__init__()
self.config = owlvit_det_model.config
self.num_classes = num_classes
self.num_parts = 12
self.device = device
self.sigmoid = nn.Sigmoid()
self.ce_loss = torch.nn.CrossEntropyLoss()
# Use CE loss for classification OR only train with contrastive loss
self.network_type = network_type
self.logits_from_teacher = logits_from_teacher
# Initialize OwlViT model from the teacher model
self.owlvit = copy.deepcopy(owlvit_det_model.owlvit)
self.layer_norm = copy.deepcopy(owlvit_det_model.layer_norm)
# For image-level classification
self.cls_head = OwlViTPredictionHead(self.config, self.num_classes, finetuned=finetuned)
# For box prediction
if custom_box_head:
self.box_head = OwlViTBoxPredictionHead(self.config)
else:
self.box_head = copy.deepcopy(owlvit_det_model.box_head)
# For box-level classification
# Why don't just:
# self.class_head = copy.deepcopy(owlvit_det_model.class_head)
self.class_head = OwlViTClassPredictionHead(self.config)
self.class_head.dense0.load_state_dict(owlvit_det_model.class_head.dense0.state_dict())
self.class_head.logit_shift.load_state_dict(owlvit_det_model.class_head.logit_shift.state_dict())
self.class_head.logit_scale.load_state_dict(owlvit_det_model.class_head.logit_scale.state_dict())
# OwlViT: set equal weights for the bounding box, gIoU and classification losses
# self.matcher = DetrHungarianMatcher(class_cost=1, bbox_cost=1, giou_cost=1)
# Losses for the criterion in DETR/OwlViT
self.weight_dict = weight_dict
losses = ["cardinality"]
losses += ["boxes"] if weight_dict["loss_bbox"] > 0 else []
losses += ["labels"] if weight_dict["loss_ce"] > 0 else []
# self.criterion = DetrLoss(
# matcher=None,
# num_parts=self.num_parts,
# eos_coef=0.1, # Following facebook/detr-resnet-50
# losses=losses,
# )
self.freeze_parameters(freeze_box_heads, train_box_heads_only)
del owlvit_det_model
def freeze_parameters(self, freeze_box_heads, train_box_heads_only):
# OwlViT's text encoder is frozen by default
for param in self.owlvit.text_model.parameters():
param.requires_grad = False
for param in self.owlvit.text_projection.parameters():
param.requires_grad = False
# SKIP finetuning box heads
if freeze_box_heads:
for param in self.box_head.parameters():
param.requires_grad = False
for param in self.class_head.parameters():
param.requires_grad = False
# SKIP finetuning vision encoder and MLP head for classification --> Adjust weights of box heads only
if train_box_heads_only:
for param in self.owlvit.parameters():
param.requires_grad = False
for param in self.layer_norm.parameters():
param.requires_grad = False
for param in self.cls_head.parameters():
param.requires_grad = False
def update_num_classes(self, num_classes):
self.num_classes = num_classes
self.cls_head.num_classes = num_classes
def image_text_embedder(self,
input_ids: torch.Tensor,
pixel_values: torch.FloatTensor,
attention_mask: torch.Tensor,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
) -> Tuple[torch.FloatTensor]:
# Encode text and image
outputs = self.owlvit(
pixel_values=pixel_values,
input_ids=input_ids,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
)
# Get image embeddings
last_hidden_state = outputs.vision_model_output[0] # 0: last_hidden_state; 1: pooled_output
image_embeds = self.owlvit.vision_model.post_layernorm(last_hidden_state)
# Resize class token
new_size = tuple(np.array(image_embeds.shape) - np.array((0, 1, 0)))
class_token_out = torch.broadcast_to(image_embeds[:, :1, :], new_size)
# Merge image embedding with class tokens
image_embeds = image_embeds[:, 1:, :] * class_token_out
image_embeds = self.layer_norm(image_embeds)
# Resize to [batch_size, num_patches, num_patches, hidden_size]
new_size = (
image_embeds.shape[0],
int(np.sqrt(image_embeds.shape[1])),
int(np.sqrt(image_embeds.shape[1])),
image_embeds.shape[-1],
)
image_embeds = image_embeds.reshape(new_size)
text_embeds = outputs[-4]
return (text_embeds, image_embeds, outputs)
def image_embedder(
self,
pixel_values: torch.FloatTensor
) -> Tuple[torch.FloatTensor]:
# Get OwlViTModel vision embeddings (same as CLIP)
vision_outputs = self.owlvit.vision_model(pixel_values=pixel_values, return_dict=True)
# Apply post_layernorm to last_hidden_state, return non-projected output
last_hidden_state = vision_outputs[0]
image_embeds = self.owlvit.vision_model.post_layernorm(last_hidden_state)
# Resize class token
new_size = tuple(np.array(image_embeds.shape) - np.array((0, 1, 0)))
class_token_out = torch.broadcast_to(image_embeds[:, :1, :], new_size)
# Merge image embedding with class tokens
image_embeds = image_embeds[:, 1:, :] * class_token_out
image_embeds = self.layer_norm(image_embeds)
# Resize to [batch_size, num_patches, num_patches, hidden_size]
new_size = (
image_embeds.shape[0],
int(np.sqrt(image_embeds.shape[1])),
int(np.sqrt(image_embeds.shape[1])),
image_embeds.shape[-1],
)
image_embeds = image_embeds.reshape(new_size)
return (image_embeds, vision_outputs)
def normalize_grid_corner_coordinates(self, feature_map: torch.FloatTensor):
# Computes normalized xy corner coordinates from feature_map.
if not feature_map.ndim == 4:
raise ValueError("Expected input shape is [batch_size, num_patches, num_patches, hidden_dim]")
device = feature_map.device
num_patches = feature_map.shape[1]
box_coordinates = np.stack(np.meshgrid(np.arange(1, num_patches + 1), np.arange(1, num_patches + 1)), axis=-1).astype(np.float32)
box_coordinates /= np.array([num_patches, num_patches], np.float32)
# Flatten (h, w, 2) -> (h*w, 2)
box_coordinates = box_coordinates.reshape(box_coordinates.shape[0] * box_coordinates.shape[1], box_coordinates.shape[2])
box_coordinates = torch.from_numpy(box_coordinates).to(device)
return box_coordinates
def compute_box_bias(self, feature_map: torch.FloatTensor) -> torch.FloatTensor:
# The box center is biased to its position on the feature grid
box_coordinates = self.normalize_grid_corner_coordinates(feature_map)
box_coordinates = torch.clip(box_coordinates, 0.0, 1.0)
# Unnormalize xy
box_coord_bias = torch.log(box_coordinates + 1e-4) - torch.log1p(-box_coordinates + 1e-4)
# The box size is biased to the patch size
box_size = torch.full_like(box_coord_bias, 1.0 / feature_map.shape[-2])
box_size_bias = torch.log(box_size + 1e-4) - torch.log1p(-box_size + 1e-4)
# Compute box bias
box_bias = torch.cat([box_coord_bias, box_size_bias], dim=-1)
return box_bias
def box_predictor(
self,
image_feats: torch.FloatTensor,
feature_map: torch.FloatTensor,
) -> torch.FloatTensor:
"""
Args:
image_feats:
Features extracted from the image, returned by the `image_text_embedder` method.
feature_map:
A spatial re-arrangement of image_features, also returned by the `image_text_embedder` method.
Returns:
pred_boxes:
List of predicted boxes (cxcywh normalized to 0, 1) nested within a dictionary.
"""
# Bounding box detection head [batch_size, num_boxes, 4].
pred_boxes = self.box_head(image_feats)
# Compute the location of each token on the grid and use it to compute a bias for the bbox prediction
pred_boxes += self.compute_box_bias(feature_map)
pred_boxes = self.sigmoid(pred_boxes)
return pred_boxes
def class_predictor(
self,
image_feats: torch.FloatTensor,
query_embeds: Optional[torch.FloatTensor] = None,
query_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.FloatTensor]:
"""
Args:
image_feats:
Features extracted from the `image_text_embedder`.
query_embeds:
Text query embeddings.
query_mask:
Must be provided with query_embeddings. A mask indicating which query embeddings are valid.
"""
(pred_logits, image_class_embeds) = self.class_head(image_feats, query_embeds, query_mask)
return (pred_logits, image_class_embeds)
def _get_text_query_mask(self, text_inputs, text_embeds, batch_size: int):
# Embed images and text queries
input_ids = text_inputs["input_ids"]
# Reshape from [batch_size * max_text_queries, hidden_dim] -> [batch_size, max_text_queries, hidden_dim]
max_text_queries = input_ids.shape[0] // batch_size
text_embeds = text_embeds.reshape(batch_size, max_text_queries, text_embeds.shape[-1])
# If first token is 0, then this is a padded query [batch_size, num_queries].
input_ids = input_ids.reshape(batch_size, max_text_queries, input_ids.shape[-1])
query_mask = input_ids[..., 0] > 0
return query_mask, text_embeds
def forward(self, image_inputs, text_inputs_parts, text_embeds, targets: dict = None):
# Store outputs for computing losses
loss_dict = {}
if not isinstance(image_inputs, torch.Tensor):
feature_map, _ = self.image_embedder(pixel_values = image_inputs['pixel_values'])
else:
feature_map = image_inputs
batch_size, num_patches, num_patches, hidden_dim = feature_map.shape
image_feats = torch.reshape(feature_map, (batch_size, num_patches * num_patches, hidden_dim))
if self.logits_from_teacher:
teacher_boxes_logits = torch.stack([target["logits"] for target in targets], dim=0).to(self.device)
topk_scores, topk_idxs = torch.topk(teacher_boxes_logits, k=1, dim=1)
else:
text_embeds_parts = self.owlvit.get_text_features(**text_inputs_parts)
# # Embed images and text queries
query_mask, text_embeds_parts = self._get_text_query_mask(text_inputs_parts, text_embeds_parts, batch_size)
# Predict object classes [batch_size, num_patches, num_queries+1]
pred_logits_parts, class_embeds = self.class_predictor(image_feats, text_embeds_parts, query_mask)
# Predict object boxes
pred_boxes = self.box_predictor(image_feats, feature_map)
# Get the top-1 predictions
scores = self.sigmoid(pred_logits_parts)
topk_scores, topk_idxs = torch.topk(scores, k=1, dim=1)
mapping_indices = [(selected_indices, torch.tensor(list(range(self.num_parts))).to(self.device)) for selected_indices in topk_idxs.squeeze(1)]
# get the selected_indexs for mapping_indices
selected_idxs = torch.stack([item[0].cpu() for item in mapping_indices])
loss_dict["pred_boxes"] = torch.gather(pred_boxes.cpu(), 1, selected_idxs.unsqueeze(-1).expand(*selected_idxs.shape, 4))
if targets is not None:
# ----------------------------------------------------------------------------------------
# Computing box + class + symmetric losses for box selection
# ----------------------------------------------------------------------------------------
outputs_loss = {}
outputs_loss["logits"] = pred_logits_parts
outputs_loss["pred_boxes"] = pred_boxes
# Predict image-level classes (batch_size, num_patches, num_queries)
image_text_logits, pred_logits, part_logits = self.cls_head(image_feats, text_embeds, topk_idxs)
return pred_logits, part_logits
def loss_symmetric(self, text_logits: torch.Tensor, image_logits: torch.Tensor, targets: torch.Tensor, box_labels: torch.Tensor = None) -> torch.Tensor:
# text/image logits (batch_size*num_boxes, num_classes*num_descs): The logits that softmax over text descriptors or boxes
# targets (batch_size, 1): The ground truth label of box-text pair for classification OR
# targets (batch_size, all_boxes, num_parts): The ground truth label of box-text pair for box selection
# box_labels (batch_size, num_boxes), 0 for no box, 1 for box
assert text_logits.shape == image_logits.shape
# For image classification
if image_logits.shape != targets.shape:
batch_size = targets.shape[0]
# get the matching labels (bs * 12, num_classes * num_parts)
default_box_labels = torch.kron(torch.ones(batch_size, self.num_classes), torch.eye(self.num_parts)).to(self.device)
if box_labels is None:
box_labels = default_box_labels.clone()
else:
# (batch_size, num_boxes) -> (bs * num_boxes, num_classes * num_parts)
box_labels = box_labels.view(-1, 1) * default_box_labels
# Create one-hot encoding of targets; matching_labels shape: (bs * 12, num_classes * num_parts)
target_one_hot = torch.zeros(batch_size, self.num_classes).to(self.device).scatter(1, targets.view(-1, 1), 1)
target_one_hot = torch.kron(target_one_hot, torch.ones(self.num_parts, self.num_parts).to(self.device))
matching_labels = target_one_hot * box_labels
else:
# For box selection: matching_labels shape: (bs, 576, num_parts)
values, indices = torch.max(targets, dim=1)
matching_labels = torch.zeros_like(targets).scatter(1, indices.unsqueeze(1), 1)
loss_i = F.binary_cross_entropy_with_logits(image_logits, matching_labels, reduction='mean')
loss_t = F.binary_cross_entropy_with_logits(text_logits, matching_labels, reduction='mean')
sym_loss = (loss_i + loss_t).mean()
return sym_loss
# class DetrLoss(nn.Module):
# """
# This class computes the losses for DetrForObjectDetection/DetrForSegmentation. The process happens in two steps: 1)
# we compute hungarian assignment between ground truth boxes and the outputs of the model 2) we supervise each pair
# of matched ground-truth / prediction (supervise class and box).
# A note on the `num_classes` argument (copied from original repo in detr.py): "the naming of the `num_classes`
# parameter of the criterion is somewhat misleading. It indeed corresponds to `max_obj_id` + 1, where `max_obj_id` is
# the maximum id for a class in your dataset. For example, COCO has a `max_obj_id` of 90, so we pass `num_classes` to
# be 91. As another example, for a dataset that has a single class with `id` 1, you should pass `num_classes` to be 2
# (`max_obj_id` + 1). For more details on this, check the following discussion
# https://github.com/facebookresearch/detr/issues/108#issuecomment-650269223"
# Args:
# matcher (`DetrHungarianMatcher`):
# Module able to compute a matching between targets and proposals.
# num_parts (`int`):
# Number of object categories, omitting the special no-object category.
# eos_coef (`float`):
# Relative classification weight applied to the no-object category.
# losses (`List[str]`):
# List of all the losses to be applied. See `get_loss` for a list of all available losses.
# """
# def __init__(self, matcher, num_parts, eos_coef, losses):
# super().__init__()
# self.matcher = matcher
# self.num_parts = num_parts
# self.eos_coef = eos_coef
# self.losses = losses
# # empty_weight = torch.ones(self.num_parts + 1)
# empty_weight = torch.ones(self.num_parts)
# empty_weight[-1] = self.eos_coef
# self.register_buffer("empty_weight", empty_weight)
# # removed logging parameter, which was part of the original implementation
# def loss_labels(self, outputs, targets, indices, num_boxes):
# """
# Classification loss (NLL) targets dicts must contain the key "class_labels" containing a tensor of dim
# [nb_target_boxes]
# """
# if "logits" not in outputs:
# raise KeyError("No logits were found in the outputs")
# source_logits = outputs["logits"]
# idx = self._get_source_permutation_idx(indices)
# # target_classes_o = torch.cat([t["class_labels"][J] for t, (_, J) in zip(targets, indices)])
# # target_classes = torch.full(source_logits.shape[:2], self.num_parts, dtype=torch.int64, device=source_logits.device)
# # target_classes[idx] = target_classes_o
# source_logits = source_logits[idx].view(len(indices), -1, self.num_parts)
# target_classes = torch.stack([t["class_labels"][J] for t, (_, J) in zip(targets, indices)], dim=0)
# loss_ce = nn.functional.cross_entropy(source_logits.transpose(1, 2), target_classes, self.empty_weight)
# losses = {"loss_ce": loss_ce}
# return losses
# @torch.no_grad()
# def loss_cardinality(self, outputs, targets, indices, num_boxes):
# """
# Compute the cardinality error, i.e. the absolute error in the number of predicted non-empty boxes.
# This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients.
# """
# logits = outputs["logits"]
# device = logits.device
# target_lengths = torch.as_tensor([len(v["class_labels"]) for v in targets], device=device)
# # Count the number of predictions that are NOT "no-object" (which is the last class)
# card_pred = (logits.argmax(-1) != logits.shape[-1] - 1).sum(1)
# card_err = nn.functional.l1_loss(card_pred.float(), target_lengths.float())
# losses = {"cardinality_error": card_err}
# return losses
# def loss_boxes(self, outputs, targets, indices, num_boxes):
# """
# Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss.
# Targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]. The target boxes
# are expected in format (center_x, center_y, w, h), normalized by the image size.
# """
# if "pred_boxes" not in outputs:
# raise KeyError("No predicted boxes found in outputs")
# idx = self._get_source_permutation_idx(indices)
# source_boxes = outputs["pred_boxes"][idx]
# target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0)
# losses = {}
# loss_bbox = nn.functional.l1_loss(source_boxes, target_boxes, reduction="none")
# losses["loss_bbox"] = loss_bbox.sum() / num_boxes
# loss_giou = 1 - torch.diag(generalized_box_iou(center_to_corners_format(source_boxes), center_to_corners_format(target_boxes)))
# losses["loss_giou"] = loss_giou.sum() / num_boxes
# return losses
# def loss_masks(self, outputs, targets, indices, num_boxes):
# """
# Compute the losses related to the masks: the focal loss and the dice loss.
# Targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w].
# """
# if "pred_masks" not in outputs:
# raise KeyError("No predicted masks found in outputs")
# source_idx = self._get_source_permutation_idx(indices)
# target_idx = self._get_target_permutation_idx(indices)
# source_masks = outputs["pred_masks"]
# source_masks = source_masks[source_idx]
# masks = [t["masks"] for t in targets]
# # TODO use valid to mask invalid areas due to padding in loss
# target_masks, valid = nested_tensor_from_tensor_list(masks).decompose()
# target_masks = target_masks.to(source_masks)
# target_masks = target_masks[target_idx]
# # upsample predictions to the target size
# source_masks = nn.functional.interpolate(
# source_masks[:, None], size=target_masks.shape[-2:], mode="bilinear", align_corners=False
# )
# source_masks = source_masks[:, 0].flatten(1)
# target_masks = target_masks.flatten(1)
# target_masks = target_masks.view(source_masks.shape)
# losses = {
# "loss_mask": sigmoid_focal_loss(source_masks, target_masks, num_boxes),
# "loss_dice": dice_loss(source_masks, target_masks, num_boxes),
# }
# return losses
# def _get_source_permutation_idx(self, indices):
# # permute predictions following indices
# batch_idx = torch.cat([torch.full_like(source, i) for i, (source, _) in enumerate(indices)])
# source_idx = torch.cat([source for (source, _) in indices])
# return batch_idx, source_idx
# def _get_target_permutation_idx(self, indices):
# # permute targets following indices
# batch_idx = torch.cat([torch.full_like(target, i) for i, (_, target) in enumerate(indices)])
# target_idx = torch.cat([target for (_, target) in indices])
# return batch_idx, target_idx
# def get_loss(self, loss, outputs, targets, indices, num_boxes):
# loss_map = {
# "labels": self.loss_labels,
# "cardinality": self.loss_cardinality,
# "boxes": self.loss_boxes,
# "masks": self.loss_masks,
# }
# if loss not in loss_map:
# raise ValueError(f"Loss {loss} not supported")
# return loss_map[loss](outputs, targets, indices, num_boxes)
# def forward(self, outputs, targets, indices):
# """
# This performs the loss computation.
# Args:
# outputs (`dict`, *optional*):
# Dictionary of tensors, see the output specification of the model for the format.
# targets (`List[dict]`, *optional*):
# List of dicts, such that `len(targets) == batch_size`. The expected keys in each dict depends on the
# losses applied, see each loss' doc.
# """
# outputs_without_aux = {k: v for k, v in outputs.items() if k != "auxiliary_outputs"}
# # ThangPM: Do NOT use bipartite matching --> Use the boxes selected by argmax for computing symmetric loss
# # Retrieve the matching between the outputs of the last layer and the targets
# # indices = self.matcher(outputs_without_aux, targets)
# # Compute the average number of target boxes across all nodes, for normalization purposes
# num_boxes = sum(len(t["class_labels"]) for t in targets)
# num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
# # (Niels): comment out function below, distributed training to be added
# # if is_dist_avail_and_initialized():
# # torch.distributed.all_reduce(num_boxes)
# # (Niels) in original implementation, num_boxes is divided by get_world_size()
# num_boxes = torch.clamp(num_boxes, min=1).item()
# # Compute all the requested losses
# losses = {}
# for loss in self.losses:
# losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes))
# # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
# if "auxiliary_outputs" in outputs:
# for i, auxiliary_outputs in enumerate(outputs["auxiliary_outputs"]):
# # indices = self.matcher(auxiliary_outputs, targets)
# for loss in self.losses:
# if loss == "masks":
# # Intermediate masks losses are too costly to compute, we ignore them.
# continue
# l_dict = self.get_loss(loss, auxiliary_outputs, targets, indices, num_boxes)
# l_dict = {k + f"_{i}": v for k, v in l_dict.items()}
# losses.update(l_dict)
# return losses