import torch import torch.nn as nn import sys import os.path as osp import numpy as np from yacs.config import CfgNode as CN this_dir = osp.dirname(__file__) sys.path.insert(0, this_dir) import tgs from network.pointnet import PointNetEncoder hort_cfg = CN() hort_cfg.image_tokenizer_cls = "tgs.models.tokenizers.image.DINOV2SingleImageTokenizer" hort_cfg.image_tokenizer = CN() hort_cfg.image_tokenizer.pretrained_model_name_or_path = "facebook/dinov2-large" hort_cfg.image_tokenizer.width = 224 hort_cfg.image_tokenizer.height = 224 hort_cfg.image_tokenizer.modulation = False hort_cfg.image_tokenizer.modulation_zero_init = True hort_cfg.image_tokenizer.modulation_cond_dim = 1024 hort_cfg.image_tokenizer.freeze_backbone_params = False hort_cfg.image_tokenizer.enable_memory_efficient_attention = False hort_cfg.image_tokenizer.enable_gradient_checkpointing = False hort_cfg.tokenizer_cls = "tgs.models.tokenizers.point.PointLearnablePositionalEmbedding" hort_cfg.tokenizer = CN() hort_cfg.tokenizer.num_pcl = 2049 hort_cfg.tokenizer.num_channels = 512 hort_cfg.backbone_cls = "tgs.models.transformers.Transformer1D" hort_cfg.backbone = CN() hort_cfg.backbone.in_channels = 512 hort_cfg.backbone.num_attention_heads = 8 hort_cfg.backbone.attention_head_dim = 64 hort_cfg.backbone.num_layers = 10 hort_cfg.backbone.cross_attention_dim = 1024 hort_cfg.backbone.norm_type = "layer_norm" hort_cfg.backbone.enable_memory_efficient_attention = False hort_cfg.backbone.gradient_checkpointing = False hort_cfg.post_processor_cls = "tgs.models.networks.PointOutLayer" hort_cfg.post_processor = CN() hort_cfg.post_processor.in_channels = 512 hort_cfg.post_processor.out_channels = 3 hort_cfg.pointcloud_upsampler_cls = "tgs.models.snowflake.model_spdpp.SnowflakeModelSPDPP" hort_cfg.pointcloud_upsampler = CN() hort_cfg.pointcloud_upsampler.input_channels = 1024 hort_cfg.pointcloud_upsampler.dim_feat = 128 hort_cfg.pointcloud_upsampler.num_p0 = 2048 hort_cfg.pointcloud_upsampler.radius = 1 hort_cfg.pointcloud_upsampler.bounding = True hort_cfg.pointcloud_upsampler.use_fps = True hort_cfg.pointcloud_upsampler.up_factors = [2, 4] hort_cfg.pointcloud_upsampler.token_type = "image_token" class model(nn.Module): def __init__(self): super(model, self).__init__() self.image_tokenizer = tgs.find(hort_cfg.image_tokenizer_cls)(hort_cfg.image_tokenizer) self.pointnet = PointNetEncoder(67, 1024) self.tokenizer = tgs.find(hort_cfg.tokenizer_cls)(hort_cfg.tokenizer) self.backbone = tgs.find(hort_cfg.backbone_cls)(hort_cfg.backbone) self.post_processor = tgs.find(hort_cfg.post_processor_cls)(hort_cfg.post_processor) self.post_processor_trans = nn.Sequential(nn.Linear(512, 256), nn.ReLU(), nn.Linear(256, 128), nn.ReLU(), nn.Linear(128, 3)) self.pointcloud_upsampler = tgs.find(hort_cfg.pointcloud_upsampler_cls)(hort_cfg.pointcloud_upsampler) def forward(self, input_img, metas): with torch.no_grad(): batch_size = input_img.shape[0] encoder_hidden_states = self.image_tokenizer(input_img, None) # B * C * Nt encoder_hidden_states = encoder_hidden_states.transpose(2, 1) # B * Nt * C palm_norm_hand_verts_3d = metas['right_hand_verts_3d'] - metas['right_hand_palm'].unsqueeze(1) point_idx = torch.arange(778).view(1, 778, 1).expand(batch_size, -1, -1).to(input_img.device) / 778. palm_norm_hand_verts_3d = torch.cat([palm_norm_hand_verts_3d, point_idx], -1) tip_norm_hand_verts_3d = (metas['right_hand_verts_3d'].unsqueeze(2) - metas['right_hand_joints_3d'].unsqueeze(1)).reshape((batch_size, 778, -1)) norm_hand_verts_3d = torch.cat([palm_norm_hand_verts_3d, tip_norm_hand_verts_3d], -1) hand_feats = self.pointnet(norm_hand_verts_3d) tokens = self.tokenizer(batch_size) tokens = self.backbone(tokens, torch.cat([encoder_hidden_states, hand_feats.unsqueeze(1)], 1), modulation_cond=None) tokens = self.tokenizer.detokenize(tokens) pointclouds = self.post_processor(tokens[:, :2048, :]) pred_obj_trans = self.post_processor_trans(tokens[:, -1, :]) upsampling_input = { "input_image_tokens": encoder_hidden_states.permute(0, 2, 1), "intrinsic_cond": metas['cam_intr'], "points": pointclouds, "hand_points": metas["right_hand_verts_3d"], "trans": pred_obj_trans + metas['right_hand_palm'], "scale": 0.3 } up_results = self.pointcloud_upsampler(upsampling_input) pointclouds_up = up_results[-1] pc_results = {} pc_results['pointclouds'] = pointclouds pc_results['objtrans'] = pred_obj_trans pc_results['handpalm'] = metas['right_hand_palm'] pc_results['pointclouds_up'] = pointclouds_up return pc_results def load_hort(ckpt_path): hort_model = model() ckpt = torch.load(ckpt_path, map_location=torch.device('cpu'))["network"] ckpt = {k.replace('module.', ''): v for k, v in ckpt.items()} hort_model.load_state_dict(ckpt) return hort_model