Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,236 Bytes
717b269 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
import torch
import torch.nn as nn
import sys
import os.path as osp
import numpy as np
from yacs.config import CfgNode as CN
this_dir = osp.dirname(__file__)
sys.path.insert(0, this_dir)
import tgs
from network.pointnet import PointNetEncoder
hort_cfg = CN()
hort_cfg.image_tokenizer_cls = "tgs.models.tokenizers.image.DINOV2SingleImageTokenizer"
hort_cfg.image_tokenizer = CN()
hort_cfg.image_tokenizer.pretrained_model_name_or_path = "facebook/dinov2-large"
hort_cfg.image_tokenizer.width = 224
hort_cfg.image_tokenizer.height = 224
hort_cfg.image_tokenizer.modulation = False
hort_cfg.image_tokenizer.modulation_zero_init = True
hort_cfg.image_tokenizer.modulation_cond_dim = 1024
hort_cfg.image_tokenizer.freeze_backbone_params = False
hort_cfg.image_tokenizer.enable_memory_efficient_attention = False
hort_cfg.image_tokenizer.enable_gradient_checkpointing = False
hort_cfg.tokenizer_cls = "tgs.models.tokenizers.point.PointLearnablePositionalEmbedding"
hort_cfg.tokenizer = CN()
hort_cfg.tokenizer.num_pcl = 2049
hort_cfg.tokenizer.num_channels = 512
hort_cfg.backbone_cls = "tgs.models.transformers.Transformer1D"
hort_cfg.backbone = CN()
hort_cfg.backbone.in_channels = 512
hort_cfg.backbone.num_attention_heads = 8
hort_cfg.backbone.attention_head_dim = 64
hort_cfg.backbone.num_layers = 10
hort_cfg.backbone.cross_attention_dim = 1024
hort_cfg.backbone.norm_type = "layer_norm"
hort_cfg.backbone.enable_memory_efficient_attention = False
hort_cfg.backbone.gradient_checkpointing = False
hort_cfg.post_processor_cls = "tgs.models.networks.PointOutLayer"
hort_cfg.post_processor = CN()
hort_cfg.post_processor.in_channels = 512
hort_cfg.post_processor.out_channels = 3
hort_cfg.pointcloud_upsampler_cls = "tgs.models.snowflake.model_spdpp.SnowflakeModelSPDPP"
hort_cfg.pointcloud_upsampler = CN()
hort_cfg.pointcloud_upsampler.input_channels = 1024
hort_cfg.pointcloud_upsampler.dim_feat = 128
hort_cfg.pointcloud_upsampler.num_p0 = 2048
hort_cfg.pointcloud_upsampler.radius = 1
hort_cfg.pointcloud_upsampler.bounding = True
hort_cfg.pointcloud_upsampler.use_fps = True
hort_cfg.pointcloud_upsampler.up_factors = [2, 4]
hort_cfg.pointcloud_upsampler.token_type = "image_token"
class model(nn.Module):
def __init__(self):
super(model, self).__init__()
self.image_tokenizer = tgs.find(hort_cfg.image_tokenizer_cls)(hort_cfg.image_tokenizer)
self.pointnet = PointNetEncoder(67, 1024)
self.tokenizer = tgs.find(hort_cfg.tokenizer_cls)(hort_cfg.tokenizer)
self.backbone = tgs.find(hort_cfg.backbone_cls)(hort_cfg.backbone)
self.post_processor = tgs.find(hort_cfg.post_processor_cls)(hort_cfg.post_processor)
self.post_processor_trans = nn.Sequential(nn.Linear(512, 256), nn.ReLU(), nn.Linear(256, 128), nn.ReLU(), nn.Linear(128, 3))
self.pointcloud_upsampler = tgs.find(hort_cfg.pointcloud_upsampler_cls)(hort_cfg.pointcloud_upsampler)
def forward(self, input_img, metas):
with torch.no_grad():
batch_size = input_img.shape[0]
encoder_hidden_states = self.image_tokenizer(input_img, None) # B * C * Nt
encoder_hidden_states = encoder_hidden_states.transpose(2, 1) # B * Nt * C
palm_norm_hand_verts_3d = metas['right_hand_verts_3d'] - metas['right_hand_palm'].unsqueeze(1)
point_idx = torch.arange(778).view(1, 778, 1).expand(batch_size, -1, -1).to(input_img.device) / 778.
palm_norm_hand_verts_3d = torch.cat([palm_norm_hand_verts_3d, point_idx], -1)
tip_norm_hand_verts_3d = (metas['right_hand_verts_3d'].unsqueeze(2) - metas['right_hand_joints_3d'].unsqueeze(1)).reshape((batch_size, 778, -1))
norm_hand_verts_3d = torch.cat([palm_norm_hand_verts_3d, tip_norm_hand_verts_3d], -1)
hand_feats = self.pointnet(norm_hand_verts_3d)
tokens = self.tokenizer(batch_size)
tokens = self.backbone(tokens, torch.cat([encoder_hidden_states, hand_feats.unsqueeze(1)], 1), modulation_cond=None)
tokens = self.tokenizer.detokenize(tokens)
pointclouds = self.post_processor(tokens[:, :2048, :])
pred_obj_trans = self.post_processor_trans(tokens[:, -1, :])
upsampling_input = {
"input_image_tokens": encoder_hidden_states.permute(0, 2, 1),
"intrinsic_cond": metas['cam_intr'],
"points": pointclouds,
"hand_points": metas["right_hand_verts_3d"],
"trans": pred_obj_trans + metas['right_hand_palm'],
"scale": 0.3
}
up_results = self.pointcloud_upsampler(upsampling_input)
pointclouds_up = up_results[-1]
pc_results = {}
pc_results['pointclouds'] = pointclouds
pc_results['objtrans'] = pred_obj_trans
pc_results['handpalm'] = metas['right_hand_palm']
pc_results['pointclouds_up'] = pointclouds_up
return pc_results
def load_hort(ckpt_path):
hort_model = model()
ckpt = torch.load(ckpt_path, map_location=torch.device('cpu'))["network"]
ckpt = {k.replace('module.', ''): v for k, v in ckpt.items()}
hort_model.load_state_dict(ckpt)
return hort_model
|