import torch import torch.nn as nn from mono.utils.comm import get_func from .__base_model__ import BaseDepthModel class DepthModel(BaseDepthModel): def __init__(self, cfg, criterions, **kwards): super(DepthModel, self).__init__(cfg, criterions) model_type = cfg.model.type self.training = True # def inference(self, data): # with torch.no_grad(): # pred_depth, _, confidence = self.inference(data) # return pred_depth, confidence def get_monodepth_model( cfg : dict, criterions: dict, **kwargs ) -> nn.Module: # config depth model model = DepthModel(cfg, criterions, **kwargs) #model.init_weights(load_imagenet_model, imagenet_ckpt_fpath) assert isinstance(model, nn.Module) return model def get_configured_monodepth_model( cfg: dict, criterions: dict, ) -> nn.Module: """ Args: @ configs: configures for the network. @ load_imagenet_model: whether to initialize from ImageNet-pretrained model. @ imagenet_ckpt_fpath: string representing path to file with weights to initialize model with. Returns: # model: depth model. """ model = get_monodepth_model(cfg, criterions) return model