|
import torch.nn as nn |
|
import torch_scatter |
|
|
|
from pointcept.models.losses import build_criteria |
|
from pointcept.models.utils.structure import Point |
|
from .builder import MODELS, build_model |
|
|
|
|
|
@MODELS.register_module() |
|
class DefaultSegmentor(nn.Module): |
|
def __init__(self, backbone=None, criteria=None): |
|
super().__init__() |
|
self.backbone = build_model(backbone) |
|
self.criteria = build_criteria(criteria) |
|
|
|
def forward(self, input_dict): |
|
if "condition" in input_dict.keys(): |
|
|
|
|
|
input_dict["condition"] = input_dict["condition"][0] |
|
seg_logits = self.backbone(input_dict) |
|
|
|
if self.training: |
|
loss = self.criteria(seg_logits, input_dict["segment"]) |
|
return dict(loss=loss) |
|
|
|
elif "segment" in input_dict.keys(): |
|
loss = self.criteria(seg_logits, input_dict["segment"]) |
|
return dict(loss=loss, seg_logits=seg_logits) |
|
|
|
else: |
|
return dict(seg_logits=seg_logits) |
|
|
|
|
|
@MODELS.register_module() |
|
class DefaultSegmentorV2(nn.Module): |
|
def __init__( |
|
self, |
|
num_classes, |
|
backbone_out_channels, |
|
backbone=None, |
|
criteria=None, |
|
): |
|
super().__init__() |
|
self.seg_head = ( |
|
nn.Linear(backbone_out_channels, num_classes) |
|
if num_classes > 0 |
|
else nn.Identity() |
|
) |
|
self.backbone = build_model(backbone) |
|
self.criteria = build_criteria(criteria) |
|
|
|
def forward(self, input_dict): |
|
point = Point(input_dict) |
|
point = self.backbone(point) |
|
|
|
|
|
if isinstance(point, Point): |
|
feat = point.feat |
|
else: |
|
feat = point |
|
seg_logits = self.seg_head(feat) |
|
|
|
if self.training: |
|
loss = self.criteria(seg_logits, input_dict["segment"]) |
|
return dict(loss=loss) |
|
|
|
elif "segment" in input_dict.keys(): |
|
loss = self.criteria(seg_logits, input_dict["segment"]) |
|
return dict(loss=loss, seg_logits=seg_logits) |
|
|
|
else: |
|
return dict(seg_logits=seg_logits) |
|
|
|
|
|
@MODELS.register_module() |
|
class DefaultClassifier(nn.Module): |
|
def __init__( |
|
self, |
|
backbone=None, |
|
criteria=None, |
|
num_classes=40, |
|
backbone_embed_dim=256, |
|
): |
|
super().__init__() |
|
self.backbone = build_model(backbone) |
|
self.criteria = build_criteria(criteria) |
|
self.num_classes = num_classes |
|
self.backbone_embed_dim = backbone_embed_dim |
|
self.cls_head = nn.Sequential( |
|
nn.Linear(backbone_embed_dim, 256), |
|
nn.BatchNorm1d(256), |
|
nn.ReLU(inplace=True), |
|
nn.Dropout(p=0.5), |
|
nn.Linear(256, 128), |
|
nn.BatchNorm1d(128), |
|
nn.ReLU(inplace=True), |
|
nn.Dropout(p=0.5), |
|
nn.Linear(128, num_classes), |
|
) |
|
|
|
def forward(self, input_dict): |
|
point = Point(input_dict) |
|
point = self.backbone(point) |
|
|
|
|
|
|
|
if isinstance(point, Point): |
|
point.feat = torch_scatter.segment_csr( |
|
src=point.feat, |
|
indptr=nn.functional.pad(point.offset, (1, 0)), |
|
reduce="mean", |
|
) |
|
feat = point.feat |
|
else: |
|
feat = point |
|
cls_logits = self.cls_head(feat) |
|
if self.training: |
|
loss = self.criteria(cls_logits, input_dict["category"]) |
|
return dict(loss=loss) |
|
elif "category" in input_dict.keys(): |
|
loss = self.criteria(cls_logits, input_dict["category"]) |
|
return dict(loss=loss, cls_logits=cls_logits) |
|
else: |
|
return dict(cls_logits=cls_logits) |
|
|