""" Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR) Copyright(c) 2023 lyuwenyu. All Rights Reserved. """ import torch import torch.nn as nn from ...core import register __all__ = ["Classification", "ClassHead"] @register() class Classification(torch.nn.Module): __inject__ = ["backbone", "head"] def __init__(self, backbone: nn.Module, head: nn.Module = None): super().__init__() self.backbone = backbone self.head = head def forward(self, x): x = self.backbone(x) if self.head is not None: x = self.head(x) return x @register() class ClassHead(nn.Module): def __init__(self, hidden_dim, num_classes): super().__init__() self.pool = nn.AdaptiveAvgPool2d(1) self.proj = nn.Linear(hidden_dim, num_classes) def forward(self, x): x = x[0] if isinstance(x, (list, tuple)) else x x = self.pool(x) x = x.reshape(x.shape[0], -1) x = self.proj(x) return x