Spaces:
Running
on
Zero
Running
on
Zero
""" | |
Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR) | |
Copyright(c) 2023 lyuwenyu. All Rights Reserved. | |
""" | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torchvision | |
__all__ = ["DetDETRPostProcessor"] | |
from .box_revert import BoxProcessFormat, box_revert | |
def mod(a, b): | |
out = a - a // b * b | |
return out | |
class DetDETRPostProcessor(nn.Module): | |
def __init__( | |
self, | |
num_classes=80, | |
use_focal_loss=True, | |
num_top_queries=300, | |
box_process_format=BoxProcessFormat.RESIZE, | |
) -> None: | |
super().__init__() | |
self.use_focal_loss = use_focal_loss | |
self.num_top_queries = num_top_queries | |
self.num_classes = int(num_classes) | |
self.box_process_format = box_process_format | |
self.deploy_mode = False | |
def extra_repr(self) -> str: | |
return f"use_focal_loss={self.use_focal_loss}, num_classes={self.num_classes}, num_top_queries={self.num_top_queries}" | |
def forward(self, outputs, **kwargs): | |
logits, boxes = outputs["pred_logits"], outputs["pred_boxes"] | |
if self.use_focal_loss: | |
scores = F.sigmoid(logits) | |
scores, index = torch.topk(scores.flatten(1), self.num_top_queries, dim=-1) | |
labels = index % self.num_classes | |
# labels = mod(index, self.num_classes) # for tensorrt | |
index = index // self.num_classes | |
boxes = boxes.gather(dim=1, index=index.unsqueeze(-1).repeat(1, 1, boxes.shape[-1])) | |
else: | |
scores = F.softmax(logits)[:, :, :-1] | |
scores, labels = scores.max(dim=-1) | |
if scores.shape[1] > self.num_top_queries: | |
scores, index = torch.topk(scores, self.num_top_queries, dim=-1) | |
labels = torch.gather(labels, dim=1, index=index) | |
boxes = torch.gather( | |
boxes, dim=1, index=index.unsqueeze(-1).tile(1, 1, boxes.shape[-1]) | |
) | |
if kwargs is not None: | |
boxes = box_revert( | |
boxes, | |
in_fmt="cxcywh", | |
out_fmt="xyxy", | |
process_fmt=self.box_process_format, | |
normalized=True, | |
**kwargs, | |
) | |
# TODO for onnx export | |
if self.deploy_mode: | |
return labels, boxes, scores | |
results = [] | |
for lab, box, sco in zip(labels, boxes, scores): | |
result = dict(labels=lab, boxes=box, scores=sco) | |
results.append(result) | |
return results | |
def deploy( | |
self, | |
): | |
self.eval() | |
self.deploy_mode = True | |
return self | |