diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..a049828f0525233dcffc8a52d34dd8e2689373cd 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+assets/test1.png filter=lfs diff=lfs merge=lfs -text
+assets/test2.png filter=lfs diff=lfs merge=lfs -text
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..c0f55d6c64b31a249f0fe376a9038e4111cc3616
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,139 @@
+# Byte-compiled / optimized / DLL files
+__pycache__
+*.py[cod]
+*$py.class
+
+# pyc
+*.pyc
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+pip-wheel-metadata/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+.python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# VSCode
+.vscode
+
+*.swp
+*.h5
+*.mp4
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..e74e0a3fdc022a90a91a875481b4b5021ba0c085
--- /dev/null
+++ b/app.py
@@ -0,0 +1,202 @@
+import os
+import sys
+os.environ["PYOPENGL_PLATFORM"] = "egl"
+os.environ["MESA_GL_VERSION_OVERRIDE"] = "4.1"
+
+import gradio as gr
+#import spaces
+import cv2
+import numpy as np
+import torch
+from ultralytics import YOLO
+from pathlib import Path
+import argparse
+import json
+from torchvision import transforms
+from typing import Dict, Optional
+from PIL import Image, ImageDraw
+from lang_sam import LangSAM
+
+from wilor.models import load_wilor
+from wilor.utils import recursive_to
+from wilor.datasets.vitdet_dataset import ViTDetDataset
+from hort.models import load_hort
+from hort.utils.renderer import Renderer, cam_crop_to_new
+from hort.utils.img_utils import process_bbox, generate_patch_image, PerspectiveCamera
+from ultralytics import YOLO
+LIGHT_PURPLE=(0.25098039, 0.274117647, 0.65882353)
+STEEL_BLUE=(0.2745098, 0.5098039, 0.7058824)
+
+# Download and load checkpoints
+wilor_model, wilor_model_cfg = load_wilor(checkpoint_path = './pretrained_models/wilor_final.ckpt' , cfg_path= './pretrained_models/model_config.yaml')
+hand_detector = YOLO('./pretrained_models/detector.pt')
+# Setup the renderer
+renderer = Renderer(wilor_model_cfg, faces=wilor_model.mano.faces)
+# Setup the SAM model
+sam_model = LangSAM(sam_type="sam2.1_hiera_large")
+# Setup the HORT model
+hort_model = load_hort("./pretrained_models/hort_final.pth.tar")
+
+device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+wilor_model = wilor_model.to(device)
+hand_detector = hand_detector.to(device)
+hort_model = hort_model.to(device)
+wilor_model.eval()
+hort_model.eval()
+
+image_transform = transforms.Compose([transforms.ToPILImage(), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
+
+@spaces.GPU()
+def run_model(image, conf, IoU_threshold=0.5):
+ img_cv2 = image[..., ::-1]
+ img_pil = Image.fromarray(image)
+
+ pred_obj = sam_model.predict([img_pil], ["manipulated object"])
+ pred_hand = sam_model.predict([img_pil], ["hand"])
+
+ bbox_obj = pred_obj[0]["boxes"][0].reshape((-1, 2))
+ mask_obj = pred_obj[0]["masks"][0]
+ bbox_hand = pred_hand[0]["boxes"][0].reshape((-1, 2))
+ mask_hand = pred_hand[0]["masks"][0]
+
+ tl = np.min(np.concatenate([bbox_obj, bbox_hand], axis=0), axis=0)
+ br = np.max(np.concatenate([bbox_obj, bbox_hand], axis=0), axis=0)
+ box_size = br - tl
+ bbox = np.concatenate([tl - 10, box_size + 20], axis=0)
+ ho_bbox = process_bbox(bbox)
+
+ detections = hand_detector(img_cv2, conf=conf, verbose=False, iou=IoU_threshold)[0]
+
+ bboxes = []
+ is_right = []
+ for det in detections:
+ Bbox = det.boxes.data.cpu().detach().squeeze().numpy()
+ is_right.append(det.boxes.cls.cpu().detach().squeeze().item())
+ bboxes.append(Bbox[:4].tolist())
+
+ if len(bboxes) == 1:
+ boxes = np.stack(bboxes)
+ right = np.stack(is_right)
+ if not right:
+ new_x1 = img_cv2.shape[1] - boxes[0][2]
+ new_x2 = img_cv2.shape[1] - boxes[0][0]
+ boxes[0][0] = new_x1
+ boxes[0][2] = new_x2
+ ho_bbox[0] = img_cv2.shape[1] - (ho_bbox[0] + ho_bbox[2])
+ img_cv2 = cv2.flip(img_cv2, 1)
+ right[0] = 1.
+ crop_img_cv2, _ = generate_patch_image(img_cv2, ho_bbox, (224, 224), 0, 1.0, 0)
+
+ dataset = ViTDetDataset(wilor_model_cfg, img_cv2, boxes, right, rescale_factor=2.0)
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=16, shuffle=False, num_workers=0)
+
+ for batch in dataloader:
+ batch = recursive_to(batch, device)
+
+ with torch.no_grad():
+ out = wilor_model(batch)
+
+ pred_cam = out['pred_cam']
+ box_center = batch["box_center"].float()
+ box_size = batch["box_size"].float()
+ img_size = batch["img_size"].float()
+ scaled_focal_length = wilor_model_cfg.EXTRA.FOCAL_LENGTH / wilor_model_cfg.MODEL.IMAGE_SIZE * 224
+ pred_cam_t_full = cam_crop_to_new(pred_cam, box_center, box_size, img_size, torch.from_numpy(np.array(ho_bbox, dtype=np.float32))[None, :].to(img_size.device), scaled_focal_length).detach().cpu().numpy()
+
+ batch_size = batch['img'].shape[0]
+ for n in range(batch_size):
+ verts = out['pred_vertices'][n].detach().cpu().numpy()
+ joints = out['pred_keypoints_3d'][n].detach().cpu().numpy()
+
+ is_right = batch['right'][n].cpu().numpy()
+ palm = (verts[95] + verts[22]) / 2
+ cam_t = pred_cam_t_full[n]
+
+ img_input = image_transform(crop_img_cv2[:, :, ::-1]).unsqueeze(0).cuda()
+ camera = PerspectiveCamera(5000 / 256 * 224, 5000 / 256 * 224, 112, 112)
+ cam_intr = camera.intrinsics
+
+ metas = dict()
+ metas["right_hand_verts_3d"] = torch.from_numpy((verts + cam_t)[None]).cuda()
+ metas["right_hand_joints_3d"] = torch.from_numpy((joints + cam_t)[None]).cuda()
+ metas["right_hand_palm"] = torch.from_numpy((palm + cam_t)[None]).cuda()
+ metas["cam_intr"] = torch.from_numpy(cam_intr[None]).cuda()
+ with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
+ pc_results = hort_model(img_input, metas)
+ objtrans = pc_results["objtrans"][0].detach().cpu().numpy()
+ pointclouds_up = pc_results["pointclouds_up"][0].detach().cpu().numpy() * 0.3
+
+ reconstructions = {'verts': verts, 'palm': palm, 'objtrans': objtrans, 'objpcs': pointclouds_up, 'cam_t': cam_t, 'right': is_right, 'img_size': 224, 'focal': scaled_focal_length}
+
+ return crop_img_cv2[..., ::-1].astype(np.float32) / 255.0, len(detections), reconstructions
+ else:
+ return crop_img_cv2[..., ::-1].astype(np.float32) / 255.0, len(detections), None
+
+
+def render_reconstruction(image, conf, IoU_threshold=0.3):
+ input_img, num_dets, reconstructions = run_model(image, conf, IoU_threshold=0.5)
+ if num_dets == 1:
+ # Render front view
+ misc_args = dict(mesh_base_color=LIGHT_PURPLE, point_base_color=STEEL_BLUE, scene_bg_color=(1, 1, 1), focal_length=reconstructions['focal'])
+ cam_view = renderer.render_rgba(reconstructions['verts'], reconstructions['objpcs'] + reconstructions['palm'] + reconstructions['objtrans'], cam_t=reconstructions['cam_t'], render_res=(224, 224), is_right=True, **misc_args)
+
+ # Overlay image
+ input_img = np.concatenate([input_img, np.ones_like(input_img[:,:,:1])], axis=2) # Add alpha channel
+ input_img_overlay = input_img[:,:,:3] * (1-cam_view[:,:,3:]) + cam_view[:,:,:3] * cam_view[:,:,3:]
+
+ return input_img_overlay, f'{num_dets} hands detected'
+ else:
+ return input_img, f'{num_dets} hands detected'
+
+
+header = ('''
+
+

+

+

+

+''')
+
+
+with gr.Blocks(title="HORT: Monocular Hand-held Objects Reconstruction with Transformers", css=".gradio-container") as demo:
+
+ gr.Markdown(header)
+
+ with gr.Row():
+ with gr.Column():
+ input_image = gr.Image(label="Input image", type="numpy")
+ threshold = gr.Slider(value=0.3, minimum=0.05, maximum=0.95, step=0.05, label='Detection Confidence Threshold')
+ submit = gr.Button("Submit", variant="primary")
+
+
+ with gr.Column():
+ reconstruction = gr.Image(label="Reconstructions", type="numpy")
+ hands_detected = gr.Textbox(label="Hands Detected")
+
+ submit.click(fn=render_reconstruction, inputs=[input_image, threshold], outputs=[reconstruction, hands_detected])
+
+ with gr.Row():
+ example_images = gr.Examples([
+ ['/home/user/app/assets/test1.png'],
+ ['./demo_img/app/assets/test2.png'],
+ ['./demo_img/app/assets/test3.jpg'],
+ ['./demo_img/app/assets/test4.jpeg'],
+ ['./demo_img/app/assets/test5.jpeg']
+ ],
+ inputs=input_image)
+
+demo.launch(debug=True)
diff --git a/assets/test1.png b/assets/test1.png
new file mode 100644
index 0000000000000000000000000000000000000000..b0817a749ed4c22c8d322bf88489ced23fd16104
--- /dev/null
+++ b/assets/test1.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:220310a89f9777975b10d933eb6aef34c3fe036ae2f453c2e31d537f8827111b
+size 129676
diff --git a/assets/test2.png b/assets/test2.png
new file mode 100644
index 0000000000000000000000000000000000000000..e35549741cf978f8613588d1f2848e50665fae24
--- /dev/null
+++ b/assets/test2.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:29e4602efe21a483442c42a50ebf1c666c9e525dc630ec801a5af1d3acee18b1
+size 132948
diff --git a/assets/test3.jpg b/assets/test3.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..4ea7fd1b82a40e28c004473e041e63227ee94864
Binary files /dev/null and b/assets/test3.jpg differ
diff --git a/assets/test4.jpeg b/assets/test4.jpeg
new file mode 100644
index 0000000000000000000000000000000000000000..c6a2d1b4549a89faee6ec1dfe0b13742d52371de
Binary files /dev/null and b/assets/test4.jpeg differ
diff --git a/assets/test5.jpeg b/assets/test5.jpeg
new file mode 100644
index 0000000000000000000000000000000000000000..eeacc48a79a19f104ed6759024be896bd8d5cace
Binary files /dev/null and b/assets/test5.jpeg differ
diff --git a/hort/models/__init__.py b/hort/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6cc23c6b01efa0b05d7887fd2fa45c4a302b6187
--- /dev/null
+++ b/hort/models/__init__.py
@@ -0,0 +1,114 @@
+import torch
+import torch.nn as nn
+import sys
+import os.path as osp
+import numpy as np
+from yacs.config import CfgNode as CN
+this_dir = osp.dirname(__file__)
+sys.path.insert(0, this_dir)
+import tgs
+from network.pointnet import PointNetEncoder
+
+hort_cfg = CN()
+hort_cfg.image_tokenizer_cls = "tgs.models.tokenizers.image.DINOV2SingleImageTokenizer"
+hort_cfg.image_tokenizer = CN()
+hort_cfg.image_tokenizer.pretrained_model_name_or_path = "facebook/dinov2-large"
+hort_cfg.image_tokenizer.width = 224
+hort_cfg.image_tokenizer.height = 224
+hort_cfg.image_tokenizer.modulation = False
+hort_cfg.image_tokenizer.modulation_zero_init = True
+hort_cfg.image_tokenizer.modulation_cond_dim = 1024
+hort_cfg.image_tokenizer.freeze_backbone_params = False
+hort_cfg.image_tokenizer.enable_memory_efficient_attention = False
+hort_cfg.image_tokenizer.enable_gradient_checkpointing = False
+
+hort_cfg.tokenizer_cls = "tgs.models.tokenizers.point.PointLearnablePositionalEmbedding"
+hort_cfg.tokenizer = CN()
+hort_cfg.tokenizer.num_pcl = 2049
+hort_cfg.tokenizer.num_channels = 512
+
+hort_cfg.backbone_cls = "tgs.models.transformers.Transformer1D"
+hort_cfg.backbone = CN()
+hort_cfg.backbone.in_channels = 512
+hort_cfg.backbone.num_attention_heads = 8
+hort_cfg.backbone.attention_head_dim = 64
+hort_cfg.backbone.num_layers = 10
+hort_cfg.backbone.cross_attention_dim = 1024
+hort_cfg.backbone.norm_type = "layer_norm"
+hort_cfg.backbone.enable_memory_efficient_attention = False
+hort_cfg.backbone.gradient_checkpointing = False
+
+hort_cfg.post_processor_cls = "tgs.models.networks.PointOutLayer"
+hort_cfg.post_processor = CN()
+hort_cfg.post_processor.in_channels = 512
+hort_cfg.post_processor.out_channels = 3
+
+hort_cfg.pointcloud_upsampler_cls = "tgs.models.snowflake.model_spdpp.SnowflakeModelSPDPP"
+hort_cfg.pointcloud_upsampler = CN()
+hort_cfg.pointcloud_upsampler.input_channels = 1024
+hort_cfg.pointcloud_upsampler.dim_feat = 128
+hort_cfg.pointcloud_upsampler.num_p0 = 2048
+hort_cfg.pointcloud_upsampler.radius = 1
+hort_cfg.pointcloud_upsampler.bounding = True
+hort_cfg.pointcloud_upsampler.use_fps = True
+hort_cfg.pointcloud_upsampler.up_factors = [2, 4]
+hort_cfg.pointcloud_upsampler.token_type = "image_token"
+
+
+class model(nn.Module):
+ def __init__(self):
+ super(model, self).__init__()
+ self.image_tokenizer = tgs.find(hort_cfg.image_tokenizer_cls)(hort_cfg.image_tokenizer)
+ self.pointnet = PointNetEncoder(67, 1024)
+ self.tokenizer = tgs.find(hort_cfg.tokenizer_cls)(hort_cfg.tokenizer)
+ self.backbone = tgs.find(hort_cfg.backbone_cls)(hort_cfg.backbone)
+ self.post_processor = tgs.find(hort_cfg.post_processor_cls)(hort_cfg.post_processor)
+ self.post_processor_trans = nn.Sequential(nn.Linear(512, 256), nn.ReLU(), nn.Linear(256, 128), nn.ReLU(), nn.Linear(128, 3))
+ self.pointcloud_upsampler = tgs.find(hort_cfg.pointcloud_upsampler_cls)(hort_cfg.pointcloud_upsampler)
+
+ def forward(self, input_img, metas):
+ with torch.no_grad():
+ batch_size = input_img.shape[0]
+
+ encoder_hidden_states = self.image_tokenizer(input_img, None) # B * C * Nt
+ encoder_hidden_states = encoder_hidden_states.transpose(2, 1) # B * Nt * C
+
+ palm_norm_hand_verts_3d = metas['right_hand_verts_3d'] - metas['right_hand_palm'].unsqueeze(1)
+ point_idx = torch.arange(778).view(1, 778, 1).expand(batch_size, -1, -1).to(input_img.device) / 778.
+ palm_norm_hand_verts_3d = torch.cat([palm_norm_hand_verts_3d, point_idx], -1)
+ tip_norm_hand_verts_3d = (metas['right_hand_verts_3d'].unsqueeze(2) - metas['right_hand_joints_3d'].unsqueeze(1)).reshape((batch_size, 778, -1))
+ norm_hand_verts_3d = torch.cat([palm_norm_hand_verts_3d, tip_norm_hand_verts_3d], -1)
+ hand_feats = self.pointnet(norm_hand_verts_3d)
+
+ tokens = self.tokenizer(batch_size)
+ tokens = self.backbone(tokens, torch.cat([encoder_hidden_states, hand_feats.unsqueeze(1)], 1), modulation_cond=None)
+ tokens = self.tokenizer.detokenize(tokens)
+
+ pointclouds = self.post_processor(tokens[:, :2048, :])
+ pred_obj_trans = self.post_processor_trans(tokens[:, -1, :])
+
+ upsampling_input = {
+ "input_image_tokens": encoder_hidden_states.permute(0, 2, 1),
+ "intrinsic_cond": metas['cam_intr'],
+ "points": pointclouds,
+ "hand_points": metas["right_hand_verts_3d"],
+ "trans": pred_obj_trans + metas['right_hand_palm'],
+ "scale": 0.3
+ }
+ up_results = self.pointcloud_upsampler(upsampling_input)
+ pointclouds_up = up_results[-1]
+
+ pc_results = {}
+ pc_results['pointclouds'] = pointclouds
+ pc_results['objtrans'] = pred_obj_trans
+ pc_results['handpalm'] = metas['right_hand_palm']
+ pc_results['pointclouds_up'] = pointclouds_up
+
+ return pc_results
+
+def load_hort(ckpt_path):
+ hort_model = model()
+ ckpt = torch.load(ckpt_path, map_location=torch.device('cpu'))["network"]
+ ckpt = {k.replace('module.', ''): v for k, v in ckpt.items()}
+ hort_model.load_state_dict(ckpt)
+ return hort_model
diff --git a/hort/models/network/pointnet.py b/hort/models/network/pointnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..c81db3db7ef65e80952c674fb4d20ac30c2e25af
--- /dev/null
+++ b/hort/models/network/pointnet.py
@@ -0,0 +1,36 @@
+import torch
+import torch.nn as nn
+
+
+class PointNetEncoder(nn.Module):
+ """Encoder for Pointcloud
+ """
+ def __init__(self, in_channels: int=3, output_channels: int=768):
+ super().__init__()
+
+ block_channel = [64, 128, 256, 512]
+ self.mlp = nn.Sequential(
+ nn.Linear(in_channels, block_channel[0]),
+ nn.LayerNorm(block_channel[0]),
+ nn.ReLU(),
+ nn.Linear(block_channel[0], block_channel[1]),
+ nn.LayerNorm(block_channel[1]),
+ nn.ReLU(),
+ nn.Linear(block_channel[1], block_channel[2]),
+ nn.LayerNorm(block_channel[2]),
+ nn.ReLU(),
+ nn.Linear(block_channel[2], block_channel[3]),
+ nn.LayerNorm(block_channel[3]),
+ nn.ReLU(),
+ )
+
+ self.final_projection = nn.Sequential(
+ nn.Linear(block_channel[-1], output_channels),
+ nn.LayerNorm(output_channels)
+ )
+
+ def forward(self, x):
+ x = self.mlp(x)
+ x = torch.max(x, 1)[0]
+ x = self.final_projection(x)
+ return x
\ No newline at end of file
diff --git a/hort/models/tgs/__init__.py b/hort/models/tgs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d72c1d1ab4cc2b5481d6028c001c9cab91b35ed9
--- /dev/null
+++ b/hort/models/tgs/__init__.py
@@ -0,0 +1,9 @@
+import importlib
+from tgs.utils.typing import *
+
+def find(cls_string) -> Type:
+ module_string = ".".join(cls_string.split(".")[:-1])
+ cls_name = cls_string.split(".")[-1]
+ module = importlib.import_module(module_string, package=None)
+ cls = getattr(module, cls_name)
+ return cls
\ No newline at end of file
diff --git a/hort/models/tgs/data.py b/hort/models/tgs/data.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ab123fcc1223232934213d8c29c17b4c1f9df50
--- /dev/null
+++ b/hort/models/tgs/data.py
@@ -0,0 +1,265 @@
+import json
+import math
+from dataclasses import dataclass, field
+
+import os
+import imageio
+import numpy as np
+import torch
+import torch.nn.functional as F
+from PIL import Image
+from torch.utils.data import Dataset
+
+from tgs.utils.config import parse_structured
+from tgs.utils.ops import get_intrinsic_from_fov, get_ray_directions, get_rays
+from tgs.utils.typing import *
+
+
+def _parse_scene_list_single(scene_list_path: str):
+ if scene_list_path.endswith(".json"):
+ with open(scene_list_path) as f:
+ all_scenes = json.loads(f.read())
+ elif scene_list_path.endswith(".txt"):
+ with open(scene_list_path) as f:
+ all_scenes = [p.strip() for p in f.readlines()]
+ else:
+ all_scenes = [scene_list_path]
+
+ return all_scenes
+
+
+def _parse_scene_list(scene_list_path: Union[str, List[str]]):
+ all_scenes = []
+ if isinstance(scene_list_path, str):
+ scene_list_path = [scene_list_path]
+ for scene_list_path_ in scene_list_path:
+ all_scenes += _parse_scene_list_single(scene_list_path_)
+ return all_scenes
+
+@dataclass
+class CustomImageDataModuleConfig:
+ image_list: Any = ""
+ background_color: Tuple[float, float, float] = field(
+ default_factory=lambda: (1.0, 1.0, 1.0)
+ )
+
+ relative_pose: bool = False
+ cond_height: int = 512
+ cond_width: int = 512
+ cond_camera_distance: float = 1.6
+ cond_fovy_deg: float = 40.0
+ cond_elevation_deg: float = 0.0
+ cond_azimuth_deg: float = 0.0
+ num_workers: int = 16
+
+ eval_height: int = 512
+ eval_width: int = 512
+ eval_batch_size: int = 1
+ eval_elevation_deg: float = 0.0
+ eval_camera_distance: float = 1.6
+ eval_fovy_deg: float = 40.0
+ n_test_views: int = 120
+ num_views_output: int = 120
+ only_3dgs: bool = False
+
+class CustomImageOrbitDataset(Dataset):
+ def __init__(self, cfg: Any) -> None:
+ super().__init__()
+ self.cfg: CustomImageDataModuleConfig = parse_structured(CustomImageDataModuleConfig, cfg)
+
+ self.n_views = self.cfg.n_test_views
+ assert self.n_views % self.cfg.num_views_output == 0
+
+ self.all_scenes = _parse_scene_list(self.cfg.image_list)
+
+ azimuth_deg: Float[Tensor, "B"] = torch.linspace(0, 360.0, self.n_views + 1)[
+ : self.n_views
+ ]
+ elevation_deg: Float[Tensor, "B"] = torch.full_like(
+ azimuth_deg, self.cfg.eval_elevation_deg
+ )
+ camera_distances: Float[Tensor, "B"] = torch.full_like(
+ elevation_deg, self.cfg.eval_camera_distance
+ )
+
+ elevation = elevation_deg * math.pi / 180
+ azimuth = azimuth_deg * math.pi / 180
+
+ # convert spherical coordinates to cartesian coordinates
+ # right hand coordinate system, x back, y right, z up
+ # elevation in (-90, 90), azimuth from +x to +y in (-180, 180)
+ camera_positions: Float[Tensor, "B 3"] = torch.stack(
+ [
+ camera_distances * torch.cos(elevation) * torch.cos(azimuth),
+ camera_distances * torch.cos(elevation) * torch.sin(azimuth),
+ camera_distances * torch.sin(elevation),
+ ],
+ dim=-1,
+ )
+
+ # default scene center at origin
+ center: Float[Tensor, "B 3"] = torch.zeros_like(camera_positions)
+ # default camera up direction as +z
+ up: Float[Tensor, "B 3"] = torch.as_tensor([0, 0, 1], dtype=torch.float32)[
+ None, :
+ ].repeat(self.n_views, 1)
+
+ fovy_deg: Float[Tensor, "B"] = torch.full_like(
+ elevation_deg, self.cfg.eval_fovy_deg
+ )
+ fovy = fovy_deg * math.pi / 180
+
+ lookat: Float[Tensor, "B 3"] = F.normalize(center - camera_positions, dim=-1)
+ right: Float[Tensor, "B 3"] = F.normalize(torch.cross(lookat, up), dim=-1)
+ up = F.normalize(torch.cross(right, lookat), dim=-1)
+ c2w3x4: Float[Tensor, "B 3 4"] = torch.cat(
+ [torch.stack([right, up, -lookat], dim=-1), camera_positions[:, :, None]],
+ dim=-1,
+ )
+ c2w: Float[Tensor, "B 4 4"] = torch.cat(
+ [c2w3x4, torch.zeros_like(c2w3x4[:, :1])], dim=1
+ )
+ c2w[:, 3, 3] = 1.0
+
+ # get directions by dividing directions_unit_focal by focal length
+ focal_length: Float[Tensor, "B"] = (
+ 0.5 * self.cfg.eval_height / torch.tan(0.5 * fovy)
+ )
+ directions_unit_focal = get_ray_directions(
+ H=self.cfg.eval_height,
+ W=self.cfg.eval_width,
+ focal=1.0,
+ )
+ directions: Float[Tensor, "B H W 3"] = directions_unit_focal[
+ None, :, :, :
+ ].repeat(self.n_views, 1, 1, 1)
+ directions[:, :, :, :2] = (
+ directions[:, :, :, :2] / focal_length[:, None, None, None]
+ )
+ # must use normalize=True to normalize directions here
+ rays_o, rays_d = get_rays(directions, c2w, keepdim=True)
+
+ intrinsic: Float[Tensor, "B 3 3"] = get_intrinsic_from_fov(
+ self.cfg.eval_fovy_deg * math.pi / 180,
+ H=self.cfg.eval_height,
+ W=self.cfg.eval_width,
+ bs=self.n_views,
+ )
+ intrinsic_normed: Float[Tensor, "B 3 3"] = intrinsic.clone()
+ intrinsic_normed[..., 0, 2] /= self.cfg.eval_width
+ intrinsic_normed[..., 1, 2] /= self.cfg.eval_height
+ intrinsic_normed[..., 0, 0] /= self.cfg.eval_width
+ intrinsic_normed[..., 1, 1] /= self.cfg.eval_height
+
+ self.rays_o, self.rays_d = rays_o, rays_d
+ self.intrinsic = intrinsic
+ self.intrinsic_normed = intrinsic_normed
+ self.c2w = c2w
+ self.camera_positions = camera_positions
+
+ self.background_color = torch.as_tensor(self.cfg.background_color)
+
+ # condition
+ self.intrinsic_cond = get_intrinsic_from_fov(
+ np.deg2rad(self.cfg.cond_fovy_deg),
+ H=self.cfg.cond_height,
+ W=self.cfg.cond_width,
+ )
+ self.intrinsic_normed_cond = self.intrinsic_cond.clone()
+ self.intrinsic_normed_cond[..., 0, 2] /= self.cfg.cond_width
+ self.intrinsic_normed_cond[..., 1, 2] /= self.cfg.cond_height
+ self.intrinsic_normed_cond[..., 0, 0] /= self.cfg.cond_width
+ self.intrinsic_normed_cond[..., 1, 1] /= self.cfg.cond_height
+
+
+ if self.cfg.relative_pose:
+ self.c2w_cond = torch.as_tensor(
+ [
+ [0, 0, 1, self.cfg.cond_camera_distance],
+ [1, 0, 0, 0],
+ [0, 1, 0, 0],
+ [0, 0, 0, 1],
+ ]
+ ).float()
+ else:
+ cond_elevation = self.cfg.cond_elevation_deg * math.pi / 180
+ cond_azimuth = self.cfg.cond_azimuth_deg * math.pi / 180
+ cond_camera_position: Float[Tensor, "3"] = torch.as_tensor(
+ [
+ self.cfg.cond_camera_distance * np.cos(cond_elevation) * np.cos(cond_azimuth),
+ self.cfg.cond_camera_distance * np.cos(cond_elevation) * np.sin(cond_azimuth),
+ self.cfg.cond_camera_distance * np.sin(cond_elevation),
+ ], dtype=torch.float32
+ )
+
+ cond_center: Float[Tensor, "3"] = torch.zeros_like(cond_camera_position)
+ cond_up: Float[Tensor, "3"] = torch.as_tensor([0, 0, 1], dtype=torch.float32)
+ cond_lookat: Float[Tensor, "3"] = F.normalize(cond_center - cond_camera_position, dim=-1)
+ cond_right: Float[Tensor, "3"] = F.normalize(torch.cross(cond_lookat, cond_up), dim=-1)
+ cond_up = F.normalize(torch.cross(cond_right, cond_lookat), dim=-1)
+ cond_c2w3x4: Float[Tensor, "3 4"] = torch.cat(
+ [torch.stack([cond_right, cond_up, -cond_lookat], dim=-1), cond_camera_position[:, None]],
+ dim=-1,
+ )
+ cond_c2w: Float[Tensor, "4 4"] = torch.cat(
+ [cond_c2w3x4, torch.zeros_like(cond_c2w3x4[:1])], dim=0
+ )
+ cond_c2w[3, 3] = 1.0
+ self.c2w_cond = cond_c2w
+
+ def __len__(self):
+ if self.cfg.only_3dgs:
+ return len(self.all_scenes)
+ else:
+ return len(self.all_scenes) * self.n_views // self.cfg.num_views_output
+
+ def __getitem__(self, index):
+ if self.cfg.only_3dgs:
+ scene_index = index
+ view_index = [0]
+ else:
+ scene_index = index * self.cfg.num_views_output // self.n_views
+ view_start = index % (self.n_views // self.cfg.num_views_output)
+ view_index = list(range(self.n_views))[view_start * self.cfg.num_views_output :
+ (view_start + 1) * self.cfg.num_views_output]
+
+ img_path = self.all_scenes[scene_index]
+ img_cond = torch.from_numpy(
+ np.asarray(
+ Image.fromarray(imageio.v2.imread(img_path))
+ .convert("RGBA")
+ .resize((self.cfg.cond_width, self.cfg.cond_height))
+ )
+ / 255.0
+ ).float()
+ mask_cond: Float[Tensor, "Hc Wc 1"] = img_cond[:, :, -1:]
+ rgb_cond: Float[Tensor, "Hc Wc 3"] = img_cond[
+ :, :, :3
+ ] * mask_cond + self.background_color[None, None, :] * (1 - mask_cond)
+
+ out = {
+ "rgb_cond": rgb_cond.unsqueeze(0),
+ "c2w_cond": self.c2w_cond.unsqueeze(0),
+ "mask_cond": mask_cond.unsqueeze(0),
+ "intrinsic_cond": self.intrinsic_cond.unsqueeze(0),
+ "intrinsic_normed_cond": self.intrinsic_normed_cond.unsqueeze(0),
+ "view_index": torch.as_tensor(view_index),
+ "rays_o": self.rays_o[view_index],
+ "rays_d": self.rays_d[view_index],
+ "intrinsic": self.intrinsic[view_index],
+ "intrinsic_normed": self.intrinsic_normed[view_index],
+ "c2w": self.c2w[view_index],
+ "camera_positions": self.camera_positions[view_index],
+ }
+ out["c2w"][..., :3, 1:3] *= -1
+ out["c2w_cond"][..., :3, 1:3] *= -1
+ instance_id = os.path.split(img_path)[-1].split('.')[0]
+ out["index"] = torch.as_tensor(scene_index)
+ out["background_color"] = self.background_color
+ out["instance_id"] = instance_id
+ return out
+
+ def collate(self, batch):
+ batch = torch.utils.data.default_collate(batch)
+ batch.update({"height": self.cfg.eval_height, "width": self.cfg.eval_width})
+ return batch
\ No newline at end of file
diff --git a/hort/models/tgs/models/__init__.py b/hort/models/tgs/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/hort/models/tgs/models/image_feature.py b/hort/models/tgs/models/image_feature.py
new file mode 100644
index 0000000000000000000000000000000000000000..e1c6a0aa8fd3fe4b04315bb56e206153c291dee4
--- /dev/null
+++ b/hort/models/tgs/models/image_feature.py
@@ -0,0 +1,48 @@
+from dataclasses import dataclass
+import torch
+import torch.nn.functional as F
+from einops import rearrange
+
+from tgs.utils.base import BaseModule
+from tgs.utils.ops import compute_distance_transform
+from tgs.utils.typing import *
+
+class ImageFeature(BaseModule):
+ @dataclass
+ class Config(BaseModule.Config):
+ use_rgb: bool = True
+ use_feature: bool = True
+ use_mask: bool = True
+ feature_dim: int = 128
+ out_dim: int = 133
+ backbone: str = "default"
+ freeze_backbone_params: bool = True
+
+ cfg: Config
+
+ def forward(self, rgb, mask=None, feature=None):
+ B, Nv, H, W = rgb.shape[:4]
+ rgb = rearrange(rgb, "B Nv H W C -> (B Nv) C H W")
+ if mask is not None:
+ mask = rearrange(mask, "B Nv H W C -> (B Nv) C H W")
+
+ assert feature is not None
+ # reshape dino tokens to image-like size
+ feature = rearrange(feature, "B (Nv Nt) C -> (B Nv) Nt C", Nv=Nv)
+ feature = feature[:, 1:].reshape(B * Nv, H // 14, W // 14, -1).permute(0, 3, 1, 2).contiguous()
+ feature = F.interpolate(feature, size=(H, W), mode='bilinear', align_corners=False)
+
+ if mask is not None and mask.is_floating_point():
+ mask = mask > 0.5
+
+ image_features = []
+ if self.cfg.use_rgb:
+ image_features.append(rgb)
+ if self.cfg.use_feature:
+ image_features.append(feature)
+ if self.cfg.use_mask:
+ image_features += [mask, compute_distance_transform(mask)]
+
+ # detach features, occur error when with grad
+ image_features = torch.cat(image_features, dim=1)#.detach()
+ return rearrange(image_features, "(B Nv) C H W -> B Nv C H W", B=B, Nv=Nv).squeeze(1)
\ No newline at end of file
diff --git a/hort/models/tgs/models/networks.py b/hort/models/tgs/models/networks.py
new file mode 100644
index 0000000000000000000000000000000000000000..a5c2bb1f4ec1a5d9e8b824926ba1e4ec51ade25a
--- /dev/null
+++ b/hort/models/tgs/models/networks.py
@@ -0,0 +1,204 @@
+from dataclasses import dataclass
+
+import torch
+import torch.nn as nn
+from einops import rearrange
+import numpy as np
+
+from tgs.utils.base import BaseModule
+from tgs.utils.ops import get_activation
+from tgs.utils.typing import *
+
+class PointOutLayer(BaseModule):
+ @dataclass
+ class Config(BaseModule.Config):
+ in_channels: int = 1024
+ out_channels: int = 3
+ cfg: Config
+ def configure(self) -> None:
+ super().configure()
+ self.point_layer = nn.Linear(self.cfg.in_channels, self.cfg.out_channels)
+ self.initialize_weights()
+
+ def initialize_weights(self):
+ nn.init.constant_(self.point_layer.weight, 0)
+ nn.init.constant_(self.point_layer.bias, 0)
+
+ def forward(self, x):
+ return self.point_layer(x)
+
+class TriplaneUpsampleNetwork(BaseModule):
+ @dataclass
+ class Config(BaseModule.Config):
+ in_channels: int = 1024
+ out_channels: int = 80
+
+ cfg: Config
+
+ def configure(self) -> None:
+ super().configure()
+ self.upsample = nn.ConvTranspose2d(
+ self.cfg.in_channels, self.cfg.out_channels, kernel_size=2, stride=2
+ )
+
+ def forward(
+ self, triplanes: Float[Tensor, "B 3 Ci Hp Wp"]
+ ) -> Float[Tensor, "B 3 Co Hp2 Wp2"]:
+ triplanes_up = rearrange(
+ self.upsample(
+ rearrange(triplanes, "B Np Ci Hp Wp -> (B Np) Ci Hp Wp", Np=3)
+ ),
+ "(B Np) Co Hp Wp -> B Np Co Hp Wp",
+ Np=3,
+ )
+ return triplanes_up
+
+
+class MLP(nn.Module):
+ def __init__(
+ self,
+ dim_in: int,
+ dim_out: int,
+ n_neurons: int,
+ n_hidden_layers: int,
+ activation: str = "relu",
+ output_activation: Optional[str] = None,
+ bias: bool = True,
+ ):
+ super().__init__()
+ layers = [
+ self.make_linear(
+ dim_in, n_neurons, is_first=True, is_last=False, bias=bias
+ ),
+ self.make_activation(activation),
+ ]
+ for i in range(n_hidden_layers - 1):
+ layers += [
+ self.make_linear(
+ n_neurons, n_neurons, is_first=False, is_last=False, bias=bias
+ ),
+ self.make_activation(activation),
+ ]
+ layers += [
+ self.make_linear(
+ n_neurons, dim_out, is_first=False, is_last=True, bias=bias
+ )
+ ]
+ self.layers = nn.Sequential(*layers)
+ self.output_activation = get_activation(output_activation)
+
+ def forward(self, x):
+ x = self.layers(x)
+ x = self.output_activation(x)
+ return x
+
+ def make_linear(self, dim_in, dim_out, is_first, is_last, bias=True):
+ layer = nn.Linear(dim_in, dim_out, bias=bias)
+ return layer
+
+ def make_activation(self, activation):
+ if activation == "relu":
+ return nn.ReLU(inplace=True)
+ elif activation == "silu":
+ return nn.SiLU(inplace=True)
+ else:
+ raise NotImplementedError
+
+class GSProjection(nn.Module):
+ def __init__(self,
+ in_channels: int = 80,
+ sh_degree: int = 3,
+ init_scaling: float = -5.0,
+ init_density: float = 0.1) -> None:
+ super().__init__()
+
+ self.out_keys = GS_KEYS + ["shs"]
+ self.out_channels = GS_CHANNELS + [(sh_degree + 1) ** 2 * 3]
+
+ self.out_layers = nn.ModuleList()
+ for key, ch in zip(self.out_keys, self.out_channels):
+ layer = nn.Linear(in_channels, ch)
+ # initialize
+ nn.init.constant_(layer.weight, 0)
+ nn.init.constant_(layer.bias, 0)
+
+ if key == "scaling":
+ nn.init.constant_(layer.bias, init_scaling)
+ elif key == "rotation":
+ nn.init.constant_(layer.bias, 0)
+ nn.init.constant_(layer.bias[0], 1.0)
+ elif key == "opacity":
+ inverse_sigmoid = lambda x: np.log(x / (1 - x))
+ nn.init.constant_(layer.bias, inverse_sigmoid(init_density))
+
+ self.out_layers.append(layer)
+
+ def forward(self, x):
+ ret = []
+ for k, layer in zip(self.out_keys, self.out_layers):
+ v = layer(x)
+ if k == "rotation":
+ v = torch.nn.functional.normalize(v)
+ elif k == "scaling":
+ v = torch.exp(v)
+ # v = v.detach() # FIXME: for DEBUG
+ elif k == "opacity":
+ v = torch.sigmoid(v)
+ # elif k == "shs":
+ # v = torch.reshape(v, (v.shape[0], -1, 3))
+ ret.append(v)
+ ret = torch.cat(ret, dim=-1)
+ return ret
+
+def get_encoding(n_input_dims: int, config) -> nn.Module:
+ raise NotImplementedError
+
+
+def get_mlp(n_input_dims, n_output_dims, config) -> nn.Module:
+ raise NotImplementedError
+
+
+# Resnet Blocks for pointnet
+class ResnetBlockFC(nn.Module):
+ ''' Fully connected ResNet Block class.
+
+ Args:
+ size_in (int): input dimension
+ size_out (int): output dimension
+ size_h (int): hidden dimension
+ '''
+
+ def __init__(self, size_in, size_out=None, size_h=None):
+ super().__init__()
+ # Attributes
+ if size_out is None:
+ size_out = size_in
+
+ if size_h is None:
+ size_h = min(size_in, size_out)
+
+ self.size_in = size_in
+ self.size_h = size_h
+ self.size_out = size_out
+ # Submodules
+ self.fc_0 = nn.Linear(size_in, size_h)
+ self.fc_1 = nn.Linear(size_h, size_out)
+ self.actvn = nn.ReLU()
+
+ if size_in == size_out:
+ self.shortcut = None
+ else:
+ self.shortcut = nn.Linear(size_in, size_out, bias=False)
+ # Initialization
+ nn.init.zeros_(self.fc_1.weight)
+
+ def forward(self, x):
+ net = self.fc_0(self.actvn(x))
+ dx = self.fc_1(self.actvn(net))
+
+ if self.shortcut is not None:
+ x_s = self.shortcut(x)
+ else:
+ x_s = x
+
+ return x_s + dx
\ No newline at end of file
diff --git a/hort/models/tgs/models/pointclouds/LICENSE_POINTNET b/hort/models/tgs/models/pointclouds/LICENSE_POINTNET
new file mode 100644
index 0000000000000000000000000000000000000000..5358ea96300c33c974dafa4c492e2df20b0dbb39
--- /dev/null
+++ b/hort/models/tgs/models/pointclouds/LICENSE_POINTNET
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2020 Songyou Peng, Michael Niemeyer, Lars Mescheder, Marc Pollefeys, Andreas Geiger
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
\ No newline at end of file
diff --git a/hort/models/tgs/models/pointclouds/pointnet.py b/hort/models/tgs/models/pointclouds/pointnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..14b5cbe9ba43ab514369f06fc6db794fb88c9329
--- /dev/null
+++ b/hort/models/tgs/models/pointclouds/pointnet.py
@@ -0,0 +1,121 @@
+# modified from https://github.com/autonomousvision/convolutional_occupancy_networks/blob/master/src/encoder/pointnet.py
+from dataclasses import dataclass
+import torch
+import torch.nn as nn
+from torch_scatter import scatter_mean, scatter_max
+
+from tgs.utils.base import BaseModule
+from tgs.models.networks import ResnetBlockFC
+from tgs.utils.ops import scale_tensor
+
+class LocalPoolPointnet(BaseModule):
+ ''' PointNet-based encoder network with ResNet blocks for each point.
+ Number of input points are fixed.
+
+ Args:
+ c_dim (int): dimension of latent code c
+ dim (int): input points dimension
+ hidden_dim (int): hidden dimension of the network
+ scatter_type (str): feature aggregation when doing local pooling
+ plane_resolution (int): defined resolution for plane feature
+ padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55]
+ n_blocks (int): number of blocks ResNetBlockFC layers
+ '''
+
+ @dataclass
+ class Config(BaseModule.Config):
+ input_channels: int = 3
+ c_dim: int = 128
+ hidden_dim: int = 128
+ scatter_type: str = "max"
+ plane_size: int = 32
+ n_blocks: int = 5
+ radius: float = 1.
+
+ cfg: Config
+
+ def configure(self) -> None:
+ super().configure()
+ self.fc_pos = nn.Linear(self.cfg.input_channels, 2 * self.cfg.hidden_dim)
+ self.blocks = nn.ModuleList([
+ ResnetBlockFC(2 * self.cfg.hidden_dim, self.cfg.hidden_dim) for i in range(self.cfg.n_blocks)
+ ])
+ self.fc_c = nn.Linear(self.cfg.hidden_dim, self.cfg.c_dim)
+
+ self.actvn = nn.ReLU()
+
+ if self.cfg.scatter_type == 'max':
+ self.scatter = scatter_max
+ elif self.cfg.scatter_type == 'mean':
+ self.scatter = scatter_mean
+ else:
+ raise ValueError('incorrect scatter type')
+
+
+ def generate_plane_features(self, index, c):
+ # acquire indices of features in plane
+ # xy = normalize_coordinate(p.clone(), plane=plane, padding=self.padding) # normalize to the range of (0, 1)
+ # index = self.coordinate2index(x, self.cfg.plane_size)
+
+ # scatter plane features from points
+ fea_plane = c.new_zeros(index.shape[0], self.cfg.c_dim, self.cfg.plane_size ** 2)
+ c = c.permute(0, 2, 1) # B x 512 x T
+ fea_plane = scatter_mean(c, index, out=fea_plane) # B x 512 x reso^2
+ fea_plane = fea_plane.reshape(index.shape[0], self.cfg.c_dim, self.cfg.plane_size, self.cfg.plane_size) # sparce matrix (B x 512 x reso x reso)
+
+ return fea_plane
+
+ def pool_local(self, xy, index, c):
+ bs, fea_dim = c.shape[0], c.shape[2]
+ keys = xy.keys()
+
+ c_out = 0
+ for key in keys:
+ # scatter plane features from points
+ fea = self.scatter(c.permute(0, 2, 1), index[key], dim_size=self.cfg.plane_size ** 2)
+ if self.scatter == scatter_max:
+ fea = fea[0]
+ # gather feature back to points
+ fea = fea.gather(dim=2, index=index[key].expand(-1, fea_dim, -1))
+ c_out += fea
+ return c_out.permute(0, 2, 1)
+
+ def coordinate2index(self, x):
+ x = (x * self.cfg.plane_size).long()
+ index = x[..., 0] + self.cfg.plane_size * x[..., 1]
+ assert index.max() < self.cfg.plane_size ** 2
+ return index[:, None, :]
+
+ def forward(self, p):
+ batch_size, T, D = p.shape
+
+ # acquire the index for each point
+ coord = {}
+ index = {}
+
+ position = torch.clamp(p[..., :3], -self.cfg.radius + 1e-6, self.cfg.radius - 1e-6)
+ position_norm = scale_tensor(position, (-self.cfg.radius, self.cfg.radius), (0, 1))
+ coord["xy"] = position_norm[..., [0, 1]]
+ coord["xz"] = position_norm[..., [0, 2]]
+ coord["yz"] = position_norm[..., [1, 2]]
+ index["xy"] = self.coordinate2index(coord["xy"])
+ index["xz"] = self.coordinate2index(coord["xz"])
+ index["yz"] = self.coordinate2index(coord["yz"])
+
+ net = self.fc_pos(p)
+
+ net = self.blocks[0](net)
+ for block in self.blocks[1:]:
+ pooled = self.pool_local(coord, index, net)
+ net = torch.cat([net, pooled], dim=2)
+ net = block(net)
+
+ c = self.fc_c(net)
+
+ features = torch.stack([
+ self.generate_plane_features(index["xy"], c),
+ self.generate_plane_features(index["xz"], c),
+ self.generate_plane_features(index["yz"], c)
+ ], dim=1)
+
+ return features
diff --git a/hort/models/tgs/models/pointclouds/simplepoint.py b/hort/models/tgs/models/pointclouds/simplepoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..2fb028d07a7949e93a2116e9d77f4d669636f145
--- /dev/null
+++ b/hort/models/tgs/models/pointclouds/simplepoint.py
@@ -0,0 +1,110 @@
+from dataclasses import dataclass, field
+import torch
+from einops import rearrange
+
+import tgs
+from tgs.utils.base import BaseModule
+from tgs.utils.typing import *
+
+class SimplePointGenerator(BaseModule):
+ @dataclass
+ class Config(BaseModule.Config):
+ camera_embedder_cls: str = ""
+ camera_embedder: dict = field(default_factory=dict)
+
+ image_tokenizer_cls: str = ""
+ image_tokenizer: dict = field(default_factory=dict)
+
+ tokenizer_cls: str = ""
+ tokenizer: dict = field(default_factory=dict)
+
+ backbone_cls: str = ""
+ backbone: dict = field(default_factory=dict)
+
+ post_processor_cls: str = ""
+ post_processor: dict = field(default_factory=dict)
+
+ pointcloud_upsampling_cls: str = ""
+ pointcloud_upsampling: dict = field(default_factory=dict)
+
+ flip_c2w_cond: bool = True
+
+ cfg: Config
+
+ def configure(self) -> None:
+ super().configure()
+
+ self.image_tokenizer = tgs.find(self.cfg.image_tokenizer_cls)(
+ self.cfg.image_tokenizer
+ )
+
+ assert self.cfg.camera_embedder_cls == 'tgs.models.networks.MLP'
+ weights = self.cfg.camera_embedder.pop("weights") if "weights" in self.cfg.camera_embedder else None
+ self.camera_embedder = tgs.find(self.cfg.camera_embedder_cls)(**self.cfg.camera_embedder)
+ if weights:
+ from tgs.utils.misc import load_module_weights
+ weights_path, module_name = weights.split(":")
+ state_dict = load_module_weights(
+ weights_path, module_name=module_name, map_location="cpu"
+ )
+ self.camera_embedder.load_state_dict(state_dict)
+
+ self.tokenizer = tgs.find(self.cfg.tokenizer_cls)(self.cfg.tokenizer)
+
+ self.backbone = tgs.find(self.cfg.backbone_cls)(self.cfg.backbone)
+
+ self.post_processor = tgs.find(self.cfg.post_processor_cls)(
+ self.cfg.post_processor
+ )
+
+ self.pointcloud_upsampling = tgs.find(self.cfg.pointcloud_upsampling_cls)(self.cfg.pointcloud_upsampling)
+
+ def forward(self, batch, encoder_hidden_states=None, **kwargs):
+ batch_size, n_input_views = batch["rgb_cond"].shape[:2]
+
+ if encoder_hidden_states is None:
+ # Camera modulation
+ c2w_cond = batch["c2w_cond"].clone()
+ if self.cfg.flip_c2w_cond:
+ c2w_cond[..., :3, 1:3] *= -1
+ camera_extri = c2w_cond.view(*c2w_cond.shape[:-2], -1)
+ camera_intri = batch["intrinsic_normed_cond"].view(
+ *batch["intrinsic_normed_cond"].shape[:-2], -1)
+ camera_feats = torch.cat([camera_intri, camera_extri], dim=-1)
+ # camera_feats = rearrange(camera_feats, 'B Nv C -> (B Nv) C')
+
+ camera_feats = self.camera_embedder(camera_feats)
+
+ encoder_hidden_states: Float[Tensor, "B Cit Nit"] = self.image_tokenizer(
+ rearrange(batch["rgb_cond"], 'B Nv H W C -> B Nv C H W'),
+ modulation_cond=camera_feats,
+ )
+ encoder_hidden_states = rearrange(
+ encoder_hidden_states, 'B Nv C Nt -> B (Nv Nt) C', Nv=n_input_views)
+
+ tokens: Float[Tensor, "B Ct Nt"] = self.tokenizer(batch_size)
+
+ tokens = self.backbone(
+ tokens,
+ encoder_hidden_states=encoder_hidden_states,
+ modulation_cond=None,
+ )
+ pointclouds = self.post_processor(self.tokenizer.detokenize(tokens))
+
+ upsampling_input = {
+ "input_image_tokens": encoder_hidden_states.permute(0, 2, 1),
+ "input_image_tokens_global": encoder_hidden_states[:, :1],
+ "c2w_cond": c2w_cond,
+ "rgb_cond": batch["rgb_cond"],
+ "intrinsic_cond": batch["intrinsic_cond"],
+ "intrinsic_normed_cond": batch["intrinsic_normed_cond"],
+ "points": pointclouds.float()
+ }
+ up_results = self.pointcloud_upsampling(upsampling_input)
+ up_results.insert(0, pointclouds)
+ pointclouds = up_results[-1]
+ out = {
+ "points": pointclouds,
+ "up_results": up_results
+ }
+ return out
\ No newline at end of file
diff --git a/hort/models/tgs/models/renderer.py b/hort/models/tgs/models/renderer.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ff65178eb64f8c308acbf25cbf367a68f58f2f3
--- /dev/null
+++ b/hort/models/tgs/models/renderer.py
@@ -0,0 +1,427 @@
+from dataclasses import dataclass, field
+from collections import defaultdict
+from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer
+from plyfile import PlyData, PlyElement
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+import math
+
+from tgs.utils.typing import *
+from tgs.utils.base import BaseModule
+from tgs.utils.ops import trunc_exp
+from tgs.models.networks import MLP
+from tgs.utils.ops import scale_tensor
+from einops import rearrange, reduce
+
+inverse_sigmoid = lambda x: np.log(x / (1 - x))
+
+def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0):
+ Rt = np.zeros((4, 4))
+ Rt[:3, :3] = R.transpose()
+ Rt[:3, 3] = t
+ Rt[3, 3] = 1.0
+
+ C2W = np.linalg.inv(Rt)
+ cam_center = C2W[:3, 3]
+ cam_center = (cam_center + translate) * scale
+ C2W[:3, 3] = cam_center
+ Rt = np.linalg.inv(C2W)
+ return np.float32(Rt)
+
+def getProjectionMatrix(znear, zfar, fovX, fovY):
+ tanHalfFovY = math.tan((fovY / 2))
+ tanHalfFovX = math.tan((fovX / 2))
+
+ top = tanHalfFovY * znear
+ bottom = -top
+ right = tanHalfFovX * znear
+ left = -right
+
+ P = torch.zeros(4, 4)
+
+ z_sign = 1.0
+
+ P[0, 0] = 2.0 * znear / (right - left)
+ P[1, 1] = 2.0 * znear / (top - bottom)
+ P[0, 2] = (right + left) / (right - left)
+ P[1, 2] = (top + bottom) / (top - bottom)
+ P[3, 2] = z_sign
+ P[2, 2] = z_sign * zfar / (zfar - znear)
+ P[2, 3] = -(zfar * znear) / (zfar - znear)
+ return P
+
+def intrinsic_to_fov(intrinsic, w, h):
+ fx, fy = intrinsic[0, 0], intrinsic[1, 1]
+ fov_x = 2 * torch.arctan2(w, 2 * fx)
+ fov_y = 2 * torch.arctan2(h, 2 * fy)
+ return fov_x, fov_y
+
+
+class Camera:
+ def __init__(self, w2c, intrinsic, FoVx, FoVy, height, width, trans=np.array([0.0, 0.0, 0.0]), scale=1.0) -> None:
+ self.FoVx = FoVx
+ self.FoVy = FoVy
+ self.height = height
+ self.width = width
+ self.world_view_transform = w2c.transpose(0, 1)
+
+ self.zfar = 100.0
+ self.znear = 0.01
+
+ self.trans = trans
+ self.scale = scale
+
+ self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy).transpose(0,1).to(w2c.device)
+ self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0)
+ self.camera_center = self.world_view_transform.inverse()[3, :3]
+
+ @staticmethod
+ def from_c2w(c2w, intrinsic, height, width):
+ w2c = torch.inverse(c2w)
+ FoVx, FoVy = intrinsic_to_fov(intrinsic, w=torch.tensor(width, device=w2c.device), h=torch.tensor(height, device=w2c.device))
+ return Camera(w2c=w2c, intrinsic=intrinsic, FoVx=FoVx, FoVy=FoVy, height=height, width=width)
+
+class GaussianModel(NamedTuple):
+ xyz: Tensor
+ opacity: Tensor
+ rotation: Tensor
+ scaling: Tensor
+ shs: Tensor
+
+ def construct_list_of_attributes(self):
+ l = ['x', 'y', 'z', 'nx', 'ny', 'nz']
+ features_dc = self.shs[:, :1]
+ features_rest = self.shs[:, 1:]
+ for i in range(features_dc.shape[1]*features_dc.shape[2]):
+ l.append('f_dc_{}'.format(i))
+ for i in range(features_rest.shape[1]*features_rest.shape[2]):
+ l.append('f_rest_{}'.format(i))
+ l.append('opacity')
+ for i in range(self.scaling.shape[1]):
+ l.append('scale_{}'.format(i))
+ for i in range(self.rotation.shape[1]):
+ l.append('rot_{}'.format(i))
+ return l
+
+ def save_ply(self, path):
+
+ xyz = self.xyz.detach().cpu().numpy()
+ normals = np.zeros_like(xyz)
+ features_dc = self.shs[:, :1]
+ features_rest = self.shs[:, 1:]
+ f_dc = features_dc.detach().flatten(start_dim=1).contiguous().cpu().numpy()
+ f_rest = features_rest.detach().flatten(start_dim=1).contiguous().cpu().numpy()
+ opacities = inverse_sigmoid(torch.clamp(self.opacity, 1e-3, 1 - 1e-3).detach().cpu().numpy())
+ scale = np.log(self.scaling.detach().cpu().numpy())
+ rotation = self.rotation.detach().cpu().numpy()
+
+ dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes()]
+
+ elements = np.empty(xyz.shape[0], dtype=dtype_full)
+ attributes = np.concatenate((xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1)
+ elements[:] = list(map(tuple, attributes))
+ el = PlyElement.describe(elements, 'vertex')
+ PlyData([el]).write(path)
+
+class GSLayer(BaseModule):
+ @dataclass
+ class Config(BaseModule.Config):
+ in_channels: int = 128
+ feature_channels: dict = field(default_factory=dict)
+ xyz_offset: bool = True
+ restrict_offset: bool = False
+ use_rgb: bool = False
+ clip_scaling: Optional[float] = None
+ init_scaling: float = -5.0
+ init_density: float = 0.1
+
+ cfg: Config
+
+ def configure(self, *args, **kwargs) -> None:
+ self.out_layers = nn.ModuleList()
+ for key, out_ch in self.cfg.feature_channels.items():
+ if key == "shs" and self.cfg.use_rgb:
+ out_ch = 3
+ layer = nn.Linear(self.cfg.in_channels, out_ch)
+
+ # initialize
+ if not (key == "shs" and self.cfg.use_rgb):
+ nn.init.constant_(layer.weight, 0)
+ nn.init.constant_(layer.bias, 0)
+ if key == "scaling":
+ nn.init.constant_(layer.bias, self.cfg.init_scaling)
+ elif key == "rotation":
+ nn.init.constant_(layer.bias, 0)
+ nn.init.constant_(layer.bias[0], 1.0)
+ elif key == "opacity":
+ nn.init.constant_(layer.bias, inverse_sigmoid(self.cfg.init_density))
+
+ self.out_layers.append(layer)
+
+ def forward(self, x, pts):
+ ret = {}
+ for k, layer in zip(self.cfg.feature_channels.keys(), self.out_layers):
+ v = layer(x)
+ if k == "rotation":
+ v = torch.nn.functional.normalize(v)
+ elif k == "scaling":
+ v = trunc_exp(v)
+ if self.cfg.clip_scaling is not None:
+ v = torch.clamp(v, min=0, max=self.cfg.clip_scaling)
+ elif k == "opacity":
+ v = torch.sigmoid(v)
+ elif k == "shs":
+ if self.cfg.use_rgb:
+ v = torch.sigmoid(v)
+ v = torch.reshape(v, (v.shape[0], -1, 3))
+ elif k == "xyz":
+ if self.cfg.restrict_offset:
+ max_step = 1.2 / 32
+ v = (torch.sigmoid(v) - 0.5) * max_step
+ v = v + pts if self.cfg.xyz_offset else pts
+ ret[k] = v
+
+ return GaussianModel(**ret)
+
+class GS3DRenderer(BaseModule):
+ @dataclass
+ class Config(BaseModule.Config):
+ mlp_network_config: Optional[dict] = None
+ gs_out: dict = field(default_factory=dict)
+ sh_degree: int = 3
+ scaling_modifier: float = 1.0
+ random_background: bool = False
+ radius: float = 1.0
+ feature_reduction: str = "concat"
+ projection_feature_dim: int = 773
+ background_color: Tuple[float, float, float] = field(
+ default_factory=lambda: (1.0, 1.0, 1.0)
+ )
+
+ cfg: Config
+
+ def configure(self, *args, **kwargs) -> None:
+ if self.cfg.feature_reduction == "mean":
+ mlp_in = 80
+ elif self.cfg.feature_reduction == "concat":
+ mlp_in = 80 * 3
+ else:
+ raise NotImplementedError
+ mlp_in = mlp_in + self.cfg.projection_feature_dim
+ if self.cfg.mlp_network_config is not None:
+ self.mlp_net = MLP(mlp_in, self.cfg.gs_out.in_channels, **self.cfg.mlp_network_config)
+ else:
+ self.cfg.gs_out.in_channels = mlp_in
+ self.gs_net = GSLayer(self.cfg.gs_out)
+
+ def forward_gs(self, x, p):
+ if self.cfg.mlp_network_config is not None:
+ x = self.mlp_net(x)
+ return self.gs_net(x, p)
+
+ def forward_single_view(self,
+ gs: GaussianModel,
+ viewpoint_camera: Camera,
+ background_color: Optional[Float[Tensor, "3"]],
+ ret_mask: bool = True,
+ ):
+ # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
+ screenspace_points = torch.zeros_like(gs.xyz, dtype=gs.xyz.dtype, requires_grad=True, device=self.device) + 0
+ try:
+ screenspace_points.retain_grad()
+ except:
+ pass
+
+ bg_color = background_color
+ # Set up rasterization configuration
+ tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
+ tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)
+
+ raster_settings = GaussianRasterizationSettings(
+ image_height=int(viewpoint_camera.height),
+ image_width=int(viewpoint_camera.width),
+ tanfovx=tanfovx,
+ tanfovy=tanfovy,
+ bg=bg_color,
+ scale_modifier=self.cfg.scaling_modifier,
+ viewmatrix=viewpoint_camera.world_view_transform,
+ projmatrix=viewpoint_camera.full_proj_transform.float(),
+ sh_degree=self.cfg.sh_degree,
+ campos=viewpoint_camera.camera_center,
+ prefiltered=False,
+ debug=False
+ )
+
+ rasterizer = GaussianRasterizer(raster_settings=raster_settings)
+
+ means3D = gs.xyz
+ means2D = screenspace_points
+ opacity = gs.opacity
+
+ # If precomputed 3d covariance is provided, use it. If not, then it will be computed from
+ # scaling / rotation by the rasterizer.
+ scales = None
+ rotations = None
+ cov3D_precomp = None
+ scales = gs.scaling
+ rotations = gs.rotation
+
+ # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors
+ # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.
+ shs = None
+ colors_precomp = None
+ if self.gs_net.cfg.use_rgb:
+ colors_precomp = gs.shs.squeeze(1)
+ else:
+ shs = gs.shs
+
+ # Rasterize visible Gaussians to image, obtain their radii (on screen).
+ with torch.autocast(device_type=self.device.type, dtype=torch.float32):
+ rendered_image, radii = rasterizer(
+ means3D = means3D,
+ means2D = means2D,
+ shs = shs,
+ colors_precomp = colors_precomp,
+ opacities = opacity,
+ scales = scales,
+ rotations = rotations,
+ cov3D_precomp = cov3D_precomp)
+
+ ret = {
+ "comp_rgb": rendered_image.permute(1, 2, 0),
+ "comp_rgb_bg": bg_color
+ }
+
+ if ret_mask:
+ mask_bg_color = torch.zeros(3, dtype=torch.float32, device=self.device)
+ raster_settings = GaussianRasterizationSettings(
+ image_height=int(viewpoint_camera.height),
+ image_width=int(viewpoint_camera.width),
+ tanfovx=tanfovx,
+ tanfovy=tanfovy,
+ bg=mask_bg_color,
+ scale_modifier=self.cfg.scaling_modifier,
+ viewmatrix=viewpoint_camera.world_view_transform,
+ projmatrix=viewpoint_camera.full_proj_transform.float(),
+ sh_degree=0,
+ campos=viewpoint_camera.camera_center,
+ prefiltered=False,
+ debug=False
+ )
+ rasterizer = GaussianRasterizer(raster_settings=raster_settings)
+
+ with torch.autocast(device_type=self.device.type, dtype=torch.float32):
+ rendered_mask, radii = rasterizer(
+ means3D = means3D,
+ means2D = means2D,
+ # shs = ,
+ colors_precomp = torch.ones_like(means3D),
+ opacities = opacity,
+ scales = scales,
+ rotations = rotations,
+ cov3D_precomp = cov3D_precomp)
+ ret["comp_mask"] = rendered_mask.permute(1, 2, 0)
+
+ return ret
+
+ def query_triplane(
+ self,
+ positions: Float[Tensor, "*B N 3"],
+ triplanes: Float[Tensor, "*B 3 Cp Hp Wp"],
+ ) -> Dict[str, Tensor]:
+ batched = positions.ndim == 3
+ if not batched:
+ # no batch dimension
+ triplanes = triplanes[None, ...]
+ positions = positions[None, ...]
+
+ positions = scale_tensor(positions, (-self.cfg.radius, self.cfg.radius), (-1, 1))
+ indices2D: Float[Tensor, "B 3 N 2"] = torch.stack(
+ (positions[..., [0, 1]], positions[..., [0, 2]], positions[..., [1, 2]]),
+ dim=-3,
+ )
+ out: Float[Tensor, "B3 Cp 1 N"] = F.grid_sample(
+ rearrange(triplanes, "B Np Cp Hp Wp -> (B Np) Cp Hp Wp", Np=3),
+ rearrange(indices2D, "B Np N Nd -> (B Np) () N Nd", Np=3),
+ align_corners=False,
+ mode="bilinear",
+ )
+ if self.cfg.feature_reduction == "concat":
+ out = rearrange(out, "(B Np) Cp () N -> B N (Np Cp)", Np=3)
+ elif self.cfg.feature_reduction == "mean":
+ out = reduce(out, "(B Np) Cp () N -> B N Cp", Np=3, reduction="mean")
+ else:
+ raise NotImplementedError
+
+ if not batched:
+ out = out.squeeze(0)
+
+ return out
+
+ def forward_single_batch(
+ self,
+ gs_hidden_features: Float[Tensor, "Np Cp"],
+ query_points: Float[Tensor, "Np 3"],
+ c2ws: Float[Tensor, "Nv 4 4"],
+ intrinsics: Float[Tensor, "Nv 4 4"],
+ height: int,
+ width: int,
+ background_color: Optional[Float[Tensor, "3"]],
+ ):
+ gs: GaussianModel = self.forward_gs(gs_hidden_features, query_points)
+ out_list = []
+
+ for c2w, intrinsic in zip(c2ws, intrinsics):
+ out_list.append(self.forward_single_view(
+ gs,
+ Camera.from_c2w(c2w, intrinsic, height, width),
+ background_color
+ ))
+
+ out = defaultdict(list)
+ for out_ in out_list:
+ for k, v in out_.items():
+ out[k].append(v)
+ out = {k: torch.stack(v, dim=0) for k, v in out.items()}
+ out["3dgs"] = gs
+
+ return out
+
+ def forward(self,
+ gs_hidden_features: Float[Tensor, "B Np Cp"],
+ query_points: Float[Tensor, "B Np 3"],
+ c2w: Float[Tensor, "B Nv 4 4"],
+ intrinsic: Float[Tensor, "B Nv 4 4"],
+ height,
+ width,
+ additional_features: Optional[Float[Tensor, "B C H W"]] = None,
+ background_color: Optional[Float[Tensor, "B 3"]] = None,
+ **kwargs):
+ batch_size = gs_hidden_features.shape[0]
+ out_list = []
+ gs_hidden_features = self.query_triplane(query_points, gs_hidden_features)
+ if additional_features is not None:
+ gs_hidden_features = torch.cat([gs_hidden_features, additional_features], dim=-1)
+
+ for b in range(batch_size):
+ out_list.append(self.forward_single_batch(
+ gs_hidden_features[b],
+ query_points[b],
+ c2w[b],
+ intrinsic[b],
+ height, width,
+ background_color[b] if background_color is not None else None))
+
+ out = defaultdict(list)
+ for out_ in out_list:
+ for k, v in out_.items():
+ out[k].append(v)
+ for k, v in out.items():
+ if isinstance(v[0], torch.Tensor):
+ out[k] = torch.stack(v, dim=0)
+ else:
+ out[k] = v
+ return out
+
\ No newline at end of file
diff --git a/hort/models/tgs/models/snowflake/LICENSE b/hort/models/tgs/models/snowflake/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..01c2b54a0360c374c480e1c6b606b18be9e73792
--- /dev/null
+++ b/hort/models/tgs/models/snowflake/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2021 AllenXiang
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/hort/models/tgs/models/snowflake/SPD.py b/hort/models/tgs/models/snowflake/SPD.py
new file mode 100644
index 0000000000000000000000000000000000000000..eae592c8c4792cc4067bb5fa0ea97c69a250a6bb
--- /dev/null
+++ b/hort/models/tgs/models/snowflake/SPD.py
@@ -0,0 +1,68 @@
+# -*- coding: utf-8 -*-
+# @Author: Peng Xiang
+
+import torch
+import torch.nn as nn
+from .utils import MLP_Res, MLP_CONV
+from .skip_transformer import SkipTransformer
+
+
+class SPD(nn.Module):
+ def __init__(self, dim_feat=512, up_factor=2, i=0, radius=1, bounding=True, global_feat=True):
+ """Snowflake Point Deconvolution"""
+ super(SPD, self).__init__()
+ self.i = i
+ self.up_factor = up_factor
+
+ self.bounding = bounding
+ self.radius = radius
+
+ self.global_feat = global_feat
+ self.ps_dim = 32 if global_feat else 64
+
+ self.mlp_1 = MLP_CONV(in_channel=3, layer_dims=[64, 128])
+ self.mlp_2 = MLP_CONV(in_channel=128 * 2 + dim_feat if self.global_feat else 128, layer_dims=[256, 128])
+
+ self.skip_transformer = SkipTransformer(in_channel=128, dim=64)
+
+ self.mlp_ps = MLP_CONV(in_channel=128, layer_dims=[64, self.ps_dim])
+ self.ps = nn.ConvTranspose1d(self.ps_dim, 128, up_factor, up_factor, bias=False) # point-wise splitting
+
+ self.up_sampler = nn.Upsample(scale_factor=up_factor)
+ self.mlp_delta_feature = MLP_Res(in_dim=256, hidden_dim=128, out_dim=128)
+
+ self.mlp_delta = MLP_CONV(in_channel=128, layer_dims=[64, 3])
+
+ def forward(self, pcd_prev, feat_global=None, K_prev=None):
+ """
+ Args:
+ pcd_prev: Tensor, (B, 3, N_prev)
+ feat_global: Tensor, (B, dim_feat, 1)
+ K_prev: Tensor, (B, 128, N_prev)
+
+ Returns:
+ pcd_child: Tensor, up sampled point cloud, (B, 3, N_prev * up_factor)
+ K_curr: Tensor, displacement feature of current step, (B, 128, N_prev * up_factor)
+ """
+ b, _, n_prev = pcd_prev.shape
+ feat_1 = self.mlp_1(pcd_prev)
+ feat_1 = torch.cat([feat_1,
+ torch.max(feat_1, 2, keepdim=True)[0].repeat((1, 1, feat_1.size(2))),
+ feat_global.repeat(1, 1, feat_1.size(2))], 1) if self.global_feat else feat_1
+ Q = self.mlp_2(feat_1)
+
+ H = self.skip_transformer(pcd_prev, K_prev if K_prev is not None else Q, Q)
+
+ feat_child = self.mlp_ps(H)
+ feat_child = self.ps(feat_child) # (B, 128, N_prev * up_factor)
+ H_up = self.up_sampler(H)
+ K_curr = self.mlp_delta_feature(torch.cat([feat_child, H_up], 1))
+
+ delta = self.mlp_delta(torch.relu(K_curr))
+ if self.bounding:
+ delta = torch.tanh(delta) / self.radius**self.i # (B, 3, N_prev * up_factor)
+
+ pcd_child = self.up_sampler(pcd_prev)
+ pcd_child = pcd_child + delta
+
+ return pcd_child, K_curr
\ No newline at end of file
diff --git a/hort/models/tgs/models/snowflake/SPD_crossattn.py b/hort/models/tgs/models/snowflake/SPD_crossattn.py
new file mode 100644
index 0000000000000000000000000000000000000000..90273e9e4b58df393858caab8bf5110eccccee31
--- /dev/null
+++ b/hort/models/tgs/models/snowflake/SPD_crossattn.py
@@ -0,0 +1,81 @@
+# -*- coding: utf-8 -*-
+# @Author: Peng Xiang
+
+import torch
+import torch.nn as nn
+from .utils import MLP_Res, MLP_CONV
+from .skip_transformer import SkipTransformer
+from .attention import ResidualTransformerBlock
+
+class SPD_crossattn(nn.Module):
+ def __init__(self, dim_feat=512, up_factor=2, i=0, radius=1, bounding=True, global_feat=True):
+ """Snowflake Point Deconvolution"""
+ super().__init__()
+ self.i = i
+ self.up_factor = up_factor
+
+ self.bounding = bounding
+ self.radius = radius
+
+ self.global_feat = global_feat
+ self.ps_dim = 32 if global_feat else 64
+
+ self.mlp_1 = MLP_CONV(in_channel=3, layer_dims=[64, 128])
+ self.pcd_image_attn = ResidualTransformerBlock(
+ device=torch.device('cuda'),
+ dtype=torch.float32,
+ n_data=128,
+ width=128,
+ heads=8,
+ init_scale=1.0,
+ )
+
+ self.mlp_2 = MLP_CONV(in_channel=128 * 2 + dim_feat if self.global_feat else 128, layer_dims=[256, 128])
+
+ self.skip_transformer = SkipTransformer(in_channel=128, dim=64)
+
+ self.mlp_ps = MLP_CONV(in_channel=128, layer_dims=[64, self.ps_dim])
+ self.ps = nn.ConvTranspose1d(self.ps_dim, 128, up_factor, up_factor, bias=False) # point-wise splitting
+
+ self.up_sampler = nn.Upsample(scale_factor=up_factor)
+ self.mlp_delta_feature = MLP_Res(in_dim=256, hidden_dim=128, out_dim=128)
+
+ self.mlp_delta = MLP_CONV(in_channel=128, layer_dims=[64, 3])
+
+ def forward(self, pcd_prev, feat_global=None, K_prev=None):
+ """
+ Args:
+ pcd_prev: Tensor, (B, 3, N_prev)
+ feat_global: Tensor, (B, dim_feat, 1)
+ K_prev: Tensor, (B, 128, N_prev)
+
+ Returns:
+ pcd_child: Tensor, up sampled point cloud, (B, 3, N_prev * up_factor)
+ K_curr: Tensor, displacement feature of current step, (B, 128, N_prev * up_factor)
+ """
+ b, _, n_prev = pcd_prev.shape
+ feat_1 = self.mlp_1(pcd_prev)
+ # feat_1 = torch.cat([feat_1,
+ # torch.max(feat_1, 2, keepdim=True)[0].repeat((1, 1, feat_1.size(2))),
+ # feat_global.repeat(1, 1, feat_1.size(2))], 1) if self.global_feat else feat_1
+ feat_1 = torch.permute(feat_1, (0, 2, 1))
+ feat_global = torch.permute(feat_global, (0, 2, 1))
+ feat_1 = self.pcd_image_attn(feat_1, feat_global)
+ Q = torch.permute(feat_1, (0, 2, 1))
+ # Q = self.mlp_2(feat_1)
+
+ H = self.skip_transformer(pcd_prev, K_prev if K_prev is not None else Q, Q)
+
+ feat_child = self.mlp_ps(H)
+ feat_child = self.ps(feat_child) # (B, 128, N_prev * up_factor)
+ H_up = self.up_sampler(H)
+ K_curr = self.mlp_delta_feature(torch.cat([feat_child, H_up], 1))
+
+ delta = self.mlp_delta(torch.relu(K_curr))
+ if self.bounding:
+ delta = torch.tanh(delta) / self.radius**self.i # (B, 3, N_prev * up_factor)
+
+ pcd_child = self.up_sampler(pcd_prev)
+ pcd_child = pcd_child + delta
+
+ return pcd_child, K_curr
\ No newline at end of file
diff --git a/hort/models/tgs/models/snowflake/SPD_pp.py b/hort/models/tgs/models/snowflake/SPD_pp.py
new file mode 100644
index 0000000000000000000000000000000000000000..d7a3d3511337b864d083396ae4236922d6c974e2
--- /dev/null
+++ b/hort/models/tgs/models/snowflake/SPD_pp.py
@@ -0,0 +1,71 @@
+
+import torch
+import torch.nn as nn
+from .utils import MLP_Res, MLP_CONV
+from .skip_transformer import SkipTransformer
+
+class SPD_pp(nn.Module):
+ def __init__(self, dim_feat=512, up_factor=2, i=0, radius=1, bounding=True, global_feat=True):
+ """Snowflake Point Deconvolution"""
+ super(SPD_pp, self).__init__()
+ self.i = i
+ self.up_factor = up_factor
+
+ self.bounding = bounding
+ self.radius = radius
+
+ self.global_feat = global_feat
+ self.ps_dim = 32 if global_feat else 64
+
+ self.mlp_1 = MLP_CONV(in_channel=3, layer_dims=[64, 128])
+ self.mlp_2 = MLP_CONV(
+ in_channel=128 * 2 + dim_feat if self.global_feat else 128, layer_dims=[256, 128])
+
+ self.skip_transformer = SkipTransformer(in_channel=128, dim=64)
+
+ self.mlp_ps = MLP_CONV(in_channel=128, layer_dims=[64, self.ps_dim])
+ self.ps = nn.ConvTranspose1d(
+ self.ps_dim, 128, up_factor, up_factor, bias=False) # point-wise splitting
+
+ self.up_sampler = nn.Upsample(scale_factor=up_factor)
+ self.mlp_delta_feature = MLP_Res(
+ in_dim=256, hidden_dim=128, out_dim=128)
+
+ self.mlp_delta = MLP_CONV(in_channel=128, layer_dims=[64, 3])
+
+ def forward(self, pcd_prev, feat_cond=None, K_prev=None):
+ """
+ Args:
+ pcd_prev: Tensor, (B, 3, N_prev)
+ feat_cond: Tensor, (B, dim_feat, N_prev)
+ K_prev: Tensor, (B, 128, N_prev)
+
+ Returns:
+ pcd_child: Tensor, up sampled point cloud, (B, 3, N_prev * up_factor)
+ K_curr: Tensor, displacement feature of current step, (B, 128, N_prev * up_factor)
+ """
+ b, _, n_prev = pcd_prev.shape
+ feat_1 = self.mlp_1(pcd_prev)
+ feat_1 = torch.cat([feat_1,
+ torch.max(feat_1, 2, keepdim=True)[
+ 0].repeat((1, 1, feat_1.size(2))),
+ feat_cond], 1) if self.global_feat else feat_1
+ Q = self.mlp_2(feat_1)
+
+ H = self.skip_transformer(
+ pcd_prev, K_prev if K_prev is not None else Q, Q)
+
+ feat_child = self.mlp_ps(H)
+ feat_child = self.ps(feat_child) # (B, 128, N_prev * up_factor)
+ H_up = self.up_sampler(H)
+ K_curr = self.mlp_delta_feature(torch.cat([feat_child, H_up], 1))
+
+ delta = self.mlp_delta(torch.relu(K_curr))
+ if self.bounding:
+ # (B, 3, N_prev * up_factor)
+ delta = torch.tanh(delta) / self.radius**self.i
+
+ pcd_child = self.up_sampler(pcd_prev)
+ pcd_child = pcd_child + delta
+
+ return pcd_child, K_curr
diff --git a/hort/models/tgs/models/snowflake/attention.py b/hort/models/tgs/models/snowflake/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..51d4d80f25182062f8fee74a708807b6e505e64f
--- /dev/null
+++ b/hort/models/tgs/models/snowflake/attention.py
@@ -0,0 +1,239 @@
+import torch
+import torch.nn as nn
+import math
+import math
+from typing import Optional
+from typing import Callable, Iterable, Sequence, Union
+
+import torch
+
+def checkpoint(
+ func: Callable[..., Union[torch.Tensor, Sequence[torch.Tensor]]],
+ inputs: Sequence[torch.Tensor],
+ params: Iterable[torch.Tensor],
+ flag: bool,
+):
+ """
+ Evaluate a function without caching intermediate activations, allowing for
+ reduced memory at the expense of extra compute in the backward pass.
+ :param func: the function to evaluate.
+ :param inputs: the argument sequence to pass to `func`.
+ :param params: a sequence of parameters `func` depends on but does not
+ explicitly take as arguments.
+ :param flag: if False, disable gradient checkpointing.
+ """
+ if flag:
+ args = tuple(inputs) + tuple(params)
+ return CheckpointFunction.apply(func, len(inputs), *args)
+ else:
+ return func(*inputs)
+
+
+class CheckpointFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, run_function, length, *args):
+ ctx.run_function = run_function
+ ctx.input_tensors = list(args[:length])
+ ctx.input_params = list(args[length:])
+ with torch.no_grad():
+ output_tensors = ctx.run_function(*ctx.input_tensors)
+ return output_tensors
+
+ @staticmethod
+ def backward(ctx, *output_grads):
+ ctx.input_tensors = [x.detach().requires_grad_(True)
+ for x in ctx.input_tensors]
+ with torch.enable_grad():
+ # Fixes a bug where the first op in run_function modifies the
+ # Tensor storage in place, which is not allowed for detach()'d
+ # Tensors.
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
+ output_tensors = ctx.run_function(*shallow_copies)
+ input_grads = torch.autograd.grad(
+ output_tensors,
+ ctx.input_tensors + ctx.input_params,
+ output_grads,
+ allow_unused=True,
+ )
+ del ctx.input_tensors
+ del ctx.input_params
+ del output_tensors
+ return (None, None) + input_grads
+
+
+def init_linear(l, stddev):
+ nn.init.normal_(l.weight, std=stddev)
+ if l.bias is not None:
+ nn.init.constant_(l.bias, 0.0)
+
+class MLP(nn.Module):
+ def __init__(self, *, device: torch.device, dtype: torch.dtype, width: int, init_scale: float):
+ super().__init__()
+ self.width = width
+ self.c_fc = nn.Linear(width, width * 4, device=device, dtype=dtype)
+ self.c_proj = nn.Linear(width * 4, width, device=device, dtype=dtype)
+ self.gelu = nn.GELU()
+ init_linear(self.c_fc, init_scale)
+ init_linear(self.c_proj, init_scale)
+
+ def forward(self, x):
+ return self.c_proj(self.gelu(self.c_fc(x)))
+
+class QKVMultiheadCrossAttention(nn.Module):
+ def __init__(self, *, device: torch.device, dtype: torch.dtype, heads: int, n_data: int):
+ super().__init__()
+ self.device = device
+ self.dtype = dtype
+ self.heads = heads
+ self.n_data = n_data
+
+ def forward(self, q, kv):
+ _, n_ctx, _ = q.shape
+ bs, n_data, width = kv.shape
+ attn_ch = width // self.heads // 2
+ scale = 1 / math.sqrt(math.sqrt(attn_ch))
+ q = q.view(bs, n_ctx, self.heads, -1)
+ kv = kv.view(bs, n_data, self.heads, -1)
+ k, v = torch.split(kv, attn_ch, dim=-1)
+ weight = torch.einsum(
+ "bthc,bshc->bhts", q * scale, k * scale
+ ) # More stable with f16 than dividing afterwards
+ wdtype = weight.dtype
+ weight = torch.softmax(weight.float(), dim=-1).type(wdtype)
+ return torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1)
+
+
+
+class QKVMultiheadAttention(nn.Module):
+ def __init__(self, *, device: torch.device, dtype: torch.dtype, heads: int, n_ctx: int):
+ super().__init__()
+ self.device = device
+ self.dtype = dtype
+ self.heads = heads
+ self.n_ctx = n_ctx
+
+ def forward(self, qkv):
+ bs, n_ctx, width = qkv.shape
+ attn_ch = width // self.heads // 3
+ scale = 1 / math.sqrt(math.sqrt(attn_ch))
+ qkv = qkv.view(bs, n_ctx, self.heads, -1)
+ q, k, v = torch.split(qkv, attn_ch, dim=-1)
+ weight = torch.einsum(
+ "bthc,bshc->bhts", q * scale, k * scale
+ ) # More stable with f16 than dividing afterwards
+ wdtype = weight.dtype
+ weight = torch.softmax(weight.float(), dim=-1).type(wdtype)
+ return torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1)
+
+
+
+class MultiheadCrossAttention(nn.Module):
+ def __init__(
+ self,
+ *,
+ device: torch.device,
+ dtype: torch.dtype,
+ n_data: int,
+ width: int,
+ heads: int,
+ init_scale: float,
+ data_width: Optional[int] = None,
+ ):
+ super().__init__()
+ self.n_data = n_data
+ self.width = width
+ self.heads = heads
+ self.data_width = width if data_width is None else data_width
+ self.c_q = nn.Linear(width, width, device=device, dtype=dtype)
+ self.c_kv = nn.Linear(self.data_width, width * 2,
+ device=device, dtype=dtype)
+ self.c_proj = nn.Linear(width, width, device=device, dtype=dtype)
+ self.attention = QKVMultiheadCrossAttention(
+ device=device, dtype=dtype, heads=heads, n_data=n_data
+ )
+ init_linear(self.c_q, init_scale)
+ init_linear(self.c_kv, init_scale)
+ init_linear(self.c_proj, init_scale)
+
+ def forward(self, x, data):
+ x = self.c_q(x)
+ data = self.c_kv(data)
+ x = checkpoint(self.attention, (x, data), (), True)
+ x = self.c_proj(x)
+ return x
+
+
+class MultiheadAttention(nn.Module):
+ def __init__(
+ self,
+ *,
+ device: torch.device,
+ dtype: torch.dtype,
+ n_ctx: int,
+ width: int,
+ heads: int,
+ init_scale: float,
+ ):
+ super().__init__()
+ self.n_ctx = n_ctx
+ self.width = width
+ self.heads = heads
+ self.c_qkv = nn.Linear(width, width * 3, device=device, dtype=dtype)
+ self.c_proj = nn.Linear(width, width, device=device, dtype=dtype)
+ self.attention = QKVMultiheadAttention(device=device, dtype=dtype, heads=heads, n_ctx=n_ctx)
+ init_linear(self.c_qkv, init_scale)
+ init_linear(self.c_proj, init_scale)
+
+ def forward(self, x):
+ x = self.c_qkv(x)
+ x = checkpoint(self.attention, (x,), (), True)
+ x = self.c_proj(x)
+ return x
+
+
+class ResidualTransformerBlock(nn.Module):
+ def __init__(
+ self,
+ *,
+ device: torch.device,
+ dtype: torch.dtype,
+ n_data: int,
+ width: int,
+ heads: int,
+ data_width: Optional[int] = None,
+ init_scale: float = 1.0,
+ ):
+ super().__init__()
+
+ if data_width is None:
+ data_width = width
+
+ self.attn_cross = MultiheadCrossAttention(
+ device=device,
+ dtype=dtype,
+ n_data=n_data,
+ width=width,
+ heads=heads,
+ data_width=data_width,
+ init_scale=init_scale,
+ )
+ self.attn_self = MultiheadAttention(
+ device=device,
+ dtype=dtype,
+ n_ctx=n_data,
+ width=width,
+ heads=heads,
+ init_scale=init_scale,
+ )
+ self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype)
+ self.ln_2 = nn.LayerNorm(data_width, device=device, dtype=dtype)
+ self.ln_3 = nn.LayerNorm(width, device=device, dtype=dtype)
+ self.mlp = MLP(device=device, dtype=dtype,
+ width=width, init_scale=init_scale)
+ self.ln_4 = nn.LayerNorm(width, device=device, dtype=dtype)
+
+ def forward(self, x: torch.Tensor, data: torch.Tensor):
+ x = x + self.attn_cross(self.ln_1(x), self.ln_2(data))
+ x = x + self.attn_self(self.ln_3(x))
+ x = x + self.mlp(self.ln_4(x))
+ return x
\ No newline at end of file
diff --git a/hort/models/tgs/models/snowflake/model_spdpp.py b/hort/models/tgs/models/snowflake/model_spdpp.py
new file mode 100644
index 0000000000000000000000000000000000000000..8787859f784a085b79b8dbdef6cff0de63fa8c0c
--- /dev/null
+++ b/hort/models/tgs/models/snowflake/model_spdpp.py
@@ -0,0 +1,239 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from tgs.utils.base import BaseModule
+from tgs.utils.typing import *
+from dataclasses import dataclass, field
+
+from pytorch3d.renderer import (
+ AlphaCompositor,
+ NormWeightedCompositor,
+ PointsRasterizationSettings,
+ PointsRasterizer,
+ PointsRenderer)
+from pytorch3d.renderer.cameras import CamerasBase
+from pytorch3d.structures import Pointclouds
+from pytorch3d.utils.camera_conversions import cameras_from_opencv_projection
+
+from .utils import fps_subsample
+from einops import rearrange
+
+from .utils import MLP_CONV
+from .SPD import SPD
+from .SPD_crossattn import SPD_crossattn
+from .SPD_pp import SPD_pp
+
+SPD_BLOCK = {
+ 'SPD': SPD,
+ 'SPD_crossattn': SPD_crossattn,
+ 'SPD_PP': SPD_pp,
+}
+
+
+def homoify(points):
+ """
+ Convert a batch of points to homogeneous coordinates.
+ Args:
+ points: e.g. (B, N, 3) or (N, 3)
+ Returns:
+ homoified points: e.g., (B, N, 4)
+ """
+ points_dim = points.shape[:-1] + (1,)
+ ones = points.new_ones(points_dim)
+
+ return torch.cat([points, ones], dim=-1)
+
+
+def dehomoify(points):
+ """
+ Convert a batch of homogeneous points to cartesian coordinates.
+ Args:
+ homogeneous points: (B, N, 4/3) or (N, 4/3)
+ Returns:
+ cartesian points: (B, N, 3/2)
+ """
+ return points[..., :-1] / points[..., -1:]
+
+
+def mask_generation(points: Float[Tensor, "B Np 3"],
+ intrinsics: Float[Tensor, "B 3 3"],
+ input_img: Float[Tensor, "B C H W"],
+ raster_point_radius: float = 0.01, # point size
+ raster_points_per_pixel: int = 1, # a single point per pixel, for now
+ bin_size: int = 0):
+ """
+ points: (B, Np, 3)
+ """
+ B, C, H, W = input_img.shape
+ device = intrinsics.device
+
+ cam_R = torch.eye(3).to(device).unsqueeze(0).repeat(B, 1, 1)
+ cam_t = torch.zeros(3).to(device).unsqueeze(0).repeat(B, 1)
+
+ raster_settings = PointsRasterizationSettings(image_size=(H, W), radius=raster_point_radius, points_per_pixel=raster_points_per_pixel, bin_size=bin_size)
+
+ image_size = torch.as_tensor([H, W]).view(1, 2).expand(B, -1).to(device)
+ cameras = cameras_from_opencv_projection(cam_R, cam_t, intrinsics, image_size)
+
+ rasterize = PointsRasterizer(cameras=cameras, raster_settings=raster_settings)
+ fragments = rasterize(Pointclouds(points))
+
+ fragments_idx: Tensor = fragments.idx.long()
+ mask = (fragments_idx[..., 0] > -1)
+
+ return mask.float()
+
+
+def points_projection(points: Float[Tensor, "B Np 3"],
+ intrinsics: Float[Tensor, "B 3 3"],
+ local_features: Float[Tensor, "B C H W"],
+ raster_point_radius: float = 0.0075, # point size
+ raster_points_per_pixel: int = 1, # a single point per pixel, for now
+ bin_size: int = 0):
+ """
+ points: (B, Np, 3)
+ """
+ B, C, H, W = local_features.shape
+ device = local_features.device
+ cam_R = torch.eye(3).to(device).unsqueeze(0).repeat(B, 1, 1)
+ cam_t = torch.zeros(3).to(device).unsqueeze(0).repeat(B, 1)
+
+ raster_settings = PointsRasterizationSettings(image_size=(H, W), radius=raster_point_radius, points_per_pixel=raster_points_per_pixel, bin_size=bin_size)
+ Np = points.shape[1]
+ R = raster_settings.points_per_pixel
+ image_size = torch.as_tensor([H, W]).view(1, 2).expand(B, -1).to(device)
+ cameras = cameras_from_opencv_projection(cam_R, cam_t, intrinsics, image_size)
+ rasterize = PointsRasterizer(cameras=cameras, raster_settings=raster_settings)
+ fragments = rasterize(Pointclouds(points))
+ fragments_idx: Tensor = fragments.idx.long()
+ visible_pixels = (fragments_idx > -1) # (B, H, W, R)
+ points_to_visible_pixels = fragments_idx[visible_pixels]
+ # Reshape local features to (B, H, W, R, C)
+ local_features = local_features.permute(0, 2, 3, 1).unsqueeze(-2).expand(-1, -1, -1, R, -1) # (B, H, W, R, C)
+ # Get local features corresponding to visible points
+ local_features_proj = torch.zeros(B * Np, C, device=device)
+ local_features_proj[points_to_visible_pixels] = local_features[visible_pixels]
+ local_features_proj = local_features_proj.reshape(B, Np, C)
+ return local_features_proj
+
+
+def points_projection_v2(input_xyz_points, cam_intr, feature_maps):
+ input_points = input_xyz_points.clone()
+ batch_size = input_points.shape[0]
+ xyz = input_points[:, :, :3]
+ homo_xyz = homoify(xyz)
+ homo_xyz_2d = torch.matmul(cam_intr, homo_xyz.transpose(1, 2)).transpose(1, 2)
+ xyz_2d = (homo_xyz_2d[:, :, :2] / homo_xyz_2d[:, :, [2]]).unsqueeze(2)
+ uv_2d = xyz_2d / 224 * 2 - 1
+ sample_feat = torch.nn.functional.grid_sample(feature_maps, uv_2d, align_corners=True)[:, :, :, 0].transpose(1, 2)
+ uv_2d = uv_2d.squeeze(2).reshape((-1, 2))
+ validity = (uv_2d[:, 0] >= -1.0) & (uv_2d[:, 0] <= 1.0) & (uv_2d[:, 1] >= -1.0) & (uv_2d[:, 1] <= 1.0)
+ validity = validity.unsqueeze(1)
+
+ return sample_feat
+
+
+class Decoder(nn.Module):
+ def __init__(self, input_channels=1152, dim_feat=512, num_p0=512,
+ radius=1, bounding=True, up_factors=None,
+ SPD_type='SPD',
+ token_type='image_token'
+ ):
+ super(Decoder, self).__init__()
+ # self.decoder_coarse = SeedGenerator(dim_feat=dim_feat, num_pc=num_p0)
+ if up_factors is None:
+ up_factors = [1]
+ else:
+ up_factors = up_factors
+ uppers = []
+ self.num_p0 = num_p0
+ self.mlp_feat_cond = MLP_CONV(in_channel=input_channels,
+ layer_dims=[dim_feat*2, dim_feat])
+
+ for i, factor in enumerate(up_factors):
+ uppers.append(
+ SPD_BLOCK[SPD_type](dim_feat=dim_feat, up_factor=factor,
+ i=i, bounding=bounding, radius=radius))
+ self.uppers = nn.ModuleList(uppers)
+ self.token_type = token_type
+
+ def calculate_pcl_token(self, pcl_token, up_factor):
+ up_token = F.interpolate(pcl_token, scale_factor=up_factor, mode='nearest')
+ return up_token
+
+ def calculate_image_token(self, pcd, input_image_tokens, batch):
+ """
+ Args:
+ """
+ batch_size = input_image_tokens.shape[0]
+ h_cond, w_cond = 224, 224
+ input_image_tokens = input_image_tokens.permute(0, 2, 1)
+ local_features = input_image_tokens[:, 1:].reshape(batch_size, h_cond // 14, w_cond // 14, -1).permute(0, 3, 1, 2)
+ # local_features = F.interpolate(local_features, size=(h_cond, w_cond), mode='bilinear', align_corners=False)
+ local_features_proj = points_projection_v2(pcd * batch['scale'] + batch['trans'].unsqueeze(1), batch['intrinsic_cond'], local_features)
+ local_features_proj = local_features_proj.permute(0, 2, 1).contiguous()
+
+ return local_features_proj
+
+ def forward(self, x):
+ """
+ Args:
+ points: Tensor, (b, num_p0, 3)
+ feat_cond: Tensor, (b, dim_feat) dinov2: 325x768
+ # partial_coarse: Tensor, (b, n_coarse, 3)
+ """
+ points = x['points']
+ if self.token_type == 'pcl_token':
+ feat_cond = x['pcl_token']
+ elif self.token_type == 'image_token':
+ feat_cond = x['input_image_tokens']
+ feat_cond = self.mlp_feat_cond(feat_cond)
+ arr_pcd = []
+ feat_prev = None
+
+ pcd = torch.permute(points, (0, 2, 1)).contiguous()
+ pcl_up_scale = 1
+ for upper in self.uppers:
+ if self.token_type == 'pcl_token':
+ up_cond = self.calculate_pcl_token(
+ feat_cond, pcl_up_scale)
+ pcl_up_scale *= upper.up_factor
+ elif self.token_type == 'image_token':
+ up_cond = self.calculate_image_token(points, feat_cond, x)
+ pcd, feat_prev = upper(pcd, up_cond, feat_prev)
+ points = torch.permute(pcd, (0, 2, 1)).contiguous()
+ arr_pcd.append(points)
+ return arr_pcd
+
+
+class SnowflakeModelSPDPP(BaseModule):
+ """
+ apply PC^2 / PCL token to decoder
+ """
+ @dataclass
+ class Config(BaseModule.Config):
+ input_channels: int = 1152
+ dim_feat: int = 128
+ num_p0: int = 512
+ radius: float = 1
+ bounding: bool = True
+ use_fps: bool = True
+ up_factors: List[int] = field(default_factory=lambda: [2, 2])
+ image_full_token_cond: bool = False
+ SPD_type: str = 'SPD_PP'
+ token_type: str = 'pcl_token'
+ cfg: Config
+
+ def configure(self) -> None:
+ super().configure()
+ self.decoder = Decoder(input_channels=self.cfg.input_channels,
+ dim_feat=self.cfg.dim_feat, num_p0=self.cfg.num_p0,
+ radius=self.cfg.radius, up_factors=self.cfg.up_factors, bounding=self.cfg.bounding,
+ SPD_type=self.cfg.SPD_type,
+ token_type=self.cfg.token_type
+ )
+
+ def forward(self, x):
+ results = self.decoder(x)
+ return results
diff --git a/hort/models/tgs/models/snowflake/pointnet2.py b/hort/models/tgs/models/snowflake/pointnet2.py
new file mode 100644
index 0000000000000000000000000000000000000000..01aae13aa2fd00a4f4e9a843f310a6698748f135
--- /dev/null
+++ b/hort/models/tgs/models/snowflake/pointnet2.py
@@ -0,0 +1,126 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from pointnet2_ops.pointnet2_modules import PointnetFPModule, PointnetSAModule
+
+
+class PointNet2ClassificationSSG(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self._build_model()
+
+ def _build_model(self):
+ self.SA_modules = nn.ModuleList()
+ self.SA_modules.append(
+ PointnetSAModule(
+ npoint=512,
+ radius=0.2,
+ nsample=64,
+ mlp=[3, 64, 64, 128],
+ use_xyz=True,
+ )
+ )
+ self.SA_modules.append(
+ PointnetSAModule(
+ npoint=128,
+ radius=0.4,
+ nsample=64,
+ mlp=[128, 128, 128, 256],
+ use_xyz=True,
+ )
+ )
+ self.SA_modules.append(
+ PointnetSAModule(
+ mlp=[256, 256, 512, 1024], use_xyz=True,
+ )
+ )
+
+ self.fc_layer = nn.Sequential(
+ nn.Linear(1024, 512, bias=False),
+ nn.BatchNorm1d(512),
+ nn.ReLU(True),
+ nn.Linear(512, 256, bias=False),
+ nn.BatchNorm1d(256),
+ nn.ReLU(True),
+ nn.Dropout(0.5),
+ nn.Linear(256, 40),
+ )
+
+ def _break_up_pc(self, pc):
+ xyz = pc[..., 0:3].contiguous()
+ features = pc[..., 3:].transpose(1, 2).contiguous() if pc.size(-1) > 3 else None
+
+ return xyz, features
+
+ def forward(self, pointcloud):
+ r"""
+ Forward pass of the network
+
+ Parameters
+ ----------
+ pointcloud: Variable(torch.cuda.FloatTensor)
+ (B, N, 3 + input_channels) tensor
+ Point cloud to run predicts on
+ Each point in the point-cloud MUST
+ be formated as (x, y, z, features...)
+ """
+ xyz, features = self._break_up_pc(pointcloud)
+
+ for module in self.SA_modules:
+ xyz, features = module(xyz, features)
+
+ return self.fc_layer(features.squeeze(-1))
+
+
+class PointNet2SemSegSSG(PointNet2ClassificationSSG):
+ def _build_model(self):
+ self.SA_modules = nn.ModuleList()
+ self.SA_modules.append(
+ PointnetSAModule(
+ npoint=256,
+ radius=0.05,
+ nsample=32,
+ mlp=[1, 32, 64],
+ use_xyz=True,
+ )
+ )
+ self.SA_modules.append(
+ PointnetSAModule(
+ npoint=64,
+ radius=0.10,
+ nsample=32,
+ mlp=[64, 128, 256],
+ use_xyz=True,
+ )
+ )
+ self.SA_modules.append(
+ PointnetSAModule(
+ npoint=16,
+ radius=0.20,
+ nsample=32,
+ mlp=[256, 512, 768],
+ use_xyz=True,
+ )
+ )
+
+ def forward(self, pointcloud):
+ r"""
+ Forward pass of the network
+
+ Parameters
+ ----------
+ pointcloud: Variable(torch.cuda.FloatTensor)
+ (B, N, 3 + input_channels) tensor
+ Point cloud to run predicts on
+ Each point in the point-cloud MUST
+ be formated as (x, y, z, features...)
+ """
+ xyz, features = self._break_up_pc(pointcloud)
+
+ l_xyz, l_features = [xyz], [features]
+ for i in range(len(self.SA_modules)):
+ li_xyz, li_features = self.SA_modules[i](l_xyz[i], l_features[i])
+ l_xyz.append(li_xyz)
+ l_features.append(li_features)
+
+ return l_features[-1].transpose(2, 1)
diff --git a/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/__init__.py b/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5fd361f9abbacc218f7699b8c439902b9d1bf745
--- /dev/null
+++ b/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/__init__.py
@@ -0,0 +1,3 @@
+import pointnet2_ops.pointnet2_modules
+import pointnet2_ops.pointnet2_utils
+from pointnet2_ops._version import __version__
diff --git a/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/ball_query.h b/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/ball_query.h
new file mode 100644
index 0000000000000000000000000000000000000000..1bbc6389c2837ad801e991643dae0c0401ee782c
--- /dev/null
+++ b/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/ball_query.h
@@ -0,0 +1,5 @@
+#pragma once
+#include
+
+at::Tensor ball_query(at::Tensor new_xyz, at::Tensor xyz, const float radius,
+ const int nsample);
diff --git a/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/cuda_utils.h b/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/cuda_utils.h
new file mode 100644
index 0000000000000000000000000000000000000000..0fd5b6edcfe3e7f7a03bd75d0ffdc9fc92be25eb
--- /dev/null
+++ b/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/cuda_utils.h
@@ -0,0 +1,41 @@
+#ifndef _CUDA_UTILS_H
+#define _CUDA_UTILS_H
+
+#include
+#include
+#include
+
+#include
+#include
+
+#include
+
+#define TOTAL_THREADS 512
+
+inline int opt_n_threads(int work_size) {
+ const int pow_2 = std::log(static_cast(work_size)) / std::log(2.0);
+
+ return max(min(1 << pow_2, TOTAL_THREADS), 1);
+}
+
+inline dim3 opt_block_config(int x, int y) {
+ const int x_threads = opt_n_threads(x);
+ const int y_threads =
+ max(min(opt_n_threads(y), TOTAL_THREADS / x_threads), 1);
+ dim3 block_config(x_threads, y_threads, 1);
+
+ return block_config;
+}
+
+#define CUDA_CHECK_ERRORS() \
+ do { \
+ cudaError_t err = cudaGetLastError(); \
+ if (cudaSuccess != err) { \
+ fprintf(stderr, "CUDA kernel failed : %s\n%s at L:%d in %s\n", \
+ cudaGetErrorString(err), __PRETTY_FUNCTION__, __LINE__, \
+ __FILE__); \
+ exit(-1); \
+ } \
+ } while (0)
+
+#endif
diff --git a/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/group_points.h b/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/group_points.h
new file mode 100644
index 0000000000000000000000000000000000000000..ad20cda9e68c92fd4e05a319f31fdcf7ef1e6427
--- /dev/null
+++ b/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/group_points.h
@@ -0,0 +1,5 @@
+#pragma once
+#include
+
+at::Tensor group_points(at::Tensor points, at::Tensor idx);
+at::Tensor group_points_grad(at::Tensor grad_out, at::Tensor idx, const int n);
diff --git a/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/interpolate.h b/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/interpolate.h
new file mode 100644
index 0000000000000000000000000000000000000000..26b34648396783c9858f8f4e10b869a6e74e6a11
--- /dev/null
+++ b/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/interpolate.h
@@ -0,0 +1,10 @@
+#pragma once
+
+#include
+#include
+
+std::vector three_nn(at::Tensor unknowns, at::Tensor knows);
+at::Tensor three_interpolate(at::Tensor points, at::Tensor idx,
+ at::Tensor weight);
+at::Tensor three_interpolate_grad(at::Tensor grad_out, at::Tensor idx,
+ at::Tensor weight, const int m);
diff --git a/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/sampling.h b/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/sampling.h
new file mode 100644
index 0000000000000000000000000000000000000000..d795271200c1a65525a509282ccac058f22f2422
--- /dev/null
+++ b/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/sampling.h
@@ -0,0 +1,6 @@
+#pragma once
+#include
+
+at::Tensor gather_points(at::Tensor points, at::Tensor idx);
+at::Tensor gather_points_grad(at::Tensor grad_out, at::Tensor idx, const int n);
+at::Tensor furthest_point_sampling(at::Tensor points, const int nsamples);
diff --git a/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/utils.h b/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/utils.h
new file mode 100644
index 0000000000000000000000000000000000000000..5f080ed1e455d3bff318e11cc893e3856c132099
--- /dev/null
+++ b/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/utils.h
@@ -0,0 +1,25 @@
+#pragma once
+#include
+#include
+
+#define CHECK_CUDA(x) \
+ do { \
+ AT_ASSERT(x.is_cuda(), #x " must be a CUDA tensor"); \
+ } while (0)
+
+#define CHECK_CONTIGUOUS(x) \
+ do { \
+ AT_ASSERT(x.is_contiguous(), #x " must be a contiguous tensor"); \
+ } while (0)
+
+#define CHECK_IS_INT(x) \
+ do { \
+ AT_ASSERT(x.scalar_type() == at::ScalarType::Int, \
+ #x " must be an int tensor"); \
+ } while (0)
+
+#define CHECK_IS_FLOAT(x) \
+ do { \
+ AT_ASSERT(x.scalar_type() == at::ScalarType::Float, \
+ #x " must be a float tensor"); \
+ } while (0)
diff --git a/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/ball_query.cpp b/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/ball_query.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..b1797c1aeecd32ecaf0cb614c4f2a23a1be9b777
--- /dev/null
+++ b/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/ball_query.cpp
@@ -0,0 +1,32 @@
+#include "ball_query.h"
+#include "utils.h"
+
+void query_ball_point_kernel_wrapper(int b, int n, int m, float radius,
+ int nsample, const float *new_xyz,
+ const float *xyz, int *idx);
+
+at::Tensor ball_query(at::Tensor new_xyz, at::Tensor xyz, const float radius,
+ const int nsample) {
+ CHECK_CONTIGUOUS(new_xyz);
+ CHECK_CONTIGUOUS(xyz);
+ CHECK_IS_FLOAT(new_xyz);
+ CHECK_IS_FLOAT(xyz);
+
+ if (new_xyz.is_cuda()) {
+ CHECK_CUDA(xyz);
+ }
+
+ at::Tensor idx =
+ torch::zeros({new_xyz.size(0), new_xyz.size(1), nsample},
+ at::device(new_xyz.device()).dtype(at::ScalarType::Int));
+
+ if (new_xyz.is_cuda()) {
+ query_ball_point_kernel_wrapper(xyz.size(0), xyz.size(1), new_xyz.size(1),
+ radius, nsample, new_xyz.data_ptr(),
+ xyz.data_ptr(), idx.data_ptr());
+ } else {
+ AT_ASSERT(false, "CPU not supported");
+ }
+
+ return idx;
+}
diff --git a/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/ball_query_gpu.cu b/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/ball_query_gpu.cu
new file mode 100644
index 0000000000000000000000000000000000000000..8e38e9c1cfb2fd240b1b02d8ef5371b10017aaea
--- /dev/null
+++ b/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/ball_query_gpu.cu
@@ -0,0 +1,54 @@
+#include
+#include
+#include
+
+#include "cuda_utils.h"
+
+// input: new_xyz(b, m, 3) xyz(b, n, 3)
+// output: idx(b, m, nsample)
+__global__ void query_ball_point_kernel(int b, int n, int m, float radius,
+ int nsample,
+ const float *__restrict__ new_xyz,
+ const float *__restrict__ xyz,
+ int *__restrict__ idx) {
+ int batch_index = blockIdx.x;
+ xyz += batch_index * n * 3;
+ new_xyz += batch_index * m * 3;
+ idx += m * nsample * batch_index;
+
+ int index = threadIdx.x;
+ int stride = blockDim.x;
+
+ float radius2 = radius * radius;
+ for (int j = index; j < m; j += stride) {
+ float new_x = new_xyz[j * 3 + 0];
+ float new_y = new_xyz[j * 3 + 1];
+ float new_z = new_xyz[j * 3 + 2];
+ for (int k = 0, cnt = 0; k < n && cnt < nsample; ++k) {
+ float x = xyz[k * 3 + 0];
+ float y = xyz[k * 3 + 1];
+ float z = xyz[k * 3 + 2];
+ float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) +
+ (new_z - z) * (new_z - z);
+ if (d2 < radius2) {
+ if (cnt == 0) {
+ for (int l = 0; l < nsample; ++l) {
+ idx[j * nsample + l] = k;
+ }
+ }
+ idx[j * nsample + cnt] = k;
+ ++cnt;
+ }
+ }
+ }
+}
+
+void query_ball_point_kernel_wrapper(int b, int n, int m, float radius,
+ int nsample, const float *new_xyz,
+ const float *xyz, int *idx) {
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+ query_ball_point_kernel<<>>(
+ b, n, m, radius, nsample, new_xyz, xyz, idx);
+
+ //CUDA_CHECK_ERRORS();
+}
diff --git a/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/bindings.cpp b/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/bindings.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..d1916ce1d57f9728de961eaa65e59c7318c683ae
--- /dev/null
+++ b/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/bindings.cpp
@@ -0,0 +1,19 @@
+#include "ball_query.h"
+#include "group_points.h"
+#include "interpolate.h"
+#include "sampling.h"
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("gather_points", &gather_points);
+ m.def("gather_points_grad", &gather_points_grad);
+ m.def("furthest_point_sampling", &furthest_point_sampling);
+
+ m.def("three_nn", &three_nn);
+ m.def("three_interpolate", &three_interpolate);
+ m.def("three_interpolate_grad", &three_interpolate_grad);
+
+ m.def("ball_query", &ball_query);
+
+ m.def("group_points", &group_points);
+ m.def("group_points_grad", &group_points_grad);
+}
diff --git a/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/group_points.cpp b/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/group_points.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..285a4bd42688aabcc829bbeb120a06daaa87a96e
--- /dev/null
+++ b/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/group_points.cpp
@@ -0,0 +1,62 @@
+#include "group_points.h"
+#include "utils.h"
+
+void group_points_kernel_wrapper(int b, int c, int n, int npoints, int nsample,
+ const float *points, const int *idx,
+ float *out);
+
+void group_points_grad_kernel_wrapper(int b, int c, int n, int npoints,
+ int nsample, const float *grad_out,
+ const int *idx, float *grad_points);
+
+at::Tensor group_points(at::Tensor points, at::Tensor idx) {
+ CHECK_CONTIGUOUS(points);
+ CHECK_CONTIGUOUS(idx);
+ CHECK_IS_FLOAT(points);
+ CHECK_IS_INT(idx);
+
+ if (points.is_cuda()) {
+ CHECK_CUDA(idx);
+ }
+
+ at::Tensor output =
+ torch::zeros({points.size(0), points.size(1), idx.size(1), idx.size(2)},
+ at::device(points.device()).dtype(at::ScalarType::Float));
+
+ if (points.is_cuda()) {
+ group_points_kernel_wrapper(points.size(0), points.size(1), points.size(2),
+ idx.size(1), idx.size(2),
+ points.data_ptr(), idx.data_ptr(),
+ output.data_ptr());
+ } else {
+ AT_ASSERT(false, "CPU not supported");
+ }
+
+ return output;
+}
+
+at::Tensor group_points_grad(at::Tensor grad_out, at::Tensor idx, const int n) {
+ CHECK_CONTIGUOUS(grad_out);
+ CHECK_CONTIGUOUS(idx);
+ CHECK_IS_FLOAT(grad_out);
+ CHECK_IS_INT(idx);
+
+ if (grad_out.is_cuda()) {
+ CHECK_CUDA(idx);
+ }
+
+ at::Tensor output =
+ torch::zeros({grad_out.size(0), grad_out.size(1), n},
+ at::device(grad_out.device()).dtype(at::ScalarType::Float));
+
+ if (grad_out.is_cuda()) {
+ group_points_grad_kernel_wrapper(
+ grad_out.size(0), grad_out.size(1), n, idx.size(1), idx.size(2),
+ grad_out.data_ptr(), idx.data_ptr(),
+ output.data_ptr());
+ } else {
+ AT_ASSERT(false, "CPU not supported");
+ }
+
+ return output;
+}
diff --git a/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/group_points_gpu.cu b/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/group_points_gpu.cu
new file mode 100644
index 0000000000000000000000000000000000000000..38283b72126e464f1ec89d4e05b40aecb70ede6d
--- /dev/null
+++ b/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/group_points_gpu.cu
@@ -0,0 +1,75 @@
+#include
+#include
+
+#include "cuda_utils.h"
+
+// input: points(b, c, n) idx(b, npoints, nsample)
+// output: out(b, c, npoints, nsample)
+__global__ void group_points_kernel(int b, int c, int n, int npoints,
+ int nsample,
+ const float *__restrict__ points,
+ const int *__restrict__ idx,
+ float *__restrict__ out) {
+ int batch_index = blockIdx.x;
+ points += batch_index * n * c;
+ idx += batch_index * npoints * nsample;
+ out += batch_index * npoints * nsample * c;
+
+ const int index = threadIdx.y * blockDim.x + threadIdx.x;
+ const int stride = blockDim.y * blockDim.x;
+ for (int i = index; i < c * npoints; i += stride) {
+ const int l = i / npoints;
+ const int j = i % npoints;
+ for (int k = 0; k < nsample; ++k) {
+ int ii = idx[j * nsample + k];
+ out[(l * npoints + j) * nsample + k] = points[l * n + ii];
+ }
+ }
+}
+
+void group_points_kernel_wrapper(int b, int c, int n, int npoints, int nsample,
+ const float *points, const int *idx,
+ float *out) {
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+
+ group_points_kernel<<>>(
+ b, c, n, npoints, nsample, points, idx, out);
+
+ //CUDA_CHECK_ERRORS();
+}
+
+// input: grad_out(b, c, npoints, nsample), idx(b, npoints, nsample)
+// output: grad_points(b, c, n)
+__global__ void group_points_grad_kernel(int b, int c, int n, int npoints,
+ int nsample,
+ const float *__restrict__ grad_out,
+ const int *__restrict__ idx,
+ float *__restrict__ grad_points) {
+ int batch_index = blockIdx.x;
+ grad_out += batch_index * npoints * nsample * c;
+ idx += batch_index * npoints * nsample;
+ grad_points += batch_index * n * c;
+
+ const int index = threadIdx.y * blockDim.x + threadIdx.x;
+ const int stride = blockDim.y * blockDim.x;
+ for (int i = index; i < c * npoints; i += stride) {
+ const int l = i / npoints;
+ const int j = i % npoints;
+ for (int k = 0; k < nsample; ++k) {
+ int ii = idx[j * nsample + k];
+ atomicAdd(grad_points + l * n + ii,
+ grad_out[(l * npoints + j) * nsample + k]);
+ }
+ }
+}
+
+void group_points_grad_kernel_wrapper(int b, int c, int n, int npoints,
+ int nsample, const float *grad_out,
+ const int *idx, float *grad_points) {
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+
+ group_points_grad_kernel<<>>(
+ b, c, n, npoints, nsample, grad_out, idx, grad_points);
+
+ //CUDA_CHECK_ERRORS();
+}
diff --git a/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/interpolate.cpp b/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/interpolate.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..cdee31ca7758278729c0dc2855ce5aba197104e3
--- /dev/null
+++ b/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/interpolate.cpp
@@ -0,0 +1,99 @@
+#include "interpolate.h"
+#include "utils.h"
+
+void three_nn_kernel_wrapper(int b, int n, int m, const float *unknown,
+ const float *known, float *dist2, int *idx);
+void three_interpolate_kernel_wrapper(int b, int c, int m, int n,
+ const float *points, const int *idx,
+ const float *weight, float *out);
+void three_interpolate_grad_kernel_wrapper(int b, int c, int n, int m,
+ const float *grad_out,
+ const int *idx, const float *weight,
+ float *grad_points);
+
+std::vector three_nn(at::Tensor unknowns, at::Tensor knows) {
+ CHECK_CONTIGUOUS(unknowns);
+ CHECK_CONTIGUOUS(knows);
+ CHECK_IS_FLOAT(unknowns);
+ CHECK_IS_FLOAT(knows);
+
+ if (unknowns.is_cuda()) {
+ CHECK_CUDA(knows);
+ }
+
+ at::Tensor idx =
+ torch::zeros({unknowns.size(0), unknowns.size(1), 3},
+ at::device(unknowns.device()).dtype(at::ScalarType::Int));
+ at::Tensor dist2 =
+ torch::zeros({unknowns.size(0), unknowns.size(1), 3},
+ at::device(unknowns.device()).dtype(at::ScalarType::Float));
+
+ if (unknowns.is_cuda()) {
+ three_nn_kernel_wrapper(unknowns.size(0), unknowns.size(1), knows.size(1),
+ unknowns.data_ptr(), knows.data_ptr(),
+ dist2.data_ptr(), idx.data_ptr());
+ } else {
+ AT_ASSERT(false, "CPU not supported");
+ }
+
+ return {dist2, idx};
+}
+
+at::Tensor three_interpolate(at::Tensor points, at::Tensor idx,
+ at::Tensor weight) {
+ CHECK_CONTIGUOUS(points);
+ CHECK_CONTIGUOUS(idx);
+ CHECK_CONTIGUOUS(weight);
+ CHECK_IS_FLOAT(points);
+ CHECK_IS_INT(idx);
+ CHECK_IS_FLOAT(weight);
+
+ if (points.is_cuda()) {
+ CHECK_CUDA(idx);
+ CHECK_CUDA(weight);
+ }
+
+ at::Tensor output =
+ torch::zeros({points.size(0), points.size(1), idx.size(1)},
+ at::device(points.device()).dtype(at::ScalarType::Float));
+
+ if (points.is_cuda()) {
+ three_interpolate_kernel_wrapper(
+ points.size(0), points.size(1), points.size(2), idx.size(1),
+ points.data_ptr(), idx.data_ptr(), weight.data_ptr(),
+ output.data_ptr());
+ } else {
+ AT_ASSERT(false, "CPU not supported");
+ }
+
+ return output;
+}
+at::Tensor three_interpolate_grad(at::Tensor grad_out, at::Tensor idx,
+ at::Tensor weight, const int m) {
+ CHECK_CONTIGUOUS(grad_out);
+ CHECK_CONTIGUOUS(idx);
+ CHECK_CONTIGUOUS(weight);
+ CHECK_IS_FLOAT(grad_out);
+ CHECK_IS_INT(idx);
+ CHECK_IS_FLOAT(weight);
+
+ if (grad_out.is_cuda()) {
+ CHECK_CUDA(idx);
+ CHECK_CUDA(weight);
+ }
+
+ at::Tensor output =
+ torch::zeros({grad_out.size(0), grad_out.size(1), m},
+ at::device(grad_out.device()).dtype(at::ScalarType::Float));
+
+ if (grad_out.is_cuda()) {
+ three_interpolate_grad_kernel_wrapper(
+ grad_out.size(0), grad_out.size(1), grad_out.size(2), m,
+ grad_out.data_ptr(), idx.data_ptr(),
+ weight.data_ptr(), output.data_ptr());
+ } else {
+ AT_ASSERT(false, "CPU not supported");
+ }
+
+ return output;
+}
diff --git a/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/interpolate_gpu.cu b/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/interpolate_gpu.cu
new file mode 100644
index 0000000000000000000000000000000000000000..ef245848ac797aef1ae08f8ce49cfb4338e92dc6
--- /dev/null
+++ b/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/interpolate_gpu.cu
@@ -0,0 +1,154 @@
+#include
+#include
+#include
+
+#include "cuda_utils.h"
+
+// input: unknown(b, n, 3) known(b, m, 3)
+// output: dist2(b, n, 3), idx(b, n, 3)
+__global__ void three_nn_kernel(int b, int n, int m,
+ const float *__restrict__ unknown,
+ const float *__restrict__ known,
+ float *__restrict__ dist2,
+ int *__restrict__ idx) {
+ int batch_index = blockIdx.x;
+ unknown += batch_index * n * 3;
+ known += batch_index * m * 3;
+ dist2 += batch_index * n * 3;
+ idx += batch_index * n * 3;
+
+ int index = threadIdx.x;
+ int stride = blockDim.x;
+ for (int j = index; j < n; j += stride) {
+ float ux = unknown[j * 3 + 0];
+ float uy = unknown[j * 3 + 1];
+ float uz = unknown[j * 3 + 2];
+
+ double best1 = 1e40, best2 = 1e40, best3 = 1e40;
+ int besti1 = 0, besti2 = 0, besti3 = 0;
+ for (int k = 0; k < m; ++k) {
+ float x = known[k * 3 + 0];
+ float y = known[k * 3 + 1];
+ float z = known[k * 3 + 2];
+ float d = (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z);
+ if (d < best1) {
+ best3 = best2;
+ besti3 = besti2;
+ best2 = best1;
+ besti2 = besti1;
+ best1 = d;
+ besti1 = k;
+ } else if (d < best2) {
+ best3 = best2;
+ besti3 = besti2;
+ best2 = d;
+ besti2 = k;
+ } else if (d < best3) {
+ best3 = d;
+ besti3 = k;
+ }
+ }
+ dist2[j * 3 + 0] = best1;
+ dist2[j * 3 + 1] = best2;
+ dist2[j * 3 + 2] = best3;
+
+ idx[j * 3 + 0] = besti1;
+ idx[j * 3 + 1] = besti2;
+ idx[j * 3 + 2] = besti3;
+ }
+}
+
+void three_nn_kernel_wrapper(int b, int n, int m, const float *unknown,
+ const float *known, float *dist2, int *idx) {
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+ three_nn_kernel<<>>(b, n, m, unknown, known,
+ dist2, idx);
+
+ //CUDA_CHECK_ERRORS();
+}
+
+// input: points(b, c, m), idx(b, n, 3), weight(b, n, 3)
+// output: out(b, c, n)
+__global__ void three_interpolate_kernel(int b, int c, int m, int n,
+ const float *__restrict__ points,
+ const int *__restrict__ idx,
+ const float *__restrict__ weight,
+ float *__restrict__ out) {
+ int batch_index = blockIdx.x;
+ points += batch_index * m * c;
+
+ idx += batch_index * n * 3;
+ weight += batch_index * n * 3;
+
+ out += batch_index * n * c;
+
+ const int index = threadIdx.y * blockDim.x + threadIdx.x;
+ const int stride = blockDim.y * blockDim.x;
+ for (int i = index; i < c * n; i += stride) {
+ const int l = i / n;
+ const int j = i % n;
+ float w1 = weight[j * 3 + 0];
+ float w2 = weight[j * 3 + 1];
+ float w3 = weight[j * 3 + 2];
+
+ int i1 = idx[j * 3 + 0];
+ int i2 = idx[j * 3 + 1];
+ int i3 = idx[j * 3 + 2];
+
+ out[i] = points[l * m + i1] * w1 + points[l * m + i2] * w2 +
+ points[l * m + i3] * w3;
+ }
+}
+
+void three_interpolate_kernel_wrapper(int b, int c, int m, int n,
+ const float *points, const int *idx,
+ const float *weight, float *out) {
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+ three_interpolate_kernel<<>>(
+ b, c, m, n, points, idx, weight, out);
+
+ //CUDA_CHECK_ERRORS();
+}
+
+// input: grad_out(b, c, n), idx(b, n, 3), weight(b, n, 3)
+// output: grad_points(b, c, m)
+
+__global__ void three_interpolate_grad_kernel(
+ int b, int c, int n, int m, const float *__restrict__ grad_out,
+ const int *__restrict__ idx, const float *__restrict__ weight,
+ float *__restrict__ grad_points) {
+ int batch_index = blockIdx.x;
+ grad_out += batch_index * n * c;
+ idx += batch_index * n * 3;
+ weight += batch_index * n * 3;
+ grad_points += batch_index * m * c;
+
+ const int index = threadIdx.y * blockDim.x + threadIdx.x;
+ const int stride = blockDim.y * blockDim.x;
+ for (int i = index; i < c * n; i += stride) {
+ const int l = i / n;
+ const int j = i % n;
+ float w1 = weight[j * 3 + 0];
+ float w2 = weight[j * 3 + 1];
+ float w3 = weight[j * 3 + 2];
+
+ int i1 = idx[j * 3 + 0];
+ int i2 = idx[j * 3 + 1];
+ int i3 = idx[j * 3 + 2];
+
+ atomicAdd(grad_points + l * m + i1, grad_out[i] * w1);
+ atomicAdd(grad_points + l * m + i2, grad_out[i] * w2);
+ atomicAdd(grad_points + l * m + i3, grad_out[i] * w3);
+ }
+}
+
+void three_interpolate_grad_kernel_wrapper(int b, int c, int n, int m,
+ const float *grad_out,
+ const int *idx, const float *weight,
+ float *grad_points) {
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+ three_interpolate_grad_kernel<<>>(
+ b, c, n, m, grad_out, idx, weight, grad_points);
+
+ CUDA_CHECK_ERRORS();
+}
diff --git a/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/sampling.cpp b/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/sampling.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..ddbdc11b6bbecb3aa1238f0d760bd6a47303b858
--- /dev/null
+++ b/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/sampling.cpp
@@ -0,0 +1,87 @@
+#include "sampling.h"
+#include "utils.h"
+
+void gather_points_kernel_wrapper(int b, int c, int n, int npoints,
+ const float *points, const int *idx,
+ float *out);
+void gather_points_grad_kernel_wrapper(int b, int c, int n, int npoints,
+ const float *grad_out, const int *idx,
+ float *grad_points);
+
+void furthest_point_sampling_kernel_wrapper(int b, int n, int m,
+ const float *dataset, float *temp,
+ int *idxs);
+
+at::Tensor gather_points(at::Tensor points, at::Tensor idx) {
+ CHECK_CONTIGUOUS(points);
+ CHECK_CONTIGUOUS(idx);
+ CHECK_IS_FLOAT(points);
+ CHECK_IS_INT(idx);
+
+ if (points.is_cuda()) {
+ CHECK_CUDA(idx);
+ }
+
+ at::Tensor output =
+ torch::zeros({points.size(0), points.size(1), idx.size(1)},
+ at::device(points.device()).dtype(at::ScalarType::Float));
+
+ if (points.is_cuda()) {
+ gather_points_kernel_wrapper(points.size(0), points.size(1), points.size(2),
+ idx.size(1), points.data_ptr(),
+ idx.data_ptr(), output.data_ptr());
+ } else {
+ AT_ASSERT(false, "CPU not supported");
+ }
+
+ return output;
+}
+
+at::Tensor gather_points_grad(at::Tensor grad_out, at::Tensor idx,
+ const int n) {
+ CHECK_CONTIGUOUS(grad_out);
+ CHECK_CONTIGUOUS(idx);
+ CHECK_IS_FLOAT(grad_out);
+ CHECK_IS_INT(idx);
+
+ if (grad_out.is_cuda()) {
+ CHECK_CUDA(idx);
+ }
+
+ at::Tensor output =
+ torch::zeros({grad_out.size(0), grad_out.size(1), n},
+ at::device(grad_out.device()).dtype(at::ScalarType::Float));
+
+ if (grad_out.is_cuda()) {
+ gather_points_grad_kernel_wrapper(grad_out.size(0), grad_out.size(1), n,
+ idx.size(1), grad_out.data_ptr(),
+ idx.data_ptr(),
+ output.data_ptr());
+ } else {
+ AT_ASSERT(false, "CPU not supported");
+ }
+
+ return output;
+}
+at::Tensor furthest_point_sampling(at::Tensor points, const int nsamples) {
+ CHECK_CONTIGUOUS(points);
+ CHECK_IS_FLOAT(points);
+
+ at::Tensor output =
+ torch::zeros({points.size(0), nsamples},
+ at::device(points.device()).dtype(at::ScalarType::Int));
+
+ at::Tensor tmp =
+ torch::full({points.size(0), points.size(1)}, 1e10,
+ at::device(points.device()).dtype(at::ScalarType::Float));
+
+ if (points.is_cuda()) {
+ furthest_point_sampling_kernel_wrapper(
+ points.size(0), points.size(1), nsamples, points.data_ptr(),
+ tmp.data_ptr(), output.data_ptr());
+ } else {
+ AT_ASSERT(false, "CPU not supported");
+ }
+
+ return output;
+}
diff --git a/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/sampling_gpu.cu b/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/sampling_gpu.cu
new file mode 100644
index 0000000000000000000000000000000000000000..d6f83aa6cca7f8acb3b288d1701c7b5507c532a7
--- /dev/null
+++ b/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/sampling_gpu.cu
@@ -0,0 +1,229 @@
+#include
+#include
+
+#include "cuda_utils.h"
+
+// input: points(b, c, n) idx(b, m)
+// output: out(b, c, m)
+__global__ void gather_points_kernel(int b, int c, int n, int m,
+ const float *__restrict__ points,
+ const int *__restrict__ idx,
+ float *__restrict__ out) {
+ for (int i = blockIdx.x; i < b; i += gridDim.x) {
+ for (int l = blockIdx.y; l < c; l += gridDim.y) {
+ for (int j = threadIdx.x; j < m; j += blockDim.x) {
+ int a = idx[i * m + j];
+ out[(i * c + l) * m + j] = points[(i * c + l) * n + a];
+ }
+ }
+ }
+}
+
+void gather_points_kernel_wrapper(int b, int c, int n, int npoints,
+ const float *points, const int *idx,
+ float *out) {
+ gather_points_kernel<<>>(b, c, n, npoints,
+ points, idx, out);
+
+ //CUDA_CHECK_ERRORS();
+}
+
+// input: grad_out(b, c, m) idx(b, m)
+// output: grad_points(b, c, n)
+__global__ void gather_points_grad_kernel(int b, int c, int n, int m,
+ const float *__restrict__ grad_out,
+ const int *__restrict__ idx,
+ float *__restrict__ grad_points) {
+ for (int i = blockIdx.x; i < b; i += gridDim.x) {
+ for (int l = blockIdx.y; l < c; l += gridDim.y) {
+ for (int j = threadIdx.x; j < m; j += blockDim.x) {
+ int a = idx[i * m + j];
+ atomicAdd(grad_points + (i * c + l) * n + a,
+ grad_out[(i * c + l) * m + j]);
+ }
+ }
+ }
+}
+
+void gather_points_grad_kernel_wrapper(int b, int c, int n, int npoints,
+ const float *grad_out, const int *idx,
+ float *grad_points) {
+ gather_points_grad_kernel<<>>(
+ b, c, n, npoints, grad_out, idx, grad_points);
+
+ //CUDA_CHECK_ERRORS();
+}
+
+__device__ void __update(float *__restrict__ dists, int *__restrict__ dists_i,
+ int idx1, int idx2) {
+ const float v1 = dists[idx1], v2 = dists[idx2];
+ const int i1 = dists_i[idx1], i2 = dists_i[idx2];
+ dists[idx1] = max(v1, v2);
+ dists_i[idx1] = v2 > v1 ? i2 : i1;
+}
+
+// Input dataset: (b, n, 3), tmp: (b, n)
+// Ouput idxs (b, m)
+template
+__global__ void furthest_point_sampling_kernel(
+ int b, int n, int m, const float *__restrict__ dataset,
+ float *__restrict__ temp, int *__restrict__ idxs) {
+ if (m <= 0) return;
+ __shared__ float dists[block_size];
+ __shared__ int dists_i[block_size];
+
+ int batch_index = blockIdx.x;
+ dataset += batch_index * n * 3;
+ temp += batch_index * n;
+ idxs += batch_index * m;
+
+ int tid = threadIdx.x;
+ const int stride = block_size;
+
+ int old = 0;
+ if (threadIdx.x == 0) idxs[0] = old;
+
+ __syncthreads();
+ for (int j = 1; j < m; j++) {
+ int besti = 0;
+ float best = -1;
+ float x1 = dataset[old * 3 + 0];
+ float y1 = dataset[old * 3 + 1];
+ float z1 = dataset[old * 3 + 2];
+ for (int k = tid; k < n; k += stride) {
+ float x2, y2, z2;
+ x2 = dataset[k * 3 + 0];
+ y2 = dataset[k * 3 + 1];
+ z2 = dataset[k * 3 + 2];
+ float mag = (x2 * x2) + (y2 * y2) + (z2 * z2);
+ if (mag <= 1e-3) continue;
+
+ float d =
+ (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1);
+
+ float d2 = min(d, temp[k]);
+ temp[k] = d2;
+ besti = d2 > best ? k : besti;
+ best = d2 > best ? d2 : best;
+ }
+ dists[tid] = best;
+ dists_i[tid] = besti;
+ __syncthreads();
+
+ if (block_size >= 512) {
+ if (tid < 256) {
+ __update(dists, dists_i, tid, tid + 256);
+ }
+ __syncthreads();
+ }
+ if (block_size >= 256) {
+ if (tid < 128) {
+ __update(dists, dists_i, tid, tid + 128);
+ }
+ __syncthreads();
+ }
+ if (block_size >= 128) {
+ if (tid < 64) {
+ __update(dists, dists_i, tid, tid + 64);
+ }
+ __syncthreads();
+ }
+ if (block_size >= 64) {
+ if (tid < 32) {
+ __update(dists, dists_i, tid, tid + 32);
+ }
+ __syncthreads();
+ }
+ if (block_size >= 32) {
+ if (tid < 16) {
+ __update(dists, dists_i, tid, tid + 16);
+ }
+ __syncthreads();
+ }
+ if (block_size >= 16) {
+ if (tid < 8) {
+ __update(dists, dists_i, tid, tid + 8);
+ }
+ __syncthreads();
+ }
+ if (block_size >= 8) {
+ if (tid < 4) {
+ __update(dists, dists_i, tid, tid + 4);
+ }
+ __syncthreads();
+ }
+ if (block_size >= 4) {
+ if (tid < 2) {
+ __update(dists, dists_i, tid, tid + 2);
+ }
+ __syncthreads();
+ }
+ if (block_size >= 2) {
+ if (tid < 1) {
+ __update(dists, dists_i, tid, tid + 1);
+ }
+ __syncthreads();
+ }
+
+ old = dists_i[0];
+ if (tid == 0) idxs[j] = old;
+ }
+}
+
+void furthest_point_sampling_kernel_wrapper(int b, int n, int m,
+ const float *dataset, float *temp,
+ int *idxs) {
+ unsigned int n_threads = opt_n_threads(n);
+
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+
+ switch (n_threads) {
+ case 512:
+ furthest_point_sampling_kernel<512>
+ <<>>(b, n, m, dataset, temp, idxs);
+ break;
+ case 256:
+ furthest_point_sampling_kernel<256>
+ <<>>(b, n, m, dataset, temp, idxs);
+ break;
+ case 128:
+ furthest_point_sampling_kernel<128>
+ <<>>(b, n, m, dataset, temp, idxs);
+ break;
+ case 64:
+ furthest_point_sampling_kernel<64>
+ <<>>(b, n, m, dataset, temp, idxs);
+ break;
+ case 32:
+ furthest_point_sampling_kernel<32>
+ <<>>(b, n, m, dataset, temp, idxs);
+ break;
+ case 16:
+ furthest_point_sampling_kernel<16>
+ <<>>(b, n, m, dataset, temp, idxs);
+ break;
+ case 8:
+ furthest_point_sampling_kernel<8>
+ <<>>(b, n, m, dataset, temp, idxs);
+ break;
+ case 4:
+ furthest_point_sampling_kernel<4>
+ <<>>(b, n, m, dataset, temp, idxs);
+ break;
+ case 2:
+ furthest_point_sampling_kernel<2>
+ <<>>(b, n, m, dataset, temp, idxs);
+ break;
+ case 1:
+ furthest_point_sampling_kernel<1>
+ <<>>(b, n, m, dataset, temp, idxs);
+ break;
+ default:
+ furthest_point_sampling_kernel<512>
+ <<>>(b, n, m, dataset, temp, idxs);
+ }
+
+ //CUDA_CHECK_ERRORS();
+}
diff --git a/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_version.py b/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_version.py
new file mode 100644
index 0000000000000000000000000000000000000000..528787cfc8ad81ed41822a8104b60b4896632906
--- /dev/null
+++ b/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_version.py
@@ -0,0 +1 @@
+__version__ = "3.0.0"
diff --git a/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/pointnet2_modules.py b/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/pointnet2_modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0ad4f6bc23f54ca2d61454e657a6f533e9b875c
--- /dev/null
+++ b/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/pointnet2_modules.py
@@ -0,0 +1,209 @@
+from typing import List, Optional, Tuple
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from pointnet2_ops import pointnet2_utils
+
+
+def build_shared_mlp(mlp_spec: List[int], bn: bool = True):
+ layers = []
+ for i in range(1, len(mlp_spec)):
+ layers.append(
+ nn.Conv2d(mlp_spec[i - 1], mlp_spec[i], kernel_size=1, bias=not bn)
+ )
+ if bn:
+ layers.append(nn.BatchNorm2d(mlp_spec[i]))
+ layers.append(nn.ReLU(True))
+
+ return nn.Sequential(*layers)
+
+
+class _PointnetSAModuleBase(nn.Module):
+ def __init__(self):
+ super(_PointnetSAModuleBase, self).__init__()
+ self.npoint = None
+ self.groupers = None
+ self.mlps = None
+
+ def forward(
+ self, xyz: torch.Tensor, features: Optional[torch.Tensor]
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ r"""
+ Parameters
+ ----------
+ xyz : torch.Tensor
+ (B, N, 3) tensor of the xyz coordinates of the features
+ features : torch.Tensor
+ (B, C, N) tensor of the descriptors of the the features
+
+ Returns
+ -------
+ new_xyz : torch.Tensor
+ (B, npoint, 3) tensor of the new features' xyz
+ new_features : torch.Tensor
+ (B, \sum_k(mlps[k][-1]), npoint) tensor of the new_features descriptors
+ """
+
+ new_features_list = []
+
+ xyz_flipped = xyz.transpose(1, 2).contiguous()
+ new_xyz = (
+ pointnet2_utils.gather_operation(
+ xyz_flipped, pointnet2_utils.furthest_point_sample(xyz, self.npoint)
+ )
+ .transpose(1, 2)
+ .contiguous()
+ if self.npoint is not None
+ else None
+ )
+
+ for i in range(len(self.groupers)):
+ new_features = self.groupers[i](
+ xyz, new_xyz, features
+ ) # (B, C, npoint, nsample)
+
+ new_features = self.mlps[i](new_features) # (B, mlp[-1], npoint, nsample)
+ new_features = F.max_pool2d(
+ new_features, kernel_size=[1, new_features.size(3)]
+ ) # (B, mlp[-1], npoint, 1)
+ new_features = new_features.squeeze(-1) # (B, mlp[-1], npoint)
+
+ new_features_list.append(new_features)
+
+ return new_xyz, torch.cat(new_features_list, dim=1)
+
+
+class PointnetSAModuleMSG(_PointnetSAModuleBase):
+ r"""Pointnet set abstrction layer with multiscale grouping
+
+ Parameters
+ ----------
+ npoint : int
+ Number of features
+ radii : list of float32
+ list of radii to group with
+ nsamples : list of int32
+ Number of samples in each ball query
+ mlps : list of list of int32
+ Spec of the pointnet before the global max_pool for each scale
+ bn : bool
+ Use batchnorm
+ """
+
+ def __init__(self, npoint, radii, nsamples, mlps, bn=True, use_xyz=True):
+ # type: (PointnetSAModuleMSG, int, List[float], List[int], List[List[int]], bool, bool) -> None
+ super(PointnetSAModuleMSG, self).__init__()
+
+ assert len(radii) == len(nsamples) == len(mlps)
+
+ self.npoint = npoint
+ self.groupers = nn.ModuleList()
+ self.mlps = nn.ModuleList()
+ for i in range(len(radii)):
+ radius = radii[i]
+ nsample = nsamples[i]
+ self.groupers.append(
+ pointnet2_utils.QueryAndGroup(radius, nsample, use_xyz=use_xyz)
+ if npoint is not None
+ else pointnet2_utils.GroupAll(use_xyz)
+ )
+ mlp_spec = mlps[i]
+ if use_xyz:
+ mlp_spec[0] += 3
+
+ self.mlps.append(build_shared_mlp(mlp_spec, bn))
+
+
+class PointnetSAModule(PointnetSAModuleMSG):
+ r"""Pointnet set abstrction layer
+
+ Parameters
+ ----------
+ npoint : int
+ Number of features
+ radius : float
+ Radius of ball
+ nsample : int
+ Number of samples in the ball query
+ mlp : list
+ Spec of the pointnet before the global max_pool
+ bn : bool
+ Use batchnorm
+ """
+
+ def __init__(
+ self, mlp, npoint=None, radius=None, nsample=None, bn=True, use_xyz=True
+ ):
+ # type: (PointnetSAModule, List[int], int, float, int, bool, bool) -> None
+ super(PointnetSAModule, self).__init__(
+ mlps=[mlp],
+ npoint=npoint,
+ radii=[radius],
+ nsamples=[nsample],
+ bn=bn,
+ use_xyz=use_xyz,
+ )
+
+
+class PointnetFPModule(nn.Module):
+ r"""Propigates the features of one set to another
+
+ Parameters
+ ----------
+ mlp : list
+ Pointnet module parameters
+ bn : bool
+ Use batchnorm
+ """
+
+ def __init__(self, mlp, bn=True):
+ # type: (PointnetFPModule, List[int], bool) -> None
+ super(PointnetFPModule, self).__init__()
+ self.mlp = build_shared_mlp(mlp, bn=bn)
+
+ def forward(self, unknown, known, unknow_feats, known_feats):
+ # type: (PointnetFPModule, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor) -> torch.Tensor
+ r"""
+ Parameters
+ ----------
+ unknown : torch.Tensor
+ (B, n, 3) tensor of the xyz positions of the unknown features
+ known : torch.Tensor
+ (B, m, 3) tensor of the xyz positions of the known features
+ unknow_feats : torch.Tensor
+ (B, C1, n) tensor of the features to be propigated to
+ known_feats : torch.Tensor
+ (B, C2, m) tensor of features to be propigated
+
+ Returns
+ -------
+ new_features : torch.Tensor
+ (B, mlp[-1], n) tensor of the features of the unknown features
+ """
+
+ if known is not None:
+ dist, idx = pointnet2_utils.three_nn(unknown, known)
+ dist_recip = 1.0 / (dist + 1e-8)
+ norm = torch.sum(dist_recip, dim=2, keepdim=True)
+ weight = dist_recip / norm
+
+ interpolated_feats = pointnet2_utils.three_interpolate(
+ known_feats, idx, weight
+ )
+ else:
+ interpolated_feats = known_feats.expand(
+ *(known_feats.size()[0:2] + [unknown.size(1)])
+ )
+
+ if unknow_feats is not None:
+ new_features = torch.cat(
+ [interpolated_feats, unknow_feats], dim=1
+ ) # (B, C2 + C1, n)
+ else:
+ new_features = interpolated_feats
+
+ new_features = new_features.unsqueeze(-1)
+ new_features = self.mlp(new_features)
+
+ return new_features.squeeze(-1)
diff --git a/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/pointnet2_utils.py b/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/pointnet2_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..2972e2e7793ae73f5f52a533f938f14db8a0341a
--- /dev/null
+++ b/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/pointnet2_utils.py
@@ -0,0 +1,391 @@
+import torch
+import torch.nn as nn
+import warnings
+from torch.autograd import Function
+from typing import *
+
+try:
+ import pointnet2_ops._ext as _ext
+except ImportError:
+ from torch.utils.cpp_extension import load
+ import glob
+ import os.path as osp
+ import os
+
+ warnings.warn("Unable to load pointnet2_ops cpp extension. JIT Compiling.")
+
+ _ext_src_root = osp.join(osp.dirname(__file__), "_ext-src")
+ _ext_sources = glob.glob(osp.join(_ext_src_root, "src", "*.cpp")) + glob.glob(
+ osp.join(_ext_src_root, "src", "*.cu")
+ )
+ _ext_headers = glob.glob(osp.join(_ext_src_root, "include", "*"))
+
+ os.environ["TORCH_CUDA_ARCH_LIST"] = "3.7+PTX;5.0;6.0;6.1;6.2;7.0;7.5"
+ _ext = load(
+ "_ext",
+ sources=_ext_sources,
+ extra_include_paths=[osp.join(_ext_src_root, "include")],
+ extra_cflags=["-O3"],
+ extra_cuda_cflags=["-O3", "-Xfatbin", "-compress-all"],
+ with_cuda=True,
+ )
+
+
+class FurthestPointSampling(Function):
+ @staticmethod
+ @torch.amp.custom_fwd(cast_inputs=torch.float32, device_type="cuda")
+ def forward(ctx, xyz, npoint):
+ # type: (Any, torch.Tensor, int) -> torch.Tensor
+ r"""
+ Uses iterative furthest point sampling to select a set of npoint features that have the largest
+ minimum distance
+
+ Parameters
+ ----------
+ xyz : torch.Tensor
+ (B, N, 3) tensor where N > npoint
+ npoint : int32
+ number of features in the sampled set
+
+ Returns
+ -------
+ torch.Tensor
+ (B, npoint) tensor containing the set
+ """
+ out = _ext.furthest_point_sampling(xyz, npoint)
+
+ ctx.mark_non_differentiable(out)
+
+ return out
+
+ @staticmethod
+ @torch.amp.custom_bwd(device_type="cuda")
+ def backward(ctx, grad_out):
+ return ()
+
+
+furthest_point_sample = FurthestPointSampling.apply
+
+
+class GatherOperation(Function):
+ @staticmethod
+ @torch.amp.custom_fwd(cast_inputs=torch.float32, device_type="cuda")
+ def forward(ctx, features, idx):
+ # type: (Any, torch.Tensor, torch.Tensor) -> torch.Tensor
+ r"""
+
+ Parameters
+ ----------
+ features : torch.Tensor
+ (B, C, N) tensor
+
+ idx : torch.Tensor
+ (B, npoint) tensor of the features to gather
+
+ Returns
+ -------
+ torch.Tensor
+ (B, C, npoint) tensor
+ """
+
+ ctx.save_for_backward(idx, features)
+
+ return _ext.gather_points(features, idx)
+
+ @staticmethod
+ @torch.amp.custom_bwd(device_type="cuda")
+ def backward(ctx, grad_out):
+ idx, features = ctx.saved_tensors
+ N = features.size(2)
+
+ grad_features = _ext.gather_points_grad(grad_out.contiguous(), idx, N)
+ return grad_features, None
+
+
+gather_operation = GatherOperation.apply
+
+
+class ThreeNN(Function):
+ @staticmethod
+ @torch.amp.custom_fwd(cast_inputs=torch.float32, device_type="cuda")
+ def forward(ctx, unknown, known):
+ # type: (Any, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]
+ r"""
+ Find the three nearest neighbors of unknown in known
+ Parameters
+ ----------
+ unknown : torch.Tensor
+ (B, n, 3) tensor of known features
+ known : torch.Tensor
+ (B, m, 3) tensor of unknown features
+
+ Returns
+ -------
+ dist : torch.Tensor
+ (B, n, 3) l2 distance to the three nearest neighbors
+ idx : torch.Tensor
+ (B, n, 3) index of 3 nearest neighbors
+ """
+ dist2, idx = _ext.three_nn(unknown, known)
+ dist = torch.sqrt(dist2)
+
+ ctx.mark_non_differentiable(dist, idx)
+
+ return dist, idx
+
+ @staticmethod
+ @torch.amp.custom_bwd(device_type="cuda")
+ def backward(ctx, grad_dist, grad_idx):
+ return ()
+
+
+three_nn = ThreeNN.apply
+
+
+class ThreeInterpolate(Function):
+ @staticmethod
+ @torch.amp.custom_fwd(cast_inputs=torch.float32, device_type="cuda")
+ def forward(ctx, features, idx, weight):
+ # type(Any, torch.Tensor, torch.Tensor, torch.Tensor) -> Torch.Tensor
+ r"""
+ Performs weight linear interpolation on 3 features
+ Parameters
+ ----------
+ features : torch.Tensor
+ (B, c, m) Features descriptors to be interpolated from
+ idx : torch.Tensor
+ (B, n, 3) three nearest neighbors of the target features in features
+ weight : torch.Tensor
+ (B, n, 3) weights
+
+ Returns
+ -------
+ torch.Tensor
+ (B, c, n) tensor of the interpolated features
+ """
+ ctx.save_for_backward(idx, weight, features)
+
+ return _ext.three_interpolate(features, idx, weight)
+
+ @staticmethod
+ @torch.amp.custom_bwd(device_type="cuda")
+ def backward(ctx, grad_out):
+ # type: (Any, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
+ r"""
+ Parameters
+ ----------
+ grad_out : torch.Tensor
+ (B, c, n) tensor with gradients of ouputs
+
+ Returns
+ -------
+ grad_features : torch.Tensor
+ (B, c, m) tensor with gradients of features
+
+ None
+
+ None
+ """
+ idx, weight, features = ctx.saved_tensors
+ m = features.size(2)
+
+ grad_features = _ext.three_interpolate_grad(
+ grad_out.contiguous(), idx, weight, m
+ )
+
+ return grad_features, torch.zeros_like(idx), torch.zeros_like(weight)
+
+
+three_interpolate = ThreeInterpolate.apply
+
+
+class GroupingOperation(Function):
+ @staticmethod
+ @torch.amp.custom_fwd(cast_inputs=torch.float32, device_type="cuda")
+ def forward(ctx, features, idx):
+ # type: (Any, torch.Tensor, torch.Tensor) -> torch.Tensor
+ r"""
+
+ Parameters
+ ----------
+ features : torch.Tensor
+ (B, C, N) tensor of features to group
+ idx : torch.Tensor
+ (B, npoint, nsample) tensor containing the indicies of features to group with
+
+ Returns
+ -------
+ torch.Tensor
+ (B, C, npoint, nsample) tensor
+ """
+ ctx.save_for_backward(idx, features)
+
+ return _ext.group_points(features, idx)
+
+ @staticmethod
+ @torch.amp.custom_bwd(device_type="cuda")
+ def backward(ctx, grad_out):
+ # type: (Any, torch.tensor) -> Tuple[torch.Tensor, torch.Tensor]
+ r"""
+
+ Parameters
+ ----------
+ grad_out : torch.Tensor
+ (B, C, npoint, nsample) tensor of the gradients of the output from forward
+
+ Returns
+ -------
+ torch.Tensor
+ (B, C, N) gradient of the features
+ None
+ """
+ idx, features = ctx.saved_tensors
+ N = features.size(2)
+
+ grad_features = _ext.group_points_grad(grad_out.contiguous(), idx, N)
+
+ return grad_features, torch.zeros_like(idx)
+
+
+grouping_operation = GroupingOperation.apply
+
+
+class BallQuery(Function):
+ @staticmethod
+ @torch.amp.custom_fwd(cast_inputs=torch.float32, device_type="cuda")
+ def forward(ctx, radius, nsample, xyz, new_xyz):
+ # type: (Any, float, int, torch.Tensor, torch.Tensor) -> torch.Tensor
+ r"""
+
+ Parameters
+ ----------
+ radius : float
+ radius of the balls
+ nsample : int
+ maximum number of features in the balls
+ xyz : torch.Tensor
+ (B, N, 3) xyz coordinates of the features
+ new_xyz : torch.Tensor
+ (B, npoint, 3) centers of the ball query
+
+ Returns
+ -------
+ torch.Tensor
+ (B, npoint, nsample) tensor with the indicies of the features that form the query balls
+ """
+ output = _ext.ball_query(new_xyz, xyz, radius, nsample)
+
+ ctx.mark_non_differentiable(output)
+
+ return output
+
+ @staticmethod
+ @torch.amp.custom_bwd(device_type="cuda")
+ def backward(ctx, grad_out):
+ return ()
+
+
+ball_query = BallQuery.apply
+
+
+class QueryAndGroup(nn.Module):
+ r"""
+ Groups with a ball query of radius
+
+ Parameters
+ ---------
+ radius : float32
+ Radius of ball
+ nsample : int32
+ Maximum number of features to gather in the ball
+ """
+
+ def __init__(self, radius, nsample, use_xyz=True):
+ # type: (QueryAndGroup, float, int, bool) -> None
+ super(QueryAndGroup, self).__init__()
+ self.radius, self.nsample, self.use_xyz = radius, nsample, use_xyz
+
+ def forward(self, xyz, new_xyz, features=None):
+ # type: (QueryAndGroup, torch.Tensor. torch.Tensor, torch.Tensor) -> Tuple[Torch.Tensor]
+ r"""
+ Parameters
+ ----------
+ xyz : torch.Tensor
+ xyz coordinates of the features (B, N, 3)
+ new_xyz : torch.Tensor
+ centriods (B, npoint, 3)
+ features : torch.Tensor
+ Descriptors of the features (B, C, N)
+
+ Returns
+ -------
+ new_features : torch.Tensor
+ (B, 3 + C, npoint, nsample) tensor
+ """
+
+ idx = ball_query(self.radius, self.nsample, xyz, new_xyz)
+ xyz_trans = xyz.transpose(1, 2).contiguous()
+ grouped_xyz = grouping_operation(xyz_trans, idx) # (B, 3, npoint, nsample)
+ grouped_xyz -= new_xyz.transpose(1, 2).unsqueeze(-1)
+
+ if features is not None:
+ grouped_features = grouping_operation(features, idx)
+ if self.use_xyz:
+ new_features = torch.cat(
+ [grouped_xyz, grouped_features], dim=1
+ ) # (B, C + 3, npoint, nsample)
+ else:
+ new_features = grouped_features
+ else:
+ assert (
+ self.use_xyz
+ ), "Cannot have not features and not use xyz as a feature!"
+ new_features = grouped_xyz
+
+ return new_features
+
+
+class GroupAll(nn.Module):
+ r"""
+ Groups all features
+
+ Parameters
+ ---------
+ """
+
+ def __init__(self, use_xyz=True):
+ # type: (GroupAll, bool) -> None
+ super(GroupAll, self).__init__()
+ self.use_xyz = use_xyz
+
+ def forward(self, xyz, new_xyz, features=None):
+ # type: (GroupAll, torch.Tensor, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor]
+ r"""
+ Parameters
+ ----------
+ xyz : torch.Tensor
+ xyz coordinates of the features (B, N, 3)
+ new_xyz : torch.Tensor
+ Ignored
+ features : torch.Tensor
+ Descriptors of the features (B, C, N)
+
+ Returns
+ -------
+ new_features : torch.Tensor
+ (B, C + 3, 1, N) tensor
+ """
+
+ grouped_xyz = xyz.transpose(1, 2).unsqueeze(2)
+ if features is not None:
+ grouped_features = features.unsqueeze(2)
+ if self.use_xyz:
+ new_features = torch.cat(
+ [grouped_xyz, grouped_features], dim=1
+ ) # (B, 3 + C, 1, N)
+ else:
+ new_features = grouped_features
+ else:
+ new_features = grouped_xyz
+
+ return new_features
diff --git a/hort/models/tgs/models/snowflake/pointnet2_ops_lib/setup.py b/hort/models/tgs/models/snowflake/pointnet2_ops_lib/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..d5cbe1ca5d45acab7bfcd880a743f1f883cd7002
--- /dev/null
+++ b/hort/models/tgs/models/snowflake/pointnet2_ops_lib/setup.py
@@ -0,0 +1,41 @@
+import glob
+import os
+import os.path as osp
+
+from setuptools import find_packages, setup
+import torch
+from torch.utils.cpp_extension import BuildExtension, CUDAExtension
+
+this_dir = osp.dirname(osp.abspath(__file__))
+_ext_src_root = osp.join("pointnet2_ops", "_ext-src")
+_ext_sources = glob.glob(osp.join(_ext_src_root, "src", "*.cpp")) + glob.glob(
+ osp.join(_ext_src_root, "src", "*.cu")
+)
+_ext_headers = glob.glob(osp.join(_ext_src_root, "include", "*"))
+
+requirements = ["torch>=1.4"]
+
+exec(open(osp.join("pointnet2_ops", "_version.py")).read())
+
+# os.environ["TORCH_CUDA_ARCH_LIST"] = ".".join(map(str, torch.cuda.get_device_capability()))
+os.environ["TORCH_CUDA_ARCH_LIST"] = "5.0;6.0;6.1;6.2;7.0;7.5;8.0;8.6;9.0"
+setup(
+ name="pointnet2_ops",
+ version=__version__,
+ author="Erik Wijmans",
+ packages=find_packages(),
+ install_requires=requirements,
+ ext_modules=[
+ CUDAExtension(
+ name="pointnet2_ops._ext",
+ sources=_ext_sources,
+ extra_compile_args={
+ "cxx": ["-O3"],
+ "nvcc": ["-O3", "-Xfatbin", "-compress-all"],
+ },
+ include_dirs=[osp.join(this_dir, _ext_src_root, "include")],
+ )
+ ],
+ cmdclass={"build_ext": BuildExtension},
+ include_package_data=True,
+)
diff --git a/hort/models/tgs/models/snowflake/skip_transformer.py b/hort/models/tgs/models/snowflake/skip_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..36462f7242dd3dbe58495c19a61d8cf0cc54857e
--- /dev/null
+++ b/hort/models/tgs/models/snowflake/skip_transformer.py
@@ -0,0 +1,69 @@
+# -*- coding: utf-8 -*-
+# @Author: Peng Xiang
+
+import torch
+from torch import nn, einsum
+from .utils import MLP_Res, grouping_operation, query_knn
+
+
+class SkipTransformer(nn.Module):
+ def __init__(self, in_channel, dim=256, n_knn=16, pos_hidden_dim=64, attn_hidden_multiplier=4):
+ super(SkipTransformer, self).__init__()
+ self.mlp_v = MLP_Res(in_dim=in_channel*2, hidden_dim=in_channel, out_dim=in_channel)
+ self.n_knn = n_knn
+ self.conv_key = nn.Conv1d(in_channel, dim, 1)
+ self.conv_query = nn.Conv1d(in_channel, dim, 1)
+ self.conv_value = nn.Conv1d(in_channel, dim, 1)
+
+ self.pos_mlp = nn.Sequential(
+ nn.Conv2d(3, pos_hidden_dim, 1),
+ nn.BatchNorm2d(pos_hidden_dim),
+ nn.ReLU(),
+ nn.Conv2d(pos_hidden_dim, dim, 1)
+ )
+
+ self.attn_mlp = nn.Sequential(
+ nn.Conv2d(dim, dim * attn_hidden_multiplier, 1),
+ nn.BatchNorm2d(dim * attn_hidden_multiplier),
+ nn.ReLU(),
+ nn.Conv2d(dim * attn_hidden_multiplier, dim, 1)
+ )
+
+ self.conv_end = nn.Conv1d(dim, in_channel, 1)
+
+ def forward(self, pos, key, query, include_self=True):
+ """
+ Args:
+ pos: (B, 3, N)
+ key: (B, in_channel, N)
+ query: (B, in_channel, N)
+ include_self: boolean
+
+ Returns:
+ Tensor: (B, in_channel, N), shape context feature
+ """
+ value = self.mlp_v(torch.cat([key, query], 1))
+ identity = value
+ key = self.conv_key(key)
+ query = self.conv_query(query)
+ value = self.conv_value(value)
+ b, dim, n = value.shape
+
+ pos_flipped = pos.permute(0, 2, 1).contiguous()
+ idx_knn = query_knn(self.n_knn, pos_flipped, pos_flipped, include_self=include_self)
+
+ key = grouping_operation(key, idx_knn) # b, dim, n, n_knn
+ qk_rel = query.reshape((b, -1, n, 1)) - key
+
+ pos_rel = pos.reshape((b, -1, n, 1)) - grouping_operation(pos, idx_knn) # b, 3, n, n_knn
+ pos_embedding = self.pos_mlp(pos_rel)
+
+ attention = self.attn_mlp(qk_rel + pos_embedding) # b, dim, n, n_knn
+ attention = torch.softmax(attention, -1)
+
+ value = value.reshape((b, -1, n, 1)) + pos_embedding #
+
+ agg = einsum('b c i j, b c i j -> b c i', attention, value) # b, dim, n
+ y = self.conv_end(agg)
+
+ return y + identity
diff --git a/hort/models/tgs/models/snowflake/utils.py b/hort/models/tgs/models/snowflake/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..6661e38da308f5a05714274c8b9ade378be61486
--- /dev/null
+++ b/hort/models/tgs/models/snowflake/utils.py
@@ -0,0 +1,741 @@
+# -*- coding: utf-8 -*-
+# @Author: Peng Xiang
+
+import types
+import torch
+import torch.nn.functional as F
+import numpy as np
+from torch import nn, einsum
+from pointnet2_ops.pointnet2_utils import furthest_point_sample, \
+ gather_operation, ball_query, three_nn, three_interpolate, grouping_operation
+
+class Conv1d(nn.Module):
+ def __init__(self, in_channel, out_channel, kernel_size=1, stride=1, if_bn=True, activation_fn=torch.relu):
+ super(Conv1d, self).__init__()
+ self.conv = nn.Conv1d(in_channel, out_channel, kernel_size, stride=stride)
+ self.if_bn = if_bn
+ self.bn = nn.BatchNorm1d(out_channel)
+ self.activation_fn = activation_fn
+
+ def forward(self, input):
+ out = self.conv(input)
+ if self.if_bn:
+ out = self.bn(out)
+
+ if self.activation_fn is not None:
+ out = self.activation_fn(out)
+
+ return out
+
+class Conv2d(nn.Module):
+ def __init__(self, in_channel, out_channel, kernel_size=(1, 1), stride=(1, 1), if_bn=True, activation_fn=torch.relu):
+ super(Conv2d, self).__init__()
+ self.conv = nn.Conv2d(in_channel, out_channel, kernel_size, stride=stride)
+ self.if_bn = if_bn
+ self.bn = nn.BatchNorm2d(out_channel)
+ self.activation_fn = activation_fn
+
+ def forward(self, input):
+ out = self.conv(input)
+ if self.if_bn:
+ out = self.bn(out)
+
+ if self.activation_fn is not None:
+ out = self.activation_fn(out)
+
+ return out
+
+class MLP(nn.Module):
+ def __init__(self, in_channel, layer_dims, bn=None):
+ super(MLP, self).__init__()
+ layers = []
+ last_channel = in_channel
+ for out_channel in layer_dims[:-1]:
+ layers.append(nn.Linear(last_channel, out_channel))
+ if bn:
+ layers.append(nn.BatchNorm1d(out_channel))
+ layers.append(nn.ReLU())
+ last_channel = out_channel
+ layers.append(nn.Linear(last_channel, layer_dims[-1]))
+ self.mlp = nn.Sequential(*layers)
+
+ def forward(self, inputs):
+ return self.mlp(inputs)
+
+class MLP_CONV(nn.Module):
+ def __init__(self, in_channel, layer_dims, bn=None):
+ super(MLP_CONV, self).__init__()
+ layers = []
+ last_channel = in_channel
+ for out_channel in layer_dims[:-1]:
+ layers.append(nn.Conv1d(last_channel, out_channel, 1))
+ if bn:
+ layers.append(nn.BatchNorm1d(out_channel))
+ layers.append(nn.ReLU())
+ last_channel = out_channel
+ layers.append(nn.Conv1d(last_channel, layer_dims[-1], 1))
+ self.mlp = nn.Sequential(*layers)
+
+ def forward(self, inputs):
+ return self.mlp(inputs)
+
+class MLP_Res(nn.Module):
+ def __init__(self, in_dim=128, hidden_dim=None, out_dim=128):
+ super(MLP_Res, self).__init__()
+ if hidden_dim is None:
+ hidden_dim = in_dim
+ self.conv_1 = nn.Conv1d(in_dim, hidden_dim, 1)
+ self.conv_2 = nn.Conv1d(hidden_dim, out_dim, 1)
+ self.conv_shortcut = nn.Conv1d(in_dim, out_dim, 1)
+
+ def forward(self, x):
+ """
+ Args:
+ x: (B, out_dim, n)
+ """
+ shortcut = self.conv_shortcut(x)
+ out = self.conv_2(torch.relu(self.conv_1(x))) + shortcut
+ return out
+
+
+def sample_and_group(xyz, points, npoint, nsample, radius, use_xyz=True):
+ """
+ Args:
+ xyz: Tensor, (B, 3, N)
+ points: Tensor, (B, f, N)
+ npoint: int
+ nsample: int
+ radius: float
+ use_xyz: boolean
+
+ Returns:
+ new_xyz: Tensor, (B, 3, npoint)
+ new_points: Tensor, (B, 3 | f+3 | f, npoint, nsample)
+ idx_local: Tensor, (B, npoint, nsample)
+ grouped_xyz: Tensor, (B, 3, npoint, nsample)
+
+ """
+ xyz_flipped = xyz.permute(0, 2, 1).contiguous() # (B, N, 3)
+ new_xyz = gather_operation(xyz, furthest_point_sample(xyz_flipped, npoint)) # (B, 3, npoint)
+
+ idx = ball_query(radius, nsample, xyz_flipped, new_xyz.permute(0, 2, 1).contiguous()) # (B, npoint, nsample)
+ grouped_xyz = grouping_operation(xyz, idx) # (B, 3, npoint, nsample)
+ grouped_xyz -= new_xyz.unsqueeze(3).repeat(1, 1, 1, nsample)
+
+ if points is not None:
+ grouped_points = grouping_operation(points, idx) # (B, f, npoint, nsample)
+ if use_xyz:
+ new_points = torch.cat([grouped_xyz, grouped_points], 1)
+ else:
+ new_points = grouped_points
+ else:
+ new_points = grouped_xyz
+
+ return new_xyz, new_points, idx, grouped_xyz
+
+
+def sample_and_group_all(xyz, points, use_xyz=True):
+ """
+ Args:
+ xyz: Tensor, (B, 3, nsample)
+ points: Tensor, (B, f, nsample)
+ use_xyz: boolean
+
+ Returns:
+ new_xyz: Tensor, (B, 3, 1)
+ new_points: Tensor, (B, f|f+3|3, 1, nsample)
+ idx: Tensor, (B, 1, nsample)
+ grouped_xyz: Tensor, (B, 3, 1, nsample)
+ """
+ b, _, nsample = xyz.shape
+ device = xyz.device
+ new_xyz = torch.zeros((1, 3, 1), dtype=torch.float, device=device).repeat(b, 1, 1)
+ grouped_xyz = xyz.reshape((b, 3, 1, nsample))
+ idx = torch.arange(nsample, device=device).reshape(1, 1, nsample).repeat(b, 1, 1)
+ if points is not None:
+ if use_xyz:
+ new_points = torch.cat([xyz, points], 1)
+ else:
+ new_points = points
+ new_points = new_points.unsqueeze(2)
+ else:
+ new_points = grouped_xyz
+
+ return new_xyz, new_points, idx, grouped_xyz
+
+
+class PointNet_SA_Module(nn.Module):
+ def __init__(self, npoint, nsample, radius, in_channel, mlp, if_bn=True, group_all=False, use_xyz=True):
+ """
+ Args:
+ npoint: int, number of points to sample
+ nsample: int, number of points in each local region
+ radius: float
+ in_channel: int, input channel of features(points)
+ mlp: list of int,
+ """
+ super(PointNet_SA_Module, self).__init__()
+ self.npoint = npoint
+ self.nsample = nsample
+ self.radius = radius
+ self.mlp = mlp
+ self.group_all = group_all
+ self.use_xyz = use_xyz
+ if use_xyz:
+ in_channel += 3
+
+ last_channel = in_channel
+ self.mlp_conv = []
+ for out_channel in mlp:
+ self.mlp_conv.append(Conv2d(last_channel, out_channel, if_bn=if_bn))
+ last_channel = out_channel
+
+ self.mlp_conv = nn.Sequential(*self.mlp_conv)
+
+ def forward(self, xyz, points):
+ """
+ Args:
+ xyz: Tensor, (B, 3, N)
+ points: Tensor, (B, f, N)
+
+ Returns:
+ new_xyz: Tensor, (B, 3, npoint)
+ new_points: Tensor, (B, mlp[-1], npoint)
+ """
+ if self.group_all:
+ new_xyz, new_points, idx, grouped_xyz = sample_and_group_all(xyz, points, self.use_xyz)
+ else:
+ new_xyz, new_points, idx, grouped_xyz = sample_and_group(xyz, points, self.npoint, self.nsample, self.radius, self.use_xyz)
+
+ new_points = self.mlp_conv(new_points)
+ new_points = torch.max(new_points, 3)[0]
+
+ return new_xyz, new_points
+
+
+class PointNet_FP_Module(nn.Module):
+ def __init__(self, in_channel, mlp, use_points1=False, in_channel_points1=None, if_bn=True):
+ """
+ Args:
+ in_channel: int, input channel of points2
+ mlp: list of int
+ use_points1: boolean, if use points
+ in_channel_points1: int, input channel of points1
+ """
+ super(PointNet_FP_Module, self).__init__()
+ self.use_points1 = use_points1
+
+ if use_points1:
+ in_channel += in_channel_points1
+
+ last_channel = in_channel
+ self.mlp_conv = []
+ for out_channel in mlp:
+ self.mlp_conv.append(Conv1d(last_channel, out_channel, if_bn=if_bn))
+ last_channel = out_channel
+
+ self.mlp_conv = nn.Sequential(*self.mlp_conv)
+
+ def forward(self, xyz1, xyz2, points1, points2):
+ """
+ Args:
+ xyz1: Tensor, (B, 3, N)
+ xyz2: Tensor, (B, 3, M)
+ points1: Tensor, (B, in_channel, N)
+ points2: Tensor, (B, in_channel, M)
+
+ Returns:MLP_CONV
+ new_points: Tensor, (B, mlp[-1], N)
+ """
+ dist, idx = three_nn(xyz1.permute(0, 2, 1).contiguous(), xyz2.permute(0, 2, 1).contiguous())
+ dist = torch.clamp_min(dist, 1e-10) # (B, N, 3)
+ recip_dist = 1.0/dist
+ norm = torch.sum(recip_dist, 2, keepdim=True).repeat((1, 1, 3))
+ weight = recip_dist / norm
+ interpolated_points = three_interpolate(points2, idx, weight) # B, in_channel, N
+
+ if self.use_points1:
+ new_points = torch.cat([interpolated_points, points1], 1)
+ else:
+ new_points = interpolated_points
+
+ new_points = self.mlp_conv(new_points)
+ return new_points
+
+
+def square_distance(src, dst):
+ """
+ Calculate Euclid distance between each two points.
+
+ src^T * dst = xn * xm + yn * ym + zn * zm;
+ sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
+ sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
+ dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
+ = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
+
+ Input:
+ src: source points, [B, N, C]
+ dst: target points, [B, M, C]
+ Output:
+ dist: per-point square distance, [B, N, M]
+ """
+ B, N, _ = src.shape
+ _, M, _ = dst.shape
+ dist = -2 * torch.matmul(src, dst.permute(0, 2, 1)) # B, N, M
+ dist += torch.sum(src ** 2, -1).view(B, N, 1)
+ dist += torch.sum(dst ** 2, -1).view(B, 1, M)
+ return dist
+
+
+def query_knn(nsample, xyz, new_xyz, include_self=True):
+ """Find k-NN of new_xyz in xyz"""
+ pad = 0 if include_self else 1
+ sqrdists = square_distance(new_xyz, xyz) # B, S, N
+ idx = torch.argsort(sqrdists, dim=-1, descending=False)[:, :, pad: nsample+pad]
+ return idx.int()
+
+
+def sample_and_group_knn(xyz, points, npoint, k, use_xyz=True, idx=None):
+ """
+ Args:
+ xyz: Tensor, (B, 3, N)
+ points: Tensor, (B, f, N)
+ npoint: int
+ nsample: int
+ radius: float
+ use_xyz: boolean
+
+ Returns:
+ new_xyz: Tensor, (B, 3, npoint)
+ new_points: Tensor, (B, 3 | f+3 | f, npoint, nsample)
+ idx_local: Tensor, (B, npoint, nsample)
+ grouped_xyz: Tensor, (B, 3, npoint, nsample)
+
+ """
+ xyz_flipped = xyz.permute(0, 2, 1).contiguous() # (B, N, 3)
+ new_xyz = gather_operation(xyz, furthest_point_sample(xyz_flipped, npoint)) # (B, 3, npoint)
+ if idx is None:
+ idx = query_knn(k, xyz_flipped, new_xyz.permute(0, 2, 1).contiguous())
+ grouped_xyz = grouping_operation(xyz, idx) # (B, 3, npoint, nsample)
+ grouped_xyz -= new_xyz.unsqueeze(3).repeat(1, 1, 1, k)
+
+ if points is not None:
+ grouped_points = grouping_operation(points, idx) # (B, f, npoint, nsample)
+ if use_xyz:
+ new_points = torch.cat([grouped_xyz, grouped_points], 1)
+ else:
+ new_points = grouped_points
+ else:
+ new_points = grouped_xyz
+
+ return new_xyz, new_points, idx, grouped_xyz
+
+
+class PointNet_SA_Module_KNN(nn.Module):
+ def __init__(self, npoint, nsample, in_channel, mlp, if_bn=True, group_all=False, use_xyz=True, if_idx=False):
+ """
+ Args:
+ npoint: int, number of points to sample
+ nsample: int, number of points in each local region
+ radius: float
+ in_channel: int, input channel of features(points)
+ mlp: list of int,
+ """
+ super(PointNet_SA_Module_KNN, self).__init__()
+ self.npoint = npoint
+ self.nsample = nsample
+ self.mlp = mlp
+ self.group_all = group_all
+ self.use_xyz = use_xyz
+ self.if_idx = if_idx
+ if use_xyz:
+ in_channel += 3
+
+ last_channel = in_channel
+ self.mlp_conv = []
+ for out_channel in mlp[:-1]:
+ self.mlp_conv.append(Conv2d(last_channel, out_channel, if_bn=if_bn))
+ last_channel = out_channel
+ self.mlp_conv.append(Conv2d(last_channel, mlp[-1], if_bn=False, activation_fn=None))
+ self.mlp_conv = nn.Sequential(*self.mlp_conv)
+
+ def forward(self, xyz, points, idx=None):
+ """
+ Args:
+ xyz: Tensor, (B, 3, N)
+ points: Tensor, (B, f, N)
+
+ Returns:
+ new_xyz: Tensor, (B, 3, npoint)
+ new_points: Tensor, (B, mlp[-1], npoint)
+ """
+ if self.group_all:
+ new_xyz, new_points, idx, grouped_xyz = sample_and_group_all(xyz, points, self.use_xyz)
+ else:
+ new_xyz, new_points, idx, grouped_xyz = sample_and_group_knn(xyz, points, self.npoint, self.nsample, self.use_xyz, idx=idx)
+
+ new_points = self.mlp_conv(new_points)
+ new_points = torch.max(new_points, 3)[0]
+
+ if self.if_idx:
+ return new_xyz, new_points, idx
+ else:
+ return new_xyz, new_points
+
+
+def fps_subsample(pcd, n_points=2048):
+ """
+ Args
+ pcd: (b, 16384, 3)
+
+ returns
+ new_pcd: (b, n_points, 3)
+ """
+ new_pcd = gather_operation(pcd.permute(0, 2, 1).contiguous(), furthest_point_sample(pcd, n_points))
+ new_pcd = new_pcd.permute(0, 2, 1).contiguous()
+ return new_pcd
+
+
+class Transformer(nn.Module):
+ def __init__(self, in_channel, dim=256, n_knn=16, pos_hidden_dim=64, attn_hidden_multiplier=4):
+ super(Transformer, self).__init__()
+ self.n_knn = n_knn
+ self.conv_key = nn.Conv1d(dim, dim, 1)
+ self.conv_query = nn.Conv1d(dim, dim, 1)
+ self.conv_value = nn.Conv1d(dim, dim, 1)
+
+ self.pos_mlp = nn.Sequential(
+ nn.Conv2d(3, pos_hidden_dim, 1),
+ nn.BatchNorm2d(pos_hidden_dim),
+ nn.ReLU(),
+ nn.Conv2d(pos_hidden_dim, dim, 1)
+ )
+
+ self.attn_mlp = nn.Sequential(
+ nn.Conv2d(dim, dim * attn_hidden_multiplier, 1),
+ nn.BatchNorm2d(dim * attn_hidden_multiplier),
+ nn.ReLU(),
+ nn.Conv2d(dim * attn_hidden_multiplier, dim, 1)
+ )
+
+ self.linear_start = nn.Conv1d(in_channel, dim, 1)
+ self.linear_end = nn.Conv1d(dim, in_channel, 1)
+
+ def forward(self, x, pos):
+ """feed forward of transformer
+ Args:
+ x: Tensor of features, (B, in_channel, n)
+ pos: Tensor of positions, (B, 3, n)
+
+ Returns:
+ y: Tensor of features with attention, (B, in_channel, n)
+ """
+
+ identity = x
+
+ x = self.linear_start(x)
+ b, dim, n = x.shape
+
+ pos_flipped = pos.permute(0, 2, 1).contiguous()
+ idx_knn = query_knn(self.n_knn, pos_flipped, pos_flipped)
+ key = self.conv_key(x)
+ value = self.conv_value(x)
+ query = self.conv_query(x)
+
+ key = grouping_operation(key, idx_knn) # b, dim, n, n_knn
+ qk_rel = query.reshape((b, -1, n, 1)) - key
+
+ pos_rel = pos.reshape((b, -1, n, 1)) - grouping_operation(pos, idx_knn) # b, 3, n, n_knn
+ pos_embedding = self.pos_mlp(pos_rel) # b, dim, n, n_knn
+
+ attention = self.attn_mlp(qk_rel + pos_embedding)
+ attention = torch.softmax(attention, -1)
+
+ value = value.reshape((b, -1, n, 1)) + pos_embedding
+
+ agg = einsum('b c i j, b c i j -> b c i', attention, value) # b, dim, n
+ y = self.linear_end(agg)
+
+ return y+identity
+
+
+class CouplingLayer(nn.Module):
+
+ def __init__(self, d, intermediate_dim, swap=False):
+ nn.Module.__init__(self)
+ self.d = d - (d // 2)
+ self.swap = swap
+ self.net_s_t = nn.Sequential(
+ nn.Linear(self.d, intermediate_dim),
+ nn.ReLU(inplace=True),
+ nn.Linear(intermediate_dim, intermediate_dim),
+ nn.ReLU(inplace=True),
+ nn.Linear(intermediate_dim, (d - self.d) * 2),
+ )
+
+ def forward(self, x, logpx=None, reverse=False):
+
+ if self.swap:
+ x = torch.cat([x[:, self.d:], x[:, :self.d]], 1)
+
+ in_dim = self.d
+ out_dim = x.shape[1] - self.d
+
+ s_t = self.net_s_t(x[:, :in_dim])
+ scale = torch.sigmoid(s_t[:, :out_dim] + 2.)
+ shift = s_t[:, out_dim:]
+
+ logdetjac = torch.sum(torch.log(scale).view(scale.shape[0], -1), 1, keepdim=True)
+
+ if not reverse:
+ y1 = x[:, self.d:] * scale + shift
+ delta_logp = -logdetjac
+ else:
+ y1 = (x[:, self.d:] - shift) / scale
+ delta_logp = logdetjac
+
+ y = torch.cat([x[:, :self.d], y1], 1) if not self.swap else torch.cat([y1, x[:, :self.d]], 1)
+
+ if logpx is None:
+ return y
+ else:
+ return y, logpx + delta_logp
+
+
+class SequentialFlow(nn.Module):
+ """A generalized nn.Sequential container for normalizing flows.
+ """
+
+ def __init__(self, layersList):
+ super(SequentialFlow, self).__init__()
+ self.chain = nn.ModuleList(layersList)
+
+ def forward(self, x, logpx=None, reverse=False, inds=None):
+ if inds is None:
+ if reverse:
+ inds = range(len(self.chain) - 1, -1, -1)
+ else:
+ inds = range(len(self.chain))
+
+ if logpx is None:
+ for i in inds:
+ x = self.chain[i](x, reverse=reverse)
+ return x
+ else:
+ for i in inds:
+ x, logpx = self.chain[i](x, logpx, reverse=reverse)
+ return x, logpx
+
+
+def build_latent_flow(args):
+ chain = []
+ for i in range(args.latent_flow_depth):
+ chain.append(CouplingLayer(args.latent_dim, args.latent_flow_hidden_dim, swap=(i % 2 == 0)))
+ return SequentialFlow(chain)
+
+
+##################
+## SpectralNorm ##
+##################
+
+POWER_ITERATION_FN = "spectral_norm_power_iteration"
+
+
+class SpectralNorm(object):
+ def __init__(self, name='weight', dim=0, eps=1e-12):
+ self.name = name
+ self.dim = dim
+ self.eps = eps
+
+ def compute_weight(self, module, n_power_iterations):
+ if n_power_iterations < 0:
+ raise ValueError(
+ 'Expected n_power_iterations to be non-negative, but '
+ 'got n_power_iterations={}'.format(n_power_iterations)
+ )
+
+ weight = getattr(module, self.name + '_orig')
+ u = getattr(module, self.name + '_u')
+ v = getattr(module, self.name + '_v')
+ weight_mat = weight
+ if self.dim != 0:
+ # permute dim to front
+ weight_mat = weight_mat.permute(self.dim, * [d for d in range(weight_mat.dim()) if d != self.dim])
+ height = weight_mat.size(0)
+ weight_mat = weight_mat.reshape(height, -1)
+ with torch.no_grad():
+ for _ in range(n_power_iterations):
+ # Spectral norm of weight equals to `u^T W v`, where `u` and `v`
+ # are the first left and right singular vectors.
+ # This power iteration produces approximations of `u` and `v`.
+ v = F.normalize(torch.matmul(weight_mat.t(), u), dim=0, eps=self.eps)
+ u = F.normalize(torch.matmul(weight_mat, v), dim=0, eps=self.eps)
+ setattr(module, self.name + '_u', u)
+ setattr(module, self.name + '_v', v)
+
+ sigma = torch.dot(u, torch.matmul(weight_mat, v))
+ weight = weight / sigma
+ setattr(module, self.name, weight)
+
+ def remove(self, module):
+ weight = getattr(module, self.name)
+ delattr(module, self.name)
+ delattr(module, self.name + '_u')
+ delattr(module, self.name + '_orig')
+ module.register_parameter(self.name, torch.nn.Parameter(weight))
+
+ def get_update_method(self, module):
+ def update_fn(module, n_power_iterations):
+ self.compute_weight(module, n_power_iterations)
+
+ return update_fn
+
+ def __call__(self, module, unused_inputs):
+ del unused_inputs
+ self.compute_weight(module, n_power_iterations=0)
+
+ # requires_grad might be either True or False during inference.
+ if not module.training:
+ r_g = getattr(module, self.name + '_orig').requires_grad
+ setattr(module, self.name, getattr(module, self.name).detach().requires_grad_(r_g))
+
+ @staticmethod
+ def apply(module, name, dim, eps):
+ fn = SpectralNorm(name, dim, eps)
+ weight = module._parameters[name]
+ height = weight.size(dim)
+
+ u = F.normalize(weight.new_empty(height).normal_(0, 1), dim=0, eps=fn.eps)
+ v = F.normalize(weight.new_empty(int(weight.numel() / height)).normal_(0, 1), dim=0, eps=fn.eps)
+ delattr(module, fn.name)
+ module.register_parameter(fn.name + "_orig", weight)
+ # We still need to assign weight back as fn.name because all sorts of
+ # things may assume that it exists, e.g., when initializing weights.
+ # However, we can't directly assign as it could be an nn.Parameter and
+ # gets added as a parameter. Instead, we register weight.data as a
+ # buffer, which will cause weight to be included in the state dict
+ # and also supports nn.init due to shared storage.
+ module.register_buffer(fn.name, weight.data)
+ module.register_buffer(fn.name + "_u", u)
+ module.register_buffer(fn.name + "_v", v)
+
+ setattr(module, POWER_ITERATION_FN, types.MethodType(fn.get_update_method(module), module))
+
+ module.register_forward_pre_hook(fn)
+ return fn
+
+
+def inplace_spectral_norm(module, name='weight', dim=None, eps=1e-12):
+ r"""Applies spectral normalization to a parameter in the given module.
+ .. math::
+ \mathbf{W} = \dfrac{\mathbf{W}}{\sigma(\mathbf{W})} \\
+ \sigma(\mathbf{W}) = \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2}
+ Spectral normalization stabilizes the training of discriminators (critics)
+ in Generaive Adversarial Networks (GANs) by rescaling the weight tensor
+ with spectral norm :math:`\sigma` of the weight matrix calculated using
+ power iteration method. If the dimension of the weight tensor is greater
+ than 2, it is reshaped to 2D in power iteration method to get spectral
+ norm. This is implemented via a hook that calculates spectral norm and
+ rescales weight before every :meth:`~Module.forward` call.
+ See `Spectral Normalization for Generative Adversarial Networks`_ .
+ .. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957
+ Args:
+ module (nn.Module): containing module
+ name (str, optional): name of weight parameter
+ n_power_iterations (int, optional): number of power iterations to
+ calculate spectal norm
+ dim (int, optional): dimension corresponding to number of outputs,
+ the default is 0, except for modules that are instances of
+ ConvTranspose1/2/3d, when it is 1
+ eps (float, optional): epsilon for numerical stability in
+ calculating norms
+ Returns:
+ The original module with the spectal norm hook
+ Example::
+ >>> m = spectral_norm(nn.Linear(20, 40))
+ Linear (20 -> 40)
+ >>> m.weight_u.size()
+ torch.Size([20])
+ """
+ if dim is None:
+ if isinstance(module, (torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d)):
+ dim = 1
+ else:
+ dim = 0
+ SpectralNorm.apply(module, name, dim=dim, eps=eps)
+ return module
+
+
+def remove_spectral_norm(module, name='weight'):
+ r"""Removes the spectral normalization reparameterization from a module.
+ Args:
+ module (nn.Module): containing module
+ name (str, optional): name of weight parameter
+ Example:
+ >>> m = spectral_norm(nn.Linear(40, 10))
+ >>> remove_spectral_norm(m)
+ """
+ for k, hook in module._forward_pre_hooks.items():
+ if isinstance(hook, SpectralNorm) and hook.name == name:
+ hook.remove(module)
+ del module._forward_pre_hooks[k]
+ return module
+
+ raise ValueError("spectral_norm of '{}' not found in {}".format(name, module))
+
+
+def add_spectral_norm(model, logger=None):
+ """Applies spectral norm to all modules within the scope of a CNF."""
+
+ def apply_spectral_norm(module):
+ if 'weight' in module._parameters:
+ if logger: logger.info("Adding spectral norm to {}".format(module))
+ inplace_spectral_norm(module, 'weight')
+
+ def find_coupling_layer(module):
+ if isinstance(module, CouplingLayer):
+ module.apply(apply_spectral_norm)
+ else:
+ for child in module.children():
+ find_coupling_layer(child)
+
+ find_coupling_layer(model)
+
+
+def spectral_norm_power_iteration(model, n_power_iterations=1):
+
+ def recursive_power_iteration(module):
+ if hasattr(module, POWER_ITERATION_FN):
+ getattr(module, POWER_ITERATION_FN)(n_power_iterations)
+
+ model.apply(recursive_power_iteration)
+
+def reparameterize_gaussian(mean, logvar):
+ std = torch.exp(0.5 * logvar)
+ eps = torch.randn(std.size()).to(mean)
+ return mean + std * eps
+
+
+
+def gaussian_entropy(logvar):
+ const = 0.5 * float(logvar.size(1)) * (1. + np.log(np.pi * 2))
+ ent = 0.5 * logvar.sum(dim=1, keepdim=False) + const
+ return ent
+
+
+def standard_normal_logprob(z):
+ dim = z.size(-1)
+ log_z = -0.5 * dim * np.log(2 * np.pi)
+ return log_z - z.pow(2) / 2
+
+def truncated_normal_(tensor, mean=0, std=1, trunc_std=2):
+ """
+ Taken from https://discuss.pytorch.org/t/implementing-truncated-normal-initializer/4778/15
+ """
+ size = tensor.shape
+ tmp = tensor.new_empty(size + (4,)).normal_()
+ valid = (tmp < trunc_std) & (tmp > -trunc_std)
+ ind = valid.max(-1, keepdim=True)[1]
+ tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1))
+ tensor.data.mul_(std).add_(mean)
+ return tensor
\ No newline at end of file
diff --git a/hort/models/tgs/models/tokenizers/dinov2.py b/hort/models/tgs/models/tokenizers/dinov2.py
new file mode 100644
index 0000000000000000000000000000000000000000..64f939b2d1f0d31c64ba21328fdecf72c280c6fa
--- /dev/null
+++ b/hort/models/tgs/models/tokenizers/dinov2.py
@@ -0,0 +1,1179 @@
+# coding=utf-8
+# Copyright 2023 Meta AI and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" PyTorch DINOv2 model."""
+
+
+import collections.abc
+import math
+from typing import Dict, List, Optional, Set, Tuple, Union
+from dataclasses import dataclass
+
+import torch
+import torch.utils.checkpoint
+import torch.nn.functional as F
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from transformers.activations import ACT2FN
+from transformers.modeling_outputs import (
+ BackboneOutput,
+ BaseModelOutput,
+ BaseModelOutputWithPooling,
+ ImageClassifierOutput,
+)
+from transformers.modeling_utils import PreTrainedModel
+from transformers.pytorch_utils import (
+ find_pruneable_heads_and_indices,
+ prune_linear_layer,
+)
+from transformers.utils import (
+ add_code_sample_docstrings,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+ replace_return_docstrings,
+)
+from transformers.utils.backbone_utils import BackboneMixin
+from transformers.models.dinov2.configuration_dinov2 import Dinov2Config
+
+from tgs.models.transformers import MemoryEfficientAttentionMixin
+from tgs.utils.typing import *
+
+
+logger = logging.get_logger(__name__)
+
+# General docstring
+_CONFIG_FOR_DOC = "Dinov2Config"
+
+# Base docstring
+_CHECKPOINT_FOR_DOC = "facebook/dinov2-base"
+_EXPECTED_OUTPUT_SHAPE = [1, 257, 768]
+
+# Image classification docstring
+_IMAGE_CLASS_CHECKPOINT = "facebook/dinov2-base"
+
+
+DINOV2_PRETRAINED_MODEL_ARCHIVE_LIST = [
+ "facebook/dinov2-base",
+ # See all DINOv2 models at https://huggingface.co/models?filter=dinov2
+]
+
+
+class Dinov2Embeddings(nn.Module):
+ """
+ Construct the CLS token, mask token, position and patch embeddings.
+ """
+
+ def __init__(self, config: Dinov2Config) -> None:
+ super().__init__()
+
+ self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
+ # register as mask token as it's not used in optimization
+ # to avoid the use of find_unused_parameters_true
+ # self.mask_token = nn.Parameter(torch.zeros(1, config.hidden_size))
+ self.register_buffer("mask_token", torch.zeros(1, config.hidden_size))
+ self.patch_embeddings = Dinov2PatchEmbeddings(config)
+ num_patches = self.patch_embeddings.num_patches
+ self.position_embeddings = nn.Parameter(
+ torch.randn(1, num_patches + 1, config.hidden_size)
+ )
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ self.config = config
+
+ def interpolate_pos_encoding(
+ self, embeddings: torch.Tensor, height: int, width: int
+ ) -> torch.Tensor:
+ """
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
+ resolution images.
+
+ Source:
+ https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
+ """
+
+ num_patches = embeddings.shape[1] - 1
+ num_positions = self.position_embeddings.shape[1] - 1
+ if num_patches == num_positions and height == width:
+ return self.position_embeddings
+ class_pos_embed = self.position_embeddings[:, 0]
+ patch_pos_embed = self.position_embeddings[:, 1:]
+ dim = embeddings.shape[-1]
+ height = height // self.config.patch_size
+ width = width // self.config.patch_size
+ # we add a small number to avoid floating point error in the interpolation
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
+ height, width = height + 0.1, width + 0.1
+ patch_pos_embed = patch_pos_embed.reshape(
+ 1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim
+ )
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
+ patch_pos_embed = nn.functional.interpolate(
+ patch_pos_embed,
+ scale_factor=(
+ height / math.sqrt(num_positions),
+ width / math.sqrt(num_positions),
+ ),
+ mode="bicubic",
+ align_corners=False,
+ )
+ if (
+ int(height) != patch_pos_embed.shape[-2]
+ or int(width) != patch_pos_embed.shape[-1]
+ ):
+ raise ValueError(
+ "Width or height does not match with the interpolated position embeddings"
+ )
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
+
+ def forward(
+ self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ batch_size, _, height, width = pixel_values.shape
+ patch_embeddings = self.patch_embeddings(pixel_values)
+ embeddings = patch_embeddings
+
+ if bool_masked_pos is not None:
+ embeddings = torch.where(
+ bool_masked_pos.unsqueeze(-1),
+ self.mask_token.to(embeddings.dtype).unsqueeze(0),
+ embeddings,
+ )
+
+ # add the [CLS] token to the embedded patch tokens
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1)
+ embeddings = torch.cat((cls_tokens, embeddings), dim=1)
+
+ # add positional encoding to each token
+ embeddings = embeddings + self.interpolate_pos_encoding(
+ embeddings, height, width
+ )
+
+ embeddings = self.dropout(embeddings)
+
+ return embeddings
+
+
+class Dinov2PatchEmbeddings(nn.Module):
+ """
+ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
+ `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
+ Transformer.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ image_size, patch_size = config.image_size, config.patch_size
+ num_channels, hidden_size = config.num_channels, config.hidden_size
+
+ image_size = (
+ image_size
+ if isinstance(image_size, collections.abc.Iterable)
+ else (image_size, image_size)
+ )
+ patch_size = (
+ patch_size
+ if isinstance(patch_size, collections.abc.Iterable)
+ else (patch_size, patch_size)
+ )
+ num_patches = (image_size[1] // patch_size[1]) * (
+ image_size[0] // patch_size[0]
+ )
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.num_patches = num_patches
+
+ self.projection = nn.Conv2d(
+ num_channels, hidden_size, kernel_size=patch_size, stride=patch_size
+ )
+
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
+ num_channels = pixel_values.shape[1]
+ if num_channels != self.num_channels:
+ raise ValueError(
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
+ f" Expected {self.num_channels} but got {num_channels}."
+ )
+ embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
+ return embeddings
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->Dinov2
+class Dinov2SelfAttention(nn.Module):
+ def __init__(self, config: Dinov2Config) -> None:
+ super().__init__()
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
+ config, "embedding_size"
+ ):
+ raise ValueError(
+ f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
+ f"heads {config.num_attention_heads}."
+ )
+
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+ self.attention_probs_dropout_prob = config.attention_probs_dropout_prob
+
+ self.query = nn.Linear(
+ config.hidden_size, self.all_head_size, bias=config.qkv_bias
+ )
+ self.key = nn.Linear(
+ config.hidden_size, self.all_head_size, bias=config.qkv_bias
+ )
+ self.value = nn.Linear(
+ config.hidden_size, self.all_head_size, bias=config.qkv_bias
+ )
+
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+
+ self.use_memory_efficient_attention_xformers: bool = False
+
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
+ new_x_shape = x.size()[:-1] + (
+ self.num_attention_heads,
+ self.attention_head_size,
+ )
+ x = x.view(new_x_shape)
+ return x.permute(0, 2, 1, 3)
+
+ def forward(
+ self,
+ hidden_states,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
+ mixed_query_layer = self.query(hidden_states)
+
+ if self.use_memory_efficient_attention_xformers:
+ import xformers
+ assert head_mask is None and not output_attentions
+ new_size = hidden_states.size()[:-1] + (
+ self.num_attention_heads,
+ self.attention_head_size,
+ )
+ key_layer = self.key(hidden_states).view(new_size)
+ value_layer = self.value(hidden_states).view(new_size)
+ query_layer = mixed_query_layer.view(new_size)
+ context_layer = xformers.ops.memory_efficient_attention(
+ query_layer, key_layer, value_layer, p=self.attention_probs_dropout_prob
+ )
+ context_layer = context_layer.view(*hidden_states.size()[:-1], -1)
+ else:
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+ query_layer = self.transpose_for_scores(mixed_query_layer)
+
+ try:
+ context_layer = F.scaled_dot_product_attention(query_layer, key_layer, value_layer, attn_mask=head_mask, dropout_p=(self.dropout.p if self.training else 0.0), scale=1/math.sqrt(self.attention_head_size))
+ except:
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
+
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
+
+ # Normalize the attention scores to probabilities.
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs = self.dropout(attention_probs)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attention_probs = attention_probs * head_mask
+
+ context_layer = torch.matmul(attention_probs, value_layer)
+
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.view(new_context_layer_shape)
+
+ outputs = (
+ (context_layer, attention_probs) if output_attentions else (context_layer,)
+ )
+
+ return outputs
+
+ def set_use_memory_efficient_attention_xformers(
+ self, valid: bool, attention_op: Optional[Callable] = None
+ ):
+ self.use_memory_efficient_attention_xformers = valid
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->Dinov2
+class Dinov2SelfOutput(nn.Module):
+ """
+ The residual connection is defined in Dinov2Layer instead of here (as is the case with other models), due to the
+ layernorm applied before each block.
+ """
+
+ def __init__(self, config: Dinov2Config) -> None:
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(
+ self, hidden_states: torch.Tensor, input_tensor: torch.Tensor
+ ) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+
+ return hidden_states
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->Dinov2
+class Dinov2Attention(nn.Module):
+ def __init__(self, config: Dinov2Config) -> None:
+ super().__init__()
+ self.attention = Dinov2SelfAttention(config)
+ self.output = Dinov2SelfOutput(config)
+ self.pruned_heads = set()
+
+ def prune_heads(self, heads: Set[int]) -> None:
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(
+ heads,
+ self.attention.num_attention_heads,
+ self.attention.attention_head_size,
+ self.pruned_heads,
+ )
+
+ # Prune linear layers
+ self.attention.query = prune_linear_layer(self.attention.query, index)
+ self.attention.key = prune_linear_layer(self.attention.key, index)
+ self.attention.value = prune_linear_layer(self.attention.value, index)
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+ # Update hyper params and store pruned heads
+ self.attention.num_attention_heads = self.attention.num_attention_heads - len(
+ heads
+ )
+ self.attention.all_head_size = (
+ self.attention.attention_head_size * self.attention.num_attention_heads
+ )
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
+ self_outputs = self.attention(hidden_states, head_mask, output_attentions)
+
+ attention_output = self.output(self_outputs[0], hidden_states)
+
+ outputs = (attention_output,) + self_outputs[
+ 1:
+ ] # add attentions if we output them
+ return outputs
+
+
+class Dinov2LayerScale(nn.Module):
+ def __init__(self, config) -> None:
+ super().__init__()
+ self.lambda1 = nn.Parameter(
+ config.layerscale_value * torch.ones(config.hidden_size)
+ )
+
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
+ return hidden_state * self.lambda1
+
+
+# Copied from transformers.models.beit.modeling_beit.drop_path
+def drop_path(
+ input: torch.Tensor, drop_prob: float = 0.0, training: bool = False
+) -> torch.Tensor:
+ """
+ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+
+ Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
+ however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
+ layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
+ argument.
+ """
+ if drop_prob == 0.0 or not training:
+ return input
+ keep_prob = 1 - drop_prob
+ shape = (input.shape[0],) + (1,) * (
+ input.ndim - 1
+ ) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = keep_prob + torch.rand(
+ shape, dtype=input.dtype, device=input.device
+ )
+ random_tensor.floor_() # binarize
+ output = input.div(keep_prob) * random_tensor
+ return output
+
+
+# Copied from transformers.models.beit.modeling_beit.BeitDropPath
+class Dinov2DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+ def __init__(self, drop_prob: Optional[float] = None) -> None:
+ super().__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ return drop_path(hidden_states, self.drop_prob, self.training)
+
+ def extra_repr(self) -> str:
+ return "p={}".format(self.drop_prob)
+
+
+class Dinov2MLP(nn.Module):
+ def __init__(self, config) -> None:
+ super().__init__()
+ in_features = out_features = config.hidden_size
+ hidden_features = int(config.hidden_size * config.mlp_ratio)
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=True)
+ if isinstance(config.hidden_act, str):
+ self.activation = ACT2FN[config.hidden_act]
+ else:
+ self.activation = config.hidden_act
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=True)
+
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
+ hidden_state = self.fc1(hidden_state)
+ hidden_state = self.activation(hidden_state)
+ hidden_state = self.fc2(hidden_state)
+ return hidden_state
+
+
+class Dinov2SwiGLUFFN(nn.Module):
+ def __init__(self, config) -> None:
+ super().__init__()
+ in_features = out_features = config.hidden_size
+ hidden_features = int(config.hidden_size * config.mlp_ratio)
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
+
+ self.weights_in = nn.Linear(in_features, 2 * hidden_features, bias=True)
+ self.weights_out = nn.Linear(hidden_features, out_features, bias=True)
+
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
+ hidden_state = self.weights_in(hidden_state)
+ x1, x2 = hidden_state.chunk(2, dim=-1)
+ hidden = nn.functional.silu(x1) * x2
+ return self.weights_out(hidden)
+
+
+class Dinov2Layer(nn.Module, MemoryEfficientAttentionMixin):
+ """This corresponds to the Block class in the original implementation."""
+
+ def __init__(self, config: Dinov2Config) -> None:
+ super().__init__()
+
+ self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.norm1_modulation = None
+ self.attention = Dinov2Attention(config)
+ self.layer_scale1 = Dinov2LayerScale(config)
+ self.drop_path1 = (
+ Dinov2DropPath(config.drop_path_rate)
+ if config.drop_path_rate > 0.0
+ else nn.Identity()
+ )
+
+ self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.norm2_modulation = None
+
+ if config.use_swiglu_ffn:
+ self.mlp = Dinov2SwiGLUFFN(config)
+ else:
+ self.mlp = Dinov2MLP(config)
+ self.layer_scale2 = Dinov2LayerScale(config)
+ self.drop_path2 = (
+ Dinov2DropPath(config.drop_path_rate)
+ if config.drop_path_rate > 0.0
+ else nn.Identity()
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ head_mask: Optional[torch.Tensor] = None,
+ modulation_cond: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
+ hidden_states_norm = self.norm1(hidden_states)
+ if self.norm1_modulation is not None:
+ assert modulation_cond is not None
+ hidden_states_norm = self.norm1_modulation(
+ hidden_states_norm, modulation_cond
+ )
+ self_attention_outputs = self.attention(
+ hidden_states_norm, # in Dinov2, layernorm is applied before self-attention
+ head_mask,
+ output_attentions=output_attentions,
+ )
+ attention_output = self_attention_outputs[0]
+
+ attention_output = self.layer_scale1(attention_output)
+ outputs = self_attention_outputs[
+ 1:
+ ] # add self attentions if we output attention weights
+
+ # first residual connection
+ hidden_states = attention_output + hidden_states
+
+ # in Dinov2, layernorm is also applied after self-attention
+ layer_output = self.norm2(hidden_states)
+ if self.norm2_modulation is not None:
+ assert modulation_cond is not None
+ layer_output = self.norm2_modulation(layer_output, modulation_cond)
+ layer_output = self.mlp(layer_output)
+ layer_output = self.layer_scale2(layer_output)
+
+ # second residual connection
+ layer_output = layer_output + hidden_states
+
+ outputs = (layer_output,) + outputs
+
+ return outputs
+
+ def register_ada_norm_modulation(self, norm1_mod: nn.Module, norm2_mod: nn.Module):
+ self.norm1_modulation = norm1_mod
+ self.norm2_modulation = norm2_mod
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->Dinov2
+class Dinov2Encoder(nn.Module, MemoryEfficientAttentionMixin):
+ def __init__(self, config: Dinov2Config) -> None:
+ super().__init__()
+ self.config = config
+ self.layer = nn.ModuleList(
+ [Dinov2Layer(config) for _ in range(config.num_hidden_layers)]
+ )
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ head_mask: Optional[torch.Tensor] = None,
+ modulation_cond: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ) -> Union[tuple, BaseModelOutput]:
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+
+ for i, layer_module in enumerate(self.layer):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+
+ if self.gradient_checkpointing and self.training:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs, output_attentions)
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(layer_module),
+ hidden_states,
+ layer_head_mask,
+ modulation_cond,
+ )
+ else:
+ layer_outputs = layer_module(
+ hidden_states, layer_head_mask, modulation_cond, output_attentions
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [hidden_states, all_hidden_states, all_self_attentions]
+ if v is not None
+ )
+ return BaseModelOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ )
+
+
+class Dinov2PreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = Dinov2Config
+ base_model_prefix = "dinov2"
+ main_input_name = "pixel_values"
+ supports_gradient_checkpointing = True
+
+ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
+ """Initialize the weights"""
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
+ # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
+ # `trunc_normal_cpu` not implemented in `half` issues
+ module.weight.data = nn.init.trunc_normal_(
+ module.weight.data.to(torch.float32),
+ mean=0.0,
+ std=self.config.initializer_range,
+ ).to(module.weight.dtype)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ elif isinstance(module, Dinov2Embeddings):
+ module.position_embeddings.data = nn.init.trunc_normal_(
+ module.position_embeddings.data.to(torch.float32),
+ mean=0.0,
+ std=self.config.initializer_range,
+ ).to(module.position_embeddings.dtype)
+
+ module.cls_token.data = nn.init.trunc_normal_(
+ module.cls_token.data.to(torch.float32),
+ mean=0.0,
+ std=self.config.initializer_range,
+ ).to(module.cls_token.dtype)
+
+ def _set_gradient_checkpointing(
+ self, module: Dinov2Encoder, value: bool = False
+ ) -> None:
+ if isinstance(module, Dinov2Encoder):
+ module.gradient_checkpointing = value
+
+
+DINOV2_START_DOCSTRING = r"""
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
+ as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
+ behavior.
+
+ Parameters:
+ config ([`Dinov2Config`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+DINOV2_BASE_INPUTS_DOCSTRING = r"""
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
+ [`BitImageProcessor.preprocess`] for details.
+
+ bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, sequence_length)`):
+ Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). Only relevant for
+ pre-training.
+
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+DINOV2_INPUTS_DOCSTRING = r"""
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
+ [`BitImageProcessor.preprocess`] for details.
+
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+@dataclass
+class CustomBaseModelOutputWithPooling(BaseModelOutputWithPooling):
+ patch_embeddings: Optional[torch.FloatTensor] = None
+
+
+@add_start_docstrings(
+ "The bare DINOv2 Model transformer outputting raw hidden-states without any specific head on top.",
+ DINOV2_START_DOCSTRING,
+)
+class Dinov2Model(Dinov2PreTrainedModel, MemoryEfficientAttentionMixin):
+ def __init__(self, config: Dinov2Config):
+ super().__init__(config)
+ self.config = config
+
+ self.embeddings = Dinov2Embeddings(config)
+ self.encoder = Dinov2Encoder(config)
+
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self) -> Dinov2PatchEmbeddings:
+ return self.embeddings.patch_embeddings
+
+ def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+ @add_start_docstrings_to_model_forward(DINOV2_BASE_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=BaseModelOutputWithPooling,
+ config_class=_CONFIG_FOR_DOC,
+ modality="vision",
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
+ )
+ def forward(
+ self,
+ pixel_values: Optional[torch.Tensor] = None,
+ bool_masked_pos: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ modulation_cond: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
+ output_attentions = (
+ output_attentions
+ if output_attentions is not None
+ else self.config.output_attentions
+ )
+ output_hidden_states = (
+ output_hidden_states
+ if output_hidden_states is not None
+ else self.config.output_hidden_states
+ )
+ return_dict = (
+ return_dict if return_dict is not None else self.config.use_return_dict
+ )
+
+ if pixel_values is None:
+ raise ValueError("You have to specify pixel_values")
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+
+ embedding_output = self.embeddings(
+ pixel_values, bool_masked_pos=bool_masked_pos
+ )
+
+ encoder_outputs = self.encoder(
+ embedding_output,
+ head_mask=head_mask,
+ modulation_cond=modulation_cond,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ sequence_output = encoder_outputs[0]
+ sequence_output = self.layernorm(sequence_output)
+ pooled_output = sequence_output[:, 0, :]
+
+ if not return_dict:
+ head_outputs = (sequence_output, pooled_output)
+ return head_outputs + encoder_outputs[1:]
+
+ return CustomBaseModelOutputWithPooling(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ patch_embeddings=embedding_output
+ )
+
+ def set_gradient_checkpointing(self, value: bool = False) -> None:
+ self._set_gradient_checkpointing(self.encoder, value)
+
+
+@add_start_docstrings(
+ """
+ Dinov2 Model transformer with an image classification head on top (a linear layer on top of the final hidden state
+ of the [CLS] token) e.g. for ImageNet.
+ """,
+ DINOV2_START_DOCSTRING,
+)
+class Dinov2ForImageClassification(Dinov2PreTrainedModel):
+ def __init__(self, config: Dinov2Config) -> None:
+ super().__init__(config)
+
+ self.num_labels = config.num_labels
+ self.dinov2 = Dinov2Model(config)
+
+ # Classifier head
+ self.classifier = (
+ nn.Linear(config.hidden_size * 2, config.num_labels)
+ if config.num_labels > 0
+ else nn.Identity()
+ )
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(DINOV2_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_IMAGE_CLASS_CHECKPOINT,
+ output_type=ImageClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ pixel_values: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, ImageClassifierOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = (
+ return_dict if return_dict is not None else self.config.use_return_dict
+ )
+
+ outputs = self.dinov2(
+ pixel_values,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0] # batch_size, sequence_length, hidden_size
+
+ cls_token = sequence_output[:, 0]
+ patch_tokens = sequence_output[:, 1:]
+
+ linear_input = torch.cat([cls_token, patch_tokens.mean(dim=1)], dim=1)
+
+ logits = self.classifier(linear_input)
+
+ loss = None
+ if labels is not None:
+ # move labels to correct device to enable model parallelism
+ labels = labels.to(logits.device)
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (
+ labels.dtype == torch.long or labels.dtype == torch.int
+ ):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(logits, labels)
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return ImageClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ Dinov2 backbone, to be used with frameworks like DETR and MaskFormer.
+ """,
+ DINOV2_START_DOCSTRING,
+)
+class Dinov2Backbone(Dinov2PreTrainedModel, BackboneMixin):
+ def __init__(self, config):
+ super().__init__(config)
+ super()._init_backbone(config)
+
+ self.num_features = [
+ config.hidden_size for _ in range(config.num_hidden_layers + 1)
+ ]
+ self.embeddings = Dinov2Embeddings(config)
+ self.encoder = Dinov2Encoder(config)
+
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self) -> Dinov2PatchEmbeddings:
+ return self.embeddings.patch_embeddings
+
+ @add_start_docstrings_to_model_forward(DINOV2_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ pixel_values: torch.Tensor,
+ output_hidden_states: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> BackboneOutput:
+ """
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoImageProcessor, AutoBackbone
+ >>> import torch
+ >>> from PIL import Image
+ >>> import requests
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base")
+ >>> model = AutoBackbone.from_pretrained(
+ ... "facebook/dinov2-base", out_features=["stage2", "stage5", "stage8", "stage11"]
+ ... )
+
+ >>> inputs = processor(image, return_tensors="pt")
+
+ >>> outputs = model(**inputs)
+ >>> feature_maps = outputs.feature_maps
+ >>> list(feature_maps[-1].shape)
+ [1, 768, 16, 16]
+ ```"""
+ return_dict = (
+ return_dict if return_dict is not None else self.config.use_return_dict
+ )
+ output_hidden_states = (
+ output_hidden_states
+ if output_hidden_states is not None
+ else self.config.output_hidden_states
+ )
+ output_attentions = (
+ output_attentions
+ if output_attentions is not None
+ else self.config.output_attentions
+ )
+
+ embedding_output = self.embeddings(pixel_values)
+
+ outputs = self.encoder(
+ embedding_output,
+ output_hidden_states=True,
+ output_attentions=output_attentions,
+ return_dict=return_dict,
+ )
+
+ hidden_states = outputs.hidden_states if return_dict else outputs[1]
+
+ feature_maps = ()
+ for stage, hidden_state in zip(self.stage_names, hidden_states):
+ if stage in self.out_features:
+ if self.config.apply_layernorm:
+ hidden_state = self.layernorm(hidden_state)
+ if self.config.reshape_hidden_states:
+ batch_size, _, height, width = pixel_values.shape
+ patch_size = self.config.patch_size
+ hidden_state = hidden_state[:, 1:, :].reshape(
+ batch_size, width // patch_size, height // patch_size, -1
+ )
+ hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
+ feature_maps += (hidden_state,)
+
+ if not return_dict:
+ if output_hidden_states:
+ output = (feature_maps,) + outputs[1:]
+ else:
+ output = (feature_maps,) + outputs[2:]
+ return output
+
+ return BackboneOutput(
+ feature_maps=feature_maps,
+ hidden_states=outputs.hidden_states if output_hidden_states else None,
+ attentions=outputs.attentions if output_attentions else None,
+ )
+
+
+
+class CustomPatchEmbeddings(nn.Module):
+ """
+ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
+ `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
+ Transformer.
+ """
+
+ def __init__(self, image_size: int, patch_size: int, num_channels: int, hidden_size: int):
+ super().__init__()
+
+ image_size = (
+ image_size
+ if isinstance(image_size, collections.abc.Iterable)
+ else (image_size, image_size)
+ )
+ patch_size = (
+ patch_size
+ if isinstance(patch_size, collections.abc.Iterable)
+ else (patch_size, patch_size)
+ )
+ num_patches = (image_size[1] // patch_size[1]) * (
+ image_size[0] // patch_size[0]
+ )
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.num_patches = num_patches
+
+ self.projection = nn.Conv2d(
+ num_channels, hidden_size, kernel_size=patch_size, stride=patch_size
+ )
+
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
+ num_channels = pixel_values.shape[1]
+ if num_channels != self.num_channels:
+ raise ValueError(
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
+ f" Expected {self.num_channels} but got {num_channels}."
+ )
+ embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
+ return embeddings
+
+
+class CustomEmbeddings(nn.Module):
+ """
+ Construct the CLS token, mask token, position and patch embeddings.
+ """
+
+ def __init__(self, image_size: int, patch_size: int, num_channels: int, hidden_size: int) -> None:
+ super().__init__()
+
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.hidden_size = hidden_size
+
+ self.cls_token = nn.Parameter(torch.randn(1, 1, self.hidden_size))
+
+ self.patch_embeddings = CustomPatchEmbeddings(image_size, patch_size, num_channels, hidden_size)
+ num_patches = self.patch_embeddings.num_patches
+ self.position_embeddings = nn.Parameter(
+ torch.randn(1, num_patches + 1, self.hidden_size)
+ )
+
+ def interpolate_pos_encoding(
+ self, embeddings: torch.Tensor, height: int, width: int
+ ) -> torch.Tensor:
+ """
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
+ resolution images.
+
+ Source:
+ https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
+ """
+
+ num_patches = embeddings.shape[1] - 1
+ num_positions = self.position_embeddings.shape[1] - 1
+ if num_patches == num_positions and height == width:
+ return self.position_embeddings
+ class_pos_embed = self.position_embeddings[:, 0]
+ patch_pos_embed = self.position_embeddings[:, 1:]
+ dim = embeddings.shape[-1]
+ height = height // self.patch_size
+ width = width // self.patch_size
+ # we add a small number to avoid floating point error in the interpolation
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
+ height, width = height + 0.1, width + 0.1
+ patch_pos_embed = patch_pos_embed.reshape(
+ 1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim
+ )
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
+ patch_pos_embed = nn.functional.interpolate(
+ patch_pos_embed,
+ scale_factor=(
+ height / math.sqrt(num_positions),
+ width / math.sqrt(num_positions),
+ ),
+ mode="bicubic",
+ align_corners=False,
+ )
+ if (
+ int(height) != patch_pos_embed.shape[-2]
+ or int(width) != patch_pos_embed.shape[-1]
+ ):
+ raise ValueError(
+ "Width or height does not match with the interpolated position embeddings"
+ )
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
+
+ def forward(
+ self, pixel_values: torch.Tensor,
+ ) -> torch.Tensor:
+ batch_size, _, height, width = pixel_values.shape
+ patch_embeddings = self.patch_embeddings(pixel_values)
+ embeddings = patch_embeddings
+
+ # add the [CLS] token to the embedded patch tokens
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1)
+ embeddings = torch.cat((cls_tokens, embeddings), dim=1)
+
+ # add positional encoding to each token
+ embeddings = embeddings + self.interpolate_pos_encoding(
+ embeddings, height, width
+ )
+
+ return embeddings
diff --git a/hort/models/tgs/models/tokenizers/image.py b/hort/models/tgs/models/tokenizers/image.py
new file mode 100644
index 0000000000000000000000000000000000000000..548aee829200cc37e8e44ac363556de776bc9e74
--- /dev/null
+++ b/hort/models/tgs/models/tokenizers/image.py
@@ -0,0 +1,123 @@
+from dataclasses import dataclass
+
+import torch
+import torch.nn as nn
+from einops import rearrange
+
+from tgs.utils.base import BaseModule
+from tgs.models.tokenizers.dinov2 import Dinov2Model
+from tgs.models.transformers import Modulation
+from tgs.utils.typing import *
+
+class DINOV2SingleImageTokenizer(BaseModule):
+ @dataclass
+ class Config(BaseModule.Config):
+ pretrained_model_name_or_path: str = "facebook/dinov2-base"
+ width: int = 224
+ height: int = 224
+ modulation: bool = False
+ modulation_zero_init: bool = False
+ modulation_single_layer: bool = False
+ modulation_cond_dim: int = 16
+ freeze_backbone_params: bool = True
+ enable_memory_efficient_attention: bool = False
+ enable_gradient_checkpointing: bool = False
+ use_patch_embeddings: bool = False
+ patch_embeddings_aggr_method: str = 'concat'
+
+ cfg: Config
+
+ def configure(self) -> None:
+ super().configure()
+ model: Dinov2Model
+
+ if self.cfg.freeze_backbone_params:
+ # freeze dino backbone parameters
+ self.register_non_module(
+ "model",
+ Dinov2Model.from_pretrained(self.cfg.pretrained_model_name_or_path).to(
+ self.device
+ ),
+ )
+
+ model = self.non_module("model")
+ for p in model.parameters():
+ p.requires_grad_(False)
+ model.eval()
+ else:
+ self.model = Dinov2Model.from_pretrained(
+ self.cfg.pretrained_model_name_or_path
+ ).to(self.device)
+ model = self.model
+
+ model.set_use_memory_efficient_attention_xformers(
+ self.cfg.enable_memory_efficient_attention
+ )
+ model.set_gradient_checkpointing(self.cfg.enable_gradient_checkpointing)
+
+ # add modulation
+ if self.cfg.modulation:
+ modulations = []
+ for layer in model.encoder.layer:
+ norm1_modulation = Modulation(
+ model.config.hidden_size,
+ self.cfg.modulation_cond_dim,
+ zero_init=self.cfg.modulation_zero_init,
+ single_layer=self.cfg.modulation_single_layer,
+ )
+ norm2_modulation = Modulation(
+ model.config.hidden_size,
+ self.cfg.modulation_cond_dim,
+ zero_init=self.cfg.modulation_zero_init,
+ single_layer=self.cfg.modulation_single_layer,
+ )
+ layer.register_ada_norm_modulation(norm1_modulation, norm2_modulation)
+ modulations += [norm1_modulation, norm2_modulation]
+ self.modulations = nn.ModuleList(modulations)
+
+ def forward(
+ self,
+ images: Float[Tensor, "B *N C H W"],
+ modulation_cond: Optional[Float[Tensor, "B *N Cc"]],
+ ) -> Float[Tensor, "B *N Ct Nt"]:
+ model: Dinov2Model
+ if self.cfg.freeze_backbone_params:
+ model = self.non_module("model")
+ else:
+ model = self.model
+
+ packed = False
+ if images.ndim == 4:
+ packed = True
+ images = images.unsqueeze(1)
+ if modulation_cond is not None:
+ assert modulation_cond.ndim == 2
+ modulation_cond = modulation_cond.unsqueeze(1)
+
+ batch_size, n_input_views = images.shape[:2]
+ out = model(
+ rearrange(images, "B N C H W -> (B N) C H W"),
+ modulation_cond=rearrange(modulation_cond, "B N Cc -> (B N) Cc")
+ if modulation_cond is not None
+ else None,
+ )
+ local_features, global_features = out.last_hidden_state, out.pooler_output
+ if self.cfg.use_patch_embeddings:
+ patch_embeddings = out.patch_embeddings
+ if self.cfg.patch_embeddings_aggr_method == 'concat':
+ local_features = torch.cat([local_features, patch_embeddings], dim=1)
+ elif self.cfg.patch_embeddings_aggr_method == 'add':
+ local_features = local_features + patch_embeddings
+ else:
+ raise NotImplementedError
+ local_features = local_features.permute(0, 2, 1)
+ local_features = rearrange(
+ local_features, "(B N) Ct Nt -> B N Ct Nt", B=batch_size
+ )
+ if packed:
+ local_features = local_features.squeeze(1)
+
+ return local_features
+
+ def detokenize(self, *args, **kwargs):
+ raise NotImplementedError
diff --git a/hort/models/tgs/models/tokenizers/point.py b/hort/models/tgs/models/tokenizers/point.py
new file mode 100644
index 0000000000000000000000000000000000000000..99afb6cc6f2736dd5a24b3c7cdf337ea860bc5ec
--- /dev/null
+++ b/hort/models/tgs/models/tokenizers/point.py
@@ -0,0 +1,29 @@
+from dataclasses import dataclass
+import torch.nn as nn
+from tgs.utils.base import BaseModule
+from tgs.utils.typing import *
+import torch
+
+class PointLearnablePositionalEmbedding(BaseModule):
+ @dataclass
+ class Config(BaseModule.Config):
+ num_pcl: int = 2048
+ num_channels: int = 512
+
+ cfg: Config
+
+ def configure(self) -> None:
+ super().configure()
+ self.pcl_embeddings = nn.Embedding(
+ self.cfg.num_pcl , self.cfg.num_channels
+ )
+
+ def forward(self, batch_size: int) -> Float[Tensor, "B Ct Nt"]:
+ range_ = torch.arange(self.cfg.num_pcl, device=self.device)
+ embeddings = self.pcl_embeddings(range_).unsqueeze(0).repeat((batch_size,1,1))
+ return torch.permute(embeddings, (0,2,1))
+
+ def detokenize(
+ self, tokens: Float[Tensor, "B Ct Nt"]
+ ) -> Float[Tensor, "B 3 Ct Hp Wp"]:
+ return torch.permute(tokens, (0,2,1))
\ No newline at end of file
diff --git a/hort/models/tgs/models/tokenizers/triplane.py b/hort/models/tgs/models/tokenizers/triplane.py
new file mode 100644
index 0000000000000000000000000000000000000000..8bd6774c5028d25308878c5218ee4664cede6e63
--- /dev/null
+++ b/hort/models/tgs/models/tokenizers/triplane.py
@@ -0,0 +1,52 @@
+from dataclasses import dataclass
+import math
+
+import torch
+import torch.nn as nn
+from einops import rearrange, repeat
+
+from tgs.utils.base import BaseModule
+from tgs.utils.typing import *
+
+
+class TriplaneLearnablePositionalEmbedding(BaseModule):
+ @dataclass
+ class Config(BaseModule.Config):
+ plane_size: int = 32
+ num_channels: int = 1024
+
+ cfg: Config
+
+ def configure(self) -> None:
+ super().configure()
+ self.embeddings = nn.Parameter(
+ torch.randn(
+ (3, self.cfg.num_channels, self.cfg.plane_size, self.cfg.plane_size),
+ dtype=torch.float32,
+ )
+ * 1
+ / math.sqrt(self.cfg.num_channels)
+ )
+
+ def forward(self, batch_size: int, cond_embeddings: Float[Tensor, "B Ct"] = None) -> Float[Tensor, "B Ct Nt"]:
+ embeddings = repeat(self.embeddings, "Np Ct Hp Wp -> B Np Ct Hp Wp", B=batch_size)
+ if cond_embeddings is not None:
+ embeddings = embeddings + cond_embeddings
+ return rearrange(
+ embeddings,
+ "B Np Ct Hp Wp -> B Ct (Np Hp Wp)",
+ )
+
+ def detokenize(
+ self, tokens: Float[Tensor, "B Ct Nt"]
+ ) -> Float[Tensor, "B 3 Ct Hp Wp"]:
+ batch_size, Ct, Nt = tokens.shape
+ assert Nt == self.cfg.plane_size**2 * 3
+ assert Ct == self.cfg.num_channels
+ return rearrange(
+ tokens,
+ "B Ct (Np Hp Wp) -> B Np Ct Hp Wp",
+ Np=3,
+ Hp=self.cfg.plane_size,
+ Wp=self.cfg.plane_size,
+ )
diff --git a/hort/models/tgs/models/transformers.py b/hort/models/tgs/models/transformers.py
new file mode 100644
index 0000000000000000000000000000000000000000..6c7ce3b82f3c06a44ba97986307ddc30284163d0
--- /dev/null
+++ b/hort/models/tgs/models/transformers.py
@@ -0,0 +1,908 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from diffusers.utils.torch_utils import maybe_allow_in_graph
+from diffusers.models.activations import get_activation
+from diffusers.models.attention_processor import Attention
+from diffusers.models.embeddings import CombinedTimestepLabelEmbeddings
+
+from dataclasses import dataclass
+from tgs.utils.base import BaseModule
+from tgs.utils.typing import *
+
+
+class MemoryEfficientAttentionMixin:
+ def enable_xformers_memory_efficient_attention(
+ self, attention_op: Optional[Callable] = None
+ ):
+ r"""
+ Enable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/). When this
+ option is enabled, you should observe lower GPU memory usage and a potential speed up during inference. Speed
+ up during training is not guaranteed.
+
+
+
+ ⚠️ When memory efficient attention and sliced attention are both enabled, memory efficient attention takes
+ precedent.
+
+
+
+ Parameters:
+ attention_op (`Callable`, *optional*):
+ Override the default `None` operator for use as `op` argument to the
+ [`memory_efficient_attention()`](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.memory_efficient_attention)
+ function of xFormers.
+
+ Examples:
+
+ ```py
+ >>> import torch
+ >>> from diffusers import DiffusionPipeline
+ >>> from xformers.ops import MemoryEfficientAttentionFlashAttentionOp
+
+ >>> pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16)
+ >>> pipe = pipe.to("cuda")
+ >>> pipe.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp)
+ >>> # Workaround for not accepting attention shape using VAE for Flash Attention
+ >>> pipe.vae.enable_xformers_memory_efficient_attention(attention_op=None)
+ ```
+ """
+ self.set_use_memory_efficient_attention_xformers(True, attention_op)
+
+ def disable_xformers_memory_efficient_attention(self):
+ r"""
+ Disable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/).
+ """
+ self.set_use_memory_efficient_attention_xformers(False)
+
+ def set_use_memory_efficient_attention_xformers(
+ self, valid: bool, attention_op: Optional[Callable] = None
+ ) -> None:
+ # Recursively walk through all the children.
+ # Any children which exposes the set_use_memory_efficient_attention_xformers method
+ # gets the message
+ def fn_recursive_set_mem_eff(module: torch.nn.Module):
+ if hasattr(module, "set_use_memory_efficient_attention_xformers"):
+ module.set_use_memory_efficient_attention_xformers(valid, attention_op)
+
+ for child in module.children():
+ fn_recursive_set_mem_eff(child)
+
+ for module in self.children():
+ if isinstance(module, torch.nn.Module):
+ fn_recursive_set_mem_eff(module)
+
+
+@maybe_allow_in_graph
+class GatedSelfAttentionDense(nn.Module):
+ r"""
+ A gated self-attention dense layer that combines visual features and object features.
+
+ Parameters:
+ query_dim (`int`): The number of channels in the query.
+ context_dim (`int`): The number of channels in the context.
+ n_heads (`int`): The number of heads to use for attention.
+ d_head (`int`): The number of channels in each head.
+ """
+
+ def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int):
+ super().__init__()
+
+ # we need a linear projection since we need cat visual feature and obj feature
+ self.linear = nn.Linear(context_dim, query_dim)
+
+ self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
+ self.ff = FeedForward(query_dim, activation_fn="geglu")
+
+ self.norm1 = nn.LayerNorm(query_dim)
+ self.norm2 = nn.LayerNorm(query_dim)
+
+ self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0)))
+ self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0)))
+
+ self.enabled = True
+
+ def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor:
+ if not self.enabled:
+ return x
+
+ n_visual = x.shape[1]
+ objs = self.linear(objs)
+
+ x = (
+ x
+ + self.alpha_attn.tanh()
+ * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :]
+ )
+ x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x))
+
+ return x
+
+
+@maybe_allow_in_graph
+class BasicTransformerBlock(nn.Module, MemoryEfficientAttentionMixin):
+ r"""
+ A basic Transformer block.
+
+ Parameters:
+ dim (`int`): The number of channels in the input and output.
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
+ attention_head_dim (`int`): The number of channels in each head.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
+ num_embeds_ada_norm (:
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
+ attention_bias (:
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
+ only_cross_attention (`bool`, *optional*):
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
+ double_self_attention (`bool`, *optional*):
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
+ upcast_attention (`bool`, *optional*):
+ Whether to upcast the attention computation to float32. This is useful for mixed precision training.
+ norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
+ Whether to use learnable elementwise affine parameters for normalization.
+ norm_type (`str`, *optional*, defaults to `"layer_norm"`):
+ The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
+ final_dropout (`bool` *optional*, defaults to False):
+ Whether to apply a final dropout after the last feed-forward layer.
+ attention_type (`str`, *optional*, defaults to `"default"`):
+ The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ dropout=0.0,
+ cross_attention_dim: Optional[int] = None,
+ activation_fn: str = "geglu",
+ num_embeds_ada_norm: Optional[int] = None,
+ cond_dim_ada_norm_continuous: Optional[int] = None,
+ attention_bias: bool = False,
+ only_cross_attention: bool = False,
+ double_self_attention: bool = False,
+ upcast_attention: bool = False,
+ norm_elementwise_affine: bool = True,
+ norm_type: str = "layer_norm",
+ final_dropout: bool = False,
+ attention_type: str = "default",
+ ):
+ super().__init__()
+ self.only_cross_attention = only_cross_attention
+
+ self.use_ada_layer_norm_zero = (
+ num_embeds_ada_norm is not None
+ ) and norm_type == "ada_norm_zero"
+ self.use_ada_layer_norm = (
+ num_embeds_ada_norm is not None
+ ) and norm_type == "ada_norm"
+ self.use_ada_layer_norm_continuous = (
+ cond_dim_ada_norm_continuous is not None
+ ) and norm_type == "ada_norm_continuous"
+
+ assert (
+ int(self.use_ada_layer_norm)
+ + int(self.use_ada_layer_norm_continuous)
+ + int(self.use_ada_layer_norm_zero)
+ <= 1
+ )
+
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
+ raise ValueError(
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
+ )
+
+ # Define 3 blocks. Each block has its own normalization layer.
+ # 1. Self-Attn
+ if self.use_ada_layer_norm:
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
+ elif self.use_ada_layer_norm_continuous:
+ self.norm1 = AdaLayerNormContinuous(dim, cond_dim_ada_norm_continuous)
+ elif self.use_ada_layer_norm_zero:
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
+ else:
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
+ self.attn1 = Attention(
+ query_dim=dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
+ upcast_attention=upcast_attention,
+ )
+
+ # 2. Cross-Attn
+ if cross_attention_dim is not None or double_self_attention:
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
+ # the second cross attention block.
+ if self.use_ada_layer_norm:
+ self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm)
+ elif self.use_ada_layer_norm_continuous:
+ self.norm2 = AdaLayerNormContinuous(dim, cond_dim_ada_norm_continuous)
+ else:
+ self.norm2 = nn.LayerNorm(
+ dim, elementwise_affine=norm_elementwise_affine
+ )
+
+ self.attn2 = Attention(
+ query_dim=dim,
+ cross_attention_dim=cross_attention_dim
+ if not double_self_attention
+ else None,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ upcast_attention=upcast_attention,
+ ) # is self-attn if encoder_hidden_states is none
+ else:
+ self.norm2 = None
+ self.attn2 = None
+
+ # 3. Feed-forward
+ if self.use_ada_layer_norm_continuous:
+ self.norm3 = AdaLayerNormContinuous(dim, cond_dim_ada_norm_continuous)
+ else:
+ self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
+ self.ff = FeedForward(
+ dim,
+ dropout=dropout,
+ activation_fn=activation_fn,
+ final_dropout=final_dropout,
+ )
+
+ # 4. Fuser
+ if attention_type == "gated" or attention_type == "gated-text-image":
+ self.fuser = GatedSelfAttentionDense(
+ dim, cross_attention_dim, num_attention_heads, attention_head_dim
+ )
+
+ # let chunk size default to None
+ self._chunk_size = None
+ self._chunk_dim = 0
+
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
+ # Sets chunk feed-forward
+ self._chunk_size = chunk_size
+ self._chunk_dim = dim
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ timestep: Optional[torch.LongTensor] = None,
+ modulation_cond: Optional[torch.FloatTensor] = None,
+ cross_attention_kwargs: Dict[str, Any] = None,
+ class_labels: Optional[torch.LongTensor] = None,
+ ) -> torch.FloatTensor:
+ # Notice that normalization is always applied before the real computation in the following blocks.
+ # 0. Self-Attention
+ if self.use_ada_layer_norm:
+ norm_hidden_states = self.norm1(hidden_states, timestep)
+ elif self.use_ada_layer_norm_continuous:
+ norm_hidden_states = self.norm1(hidden_states, modulation_cond)
+ elif self.use_ada_layer_norm_zero:
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
+ )
+ else:
+ norm_hidden_states = self.norm1(hidden_states)
+
+ # 1. Retrieve lora scale.
+ lora_scale = (
+ cross_attention_kwargs.get("scale", 1.0)
+ if cross_attention_kwargs is not None
+ else 1.0
+ )
+
+ # 2. Prepare GLIGEN inputs
+ cross_attention_kwargs = (
+ cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
+ )
+ gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
+
+ attn_output = self.attn1(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states
+ if self.only_cross_attention
+ else None,
+ attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
+ if self.use_ada_layer_norm_zero:
+ attn_output = gate_msa.unsqueeze(1) * attn_output
+ hidden_states = attn_output + hidden_states
+
+ # 2.5 GLIGEN Control
+ if gligen_kwargs is not None:
+ hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
+ # 2.5 ends
+
+ # 3. Cross-Attention
+ if self.attn2 is not None:
+ if self.use_ada_layer_norm:
+ norm_hidden_states = self.norm2(hidden_states, timestep)
+ elif self.use_ada_layer_norm_continuous:
+ norm_hidden_states = self.norm2(hidden_states, modulation_cond)
+ else:
+ norm_hidden_states = self.norm2(hidden_states)
+
+ attn_output = self.attn2(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=encoder_attention_mask,
+ **cross_attention_kwargs,
+ )
+ hidden_states = attn_output + hidden_states
+
+ # 4. Feed-forward
+ if self.use_ada_layer_norm_continuous:
+ norm_hidden_states = self.norm3(hidden_states, modulation_cond)
+ else:
+ norm_hidden_states = self.norm3(hidden_states)
+
+ if self.use_ada_layer_norm_zero:
+ norm_hidden_states = (
+ norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
+ )
+
+ if self._chunk_size is not None:
+ # "feed_forward_chunk_size" can be used to save memory
+ if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
+ raise ValueError(
+ f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
+ )
+
+ num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
+ ff_output = torch.cat(
+ [
+ self.ff(hid_slice, scale=lora_scale)
+ for hid_slice in norm_hidden_states.chunk(
+ num_chunks, dim=self._chunk_dim
+ )
+ ],
+ dim=self._chunk_dim,
+ )
+ else:
+ ff_output = self.ff(norm_hidden_states, scale=lora_scale)
+
+ if self.use_ada_layer_norm_zero:
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
+
+ hidden_states = ff_output + hidden_states
+
+ return hidden_states
+
+
+class FeedForward(nn.Module):
+ r"""
+ A feed-forward layer.
+
+ Parameters:
+ dim (`int`): The number of channels in the input.
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
+ final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ dim_out: Optional[int] = None,
+ mult: int = 4,
+ dropout: float = 0.0,
+ activation_fn: str = "geglu",
+ final_dropout: bool = False,
+ ):
+ super().__init__()
+ inner_dim = int(dim * mult)
+ dim_out = dim_out if dim_out is not None else dim
+ linear_cls = nn.Linear
+
+ if activation_fn == "gelu":
+ act_fn = GELU(dim, inner_dim)
+ if activation_fn == "gelu-approximate":
+ act_fn = GELU(dim, inner_dim, approximate="tanh")
+ elif activation_fn == "geglu":
+ act_fn = GEGLU(dim, inner_dim)
+ elif activation_fn == "geglu-approximate":
+ act_fn = ApproximateGELU(dim, inner_dim)
+
+ self.net = nn.ModuleList([])
+ # project in
+ self.net.append(act_fn)
+ # project dropout
+ self.net.append(nn.Dropout(dropout))
+ # project out
+ self.net.append(linear_cls(inner_dim, dim_out))
+ # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
+ if final_dropout:
+ self.net.append(nn.Dropout(dropout))
+
+ def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
+ for module in self.net:
+ hidden_states = module(hidden_states)
+ return hidden_states
+
+
+class GELU(nn.Module):
+ r"""
+ GELU activation function with tanh approximation support with `approximate="tanh"`.
+
+ Parameters:
+ dim_in (`int`): The number of channels in the input.
+ dim_out (`int`): The number of channels in the output.
+ approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation.
+ """
+
+ def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out)
+ self.approximate = approximate
+
+ def gelu(self, gate: torch.Tensor) -> torch.Tensor:
+ if gate.device.type != "mps":
+ return F.gelu(gate, approximate=self.approximate)
+ # mps: gelu is not implemented for float16
+ return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(
+ dtype=gate.dtype
+ )
+
+ def forward(self, hidden_states):
+ hidden_states = self.proj(hidden_states)
+ hidden_states = self.gelu(hidden_states)
+ return hidden_states
+
+
+class GEGLU(nn.Module):
+ r"""
+ A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
+
+ Parameters:
+ dim_in (`int`): The number of channels in the input.
+ dim_out (`int`): The number of channels in the output.
+ """
+
+ def __init__(self, dim_in: int, dim_out: int):
+ super().__init__()
+ linear_cls = nn.Linear
+
+ self.proj = linear_cls(dim_in, dim_out * 2)
+
+ def gelu(self, gate: torch.Tensor) -> torch.Tensor:
+ if gate.device.type != "mps":
+ return F.gelu(gate)
+ # mps: gelu is not implemented for float16
+ return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
+
+ def forward(self, hidden_states, scale: float = 1.0):
+ args = ()
+ hidden_states, gate = self.proj(hidden_states, *args).chunk(2, dim=-1)
+ return hidden_states * self.gelu(gate)
+
+
+class ApproximateGELU(nn.Module):
+ r"""
+ The approximate form of Gaussian Error Linear Unit (GELU). For more details, see section 2:
+ https://arxiv.org/abs/1606.08415.
+
+ Parameters:
+ dim_in (`int`): The number of channels in the input.
+ dim_out (`int`): The number of channels in the output.
+ """
+
+ def __init__(self, dim_in: int, dim_out: int):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.proj(x)
+ return x * torch.sigmoid(1.702 * x)
+
+
+class AdaLayerNorm(nn.Module):
+ r"""
+ Norm layer modified to incorporate timestep embeddings.
+
+ Parameters:
+ embedding_dim (`int`): The size of each embedding vector.
+ num_embeddings (`int`): The size of the dictionary of embeddings.
+ """
+
+ def __init__(self, embedding_dim: int, num_embeddings: int):
+ super().__init__()
+ self.emb = nn.Embedding(num_embeddings, embedding_dim)
+ self.silu = nn.SiLU()
+ self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False)
+
+ def forward(self, x: torch.Tensor, timestep: torch.Tensor) -> torch.Tensor:
+ emb = self.linear(self.silu(self.emb(timestep)))
+ scale, shift = torch.chunk(emb, 2, dim=1)
+ x = self.norm(x) * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
+ return x
+
+
+class AdaLayerNormContinuous(nn.Module):
+ r"""
+ Norm layer modified to incorporate arbitrary continuous embeddings.
+
+ Parameters:
+ embedding_dim (`int`): The size of each embedding vector.
+ """
+
+ def __init__(self, embedding_dim: int, condition_dim: int):
+ super().__init__()
+ self.silu = nn.SiLU()
+ self.linear1 = nn.Linear(condition_dim, condition_dim)
+ self.linear2 = nn.Linear(condition_dim, embedding_dim * 2)
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False)
+
+ def forward(self, x: torch.Tensor, condition: torch.Tensor) -> torch.Tensor:
+ emb = self.linear2(self.silu(self.linear1(condition)))
+ scale, shift = torch.chunk(emb, 2, dim=1)
+ x = self.norm(x) * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
+ return x
+
+
+class Modulation(nn.Module):
+ def __init__(self, embedding_dim: int, condition_dim: int, zero_init: bool = False, single_layer: bool = False):
+ super().__init__()
+ self.silu = nn.SiLU()
+ if single_layer:
+ self.linear1 = nn.Identity()
+ else:
+ self.linear1 = nn.Linear(condition_dim, condition_dim)
+
+ self.linear2 = nn.Linear(condition_dim, embedding_dim * 2)
+
+ # Only zero init the last linear layer
+ if zero_init:
+ nn.init.zeros_(self.linear2.weight)
+ nn.init.zeros_(self.linear2.bias)
+
+ def forward(self, x: torch.Tensor, condition: torch.Tensor) -> torch.Tensor:
+ emb = self.linear2(self.silu(self.linear1(condition)))
+ scale, shift = torch.chunk(emb, 2, dim=1)
+ x = x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
+ return x
+
+
+class AdaLayerNormZero(nn.Module):
+ r"""
+ Norm layer adaptive layer norm zero (adaLN-Zero).
+
+ Parameters:
+ embedding_dim (`int`): The size of each embedding vector.
+ num_embeddings (`int`): The size of the dictionary of embeddings.
+ """
+
+ def __init__(self, embedding_dim: int, num_embeddings: int):
+ super().__init__()
+
+ self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim)
+
+ self.silu = nn.SiLU()
+ self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ timestep: torch.Tensor,
+ class_labels: torch.LongTensor,
+ hidden_dtype: Optional[torch.dtype] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ emb = self.linear(
+ self.silu(self.emb(timestep, class_labels, hidden_dtype=hidden_dtype))
+ )
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(
+ 6, dim=1
+ )
+ x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
+ return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
+
+
+class AdaGroupNorm(nn.Module):
+ r"""
+ GroupNorm layer modified to incorporate timestep embeddings.
+
+ Parameters:
+ embedding_dim (`int`): The size of each embedding vector.
+ num_embeddings (`int`): The size of the dictionary of embeddings.
+ num_groups (`int`): The number of groups to separate the channels into.
+ act_fn (`str`, *optional*, defaults to `None`): The activation function to use.
+ eps (`float`, *optional*, defaults to `1e-5`): The epsilon value to use for numerical stability.
+ """
+
+ def __init__(
+ self,
+ embedding_dim: int,
+ out_dim: int,
+ num_groups: int,
+ act_fn: Optional[str] = None,
+ eps: float = 1e-5,
+ ):
+ super().__init__()
+ self.num_groups = num_groups
+ self.eps = eps
+
+ if act_fn is None:
+ self.act = None
+ else:
+ self.act = get_activation(act_fn)
+
+ self.linear = nn.Linear(embedding_dim, out_dim * 2)
+
+ def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
+ if self.act:
+ emb = self.act(emb)
+ emb = self.linear(emb)
+ emb = emb[:, :, None, None]
+ scale, shift = emb.chunk(2, dim=1)
+
+ x = F.group_norm(x, self.num_groups, eps=self.eps)
+ x = x * (1 + scale) + shift
+ return x
+
+class Transformer1D(BaseModule, MemoryEfficientAttentionMixin):
+ """
+ A 1D Transformer model for sequence data.
+
+ Parameters:
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
+ in_channels (`int`, *optional*):
+ The number of channels in the input and output (specify if the input is **continuous**).
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
+ num_embeds_ada_norm ( `int`, *optional*):
+ The number of diffusion steps used during training. Pass if at least one of the norm_layers is
+ `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
+ added to the hidden states.
+
+ During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
+ attention_bias (`bool`, *optional*):
+ Configure if the `TransformerBlocks` attention should contain a bias parameter.
+ """
+
+ @dataclass
+ class Config(BaseModule.Config):
+ num_attention_heads: int = 16
+ attention_head_dim: int = 88
+ in_channels: Optional[int] = None
+ out_channels: Optional[int] = None
+ num_layers: int = 1
+ dropout: float = 0.0
+ norm_num_groups: int = 32
+ cross_attention_dim: Optional[int] = None
+ attention_bias: bool = False
+ activation_fn: str = "geglu"
+ num_embeds_ada_norm: Optional[int] = None
+ cond_dim_ada_norm_continuous: Optional[int] = None
+ only_cross_attention: bool = False
+ double_self_attention: bool = False
+ upcast_attention: bool = False
+ norm_type: str = "layer_norm"
+ norm_elementwise_affine: bool = True
+ attention_type: str = "default"
+ enable_memory_efficient_attention: bool = False
+ gradient_checkpointing: bool = False
+
+ cfg: Config
+
+ def configure(self) -> None:
+ super().configure()
+
+ self.num_attention_heads = self.cfg.num_attention_heads
+ self.attention_head_dim = self.cfg.attention_head_dim
+ inner_dim = self.num_attention_heads * self.attention_head_dim
+
+ linear_cls = nn.Linear
+
+ if self.cfg.norm_type == "layer_norm" and (
+ self.cfg.num_embeds_ada_norm is not None
+ or self.cfg.cond_dim_ada_norm_continuous is not None
+ ):
+ raise ValueError("Incorrect norm_type.")
+
+ # 2. Define input layers
+ self.in_channels = self.cfg.in_channels
+
+ self.norm = torch.nn.GroupNorm(
+ num_groups=self.cfg.norm_num_groups,
+ num_channels=self.cfg.in_channels,
+ eps=1e-6,
+ affine=True,
+ )
+ self.proj_in = linear_cls(self.cfg.in_channels, inner_dim)
+
+ # 3. Define transformers blocks
+ self.transformer_blocks = nn.ModuleList(
+ [
+ BasicTransformerBlock(
+ inner_dim,
+ self.num_attention_heads,
+ self.attention_head_dim,
+ dropout=self.cfg.dropout,
+ cross_attention_dim=self.cfg.cross_attention_dim,
+ activation_fn=self.cfg.activation_fn,
+ num_embeds_ada_norm=self.cfg.num_embeds_ada_norm,
+ cond_dim_ada_norm_continuous=self.cfg.cond_dim_ada_norm_continuous,
+ attention_bias=self.cfg.attention_bias,
+ only_cross_attention=self.cfg.only_cross_attention,
+ double_self_attention=self.cfg.double_self_attention,
+ upcast_attention=self.cfg.upcast_attention,
+ norm_type=self.cfg.norm_type,
+ norm_elementwise_affine=self.cfg.norm_elementwise_affine,
+ attention_type=self.cfg.attention_type,
+ )
+ for d in range(self.cfg.num_layers)
+ ]
+ )
+
+ # 4. Define output layers
+ self.out_channels = (
+ self.cfg.in_channels
+ if self.cfg.out_channels is None
+ else self.cfg.out_channels
+ )
+
+ self.proj_out = linear_cls(inner_dim, self.cfg.in_channels)
+
+ self.gradient_checkpointing = self.cfg.gradient_checkpointing
+
+ self.set_use_memory_efficient_attention_xformers(
+ self.cfg.enable_memory_efficient_attention
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ timestep: Optional[torch.LongTensor] = None,
+ modulation_cond: Optional[torch.FloatTensor] = None,
+ class_labels: Optional[torch.LongTensor] = None,
+ cross_attention_kwargs: Dict[str, Any] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ ):
+ """
+ The [`Transformer1DModel`] forward method.
+
+ Args:
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
+ Input `hidden_states`.
+ encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
+ self-attention.
+ timestep ( `torch.LongTensor`, *optional*):
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
+ Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
+ `AdaLayerZeroNorm`.
+ cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ attention_mask ( `torch.Tensor`, *optional*):
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
+ negative values to the attention scores corresponding to "discard" tokens.
+ encoder_attention_mask ( `torch.Tensor`, *optional*):
+ Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
+
+ * Mask `(batch, sequence_length)` True = keep, False = discard.
+ * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
+
+ If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
+ above. This bias will be added to the cross-attention scores.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
+ tuple.
+
+ Returns:
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
+ `tuple` where the first element is the sample tensor.
+ """
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
+ # expects mask of shape:
+ # [batch, key_tokens]
+ # adds singleton query_tokens dimension:
+ # [batch, 1, key_tokens]
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
+ if attention_mask is not None and attention_mask.ndim == 2:
+ # assume that mask is expressed as:
+ # (1 = keep, 0 = discard)
+ # convert mask into a bias that can be added to attention scores:
+ # (keep = +0, discard = -10000.0)
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
+ attention_mask = attention_mask.unsqueeze(1)
+
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
+ encoder_attention_mask = (
+ 1 - encoder_attention_mask.to(hidden_states.dtype)
+ ) * -10000.0
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
+
+ # 1. Input
+ batch, _, seq_len = hidden_states.shape
+ residual = hidden_states
+
+ hidden_states = self.norm(hidden_states)
+ inner_dim = hidden_states.shape[1]
+ hidden_states = hidden_states.permute(0, 2, 1).reshape(
+ batch, seq_len, inner_dim
+ )
+ hidden_states = self.proj_in(hidden_states)
+
+ # 2. Blocks
+ for block in self.transformer_blocks:
+ if self.training and self.gradient_checkpointing:
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ block,
+ hidden_states,
+ attention_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ timestep,
+ modulation_cond,
+ cross_attention_kwargs,
+ class_labels,
+ use_reentrant=False,
+ )
+ else:
+ hidden_states = block(
+ hidden_states,
+ attention_mask=attention_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ timestep=timestep,
+ modulation_cond=modulation_cond,
+ cross_attention_kwargs=cross_attention_kwargs,
+ class_labels=class_labels,
+ )
+
+ # 3. Output
+ hidden_states = self.proj_out(hidden_states)
+ hidden_states = (
+ hidden_states.reshape(batch, seq_len, inner_dim)
+ .permute(0, 2, 1)
+ .contiguous()
+ )
+
+ output = hidden_states + residual
+
+ return output
diff --git a/hort/models/tgs/utils/__init__.py b/hort/models/tgs/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/hort/models/tgs/utils/base.py b/hort/models/tgs/utils/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..36cee1fad8741cb8e0c06f9126a6aef51ceb206e
--- /dev/null
+++ b/hort/models/tgs/utils/base.py
@@ -0,0 +1,129 @@
+from dataclasses import dataclass
+
+import torch.nn as nn
+
+from tgs.utils.config import parse_structured
+from tgs.utils.misc import get_device, load_module_weights
+from tgs.utils.typing import *
+
+
+class Configurable:
+ @dataclass
+ class Config:
+ pass
+
+ def __init__(self, cfg: Optional[dict] = None) -> None:
+ super().__init__()
+ self.cfg = parse_structured(self.Config, cfg)
+
+
+class Updateable:
+ def do_update_step(
+ self, epoch: int, global_step: int, on_load_weights: bool = False
+ ):
+ for attr in self.__dir__():
+ if attr.startswith("_"):
+ continue
+ try:
+ module = getattr(self, attr)
+ except:
+ continue # ignore attributes like property, which can't be retrived using getattr?
+ if isinstance(module, Updateable):
+ module.do_update_step(
+ epoch, global_step, on_load_weights=on_load_weights
+ )
+ self.update_step(epoch, global_step, on_load_weights=on_load_weights)
+
+ def do_update_step_end(self, epoch: int, global_step: int):
+ for attr in self.__dir__():
+ if attr.startswith("_"):
+ continue
+ try:
+ module = getattr(self, attr)
+ except:
+ continue # ignore attributes like property, which can't be retrived using getattr?
+ if isinstance(module, Updateable):
+ module.do_update_step_end(epoch, global_step)
+ self.update_step_end(epoch, global_step)
+
+ def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False):
+ # override this method to implement custom update logic
+ # if on_load_weights is True, you should be careful doing things related to model evaluations,
+ # as the models and tensors are not guarenteed to be on the same device
+ pass
+
+ def update_step_end(self, epoch: int, global_step: int):
+ pass
+
+
+def update_if_possible(module: Any, epoch: int, global_step: int) -> None:
+ if isinstance(module, Updateable):
+ module.do_update_step(epoch, global_step)
+
+
+def update_end_if_possible(module: Any, epoch: int, global_step: int) -> None:
+ if isinstance(module, Updateable):
+ module.do_update_step_end(epoch, global_step)
+
+
+class BaseObject(Updateable):
+ @dataclass
+ class Config:
+ pass
+
+ cfg: Config # add this to every subclass of BaseObject to enable static type checking
+
+ def __init__(
+ self, cfg: Optional[Union[dict, DictConfig]] = None, *args, **kwargs
+ ) -> None:
+ super().__init__()
+ self.cfg = parse_structured(self.Config, cfg)
+ self.device = get_device()
+ self.configure(*args, **kwargs)
+
+ def configure(self, *args, **kwargs) -> None:
+ pass
+
+
+class BaseModule(nn.Module, Updateable):
+ @dataclass
+ class Config:
+ weights: Optional[str] = None
+ freeze: Optional[bool] = False
+
+ cfg: Config # add this to every subclass of BaseModule to enable static type checking
+
+ def __init__(
+ self, cfg: Optional[Union[dict, DictConfig]] = None, *args, **kwargs
+ ) -> None:
+ super().__init__()
+ self.cfg = parse_structured(self.Config, cfg)
+ self.device = get_device()
+ self._non_modules = {}
+ self.configure(*args, **kwargs)
+ if self.cfg.weights is not None:
+ # format: path/to/weights:module_name
+ weights_path, module_name = self.cfg.weights.split(":")
+ state_dict = load_module_weights(
+ weights_path, module_name=module_name, map_location="cpu"
+ )
+ self.load_state_dict(state_dict, strict=False)
+ # self.do_update_step(
+ # epoch, global_step, on_load_weights=True
+ # ) # restore states
+
+ if self.cfg.freeze:
+ for params in self.parameters():
+ params.requires_grad = False
+
+ def configure(self, *args, **kwargs) -> None:
+ pass
+
+ def register_non_module(self, name: str, module: nn.Module) -> None:
+ # non-modules won't be treated as model parameters
+ if name in self._non_modules:
+ raise ValueError(f"Non-module {name} already exists!")
+ self._non_modules[name] = module
+
+ def non_module(self, name: str):
+ return self._non_modules.get(name, None)
diff --git a/hort/models/tgs/utils/config.py b/hort/models/tgs/utils/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..8fe2c623f5a0458896d958e00e8e8327bfafd38b
--- /dev/null
+++ b/hort/models/tgs/utils/config.py
@@ -0,0 +1,74 @@
+import os
+from dataclasses import dataclass, field
+
+from omegaconf import OmegaConf
+
+from tgs.utils.typing import *
+
+# ============ Register OmegaConf Recolvers ============= #
+OmegaConf.register_new_resolver(
+ "calc_exp_lr_decay_rate", lambda factor, n: factor ** (1.0 / n)
+)
+OmegaConf.register_new_resolver("add", lambda a, b: a + b)
+OmegaConf.register_new_resolver("sub", lambda a, b: a - b)
+OmegaConf.register_new_resolver("mul", lambda a, b: a * b)
+OmegaConf.register_new_resolver("div", lambda a, b: a / b)
+OmegaConf.register_new_resolver("idiv", lambda a, b: a // b)
+OmegaConf.register_new_resolver("basename", lambda p: os.path.basename(p))
+OmegaConf.register_new_resolver("rmspace", lambda s, sub: s.replace(" ", sub))
+OmegaConf.register_new_resolver("tuple2", lambda s: [float(s), float(s)])
+OmegaConf.register_new_resolver("gt0", lambda s: s > 0)
+OmegaConf.register_new_resolver("not", lambda s: not s)
+OmegaConf.register_new_resolver("shsdim", lambda sh_degree: (sh_degree + 1) ** 2 * 3)
+# ======================================================= #
+
+# ============== Automatic Name Resolvers =============== #
+def get_naming_convention(cfg):
+ # TODO
+ name = f"tgs_{cfg.system.backbone.num_layers}"
+ return name
+
+# ======================================================= #
+
+@dataclass
+class ExperimentConfig:
+ n_gpus: int = 1
+ data: dict = field(default_factory=dict)
+ system: dict = field(default_factory=dict)
+
+def load_config(
+ *yamls: str, cli_args: list = [], from_string=False, makedirs=True, **kwargs
+) -> Any:
+ if from_string:
+ parse_func = OmegaConf.create
+ else:
+ parse_func = OmegaConf.load
+ yaml_confs = []
+ for y in yamls:
+ conf = parse_func(y)
+ extends = conf.pop("extends", None)
+ if extends:
+ assert os.path.exists(extends), f"File {extends} does not exist."
+ yaml_confs.append(OmegaConf.load(extends))
+ yaml_confs.append(conf)
+ cli_conf = OmegaConf.from_cli(cli_args)
+ cfg = OmegaConf.merge(*yaml_confs, cli_conf, kwargs)
+ OmegaConf.resolve(cfg)
+ assert isinstance(cfg, DictConfig)
+ scfg: ExperimentConfig = parse_structured(ExperimentConfig, cfg)
+
+ return scfg
+
+
+def config_to_primitive(config, resolve: bool = True) -> Any:
+ return OmegaConf.to_container(config, resolve=resolve)
+
+
+def dump_config(path: str, config) -> None:
+ with open(path, "w") as fp:
+ OmegaConf.save(config=config, f=fp)
+
+
+def parse_structured(fields: Any, cfg: Optional[Union[dict, DictConfig]] = None) -> Any:
+ scfg = OmegaConf.structured(fields(**cfg))
+ return scfg
diff --git a/hort/models/tgs/utils/misc.py b/hort/models/tgs/utils/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..85442996c7324d86aa94fe3661cee44d8b404a25
--- /dev/null
+++ b/hort/models/tgs/utils/misc.py
@@ -0,0 +1,88 @@
+import os
+import re
+
+import torch
+from packaging import version
+
+from tgs.utils.typing import *
+
+
+def parse_version(ver: str):
+ return version.parse(ver)
+
+
+def get_rank():
+ # SLURM_PROCID can be set even if SLURM is not managing the multiprocessing,
+ # therefore LOCAL_RANK needs to be checked first
+ rank_keys = ("RANK", "LOCAL_RANK", "SLURM_PROCID", "JSM_NAMESPACE_RANK")
+ for key in rank_keys:
+ rank = os.environ.get(key)
+ if rank is not None:
+ return int(rank)
+ return 0
+
+
+def get_device():
+ return torch.device(f"cuda:{get_rank()}")
+
+
+def load_module_weights(
+ path, module_name=None, ignore_modules=None, map_location=None
+) -> Tuple[dict, int, int]:
+ if module_name is not None and ignore_modules is not None:
+ raise ValueError("module_name and ignore_modules cannot be both set")
+ if map_location is None:
+ map_location = get_device()
+
+ ckpt = torch.load(path, map_location=map_location)
+ state_dict = ckpt["state_dict"]
+ state_dict_to_load = state_dict
+
+ if ignore_modules is not None:
+ state_dict_to_load = {}
+ for k, v in state_dict.items():
+ ignore = any(
+ [k.startswith(ignore_module + ".") for ignore_module in ignore_modules]
+ )
+ if ignore:
+ continue
+ state_dict_to_load[k] = v
+
+ if module_name is not None:
+ state_dict_to_load = {}
+ for k, v in state_dict.items():
+ m = re.match(rf"^{module_name}\.(.*)$", k)
+ if m is None:
+ continue
+ state_dict_to_load[m.group(1)] = v
+
+ return state_dict_to_load
+
+# convert a function into recursive style to handle nested dict/list/tuple variables
+def make_recursive_func(func):
+ def wrapper(vars, *args, **kwargs):
+ if isinstance(vars, list):
+ return [wrapper(x, *args, **kwargs) for x in vars]
+ elif isinstance(vars, tuple):
+ return tuple([wrapper(x, *args, **kwargs) for x in vars])
+ elif isinstance(vars, dict):
+ return {k: wrapper(v, *args, **kwargs) for k, v in vars.items()}
+ else:
+ return func(vars, *args, **kwargs)
+
+ return wrapper
+
+@make_recursive_func
+def todevice(vars, device="cuda"):
+ if isinstance(vars, torch.Tensor):
+ return vars.to(device)
+ elif isinstance(vars, str):
+ return vars
+ elif isinstance(vars, bool):
+ return vars
+ elif isinstance(vars, float):
+ return vars
+ elif isinstance(vars, int):
+ return vars
+ else:
+ raise NotImplementedError("invalid input type {} for tensor2numpy".format(type(vars)))
diff --git a/hort/models/tgs/utils/ops.py b/hort/models/tgs/utils/ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..531474faf282a42db0a06e1440dd8b58d441c8bb
--- /dev/null
+++ b/hort/models/tgs/utils/ops.py
@@ -0,0 +1,279 @@
+import math
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch.autograd import Function
+from torch.amp import custom_bwd, custom_fwd
+from pytorch3d import io
+from pytorch3d.renderer import (
+ PointsRasterizationSettings,
+ PointsRasterizer)
+from pytorch3d.structures import Pointclouds
+from pytorch3d.utils.camera_conversions import cameras_from_opencv_projection
+import cv2
+
+from tgs.utils.typing import *
+
+ValidScale = Union[Tuple[float, float], Num[Tensor, "2 D"]]
+
+def scale_tensor(
+ dat: Num[Tensor, "... D"], inp_scale: ValidScale, tgt_scale: ValidScale
+):
+ if inp_scale is None:
+ inp_scale = (0, 1)
+ if tgt_scale is None:
+ tgt_scale = (0, 1)
+ if isinstance(tgt_scale, Tensor):
+ assert dat.shape[-1] == tgt_scale.shape[-1]
+ dat = (dat - inp_scale[0]) / (inp_scale[1] - inp_scale[0])
+ dat = dat * (tgt_scale[1] - tgt_scale[0]) + tgt_scale[0]
+ return dat
+
+
+class _TruncExp(Function): # pylint: disable=abstract-method
+ # Implementation from torch-ngp:
+ # https://github.com/ashawkey/torch-ngp/blob/93b08a0d4ec1cc6e69d85df7f0acdfb99603b628/activation.py
+ @staticmethod
+ @custom_fwd(cast_inputs=torch.float32, device_type="cuda")
+ def forward(ctx, x): # pylint: disable=arguments-differ
+ ctx.save_for_backward(x)
+ return torch.exp(x)
+
+ @staticmethod
+ @custom_bwd(device_type="cuda")
+ def backward(ctx, g): # pylint: disable=arguments-differ
+ x = ctx.saved_tensors[0]
+ return g * torch.exp(torch.clamp(x, max=15))
+
+
+trunc_exp = _TruncExp.apply
+
+
+def get_activation(name) -> Callable:
+ if name is None:
+ return lambda x: x
+ name = name.lower()
+ if name == "none":
+ return lambda x: x
+ elif name == "lin2srgb":
+ return lambda x: torch.where(
+ x > 0.0031308,
+ torch.pow(torch.clamp(x, min=0.0031308), 1.0 / 2.4) * 1.055 - 0.055,
+ 12.92 * x,
+ ).clamp(0.0, 1.0)
+ elif name == "exp":
+ return lambda x: torch.exp(x)
+ elif name == "shifted_exp":
+ return lambda x: torch.exp(x - 1.0)
+ elif name == "trunc_exp":
+ return trunc_exp
+ elif name == "shifted_trunc_exp":
+ return lambda x: trunc_exp(x - 1.0)
+ elif name == "sigmoid":
+ return lambda x: torch.sigmoid(x)
+ elif name == "tanh":
+ return lambda x: torch.tanh(x)
+ elif name == "shifted_softplus":
+ return lambda x: F.softplus(x - 1.0)
+ elif name == "scale_-11_01":
+ return lambda x: x * 0.5 + 0.5
+ else:
+ try:
+ return getattr(F, name)
+ except AttributeError:
+ raise ValueError(f"Unknown activation function: {name}")
+
+def get_ray_directions(
+ H: int,
+ W: int,
+ focal: Union[float, Tuple[float, float]],
+ principal: Optional[Tuple[float, float]] = None,
+ use_pixel_centers: bool = True,
+) -> Float[Tensor, "H W 3"]:
+ """
+ Get ray directions for all pixels in camera coordinate.
+ Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
+ ray-tracing-generating-camera-rays/standard-coordinate-systems
+
+ Inputs:
+ H, W, focal, principal, use_pixel_centers: image height, width, focal length, principal point and whether use pixel centers
+ Outputs:
+ directions: (H, W, 3), the direction of the rays in camera coordinate
+ """
+ pixel_center = 0.5 if use_pixel_centers else 0
+
+ if isinstance(focal, float):
+ fx, fy = focal, focal
+ cx, cy = W / 2, H / 2
+ else:
+ fx, fy = focal
+ assert principal is not None
+ cx, cy = principal
+
+ i, j = torch.meshgrid(
+ torch.arange(W, dtype=torch.float32) + pixel_center,
+ torch.arange(H, dtype=torch.float32) + pixel_center,
+ indexing="xy",
+ )
+
+ directions: Float[Tensor, "H W 3"] = torch.stack(
+ [(i - cx) / fx, -(j - cy) / fy, -torch.ones_like(i)], -1
+ )
+
+ return directions
+
+
+def get_rays(
+ directions: Float[Tensor, "... 3"],
+ c2w: Float[Tensor, "... 4 4"],
+ keepdim=False,
+ noise_scale=0.0,
+) -> Tuple[Float[Tensor, "... 3"], Float[Tensor, "... 3"]]:
+ # Rotate ray directions from camera coordinate to the world coordinate
+ assert directions.shape[-1] == 3
+
+ if directions.ndim == 2: # (N_rays, 3)
+ if c2w.ndim == 2: # (4, 4)
+ c2w = c2w[None, :, :]
+ assert c2w.ndim == 3 # (N_rays, 4, 4) or (1, 4, 4)
+ rays_d = (directions[:, None, :] * c2w[:, :3, :3]).sum(-1) # (N_rays, 3)
+ rays_o = c2w[:, :3, 3].expand(rays_d.shape)
+ elif directions.ndim == 3: # (H, W, 3)
+ assert c2w.ndim in [2, 3]
+ if c2w.ndim == 2: # (4, 4)
+ rays_d = (directions[:, :, None, :] * c2w[None, None, :3, :3]).sum(
+ -1
+ ) # (H, W, 3)
+ rays_o = c2w[None, None, :3, 3].expand(rays_d.shape)
+ elif c2w.ndim == 3: # (B, 4, 4)
+ rays_d = (directions[None, :, :, None, :] * c2w[:, None, None, :3, :3]).sum(
+ -1
+ ) # (B, H, W, 3)
+ rays_o = c2w[:, None, None, :3, 3].expand(rays_d.shape)
+ elif directions.ndim == 4: # (B, H, W, 3)
+ assert c2w.ndim == 3 # (B, 4, 4)
+ rays_d = (directions[:, :, :, None, :] * c2w[:, None, None, :3, :3]).sum(
+ -1
+ ) # (B, H, W, 3)
+ rays_o = c2w[:, None, None, :3, 3].expand(rays_d.shape)
+
+ # add camera noise to avoid grid-like artifect
+ # https://github.com/ashawkey/stable-dreamfusion/blob/49c3d4fa01d68a4f027755acf94e1ff6020458cc/nerf/utils.py#L373
+ if noise_scale > 0:
+ rays_o = rays_o + torch.randn(3, device=rays_o.device) * noise_scale
+ rays_d = rays_d + torch.randn(3, device=rays_d.device) * noise_scale
+
+ rays_d = F.normalize(rays_d, dim=-1)
+ if not keepdim:
+ rays_o, rays_d = rays_o.reshape(-1, 3), rays_d.reshape(-1, 3)
+
+ return rays_o, rays_d
+
+
+def get_projection_matrix(
+ fovy: Union[float, Float[Tensor, "B"]], aspect_wh: float, near: float, far: float
+) -> Float[Tensor, "*B 4 4"]:
+ if isinstance(fovy, float):
+ proj_mtx = torch.zeros(4, 4, dtype=torch.float32)
+ proj_mtx[0, 0] = 1.0 / (math.tan(fovy / 2.0) * aspect_wh)
+ proj_mtx[1, 1] = -1.0 / math.tan(
+ fovy / 2.0
+ ) # add a negative sign here as the y axis is flipped in nvdiffrast output
+ proj_mtx[2, 2] = -(far + near) / (far - near)
+ proj_mtx[2, 3] = -2.0 * far * near / (far - near)
+ proj_mtx[3, 2] = -1.0
+ else:
+ batch_size = fovy.shape[0]
+ proj_mtx = torch.zeros(batch_size, 4, 4, dtype=torch.float32)
+ proj_mtx[:, 0, 0] = 1.0 / (torch.tan(fovy / 2.0) * aspect_wh)
+ proj_mtx[:, 1, 1] = -1.0 / torch.tan(
+ fovy / 2.0
+ ) # add a negative sign here as the y axis is flipped in nvdiffrast output
+ proj_mtx[:, 2, 2] = -(far + near) / (far - near)
+ proj_mtx[:, 2, 3] = -2.0 * far * near / (far - near)
+ proj_mtx[:, 3, 2] = -1.0
+ return proj_mtx
+
+
+def get_mvp_matrix(
+ c2w: Float[Tensor, "*B 4 4"], proj_mtx: Float[Tensor, "*B 4 4"]
+) -> Float[Tensor, "*B 4 4"]:
+ # calculate w2c from c2w: R' = Rt, t' = -Rt * t
+ # mathematically equivalent to (c2w)^-1
+ if c2w.ndim == 2:
+ assert proj_mtx.ndim == 2
+ w2c: Float[Tensor, "4 4"] = torch.zeros(4, 4).to(c2w)
+ w2c[:3, :3] = c2w[:3, :3].permute(1, 0)
+ w2c[:3, 3:] = -c2w[:3, :3].permute(1, 0) @ c2w[:3, 3:]
+ w2c[3, 3] = 1.0
+ else:
+ w2c: Float[Tensor, "B 4 4"] = torch.zeros(c2w.shape[0], 4, 4).to(c2w)
+ w2c[:, :3, :3] = c2w[:, :3, :3].permute(0, 2, 1)
+ w2c[:, :3, 3:] = -c2w[:, :3, :3].permute(0, 2, 1) @ c2w[:, :3, 3:]
+ w2c[:, 3, 3] = 1.0
+ # calculate mvp matrix by proj_mtx @ w2c (mv_mtx)
+ mvp_mtx = proj_mtx @ w2c
+ return mvp_mtx
+
+def get_intrinsic_from_fov(fov, H, W, bs=-1):
+ focal_length = 0.5 * H / np.tan(0.5 * fov)
+ intrinsic = np.identity(3, dtype=np.float32)
+ intrinsic[0, 0] = focal_length
+ intrinsic[1, 1] = focal_length
+ intrinsic[0, 2] = W / 2.0
+ intrinsic[1, 2] = H / 2.0
+
+ if bs > 0:
+ intrinsic = intrinsic[None].repeat(bs, axis=0)
+
+ return torch.from_numpy(intrinsic)
+
+def points_projection(points: Float[Tensor, "B Np 3"],
+ c2ws: Float[Tensor, "B 4 4"],
+ intrinsics: Float[Tensor, "B 3 3"],
+ local_features: Float[Tensor, "B C H W"],
+ # Rasterization settings
+ raster_point_radius: float = 0.0075, # point size
+ raster_points_per_pixel: int = 1, # a single point per pixel, for now
+ bin_size: int = 0):
+ B, C, H, W = local_features.shape
+ device = local_features.device
+ raster_settings = PointsRasterizationSettings(
+ image_size=(H, W),
+ radius=raster_point_radius,
+ points_per_pixel=raster_points_per_pixel,
+ bin_size=bin_size,
+ )
+ Np = points.shape[1]
+ R = raster_settings.points_per_pixel
+
+ w2cs = torch.inverse(c2ws)
+ image_size = torch.as_tensor([H, W]).view(1, 2).expand(w2cs.shape[0], -1).to(device)
+ cameras = cameras_from_opencv_projection(w2cs[:, :3, :3], w2cs[:, :3, 3], intrinsics, image_size)
+
+ rasterize = PointsRasterizer(cameras=cameras, raster_settings=raster_settings)
+ fragments = rasterize(Pointclouds(points))
+ fragments_idx: Tensor = fragments.idx.long()
+ visible_pixels = (fragments_idx > -1) # (B, H, W, R)
+ points_to_visible_pixels = fragments_idx[visible_pixels]
+
+ # Reshape local features to (B, H, W, R, C)
+ local_features = local_features.permute(0, 2, 3, 1).unsqueeze(-2).expand(-1, -1, -1, R, -1) # (B, H, W, R, C)
+
+ # Get local features corresponding to visible points
+ local_features_proj = torch.zeros(B * Np, C, device=device)
+ local_features_proj[points_to_visible_pixels] = local_features[visible_pixels]
+ local_features_proj = local_features_proj.reshape(B, Np, C)
+
+ return local_features_proj
+
+def compute_distance_transform(mask: torch.Tensor):
+ image_size = mask.shape[-1]
+ distance_transform = torch.stack([
+ torch.from_numpy(cv2.distanceTransform(
+ (1 - m), distanceType=cv2.DIST_L2, maskSize=cv2.DIST_MASK_3
+ ) / (image_size / 2))
+ for m in mask.squeeze(1).detach().cpu().numpy().astype(np.uint8)
+ ]).unsqueeze(1).clip(0, 1).to(mask.device)
+ return distance_transform
diff --git a/hort/models/tgs/utils/saving.py b/hort/models/tgs/utils/saving.py
new file mode 100644
index 0000000000000000000000000000000000000000..a497823b74a6cf9bc852f2223e95b3064a4a64ef
--- /dev/null
+++ b/hort/models/tgs/utils/saving.py
@@ -0,0 +1,315 @@
+import os
+import re
+import shutil
+
+import cv2
+import imageio
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+from matplotlib import cm
+from matplotlib.colors import LinearSegmentedColormap
+from PIL import Image, ImageDraw
+
+import tgs
+from tgs.utils.typing import *
+
+class SaverMixin:
+ _save_dir: Optional[str] = None
+
+ def set_save_dir(self, save_dir: str):
+ self._save_dir = save_dir
+
+ def get_save_dir(self):
+ if self._save_dir is None:
+ raise ValueError("Save dir is not set")
+ return self._save_dir
+
+ def convert_data(self, data):
+ if data is None:
+ return None
+ elif isinstance(data, np.ndarray):
+ return data
+ elif isinstance(data, torch.Tensor):
+ return data.detach().cpu().numpy()
+ elif isinstance(data, list):
+ return [self.convert_data(d) for d in data]
+ elif isinstance(data, dict):
+ return {k: self.convert_data(v) for k, v in data.items()}
+ else:
+ raise TypeError(
+ "Data must be in type numpy.ndarray, torch.Tensor, list or dict, getting",
+ type(data),
+ )
+
+ def get_save_path(self, filename):
+ save_path = os.path.join(self.get_save_dir(), filename)
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
+ return save_path
+
+ DEFAULT_RGB_KWARGS = {"data_format": "HWC", "data_range": (0, 1)}
+ DEFAULT_UV_KWARGS = {
+ "data_format": "HWC",
+ "data_range": (0, 1),
+ "cmap": "checkerboard",
+ }
+ DEFAULT_GRAYSCALE_KWARGS = {"data_range": None, "cmap": "jet"}
+ DEFAULT_GRID_KWARGS = {"align": "max"}
+
+ def get_rgb_image_(self, img, data_format, data_range, rgba=False):
+ img = self.convert_data(img)
+ assert data_format in ["CHW", "HWC"]
+ if data_format == "CHW":
+ img = img.transpose(1, 2, 0)
+ if img.dtype != np.uint8:
+ img = img.clip(min=data_range[0], max=data_range[1])
+ img = (
+ (img - data_range[0]) / (data_range[1] - data_range[0]) * 255.0
+ ).astype(np.uint8)
+ nc = 4 if rgba else 3
+ imgs = [img[..., start : start + nc] for start in range(0, img.shape[-1], nc)]
+ imgs = [
+ img_
+ if img_.shape[-1] == nc
+ else np.concatenate(
+ [
+ img_,
+ np.zeros(
+ (img_.shape[0], img_.shape[1], nc - img_.shape[2]),
+ dtype=img_.dtype,
+ ),
+ ],
+ axis=-1,
+ )
+ for img_ in imgs
+ ]
+ img = np.concatenate(imgs, axis=1)
+ if rgba:
+ img = cv2.cvtColor(img, cv2.COLOR_RGBA2BGRA)
+ else:
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
+ return img
+
+ def _save_rgb_image(
+ self,
+ filename,
+ img,
+ data_format,
+ data_range
+ ):
+ img = self.get_rgb_image_(img, data_format, data_range)
+ cv2.imwrite(filename, img)
+
+ def save_rgb_image(
+ self,
+ filename,
+ img,
+ data_format=DEFAULT_RGB_KWARGS["data_format"],
+ data_range=DEFAULT_RGB_KWARGS["data_range"],
+ ) -> str:
+ save_path = self.get_save_path(filename)
+ self._save_rgb_image(save_path, img, data_format, data_range)
+ return save_path
+
+ def get_grayscale_image_(self, img, data_range, cmap):
+ img = self.convert_data(img)
+ img = np.nan_to_num(img)
+ if data_range is None:
+ img = (img - img.min()) / (img.max() - img.min())
+ else:
+ img = img.clip(data_range[0], data_range[1])
+ img = (img - data_range[0]) / (data_range[1] - data_range[0])
+ assert cmap in [None, "jet", "magma", "spectral"]
+ if cmap == None:
+ img = (img * 255.0).astype(np.uint8)
+ img = np.repeat(img[..., None], 3, axis=2)
+ elif cmap == "jet":
+ img = (img * 255.0).astype(np.uint8)
+ img = cv2.applyColorMap(img, cv2.COLORMAP_JET)
+ elif cmap == "magma":
+ img = 1.0 - img
+ base = cm.get_cmap("magma")
+ num_bins = 256
+ colormap = LinearSegmentedColormap.from_list(
+ f"{base.name}{num_bins}", base(np.linspace(0, 1, num_bins)), num_bins
+ )(np.linspace(0, 1, num_bins))[:, :3]
+ a = np.floor(img * 255.0)
+ b = (a + 1).clip(max=255.0)
+ f = img * 255.0 - a
+ a = a.astype(np.uint16).clip(0, 255)
+ b = b.astype(np.uint16).clip(0, 255)
+ img = colormap[a] + (colormap[b] - colormap[a]) * f[..., None]
+ img = (img * 255.0).astype(np.uint8)
+ elif cmap == "spectral":
+ colormap = plt.get_cmap("Spectral")
+
+ def blend_rgba(image):
+ image = image[..., :3] * image[..., -1:] + (
+ 1.0 - image[..., -1:]
+ ) # blend A to RGB
+ return image
+
+ img = colormap(img)
+ img = blend_rgba(img)
+ img = (img * 255).astype(np.uint8)
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
+ return img
+
+ def _save_grayscale_image(
+ self,
+ filename,
+ img,
+ data_range,
+ cmap,
+ ):
+ img = self.get_grayscale_image_(img, data_range, cmap)
+ cv2.imwrite(filename, img)
+
+ def save_grayscale_image(
+ self,
+ filename,
+ img,
+ data_range=DEFAULT_GRAYSCALE_KWARGS["data_range"],
+ cmap=DEFAULT_GRAYSCALE_KWARGS["cmap"],
+ ) -> str:
+ save_path = self.get_save_path(filename)
+ self._save_grayscale_image(save_path, img, data_range, cmap)
+ return save_path
+
+ def get_image_grid_(self, imgs, align):
+ if isinstance(imgs[0], list):
+ return np.concatenate(
+ [self.get_image_grid_(row, align) for row in imgs], axis=0
+ )
+ cols = []
+ for col in imgs:
+ assert col["type"] in ["rgb", "uv", "grayscale"]
+ if col["type"] == "rgb":
+ rgb_kwargs = self.DEFAULT_RGB_KWARGS.copy()
+ rgb_kwargs.update(col["kwargs"])
+ cols.append(self.get_rgb_image_(col["img"], **rgb_kwargs))
+ elif col["type"] == "uv":
+ uv_kwargs = self.DEFAULT_UV_KWARGS.copy()
+ uv_kwargs.update(col["kwargs"])
+ cols.append(self.get_uv_image_(col["img"], **uv_kwargs))
+ elif col["type"] == "grayscale":
+ grayscale_kwargs = self.DEFAULT_GRAYSCALE_KWARGS.copy()
+ grayscale_kwargs.update(col["kwargs"])
+ cols.append(self.get_grayscale_image_(col["img"], **grayscale_kwargs))
+
+ if align == "max":
+ h = max([col.shape[0] for col in cols])
+ w = max([col.shape[1] for col in cols])
+ elif align == "min":
+ h = min([col.shape[0] for col in cols])
+ w = min([col.shape[1] for col in cols])
+ elif isinstance(align, int):
+ h = align
+ w = align
+ elif (
+ isinstance(align, tuple)
+ and isinstance(align[0], int)
+ and isinstance(align[1], int)
+ ):
+ h, w = align
+ else:
+ raise ValueError(
+ f"Unsupported image grid align: {align}, should be min, max, int or (int, int)"
+ )
+
+ for i in range(len(cols)):
+ if cols[i].shape[0] != h or cols[i].shape[1] != w:
+ cols[i] = cv2.resize(cols[i], (w, h), interpolation=cv2.INTER_LINEAR)
+ return np.concatenate(cols, axis=1)
+
+ def save_image_grid(
+ self,
+ filename,
+ imgs,
+ align=DEFAULT_GRID_KWARGS["align"],
+ texts: Optional[List[float]] = None,
+ ):
+ save_path = self.get_save_path(filename)
+ img = self.get_image_grid_(imgs, align=align)
+
+ if texts is not None:
+ img = Image.fromarray(img)
+ draw = ImageDraw.Draw(img)
+ black, white = (0, 0, 0), (255, 255, 255)
+ for i, text in enumerate(texts):
+ draw.text((2, (img.size[1] // len(texts)) * i + 1), f"{text}", white)
+ draw.text((0, (img.size[1] // len(texts)) * i + 1), f"{text}", white)
+ draw.text((2, (img.size[1] // len(texts)) * i - 1), f"{text}", white)
+ draw.text((0, (img.size[1] // len(texts)) * i - 1), f"{text}", white)
+ draw.text((1, (img.size[1] // len(texts)) * i), f"{text}", black)
+ img = np.asarray(img)
+
+ cv2.imwrite(save_path, img)
+ return save_path
+
+ def save_image(self, filename, img) -> str:
+ save_path = self.get_save_path(filename)
+ img = self.convert_data(img)
+ assert img.dtype == np.uint8 or img.dtype == np.uint16
+ if img.ndim == 3 and img.shape[-1] == 3:
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
+ elif img.ndim == 3 and img.shape[-1] == 4:
+ img = cv2.cvtColor(img, cv2.COLOR_RGBA2BGRA)
+ cv2.imwrite(save_path, img)
+ return save_path
+
+ def save_img_sequence(
+ self,
+ filename,
+ img_dir,
+ matcher,
+ save_format="mp4",
+ fps=30,
+ ) -> str:
+ assert save_format in ["gif", "mp4"]
+ if not filename.endswith(save_format):
+ filename += f".{save_format}"
+ save_path = self.get_save_path(filename)
+ matcher = re.compile(matcher)
+ img_dir = os.path.join(self.get_save_dir(), img_dir)
+ imgs = []
+ for f in os.listdir(img_dir):
+ if matcher.search(f):
+ imgs.append(f)
+ imgs = sorted(imgs, key=lambda f: int(matcher.search(f).groups()[0]))
+ imgs = [cv2.imread(os.path.join(img_dir, f)) for f in imgs]
+
+ if save_format == "gif":
+ imgs = [cv2.cvtColor(i, cv2.COLOR_BGR2RGB) for i in imgs]
+ imageio.mimsave(save_path, imgs, fps=fps, palettesize=256)
+ elif save_format == "mp4":
+ imgs = [cv2.cvtColor(i, cv2.COLOR_BGR2RGB) for i in imgs]
+ imageio.mimsave(save_path, imgs, fps=fps)
+ return save_path
+
+ def save_img_sequences(
+ self,
+ seq_dir,
+ matcher,
+ save_format="mp4",
+ fps=30,
+ delete=True
+ ):
+ seq_dir_ = os.path.join(self.get_save_dir(), seq_dir)
+ for f in os.listdir(seq_dir_):
+ img_dir_ = os.path.join(seq_dir_, f)
+ if not os.path.isdir(img_dir_):
+ continue
+ try:
+ self.save_img_sequence(
+ os.path.join(seq_dir, f),
+ os.path.join(seq_dir, f),
+ matcher,
+ save_format=save_format,
+ fps=fps
+ )
+ except:
+ tgs.warn(f"Video saving for directory {seq_dir_} failed!")
+
+ if delete:
+ shutil.rmtree(img_dir_)
diff --git a/hort/models/tgs/utils/typing.py b/hort/models/tgs/utils/typing.py
new file mode 100644
index 0000000000000000000000000000000000000000..dee9f967c21f94db1ad939d7dead156d86748752
--- /dev/null
+++ b/hort/models/tgs/utils/typing.py
@@ -0,0 +1,40 @@
+"""
+This module contains type annotations for the project, using
+1. Python type hints (https://docs.python.org/3/library/typing.html) for Python objects
+2. jaxtyping (https://github.com/google/jaxtyping/blob/main/API.md) for PyTorch tensors
+
+Two types of typing checking can be used:
+1. Static type checking with mypy (install with pip and enabled as the default linter in VSCode)
+2. Runtime type checking with typeguard (install with pip and triggered at runtime, mainly for tensor dtype and shape checking)
+"""
+
+# Basic types
+from typing import (
+ Any,
+ Callable,
+ Dict,
+ Iterable,
+ List,
+ Literal,
+ NamedTuple,
+ NewType,
+ Optional,
+ Sized,
+ Tuple,
+ Type,
+ TypeVar,
+ Union,
+)
+
+# Tensor dtype
+# for jaxtyping usage, see https://github.com/google/jaxtyping/blob/main/API.md
+from jaxtyping import Bool, Complex, Float, Inexact, Int, Integer, Num, Shaped, UInt
+
+# Config type
+from omegaconf import DictConfig
+
+# PyTorch Tensor type
+from torch import Tensor
+
+# Runtime type checking decorator
+from typeguard import typechecked as typechecker
diff --git a/hort/utils/__init__.py b/hort/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3fa5d45defd84f4b5acb11e17f4b440b6a780168
--- /dev/null
+++ b/hort/utils/__init__.py
@@ -0,0 +1,3 @@
+import torch
+from typing import Any
+from .renderer import Renderer
\ No newline at end of file
diff --git a/hort/utils/img_utils.py b/hort/utils/img_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..71c1dad27a7d52566aef20203fea97e4c187ad34
--- /dev/null
+++ b/hort/utils/img_utils.py
@@ -0,0 +1,181 @@
+import numpy as np
+import cv2
+
+
+def process_bbox(bbox, factor=1.25):
+ # aspect ratio preserving bbox
+ w = bbox[2]
+ h = bbox[3]
+ c_x = bbox[0] + w / 2.
+ c_y = bbox[1] + h / 2.
+ aspect_ratio = 1.
+ if w > aspect_ratio * h:
+ h = w / aspect_ratio
+ elif w < aspect_ratio * h:
+ w = h * aspect_ratio
+ bbox[2] = w * factor
+ bbox[3] = h * factor
+ bbox[0] = c_x - bbox[2] / 2.
+ bbox[1] = c_y - bbox[3] / 2.
+
+ return bbox
+
+
+def generate_patch_image(cvimg, bbox, input_shape, do_flip, scale, rot):
+ """
+ @description: Modified from https://github.com/mks0601/3DMPPE_ROOTNET_RELEASE/blob/master/data/dataset.py.
+ generate the patch image from the bounding box and other parameters.
+ ---------
+ @param: input image, bbox(x1, y1, w, h), dest image shape, do_flip, scale factor, rotation degrees.
+ -------
+ @Returns: processed image, affine_transform matrix to get the processed image.
+ -------
+ """
+
+ img = cvimg.copy()
+ img_height, img_width, _ = img.shape
+
+ bb_c_x = float(bbox[0] + 0.5 * bbox[2])
+ bb_c_y = float(bbox[1] + 0.5 * bbox[3])
+ bb_width = float(bbox[2])
+ bb_height = float(bbox[3])
+
+ if do_flip:
+ img = img[:, ::-1, :]
+ bb_c_x = img_width - bb_c_x - 1
+
+ trans = gen_trans_from_patch_cv(bb_c_x, bb_c_y, bb_width, bb_height, input_shape[1], input_shape[0], scale, rot, inv=False)
+ img_patch = cv2.warpAffine(img, trans, (int(input_shape[1]), int(input_shape[0])), flags=cv2.INTER_LINEAR)
+ new_trans = np.zeros((3, 3), dtype=np.float32)
+ new_trans[:2, :] = trans
+ new_trans[2, 2] = 1
+
+ return img_patch, new_trans
+
+
+def gen_trans_from_patch_cv(c_x, c_y, src_width, src_height, dst_width, dst_height, scale, rot, inv=False):
+ """
+ @description: Modified from https://github.com/mks0601/3DMPPE_ROOTNET_RELEASE/blob/master/data/dataset.py.
+ get affine transform matrix
+ ---------
+ @param: image center, original image size, desired image size, scale factor, rotation degree, whether to get inverse transformation.
+ -------
+ @Returns: affine transformation matrix
+ -------
+ """
+
+ def rotate_2d(pt_2d, rot_rad):
+ x = pt_2d[0]
+ y = pt_2d[1]
+ sn, cs = np.sin(rot_rad), np.cos(rot_rad)
+ xx = x * cs - y * sn
+ yy = x * sn + y * cs
+ return np.array([xx, yy], dtype=np.float32)
+
+ # augment size with scale
+ src_w = src_width * scale
+ src_h = src_height * scale
+ src_center = np.array([c_x, c_y], dtype=np.float32)
+
+ # augment rotation
+ rot_rad = np.pi * rot / 180
+ src_downdir = rotate_2d(np.array([0, src_h * 0.5], dtype=np.float32), rot_rad)
+ src_rightdir = rotate_2d(np.array([src_w * 0.5, 0], dtype=np.float32), rot_rad)
+
+ dst_w = dst_width
+ dst_h = dst_height
+ dst_center = np.array([dst_w * 0.5, dst_h * 0.5], dtype=np.float32)
+ dst_downdir = np.array([0, dst_h * 0.5], dtype=np.float32)
+ dst_rightdir = np.array([dst_w * 0.5, 0], dtype=np.float32)
+
+ src = np.zeros((3, 2), dtype=np.float32)
+ src[0, :] = src_center
+ src[1, :] = src_center + src_downdir
+ src[2, :] = src_center + src_rightdir
+
+ dst = np.zeros((3, 2), dtype=np.float32)
+ dst[0, :] = dst_center
+ dst[1, :] = dst_center + dst_downdir
+ dst[2, :] = dst_center + dst_rightdir
+
+ if inv:
+ trans = cv2.getAffineTransform(np.float32(dst), np.float32(src))
+ else:
+ trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))
+
+ return trans
+
+
+class PerspectiveCamera:
+ def __init__(self, fx, fy, cx, cy, R=np.eye(3), t=np.zeros(3)):
+ self.K = np.array([[fx, 0, cx, 0], [0, fy, cy, 0], [0, 0, 1, 0]], dtype=np.float32)
+
+ self.R = np.array(R, dtype=np.float32).copy()
+ assert self.R.shape == (3, 3)
+
+ self.t = np.array(t, dtype=np.float32).copy()
+ assert self.t.size == 3
+ self.t = self.t.reshape(3, 1)
+
+ def update_virtual_camera_after_crop(self, bbox, option='same'):
+ left, upper, width, height = bbox
+ new_img_center = np.array([left + width / 2, upper + height / 2, 1], dtype=np.float32).reshape(3, 1)
+ new_cam_center = np.linalg.inv(self.K[:3, :3]).dot(new_img_center)
+ self.K[0, 2], self.K[1, 2] = width / 2, height / 2
+
+ x, y, z = new_cam_center[0], new_cam_center[1], new_cam_center[2]
+ sin_theta = -y / np.sqrt(1 + x ** 2 + y ** 2)
+ cos_theta = np.sqrt(1 + x ** 2) / np.sqrt(1 + x ** 2 + y ** 2)
+ R_x = np.array([[1, 0, 0], [0, cos_theta, -sin_theta], [0, sin_theta, cos_theta]], dtype=np.float32)
+ sin_phi = x / np.sqrt(1 + x ** 2)
+ cos_phi = 1 / np.sqrt(1 + x ** 2)
+ R_y = np.array([[cos_phi, 0, sin_phi], [0, 1, 0], [-sin_phi, 0, cos_phi]], dtype=np.float32)
+ self.R = R_y @ R_x
+
+ # update focal length for virtual camera; please refer to the paper "PCLs: Geometry-aware Neural Reconstruction of 3D Pose with Perspective Crop Layers" for more details.
+ if option == 'length':
+ self.K[0, 0] = self.K[0, 0] * np.sqrt(1 + x ** 2 + y ** 2)
+ self.K[1, 1] = self.K[1, 1] * np.sqrt(1 + x ** 2 + y ** 2)
+
+ if option == 'scale':
+ self.K[0, 0] = self.K[0, 0] * np.sqrt(1 + x ** 2 + y ** 2) * np.sqrt(1 + x ** 2)
+ self.K[1, 1] = self.K[1, 1] * (1 + x ** 2 + y ** 2)/ np.sqrt(1 + x ** 2)
+
+ def update_intrinsics_after_crop(self, bbox):
+ left, upper, _, _ = bbox
+
+ cx, cy = self.K[0, 2], self.K[1, 2]
+
+ new_cx = cx - left
+ new_cy = cy - upper
+
+ self.K[0, 2], self.K[1, 2] = new_cx, new_cy
+
+ def update_intrinsics_after_resize(self, image_shape, new_image_shape):
+ height, width = image_shape
+ new_height, new_width = new_image_shape
+
+ fx, fy, cx, cy = self.K[0, 0], self.K[1, 1], self.K[0, 2], self.K[1, 2]
+
+ new_fx = fx * (new_width / width)
+ new_fy = fy * (new_height / height)
+ new_cx = cx * (new_width / width)
+ new_cy = cy * (new_height / height)
+
+ self.K[0, 0], self.K[1, 1], self.K[0, 2], self.K[1, 2] = new_fx, new_fy, new_cx, new_cy
+
+ def update_intrinsics_after_scale(self, scale_factor):
+ self.K[0, 0] /= scale_factor
+ self.K[1, 1] /= scale_factor
+
+ @property
+ def projection(self):
+ return self.K.dot(self.extrinsics)
+
+ @property
+ def intrinsics(self):
+ return self.K
+
+ @property
+ def extrinsics(self):
+ return np.hstack([self.R, self.t])
\ No newline at end of file
diff --git a/hort/utils/renderer.py b/hort/utils/renderer.py
new file mode 100644
index 0000000000000000000000000000000000000000..4f6ff18ac90900d936fb75eb7bfbb56e815e8c98
--- /dev/null
+++ b/hort/utils/renderer.py
@@ -0,0 +1,375 @@
+import os
+if 'PYOPENGL_PLATFORM' not in os.environ:
+ os.environ['PYOPENGL_PLATFORM'] = 'egl'
+import torch
+import numpy as np
+import pyrender
+import trimesh
+import cv2
+from yacs.config import CfgNode
+from typing import List, Optional
+
+def cam_crop_to_full(cam_bbox, box_center, box_size, img_size, focal_length=5000.):
+ # Convert cam_bbox to full image
+ img_w, img_h = img_size[:, 0], img_size[:, 1]
+ cx, cy, b = box_center[:, 0], box_center[:, 1], box_size
+ w_2, h_2 = img_w / 2., img_h / 2.
+ bs = b * cam_bbox[:, 0] + 1e-9
+ tz = 2 * focal_length / bs
+ tx = (2 * (cx - w_2) / bs) + cam_bbox[:, 1]
+ ty = (2 * (cy - h_2) / bs) + cam_bbox[:, 2]
+ full_cam = torch.stack([tx, ty, tz], dim=-1)
+ return full_cam
+
+def cam_crop_to_new(cam_bbox, box_center, box_size, img_size, new_bbox, focal_length=5000.):
+ img_w, img_h = img_size[:, 0], img_size[:, 1]
+ cx, cy, b = box_center[:, 0], box_center[:, 1], box_size
+ w_2, h_2 = img_w / 2., img_h / 2.
+ bs = b * cam_bbox[:, 0] + 1e-9
+ tz = 2 * focal_length / bs
+
+ x0, y0, new_w, new_h = new_bbox[:, 0], new_bbox[:, 1], new_bbox[:, 2], new_bbox[:, 3]
+ tz = tz * new_w / 224
+ new_cx, new_cy = x0 + new_w / 2., y0 + new_h / 2.
+ tx = (2 * (cx - new_cx) / bs) + cam_bbox[:, 1]
+ ty = (2 * (cy - new_cy) / bs) + cam_bbox[:, 2]
+
+ full_cam_resized = torch.stack([tx, ty, tz], dim=-1)
+ return full_cam_resized
+
+def get_light_poses(n_lights=5, elevation=np.pi / 3, dist=12):
+ # get lights in a circle around origin at elevation
+ thetas = elevation * np.ones(n_lights)
+ phis = 2 * np.pi * np.arange(n_lights) / n_lights
+ poses = []
+ trans = make_translation(torch.tensor([0, 0, dist]))
+ for phi, theta in zip(phis, thetas):
+ rot = make_rotation(rx=-theta, ry=phi, order="xyz")
+ poses.append((rot @ trans).numpy())
+ return poses
+
+def make_translation(t):
+ return make_4x4_pose(torch.eye(3), t)
+
+def make_rotation(rx=0, ry=0, rz=0, order="xyz"):
+ Rx = rotx(rx)
+ Ry = roty(ry)
+ Rz = rotz(rz)
+ if order == "xyz":
+ R = Rz @ Ry @ Rx
+ elif order == "xzy":
+ R = Ry @ Rz @ Rx
+ elif order == "yxz":
+ R = Rz @ Rx @ Ry
+ elif order == "yzx":
+ R = Rx @ Rz @ Ry
+ elif order == "zyx":
+ R = Rx @ Ry @ Rz
+ elif order == "zxy":
+ R = Ry @ Rx @ Rz
+ return make_4x4_pose(R, torch.zeros(3))
+
+def make_4x4_pose(R, t):
+ """
+ :param R (*, 3, 3)
+ :param t (*, 3)
+ return (*, 4, 4)
+ """
+ dims = R.shape[:-2]
+ pose_3x4 = torch.cat([R, t.view(*dims, 3, 1)], dim=-1)
+ bottom = (
+ torch.tensor([0, 0, 0, 1], device=R.device)
+ .reshape(*(1,) * len(dims), 1, 4)
+ .expand(*dims, 1, 4)
+ )
+ return torch.cat([pose_3x4, bottom], dim=-2)
+
+
+def rotx(theta):
+ return torch.tensor(
+ [
+ [1, 0, 0],
+ [0, np.cos(theta), -np.sin(theta)],
+ [0, np.sin(theta), np.cos(theta)],
+ ],
+ dtype=torch.float32,
+ )
+
+
+def roty(theta):
+ return torch.tensor(
+ [
+ [np.cos(theta), 0, np.sin(theta)],
+ [0, 1, 0],
+ [-np.sin(theta), 0, np.cos(theta)],
+ ],
+ dtype=torch.float32,
+ )
+
+
+def rotz(theta):
+ return torch.tensor(
+ [
+ [np.cos(theta), -np.sin(theta), 0],
+ [np.sin(theta), np.cos(theta), 0],
+ [0, 0, 1],
+ ],
+ dtype=torch.float32,
+ )
+
+
+def create_raymond_lights() -> List[pyrender.Node]:
+ """
+ Return raymond light nodes for the scene.
+ """
+ thetas = np.pi * np.array([1.0 / 6.0, 1.0 / 6.0, 1.0 / 6.0])
+ phis = np.pi * np.array([0.0, 2.0 / 3.0, 4.0 / 3.0])
+
+ nodes = []
+
+ for phi, theta in zip(phis, thetas):
+ xp = np.sin(theta) * np.cos(phi)
+ yp = np.sin(theta) * np.sin(phi)
+ zp = np.cos(theta)
+
+ z = np.array([xp, yp, zp])
+ z = z / np.linalg.norm(z)
+ x = np.array([-z[1], z[0], 0.0])
+ if np.linalg.norm(x) == 0:
+ x = np.array([1.0, 0.0, 0.0])
+ x = x / np.linalg.norm(x)
+ y = np.cross(z, x)
+
+ matrix = np.eye(4)
+ matrix[:3,:3] = np.c_[x,y,z]
+ nodes.append(pyrender.Node(
+ light=pyrender.DirectionalLight(color=np.ones(3), intensity=1.0),
+ matrix=matrix
+ ))
+
+ return nodes
+
+class Renderer:
+
+ def __init__(self, cfg: CfgNode, faces: np.array):
+ """
+ Wrapper around the pyrender renderer to render MANO meshes.
+ Args:
+ cfg (CfgNode): Model config file.
+ faces (np.array): Array of shape (F, 3) containing the mesh faces.
+ """
+ self.cfg = cfg
+ self.focal_length = cfg.EXTRA.FOCAL_LENGTH
+ self.img_res = cfg.MODEL.IMAGE_SIZE
+
+ # add faces that make the hand mesh watertight
+ faces_new = np.array([[92, 38, 234],
+ [234, 38, 239],
+ [38, 122, 239],
+ [239, 122, 279],
+ [122, 118, 279],
+ [279, 118, 215],
+ [118, 117, 215],
+ [215, 117, 214],
+ [117, 119, 214],
+ [214, 119, 121],
+ [119, 120, 121],
+ [121, 120, 78],
+ [120, 108, 78],
+ [78, 108, 79]])
+ faces = np.concatenate([faces, faces_new], axis=0)
+
+ self.camera_center = [self.img_res // 2, self.img_res // 2]
+ self.faces = faces
+ self.faces_left = self.faces[:,[0,2,1]]
+
+ def __call__(self,
+ vertices: np.array,
+ camera_translation: np.array,
+ image: torch.Tensor,
+ full_frame: bool = False,
+ imgname: Optional[str] = None,
+ side_view=False, rot_angle=90,
+ mesh_base_color=(1.0, 1.0, 0.9),
+ scene_bg_color=(0,0,0),
+ return_rgba=False,
+ ) -> np.array:
+ """
+ Render meshes on input image
+ Args:
+ vertices (np.array): Array of shape (V, 3) containing the mesh vertices.
+ camera_translation (np.array): Array of shape (3,) with the camera translation.
+ image (torch.Tensor): Tensor of shape (3, H, W) containing the image crop with normalized pixel values.
+ full_frame (bool): If True, then render on the full image.
+ imgname (Optional[str]): Contains the original image filenamee. Used only if full_frame == True.
+ """
+
+ if full_frame:
+ image = cv2.imread(imgname).astype(np.float32)[:, :, ::-1] / 255.
+ else:
+ image = image.clone() * torch.tensor(self.cfg.MODEL.IMAGE_STD, device=image.device).reshape(3,1,1)
+ image = image + torch.tensor(self.cfg.MODEL.IMAGE_MEAN, device=image.device).reshape(3,1,1)
+ image = image.permute(1, 2, 0).cpu().numpy()
+
+ renderer = pyrender.OffscreenRenderer(viewport_width=image.shape[1],
+ viewport_height=image.shape[0],
+ point_size=1.0)
+ material = pyrender.MetallicRoughnessMaterial(
+ metallicFactor=0.0,
+ alphaMode='OPAQUE',
+ baseColorFactor=(*mesh_base_color, 1.0))
+
+ camera_translation[0] *= -1.
+
+ mesh = trimesh.Trimesh(vertices.copy(), self.faces.copy())
+ if side_view:
+ rot = trimesh.transformations.rotation_matrix(
+ np.radians(rot_angle), [0, 1, 0])
+ mesh.apply_transform(rot)
+ rot = trimesh.transformations.rotation_matrix(
+ np.radians(180), [1, 0, 0])
+ mesh.apply_transform(rot)
+ mesh = pyrender.Mesh.from_trimesh(mesh, material=material)
+
+ scene = pyrender.Scene(bg_color=[*scene_bg_color, 0.0],
+ ambient_light=(0.3, 0.3, 0.3))
+ scene.add(mesh, 'mesh')
+
+ camera_pose = np.eye(4)
+ camera_pose[:3, 3] = camera_translation
+ camera_center = [image.shape[1] / 2., image.shape[0] / 2.]
+ camera = pyrender.IntrinsicsCamera(fx=self.focal_length, fy=self.focal_length,
+ cx=camera_center[0], cy=camera_center[1], zfar=1e12)
+ scene.add(camera, pose=camera_pose)
+
+
+ light_nodes = create_raymond_lights()
+ for node in light_nodes:
+ scene.add_node(node)
+
+ color, rend_depth = renderer.render(scene, flags=pyrender.RenderFlags.RGBA)
+ color = color.astype(np.float32) / 255.0
+ renderer.delete()
+
+ if return_rgba:
+ return color
+
+ valid_mask = (color[:, :, -1])[:, :, np.newaxis]
+ if not side_view:
+ output_img = (color[:, :, :3] * valid_mask + (1 - valid_mask) * image)
+ else:
+ output_img = color[:, :, :3]
+
+ output_img = output_img.astype(np.float32)
+ return output_img
+
+ def vertices_to_trimesh(self, vertices, camera_translation, mesh_base_color=(1.0, 1.0, 0.9),
+ rot_axis=[1,0,0], rot_angle=0, is_right=1):
+ vertex_colors = np.array([(*mesh_base_color, 1.0)] * vertices.shape[0])
+ if is_right:
+ mesh = trimesh.Trimesh(vertices.copy() + camera_translation, self.faces.copy(), vertex_colors=vertex_colors)
+ else:
+ mesh = trimesh.Trimesh(vertices.copy() + camera_translation, self.faces_left.copy(), vertex_colors=vertex_colors)
+
+ rot = trimesh.transformations.rotation_matrix(
+ np.radians(rot_angle), rot_axis)
+ mesh.apply_transform(rot)
+
+ rot = trimesh.transformations.rotation_matrix(
+ np.radians(180), [1, 0, 0])
+ mesh.apply_transform(rot)
+ return mesh
+
+ def render_rgba(
+ self,
+ vertices: np.array,
+ points: np.array,
+ cam_t = None,
+ rot=None,
+ rot_axis=[1,0,0],
+ rot_angle=0,
+ camera_z=3,
+ mesh_base_color=(1.0, 1.0, 0.9),
+ point_base_color=(1.0, 1.0, 0.9),
+ scene_bg_color=(0,0,0),
+ render_res=[256, 256],
+ focal_length=None,
+ is_right=None,
+ ):
+
+ renderer = pyrender.OffscreenRenderer(viewport_width=render_res[0], viewport_height=render_res[1], point_size=3)
+ focal_length = focal_length if focal_length is not None else self.focal_length
+
+ if cam_t is not None:
+ camera_translation = cam_t.copy()
+ camera_translation[0] *= -1.
+ else:
+ camera_translation = np.array([0, 0, camera_z * focal_length/render_res[1]])
+
+ mesh = self.vertices_to_trimesh(vertices, np.array([0, 0, 0]), mesh_base_color, rot_axis, rot_angle, is_right=is_right)
+ mesh = pyrender.Mesh.from_trimesh(mesh)
+ rot = trimesh.transformations.rotation_matrix(np.radians(180), [1, 0, 0])
+ points = (rot[:3, :3] @ points.transpose(1, 0)).transpose(1, 0)
+ point_cloud = pyrender.Mesh.from_points(points=points, colors=point_base_color)
+
+ scene = pyrender.Scene(bg_color=[*scene_bg_color, 0.0], ambient_light=(0.3, 0.3, 0.3))
+ scene.add(mesh, 'mesh')
+ scene.add(point_cloud, 'pc')
+
+ camera_pose = np.eye(4)
+ camera_pose[:3, 3] = camera_translation
+ camera_center = [render_res[0] / 2., render_res[1] / 2.]
+ camera = pyrender.IntrinsicsCamera(fx=focal_length, fy=focal_length, cx=camera_center[0], cy=camera_center[1], zfar=1e12)
+
+ # Create camera node and add it to pyRender scene
+ camera_node = pyrender.Node(camera=camera, matrix=camera_pose)
+ scene.add_node(camera_node)
+ self.add_point_lighting(scene, camera_node)
+ self.add_lighting(scene, camera_node)
+
+ light_nodes = create_raymond_lights()
+ for node in light_nodes:
+ scene.add_node(node)
+
+ color, rend_depth = renderer.render(scene, flags=pyrender.RenderFlags.RGBA)
+ color = color.astype(np.float32) / 255.0
+ renderer.delete()
+
+ return color
+
+ def add_lighting(self, scene, cam_node, color=np.ones(3), intensity=1.0):
+ # from phalp.visualize.py_renderer import get_light_poses
+ light_poses = get_light_poses()
+ light_poses.append(np.eye(4))
+ cam_pose = scene.get_pose(cam_node)
+ for i, pose in enumerate(light_poses):
+ matrix = cam_pose @ pose
+ node = pyrender.Node(
+ name=f"light-{i:02d}",
+ light=pyrender.DirectionalLight(color=color, intensity=intensity),
+ matrix=matrix,
+ )
+ if scene.has_node(node):
+ continue
+ scene.add_node(node)
+
+ def add_point_lighting(self, scene, cam_node, color=np.ones(3), intensity=1.0):
+ # from phalp.visualize.py_renderer import get_light_poses
+ light_poses = get_light_poses(dist=0.5)
+ light_poses.append(np.eye(4))
+ cam_pose = scene.get_pose(cam_node)
+ for i, pose in enumerate(light_poses):
+ matrix = cam_pose @ pose
+ # node = pyrender.Node(
+ # name=f"light-{i:02d}",
+ # light=pyrender.DirectionalLight(color=color, intensity=intensity),
+ # matrix=matrix,
+ # )
+ node = pyrender.Node(
+ name=f"plight-{i:02d}",
+ light=pyrender.PointLight(color=color, intensity=intensity),
+ matrix=matrix,
+ )
+ if scene.has_node(node):
+ continue
+ scene.add_node(node)
diff --git a/mano_data/closed_fmano.npy b/mano_data/closed_fmano.npy
new file mode 100644
index 0000000000000000000000000000000000000000..ee3ee84dc7ffc1f13da6a772633f24d3d4ff7b8c
--- /dev/null
+++ b/mano_data/closed_fmano.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3e6842048a9a1bce51b07d9f148b676699d4ff607508fe6b104268a15f3560bf
+size 37376
diff --git a/mano_data/mano/MANO_RIGHT.pkl b/mano_data/mano/MANO_RIGHT.pkl
new file mode 100755
index 0000000000000000000000000000000000000000..8e7ac7faf64ad51096ec1da626ea13757ed7f665
--- /dev/null
+++ b/mano_data/mano/MANO_RIGHT.pkl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:45d60aa3b27ef9107a7afd4e00808f307fd91111e1cfa35afd5c4a62de264767
+size 3821356
diff --git a/mano_data/mano_mean_params.npz b/mano_data/mano_mean_params.npz
new file mode 100644
index 0000000000000000000000000000000000000000..dc294b01fb78a9cd6636c87a69b59cf82d28d15b
--- /dev/null
+++ b/mano_data/mano_mean_params.npz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:efc0ec58e4a5cef78f3abfb4e8f91623b8950be9eff8b8e0dbb0d036ebc63988
+size 1178
diff --git a/packages.txt b/packages.txt
new file mode 100644
index 0000000000000000000000000000000000000000..e7dbe446102d6b63616219acf114e844b6c266ee
--- /dev/null
+++ b/packages.txt
@@ -0,0 +1,12 @@
+libglfw3-dev
+libgles2-mesa-dev
+libgl1
+freeglut3-dev
+unzip
+ffmpeg
+libsm6
+libxext6
+libgl1-mesa-dri
+libegl1-mesa
+libgbm1
+build-essential
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..25b2e53fafafb660ff38f893370cb12f88a68f1a
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,46 @@
+opencv-python
+numpy
+trimesh
+plyfile
+pyyaml
+scikit-image
+scikit-learn
+chumpy
+tensorboard
+kornia
+loguru
+pycocotools
+yacs
+lmdb
+fire
+setuptools
+einops
+tqdm
+ipython
+gym
+transformers==4.44.2
+OmegaConf
+matplotlib
+gradio==5.0.2
+diffusers
+rembg
+segment_anything
+jaxtyping
+imageio
+iopath
+timm
+open3d
+pyrender
+pytorch-lightning
+smplx==0.1.28
+chumpy @ git+https://github.com/zerchen/chumpy
+xtcocotools
+pandas
+hydra-core
+hydra-submitit-launcher
+hydra-colorlog
+pyrootutils
+rich
+webdataset
+ultralytics
+dill
diff --git a/wilor/configs/__init__.py b/wilor/configs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3bd9dcf90daa9627cb482d2c2e602288e09b1f1b
--- /dev/null
+++ b/wilor/configs/__init__.py
@@ -0,0 +1,114 @@
+import os
+from typing import Dict
+from yacs.config import CfgNode as CN
+
+CACHE_DIR_PRETRAINED = "./pretrained_models/"
+
+def to_lower(x: Dict) -> Dict:
+ """
+ Convert all dictionary keys to lowercase
+ Args:
+ x (dict): Input dictionary
+ Returns:
+ dict: Output dictionary with all keys converted to lowercase
+ """
+ return {k.lower(): v for k, v in x.items()}
+
+_C = CN(new_allowed=True)
+
+_C.GENERAL = CN(new_allowed=True)
+_C.GENERAL.RESUME = True
+_C.GENERAL.TIME_TO_RUN = 3300
+_C.GENERAL.VAL_STEPS = 100
+_C.GENERAL.LOG_STEPS = 100
+_C.GENERAL.CHECKPOINT_STEPS = 20000
+_C.GENERAL.CHECKPOINT_DIR = "checkpoints"
+_C.GENERAL.SUMMARY_DIR = "tensorboard"
+_C.GENERAL.NUM_GPUS = 1
+_C.GENERAL.NUM_WORKERS = 4
+_C.GENERAL.MIXED_PRECISION = True
+_C.GENERAL.ALLOW_CUDA = True
+_C.GENERAL.PIN_MEMORY = False
+_C.GENERAL.DISTRIBUTED = False
+_C.GENERAL.LOCAL_RANK = 0
+_C.GENERAL.USE_SYNCBN = False
+_C.GENERAL.WORLD_SIZE = 1
+
+_C.TRAIN = CN(new_allowed=True)
+_C.TRAIN.NUM_EPOCHS = 100
+_C.TRAIN.BATCH_SIZE = 32
+_C.TRAIN.SHUFFLE = True
+_C.TRAIN.WARMUP = False
+_C.TRAIN.NORMALIZE_PER_IMAGE = False
+_C.TRAIN.CLIP_GRAD = False
+_C.TRAIN.CLIP_GRAD_VALUE = 1.0
+_C.LOSS_WEIGHTS = CN(new_allowed=True)
+
+_C.DATASETS = CN(new_allowed=True)
+
+_C.MODEL = CN(new_allowed=True)
+_C.MODEL.IMAGE_SIZE = 224
+
+_C.EXTRA = CN(new_allowed=True)
+_C.EXTRA.FOCAL_LENGTH = 5000
+
+_C.DATASETS.CONFIG = CN(new_allowed=True)
+_C.DATASETS.CONFIG.SCALE_FACTOR = 0.3
+_C.DATASETS.CONFIG.ROT_FACTOR = 30
+_C.DATASETS.CONFIG.TRANS_FACTOR = 0.02
+_C.DATASETS.CONFIG.COLOR_SCALE = 0.2
+_C.DATASETS.CONFIG.ROT_AUG_RATE = 0.6
+_C.DATASETS.CONFIG.TRANS_AUG_RATE = 0.5
+_C.DATASETS.CONFIG.DO_FLIP = False
+_C.DATASETS.CONFIG.FLIP_AUG_RATE = 0.5
+_C.DATASETS.CONFIG.EXTREME_CROP_AUG_RATE = 0.10
+
+def default_config() -> CN:
+ """
+ Get a yacs CfgNode object with the default config values.
+ """
+ # Return a clone so that the defaults will not be altered
+ # This is for the "local variable" use pattern
+ return _C.clone()
+
+def dataset_config(name='datasets_tar.yaml') -> CN:
+ """
+ Get dataset config file
+ Returns:
+ CfgNode: Dataset config as a yacs CfgNode object.
+ """
+ cfg = CN(new_allowed=True)
+ config_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), name)
+ cfg.merge_from_file(config_file)
+ cfg.freeze()
+ return cfg
+
+def dataset_eval_config() -> CN:
+ return dataset_config('datasets_eval.yaml')
+
+def get_config(config_file: str, merge: bool = True, update_cachedir: bool = False) -> CN:
+ """
+ Read a config file and optionally merge it with the default config file.
+ Args:
+ config_file (str): Path to config file.
+ merge (bool): Whether to merge with the default config or not.
+ Returns:
+ CfgNode: Config as a yacs CfgNode object.
+ """
+ if merge:
+ cfg = default_config()
+ else:
+ cfg = CN(new_allowed=True)
+ cfg.merge_from_file(config_file)
+
+ if update_cachedir:
+ def update_path(path: str) -> str:
+ if os.path.isabs(path):
+ return path
+ return os.path.join(CACHE_DIR_PRETRAINED, path)
+
+ cfg.MANO.MODEL_PATH = update_path(cfg.MANO.MODEL_PATH)
+ cfg.MANO.MEAN_PARAMS = update_path(cfg.MANO.MEAN_PARAMS)
+
+ cfg.freeze()
+ return cfg
\ No newline at end of file
diff --git a/wilor/datasets/utils.py b/wilor/datasets/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b09722ef538f25f736ab387e15df259966ce7b7
--- /dev/null
+++ b/wilor/datasets/utils.py
@@ -0,0 +1,994 @@
+"""
+Parts of the code are taken or adapted from
+https://github.com/mkocabas/EpipolarPose/blob/master/lib/utils/img_utils.py
+"""
+import torch
+import numpy as np
+from skimage.transform import rotate, resize
+from skimage.filters import gaussian
+import random
+import cv2
+from typing import List, Dict, Tuple
+from yacs.config import CfgNode
+
+def expand_to_aspect_ratio(input_shape, target_aspect_ratio=None):
+ """Increase the size of the bounding box to match the target shape."""
+ if target_aspect_ratio is None:
+ return input_shape
+
+ try:
+ w , h = input_shape
+ except (ValueError, TypeError):
+ return input_shape
+
+ w_t, h_t = target_aspect_ratio
+ if h / w < h_t / w_t:
+ h_new = max(w * h_t / w_t, h)
+ w_new = w
+ else:
+ h_new = h
+ w_new = max(h * w_t / h_t, w)
+ if h_new < h or w_new < w:
+ breakpoint()
+ return np.array([w_new, h_new])
+
+def do_augmentation(aug_config: CfgNode) -> Tuple:
+ """
+ Compute random augmentation parameters.
+ Args:
+ aug_config (CfgNode): Config containing augmentation parameters.
+ Returns:
+ scale (float): Box rescaling factor.
+ rot (float): Random image rotation.
+ do_flip (bool): Whether to flip image or not.
+ do_extreme_crop (bool): Whether to apply extreme cropping (as proposed in EFT).
+ color_scale (List): Color rescaling factor
+ tx (float): Random translation along the x axis.
+ ty (float): Random translation along the y axis.
+ """
+
+ tx = np.clip(np.random.randn(), -1.0, 1.0) * aug_config.TRANS_FACTOR
+ ty = np.clip(np.random.randn(), -1.0, 1.0) * aug_config.TRANS_FACTOR
+ scale = np.clip(np.random.randn(), -1.0, 1.0) * aug_config.SCALE_FACTOR + 1.0
+ rot = np.clip(np.random.randn(), -2.0,
+ 2.0) * aug_config.ROT_FACTOR if random.random() <= aug_config.ROT_AUG_RATE else 0
+ do_flip = aug_config.DO_FLIP and random.random() <= aug_config.FLIP_AUG_RATE
+ do_extreme_crop = random.random() <= aug_config.EXTREME_CROP_AUG_RATE
+ extreme_crop_lvl = aug_config.get('EXTREME_CROP_AUG_LEVEL', 0)
+ # extreme_crop_lvl = 0
+ c_up = 1.0 + aug_config.COLOR_SCALE
+ c_low = 1.0 - aug_config.COLOR_SCALE
+ color_scale = [random.uniform(c_low, c_up), random.uniform(c_low, c_up), random.uniform(c_low, c_up)]
+ return scale, rot, do_flip, do_extreme_crop, extreme_crop_lvl, color_scale, tx, ty
+
+def rotate_2d(pt_2d: np.array, rot_rad: float) -> np.array:
+ """
+ Rotate a 2D point on the x-y plane.
+ Args:
+ pt_2d (np.array): Input 2D point with shape (2,).
+ rot_rad (float): Rotation angle
+ Returns:
+ np.array: Rotated 2D point.
+ """
+ x = pt_2d[0]
+ y = pt_2d[1]
+ sn, cs = np.sin(rot_rad), np.cos(rot_rad)
+ xx = x * cs - y * sn
+ yy = x * sn + y * cs
+ return np.array([xx, yy], dtype=np.float32)
+
+
+def gen_trans_from_patch_cv(c_x: float, c_y: float,
+ src_width: float, src_height: float,
+ dst_width: float, dst_height: float,
+ scale: float, rot: float) -> np.array:
+ """
+ Create transformation matrix for the bounding box crop.
+ Args:
+ c_x (float): Bounding box center x coordinate in the original image.
+ c_y (float): Bounding box center y coordinate in the original image.
+ src_width (float): Bounding box width.
+ src_height (float): Bounding box height.
+ dst_width (float): Output box width.
+ dst_height (float): Output box height.
+ scale (float): Rescaling factor for the bounding box (augmentation).
+ rot (float): Random rotation applied to the box.
+ Returns:
+ trans (np.array): Target geometric transformation.
+ """
+ # augment size with scale
+ src_w = src_width * scale
+ src_h = src_height * scale
+ src_center = np.zeros(2)
+ src_center[0] = c_x
+ src_center[1] = c_y
+ # augment rotation
+ rot_rad = np.pi * rot / 180
+ src_downdir = rotate_2d(np.array([0, src_h * 0.5], dtype=np.float32), rot_rad)
+ src_rightdir = rotate_2d(np.array([src_w * 0.5, 0], dtype=np.float32), rot_rad)
+
+ dst_w = dst_width
+ dst_h = dst_height
+ dst_center = np.array([dst_w * 0.5, dst_h * 0.5], dtype=np.float32)
+ dst_downdir = np.array([0, dst_h * 0.5], dtype=np.float32)
+ dst_rightdir = np.array([dst_w * 0.5, 0], dtype=np.float32)
+
+ src = np.zeros((3, 2), dtype=np.float32)
+ src[0, :] = src_center
+ src[1, :] = src_center + src_downdir
+ src[2, :] = src_center + src_rightdir
+
+ dst = np.zeros((3, 2), dtype=np.float32)
+ dst[0, :] = dst_center
+ dst[1, :] = dst_center + dst_downdir
+ dst[2, :] = dst_center + dst_rightdir
+
+ trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))
+
+ return trans
+
+
+def trans_point2d(pt_2d: np.array, trans: np.array):
+ """
+ Transform a 2D point using translation matrix trans.
+ Args:
+ pt_2d (np.array): Input 2D point with shape (2,).
+ trans (np.array): Transformation matrix.
+ Returns:
+ np.array: Transformed 2D point.
+ """
+ src_pt = np.array([pt_2d[0], pt_2d[1], 1.]).T
+ dst_pt = np.dot(trans, src_pt)
+ return dst_pt[0:2]
+
+def get_transform(center, scale, res, rot=0):
+ """Generate transformation matrix."""
+ """Taken from PARE: https://github.com/mkocabas/PARE/blob/6e0caca86c6ab49ff80014b661350958e5b72fd8/pare/utils/image_utils.py"""
+ h = 200 * scale
+ t = np.zeros((3, 3))
+ t[0, 0] = float(res[1]) / h
+ t[1, 1] = float(res[0]) / h
+ t[0, 2] = res[1] * (-float(center[0]) / h + .5)
+ t[1, 2] = res[0] * (-float(center[1]) / h + .5)
+ t[2, 2] = 1
+ if not rot == 0:
+ rot = -rot # To match direction of rotation from cropping
+ rot_mat = np.zeros((3, 3))
+ rot_rad = rot * np.pi / 180
+ sn, cs = np.sin(rot_rad), np.cos(rot_rad)
+ rot_mat[0, :2] = [cs, -sn]
+ rot_mat[1, :2] = [sn, cs]
+ rot_mat[2, 2] = 1
+ # Need to rotate around center
+ t_mat = np.eye(3)
+ t_mat[0, 2] = -res[1] / 2
+ t_mat[1, 2] = -res[0] / 2
+ t_inv = t_mat.copy()
+ t_inv[:2, 2] *= -1
+ t = np.dot(t_inv, np.dot(rot_mat, np.dot(t_mat, t)))
+ return t
+
+
+def transform(pt, center, scale, res, invert=0, rot=0, as_int=True):
+ """Transform pixel location to different reference."""
+ """Taken from PARE: https://github.com/mkocabas/PARE/blob/6e0caca86c6ab49ff80014b661350958e5b72fd8/pare/utils/image_utils.py"""
+ t = get_transform(center, scale, res, rot=rot)
+ if invert:
+ t = np.linalg.inv(t)
+ new_pt = np.array([pt[0] - 1, pt[1] - 1, 1.]).T
+ new_pt = np.dot(t, new_pt)
+ if as_int:
+ new_pt = new_pt.astype(int)
+ return new_pt[:2] + 1
+
+def crop_img(img, ul, br, border_mode=cv2.BORDER_CONSTANT, border_value=0):
+ c_x = (ul[0] + br[0])/2
+ c_y = (ul[1] + br[1])/2
+ bb_width = patch_width = br[0] - ul[0]
+ bb_height = patch_height = br[1] - ul[1]
+ trans = gen_trans_from_patch_cv(c_x, c_y, bb_width, bb_height, patch_width, patch_height, 1.0, 0)
+ img_patch = cv2.warpAffine(img, trans, (int(patch_width), int(patch_height)),
+ flags=cv2.INTER_LINEAR,
+ borderMode=border_mode,
+ borderValue=border_value
+ )
+
+ # Force borderValue=cv2.BORDER_CONSTANT for alpha channel
+ if (img.shape[2] == 4) and (border_mode != cv2.BORDER_CONSTANT):
+ img_patch[:,:,3] = cv2.warpAffine(img[:,:,3], trans, (int(patch_width), int(patch_height)),
+ flags=cv2.INTER_LINEAR,
+ borderMode=cv2.BORDER_CONSTANT,
+ )
+
+ return img_patch
+
+def generate_image_patch_skimage(img: np.array, c_x: float, c_y: float,
+ bb_width: float, bb_height: float,
+ patch_width: float, patch_height: float,
+ do_flip: bool, scale: float, rot: float,
+ border_mode=cv2.BORDER_CONSTANT, border_value=0) -> Tuple[np.array, np.array]:
+ """
+ Crop image according to the supplied bounding box.
+ Args:
+ img (np.array): Input image of shape (H, W, 3)
+ c_x (float): Bounding box center x coordinate in the original image.
+ c_y (float): Bounding box center y coordinate in the original image.
+ bb_width (float): Bounding box width.
+ bb_height (float): Bounding box height.
+ patch_width (float): Output box width.
+ patch_height (float): Output box height.
+ do_flip (bool): Whether to flip image or not.
+ scale (float): Rescaling factor for the bounding box (augmentation).
+ rot (float): Random rotation applied to the box.
+ Returns:
+ img_patch (np.array): Cropped image patch of shape (patch_height, patch_height, 3)
+ trans (np.array): Transformation matrix.
+ """
+
+ img_height, img_width, img_channels = img.shape
+ if do_flip:
+ img = img[:, ::-1, :]
+ c_x = img_width - c_x - 1
+
+ trans = gen_trans_from_patch_cv(c_x, c_y, bb_width, bb_height, patch_width, patch_height, scale, rot)
+
+ #img_patch = cv2.warpAffine(img, trans, (int(patch_width), int(patch_height)), flags=cv2.INTER_LINEAR)
+
+ # skimage
+ center = np.zeros(2)
+ center[0] = c_x
+ center[1] = c_y
+ res = np.zeros(2)
+ res[0] = patch_width
+ res[1] = patch_height
+ # assumes bb_width = bb_height
+ # assumes patch_width = patch_height
+ assert bb_width == bb_height, f'{bb_width=} != {bb_height=}'
+ assert patch_width == patch_height, f'{patch_width=} != {patch_height=}'
+ scale1 = scale*bb_width/200.
+
+ # Upper left point
+ ul = np.array(transform([1, 1], center, scale1, res, invert=1, as_int=False)) - 1
+ # Bottom right point
+ br = np.array(transform([res[0] + 1,
+ res[1] + 1], center, scale1, res, invert=1, as_int=False)) - 1
+
+ # Padding so that when rotated proper amount of context is included
+ try:
+ pad = int(np.linalg.norm(br - ul) / 2 - float(br[1] - ul[1]) / 2) + 1
+ except:
+ breakpoint()
+ if not rot == 0:
+ ul -= pad
+ br += pad
+
+
+ if False:
+ # Old way of cropping image
+ ul_int = ul.astype(int)
+ br_int = br.astype(int)
+ new_shape = [br_int[1] - ul_int[1], br_int[0] - ul_int[0]]
+ if len(img.shape) > 2:
+ new_shape += [img.shape[2]]
+ new_img = np.zeros(new_shape)
+
+ # Range to fill new array
+ new_x = max(0, -ul_int[0]), min(br_int[0], len(img[0])) - ul_int[0]
+ new_y = max(0, -ul_int[1]), min(br_int[1], len(img)) - ul_int[1]
+ # Range to sample from original image
+ old_x = max(0, ul_int[0]), min(len(img[0]), br_int[0])
+ old_y = max(0, ul_int[1]), min(len(img), br_int[1])
+ new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1],
+ old_x[0]:old_x[1]]
+
+ # New way of cropping image
+ new_img = crop_img(img, ul, br, border_mode=border_mode, border_value=border_value).astype(np.float32)
+
+ # print(f'{new_img.shape=}')
+ # print(f'{new_img1.shape=}')
+ # print(f'{np.allclose(new_img, new_img1)=}')
+ # print(f'{img.dtype=}')
+
+
+ if not rot == 0:
+ # Remove padding
+
+ new_img = rotate(new_img, rot) # scipy.misc.imrotate(new_img, rot)
+ new_img = new_img[pad:-pad, pad:-pad]
+
+ if new_img.shape[0] < 1 or new_img.shape[1] < 1:
+ print(f'{img.shape=}')
+ print(f'{new_img.shape=}')
+ print(f'{ul=}')
+ print(f'{br=}')
+ print(f'{pad=}')
+ print(f'{rot=}')
+
+ breakpoint()
+
+ # resize image
+ new_img = resize(new_img, res) # scipy.misc.imresize(new_img, res)
+
+ new_img = np.clip(new_img, 0, 255).astype(np.uint8)
+
+ return new_img, trans
+
+
+def generate_image_patch_cv2(img: np.array, c_x: float, c_y: float,
+ bb_width: float, bb_height: float,
+ patch_width: float, patch_height: float,
+ do_flip: bool, scale: float, rot: float,
+ border_mode=cv2.BORDER_CONSTANT, border_value=0) -> Tuple[np.array, np.array]:
+ """
+ Crop the input image and return the crop and the corresponding transformation matrix.
+ Args:
+ img (np.array): Input image of shape (H, W, 3)
+ c_x (float): Bounding box center x coordinate in the original image.
+ c_y (float): Bounding box center y coordinate in the original image.
+ bb_width (float): Bounding box width.
+ bb_height (float): Bounding box height.
+ patch_width (float): Output box width.
+ patch_height (float): Output box height.
+ do_flip (bool): Whether to flip image or not.
+ scale (float): Rescaling factor for the bounding box (augmentation).
+ rot (float): Random rotation applied to the box.
+ Returns:
+ img_patch (np.array): Cropped image patch of shape (patch_height, patch_height, 3)
+ trans (np.array): Transformation matrix.
+ """
+
+ img_height, img_width, img_channels = img.shape
+ if do_flip:
+ img = img[:, ::-1, :]
+ c_x = img_width - c_x - 1
+
+
+ trans = gen_trans_from_patch_cv(c_x, c_y, bb_width, bb_height, patch_width, patch_height, scale, rot)
+
+ img_patch = cv2.warpAffine(img, trans, (int(patch_width), int(patch_height)),
+ flags=cv2.INTER_LINEAR,
+ borderMode=border_mode,
+ borderValue=border_value,
+ )
+ # Force borderValue=cv2.BORDER_CONSTANT for alpha channel
+ if (img.shape[2] == 4) and (border_mode != cv2.BORDER_CONSTANT):
+ img_patch[:,:,3] = cv2.warpAffine(img[:,:,3], trans, (int(patch_width), int(patch_height)),
+ flags=cv2.INTER_LINEAR,
+ borderMode=cv2.BORDER_CONSTANT,
+ )
+
+ return img_patch, trans
+
+
+def convert_cvimg_to_tensor(cvimg: np.array):
+ """
+ Convert image from HWC to CHW format.
+ Args:
+ cvimg (np.array): Image of shape (H, W, 3) as loaded by OpenCV.
+ Returns:
+ np.array: Output image of shape (3, H, W).
+ """
+ # from h,w,c(OpenCV) to c,h,w
+ img = cvimg.copy()
+ img = np.transpose(img, (2, 0, 1))
+ # from int to float
+ img = img.astype(np.float32)
+ return img
+
+def fliplr_params(mano_params: Dict, has_mano_params: Dict) -> Tuple[Dict, Dict]:
+ """
+ Flip MANO parameters when flipping the image.
+ Args:
+ mano_params (Dict): MANO parameter annotations.
+ has_mano_params (Dict): Whether MANO annotations are valid.
+ Returns:
+ Dict, Dict: Flipped MANO parameters and valid flags.
+ """
+ global_orient = mano_params['global_orient'].copy()
+ hand_pose = mano_params['hand_pose'].copy()
+ betas = mano_params['betas'].copy()
+ has_global_orient = has_mano_params['global_orient'].copy()
+ has_hand_pose = has_mano_params['hand_pose'].copy()
+ has_betas = has_mano_params['betas'].copy()
+
+ global_orient[1::3] *= -1
+ global_orient[2::3] *= -1
+ hand_pose[1::3] *= -1
+ hand_pose[2::3] *= -1
+
+ mano_params = {'global_orient': global_orient.astype(np.float32),
+ 'hand_pose': hand_pose.astype(np.float32),
+ 'betas': betas.astype(np.float32)
+ }
+
+ has_mano_params = {'global_orient': has_global_orient,
+ 'hand_pose': has_hand_pose,
+ 'betas': has_betas
+ }
+
+ return mano_params, has_mano_params
+
+
+def fliplr_keypoints(joints: np.array, width: float, flip_permutation: List[int]) -> np.array:
+ """
+ Flip 2D or 3D keypoints.
+ Args:
+ joints (np.array): Array of shape (N, 3) or (N, 4) containing 2D or 3D keypoint locations and confidence.
+ flip_permutation (List): Permutation to apply after flipping.
+ Returns:
+ np.array: Flipped 2D or 3D keypoints with shape (N, 3) or (N, 4) respectively.
+ """
+ joints = joints.copy()
+ # Flip horizontal
+ joints[:, 0] = width - joints[:, 0] - 1
+ joints = joints[flip_permutation, :]
+
+ return joints
+
+def keypoint_3d_processing(keypoints_3d: np.array, flip_permutation: List[int], rot: float, do_flip: float) -> np.array:
+ """
+ Process 3D keypoints (rotation/flipping).
+ Args:
+ keypoints_3d (np.array): Input array of shape (N, 4) containing the 3D keypoints and confidence.
+ flip_permutation (List): Permutation to apply after flipping.
+ rot (float): Random rotation applied to the keypoints.
+ do_flip (bool): Whether to flip keypoints or not.
+ Returns:
+ np.array: Transformed 3D keypoints with shape (N, 4).
+ """
+ if do_flip:
+ keypoints_3d = fliplr_keypoints(keypoints_3d, 1, flip_permutation)
+ # in-plane rotation
+ rot_mat = np.eye(3)
+ if not rot == 0:
+ rot_rad = -rot * np.pi / 180
+ sn,cs = np.sin(rot_rad), np.cos(rot_rad)
+ rot_mat[0,:2] = [cs, -sn]
+ rot_mat[1,:2] = [sn, cs]
+ keypoints_3d[:, :-1] = np.einsum('ij,kj->ki', rot_mat, keypoints_3d[:, :-1])
+ # flip the x coordinates
+ keypoints_3d = keypoints_3d.astype('float32')
+ return keypoints_3d
+
+def rot_aa(aa: np.array, rot: float) -> np.array:
+ """
+ Rotate axis angle parameters.
+ Args:
+ aa (np.array): Axis-angle vector of shape (3,).
+ rot (np.array): Rotation angle in degrees.
+ Returns:
+ np.array: Rotated axis-angle vector.
+ """
+ # pose parameters
+ R = np.array([[np.cos(np.deg2rad(-rot)), -np.sin(np.deg2rad(-rot)), 0],
+ [np.sin(np.deg2rad(-rot)), np.cos(np.deg2rad(-rot)), 0],
+ [0, 0, 1]])
+ # find the rotation of the hand in camera frame
+ per_rdg, _ = cv2.Rodrigues(aa)
+ # apply the global rotation to the global orientation
+ resrot, _ = cv2.Rodrigues(np.dot(R,per_rdg))
+ aa = (resrot.T)[0]
+ return aa.astype(np.float32)
+
+def mano_param_processing(mano_params: Dict, has_mano_params: Dict, rot: float, do_flip: bool) -> Tuple[Dict, Dict]:
+ """
+ Apply random augmentations to the MANO parameters.
+ Args:
+ mano_params (Dict): MANO parameter annotations.
+ has_mano_params (Dict): Whether mano annotations are valid.
+ rot (float): Random rotation applied to the keypoints.
+ do_flip (bool): Whether to flip keypoints or not.
+ Returns:
+ Dict, Dict: Transformed MANO parameters and valid flags.
+ """
+ if do_flip:
+ mano_params, has_mano_params = fliplr_params(mano_params, has_mano_params)
+ mano_params['global_orient'] = rot_aa(mano_params['global_orient'], rot)
+ return mano_params, has_mano_params
+
+
+
+def get_example(img_path: str|np.ndarray, center_x: float, center_y: float,
+ width: float, height: float,
+ keypoints_2d: np.array, keypoints_3d: np.array,
+ mano_params: Dict, has_mano_params: Dict,
+ flip_kp_permutation: List[int],
+ patch_width: int, patch_height: int,
+ mean: np.array, std: np.array,
+ do_augment: bool, is_right: bool, augm_config: CfgNode,
+ is_bgr: bool = True,
+ use_skimage_antialias: bool = False,
+ border_mode: int = cv2.BORDER_CONSTANT,
+ return_trans: bool = False) -> Tuple:
+ """
+ Get an example from the dataset and (possibly) apply random augmentations.
+ Args:
+ img_path (str): Image filename
+ center_x (float): Bounding box center x coordinate in the original image.
+ center_y (float): Bounding box center y coordinate in the original image.
+ width (float): Bounding box width.
+ height (float): Bounding box height.
+ keypoints_2d (np.array): Array with shape (N,3) containing the 2D keypoints in the original image coordinates.
+ keypoints_3d (np.array): Array with shape (N,4) containing the 3D keypoints.
+ mano_params (Dict): MANO parameter annotations.
+ has_mano_params (Dict): Whether MANO annotations are valid.
+ flip_kp_permutation (List): Permutation to apply to the keypoints after flipping.
+ patch_width (float): Output box width.
+ patch_height (float): Output box height.
+ mean (np.array): Array of shape (3,) containing the mean for normalizing the input image.
+ std (np.array): Array of shape (3,) containing the std for normalizing the input image.
+ do_augment (bool): Whether to apply data augmentation or not.
+ aug_config (CfgNode): Config containing augmentation parameters.
+ Returns:
+ return img_patch, keypoints_2d, keypoints_3d, mano_params, has_mano_params, img_size
+ img_patch (np.array): Cropped image patch of shape (3, patch_height, patch_height)
+ keypoints_2d (np.array): Array with shape (N,3) containing the transformed 2D keypoints.
+ keypoints_3d (np.array): Array with shape (N,4) containing the transformed 3D keypoints.
+ mano_params (Dict): Transformed MANO parameters.
+ has_mano_params (Dict): Valid flag for transformed MANO parameters.
+ img_size (np.array): Image size of the original image.
+ """
+ if isinstance(img_path, str):
+ # 1. load image
+ cvimg = cv2.imread(img_path, cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION)
+ if not isinstance(cvimg, np.ndarray):
+ raise IOError("Fail to read %s" % img_path)
+ elif isinstance(img_path, np.ndarray):
+ cvimg = img_path
+ else:
+ raise TypeError('img_path must be either a string or a numpy array')
+ img_height, img_width, img_channels = cvimg.shape
+
+ img_size = np.array([img_height, img_width])
+
+ # 2. get augmentation params
+ if do_augment:
+ scale, rot, do_flip, do_extreme_crop, extreme_crop_lvl, color_scale, tx, ty = do_augmentation(augm_config)
+ else:
+ scale, rot, do_flip, do_extreme_crop, extreme_crop_lvl, color_scale, tx, ty = 1.0, 0, False, False, 0, [1.0, 1.0, 1.0], 0., 0.
+
+ # if it's a left hand, we flip
+ if not is_right:
+ do_flip = True
+
+ if width < 1 or height < 1:
+ breakpoint()
+
+ if do_extreme_crop:
+ if extreme_crop_lvl == 0:
+ center_x1, center_y1, width1, height1 = extreme_cropping(center_x, center_y, width, height, keypoints_2d)
+ elif extreme_crop_lvl == 1:
+ center_x1, center_y1, width1, height1 = extreme_cropping_aggressive(center_x, center_y, width, height, keypoints_2d)
+
+ THRESH = 4
+ if width1 < THRESH or height1 < THRESH:
+ # print(f'{do_extreme_crop=}')
+ # print(f'width: {width}, height: {height}')
+ # print(f'width1: {width1}, height1: {height1}')
+ # print(f'center_x: {center_x}, center_y: {center_y}')
+ # print(f'center_x1: {center_x1}, center_y1: {center_y1}')
+ # print(f'keypoints_2d: {keypoints_2d}')
+ # print(f'\n\n', flush=True)
+ # breakpoint()
+ pass
+ # print(f'skip ==> width1: {width1}, height1: {height1}, width: {width}, height: {height}')
+ else:
+ center_x, center_y, width, height = center_x1, center_y1, width1, height1
+
+ center_x += width * tx
+ center_y += height * ty
+
+ # Process 3D keypoints
+ keypoints_3d = keypoint_3d_processing(keypoints_3d, flip_kp_permutation, rot, do_flip)
+
+ # 3. generate image patch
+ if use_skimage_antialias:
+ # Blur image to avoid aliasing artifacts
+ downsampling_factor = (patch_width / (width*scale))
+ if downsampling_factor > 1.1:
+ cvimg = gaussian(cvimg, sigma=(downsampling_factor-1)/2, channel_axis=2, preserve_range=True, truncate=3.0)
+
+ img_patch_cv, trans = generate_image_patch_cv2(cvimg,
+ center_x, center_y,
+ width, height,
+ patch_width, patch_height,
+ do_flip, scale, rot,
+ border_mode=border_mode)
+
+ # img_patch_cv, trans = generate_image_patch_skimage(cvimg,
+ # center_x, center_y,
+ # width, height,
+ # patch_width, patch_height,
+ # do_flip, scale, rot,
+ # border_mode=border_mode)
+
+ image = img_patch_cv.copy()
+ if is_bgr:
+ image = image[:, :, ::-1]
+ img_patch_cv = image.copy()
+ img_patch = convert_cvimg_to_tensor(image)
+
+
+ mano_params, has_mano_params = mano_param_processing(mano_params, has_mano_params, rot, do_flip)
+
+ # apply normalization
+ for n_c in range(min(img_channels, 3)):
+ img_patch[n_c, :, :] = np.clip(img_patch[n_c, :, :] * color_scale[n_c], 0, 255)
+ if mean is not None and std is not None:
+ img_patch[n_c, :, :] = (img_patch[n_c, :, :] - mean[n_c]) / std[n_c]
+ if do_flip:
+ keypoints_2d = fliplr_keypoints(keypoints_2d, img_width, flip_kp_permutation)
+
+
+ for n_jt in range(len(keypoints_2d)):
+ keypoints_2d[n_jt, 0:2] = trans_point2d(keypoints_2d[n_jt, 0:2], trans)
+ keypoints_2d[:, :-1] = keypoints_2d[:, :-1] / patch_width - 0.5
+
+ if not return_trans:
+ return img_patch, keypoints_2d, keypoints_3d, mano_params, has_mano_params, img_size
+ else:
+ return img_patch, keypoints_2d, keypoints_3d, mano_params, has_mano_params, img_size, trans
+
+def crop_to_hips(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array) -> Tuple:
+ """
+ Extreme cropping: Crop the box up to the hip locations.
+ Args:
+ center_x (float): x coordinate of the bounding box center.
+ center_y (float): y coordinate of the bounding box center.
+ width (float): Bounding box width.
+ height (float): Bounding box height.
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
+ Returns:
+ center_x (float): x coordinate of the new bounding box center.
+ center_y (float): y coordinate of the new bounding box center.
+ width (float): New bounding box width.
+ height (float): New bounding box height.
+ """
+ keypoints_2d = keypoints_2d.copy()
+ lower_body_keypoints = [10, 11, 13, 14, 19, 20, 21, 22, 23, 24, 25+0, 25+1, 25+4, 25+5]
+ keypoints_2d[lower_body_keypoints, :] = 0
+ if keypoints_2d[:, -1].sum() > 1:
+ center, scale = get_bbox(keypoints_2d)
+ center_x = center[0]
+ center_y = center[1]
+ width = 1.1 * scale[0]
+ height = 1.1 * scale[1]
+ return center_x, center_y, width, height
+
+
+def crop_to_shoulders(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array):
+ """
+ Extreme cropping: Crop the box up to the shoulder locations.
+ Args:
+ center_x (float): x coordinate of the bounding box center.
+ center_y (float): y coordinate of the bounding box center.
+ width (float): Bounding box width.
+ height (float): Bounding box height.
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
+ Returns:
+ center_x (float): x coordinate of the new bounding box center.
+ center_y (float): y coordinate of the new bounding box center.
+ width (float): New bounding box width.
+ height (float): New bounding box height.
+ """
+ keypoints_2d = keypoints_2d.copy()
+ lower_body_keypoints = [3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 19, 20, 21, 22, 23, 24] + [25 + i for i in [0, 1, 2, 3, 4, 5, 6, 7, 10, 11, 14, 15, 16]]
+ keypoints_2d[lower_body_keypoints, :] = 0
+ center, scale = get_bbox(keypoints_2d)
+ if keypoints_2d[:, -1].sum() > 1:
+ center, scale = get_bbox(keypoints_2d)
+ center_x = center[0]
+ center_y = center[1]
+ width = 1.2 * scale[0]
+ height = 1.2 * scale[1]
+ return center_x, center_y, width, height
+
+def crop_to_head(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array):
+ """
+ Extreme cropping: Crop the box and keep on only the head.
+ Args:
+ center_x (float): x coordinate of the bounding box center.
+ center_y (float): y coordinate of the bounding box center.
+ width (float): Bounding box width.
+ height (float): Bounding box height.
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
+ Returns:
+ center_x (float): x coordinate of the new bounding box center.
+ center_y (float): y coordinate of the new bounding box center.
+ width (float): New bounding box width.
+ height (float): New bounding box height.
+ """
+ keypoints_2d = keypoints_2d.copy()
+ lower_body_keypoints = [3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 19, 20, 21, 22, 23, 24] + [25 + i for i in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 14, 15, 16]]
+ keypoints_2d[lower_body_keypoints, :] = 0
+ if keypoints_2d[:, -1].sum() > 1:
+ center, scale = get_bbox(keypoints_2d)
+ center_x = center[0]
+ center_y = center[1]
+ width = 1.3 * scale[0]
+ height = 1.3 * scale[1]
+ return center_x, center_y, width, height
+
+def crop_torso_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array):
+ """
+ Extreme cropping: Crop the box and keep on only the torso.
+ Args:
+ center_x (float): x coordinate of the bounding box center.
+ center_y (float): y coordinate of the bounding box center.
+ width (float): Bounding box width.
+ height (float): Bounding box height.
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
+ Returns:
+ center_x (float): x coordinate of the new bounding box center.
+ center_y (float): y coordinate of the new bounding box center.
+ width (float): New bounding box width.
+ height (float): New bounding box height.
+ """
+ keypoints_2d = keypoints_2d.copy()
+ nontorso_body_keypoints = [0, 3, 4, 6, 7, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24] + [25 + i for i in [0, 1, 4, 5, 6, 7, 10, 11, 13, 17, 18]]
+ keypoints_2d[nontorso_body_keypoints, :] = 0
+ if keypoints_2d[:, -1].sum() > 1:
+ center, scale = get_bbox(keypoints_2d)
+ center_x = center[0]
+ center_y = center[1]
+ width = 1.1 * scale[0]
+ height = 1.1 * scale[1]
+ return center_x, center_y, width, height
+
+def crop_rightarm_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array):
+ """
+ Extreme cropping: Crop the box and keep on only the right arm.
+ Args:
+ center_x (float): x coordinate of the bounding box center.
+ center_y (float): y coordinate of the bounding box center.
+ width (float): Bounding box width.
+ height (float): Bounding box height.
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
+ Returns:
+ center_x (float): x coordinate of the new bounding box center.
+ center_y (float): y coordinate of the new bounding box center.
+ width (float): New bounding box width.
+ height (float): New bounding box height.
+ """
+ keypoints_2d = keypoints_2d.copy()
+ nonrightarm_body_keypoints = [0, 1, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24] + [25 + i for i in [0, 1, 2, 3, 4, 5, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18]]
+ keypoints_2d[nonrightarm_body_keypoints, :] = 0
+ if keypoints_2d[:, -1].sum() > 1:
+ center, scale = get_bbox(keypoints_2d)
+ center_x = center[0]
+ center_y = center[1]
+ width = 1.1 * scale[0]
+ height = 1.1 * scale[1]
+ return center_x, center_y, width, height
+
+def crop_leftarm_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array):
+ """
+ Extreme cropping: Crop the box and keep on only the left arm.
+ Args:
+ center_x (float): x coordinate of the bounding box center.
+ center_y (float): y coordinate of the bounding box center.
+ width (float): Bounding box width.
+ height (float): Bounding box height.
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
+ Returns:
+ center_x (float): x coordinate of the new bounding box center.
+ center_y (float): y coordinate of the new bounding box center.
+ width (float): New bounding box width.
+ height (float): New bounding box height.
+ """
+ keypoints_2d = keypoints_2d.copy()
+ nonleftarm_body_keypoints = [0, 1, 2, 3, 4, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24] + [25 + i for i in [0, 1, 2, 3, 4, 5, 6, 7, 8, 12, 13, 14, 15, 16, 17, 18]]
+ keypoints_2d[nonleftarm_body_keypoints, :] = 0
+ if keypoints_2d[:, -1].sum() > 1:
+ center, scale = get_bbox(keypoints_2d)
+ center_x = center[0]
+ center_y = center[1]
+ width = 1.1 * scale[0]
+ height = 1.1 * scale[1]
+ return center_x, center_y, width, height
+
+def crop_legs_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array):
+ """
+ Extreme cropping: Crop the box and keep on only the legs.
+ Args:
+ center_x (float): x coordinate of the bounding box center.
+ center_y (float): y coordinate of the bounding box center.
+ width (float): Bounding box width.
+ height (float): Bounding box height.
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
+ Returns:
+ center_x (float): x coordinate of the new bounding box center.
+ center_y (float): y coordinate of the new bounding box center.
+ width (float): New bounding box width.
+ height (float): New bounding box height.
+ """
+ keypoints_2d = keypoints_2d.copy()
+ nonlegs_body_keypoints = [0, 1, 2, 3, 4, 5, 6, 7, 15, 16, 17, 18] + [25 + i for i in [6, 7, 8, 9, 10, 11, 12, 13, 15, 16, 17, 18]]
+ keypoints_2d[nonlegs_body_keypoints, :] = 0
+ if keypoints_2d[:, -1].sum() > 1:
+ center, scale = get_bbox(keypoints_2d)
+ center_x = center[0]
+ center_y = center[1]
+ width = 1.1 * scale[0]
+ height = 1.1 * scale[1]
+ return center_x, center_y, width, height
+
+def crop_rightleg_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array):
+ """
+ Extreme cropping: Crop the box and keep on only the right leg.
+ Args:
+ center_x (float): x coordinate of the bounding box center.
+ center_y (float): y coordinate of the bounding box center.
+ width (float): Bounding box width.
+ height (float): Bounding box height.
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
+ Returns:
+ center_x (float): x coordinate of the new bounding box center.
+ center_y (float): y coordinate of the new bounding box center.
+ width (float): New bounding box width.
+ height (float): New bounding box height.
+ """
+ keypoints_2d = keypoints_2d.copy()
+ nonrightleg_body_keypoints = [0, 1, 2, 3, 4, 5, 6, 7, 8, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21] + [25 + i for i in [3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18]]
+ keypoints_2d[nonrightleg_body_keypoints, :] = 0
+ if keypoints_2d[:, -1].sum() > 1:
+ center, scale = get_bbox(keypoints_2d)
+ center_x = center[0]
+ center_y = center[1]
+ width = 1.1 * scale[0]
+ height = 1.1 * scale[1]
+ return center_x, center_y, width, height
+
+def crop_leftleg_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array):
+ """
+ Extreme cropping: Crop the box and keep on only the left leg.
+ Args:
+ center_x (float): x coordinate of the bounding box center.
+ center_y (float): y coordinate of the bounding box center.
+ width (float): Bounding box width.
+ height (float): Bounding box height.
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
+ Returns:
+ center_x (float): x coordinate of the new bounding box center.
+ center_y (float): y coordinate of the new bounding box center.
+ width (float): New bounding box width.
+ height (float): New bounding box height.
+ """
+ keypoints_2d = keypoints_2d.copy()
+ nonleftleg_body_keypoints = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 15, 16, 17, 18, 22, 23, 24] + [25 + i for i in [0, 1, 2, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18]]
+ keypoints_2d[nonleftleg_body_keypoints, :] = 0
+ if keypoints_2d[:, -1].sum() > 1:
+ center, scale = get_bbox(keypoints_2d)
+ center_x = center[0]
+ center_y = center[1]
+ width = 1.1 * scale[0]
+ height = 1.1 * scale[1]
+ return center_x, center_y, width, height
+
+def full_body(keypoints_2d: np.array) -> bool:
+ """
+ Check if all main body joints are visible.
+ Args:
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
+ Returns:
+ bool: True if all main body joints are visible.
+ """
+
+ body_keypoints_openpose = [2, 3, 4, 5, 6, 7, 10, 11, 13, 14]
+ body_keypoints = [25 + i for i in [8, 7, 6, 9, 10, 11, 1, 0, 4, 5]]
+ return (np.maximum(keypoints_2d[body_keypoints, -1], keypoints_2d[body_keypoints_openpose, -1]) > 0).sum() == len(body_keypoints)
+
+def upper_body(keypoints_2d: np.array):
+ """
+ Check if all upper body joints are visible.
+ Args:
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
+ Returns:
+ bool: True if all main body joints are visible.
+ """
+ lower_body_keypoints_openpose = [10, 11, 13, 14]
+ lower_body_keypoints = [25 + i for i in [1, 0, 4, 5]]
+ upper_body_keypoints_openpose = [0, 1, 15, 16, 17, 18]
+ upper_body_keypoints = [25+8, 25+9, 25+12, 25+13, 25+17, 25+18]
+ return ((keypoints_2d[lower_body_keypoints + lower_body_keypoints_openpose, -1] > 0).sum() == 0)\
+ and ((keypoints_2d[upper_body_keypoints + upper_body_keypoints_openpose, -1] > 0).sum() >= 2)
+
+def get_bbox(keypoints_2d: np.array, rescale: float = 1.2) -> Tuple:
+ """
+ Get center and scale for bounding box from openpose detections.
+ Args:
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
+ rescale (float): Scale factor to rescale bounding boxes computed from the keypoints.
+ Returns:
+ center (np.array): Array of shape (2,) containing the new bounding box center.
+ scale (float): New bounding box scale.
+ """
+ valid = keypoints_2d[:,-1] > 0
+ valid_keypoints = keypoints_2d[valid][:,:-1]
+ center = 0.5 * (valid_keypoints.max(axis=0) + valid_keypoints.min(axis=0))
+ bbox_size = (valid_keypoints.max(axis=0) - valid_keypoints.min(axis=0))
+ # adjust bounding box tightness
+ scale = bbox_size
+ scale *= rescale
+ return center, scale
+
+def extreme_cropping(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array) -> Tuple:
+ """
+ Perform extreme cropping
+ Args:
+ center_x (float): x coordinate of bounding box center.
+ center_y (float): y coordinate of bounding box center.
+ width (float): bounding box width.
+ height (float): bounding box height.
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
+ rescale (float): Scale factor to rescale bounding boxes computed from the keypoints.
+ Returns:
+ center_x (float): x coordinate of bounding box center.
+ center_y (float): y coordinate of bounding box center.
+ width (float): bounding box width.
+ height (float): bounding box height.
+ """
+ p = torch.rand(1).item()
+ if full_body(keypoints_2d):
+ if p < 0.7:
+ center_x, center_y, width, height = crop_to_hips(center_x, center_y, width, height, keypoints_2d)
+ elif p < 0.9:
+ center_x, center_y, width, height = crop_to_shoulders(center_x, center_y, width, height, keypoints_2d)
+ else:
+ center_x, center_y, width, height = crop_to_head(center_x, center_y, width, height, keypoints_2d)
+ elif upper_body(keypoints_2d):
+ if p < 0.9:
+ center_x, center_y, width, height = crop_to_shoulders(center_x, center_y, width, height, keypoints_2d)
+ else:
+ center_x, center_y, width, height = crop_to_head(center_x, center_y, width, height, keypoints_2d)
+
+ return center_x, center_y, max(width, height), max(width, height)
+
+def extreme_cropping_aggressive(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array) -> Tuple:
+ """
+ Perform aggressive extreme cropping
+ Args:
+ center_x (float): x coordinate of bounding box center.
+ center_y (float): y coordinate of bounding box center.
+ width (float): bounding box width.
+ height (float): bounding box height.
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
+ rescale (float): Scale factor to rescale bounding boxes computed from the keypoints.
+ Returns:
+ center_x (float): x coordinate of bounding box center.
+ center_y (float): y coordinate of bounding box center.
+ width (float): bounding box width.
+ height (float): bounding box height.
+ """
+ p = torch.rand(1).item()
+ if full_body(keypoints_2d):
+ if p < 0.2:
+ center_x, center_y, width, height = crop_to_hips(center_x, center_y, width, height, keypoints_2d)
+ elif p < 0.3:
+ center_x, center_y, width, height = crop_to_shoulders(center_x, center_y, width, height, keypoints_2d)
+ elif p < 0.4:
+ center_x, center_y, width, height = crop_to_head(center_x, center_y, width, height, keypoints_2d)
+ elif p < 0.5:
+ center_x, center_y, width, height = crop_torso_only(center_x, center_y, width, height, keypoints_2d)
+ elif p < 0.6:
+ center_x, center_y, width, height = crop_rightarm_only(center_x, center_y, width, height, keypoints_2d)
+ elif p < 0.7:
+ center_x, center_y, width, height = crop_leftarm_only(center_x, center_y, width, height, keypoints_2d)
+ elif p < 0.8:
+ center_x, center_y, width, height = crop_legs_only(center_x, center_y, width, height, keypoints_2d)
+ elif p < 0.9:
+ center_x, center_y, width, height = crop_rightleg_only(center_x, center_y, width, height, keypoints_2d)
+ else:
+ center_x, center_y, width, height = crop_leftleg_only(center_x, center_y, width, height, keypoints_2d)
+ elif upper_body(keypoints_2d):
+ if p < 0.2:
+ center_x, center_y, width, height = crop_to_shoulders(center_x, center_y, width, height, keypoints_2d)
+ elif p < 0.4:
+ center_x, center_y, width, height = crop_to_head(center_x, center_y, width, height, keypoints_2d)
+ elif p < 0.6:
+ center_x, center_y, width, height = crop_torso_only(center_x, center_y, width, height, keypoints_2d)
+ elif p < 0.8:
+ center_x, center_y, width, height = crop_rightarm_only(center_x, center_y, width, height, keypoints_2d)
+ else:
+ center_x, center_y, width, height = crop_leftarm_only(center_x, center_y, width, height, keypoints_2d)
+ return center_x, center_y, max(width, height), max(width, height)
diff --git a/wilor/datasets/vitdet_dataset.py b/wilor/datasets/vitdet_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..1fd6e8b9bc52d927c98959174bc556bf796f7162
--- /dev/null
+++ b/wilor/datasets/vitdet_dataset.py
@@ -0,0 +1,95 @@
+from typing import Dict
+
+import cv2
+import numpy as np
+from skimage.filters import gaussian
+from yacs.config import CfgNode
+import torch
+
+from .utils import (convert_cvimg_to_tensor,
+ expand_to_aspect_ratio,
+ generate_image_patch_cv2)
+
+DEFAULT_MEAN = 255. * np.array([0.485, 0.456, 0.406])
+DEFAULT_STD = 255. * np.array([0.229, 0.224, 0.225])
+
+class ViTDetDataset(torch.utils.data.Dataset):
+
+ def __init__(self,
+ cfg: CfgNode,
+ img_cv2: np.array,
+ boxes: np.array,
+ right: np.array,
+ rescale_factor=2.5,
+ train: bool = False,
+ **kwargs):
+ super().__init__()
+ self.cfg = cfg
+ self.img_cv2 = img_cv2
+ # self.boxes = boxes
+
+ assert train == False, "ViTDetDataset is only for inference"
+ self.train = train
+ self.img_size = cfg.MODEL.IMAGE_SIZE
+ self.mean = 255. * np.array(self.cfg.MODEL.IMAGE_MEAN)
+ self.std = 255. * np.array(self.cfg.MODEL.IMAGE_STD)
+
+ # Preprocess annotations
+ boxes = boxes.astype(np.float32)
+ self.center = (boxes[:, 2:4] + boxes[:, 0:2]) / 2.0
+ self.scale = rescale_factor * (boxes[:, 2:4] - boxes[:, 0:2]) / 200.0
+ self.personid = np.arange(len(boxes), dtype=np.int32)
+ self.right = right.astype(np.float32)
+
+ def __len__(self) -> int:
+ return len(self.personid)
+
+ def __getitem__(self, idx: int) -> Dict[str, np.array]:
+
+ center = self.center[idx].copy()
+ center_x = center[0]
+ center_y = center[1]
+
+ scale = self.scale[idx]
+ BBOX_SHAPE = self.cfg.MODEL.get('BBOX_SHAPE', None)
+ bbox_size = expand_to_aspect_ratio(scale*200, target_aspect_ratio=BBOX_SHAPE).max()
+
+ patch_width = patch_height = self.img_size
+
+ right = self.right[idx].copy()
+ flip = right == 0
+
+ # 3. generate image patch
+ # if use_skimage_antialias:
+ cvimg = self.img_cv2.copy()
+ if True:
+ # Blur image to avoid aliasing artifacts
+ downsampling_factor = ((bbox_size*1.0) / patch_width)
+ #print(f'{downsampling_factor=}')
+ downsampling_factor = downsampling_factor / 2.0
+ if downsampling_factor > 1.1:
+ cvimg = gaussian(cvimg, sigma=(downsampling_factor-1)/2, channel_axis=2, preserve_range=True)
+
+
+ img_patch_cv, trans = generate_image_patch_cv2(cvimg,
+ center_x, center_y,
+ bbox_size, bbox_size,
+ patch_width, patch_height,
+ flip, 1.0, 0,
+ border_mode=cv2.BORDER_CONSTANT)
+ img_patch_cv = img_patch_cv[:, :, ::-1]
+ img_patch = convert_cvimg_to_tensor(img_patch_cv)
+
+ # apply normalization
+ for n_c in range(min(self.img_cv2.shape[2], 3)):
+ img_patch[n_c, :, :] = (img_patch[n_c, :, :] - self.mean[n_c]) / self.std[n_c]
+
+ item = {
+ 'img': img_patch,
+ 'personid': int(self.personid[idx]),
+ }
+ item['box_center'] = self.center[idx].copy()
+ item['box_size'] = bbox_size
+ item['img_size'] = 1.0 * np.array([cvimg.shape[1], cvimg.shape[0]])
+ item['right'] = self.right[idx].copy()
+ return item
diff --git a/wilor/models/__init__.py b/wilor/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..94187b99b6c63ce2113e72d5b2ffa906ab673cc6
--- /dev/null
+++ b/wilor/models/__init__.py
@@ -0,0 +1,36 @@
+from .mano_wrapper import MANO
+from .wilor import WiLoR
+
+from .discriminator import Discriminator
+
+def load_wilor(checkpoint_path, cfg_path):
+ from pathlib import Path
+ from wilor.configs import get_config
+ print('Loading ', checkpoint_path)
+ model_cfg = get_config(cfg_path, update_cachedir=True)
+
+ # Override some config values, to crop bbox correctly
+ if ('vit' in model_cfg.MODEL.BACKBONE.TYPE) and ('BBOX_SHAPE' not in model_cfg.MODEL):
+
+ model_cfg.defrost()
+ assert model_cfg.MODEL.IMAGE_SIZE == 256, f"MODEL.IMAGE_SIZE ({model_cfg.MODEL.IMAGE_SIZE}) should be 256 for ViT backbone"
+ model_cfg.MODEL.BBOX_SHAPE = [192,256]
+ model_cfg.freeze()
+
+ # Update config to be compatible with demo
+ if ('PRETRAINED_WEIGHTS' in model_cfg.MODEL.BACKBONE):
+ model_cfg.defrost()
+ model_cfg.MODEL.BACKBONE.pop('PRETRAINED_WEIGHTS')
+ model_cfg.freeze()
+
+ # Update config to be compatible with demo
+
+ if ('DATA_DIR' in model_cfg.MANO):
+ model_cfg.defrost()
+ model_cfg.MANO.DATA_DIR = './mano_data/'
+ model_cfg.MANO.MODEL_PATH = './mano_data/mano/'
+ model_cfg.MANO.MEAN_PARAMS = './mano_data/mano_mean_params.npz'
+ model_cfg.freeze()
+
+ model = WiLoR.load_from_checkpoint(checkpoint_path, strict=False, cfg=model_cfg)
+ return model, model_cfg
\ No newline at end of file
diff --git a/wilor/models/backbones/__init__.py b/wilor/models/backbones/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0f4305a01fadf44319fd331cf7fd987f33a23510
--- /dev/null
+++ b/wilor/models/backbones/__init__.py
@@ -0,0 +1,17 @@
+from .vit import vit
+
+def create_backbone(cfg):
+ if cfg.MODEL.BACKBONE.TYPE == 'vit':
+ return vit(cfg)
+ elif cfg.MODEL.BACKBONE.TYPE == 'fast_vit':
+ import torch
+ import sys
+ from timm.models import create_model
+ #from models.modules.mobileone import reparameterize_model
+ fast_vit = create_model("fastvit_ma36", drop_path_rate=0.2)
+ checkpoint = torch.load('./pretrained_models/fastvit_ma36.pt')
+ fast_vit.load_state_dict(checkpoint['state_dict'])
+ return fast_vit
+
+ else:
+ raise NotImplementedError('Backbone type is not implemented')
diff --git a/wilor/models/backbones/vit.py b/wilor/models/backbones/vit.py
new file mode 100644
index 0000000000000000000000000000000000000000..704e4922f6d9688d9434fd8598acbcda4ca6c0e7
--- /dev/null
+++ b/wilor/models/backbones/vit.py
@@ -0,0 +1,410 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+import numpy as np
+import torch
+from functools import partial
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as checkpoint
+from ...utils.geometry import rot6d_to_rotmat, aa_to_rotmat
+from timm.models.layers import drop_path, to_2tuple, trunc_normal_
+
+def vit(cfg):
+ return ViT(
+ img_size=(256, 192),
+ patch_size=16,
+ embed_dim=1280,
+ depth=32,
+ num_heads=16,
+ ratio=1,
+ use_checkpoint=False,
+ mlp_ratio=4,
+ qkv_bias=True,
+ drop_path_rate=0.55,
+ cfg = cfg
+ )
+
+def get_abs_pos(abs_pos, h, w, ori_h, ori_w, has_cls_token=True):
+ """
+ Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token
+ dimension for the original embeddings.
+ Args:
+ abs_pos (Tensor): absolute positional embeddings with (1, num_position, C).
+ has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token.
+ hw (Tuple): size of input image tokens.
+
+ Returns:
+ Absolute positional embeddings after processing with shape (1, H, W, C)
+ """
+ cls_token = None
+ B, L, C = abs_pos.shape
+ if has_cls_token:
+ cls_token = abs_pos[:, 0:1]
+ abs_pos = abs_pos[:, 1:]
+
+ if ori_h != h or ori_w != w:
+ new_abs_pos = F.interpolate(
+ abs_pos.reshape(1, ori_h, ori_w, -1).permute(0, 3, 1, 2),
+ size=(h, w),
+ mode="bicubic",
+ align_corners=False,
+ ).permute(0, 2, 3, 1).reshape(B, -1, C)
+
+ else:
+ new_abs_pos = abs_pos
+
+ if cls_token is not None:
+ new_abs_pos = torch.cat([cls_token, new_abs_pos], dim=1)
+ return new_abs_pos
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+ """
+ def __init__(self, drop_prob=None):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training)
+
+ def extra_repr(self):
+ return 'p={}'.format(self.drop_prob)
+
+class Mlp(nn.Module):
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+class Attention(nn.Module):
+ def __init__(
+ self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
+ proj_drop=0., attn_head_dim=None,):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.dim = dim
+
+ if attn_head_dim is not None:
+ head_dim = attn_head_dim
+ all_head_dim = head_dim * self.num_heads
+
+ self.scale = qk_scale or head_dim ** -0.5
+
+ self.qkv = nn.Linear(dim, all_head_dim * 3, bias=qkv_bias)
+
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(all_head_dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x):
+ B, N, C = x.shape
+ qkv = self.qkv(x)
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
+
+ q = q * self.scale
+ attn = (q @ k.transpose(-2, -1))
+
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+
+ return x
+
+class Block(nn.Module):
+
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None,
+ drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU,
+ norm_layer=nn.LayerNorm, attn_head_dim=None
+ ):
+ super().__init__()
+
+ self.norm1 = norm_layer(dim)
+ self.attn = Attention(
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
+ attn_drop=attn_drop, proj_drop=drop, attn_head_dim=attn_head_dim
+ )
+
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+ def forward(self, x):
+ x = x + self.drop_path(self.attn(self.norm1(x)))
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+ return x
+
+
+class PatchEmbed(nn.Module):
+ """ Image to Patch Embedding
+ """
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, ratio=1):
+ super().__init__()
+ img_size = to_2tuple(img_size)
+ patch_size = to_2tuple(patch_size)
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) * (ratio ** 2)
+ self.patch_shape = (int(img_size[0] // patch_size[0] * ratio), int(img_size[1] // patch_size[1] * ratio))
+ self.origin_patch_shape = (int(img_size[0] // patch_size[0]), int(img_size[1] // patch_size[1]))
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.num_patches = num_patches
+
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=(patch_size[0] // ratio), padding=4 + 2 * (ratio//2-1))
+
+ def forward(self, x, **kwargs):
+ B, C, H, W = x.shape
+ x = self.proj(x)
+ Hp, Wp = x.shape[2], x.shape[3]
+
+ x = x.flatten(2).transpose(1, 2)
+ return x, (Hp, Wp)
+
+
+class HybridEmbed(nn.Module):
+ """ CNN Feature Map Embedding
+ Extract feature map from CNN, flatten, project to embedding dim.
+ """
+ def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768):
+ super().__init__()
+ assert isinstance(backbone, nn.Module)
+ img_size = to_2tuple(img_size)
+ self.img_size = img_size
+ self.backbone = backbone
+ if feature_size is None:
+ with torch.no_grad():
+ training = backbone.training
+ if training:
+ backbone.eval()
+ o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1]
+ feature_size = o.shape[-2:]
+ feature_dim = o.shape[1]
+ backbone.train(training)
+ else:
+ feature_size = to_2tuple(feature_size)
+ feature_dim = self.backbone.feature_info.channels()[-1]
+ self.num_patches = feature_size[0] * feature_size[1]
+ self.proj = nn.Linear(feature_dim, embed_dim)
+
+ def forward(self, x):
+ x = self.backbone(x)[-1]
+ x = x.flatten(2).transpose(1, 2)
+ x = self.proj(x)
+ return x
+
+
+class ViT(nn.Module):
+
+ def __init__(self,
+ img_size=224, patch_size=16, in_chans=3, num_classes=80, embed_dim=768, depth=12,
+ num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
+ drop_path_rate=0., hybrid_backbone=None, norm_layer=None, use_checkpoint=False,
+ frozen_stages=-1, ratio=1, last_norm=True,
+ patch_padding='pad', freeze_attn=False, freeze_ffn=False,cfg=None,
+ ):
+ # Protect mutable default arguments
+ super(ViT, self).__init__()
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
+ self.num_classes = num_classes
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
+ self.frozen_stages = frozen_stages
+ self.use_checkpoint = use_checkpoint
+ self.patch_padding = patch_padding
+ self.freeze_attn = freeze_attn
+ self.freeze_ffn = freeze_ffn
+ self.depth = depth
+
+ if hybrid_backbone is not None:
+ self.patch_embed = HybridEmbed(
+ hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)
+ else:
+ self.patch_embed = PatchEmbed(
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, ratio=ratio)
+ num_patches = self.patch_embed.num_patches
+
+ ##########################################
+ self.cfg = cfg
+ self.joint_rep_type = cfg.MODEL.MANO_HEAD.get('JOINT_REP', '6d')
+ self.joint_rep_dim = {'6d': 6, 'aa': 3}[self.joint_rep_type]
+ npose = self.joint_rep_dim * (cfg.MANO.NUM_HAND_JOINTS + 1)
+ self.npose = npose
+ mean_params = np.load(cfg.MANO.MEAN_PARAMS)
+ init_cam = torch.from_numpy(mean_params['cam'].astype(np.float32)).unsqueeze(0)
+ self.register_buffer('init_cam', init_cam)
+ init_hand_pose = torch.from_numpy(mean_params['pose'].astype(np.float32)).unsqueeze(0)
+ init_betas = torch.from_numpy(mean_params['shape'].astype('float32')).unsqueeze(0)
+ self.register_buffer('init_hand_pose', init_hand_pose)
+ self.register_buffer('init_betas', init_betas)
+
+ self.pose_emb = nn.Linear(self.joint_rep_dim , embed_dim)
+ self.shape_emb = nn.Linear(10 , embed_dim)
+ self.cam_emb = nn.Linear(3 , embed_dim)
+
+ self.decpose = nn.Linear(self.num_features, 6)
+ self.decshape = nn.Linear(self.num_features, 10)
+ self.deccam = nn.Linear(self.num_features, 3)
+ if cfg.MODEL.MANO_HEAD.get('INIT_DECODER_XAVIER', False):
+ # True by default in MLP. False by default in Transformer
+ nn.init.xavier_uniform_(self.decpose.weight, gain=0.01)
+ nn.init.xavier_uniform_(self.decshape.weight, gain=0.01)
+ nn.init.xavier_uniform_(self.deccam.weight, gain=0.01)
+
+
+ ##########################################
+
+ # since the pretraining model has class token
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
+
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
+
+ self.blocks = nn.ModuleList([
+ Block(
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
+ )
+ for i in range(depth)])
+
+ self.last_norm = norm_layer(embed_dim) if last_norm else nn.Identity()
+
+ if self.pos_embed is not None:
+ trunc_normal_(self.pos_embed, std=.02)
+
+ self._freeze_stages()
+
+ def _freeze_stages(self):
+ """Freeze parameters."""
+ if self.frozen_stages >= 0:
+ self.patch_embed.eval()
+ for param in self.patch_embed.parameters():
+ param.requires_grad = False
+
+ for i in range(1, self.frozen_stages + 1):
+ m = self.blocks[i]
+ m.eval()
+ for param in m.parameters():
+ param.requires_grad = False
+
+ if self.freeze_attn:
+ for i in range(0, self.depth):
+ m = self.blocks[i]
+ m.attn.eval()
+ m.norm1.eval()
+ for param in m.attn.parameters():
+ param.requires_grad = False
+ for param in m.norm1.parameters():
+ param.requires_grad = False
+
+ if self.freeze_ffn:
+ self.pos_embed.requires_grad = False
+ self.patch_embed.eval()
+ for param in self.patch_embed.parameters():
+ param.requires_grad = False
+ for i in range(0, self.depth):
+ m = self.blocks[i]
+ m.mlp.eval()
+ m.norm2.eval()
+ for param in m.mlp.parameters():
+ param.requires_grad = False
+ for param in m.norm2.parameters():
+ param.requires_grad = False
+
+ def init_weights(self):
+ """Initialize the weights in backbone.
+ Args:
+ pretrained (str, optional): Path to pre-trained weights.
+ Defaults to None.
+ """
+ def _init_weights(m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ self.apply(_init_weights)
+
+ def get_num_layers(self):
+ return len(self.blocks)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {'pos_embed', 'cls_token'}
+
+ def forward_features(self, x):
+ B, C, H, W = x.shape
+ x, (Hp, Wp) = self.patch_embed(x)
+
+ if self.pos_embed is not None:
+ # fit for multiple GPU training
+ # since the first element for pos embed (sin-cos manner) is zero, it will cause no difference
+ x = x + self.pos_embed[:, 1:] + self.pos_embed[:, :1]
+ # X [B, 192, 1280]
+ # x cat [ mean_pose, mean_shape, mean_cam] tokens
+ pose_tokens = self.pose_emb(self.init_hand_pose.reshape(1, self.cfg.MANO.NUM_HAND_JOINTS + 1, self.joint_rep_dim)).repeat(B, 1, 1)
+ shape_tokens = self.shape_emb(self.init_betas).unsqueeze(1).repeat(B, 1, 1)
+ cam_tokens = self.cam_emb(self.init_cam).unsqueeze(1).repeat(B, 1, 1)
+
+ x = torch.cat([pose_tokens, shape_tokens, cam_tokens, x], 1)
+ for blk in self.blocks:
+ if self.use_checkpoint:
+ x = checkpoint.checkpoint(blk, x)
+ else:
+ x = blk(x)
+
+ x = self.last_norm(x)
+
+
+ pose_feat = x[:, :(self.cfg.MANO.NUM_HAND_JOINTS + 1)]
+ shape_feat = x[:, (self.cfg.MANO.NUM_HAND_JOINTS + 1):1+(self.cfg.MANO.NUM_HAND_JOINTS + 1)]
+ cam_feat = x[:, 1+(self.cfg.MANO.NUM_HAND_JOINTS + 1):2+(self.cfg.MANO.NUM_HAND_JOINTS + 1)]
+
+ #print(pose_feat.shape, shape_feat.shape, cam_feat.shape)
+ pred_hand_pose = self.decpose(pose_feat).reshape(B, -1) + self.init_hand_pose #B , 96
+ pred_betas = self.decshape(shape_feat).reshape(B, -1) + self.init_betas #B , 10
+ pred_cam = self.deccam(cam_feat).reshape(B, -1) + self.init_cam #B , 3
+
+ pred_mano_feats = {}
+ pred_mano_feats['hand_pose'] = pred_hand_pose
+ pred_mano_feats['betas'] = pred_betas
+ pred_mano_feats['cam'] = pred_cam
+
+
+ joint_conversion_fn = {
+ '6d': rot6d_to_rotmat,
+ 'aa': lambda x: aa_to_rotmat(x.view(-1, 3).contiguous())
+ }[self.joint_rep_type]
+
+ pred_hand_pose = joint_conversion_fn(pred_hand_pose).view(B, self.cfg.MANO.NUM_HAND_JOINTS+1, 3, 3)
+ pred_mano_params = {'global_orient': pred_hand_pose[:, [0]],
+ 'hand_pose': pred_hand_pose[:, 1:],
+ 'betas': pred_betas}
+
+ img_feat = x[:, 2+(self.cfg.MANO.NUM_HAND_JOINTS + 1):].reshape(B, Hp, Wp, -1).permute(0, 3, 1, 2)
+ return pred_mano_params, pred_cam, pred_mano_feats, img_feat
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ return x
+
+ def train(self, mode=True):
+ """Convert the model into training mode."""
+ super().train(mode)
+ self._freeze_stages()
diff --git a/wilor/models/discriminator.py b/wilor/models/discriminator.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1cb2d1a21fbab47e8fa10dcc603b3d2012686a7
--- /dev/null
+++ b/wilor/models/discriminator.py
@@ -0,0 +1,98 @@
+import torch
+import torch.nn as nn
+
+class Discriminator(nn.Module):
+
+ def __init__(self):
+ """
+ Pose + Shape discriminator proposed in HMR
+ """
+ super(Discriminator, self).__init__()
+
+ self.num_joints = 15
+ # poses_alone
+ self.D_conv1 = nn.Conv2d(9, 32, kernel_size=1)
+ nn.init.xavier_uniform_(self.D_conv1.weight)
+ nn.init.zeros_(self.D_conv1.bias)
+ self.relu = nn.ReLU(inplace=True)
+ self.D_conv2 = nn.Conv2d(32, 32, kernel_size=1)
+ nn.init.xavier_uniform_(self.D_conv2.weight)
+ nn.init.zeros_(self.D_conv2.bias)
+ pose_out = []
+ for i in range(self.num_joints):
+ pose_out_temp = nn.Linear(32, 1)
+ nn.init.xavier_uniform_(pose_out_temp.weight)
+ nn.init.zeros_(pose_out_temp.bias)
+ pose_out.append(pose_out_temp)
+ self.pose_out = nn.ModuleList(pose_out)
+
+ # betas
+ self.betas_fc1 = nn.Linear(10, 10)
+ nn.init.xavier_uniform_(self.betas_fc1.weight)
+ nn.init.zeros_(self.betas_fc1.bias)
+ self.betas_fc2 = nn.Linear(10, 5)
+ nn.init.xavier_uniform_(self.betas_fc2.weight)
+ nn.init.zeros_(self.betas_fc2.bias)
+ self.betas_out = nn.Linear(5, 1)
+ nn.init.xavier_uniform_(self.betas_out.weight)
+ nn.init.zeros_(self.betas_out.bias)
+
+ # poses_joint
+ self.D_alljoints_fc1 = nn.Linear(32*self.num_joints, 1024)
+ nn.init.xavier_uniform_(self.D_alljoints_fc1.weight)
+ nn.init.zeros_(self.D_alljoints_fc1.bias)
+ self.D_alljoints_fc2 = nn.Linear(1024, 1024)
+ nn.init.xavier_uniform_(self.D_alljoints_fc2.weight)
+ nn.init.zeros_(self.D_alljoints_fc2.bias)
+ self.D_alljoints_out = nn.Linear(1024, 1)
+ nn.init.xavier_uniform_(self.D_alljoints_out.weight)
+ nn.init.zeros_(self.D_alljoints_out.bias)
+
+
+ def forward(self, poses: torch.Tensor, betas: torch.Tensor) -> torch.Tensor:
+ """
+ Forward pass of the discriminator.
+ Args:
+ poses (torch.Tensor): Tensor of shape (B, 23, 3, 3) containing a batch of MANO hand poses (excluding the global orientation).
+ betas (torch.Tensor): Tensor of shape (B, 10) containign a batch of MANO beta coefficients.
+ Returns:
+ torch.Tensor: Discriminator output with shape (B, 25)
+ """
+ #bn = poses.shape[0]
+ # poses B x 207
+ #poses = poses.reshape(bn, -1)
+ # poses B x num_joints x 1 x 9
+ poses = poses.reshape(-1, self.num_joints, 1, 9)
+ bn = poses.shape[0]
+ # poses B x 9 x num_joints x 1
+ poses = poses.permute(0, 3, 1, 2).contiguous()
+
+ # poses_alone
+ poses = self.D_conv1(poses)
+ poses = self.relu(poses)
+ poses = self.D_conv2(poses)
+ poses = self.relu(poses)
+
+ poses_out = []
+ for i in range(self.num_joints):
+ poses_out_ = self.pose_out[i](poses[:, :, i, 0])
+ poses_out.append(poses_out_)
+ poses_out = torch.cat(poses_out, dim=1)
+
+ # betas
+ betas = self.betas_fc1(betas)
+ betas = self.relu(betas)
+ betas = self.betas_fc2(betas)
+ betas = self.relu(betas)
+ betas_out = self.betas_out(betas)
+
+ # poses_joint
+ poses = poses.reshape(bn,-1)
+ poses_all = self.D_alljoints_fc1(poses)
+ poses_all = self.relu(poses_all)
+ poses_all = self.D_alljoints_fc2(poses_all)
+ poses_all = self.relu(poses_all)
+ poses_all_out = self.D_alljoints_out(poses_all)
+
+ disc_out = torch.cat((poses_out, betas_out, poses_all_out), 1)
+ return disc_out
diff --git a/wilor/models/heads/__init__.py b/wilor/models/heads/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..40279d65814d344cd6a6f452356c4ac4e6a633b3
--- /dev/null
+++ b/wilor/models/heads/__init__.py
@@ -0,0 +1 @@
+from .refinement_net import RefineNet
\ No newline at end of file
diff --git a/wilor/models/heads/refinement_net.py b/wilor/models/heads/refinement_net.py
new file mode 100644
index 0000000000000000000000000000000000000000..98cd8ef17617b2f5b5efbd0731aab103be77a361
--- /dev/null
+++ b/wilor/models/heads/refinement_net.py
@@ -0,0 +1,204 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import math
+from ...utils.geometry import rot6d_to_rotmat, aa_to_rotmat
+from typing import Optional
+
+def make_linear_layers(feat_dims, relu_final=True, use_bn=False):
+ layers = []
+ for i in range(len(feat_dims)-1):
+ layers.append(nn.Linear(feat_dims[i], feat_dims[i+1]))
+
+ # Do not use ReLU for final estimation
+ if i < len(feat_dims)-2 or (i == len(feat_dims)-2 and relu_final):
+ if use_bn:
+ layers.append(nn.BatchNorm1d(feat_dims[i+1]))
+ layers.append(nn.ReLU(inplace=True))
+
+ return nn.Sequential(*layers)
+
+def make_conv_layers(feat_dims, kernel=3, stride=1, padding=1, bnrelu_final=True):
+ layers = []
+ for i in range(len(feat_dims)-1):
+ layers.append(
+ nn.Conv2d(
+ in_channels=feat_dims[i],
+ out_channels=feat_dims[i+1],
+ kernel_size=kernel,
+ stride=stride,
+ padding=padding
+ ))
+ # Do not use BN and ReLU for final estimation
+ if i < len(feat_dims)-2 or (i == len(feat_dims)-2 and bnrelu_final):
+ layers.append(nn.BatchNorm2d(feat_dims[i+1]))
+ layers.append(nn.ReLU(inplace=True))
+
+ return nn.Sequential(*layers)
+
+def make_deconv_layers(feat_dims, bnrelu_final=True):
+ layers = []
+ for i in range(len(feat_dims)-1):
+ layers.append(
+ nn.ConvTranspose2d(
+ in_channels=feat_dims[i],
+ out_channels=feat_dims[i+1],
+ kernel_size=4,
+ stride=2,
+ padding=1,
+ output_padding=0,
+ bias=False))
+
+ # Do not use BN and ReLU for final estimation
+ if i < len(feat_dims)-2 or (i == len(feat_dims)-2 and bnrelu_final):
+ layers.append(nn.BatchNorm2d(feat_dims[i+1]))
+ layers.append(nn.ReLU(inplace=True))
+
+ return nn.Sequential(*layers)
+
+def sample_joint_features(img_feat, joint_xy):
+ height, width = img_feat.shape[2:]
+ x = joint_xy[:, :, 0] / (width - 1) * 2 - 1
+ y = joint_xy[:, :, 1] / (height - 1) * 2 - 1
+ grid = torch.stack((x, y), 2)[:, :, None, :]
+ img_feat = F.grid_sample(img_feat, grid, align_corners=True)[:, :, :, 0] # batch_size, channel_dim, joint_num
+ img_feat = img_feat.permute(0, 2, 1).contiguous() # batch_size, joint_num, channel_dim
+ return img_feat
+
+def perspective_projection(points: torch.Tensor,
+ translation: torch.Tensor,
+ focal_length: torch.Tensor,
+ camera_center: Optional[torch.Tensor] = None,
+ rotation: Optional[torch.Tensor] = None) -> torch.Tensor:
+ """
+ Computes the perspective projection of a set of 3D points.
+ Args:
+ points (torch.Tensor): Tensor of shape (B, N, 3) containing the input 3D points.
+ translation (torch.Tensor): Tensor of shape (B, 3) containing the 3D camera translation.
+ focal_length (torch.Tensor): Tensor of shape (B, 2) containing the focal length in pixels.
+ camera_center (torch.Tensor): Tensor of shape (B, 2) containing the camera center in pixels.
+ rotation (torch.Tensor): Tensor of shape (B, 3, 3) containing the camera rotation.
+ Returns:
+ torch.Tensor: Tensor of shape (B, N, 2) containing the projection of the input points.
+ """
+ batch_size = points.shape[0]
+ if rotation is None:
+ rotation = torch.eye(3, device=points.device, dtype=points.dtype).unsqueeze(0).expand(batch_size, -1, -1)
+ if camera_center is None:
+ camera_center = torch.zeros(batch_size, 2, device=points.device, dtype=points.dtype)
+ # Populate intrinsic camera matrix K.
+ K = torch.zeros([batch_size, 3, 3], device=points.device, dtype=points.dtype)
+ K[:,0,0] = focal_length[:,0]
+ K[:,1,1] = focal_length[:,1]
+ K[:,2,2] = 1.
+ K[:,:-1, -1] = camera_center
+ # Transform points
+ points = torch.einsum('bij,bkj->bki', rotation, points)
+ points = points + translation.unsqueeze(1)
+
+ # Apply perspective distortion
+ projected_points = points / points[:,:,-1].unsqueeze(-1)
+
+ # Apply camera intrinsics
+ projected_points = torch.einsum('bij,bkj->bki', K, projected_points)
+
+ return projected_points[:, :, :-1]
+
+class DeConvNet(nn.Module):
+ def __init__(self, feat_dim=768, upscale=4):
+ super(DeConvNet, self).__init__()
+ self.first_conv = make_conv_layers([feat_dim, feat_dim//2], kernel=1, stride=1, padding=0, bnrelu_final=False)
+ self.deconv = nn.ModuleList([])
+ for i in range(int(math.log2(upscale))+1):
+ if i==0:
+ self.deconv.append(make_deconv_layers([feat_dim//2, feat_dim//4]))
+ elif i==1:
+ self.deconv.append(make_deconv_layers([feat_dim//2, feat_dim//4, feat_dim//8]))
+ elif i==2:
+ self.deconv.append(make_deconv_layers([feat_dim//2, feat_dim//4, feat_dim//8, feat_dim//8]))
+
+ def forward(self, img_feat):
+
+ face_img_feats = []
+ img_feat = self.first_conv(img_feat)
+ face_img_feats.append(img_feat)
+ for i, deconv in enumerate(self.deconv):
+ scale = 2**i
+ img_feat_i = deconv(img_feat)
+ face_img_feat = img_feat_i
+ face_img_feats.append(face_img_feat)
+ return face_img_feats[::-1] # high resolution -> low resolution
+
+class DeConvNet_v2(nn.Module):
+ def __init__(self, feat_dim=768):
+ super(DeConvNet_v2, self).__init__()
+ self.first_conv = make_conv_layers([feat_dim, feat_dim//2], kernel=1, stride=1, padding=0, bnrelu_final=False)
+ self.deconv = nn.Sequential(*[nn.ConvTranspose2d(in_channels=feat_dim//2, out_channels=feat_dim//4, kernel_size=4, stride=4, padding=0, output_padding=0, bias=False),
+ nn.BatchNorm2d(feat_dim//4),
+ nn.ReLU(inplace=True)])
+
+ def forward(self, img_feat):
+
+ face_img_feats = []
+ img_feat = self.first_conv(img_feat)
+ img_feat = self.deconv(img_feat)
+
+ return [img_feat]
+
+class RefineNet(nn.Module):
+ def __init__(self, cfg, feat_dim=1280, upscale=3):
+ super(RefineNet, self).__init__()
+ #self.deconv = DeConvNet_v2(feat_dim=feat_dim)
+ #self.out_dim = feat_dim//4
+
+ self.deconv = DeConvNet(feat_dim=feat_dim, upscale=upscale)
+ self.out_dim = feat_dim//8 + feat_dim//4 + feat_dim//2
+ self.dec_pose = nn.Linear(self.out_dim, 96)
+ self.dec_cam = nn.Linear(self.out_dim, 3)
+ self.dec_shape = nn.Linear(self.out_dim, 10)
+
+ self.cfg = cfg
+ self.joint_rep_type = cfg.MODEL.MANO_HEAD.get('JOINT_REP', '6d')
+ self.joint_rep_dim = {'6d': 6, 'aa': 3}[self.joint_rep_type]
+
+ def forward(self, img_feat, verts_3d, pred_cam, pred_mano_feats, focal_length):
+ B = img_feat.shape[0]
+
+ img_feats = self.deconv(img_feat)
+
+ img_feat_sizes = [img_feat.shape[2] for img_feat in img_feats]
+
+ temp_cams = [torch.stack([pred_cam[:, 1], pred_cam[:, 2],
+ 2*focal_length[:, 0]/(img_feat_size * pred_cam[:, 0] +1e-9)],dim=-1) for img_feat_size in img_feat_sizes]
+
+ verts_2d = [perspective_projection(verts_3d,
+ translation=temp_cams[i],
+ focal_length=focal_length / img_feat_sizes[i]) for i in range(len(img_feat_sizes))]
+
+ vert_feats = [sample_joint_features(img_feats[i], verts_2d[i]).max(1).values for i in range(len(img_feat_sizes))]
+
+ vert_feats = torch.cat(vert_feats, dim=-1)
+
+ delta_pose = self.dec_pose(vert_feats)
+ delta_betas = self.dec_shape(vert_feats)
+ delta_cam = self.dec_cam(vert_feats)
+
+
+ pred_hand_pose = pred_mano_feats['hand_pose'] + delta_pose
+ pred_betas = pred_mano_feats['betas'] + delta_betas
+ pred_cam = pred_mano_feats['cam'] + delta_cam
+
+ joint_conversion_fn = {
+ '6d': rot6d_to_rotmat,
+ 'aa': lambda x: aa_to_rotmat(x.view(-1, 3).contiguous())
+ }[self.joint_rep_type]
+
+ pred_hand_pose = joint_conversion_fn(pred_hand_pose).view(B, self.cfg.MANO.NUM_HAND_JOINTS+1, 3, 3)
+
+ pred_mano_params = {'global_orient': pred_hand_pose[:, [0]],
+ 'hand_pose': pred_hand_pose[:, 1:],
+ 'betas': pred_betas}
+
+ return pred_mano_params, pred_cam
+
+
\ No newline at end of file
diff --git a/wilor/models/losses.py b/wilor/models/losses.py
new file mode 100644
index 0000000000000000000000000000000000000000..d6e493c081a4d99b97b5641e85152c4d56072a58
--- /dev/null
+++ b/wilor/models/losses.py
@@ -0,0 +1,92 @@
+import torch
+import torch.nn as nn
+
+class Keypoint2DLoss(nn.Module):
+
+ def __init__(self, loss_type: str = 'l1'):
+ """
+ 2D keypoint loss module.
+ Args:
+ loss_type (str): Choose between l1 and l2 losses.
+ """
+ super(Keypoint2DLoss, self).__init__()
+ if loss_type == 'l1':
+ self.loss_fn = nn.L1Loss(reduction='none')
+ elif loss_type == 'l2':
+ self.loss_fn = nn.MSELoss(reduction='none')
+ else:
+ raise NotImplementedError('Unsupported loss function')
+
+ def forward(self, pred_keypoints_2d: torch.Tensor, gt_keypoints_2d: torch.Tensor) -> torch.Tensor:
+ """
+ Compute 2D reprojection loss on the keypoints.
+ Args:
+ pred_keypoints_2d (torch.Tensor): Tensor of shape [B, S, N, 2] containing projected 2D keypoints (B: batch_size, S: num_samples, N: num_keypoints)
+ gt_keypoints_2d (torch.Tensor): Tensor of shape [B, S, N, 3] containing the ground truth 2D keypoints and confidence.
+ Returns:
+ torch.Tensor: 2D keypoint loss.
+ """
+ conf = gt_keypoints_2d[:, :, -1].unsqueeze(-1).clone()
+ batch_size = conf.shape[0]
+ loss = (conf * self.loss_fn(pred_keypoints_2d, gt_keypoints_2d[:, :, :-1])).sum(dim=(1,2))
+ return loss.sum()
+
+
+class Keypoint3DLoss(nn.Module):
+
+ def __init__(self, loss_type: str = 'l1'):
+ """
+ 3D keypoint loss module.
+ Args:
+ loss_type (str): Choose between l1 and l2 losses.
+ """
+ super(Keypoint3DLoss, self).__init__()
+ if loss_type == 'l1':
+ self.loss_fn = nn.L1Loss(reduction='none')
+ elif loss_type == 'l2':
+ self.loss_fn = nn.MSELoss(reduction='none')
+ else:
+ raise NotImplementedError('Unsupported loss function')
+
+ def forward(self, pred_keypoints_3d: torch.Tensor, gt_keypoints_3d: torch.Tensor, pelvis_id: int = 0):
+ """
+ Compute 3D keypoint loss.
+ Args:
+ pred_keypoints_3d (torch.Tensor): Tensor of shape [B, S, N, 3] containing the predicted 3D keypoints (B: batch_size, S: num_samples, N: num_keypoints)
+ gt_keypoints_3d (torch.Tensor): Tensor of shape [B, S, N, 4] containing the ground truth 3D keypoints and confidence.
+ Returns:
+ torch.Tensor: 3D keypoint loss.
+ """
+ batch_size = pred_keypoints_3d.shape[0]
+ gt_keypoints_3d = gt_keypoints_3d.clone()
+ pred_keypoints_3d = pred_keypoints_3d - pred_keypoints_3d[:, pelvis_id, :].unsqueeze(dim=1)
+ gt_keypoints_3d[:, :, :-1] = gt_keypoints_3d[:, :, :-1] - gt_keypoints_3d[:, pelvis_id, :-1].unsqueeze(dim=1)
+ conf = gt_keypoints_3d[:, :, -1].unsqueeze(-1).clone()
+ gt_keypoints_3d = gt_keypoints_3d[:, :, :-1]
+ loss = (conf * self.loss_fn(pred_keypoints_3d, gt_keypoints_3d)).sum(dim=(1,2))
+ return loss.sum()
+
+class ParameterLoss(nn.Module):
+
+ def __init__(self):
+ """
+ MANO parameter loss module.
+ """
+ super(ParameterLoss, self).__init__()
+ self.loss_fn = nn.MSELoss(reduction='none')
+
+ def forward(self, pred_param: torch.Tensor, gt_param: torch.Tensor, has_param: torch.Tensor):
+ """
+ Compute MANO parameter loss.
+ Args:
+ pred_param (torch.Tensor): Tensor of shape [B, S, ...] containing the predicted parameters (body pose / global orientation / betas)
+ gt_param (torch.Tensor): Tensor of shape [B, S, ...] containing the ground truth MANO parameters.
+ Returns:
+ torch.Tensor: L2 parameter loss loss.
+ """
+ batch_size = pred_param.shape[0]
+ num_dims = len(pred_param.shape)
+ mask_dimension = [batch_size] + [1] * (num_dims-1)
+ has_param = has_param.type(pred_param.type()).view(*mask_dimension)
+ loss_param = (has_param * self.loss_fn(pred_param, gt_param))
+ return loss_param.sum()
diff --git a/wilor/models/mano_wrapper.py b/wilor/models/mano_wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..f6f0cc336098e9303d2514c571307c56baf3bc86
--- /dev/null
+++ b/wilor/models/mano_wrapper.py
@@ -0,0 +1,40 @@
+import torch
+import numpy as np
+import pickle
+from typing import Optional
+import smplx
+from smplx.lbs import vertices2joints
+from smplx.utils import MANOOutput, to_tensor
+from smplx.vertex_ids import vertex_ids
+
+
+class MANO(smplx.MANOLayer):
+ def __init__(self, *args, joint_regressor_extra: Optional[str] = None, **kwargs):
+ """
+ Extension of the official MANO implementation to support more joints.
+ Args:
+ Same as MANOLayer.
+ joint_regressor_extra (str): Path to extra joint regressor.
+ """
+ super(MANO, self).__init__(*args, **kwargs)
+ mano_to_openpose = [0, 13, 14, 15, 16, 1, 2, 3, 17, 4, 5, 6, 18, 10, 11, 12, 19, 7, 8, 9, 20]
+
+ #2, 3, 5, 4, 1
+ if joint_regressor_extra is not None:
+ self.register_buffer('joint_regressor_extra', torch.tensor(pickle.load(open(joint_regressor_extra, 'rb'), encoding='latin1'), dtype=torch.float32))
+ self.register_buffer('extra_joints_idxs', to_tensor(list(vertex_ids['mano'].values()), dtype=torch.long))
+ self.register_buffer('joint_map', torch.tensor(mano_to_openpose, dtype=torch.long))
+
+ def forward(self, *args, **kwargs) -> MANOOutput:
+ """
+ Run forward pass. Same as MANO and also append an extra set of joints if joint_regressor_extra is specified.
+ """
+ mano_output = super(MANO, self).forward(*args, **kwargs)
+ extra_joints = torch.index_select(mano_output.vertices, 1, self.extra_joints_idxs)
+ joints = torch.cat([mano_output.joints, extra_joints], dim=1)
+ joints = joints[:, self.joint_map, :]
+ if hasattr(self, 'joint_regressor_extra'):
+ extra_joints = vertices2joints(self.joint_regressor_extra, mano_output.vertices)
+ joints = torch.cat([joints, extra_joints], dim=1)
+ mano_output.joints = joints
+ return mano_output
diff --git a/wilor/models/wilor.py b/wilor/models/wilor.py
new file mode 100644
index 0000000000000000000000000000000000000000..e5306376229a56931a444e693a6b0d070cc75bfd
--- /dev/null
+++ b/wilor/models/wilor.py
@@ -0,0 +1,376 @@
+import torch
+import pytorch_lightning as pl
+from typing import Any, Dict, Mapping, Tuple
+
+from yacs.config import CfgNode
+
+from ..utils import SkeletonRenderer, MeshRenderer
+from ..utils.geometry import aa_to_rotmat, perspective_projection
+from ..utils.pylogger import get_pylogger
+from .backbones import create_backbone
+from .heads import RefineNet
+from .discriminator import Discriminator
+from .losses import Keypoint3DLoss, Keypoint2DLoss, ParameterLoss
+from . import MANO
+
+log = get_pylogger(__name__)
+
+class WiLoR(pl.LightningModule):
+
+ def __init__(self, cfg: CfgNode, init_renderer: bool = True):
+ """
+ Setup WiLoR model
+ Args:
+ cfg (CfgNode): Config file as a yacs CfgNode
+ """
+ super().__init__()
+
+ # Save hyperparameters
+ self.save_hyperparameters(logger=False, ignore=['init_renderer'])
+
+ self.cfg = cfg
+ # Create backbone feature extractor
+ self.backbone = create_backbone(cfg)
+ if cfg.MODEL.BACKBONE.get('PRETRAINED_WEIGHTS', None):
+ log.info(f'Loading backbone weights from {cfg.MODEL.BACKBONE.PRETRAINED_WEIGHTS}')
+ self.backbone.load_state_dict(torch.load(cfg.MODEL.BACKBONE.PRETRAINED_WEIGHTS, map_location='cpu')['state_dict'], strict = False)
+
+ # Create RefineNet head
+ self.refine_net = RefineNet(cfg, feat_dim=1280, upscale=3)
+
+ # Create discriminator
+ if self.cfg.LOSS_WEIGHTS.ADVERSARIAL > 0:
+ self.discriminator = Discriminator()
+
+ # Define loss functions
+ self.keypoint_3d_loss = Keypoint3DLoss(loss_type='l1')
+ self.keypoint_2d_loss = Keypoint2DLoss(loss_type='l1')
+ self.mano_parameter_loss = ParameterLoss()
+
+ # Instantiate MANO model
+ mano_cfg = {k.lower(): v for k,v in dict(cfg.MANO).items()}
+ self.mano = MANO(**mano_cfg)
+
+ # Buffer that shows whetheer we need to initialize ActNorm layers
+ self.register_buffer('initialized', torch.tensor(False))
+ # Setup renderer for visualization
+ if init_renderer:
+ self.renderer = SkeletonRenderer(self.cfg)
+ self.mesh_renderer = MeshRenderer(self.cfg, faces=self.mano.faces)
+ else:
+ self.renderer = None
+ self.mesh_renderer = None
+
+
+ # Disable automatic optimization since we use adversarial training
+ self.automatic_optimization = False
+
+ def on_after_backward(self):
+ for name, param in self.named_parameters():
+ if param.grad is None:
+ print(param.shape)
+ print(name)
+
+
+ def get_parameters(self):
+ #all_params = list(self.mano_head.parameters())
+ all_params = list(self.backbone.parameters())
+ return all_params
+
+ def configure_optimizers(self) -> Tuple[torch.optim.Optimizer, torch.optim.Optimizer]:
+ """
+ Setup model and distriminator Optimizers
+ Returns:
+ Tuple[torch.optim.Optimizer, torch.optim.Optimizer]: Model and discriminator optimizers
+ """
+ param_groups = [{'params': filter(lambda p: p.requires_grad, self.get_parameters()), 'lr': self.cfg.TRAIN.LR}]
+
+ optimizer = torch.optim.AdamW(params=param_groups,
+ # lr=self.cfg.TRAIN.LR,
+ weight_decay=self.cfg.TRAIN.WEIGHT_DECAY)
+ optimizer_disc = torch.optim.AdamW(params=self.discriminator.parameters(),
+ lr=self.cfg.TRAIN.LR,
+ weight_decay=self.cfg.TRAIN.WEIGHT_DECAY)
+
+ return optimizer, optimizer_disc
+
+ def forward_step(self, batch: Dict, train: bool = False) -> Dict:
+ """
+ Run a forward step of the network
+ Args:
+ batch (Dict): Dictionary containing batch data
+ train (bool): Flag indicating whether it is training or validation mode
+ Returns:
+ Dict: Dictionary containing the regression output
+ """
+ # Use RGB image as input
+ x = batch['img']
+ batch_size = x.shape[0]
+ # Compute conditioning features using the backbone
+ # if using ViT backbone, we need to use a different aspect ratio
+ temp_mano_params, pred_cam, pred_mano_feats, vit_out = self.backbone(x[:,:,:,32:-32]) # B, 1280, 16, 12
+
+
+ # Compute camera translation
+ device = temp_mano_params['hand_pose'].device
+ dtype = temp_mano_params['hand_pose'].dtype
+ focal_length = self.cfg.EXTRA.FOCAL_LENGTH * torch.ones(batch_size, 2, device=device, dtype=dtype)
+
+
+ ## Temp MANO
+ temp_mano_params['global_orient'] = temp_mano_params['global_orient'].reshape(batch_size, -1, 3, 3)
+ temp_mano_params['hand_pose'] = temp_mano_params['hand_pose'].reshape(batch_size, -1, 3, 3)
+ temp_mano_params['betas'] = temp_mano_params['betas'].reshape(batch_size, -1)
+ temp_mano_output = self.mano(**{k: v.float() for k,v in temp_mano_params.items()}, pose2rot=False)
+ #temp_keypoints_3d = temp_mano_output.joints
+ temp_vertices = temp_mano_output.vertices
+
+ pred_mano_params, pred_cam = self.refine_net(vit_out, temp_vertices, pred_cam, pred_mano_feats, focal_length)
+ # Store useful regression outputs to the output dict
+
+
+ output = {}
+ output['pred_cam'] = pred_cam
+ output['pred_mano_params'] = {k: v.clone() for k,v in pred_mano_params.items()}
+
+ pred_cam_t = torch.stack([pred_cam[:, 1],
+ pred_cam[:, 2],
+ 2*focal_length[:, 0]/(self.cfg.MODEL.IMAGE_SIZE * pred_cam[:, 0] +1e-9)],dim=-1)
+ output['pred_cam_t'] = pred_cam_t
+ output['focal_length'] = focal_length
+
+ # Compute model vertices, joints and the projected joints
+ pred_mano_params['global_orient'] = pred_mano_params['global_orient'].reshape(batch_size, -1, 3, 3)
+ pred_mano_params['hand_pose'] = pred_mano_params['hand_pose'].reshape(batch_size, -1, 3, 3)
+ pred_mano_params['betas'] = pred_mano_params['betas'].reshape(batch_size, -1)
+ mano_output = self.mano(**{k: v.float() for k,v in pred_mano_params.items()}, pose2rot=False)
+ pred_keypoints_3d = mano_output.joints
+ pred_vertices = mano_output.vertices
+
+ output['pred_keypoints_3d'] = pred_keypoints_3d.reshape(batch_size, -1, 3)
+ output['pred_vertices'] = pred_vertices.reshape(batch_size, -1, 3)
+ pred_cam_t = pred_cam_t.reshape(-1, 3)
+ focal_length = focal_length.reshape(-1, 2)
+
+ pred_keypoints_2d = perspective_projection(pred_keypoints_3d,
+ translation=pred_cam_t,
+ focal_length=focal_length / self.cfg.MODEL.IMAGE_SIZE)
+ output['pred_keypoints_2d'] = pred_keypoints_2d.reshape(batch_size, -1, 2)
+
+ return output
+
+ def compute_loss(self, batch: Dict, output: Dict, train: bool = True) -> torch.Tensor:
+ """
+ Compute losses given the input batch and the regression output
+ Args:
+ batch (Dict): Dictionary containing batch data
+ output (Dict): Dictionary containing the regression output
+ train (bool): Flag indicating whether it is training or validation mode
+ Returns:
+ torch.Tensor : Total loss for current batch
+ """
+
+ pred_mano_params = output['pred_mano_params']
+ pred_keypoints_2d = output['pred_keypoints_2d']
+ pred_keypoints_3d = output['pred_keypoints_3d']
+
+
+ batch_size = pred_mano_params['hand_pose'].shape[0]
+ device = pred_mano_params['hand_pose'].device
+ dtype = pred_mano_params['hand_pose'].dtype
+
+ # Get annotations
+ gt_keypoints_2d = batch['keypoints_2d']
+ gt_keypoints_3d = batch['keypoints_3d']
+ gt_mano_params = batch['mano_params']
+ has_mano_params = batch['has_mano_params']
+ is_axis_angle = batch['mano_params_is_axis_angle']
+
+ # Compute 3D keypoint loss
+ loss_keypoints_2d = self.keypoint_2d_loss(pred_keypoints_2d, gt_keypoints_2d)
+ loss_keypoints_3d = self.keypoint_3d_loss(pred_keypoints_3d, gt_keypoints_3d, pelvis_id=0)
+
+ # Compute loss on MANO parameters
+ loss_mano_params = {}
+ for k, pred in pred_mano_params.items():
+ gt = gt_mano_params[k].view(batch_size, -1)
+ if is_axis_angle[k].all():
+ gt = aa_to_rotmat(gt.reshape(-1, 3)).view(batch_size, -1, 3, 3)
+ has_gt = has_mano_params[k]
+ loss_mano_params[k] = self.mano_parameter_loss(pred.reshape(batch_size, -1), gt.reshape(batch_size, -1), has_gt)
+
+ loss = self.cfg.LOSS_WEIGHTS['KEYPOINTS_3D'] * loss_keypoints_3d+\
+ self.cfg.LOSS_WEIGHTS['KEYPOINTS_2D'] * loss_keypoints_2d+\
+ sum([loss_mano_params[k] * self.cfg.LOSS_WEIGHTS[k.upper()] for k in loss_mano_params])
+
+
+ losses = dict(loss=loss.detach(),
+ loss_keypoints_2d=loss_keypoints_2d.detach(),
+ loss_keypoints_3d=loss_keypoints_3d.detach())
+
+ for k, v in loss_mano_params.items():
+ losses['loss_' + k] = v.detach()
+
+ output['losses'] = losses
+
+ return loss
+
+ # Tensoroboard logging should run from first rank only
+ @pl.utilities.rank_zero.rank_zero_only
+ def tensorboard_logging(self, batch: Dict, output: Dict, step_count: int, train: bool = True, write_to_summary_writer: bool = True) -> None:
+ """
+ Log results to Tensorboard
+ Args:
+ batch (Dict): Dictionary containing batch data
+ output (Dict): Dictionary containing the regression output
+ step_count (int): Global training step count
+ train (bool): Flag indicating whether it is training or validation mode
+ """
+
+ mode = 'train' if train else 'val'
+ batch_size = batch['keypoints_2d'].shape[0]
+ images = batch['img']
+ images = images * torch.tensor([0.229, 0.224, 0.225], device=images.device).reshape(1,3,1,1)
+ images = images + torch.tensor([0.485, 0.456, 0.406], device=images.device).reshape(1,3,1,1)
+ #images = 255*images.permute(0, 2, 3, 1).cpu().numpy()
+
+ pred_keypoints_3d = output['pred_keypoints_3d'].detach().reshape(batch_size, -1, 3)
+ pred_vertices = output['pred_vertices'].detach().reshape(batch_size, -1, 3)
+ focal_length = output['focal_length'].detach().reshape(batch_size, 2)
+ gt_keypoints_3d = batch['keypoints_3d']
+ gt_keypoints_2d = batch['keypoints_2d']
+
+ losses = output['losses']
+ pred_cam_t = output['pred_cam_t'].detach().reshape(batch_size, 3)
+ pred_keypoints_2d = output['pred_keypoints_2d'].detach().reshape(batch_size, -1, 2)
+ if write_to_summary_writer:
+ summary_writer = self.logger.experiment
+ for loss_name, val in losses.items():
+ summary_writer.add_scalar(mode +'/' + loss_name, val.detach().item(), step_count)
+ num_images = min(batch_size, self.cfg.EXTRA.NUM_LOG_IMAGES)
+
+ gt_keypoints_3d = batch['keypoints_3d']
+ pred_keypoints_3d = output['pred_keypoints_3d'].detach().reshape(batch_size, -1, 3)
+
+ # We render the skeletons instead of the full mesh because rendering a lot of meshes will make the training slow.
+ #predictions = self.renderer(pred_keypoints_3d[:num_images],
+ # gt_keypoints_3d[:num_images],
+ # 2 * gt_keypoints_2d[:num_images],
+ # images=images[:num_images],
+ # camera_translation=pred_cam_t[:num_images])
+ predictions = self.mesh_renderer.visualize_tensorboard(pred_vertices[:num_images].cpu().numpy(),
+ pred_cam_t[:num_images].cpu().numpy(),
+ images[:num_images].cpu().numpy(),
+ pred_keypoints_2d[:num_images].cpu().numpy(),
+ gt_keypoints_2d[:num_images].cpu().numpy(),
+ focal_length=focal_length[:num_images].cpu().numpy())
+ if write_to_summary_writer:
+ summary_writer.add_image('%s/predictions' % mode, predictions, step_count)
+
+ return predictions
+
+ def forward(self, batch: Dict) -> Dict:
+ """
+ Run a forward step of the network in val mode
+ Args:
+ batch (Dict): Dictionary containing batch data
+ Returns:
+ Dict: Dictionary containing the regression output
+ """
+ return self.forward_step(batch, train=False)
+
+ def training_step_discriminator(self, batch: Dict,
+ hand_pose: torch.Tensor,
+ betas: torch.Tensor,
+ optimizer: torch.optim.Optimizer) -> torch.Tensor:
+ """
+ Run a discriminator training step
+ Args:
+ batch (Dict): Dictionary containing mocap batch data
+ hand_pose (torch.Tensor): Regressed hand pose from current step
+ betas (torch.Tensor): Regressed betas from current step
+ optimizer (torch.optim.Optimizer): Discriminator optimizer
+ Returns:
+ torch.Tensor: Discriminator loss
+ """
+ batch_size = hand_pose.shape[0]
+ gt_hand_pose = batch['hand_pose']
+ gt_betas = batch['betas']
+ gt_rotmat = aa_to_rotmat(gt_hand_pose.view(-1,3)).view(batch_size, -1, 3, 3)
+ disc_fake_out = self.discriminator(hand_pose.detach(), betas.detach())
+ loss_fake = ((disc_fake_out - 0.0) ** 2).sum() / batch_size
+ disc_real_out = self.discriminator(gt_rotmat, gt_betas)
+ loss_real = ((disc_real_out - 1.0) ** 2).sum() / batch_size
+ loss_disc = loss_fake + loss_real
+ loss = self.cfg.LOSS_WEIGHTS.ADVERSARIAL * loss_disc
+ optimizer.zero_grad()
+ self.manual_backward(loss)
+ optimizer.step()
+ return loss_disc.detach()
+
+ def training_step(self, joint_batch: Dict, batch_idx: int) -> Dict:
+ """
+ Run a full training step
+ Args:
+ joint_batch (Dict): Dictionary containing image and mocap batch data
+ batch_idx (int): Unused.
+ batch_idx (torch.Tensor): Unused.
+ Returns:
+ Dict: Dictionary containing regression output.
+ """
+ batch = joint_batch['img']
+ mocap_batch = joint_batch['mocap']
+ optimizer = self.optimizers(use_pl_optimizer=True)
+ if self.cfg.LOSS_WEIGHTS.ADVERSARIAL > 0:
+ optimizer, optimizer_disc = optimizer
+
+ batch_size = batch['img'].shape[0]
+ output = self.forward_step(batch, train=True)
+ pred_mano_params = output['pred_mano_params']
+ if self.cfg.get('UPDATE_GT_SPIN', False):
+ self.update_batch_gt_spin(batch, output)
+ loss = self.compute_loss(batch, output, train=True)
+ if self.cfg.LOSS_WEIGHTS.ADVERSARIAL > 0:
+ disc_out = self.discriminator(pred_mano_params['hand_pose'].reshape(batch_size, -1), pred_mano_params['betas'].reshape(batch_size, -1))
+ loss_adv = ((disc_out - 1.0) ** 2).sum() / batch_size
+ loss = loss + self.cfg.LOSS_WEIGHTS.ADVERSARIAL * loss_adv
+
+ # Error if Nan
+ if torch.isnan(loss):
+ raise ValueError('Loss is NaN')
+
+ optimizer.zero_grad()
+ self.manual_backward(loss)
+ # Clip gradient
+ if self.cfg.TRAIN.get('GRAD_CLIP_VAL', 0) > 0:
+ gn = torch.nn.utils.clip_grad_norm_(self.get_parameters(), self.cfg.TRAIN.GRAD_CLIP_VAL, error_if_nonfinite=True)
+ self.log('train/grad_norm', gn, on_step=True, on_epoch=True, prog_bar=True, logger=True)
+ optimizer.step()
+ if self.cfg.LOSS_WEIGHTS.ADVERSARIAL > 0:
+ loss_disc = self.training_step_discriminator(mocap_batch, pred_mano_params['hand_pose'].reshape(batch_size, -1), pred_mano_params['betas'].reshape(batch_size, -1), optimizer_disc)
+ output['losses']['loss_gen'] = loss_adv
+ output['losses']['loss_disc'] = loss_disc
+
+ if self.global_step > 0 and self.global_step % self.cfg.GENERAL.LOG_STEPS == 0:
+ self.tensorboard_logging(batch, output, self.global_step, train=True)
+
+ self.log('train/loss', output['losses']['loss'], on_step=True, on_epoch=True, prog_bar=True, logger=False)
+
+ return output
+
+ def validation_step(self, batch: Dict, batch_idx: int, dataloader_idx=0) -> Dict:
+ """
+ Run a validation step and log to Tensorboard
+ Args:
+ batch (Dict): Dictionary containing batch data
+ batch_idx (int): Unused.
+ Returns:
+ Dict: Dictionary containing regression output.
+ """
+ # batch_size = batch['img'].shape[0]
+ output = self.forward_step(batch, train=False)
+ loss = self.compute_loss(batch, output, train=False)
+ output['loss'] = loss
+ self.tensorboard_logging(batch, output, self.global_step, train=False)
+
+ return output
diff --git a/wilor/utils/__init__.py b/wilor/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..09e47cdf8cdb303432d64902fbe58b256273f88a
--- /dev/null
+++ b/wilor/utils/__init__.py
@@ -0,0 +1,25 @@
+import torch
+from typing import Any
+
+from .renderer import Renderer
+from .mesh_renderer import MeshRenderer
+from .skeleton_renderer import SkeletonRenderer
+from .pose_utils import eval_pose, Evaluator
+
+def recursive_to(x: Any, target: torch.device):
+ """
+ Recursively transfer a batch of data to the target device
+ Args:
+ x (Any): Batch of data.
+ target (torch.device): Target device.
+ Returns:
+ Batch of data where all tensors are transfered to the target device.
+ """
+ if isinstance(x, dict):
+ return {k: recursive_to(v, target) for k, v in x.items()}
+ elif isinstance(x, torch.Tensor):
+ return x.to(target)
+ elif isinstance(x, list):
+ return [recursive_to(i, target) for i in x]
+ else:
+ return x
diff --git a/wilor/utils/geometry.py b/wilor/utils/geometry.py
new file mode 100644
index 0000000000000000000000000000000000000000..7929ef52608618a4682788487008e73c5736101b
--- /dev/null
+++ b/wilor/utils/geometry.py
@@ -0,0 +1,102 @@
+from typing import Optional
+import torch
+from torch.nn import functional as F
+
+def aa_to_rotmat(theta: torch.Tensor):
+ """
+ Convert axis-angle representation to rotation matrix.
+ Works by first converting it to a quaternion.
+ Args:
+ theta (torch.Tensor): Tensor of shape (B, 3) containing axis-angle representations.
+ Returns:
+ torch.Tensor: Corresponding rotation matrices with shape (B, 3, 3).
+ """
+ norm = torch.norm(theta + 1e-8, p = 2, dim = 1)
+ angle = torch.unsqueeze(norm, -1)
+ normalized = torch.div(theta, angle)
+ angle = angle * 0.5
+ v_cos = torch.cos(angle)
+ v_sin = torch.sin(angle)
+ quat = torch.cat([v_cos, v_sin * normalized], dim = 1)
+ return quat_to_rotmat(quat)
+
+def quat_to_rotmat(quat: torch.Tensor) -> torch.Tensor:
+ """
+ Convert quaternion representation to rotation matrix.
+ Args:
+ quat (torch.Tensor) of shape (B, 4); 4 <===> (w, x, y, z).
+ Returns:
+ torch.Tensor: Corresponding rotation matrices with shape (B, 3, 3).
+ """
+ norm_quat = quat
+ norm_quat = norm_quat/norm_quat.norm(p=2, dim=1, keepdim=True)
+ w, x, y, z = norm_quat[:,0], norm_quat[:,1], norm_quat[:,2], norm_quat[:,3]
+
+ B = quat.size(0)
+
+ w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2)
+ wx, wy, wz = w*x, w*y, w*z
+ xy, xz, yz = x*y, x*z, y*z
+
+ rotMat = torch.stack([w2 + x2 - y2 - z2, 2*xy - 2*wz, 2*wy + 2*xz,
+ 2*wz + 2*xy, w2 - x2 + y2 - z2, 2*yz - 2*wx,
+ 2*xz - 2*wy, 2*wx + 2*yz, w2 - x2 - y2 + z2], dim=1).view(B, 3, 3)
+ return rotMat
+
+
+def rot6d_to_rotmat(x: torch.Tensor) -> torch.Tensor:
+ """
+ Convert 6D rotation representation to 3x3 rotation matrix.
+ Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019
+ Args:
+ x (torch.Tensor): (B,6) Batch of 6-D rotation representations.
+ Returns:
+ torch.Tensor: Batch of corresponding rotation matrices with shape (B,3,3).
+ """
+ x = x.reshape(-1,2,3).permute(0, 2, 1).contiguous()
+ a1 = x[:, :, 0]
+ a2 = x[:, :, 1]
+ b1 = F.normalize(a1)
+ b2 = F.normalize(a2 - torch.einsum('bi,bi->b', b1, a2).unsqueeze(-1) * b1)
+ b3 = torch.cross(b1, b2)
+ return torch.stack((b1, b2, b3), dim=-1)
+
+def perspective_projection(points: torch.Tensor,
+ translation: torch.Tensor,
+ focal_length: torch.Tensor,
+ camera_center: Optional[torch.Tensor] = None,
+ rotation: Optional[torch.Tensor] = None) -> torch.Tensor:
+ """
+ Computes the perspective projection of a set of 3D points.
+ Args:
+ points (torch.Tensor): Tensor of shape (B, N, 3) containing the input 3D points.
+ translation (torch.Tensor): Tensor of shape (B, 3) containing the 3D camera translation.
+ focal_length (torch.Tensor): Tensor of shape (B, 2) containing the focal length in pixels.
+ camera_center (torch.Tensor): Tensor of shape (B, 2) containing the camera center in pixels.
+ rotation (torch.Tensor): Tensor of shape (B, 3, 3) containing the camera rotation.
+ Returns:
+ torch.Tensor: Tensor of shape (B, N, 2) containing the projection of the input points.
+ """
+ batch_size = points.shape[0]
+ if rotation is None:
+ rotation = torch.eye(3, device=points.device, dtype=points.dtype).unsqueeze(0).expand(batch_size, -1, -1)
+ if camera_center is None:
+ camera_center = torch.zeros(batch_size, 2, device=points.device, dtype=points.dtype)
+ # Populate intrinsic camera matrix K.
+ K = torch.zeros([batch_size, 3, 3], device=points.device, dtype=points.dtype)
+ K[:,0,0] = focal_length[:,0]
+ K[:,1,1] = focal_length[:,1]
+ K[:,2,2] = 1.
+ K[:,:-1, -1] = camera_center
+
+ # Transform points
+ points = torch.einsum('bij,bkj->bki', rotation, points)
+ points = points + translation.unsqueeze(1)
+
+ # Apply perspective distortion
+ projected_points = points / points[:,:,-1].unsqueeze(-1)
+
+ # Apply camera intrinsics
+ projected_points = torch.einsum('bij,bkj->bki', K, projected_points)
+
+ return projected_points[:, :, :-1]
\ No newline at end of file
diff --git a/wilor/utils/mesh_renderer.py b/wilor/utils/mesh_renderer.py
new file mode 100644
index 0000000000000000000000000000000000000000..bb3e8ed2e9aed8157ec852d06d5f13e8f4ff7c54
--- /dev/null
+++ b/wilor/utils/mesh_renderer.py
@@ -0,0 +1,149 @@
+import os
+if 'PYOPENGL_PLATFORM' not in os.environ:
+ os.environ['PYOPENGL_PLATFORM'] = 'egl'
+import torch
+from torchvision.utils import make_grid
+import numpy as np
+import pyrender
+import trimesh
+import cv2
+import torch.nn.functional as F
+
+from .render_openpose import render_openpose
+
+def create_raymond_lights():
+ import pyrender
+ thetas = np.pi * np.array([1.0 / 6.0, 1.0 / 6.0, 1.0 / 6.0])
+ phis = np.pi * np.array([0.0, 2.0 / 3.0, 4.0 / 3.0])
+
+ nodes = []
+
+ for phi, theta in zip(phis, thetas):
+ xp = np.sin(theta) * np.cos(phi)
+ yp = np.sin(theta) * np.sin(phi)
+ zp = np.cos(theta)
+
+ z = np.array([xp, yp, zp])
+ z = z / np.linalg.norm(z)
+ x = np.array([-z[1], z[0], 0.0])
+ if np.linalg.norm(x) == 0:
+ x = np.array([1.0, 0.0, 0.0])
+ x = x / np.linalg.norm(x)
+ y = np.cross(z, x)
+
+ matrix = np.eye(4)
+ matrix[:3,:3] = np.c_[x,y,z]
+ nodes.append(pyrender.Node(
+ light=pyrender.DirectionalLight(color=np.ones(3), intensity=1.0),
+ matrix=matrix
+ ))
+
+ return nodes
+
+class MeshRenderer:
+
+ def __init__(self, cfg, faces=None):
+ self.cfg = cfg
+ self.focal_length = cfg.EXTRA.FOCAL_LENGTH
+ self.img_res = cfg.MODEL.IMAGE_SIZE
+ self.renderer = pyrender.OffscreenRenderer(viewport_width=self.img_res,
+ viewport_height=self.img_res,
+ point_size=1.0)
+
+ self.camera_center = [self.img_res // 2, self.img_res // 2]
+ self.faces = faces
+
+ def visualize(self, vertices, camera_translation, images, focal_length=None, nrow=3, padding=2):
+ images_np = np.transpose(images, (0,2,3,1))
+ rend_imgs = []
+ for i in range(vertices.shape[0]):
+ fl = self.focal_length
+ rend_img = torch.from_numpy(np.transpose(self.__call__(vertices[i], camera_translation[i], images_np[i], focal_length=fl, side_view=False), (2,0,1))).float()
+ rend_img_side = torch.from_numpy(np.transpose(self.__call__(vertices[i], camera_translation[i], images_np[i], focal_length=fl, side_view=True), (2,0,1))).float()
+ rend_imgs.append(torch.from_numpy(images[i]))
+ rend_imgs.append(rend_img)
+ rend_imgs.append(rend_img_side)
+ rend_imgs = make_grid(rend_imgs, nrow=nrow, padding=padding)
+ return rend_imgs
+
+ def visualize_tensorboard(self, vertices, camera_translation, images, pred_keypoints, gt_keypoints, focal_length=None, nrow=5, padding=2):
+ images_np = np.transpose(images, (0,2,3,1))
+ rend_imgs = []
+ pred_keypoints = np.concatenate((pred_keypoints, np.ones_like(pred_keypoints)[:, :, [0]]), axis=-1)
+ pred_keypoints = self.img_res * (pred_keypoints + 0.5)
+ gt_keypoints[:, :, :-1] = self.img_res * (gt_keypoints[:, :, :-1] + 0.5)
+ #keypoint_matches = [(1, 12), (2, 8), (3, 7), (4, 6), (5, 9), (6, 10), (7, 11), (8, 14), (9, 2), (10, 1), (11, 0), (12, 3), (13, 4), (14, 5)]
+ for i in range(vertices.shape[0]):
+ fl = self.focal_length
+ rend_img = torch.from_numpy(np.transpose(self.__call__(vertices[i], camera_translation[i], images_np[i], focal_length=fl, side_view=False), (2,0,1))).float()
+ rend_img_side = torch.from_numpy(np.transpose(self.__call__(vertices[i], camera_translation[i], images_np[i], focal_length=fl, side_view=True), (2,0,1))).float()
+ hand_keypoints = pred_keypoints[i, :21]
+ #extra_keypoints = pred_keypoints[i, -19:]
+ #for pair in keypoint_matches:
+ # hand_keypoints[pair[0], :] = extra_keypoints[pair[1], :]
+ pred_keypoints_img = render_openpose(255 * images_np[i].copy(), hand_keypoints) / 255
+ hand_keypoints = gt_keypoints[i, :21]
+ #extra_keypoints = gt_keypoints[i, -19:]
+ #for pair in keypoint_matches:
+ # if extra_keypoints[pair[1], -1] > 0 and hand_keypoints[pair[0], -1] == 0:
+ # hand_keypoints[pair[0], :] = extra_keypoints[pair[1], :]
+ gt_keypoints_img = render_openpose(255*images_np[i].copy(), hand_keypoints) / 255
+ rend_imgs.append(torch.from_numpy(images[i]))
+ rend_imgs.append(rend_img)
+ rend_imgs.append(rend_img_side)
+ rend_imgs.append(torch.from_numpy(pred_keypoints_img).permute(2,0,1))
+ rend_imgs.append(torch.from_numpy(gt_keypoints_img).permute(2,0,1))
+ rend_imgs = make_grid(rend_imgs, nrow=nrow, padding=padding)
+ return rend_imgs
+
+ def __call__(self, vertices, camera_translation, image, focal_length=5000, text=None, resize=None, side_view=False, baseColorFactor=(1.0, 1.0, 0.9, 1.0), rot_angle=90):
+ renderer = pyrender.OffscreenRenderer(viewport_width=image.shape[1],
+ viewport_height=image.shape[0],
+ point_size=1.0)
+ material = pyrender.MetallicRoughnessMaterial(
+ metallicFactor=0.0,
+ alphaMode='OPAQUE',
+ baseColorFactor=baseColorFactor)
+
+ camera_translation[0] *= -1.
+
+ mesh = trimesh.Trimesh(vertices.copy(), self.faces.copy())
+ if side_view:
+ rot = trimesh.transformations.rotation_matrix(
+ np.radians(rot_angle), [0, 1, 0])
+ mesh.apply_transform(rot)
+ rot = trimesh.transformations.rotation_matrix(
+ np.radians(180), [1, 0, 0])
+ mesh.apply_transform(rot)
+ mesh = pyrender.Mesh.from_trimesh(mesh, material=material)
+
+ scene = pyrender.Scene(bg_color=[0.0, 0.0, 0.0, 0.0],
+ ambient_light=(0.3, 0.3, 0.3))
+ scene.add(mesh, 'mesh')
+
+ camera_pose = np.eye(4)
+ camera_pose[:3, 3] = camera_translation
+ camera_center = [image.shape[1] / 2., image.shape[0] / 2.]
+ camera = pyrender.IntrinsicsCamera(fx=focal_length, fy=focal_length,
+ cx=camera_center[0], cy=camera_center[1])
+ scene.add(camera, pose=camera_pose)
+
+
+ light_nodes = create_raymond_lights()
+ for node in light_nodes:
+ scene.add_node(node)
+
+ color, rend_depth = renderer.render(scene, flags=pyrender.RenderFlags.RGBA)
+ color = color.astype(np.float32) / 255.0
+ valid_mask = (color[:, :, -1] > 0)[:, :, np.newaxis]
+ if not side_view:
+ output_img = (color[:, :, :3] * valid_mask +
+ (1 - valid_mask) * image)
+ else:
+ output_img = color[:, :, :3]
+ if resize is not None:
+ output_img = cv2.resize(output_img, resize)
+
+ output_img = output_img.astype(np.float32)
+ renderer.delete()
+ return output_img
diff --git a/wilor/utils/misc.py b/wilor/utils/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..ffcfe784872b305c264ce6ef67fd0a9e9ad3390f
--- /dev/null
+++ b/wilor/utils/misc.py
@@ -0,0 +1,203 @@
+import time
+import warnings
+from importlib.util import find_spec
+from pathlib import Path
+from typing import Callable, List
+
+import hydra
+from omegaconf import DictConfig, OmegaConf
+from pytorch_lightning import Callback
+from pytorch_lightning.loggers import Logger
+from pytorch_lightning.utilities import rank_zero_only
+
+from . import pylogger, rich_utils
+
+log = pylogger.get_pylogger(__name__)
+
+
+def task_wrapper(task_func: Callable) -> Callable:
+ """Optional decorator that wraps the task function in extra utilities.
+
+ Makes multirun more resistant to failure.
+
+ Utilities:
+ - Calling the `utils.extras()` before the task is started
+ - Calling the `utils.close_loggers()` after the task is finished
+ - Logging the exception if occurs
+ - Logging the task total execution time
+ - Logging the output dir
+ """
+
+ def wrap(cfg: DictConfig):
+
+ # apply extra utilities
+ extras(cfg)
+
+ # execute the task
+ try:
+ start_time = time.time()
+ ret = task_func(cfg=cfg)
+ except Exception as ex:
+ log.exception("") # save exception to `.log` file
+ raise ex
+ finally:
+ path = Path(cfg.paths.output_dir, "exec_time.log")
+ content = f"'{cfg.task_name}' execution time: {time.time() - start_time} (s)"
+ save_file(path, content) # save task execution time (even if exception occurs)
+ close_loggers() # close loggers (even if exception occurs so multirun won't fail)
+
+ log.info(f"Output dir: {cfg.paths.output_dir}")
+
+ return ret
+
+ return wrap
+
+
+def extras(cfg: DictConfig) -> None:
+ """Applies optional utilities before the task is started.
+
+ Utilities:
+ - Ignoring python warnings
+ - Setting tags from command line
+ - Rich config printing
+ """
+
+ # return if no `extras` config
+ if not cfg.get("extras"):
+ log.warning("Extras config not found! ")
+ return
+
+ # disable python warnings
+ if cfg.extras.get("ignore_warnings"):
+ log.info("Disabling python warnings! ")
+ warnings.filterwarnings("ignore")
+
+ # prompt user to input tags from command line if none are provided in the config
+ if cfg.extras.get("enforce_tags"):
+ log.info("Enforcing tags! ")
+ rich_utils.enforce_tags(cfg, save_to_file=True)
+
+ # pretty print config tree using Rich library
+ if cfg.extras.get("print_config"):
+ log.info("Printing config tree with Rich! ")
+ rich_utils.print_config_tree(cfg, resolve=True, save_to_file=True)
+
+
+@rank_zero_only
+def save_file(path: str, content: str) -> None:
+ """Save file in rank zero mode (only on one process in multi-GPU setup)."""
+ with open(path, "w+") as file:
+ file.write(content)
+
+
+def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]:
+ """Instantiates callbacks from config."""
+ callbacks: List[Callback] = []
+
+ if not callbacks_cfg:
+ log.warning("Callbacks config is empty.")
+ return callbacks
+
+ if not isinstance(callbacks_cfg, DictConfig):
+ raise TypeError("Callbacks config must be a DictConfig!")
+
+ for _, cb_conf in callbacks_cfg.items():
+ if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf:
+ log.info(f"Instantiating callback <{cb_conf._target_}>")
+ callbacks.append(hydra.utils.instantiate(cb_conf))
+
+ return callbacks
+
+
+def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]:
+ """Instantiates loggers from config."""
+ logger: List[Logger] = []
+
+ if not logger_cfg:
+ log.warning("Logger config is empty.")
+ return logger
+
+ if not isinstance(logger_cfg, DictConfig):
+ raise TypeError("Logger config must be a DictConfig!")
+
+ for _, lg_conf in logger_cfg.items():
+ if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf:
+ log.info(f"Instantiating logger <{lg_conf._target_}>")
+ logger.append(hydra.utils.instantiate(lg_conf))
+
+ return logger
+
+
+@rank_zero_only
+def log_hyperparameters(object_dict: dict) -> None:
+ """Controls which config parts are saved by lightning loggers.
+
+ Additionally saves:
+ - Number of model parameters
+ """
+
+ hparams = {}
+
+ cfg = object_dict["cfg"]
+ model = object_dict["model"]
+ trainer = object_dict["trainer"]
+
+ if not trainer.logger:
+ log.warning("Logger not found! Skipping hyperparameter logging...")
+ return
+
+ # save number of model parameters
+ hparams["model/params/total"] = sum(p.numel() for p in model.parameters())
+ hparams["model/params/trainable"] = sum(
+ p.numel() for p in model.parameters() if p.requires_grad
+ )
+ hparams["model/params/non_trainable"] = sum(
+ p.numel() for p in model.parameters() if not p.requires_grad
+ )
+
+ for k in cfg.keys():
+ hparams[k] = cfg.get(k)
+
+ # Resolve all interpolations
+ def _resolve(_cfg):
+ if isinstance(_cfg, DictConfig):
+ _cfg = OmegaConf.to_container(_cfg, resolve=True)
+ return _cfg
+
+ hparams = {k: _resolve(v) for k, v in hparams.items()}
+
+ # send hparams to all loggers
+ trainer.logger.log_hyperparams(hparams)
+
+
+def get_metric_value(metric_dict: dict, metric_name: str) -> float:
+ """Safely retrieves value of the metric logged in LightningModule."""
+
+ if not metric_name:
+ log.info("Metric name is None! Skipping metric value retrieval...")
+ return None
+
+ if metric_name not in metric_dict:
+ raise Exception(
+ f"Metric value not found! \n"
+ "Make sure metric name logged in LightningModule is correct!\n"
+ "Make sure `optimized_metric` name in `hparams_search` config is correct!"
+ )
+
+ metric_value = metric_dict[metric_name].item()
+ log.info(f"Retrieved metric value! <{metric_name}={metric_value}>")
+
+ return metric_value
+
+
+def close_loggers() -> None:
+ """Makes sure all loggers closed properly (prevents logging failure during multirun)."""
+
+ log.info("Closing loggers...")
+
+ if find_spec("wandb"): # if wandb is installed
+ import wandb
+
+ if wandb.run:
+ log.info("Closing wandb!")
+ wandb.finish()
diff --git a/wilor/utils/pose_utils.py b/wilor/utils/pose_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..1b386e39aca64ce73ae10fea9ba2b767ce9e25b2
--- /dev/null
+++ b/wilor/utils/pose_utils.py
@@ -0,0 +1,352 @@
+"""
+Code adapted from: https://github.com/akanazawa/hmr/blob/master/src/benchmark/eval_util.py
+"""
+
+import torch
+import numpy as np
+from typing import Optional, Dict, List, Tuple
+
+def compute_similarity_transform(S1: torch.Tensor, S2: torch.Tensor) -> torch.Tensor:
+ """
+ Computes a similarity transform (sR, t) in a batched way that takes
+ a set of 3D points S1 (B, N, 3) closest to a set of 3D points S2 (B, N, 3),
+ where R is a 3x3 rotation matrix, t 3x1 translation, s scale.
+ i.e. solves the orthogonal Procrutes problem.
+ Args:
+ S1 (torch.Tensor): First set of points of shape (B, N, 3).
+ S2 (torch.Tensor): Second set of points of shape (B, N, 3).
+ Returns:
+ (torch.Tensor): The first set of points after applying the similarity transformation.
+ """
+
+ batch_size = S1.shape[0]
+ S1 = S1.permute(0, 2, 1)
+ S2 = S2.permute(0, 2, 1)
+ # 1. Remove mean.
+ mu1 = S1.mean(dim=2, keepdim=True)
+ mu2 = S2.mean(dim=2, keepdim=True)
+ X1 = S1 - mu1
+ X2 = S2 - mu2
+
+ # 2. Compute variance of X1 used for scale.
+ var1 = (X1**2).sum(dim=(1,2))
+
+ # 3. The outer product of X1 and X2.
+ K = torch.matmul(X1, X2.permute(0, 2, 1))
+
+ # 4. Solution that Maximizes trace(R'K) is R=U*V', where U, V are singular vectors of K.
+ U, s, V = torch.svd(K)
+ Vh = V.permute(0, 2, 1)
+
+ # Construct Z that fixes the orientation of R to get det(R)=1.
+ Z = torch.eye(U.shape[1], device=U.device).unsqueeze(0).repeat(batch_size, 1, 1)
+ Z[:, -1, -1] *= torch.sign(torch.linalg.det(torch.matmul(U, Vh)))
+
+ # Construct R.
+ R = torch.matmul(torch.matmul(V, Z), U.permute(0, 2, 1))
+
+ # 5. Recover scale.
+ trace = torch.matmul(R, K).diagonal(offset=0, dim1=-1, dim2=-2).sum(dim=-1)
+ scale = (trace / var1).unsqueeze(dim=-1).unsqueeze(dim=-1)
+
+ # 6. Recover translation.
+ t = mu2 - scale*torch.matmul(R, mu1)
+
+ # 7. Error:
+ S1_hat = scale*torch.matmul(R, S1) + t
+
+ return S1_hat.permute(0, 2, 1)
+
+def reconstruction_error(S1, S2) -> np.array:
+ """
+ Computes the mean Euclidean distance of 2 set of points S1, S2 after performing Procrustes alignment.
+ Args:
+ S1 (torch.Tensor): First set of points of shape (B, N, 3).
+ S2 (torch.Tensor): Second set of points of shape (B, N, 3).
+ Returns:
+ (np.array): Reconstruction error.
+ """
+ S1_hat = compute_similarity_transform(S1, S2)
+ re = torch.sqrt( ((S1_hat - S2)** 2).sum(dim=-1)).mean(dim=-1)
+ return re
+
+def eval_pose(pred_joints, gt_joints) -> Tuple[np.array, np.array]:
+ """
+ Compute joint errors in mm before and after Procrustes alignment.
+ Args:
+ pred_joints (torch.Tensor): Predicted 3D joints of shape (B, N, 3).
+ gt_joints (torch.Tensor): Ground truth 3D joints of shape (B, N, 3).
+ Returns:
+ Tuple[np.array, np.array]: Joint errors in mm before and after alignment.
+ """
+ # Absolute error (MPJPE)
+ mpjpe = torch.sqrt(((pred_joints - gt_joints) ** 2).sum(dim=-1)).mean(dim=-1).cpu().numpy()
+
+ # Reconstruction_error
+ r_error = reconstruction_error(pred_joints, gt_joints).cpu().numpy()
+ return 1000 * mpjpe, 1000 * r_error
+
+class Evaluator:
+
+ def __init__(self,
+ dataset_length: int,
+ dataset: str,
+ keypoint_list: List,
+ pelvis_ind: int,
+ metrics: List = ['mode_mpjpe', 'mode_re', 'min_mpjpe', 'min_re'],
+ preds: List = ['vertices', 'keypoints_3d'],
+ pck_thresholds: Optional[List] = None):
+ """
+ Class used for evaluating trained models on different 3D pose datasets.
+ Args:
+ dataset_length (int): Total dataset length.
+ keypoint_list [List]: List of keypoints used for evaluation.
+ pelvis_ind (int): Index of pelvis keypoint; used for aligning the predictions and ground truth.
+ metrics [List]: List of evaluation metrics to record.
+ """
+ self.dataset_length = dataset_length
+ self.dataset = dataset
+ self.keypoint_list = keypoint_list
+ self.pelvis_ind = pelvis_ind
+ self.metrics = metrics
+ self.preds = preds
+ if self.metrics is not None:
+ for metric in self.metrics:
+ setattr(self, metric, np.zeros((dataset_length,)))
+ if self.preds is not None:
+ for pred in self.preds:
+ if pred == 'vertices':
+ self.vertices = np.zeros((dataset_length, 778, 3))
+ if pred == 'keypoints_3d':
+ self.keypoints_3d = np.zeros((dataset_length, 21, 3))
+ self.counter = 0
+ if pck_thresholds is None:
+ self.pck_evaluator = None
+ else:
+ self.pck_evaluator = EvaluatorPCK(pck_thresholds)
+
+ def log(self):
+ """
+ Print current evaluation metrics
+ """
+ if self.counter == 0:
+ print('Evaluation has not started')
+ return
+ print(f'{self.counter} / {self.dataset_length} samples')
+ if self.pck_evaluator is not None:
+ self.pck_evaluator.log()
+ if self.metrics is not None:
+ for metric in self.metrics:
+ if metric in ['mode_mpjpe', 'mode_re', 'min_mpjpe', 'min_re']:
+ unit = 'mm'
+ else:
+ unit = ''
+ print(f'{metric}: {getattr(self, metric)[:self.counter].mean()} {unit}')
+ print('***')
+
+ def get_metrics_dict(self) -> Dict:
+ """
+ Returns:
+ Dict: Dictionary of evaluation metrics.
+ """
+ d1 = {metric: getattr(self, metric)[:self.counter].mean() for metric in self.metrics}
+ if self.pck_evaluator is not None:
+ d2 = self.pck_evaluator.get_metrics_dict()
+ d1.update(d2)
+ return d1
+
+ def get_preds_dict(self) -> Dict:
+ """
+ Returns:
+ Dict: Dictionary of evaluation preds.
+ """
+ d1 = {pred: getattr(self, pred)[:self.counter] for pred in self.preds}
+ return d1
+
+ def __call__(self, output: Dict, batch: Dict, opt_output: Optional[Dict] = None):
+ """
+ Evaluate current batch.
+ Args:
+ output (Dict): Regression output.
+ batch (Dict): Dictionary containing images and their corresponding annotations.
+ opt_output (Dict): Optimization output.
+ """
+ if self.pck_evaluator is not None:
+ self.pck_evaluator(output, batch, opt_output)
+
+ pred_keypoints_3d = output['pred_keypoints_3d'].detach()
+ pred_keypoints_3d = pred_keypoints_3d[:,None,:,:]
+ batch_size = pred_keypoints_3d.shape[0]
+ num_samples = pred_keypoints_3d.shape[1]
+ gt_keypoints_3d = batch['keypoints_3d'][:, :, :-1].unsqueeze(1).repeat(1, num_samples, 1, 1)
+ pred_vertices = output['pred_vertices'].detach()
+
+ # Align predictions and ground truth such that the pelvis location is at the origin
+ pred_keypoints_3d -= pred_keypoints_3d[:, :, [self.pelvis_ind]]
+ gt_keypoints_3d -= gt_keypoints_3d[:, :, [self.pelvis_ind]]
+
+ # Compute joint errors
+ mpjpe, re = eval_pose(pred_keypoints_3d.reshape(batch_size * num_samples, -1, 3)[:, self.keypoint_list], gt_keypoints_3d.reshape(batch_size * num_samples, -1 ,3)[:, self.keypoint_list])
+ mpjpe = mpjpe.reshape(batch_size, num_samples)
+ re = re.reshape(batch_size, num_samples)
+
+ # Compute 2d keypoint errors
+ bbox_expand_factor = batch['bbox_expand_factor'][:,None,None,None].detach()
+ pred_keypoints_2d = output['pred_keypoints_2d'].detach()
+ pred_keypoints_2d = pred_keypoints_2d[:,None,:,:]*bbox_expand_factor
+ gt_keypoints_2d = batch['keypoints_2d'][:,None,:,:].repeat(1, num_samples, 1, 1)*bbox_expand_factor
+ conf = gt_keypoints_2d[:, :, :, -1].clone()
+ kp_err = torch.nn.functional.mse_loss(
+ pred_keypoints_2d,
+ gt_keypoints_2d[:, :, :, :-1],
+ reduction='none'
+ ).sum(dim=3)
+ kp_l2_loss = (conf * kp_err).mean(dim=2)
+ kp_l2_loss = kp_l2_loss.detach().cpu().numpy()
+
+ # Compute joint errors after optimization, if available.
+ if opt_output is not None:
+ opt_keypoints_3d = opt_output['model_joints']
+ opt_keypoints_3d -= opt_keypoints_3d[:, [self.pelvis_ind]]
+ opt_mpjpe, opt_re = eval_pose(opt_keypoints_3d[:, self.keypoint_list], gt_keypoints_3d[:, 0, self.keypoint_list])
+
+ # The 0-th sample always corresponds to the mode
+ if hasattr(self, 'mode_mpjpe'):
+ mode_mpjpe = mpjpe[:, 0]
+ self.mode_mpjpe[self.counter:self.counter+batch_size] = mode_mpjpe
+ if hasattr(self, 'mode_re'):
+ mode_re = re[:, 0]
+ self.mode_re[self.counter:self.counter+batch_size] = mode_re
+ if hasattr(self, 'mode_kpl2'):
+ mode_kpl2 = kp_l2_loss[:, 0]
+ self.mode_kpl2[self.counter:self.counter+batch_size] = mode_kpl2
+ if hasattr(self, 'min_mpjpe'):
+ min_mpjpe = mpjpe.min(axis=-1)
+ self.min_mpjpe[self.counter:self.counter+batch_size] = min_mpjpe
+ if hasattr(self, 'min_re'):
+ min_re = re.min(axis=-1)
+ self.min_re[self.counter:self.counter+batch_size] = min_re
+ if hasattr(self, 'min_kpl2'):
+ min_kpl2 = kp_l2_loss.min(axis=-1)
+ self.min_kpl2[self.counter:self.counter+batch_size] = min_kpl2
+ if hasattr(self, 'opt_mpjpe'):
+ self.opt_mpjpe[self.counter:self.counter+batch_size] = opt_mpjpe
+ if hasattr(self, 'opt_re'):
+ self.opt_re[self.counter:self.counter+batch_size] = opt_re
+ if hasattr(self, 'vertices'):
+ self.vertices[self.counter:self.counter+batch_size] = pred_vertices.cpu().numpy()
+ if hasattr(self, 'keypoints_3d'):
+ if self.dataset == 'HO3D-VAL':
+ pred_keypoints_3d = pred_keypoints_3d[:,:,[0,5,6,7,9,10,11,17,18,19,13,14,15,1,2,3,4,8,12,16,20]]
+ self.keypoints_3d[self.counter:self.counter+batch_size] = pred_keypoints_3d.squeeze().cpu().numpy()
+
+ self.counter += batch_size
+
+ if hasattr(self, 'mode_mpjpe') and hasattr(self, 'mode_re'):
+ return {
+ 'mode_mpjpe': mode_mpjpe,
+ 'mode_re': mode_re,
+ }
+ else:
+ return {}
+
+
+class EvaluatorPCK:
+
+ def __init__(self, thresholds: List = [0.05, 0.1, 0.2, 0.3, 0.4, 0.5],):
+ """
+ Class used for evaluating trained models on different 3D pose datasets.
+ Args:
+ thresholds [List]: List of PCK thresholds to evaluate.
+ metrics [List]: List of evaluation metrics to record.
+ """
+ self.thresholds = thresholds
+ self.pred_kp_2d = []
+ self.gt_kp_2d = []
+ self.gt_conf_2d = []
+ self.scale = []
+ self.counter = 0
+
+ def log(self):
+ """
+ Print current evaluation metrics
+ """
+ if self.counter == 0:
+ print('Evaluation has not started')
+ return
+ print(f'{self.counter} samples')
+ metrics_dict = self.get_metrics_dict()
+ for metric in metrics_dict:
+ print(f'{metric}: {metrics_dict[metric]}')
+ print('***')
+
+ def get_metrics_dict(self) -> Dict:
+ """
+ Returns:
+ Dict: Dictionary of evaluation metrics.
+ """
+ pcks = self.compute_pcks()
+ metrics = {}
+ for thr, (acc,avg_acc,cnt) in zip(self.thresholds, pcks):
+ metrics.update({f'kp{i}_pck_{thr}': float(a) for i, a in enumerate(acc) if a>=0})
+ metrics.update({f'kpAvg_pck_{thr}': float(avg_acc)})
+ return metrics
+
+ def compute_pcks(self):
+ pred_kp_2d = np.concatenate(self.pred_kp_2d, axis=0)
+ gt_kp_2d = np.concatenate(self.gt_kp_2d, axis=0)
+ gt_conf_2d = np.concatenate(self.gt_conf_2d, axis=0)
+ scale = np.concatenate(self.scale, axis=0)
+ assert pred_kp_2d.shape == gt_kp_2d.shape
+ assert pred_kp_2d[..., 0].shape == gt_conf_2d.shape
+ assert pred_kp_2d.shape[1] == 1 # num_samples
+ assert scale.shape[0] == gt_conf_2d.shape[0] # num_samples
+
+ pcks = [
+ self.keypoint_pck_accuracy(
+ pred_kp_2d[:, 0, :, :],
+ gt_kp_2d[:, 0, :, :],
+ gt_conf_2d[:, 0, :]>0.5,
+ thr=thr,
+ scale = scale[:,None]
+ )
+ for thr in self.thresholds
+ ]
+ return pcks
+
+ def keypoint_pck_accuracy(self, pred, gt, conf, thr, scale):
+ dist = np.sqrt(np.sum((pred-gt)**2, axis=2))
+ all_joints = conf>0.5
+ correct_joints = np.logical_and(dist<=scale*thr, all_joints)
+ pck = correct_joints.sum(axis=0)/all_joints.sum(axis=0)
+ return pck, pck.mean(), pck.shape[0]
+
+ def __call__(self, output: Dict, batch: Dict, opt_output: Optional[Dict] = None):
+ """
+ Evaluate current batch.
+ Args:
+ output (Dict): Regression output.
+ batch (Dict): Dictionary containing images and their corresponding annotations.
+ opt_output (Dict): Optimization output.
+ """
+ pred_keypoints_2d = output['pred_keypoints_2d'].detach()
+ num_samples = 1
+ batch_size = pred_keypoints_2d.shape[0]
+
+ right = batch['right'].detach()
+ pred_keypoints_2d[:,:,0] = (2*right[:,None]-1)*pred_keypoints_2d[:,:,0]
+ box_size = batch['box_size'].detach()
+ box_center = batch['box_center'].detach()
+ bbox_expand_factor = batch['bbox_expand_factor'].detach()
+ scale = box_size/bbox_expand_factor
+ bbox_expand_factor = bbox_expand_factor[:,None,None,None]
+ pred_keypoints_2d = pred_keypoints_2d*box_size[:,None,None]+box_center[:,None]
+ pred_keypoints_2d = pred_keypoints_2d[:,None,:,:]
+ gt_keypoints_2d = batch['orig_keypoints_2d'][:,None,:,:].repeat(1, num_samples, 1, 1)
+
+ self.pred_kp_2d.append(pred_keypoints_2d[:, :, :, :2].detach().cpu().numpy())
+ self.gt_conf_2d.append(gt_keypoints_2d[:, :, :, -1].detach().cpu().numpy())
+ self.gt_kp_2d.append(gt_keypoints_2d[:, :, :, :2].detach().cpu().numpy())
+ self.scale.append(scale.detach().cpu().numpy())
+
+ self.counter += batch_size
\ No newline at end of file
diff --git a/wilor/utils/pylogger.py b/wilor/utils/pylogger.py
new file mode 100644
index 0000000000000000000000000000000000000000..92ffa71893ec20acde65e44d899334a38d8d1333
--- /dev/null
+++ b/wilor/utils/pylogger.py
@@ -0,0 +1,17 @@
+import logging
+
+from pytorch_lightning.utilities import rank_zero_only
+
+
+def get_pylogger(name=__name__) -> logging.Logger:
+ """Initializes multi-GPU-friendly python command line logger."""
+
+ logger = logging.getLogger(name)
+
+ # this ensures all logging levels get marked with the rank zero decorator
+ # otherwise logs would get multiplied for each GPU process in multi-GPU setup
+ logging_levels = ("debug", "info", "warning", "error", "exception", "fatal", "critical")
+ for level in logging_levels:
+ setattr(logger, level, rank_zero_only(getattr(logger, level)))
+
+ return logger
diff --git a/wilor/utils/render_openpose.py b/wilor/utils/render_openpose.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e51ee8e15f40b85e2766e1f9da42183da0d3d46
--- /dev/null
+++ b/wilor/utils/render_openpose.py
@@ -0,0 +1,191 @@
+"""
+Render OpenPose keypoints.
+Code was ported to Python from the official C++ implementation https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/src/openpose/utilities/keypoint.cpp
+"""
+import cv2
+import math
+import numpy as np
+from typing import List, Tuple
+
+def get_keypoints_rectangle(keypoints: np.array, threshold: float) -> Tuple[float, float, float]:
+ """
+ Compute rectangle enclosing keypoints above the threshold.
+ Args:
+ keypoints (np.array): Keypoint array of shape (N, 3).
+ threshold (float): Confidence visualization threshold.
+ Returns:
+ Tuple[float, float, float]: Rectangle width, height and area.
+ """
+ valid_ind = keypoints[:, -1] > threshold
+ if valid_ind.sum() > 0:
+ valid_keypoints = keypoints[valid_ind][:, :-1]
+ max_x = valid_keypoints[:,0].max()
+ max_y = valid_keypoints[:,1].max()
+ min_x = valid_keypoints[:,0].min()
+ min_y = valid_keypoints[:,1].min()
+ width = max_x - min_x
+ height = max_y - min_y
+ area = width * height
+ return width, height, area
+ else:
+ return 0,0,0
+
+def render_keypoints(img: np.array,
+ keypoints: np.array,
+ pairs: List,
+ colors: List,
+ thickness_circle_ratio: float,
+ thickness_line_ratio_wrt_circle: float,
+ pose_scales: List,
+ threshold: float = 0.1,
+ alpha: float = 1.0) -> np.array:
+ """
+ Render keypoints on input image.
+ Args:
+ img (np.array): Input image of shape (H, W, 3) with pixel values in the [0,255] range.
+ keypoints (np.array): Keypoint array of shape (N, 3).
+ pairs (List): List of keypoint pairs per limb.
+ colors: (List): List of colors per keypoint.
+ thickness_circle_ratio (float): Circle thickness ratio.
+ thickness_line_ratio_wrt_circle (float): Line thickness ratio wrt the circle.
+ pose_scales (List): List of pose scales.
+ threshold (float): Only visualize keypoints with confidence above the threshold.
+ Returns:
+ (np.array): Image of shape (H, W, 3) with keypoints drawn on top of the original image.
+ """
+ img_orig = img.copy()
+ width, height = img.shape[1], img.shape[2]
+ area = width * height
+
+ lineType = 8
+ shift = 0
+ numberColors = len(colors)
+ thresholdRectangle = 0.1
+
+ person_width, person_height, person_area = get_keypoints_rectangle(keypoints, thresholdRectangle)
+ if person_area > 0:
+ ratioAreas = min(1, max(person_width / width, person_height / height))
+ thicknessRatio = np.maximum(np.round(math.sqrt(area) * thickness_circle_ratio * ratioAreas), 2)
+ thicknessCircle = np.maximum(1, thicknessRatio if ratioAreas > 0.05 else -np.ones_like(thicknessRatio))
+ thicknessLine = np.maximum(1, np.round(thicknessRatio * thickness_line_ratio_wrt_circle))
+ radius = thicknessRatio / 2
+
+ img = np.ascontiguousarray(img.copy())
+ for i, pair in enumerate(pairs):
+ index1, index2 = pair
+ if keypoints[index1, -1] > threshold and keypoints[index2, -1] > threshold:
+ thicknessLineScaled = int(round(min(thicknessLine[index1], thicknessLine[index2]) * pose_scales[0]))
+ colorIndex = index2
+ color = colors[colorIndex % numberColors]
+ keypoint1 = keypoints[index1, :-1].astype(np.int_)
+ keypoint2 = keypoints[index2, :-1].astype(np.int_)
+ cv2.line(img, tuple(keypoint1.tolist()), tuple(keypoint2.tolist()), tuple(color.tolist()), thicknessLineScaled, lineType, shift)
+ for part in range(len(keypoints)):
+ faceIndex = part
+ if keypoints[faceIndex, -1] > threshold:
+ radiusScaled = int(round(radius[faceIndex] * pose_scales[0]))
+ thicknessCircleScaled = int(round(thicknessCircle[faceIndex] * pose_scales[0]))
+ colorIndex = part
+ color = colors[colorIndex % numberColors]
+ center = keypoints[faceIndex, :-1].astype(np.int_)
+ cv2.circle(img, tuple(center.tolist()), radiusScaled, tuple(color.tolist()), thicknessCircleScaled, lineType, shift)
+ return img
+
+def render_hand_keypoints(img, right_hand_keypoints, threshold=0.1, use_confidence=False, map_fn=lambda x: np.ones_like(x), alpha=1.0):
+ if use_confidence and map_fn is not None:
+ #thicknessCircleRatioLeft = 1./50 * map_fn(left_hand_keypoints[:, -1])
+ thicknessCircleRatioRight = 1./50 * map_fn(right_hand_keypoints[:, -1])
+ else:
+ #thicknessCircleRatioLeft = 1./50 * np.ones(left_hand_keypoints.shape[0])
+ thicknessCircleRatioRight = 1./50 * np.ones(right_hand_keypoints.shape[0])
+ thicknessLineRatioWRTCircle = 0.75
+ pairs = [0,1, 1,2, 2,3, 3,4, 0,5, 5,6, 6,7, 7,8, 0,9, 9,10, 10,11, 11,12, 0,13, 13,14, 14,15, 15,16, 0,17, 17,18, 18,19, 19,20]
+ pairs = np.array(pairs).reshape(-1,2)
+
+ colors = [100., 100., 100.,
+ 100., 0., 0.,
+ 150., 0., 0.,
+ 200., 0., 0.,
+ 255., 0., 0.,
+ 100., 100., 0.,
+ 150., 150., 0.,
+ 200., 200., 0.,
+ 255., 255., 0.,
+ 0., 100., 50.,
+ 0., 150., 75.,
+ 0., 200., 100.,
+ 0., 255., 125.,
+ 0., 50., 100.,
+ 0., 75., 150.,
+ 0., 100., 200.,
+ 0., 125., 255.,
+ 100., 0., 100.,
+ 150., 0., 150.,
+ 200., 0., 200.,
+ 255., 0., 255.]
+ colors = np.array(colors).reshape(-1,3)
+ #colors = np.zeros_like(colors)
+ poseScales = [1]
+ #img = render_keypoints(img, left_hand_keypoints, pairs, colors, thicknessCircleRatioLeft, thicknessLineRatioWRTCircle, poseScales, threshold, alpha=alpha)
+ img = render_keypoints(img, right_hand_keypoints, pairs, colors, thicknessCircleRatioRight, thicknessLineRatioWRTCircle, poseScales, threshold, alpha=alpha)
+ #img = render_keypoints(img, right_hand_keypoints, pairs, colors, thickness_circle_ratio, thickness_line_ratio_wrt_circle, pose_scales, 0.1)
+ return img
+
+def render_body_keypoints(img: np.array,
+ body_keypoints: np.array) -> np.array:
+ """
+ Render OpenPose body keypoints on input image.
+ Args:
+ img (np.array): Input image of shape (H, W, 3) with pixel values in the [0,255] range.
+ body_keypoints (np.array): Keypoint array of shape (N, 3); 3 <====> (x, y, confidence).
+ Returns:
+ (np.array): Image of shape (H, W, 3) with keypoints drawn on top of the original image.
+ """
+
+ thickness_circle_ratio = 1./75. * np.ones(body_keypoints.shape[0])
+ thickness_line_ratio_wrt_circle = 0.75
+ pairs = []
+ pairs = [1,8,1,2,1,5,2,3,3,4,5,6,6,7,8,9,9,10,10,11,8,12,12,13,13,14,1,0,0,15,15,17,0,16,16,18,14,19,19,20,14,21,11,22,22,23,11,24]
+ pairs = np.array(pairs).reshape(-1,2)
+ colors = [255., 0., 85.,
+ 255., 0., 0.,
+ 255., 85., 0.,
+ 255., 170., 0.,
+ 255., 255., 0.,
+ 170., 255., 0.,
+ 85., 255., 0.,
+ 0., 255., 0.,
+ 255., 0., 0.,
+ 0., 255., 85.,
+ 0., 255., 170.,
+ 0., 255., 255.,
+ 0., 170., 255.,
+ 0., 85., 255.,
+ 0., 0., 255.,
+ 255., 0., 170.,
+ 170., 0., 255.,
+ 255., 0., 255.,
+ 85., 0., 255.,
+ 0., 0., 255.,
+ 0., 0., 255.,
+ 0., 0., 255.,
+ 0., 255., 255.,
+ 0., 255., 255.,
+ 0., 255., 255.]
+ colors = np.array(colors).reshape(-1,3)
+ pose_scales = [1]
+ return render_keypoints(img, body_keypoints, pairs, colors, thickness_circle_ratio, thickness_line_ratio_wrt_circle, pose_scales, 0.1)
+
+def render_openpose(img: np.array,
+ hand_keypoints: np.array) -> np.array:
+ """
+ Render keypoints in the OpenPose format on input image.
+ Args:
+ img (np.array): Input image of shape (H, W, 3) with pixel values in the [0,255] range.
+ body_keypoints (np.array): Keypoint array of shape (N, 3); 3 <====> (x, y, confidence).
+ Returns:
+ (np.array): Image of shape (H, W, 3) with keypoints drawn on top of the original image.
+ """
+ #img = render_body_keypoints(img, body_keypoints)
+ img = render_hand_keypoints(img, hand_keypoints)
+ return img
diff --git a/wilor/utils/renderer.py b/wilor/utils/renderer.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e161bb05921e52a684427e3eb87c4f8739a5d89
--- /dev/null
+++ b/wilor/utils/renderer.py
@@ -0,0 +1,423 @@
+import os
+if 'PYOPENGL_PLATFORM' not in os.environ:
+ os.environ['PYOPENGL_PLATFORM'] = 'egl'
+import torch
+import numpy as np
+import pyrender
+import trimesh
+import cv2
+from yacs.config import CfgNode
+from typing import List, Optional
+
+def cam_crop_to_full(cam_bbox, box_center, box_size, img_size, focal_length=5000.):
+ # Convert cam_bbox to full image
+ img_w, img_h = img_size[:, 0], img_size[:, 1]
+ cx, cy, b = box_center[:, 0], box_center[:, 1], box_size
+ w_2, h_2 = img_w / 2., img_h / 2.
+ bs = b * cam_bbox[:, 0] + 1e-9
+ tz = 2 * focal_length / bs
+ tx = (2 * (cx - w_2) / bs) + cam_bbox[:, 1]
+ ty = (2 * (cy - h_2) / bs) + cam_bbox[:, 2]
+ full_cam = torch.stack([tx, ty, tz], dim=-1)
+ return full_cam
+
+def get_light_poses(n_lights=5, elevation=np.pi / 3, dist=12):
+ # get lights in a circle around origin at elevation
+ thetas = elevation * np.ones(n_lights)
+ phis = 2 * np.pi * np.arange(n_lights) / n_lights
+ poses = []
+ trans = make_translation(torch.tensor([0, 0, dist]))
+ for phi, theta in zip(phis, thetas):
+ rot = make_rotation(rx=-theta, ry=phi, order="xyz")
+ poses.append((rot @ trans).numpy())
+ return poses
+
+def make_translation(t):
+ return make_4x4_pose(torch.eye(3), t)
+
+def make_rotation(rx=0, ry=0, rz=0, order="xyz"):
+ Rx = rotx(rx)
+ Ry = roty(ry)
+ Rz = rotz(rz)
+ if order == "xyz":
+ R = Rz @ Ry @ Rx
+ elif order == "xzy":
+ R = Ry @ Rz @ Rx
+ elif order == "yxz":
+ R = Rz @ Rx @ Ry
+ elif order == "yzx":
+ R = Rx @ Rz @ Ry
+ elif order == "zyx":
+ R = Rx @ Ry @ Rz
+ elif order == "zxy":
+ R = Ry @ Rx @ Rz
+ return make_4x4_pose(R, torch.zeros(3))
+
+def make_4x4_pose(R, t):
+ """
+ :param R (*, 3, 3)
+ :param t (*, 3)
+ return (*, 4, 4)
+ """
+ dims = R.shape[:-2]
+ pose_3x4 = torch.cat([R, t.view(*dims, 3, 1)], dim=-1)
+ bottom = (
+ torch.tensor([0, 0, 0, 1], device=R.device)
+ .reshape(*(1,) * len(dims), 1, 4)
+ .expand(*dims, 1, 4)
+ )
+ return torch.cat([pose_3x4, bottom], dim=-2)
+
+
+def rotx(theta):
+ return torch.tensor(
+ [
+ [1, 0, 0],
+ [0, np.cos(theta), -np.sin(theta)],
+ [0, np.sin(theta), np.cos(theta)],
+ ],
+ dtype=torch.float32,
+ )
+
+
+def roty(theta):
+ return torch.tensor(
+ [
+ [np.cos(theta), 0, np.sin(theta)],
+ [0, 1, 0],
+ [-np.sin(theta), 0, np.cos(theta)],
+ ],
+ dtype=torch.float32,
+ )
+
+
+def rotz(theta):
+ return torch.tensor(
+ [
+ [np.cos(theta), -np.sin(theta), 0],
+ [np.sin(theta), np.cos(theta), 0],
+ [0, 0, 1],
+ ],
+ dtype=torch.float32,
+ )
+
+
+def create_raymond_lights() -> List[pyrender.Node]:
+ """
+ Return raymond light nodes for the scene.
+ """
+ thetas = np.pi * np.array([1.0 / 6.0, 1.0 / 6.0, 1.0 / 6.0])
+ phis = np.pi * np.array([0.0, 2.0 / 3.0, 4.0 / 3.0])
+
+ nodes = []
+
+ for phi, theta in zip(phis, thetas):
+ xp = np.sin(theta) * np.cos(phi)
+ yp = np.sin(theta) * np.sin(phi)
+ zp = np.cos(theta)
+
+ z = np.array([xp, yp, zp])
+ z = z / np.linalg.norm(z)
+ x = np.array([-z[1], z[0], 0.0])
+ if np.linalg.norm(x) == 0:
+ x = np.array([1.0, 0.0, 0.0])
+ x = x / np.linalg.norm(x)
+ y = np.cross(z, x)
+
+ matrix = np.eye(4)
+ matrix[:3,:3] = np.c_[x,y,z]
+ nodes.append(pyrender.Node(
+ light=pyrender.DirectionalLight(color=np.ones(3), intensity=1.0),
+ matrix=matrix
+ ))
+
+ return nodes
+
+class Renderer:
+
+ def __init__(self, cfg: CfgNode, faces: np.array):
+ """
+ Wrapper around the pyrender renderer to render MANO meshes.
+ Args:
+ cfg (CfgNode): Model config file.
+ faces (np.array): Array of shape (F, 3) containing the mesh faces.
+ """
+ self.cfg = cfg
+ self.focal_length = cfg.EXTRA.FOCAL_LENGTH
+ self.img_res = cfg.MODEL.IMAGE_SIZE
+
+ # add faces that make the hand mesh watertight
+ faces_new = np.array([[92, 38, 234],
+ [234, 38, 239],
+ [38, 122, 239],
+ [239, 122, 279],
+ [122, 118, 279],
+ [279, 118, 215],
+ [118, 117, 215],
+ [215, 117, 214],
+ [117, 119, 214],
+ [214, 119, 121],
+ [119, 120, 121],
+ [121, 120, 78],
+ [120, 108, 78],
+ [78, 108, 79]])
+ faces = np.concatenate([faces, faces_new], axis=0)
+
+ self.camera_center = [self.img_res // 2, self.img_res // 2]
+ self.faces = faces
+ self.faces_left = self.faces[:,[0,2,1]]
+
+ def __call__(self,
+ vertices: np.array,
+ camera_translation: np.array,
+ image: torch.Tensor,
+ full_frame: bool = False,
+ imgname: Optional[str] = None,
+ side_view=False, rot_angle=90,
+ mesh_base_color=(1.0, 1.0, 0.9),
+ scene_bg_color=(0,0,0),
+ return_rgba=False,
+ ) -> np.array:
+ """
+ Render meshes on input image
+ Args:
+ vertices (np.array): Array of shape (V, 3) containing the mesh vertices.
+ camera_translation (np.array): Array of shape (3,) with the camera translation.
+ image (torch.Tensor): Tensor of shape (3, H, W) containing the image crop with normalized pixel values.
+ full_frame (bool): If True, then render on the full image.
+ imgname (Optional[str]): Contains the original image filenamee. Used only if full_frame == True.
+ """
+
+ if full_frame:
+ image = cv2.imread(imgname).astype(np.float32)[:, :, ::-1] / 255.
+ else:
+ image = image.clone() * torch.tensor(self.cfg.MODEL.IMAGE_STD, device=image.device).reshape(3,1,1)
+ image = image + torch.tensor(self.cfg.MODEL.IMAGE_MEAN, device=image.device).reshape(3,1,1)
+ image = image.permute(1, 2, 0).cpu().numpy()
+
+ renderer = pyrender.OffscreenRenderer(viewport_width=image.shape[1],
+ viewport_height=image.shape[0],
+ point_size=1.0)
+ material = pyrender.MetallicRoughnessMaterial(
+ metallicFactor=0.0,
+ alphaMode='OPAQUE',
+ baseColorFactor=(*mesh_base_color, 1.0))
+
+ camera_translation[0] *= -1.
+
+ mesh = trimesh.Trimesh(vertices.copy(), self.faces.copy())
+ if side_view:
+ rot = trimesh.transformations.rotation_matrix(
+ np.radians(rot_angle), [0, 1, 0])
+ mesh.apply_transform(rot)
+ rot = trimesh.transformations.rotation_matrix(
+ np.radians(180), [1, 0, 0])
+ mesh.apply_transform(rot)
+ mesh = pyrender.Mesh.from_trimesh(mesh, material=material)
+
+ scene = pyrender.Scene(bg_color=[*scene_bg_color, 0.0],
+ ambient_light=(0.3, 0.3, 0.3))
+ scene.add(mesh, 'mesh')
+
+ camera_pose = np.eye(4)
+ camera_pose[:3, 3] = camera_translation
+ camera_center = [image.shape[1] / 2., image.shape[0] / 2.]
+ camera = pyrender.IntrinsicsCamera(fx=self.focal_length, fy=self.focal_length,
+ cx=camera_center[0], cy=camera_center[1], zfar=1e12)
+ scene.add(camera, pose=camera_pose)
+
+
+ light_nodes = create_raymond_lights()
+ for node in light_nodes:
+ scene.add_node(node)
+
+ color, rend_depth = renderer.render(scene, flags=pyrender.RenderFlags.RGBA)
+ color = color.astype(np.float32) / 255.0
+ renderer.delete()
+
+ if return_rgba:
+ return color
+
+ valid_mask = (color[:, :, -1])[:, :, np.newaxis]
+ if not side_view:
+ output_img = (color[:, :, :3] * valid_mask + (1 - valid_mask) * image)
+ else:
+ output_img = color[:, :, :3]
+
+ output_img = output_img.astype(np.float32)
+ return output_img
+
+ def vertices_to_trimesh(self, vertices, camera_translation, mesh_base_color=(1.0, 1.0, 0.9),
+ rot_axis=[1,0,0], rot_angle=0, is_right=1):
+ # material = pyrender.MetallicRoughnessMaterial(
+ # metallicFactor=0.0,
+ # alphaMode='OPAQUE',
+ # baseColorFactor=(*mesh_base_color, 1.0))
+ vertex_colors = np.array([(*mesh_base_color, 1.0)] * vertices.shape[0])
+ if is_right:
+ mesh = trimesh.Trimesh(vertices.copy() + camera_translation, self.faces.copy(), vertex_colors=vertex_colors)
+ else:
+ mesh = trimesh.Trimesh(vertices.copy() + camera_translation, self.faces_left.copy(), vertex_colors=vertex_colors)
+ # mesh = trimesh.Trimesh(vertices.copy(), self.faces.copy())
+
+ rot = trimesh.transformations.rotation_matrix(
+ np.radians(rot_angle), rot_axis)
+ mesh.apply_transform(rot)
+
+ rot = trimesh.transformations.rotation_matrix(
+ np.radians(180), [1, 0, 0])
+ mesh.apply_transform(rot)
+ return mesh
+
+ def render_rgba(
+ self,
+ vertices: np.array,
+ cam_t = None,
+ rot=None,
+ rot_axis=[1,0,0],
+ rot_angle=0,
+ camera_z=3,
+ # camera_translation: np.array,
+ mesh_base_color=(1.0, 1.0, 0.9),
+ scene_bg_color=(0,0,0),
+ render_res=[256, 256],
+ focal_length=None,
+ is_right=None,
+ ):
+
+ renderer = pyrender.OffscreenRenderer(viewport_width=render_res[0],
+ viewport_height=render_res[1],
+ point_size=1.0)
+ # material = pyrender.MetallicRoughnessMaterial(
+ # metallicFactor=0.0,
+ # alphaMode='OPAQUE',
+ # baseColorFactor=(*mesh_base_color, 1.0))
+
+ focal_length = focal_length if focal_length is not None else self.focal_length
+
+ if cam_t is not None:
+ camera_translation = cam_t.copy()
+ camera_translation[0] *= -1.
+ else:
+ camera_translation = np.array([0, 0, camera_z * focal_length/render_res[1]])
+
+ mesh = self.vertices_to_trimesh(vertices, np.array([0, 0, 0]), mesh_base_color, rot_axis, rot_angle, is_right=is_right)
+ mesh = pyrender.Mesh.from_trimesh(mesh)
+ # mesh = pyrender.Mesh.from_trimesh(mesh, material=material)
+
+ scene = pyrender.Scene(bg_color=[*scene_bg_color, 0.0],
+ ambient_light=(0.3, 0.3, 0.3))
+ scene.add(mesh, 'mesh')
+
+ camera_pose = np.eye(4)
+ camera_pose[:3, 3] = camera_translation
+ camera_center = [render_res[0] / 2., render_res[1] / 2.]
+ camera = pyrender.IntrinsicsCamera(fx=focal_length, fy=focal_length,
+ cx=camera_center[0], cy=camera_center[1], zfar=1e12)
+
+ # Create camera node and add it to pyRender scene
+ camera_node = pyrender.Node(camera=camera, matrix=camera_pose)
+ scene.add_node(camera_node)
+ self.add_point_lighting(scene, camera_node)
+ self.add_lighting(scene, camera_node)
+
+ light_nodes = create_raymond_lights()
+ for node in light_nodes:
+ scene.add_node(node)
+
+ color, rend_depth = renderer.render(scene, flags=pyrender.RenderFlags.RGBA)
+ color = color.astype(np.float32) / 255.0
+ renderer.delete()
+
+ return color
+
+ def render_rgba_multiple(
+ self,
+ vertices: List[np.array],
+ cam_t: List[np.array],
+ rot_axis=[1,0,0],
+ rot_angle=0,
+ mesh_base_color=(1.0, 1.0, 0.9),
+ scene_bg_color=(0,0,0),
+ render_res=[256, 256],
+ focal_length=None,
+ is_right=None,
+ ):
+
+ renderer = pyrender.OffscreenRenderer(viewport_width=render_res[0],
+ viewport_height=render_res[1],
+ point_size=1.0)
+ # material = pyrender.MetallicRoughnessMaterial(
+ # metallicFactor=0.0,
+ # alphaMode='OPAQUE',
+ # baseColorFactor=(*mesh_base_color, 1.0))
+
+ if is_right is None:
+ is_right = [1 for _ in range(len(vertices))]
+
+ mesh_list = [pyrender.Mesh.from_trimesh(self.vertices_to_trimesh(vvv, ttt.copy(), mesh_base_color, rot_axis, rot_angle, is_right=sss)) for vvv,ttt,sss in zip(vertices, cam_t, is_right)]
+
+ scene = pyrender.Scene(bg_color=[*scene_bg_color, 0.0],
+ ambient_light=(0.3, 0.3, 0.3))
+ for i,mesh in enumerate(mesh_list):
+ scene.add(mesh, f'mesh_{i}')
+
+ camera_pose = np.eye(4)
+ # camera_pose[:3, 3] = camera_translation
+ camera_center = [render_res[0] / 2., render_res[1] / 2.]
+ focal_length = focal_length if focal_length is not None else self.focal_length
+ camera = pyrender.IntrinsicsCamera(fx=focal_length, fy=focal_length,
+ cx=camera_center[0], cy=camera_center[1], zfar=1e12)
+
+ # Create camera node and add it to pyRender scene
+ camera_node = pyrender.Node(camera=camera, matrix=camera_pose)
+ scene.add_node(camera_node)
+ self.add_point_lighting(scene, camera_node)
+ self.add_lighting(scene, camera_node)
+
+ light_nodes = create_raymond_lights()
+ for node in light_nodes:
+ scene.add_node(node)
+
+ color, rend_depth = renderer.render(scene, flags=pyrender.RenderFlags.RGBA)
+ color = color.astype(np.float32) / 255.0
+ renderer.delete()
+
+ return color
+
+ def add_lighting(self, scene, cam_node, color=np.ones(3), intensity=1.0):
+ # from phalp.visualize.py_renderer import get_light_poses
+ light_poses = get_light_poses()
+ light_poses.append(np.eye(4))
+ cam_pose = scene.get_pose(cam_node)
+ for i, pose in enumerate(light_poses):
+ matrix = cam_pose @ pose
+ node = pyrender.Node(
+ name=f"light-{i:02d}",
+ light=pyrender.DirectionalLight(color=color, intensity=intensity),
+ matrix=matrix,
+ )
+ if scene.has_node(node):
+ continue
+ scene.add_node(node)
+
+ def add_point_lighting(self, scene, cam_node, color=np.ones(3), intensity=1.0):
+ # from phalp.visualize.py_renderer import get_light_poses
+ light_poses = get_light_poses(dist=0.5)
+ light_poses.append(np.eye(4))
+ cam_pose = scene.get_pose(cam_node)
+ for i, pose in enumerate(light_poses):
+ matrix = cam_pose @ pose
+ # node = pyrender.Node(
+ # name=f"light-{i:02d}",
+ # light=pyrender.DirectionalLight(color=color, intensity=intensity),
+ # matrix=matrix,
+ # )
+ node = pyrender.Node(
+ name=f"plight-{i:02d}",
+ light=pyrender.PointLight(color=color, intensity=intensity),
+ matrix=matrix,
+ )
+ if scene.has_node(node):
+ continue
+ scene.add_node(node)
diff --git a/wilor/utils/rich_utils.py b/wilor/utils/rich_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..19f97494ed2958ec2c3d75c772360b5367f2dc7b
--- /dev/null
+++ b/wilor/utils/rich_utils.py
@@ -0,0 +1,105 @@
+from pathlib import Path
+from typing import Sequence
+
+import rich
+import rich.syntax
+import rich.tree
+from hydra.core.hydra_config import HydraConfig
+from omegaconf import DictConfig, OmegaConf, open_dict
+from pytorch_lightning.utilities import rank_zero_only
+from rich.prompt import Prompt
+
+from . import pylogger
+
+log = pylogger.get_pylogger(__name__)
+
+
+@rank_zero_only
+def print_config_tree(
+ cfg: DictConfig,
+ print_order: Sequence[str] = (
+ "datamodule",
+ "model",
+ "callbacks",
+ "logger",
+ "trainer",
+ "paths",
+ "extras",
+ ),
+ resolve: bool = False,
+ save_to_file: bool = False,
+) -> None:
+ """Prints content of DictConfig using Rich library and its tree structure.
+
+ Args:
+ cfg (DictConfig): Configuration composed by Hydra.
+ print_order (Sequence[str], optional): Determines in what order config components are printed.
+ resolve (bool, optional): Whether to resolve reference fields of DictConfig.
+ save_to_file (bool, optional): Whether to export config to the hydra output folder.
+ """
+
+ style = "dim"
+ tree = rich.tree.Tree("CONFIG", style=style, guide_style=style)
+
+ queue = []
+
+ # add fields from `print_order` to queue
+ for field in print_order:
+ queue.append(field) if field in cfg else log.warning(
+ f"Field '{field}' not found in config. Skipping '{field}' config printing..."
+ )
+
+ # add all the other fields to queue (not specified in `print_order`)
+ for field in cfg:
+ if field not in queue:
+ queue.append(field)
+
+ # generate config tree from queue
+ for field in queue:
+ branch = tree.add(field, style=style, guide_style=style)
+
+ config_group = cfg[field]
+ if isinstance(config_group, DictConfig):
+ branch_content = OmegaConf.to_yaml(config_group, resolve=resolve)
+ else:
+ branch_content = str(config_group)
+
+ branch.add(rich.syntax.Syntax(branch_content, "yaml"))
+
+ # print config tree
+ rich.print(tree)
+
+ # save config tree to file
+ if save_to_file:
+ with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file:
+ rich.print(tree, file=file)
+
+
+@rank_zero_only
+def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None:
+ """Prompts user to input tags from command line if no tags are provided in config."""
+
+ if not cfg.get("tags"):
+ if "id" in HydraConfig().cfg.hydra.job:
+ raise ValueError("Specify tags before launching a multirun!")
+
+ log.warning("No tags provided in config. Prompting user to input tags...")
+ tags = Prompt.ask("Enter a list of comma separated tags", default="dev")
+ tags = [t.strip() for t in tags.split(",") if t != ""]
+
+ with open_dict(cfg):
+ cfg.tags = tags
+
+ log.info(f"Tags: {cfg.tags}")
+
+ if save_to_file:
+ with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file:
+ rich.print(cfg.tags, file=file)
+
+
+if __name__ == "__main__":
+ from hydra import compose, initialize
+
+ with initialize(version_base="1.2", config_path="../../configs"):
+ cfg = compose(config_name="train.yaml", return_hydra_config=False, overrides=[])
+ print_config_tree(cfg, resolve=False, save_to_file=False)
diff --git a/wilor/utils/skeleton_renderer.py b/wilor/utils/skeleton_renderer.py
new file mode 100644
index 0000000000000000000000000000000000000000..46a5df75bff887eab00984eeb5be3c1f6e752960
--- /dev/null
+++ b/wilor/utils/skeleton_renderer.py
@@ -0,0 +1,124 @@
+import torch
+import numpy as np
+import trimesh
+from typing import Optional
+from yacs.config import CfgNode
+
+from .geometry import perspective_projection
+from .render_openpose import render_openpose
+
+class SkeletonRenderer:
+
+ def __init__(self, cfg: CfgNode):
+ """
+ Object used to render 3D keypoints. Faster for use during training.
+ Args:
+ cfg (CfgNode): Model config file.
+ """
+ self.cfg = cfg
+
+ def __call__(self,
+ pred_keypoints_3d: torch.Tensor,
+ gt_keypoints_3d: torch.Tensor,
+ gt_keypoints_2d: torch.Tensor,
+ images: Optional[np.array] = None,
+ camera_translation: Optional[torch.Tensor] = None) -> np.array:
+ """
+ Render batch of 3D keypoints.
+ Args:
+ pred_keypoints_3d (torch.Tensor): Tensor of shape (B, S, N, 3) containing a batch of predicted 3D keypoints, with S samples per image.
+ gt_keypoints_3d (torch.Tensor): Tensor of shape (B, N, 4) containing corresponding ground truth 3D keypoints; last value is the confidence.
+ gt_keypoints_2d (torch.Tensor): Tensor of shape (B, N, 3) containing corresponding ground truth 2D keypoints.
+ images (torch.Tensor): Tensor of shape (B, H, W, 3) containing images with values in the [0,255] range.
+ camera_translation (torch.Tensor): Tensor of shape (B, 3) containing the camera translation.
+ Returns:
+ np.array : Image with the following layout. Each row contains the a) input image,
+ b) image with gt 2D keypoints,
+ c) image with projected gt 3D keypoints,
+ d_1, ... , d_S) image with projected predicted 3D keypoints,
+ e) gt 3D keypoints rendered from a side view,
+ f_1, ... , f_S) predicted 3D keypoints frorm a side view
+ """
+ batch_size = pred_keypoints_3d.shape[0]
+# num_samples = pred_keypoints_3d.shape[1]
+ pred_keypoints_3d = pred_keypoints_3d.clone().cpu().float()
+ gt_keypoints_3d = gt_keypoints_3d.clone().cpu().float()
+ gt_keypoints_3d[:, :, :-1] = gt_keypoints_3d[:, :, :-1] - gt_keypoints_3d[:, [0], :-1] + pred_keypoints_3d[:, [0]]
+ gt_keypoints_2d = gt_keypoints_2d.clone().cpu().float().numpy()
+ gt_keypoints_2d[:, :, :-1] = self.cfg.MODEL.IMAGE_SIZE * (gt_keypoints_2d[:, :, :-1] + 1.0) / 2.0
+
+ #openpose_indices = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]
+ #gt_indices = [12, 8, 7, 6, 9, 10, 11, 14, 2, 1, 0, 3, 4, 5]
+ #gt_indices = [25 + i for i in gt_indices]
+ openpose_indices = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20]
+ gt_indices = openpose_indices
+ keypoints_to_render = torch.ones(batch_size, gt_keypoints_3d.shape[1], 1)
+ rotation = torch.eye(3).unsqueeze(0)
+ if camera_translation is None:
+ camera_translation = torch.tensor([0.0, 0.0, 2 * self.cfg.EXTRA.FOCAL_LENGTH / (0.8 * self.cfg.MODEL.IMAGE_SIZE)]).unsqueeze(0).repeat(batch_size, 1)
+ else:
+ camera_translation = camera_translation.cpu()
+
+ if images is None:
+ images = np.zeros((batch_size, self.cfg.MODEL.IMAGE_SIZE, self.cfg.MODEL.IMAGE_SIZE, 3))
+ focal_length = torch.tensor([self.cfg.EXTRA.FOCAL_LENGTH, self.cfg.EXTRA.FOCAL_LENGTH]).reshape(1, 2)
+ camera_center = torch.tensor([self.cfg.MODEL.IMAGE_SIZE, self.cfg.MODEL.IMAGE_SIZE], dtype=torch.float).reshape(1, 2) / 2.
+ gt_keypoints_3d_proj = perspective_projection(gt_keypoints_3d[:, :, :-1], rotation=rotation.repeat(batch_size, 1, 1), translation=camera_translation[:, :], focal_length=focal_length.repeat(batch_size, 1), camera_center=camera_center.repeat(batch_size, 1))
+ pred_keypoints_3d_proj = perspective_projection(pred_keypoints_3d.reshape(batch_size, -1, 3), rotation=rotation.repeat(batch_size, 1, 1), translation=camera_translation.reshape(batch_size, -1), focal_length=focal_length.repeat(batch_size, 1), camera_center=camera_center.repeat(batch_size, 1)).reshape(batch_size, -1, 2)
+ gt_keypoints_3d_proj = torch.cat([gt_keypoints_3d_proj, gt_keypoints_3d[:, :, [-1]]], dim=-1).cpu().numpy()
+ pred_keypoints_3d_proj = torch.cat([pred_keypoints_3d_proj, keypoints_to_render.reshape(batch_size, -1, 1)], dim=-1).cpu().numpy()
+ rows = []
+ # Rotate keypoints to visualize side view
+ R = torch.tensor(trimesh.transformations.rotation_matrix(np.radians(90), [0, 1, 0])[:3, :3]).float()
+ gt_keypoints_3d_side = gt_keypoints_3d.clone()
+ gt_keypoints_3d_side[:, :, :-1] = torch.einsum('bni,ij->bnj', gt_keypoints_3d_side[:, :, :-1], R)
+ pred_keypoints_3d_side = pred_keypoints_3d.clone()
+ pred_keypoints_3d_side = torch.einsum('bni,ij->bnj', pred_keypoints_3d_side, R)
+ gt_keypoints_3d_proj_side = perspective_projection(gt_keypoints_3d_side[:, :, :-1], rotation=rotation.repeat(batch_size, 1, 1), translation=camera_translation[:, :], focal_length=focal_length.repeat(batch_size, 1), camera_center=camera_center.repeat(batch_size, 1))
+ pred_keypoints_3d_proj_side = perspective_projection(pred_keypoints_3d_side.reshape(batch_size, -1, 3), rotation=rotation.repeat(batch_size, 1, 1), translation=camera_translation.reshape(batch_size, -1), focal_length=focal_length.repeat(batch_size, 1), camera_center=camera_center.repeat(batch_size, 1)).reshape(batch_size, -1, 2)
+ gt_keypoints_3d_proj_side = torch.cat([gt_keypoints_3d_proj_side, gt_keypoints_3d_side[:, :, [-1]]], dim=-1).cpu().numpy()
+ pred_keypoints_3d_proj_side = torch.cat([pred_keypoints_3d_proj_side, keypoints_to_render.reshape(batch_size, -1, 1)], dim=-1).cpu().numpy()
+ for i in range(batch_size):
+ img = images[i]
+ side_img = np.zeros((self.cfg.MODEL.IMAGE_SIZE, self.cfg.MODEL.IMAGE_SIZE, 3))
+ # gt 2D keypoints
+ body_keypoints_2d = gt_keypoints_2d[i, :21].copy()
+ for op, gt in zip(openpose_indices, gt_indices):
+ if gt_keypoints_2d[i, gt, -1] > body_keypoints_2d[op, -1]:
+ body_keypoints_2d[op] = gt_keypoints_2d[i, gt]
+ gt_keypoints_img = render_openpose(img, body_keypoints_2d) / 255.
+ # gt 3D keypoints
+ body_keypoints_3d_proj = gt_keypoints_3d_proj[i, :21].copy()
+ for op, gt in zip(openpose_indices, gt_indices):
+ if gt_keypoints_3d_proj[i, gt, -1] > body_keypoints_3d_proj[op, -1]:
+ body_keypoints_3d_proj[op] = gt_keypoints_3d_proj[i, gt]
+ gt_keypoints_3d_proj_img = render_openpose(img, body_keypoints_3d_proj) / 255.
+ # gt 3D keypoints from the side
+ body_keypoints_3d_proj = gt_keypoints_3d_proj_side[i, :21].copy()
+ for op, gt in zip(openpose_indices, gt_indices):
+ if gt_keypoints_3d_proj_side[i, gt, -1] > body_keypoints_3d_proj[op, -1]:
+ body_keypoints_3d_proj[op] = gt_keypoints_3d_proj_side[i, gt]
+ gt_keypoints_3d_proj_img_side = render_openpose(side_img, body_keypoints_3d_proj) / 255.
+ # pred 3D keypoints
+ pred_keypoints_3d_proj_imgs = []
+ body_keypoints_3d_proj = pred_keypoints_3d_proj[i, :21].copy()
+ for op, gt in zip(openpose_indices, gt_indices):
+ if pred_keypoints_3d_proj[i, gt, -1] >= body_keypoints_3d_proj[op, -1]:
+ body_keypoints_3d_proj[op] = pred_keypoints_3d_proj[i, gt]
+ pred_keypoints_3d_proj_imgs.append(render_openpose(img, body_keypoints_3d_proj) / 255.)
+ pred_keypoints_3d_proj_img = np.concatenate(pred_keypoints_3d_proj_imgs, axis=1)
+ # gt 3D keypoints from the side
+ pred_keypoints_3d_proj_imgs_side = []
+ body_keypoints_3d_proj = pred_keypoints_3d_proj_side[i, :21].copy()
+ for op, gt in zip(openpose_indices, gt_indices):
+ if pred_keypoints_3d_proj_side[i, gt, -1] >= body_keypoints_3d_proj[op, -1]:
+ body_keypoints_3d_proj[op] = pred_keypoints_3d_proj_side[i, gt]
+ pred_keypoints_3d_proj_imgs_side.append(render_openpose(side_img, body_keypoints_3d_proj) / 255.)
+ pred_keypoints_3d_proj_img_side = np.concatenate(pred_keypoints_3d_proj_imgs_side, axis=1)
+ rows.append(np.concatenate((gt_keypoints_img, gt_keypoints_3d_proj_img, pred_keypoints_3d_proj_img, gt_keypoints_3d_proj_img_side, pred_keypoints_3d_proj_img_side), axis=1))
+ # Concatenate images
+ img = np.concatenate(rows, axis=0)
+ img[:, ::self.cfg.MODEL.IMAGE_SIZE, :] = 1.0
+ img[::self.cfg.MODEL.IMAGE_SIZE, :, :] = 1.0
+ img[:, (1+1+1)*self.cfg.MODEL.IMAGE_SIZE, :] = 0.5
+ return img