Spaces:
Running
Running
import torch.nn as nn | |
from .backbone.vit import ViT | |
from .head.topdown_heatmap_simple_head import TopdownHeatmapSimpleHead | |
__all__ = ['ViTPose'] | |
class ViTPose(nn.Module): | |
def __init__(self, cfg: dict) -> None: | |
super(ViTPose, self).__init__() | |
backbone_cfg = {k: v for k, v in cfg['backbone'].items() if k != 'type'} | |
head_cfg = {k: v for k, v in cfg['keypoint_head'].items() if k != 'type'} | |
self.backbone = ViT(**backbone_cfg) | |
self.keypoint_head = TopdownHeatmapSimpleHead(**head_cfg) | |
def forward_features(self, x): | |
return self.backbone(x) | |
def forward(self, x): | |
return self.keypoint_head(self.backbone(x)) | |