diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..a049828f0525233dcffc8a52d34dd8e2689373cd 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +assets/test1.png filter=lfs diff=lfs merge=lfs -text +assets/test2.png filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..c0f55d6c64b31a249f0fe376a9038e4111cc3616 --- /dev/null +++ b/.gitignore @@ -0,0 +1,139 @@ +# Byte-compiled / optimized / DLL files +__pycache__ +*.py[cod] +*$py.class + +# pyc +*.pyc + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# VSCode +.vscode + +*.swp +*.h5 +*.mp4 diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..e74e0a3fdc022a90a91a875481b4b5021ba0c085 --- /dev/null +++ b/app.py @@ -0,0 +1,202 @@ +import os +import sys +os.environ["PYOPENGL_PLATFORM"] = "egl" +os.environ["MESA_GL_VERSION_OVERRIDE"] = "4.1" + +import gradio as gr +#import spaces +import cv2 +import numpy as np +import torch +from ultralytics import YOLO +from pathlib import Path +import argparse +import json +from torchvision import transforms +from typing import Dict, Optional +from PIL import Image, ImageDraw +from lang_sam import LangSAM + +from wilor.models import load_wilor +from wilor.utils import recursive_to +from wilor.datasets.vitdet_dataset import ViTDetDataset +from hort.models import load_hort +from hort.utils.renderer import Renderer, cam_crop_to_new +from hort.utils.img_utils import process_bbox, generate_patch_image, PerspectiveCamera +from ultralytics import YOLO +LIGHT_PURPLE=(0.25098039, 0.274117647, 0.65882353) +STEEL_BLUE=(0.2745098, 0.5098039, 0.7058824) + +# Download and load checkpoints +wilor_model, wilor_model_cfg = load_wilor(checkpoint_path = './pretrained_models/wilor_final.ckpt' , cfg_path= './pretrained_models/model_config.yaml') +hand_detector = YOLO('./pretrained_models/detector.pt') +# Setup the renderer +renderer = Renderer(wilor_model_cfg, faces=wilor_model.mano.faces) +# Setup the SAM model +sam_model = LangSAM(sam_type="sam2.1_hiera_large") +# Setup the HORT model +hort_model = load_hort("./pretrained_models/hort_final.pth.tar") + +device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') +wilor_model = wilor_model.to(device) +hand_detector = hand_detector.to(device) +hort_model = hort_model.to(device) +wilor_model.eval() +hort_model.eval() + +image_transform = transforms.Compose([transforms.ToPILImage(), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) + +@spaces.GPU() +def run_model(image, conf, IoU_threshold=0.5): + img_cv2 = image[..., ::-1] + img_pil = Image.fromarray(image) + + pred_obj = sam_model.predict([img_pil], ["manipulated object"]) + pred_hand = sam_model.predict([img_pil], ["hand"]) + + bbox_obj = pred_obj[0]["boxes"][0].reshape((-1, 2)) + mask_obj = pred_obj[0]["masks"][0] + bbox_hand = pred_hand[0]["boxes"][0].reshape((-1, 2)) + mask_hand = pred_hand[0]["masks"][0] + + tl = np.min(np.concatenate([bbox_obj, bbox_hand], axis=0), axis=0) + br = np.max(np.concatenate([bbox_obj, bbox_hand], axis=0), axis=0) + box_size = br - tl + bbox = np.concatenate([tl - 10, box_size + 20], axis=0) + ho_bbox = process_bbox(bbox) + + detections = hand_detector(img_cv2, conf=conf, verbose=False, iou=IoU_threshold)[0] + + bboxes = [] + is_right = [] + for det in detections: + Bbox = det.boxes.data.cpu().detach().squeeze().numpy() + is_right.append(det.boxes.cls.cpu().detach().squeeze().item()) + bboxes.append(Bbox[:4].tolist()) + + if len(bboxes) == 1: + boxes = np.stack(bboxes) + right = np.stack(is_right) + if not right: + new_x1 = img_cv2.shape[1] - boxes[0][2] + new_x2 = img_cv2.shape[1] - boxes[0][0] + boxes[0][0] = new_x1 + boxes[0][2] = new_x2 + ho_bbox[0] = img_cv2.shape[1] - (ho_bbox[0] + ho_bbox[2]) + img_cv2 = cv2.flip(img_cv2, 1) + right[0] = 1. + crop_img_cv2, _ = generate_patch_image(img_cv2, ho_bbox, (224, 224), 0, 1.0, 0) + + dataset = ViTDetDataset(wilor_model_cfg, img_cv2, boxes, right, rescale_factor=2.0) + dataloader = torch.utils.data.DataLoader(dataset, batch_size=16, shuffle=False, num_workers=0) + + for batch in dataloader: + batch = recursive_to(batch, device) + + with torch.no_grad(): + out = wilor_model(batch) + + pred_cam = out['pred_cam'] + box_center = batch["box_center"].float() + box_size = batch["box_size"].float() + img_size = batch["img_size"].float() + scaled_focal_length = wilor_model_cfg.EXTRA.FOCAL_LENGTH / wilor_model_cfg.MODEL.IMAGE_SIZE * 224 + pred_cam_t_full = cam_crop_to_new(pred_cam, box_center, box_size, img_size, torch.from_numpy(np.array(ho_bbox, dtype=np.float32))[None, :].to(img_size.device), scaled_focal_length).detach().cpu().numpy() + + batch_size = batch['img'].shape[0] + for n in range(batch_size): + verts = out['pred_vertices'][n].detach().cpu().numpy() + joints = out['pred_keypoints_3d'][n].detach().cpu().numpy() + + is_right = batch['right'][n].cpu().numpy() + palm = (verts[95] + verts[22]) / 2 + cam_t = pred_cam_t_full[n] + + img_input = image_transform(crop_img_cv2[:, :, ::-1]).unsqueeze(0).cuda() + camera = PerspectiveCamera(5000 / 256 * 224, 5000 / 256 * 224, 112, 112) + cam_intr = camera.intrinsics + + metas = dict() + metas["right_hand_verts_3d"] = torch.from_numpy((verts + cam_t)[None]).cuda() + metas["right_hand_joints_3d"] = torch.from_numpy((joints + cam_t)[None]).cuda() + metas["right_hand_palm"] = torch.from_numpy((palm + cam_t)[None]).cuda() + metas["cam_intr"] = torch.from_numpy(cam_intr[None]).cuda() + with torch.amp.autocast(device_type='cuda', dtype=torch.float16): + pc_results = hort_model(img_input, metas) + objtrans = pc_results["objtrans"][0].detach().cpu().numpy() + pointclouds_up = pc_results["pointclouds_up"][0].detach().cpu().numpy() * 0.3 + + reconstructions = {'verts': verts, 'palm': palm, 'objtrans': objtrans, 'objpcs': pointclouds_up, 'cam_t': cam_t, 'right': is_right, 'img_size': 224, 'focal': scaled_focal_length} + + return crop_img_cv2[..., ::-1].astype(np.float32) / 255.0, len(detections), reconstructions + else: + return crop_img_cv2[..., ::-1].astype(np.float32) / 255.0, len(detections), None + + +def render_reconstruction(image, conf, IoU_threshold=0.3): + input_img, num_dets, reconstructions = run_model(image, conf, IoU_threshold=0.5) + if num_dets == 1: + # Render front view + misc_args = dict(mesh_base_color=LIGHT_PURPLE, point_base_color=STEEL_BLUE, scene_bg_color=(1, 1, 1), focal_length=reconstructions['focal']) + cam_view = renderer.render_rgba(reconstructions['verts'], reconstructions['objpcs'] + reconstructions['palm'] + reconstructions['objtrans'], cam_t=reconstructions['cam_t'], render_res=(224, 224), is_right=True, **misc_args) + + # Overlay image + input_img = np.concatenate([input_img, np.ones_like(input_img[:,:,:1])], axis=2) # Add alpha channel + input_img_overlay = input_img[:,:,:3] * (1-cam_view[:,:,3:]) + cam_view[:,:,:3] * cam_view[:,:,3:] + + return input_img_overlay, f'{num_dets} hands detected' + else: + return input_img, f'{num_dets} hands detected' + + +header = (''' +
+

HORT: Monocular Hand-held Objects Reconstruction with Transformers

+

+ Zerui Chen1, + Rolandos Alexandros Potamias2, +
+ Shizhe Chen1, + Cordelia Schmid1 +

+

+ 1Inria, Ecole normale supérieure, CNRS, PSL Research University; + 2Imperial College London +

+
+
+ + + + +''') + + +with gr.Blocks(title="HORT: Monocular Hand-held Objects Reconstruction with Transformers", css=".gradio-container") as demo: + + gr.Markdown(header) + + with gr.Row(): + with gr.Column(): + input_image = gr.Image(label="Input image", type="numpy") + threshold = gr.Slider(value=0.3, minimum=0.05, maximum=0.95, step=0.05, label='Detection Confidence Threshold') + submit = gr.Button("Submit", variant="primary") + + + with gr.Column(): + reconstruction = gr.Image(label="Reconstructions", type="numpy") + hands_detected = gr.Textbox(label="Hands Detected") + + submit.click(fn=render_reconstruction, inputs=[input_image, threshold], outputs=[reconstruction, hands_detected]) + + with gr.Row(): + example_images = gr.Examples([ + ['/home/user/app/assets/test1.png'], + ['./demo_img/app/assets/test2.png'], + ['./demo_img/app/assets/test3.jpg'], + ['./demo_img/app/assets/test4.jpeg'], + ['./demo_img/app/assets/test5.jpeg'] + ], + inputs=input_image) + +demo.launch(debug=True) diff --git a/assets/test1.png b/assets/test1.png new file mode 100644 index 0000000000000000000000000000000000000000..b0817a749ed4c22c8d322bf88489ced23fd16104 --- /dev/null +++ b/assets/test1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:220310a89f9777975b10d933eb6aef34c3fe036ae2f453c2e31d537f8827111b +size 129676 diff --git a/assets/test2.png b/assets/test2.png new file mode 100644 index 0000000000000000000000000000000000000000..e35549741cf978f8613588d1f2848e50665fae24 --- /dev/null +++ b/assets/test2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:29e4602efe21a483442c42a50ebf1c666c9e525dc630ec801a5af1d3acee18b1 +size 132948 diff --git a/assets/test3.jpg b/assets/test3.jpg new file mode 100644 index 0000000000000000000000000000000000000000..4ea7fd1b82a40e28c004473e041e63227ee94864 Binary files /dev/null and b/assets/test3.jpg differ diff --git a/assets/test4.jpeg b/assets/test4.jpeg new file mode 100644 index 0000000000000000000000000000000000000000..c6a2d1b4549a89faee6ec1dfe0b13742d52371de Binary files /dev/null and b/assets/test4.jpeg differ diff --git a/assets/test5.jpeg b/assets/test5.jpeg new file mode 100644 index 0000000000000000000000000000000000000000..eeacc48a79a19f104ed6759024be896bd8d5cace Binary files /dev/null and b/assets/test5.jpeg differ diff --git a/hort/models/__init__.py b/hort/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6cc23c6b01efa0b05d7887fd2fa45c4a302b6187 --- /dev/null +++ b/hort/models/__init__.py @@ -0,0 +1,114 @@ +import torch +import torch.nn as nn +import sys +import os.path as osp +import numpy as np +from yacs.config import CfgNode as CN +this_dir = osp.dirname(__file__) +sys.path.insert(0, this_dir) +import tgs +from network.pointnet import PointNetEncoder + +hort_cfg = CN() +hort_cfg.image_tokenizer_cls = "tgs.models.tokenizers.image.DINOV2SingleImageTokenizer" +hort_cfg.image_tokenizer = CN() +hort_cfg.image_tokenizer.pretrained_model_name_or_path = "facebook/dinov2-large" +hort_cfg.image_tokenizer.width = 224 +hort_cfg.image_tokenizer.height = 224 +hort_cfg.image_tokenizer.modulation = False +hort_cfg.image_tokenizer.modulation_zero_init = True +hort_cfg.image_tokenizer.modulation_cond_dim = 1024 +hort_cfg.image_tokenizer.freeze_backbone_params = False +hort_cfg.image_tokenizer.enable_memory_efficient_attention = False +hort_cfg.image_tokenizer.enable_gradient_checkpointing = False + +hort_cfg.tokenizer_cls = "tgs.models.tokenizers.point.PointLearnablePositionalEmbedding" +hort_cfg.tokenizer = CN() +hort_cfg.tokenizer.num_pcl = 2049 +hort_cfg.tokenizer.num_channels = 512 + +hort_cfg.backbone_cls = "tgs.models.transformers.Transformer1D" +hort_cfg.backbone = CN() +hort_cfg.backbone.in_channels = 512 +hort_cfg.backbone.num_attention_heads = 8 +hort_cfg.backbone.attention_head_dim = 64 +hort_cfg.backbone.num_layers = 10 +hort_cfg.backbone.cross_attention_dim = 1024 +hort_cfg.backbone.norm_type = "layer_norm" +hort_cfg.backbone.enable_memory_efficient_attention = False +hort_cfg.backbone.gradient_checkpointing = False + +hort_cfg.post_processor_cls = "tgs.models.networks.PointOutLayer" +hort_cfg.post_processor = CN() +hort_cfg.post_processor.in_channels = 512 +hort_cfg.post_processor.out_channels = 3 + +hort_cfg.pointcloud_upsampler_cls = "tgs.models.snowflake.model_spdpp.SnowflakeModelSPDPP" +hort_cfg.pointcloud_upsampler = CN() +hort_cfg.pointcloud_upsampler.input_channels = 1024 +hort_cfg.pointcloud_upsampler.dim_feat = 128 +hort_cfg.pointcloud_upsampler.num_p0 = 2048 +hort_cfg.pointcloud_upsampler.radius = 1 +hort_cfg.pointcloud_upsampler.bounding = True +hort_cfg.pointcloud_upsampler.use_fps = True +hort_cfg.pointcloud_upsampler.up_factors = [2, 4] +hort_cfg.pointcloud_upsampler.token_type = "image_token" + + +class model(nn.Module): + def __init__(self): + super(model, self).__init__() + self.image_tokenizer = tgs.find(hort_cfg.image_tokenizer_cls)(hort_cfg.image_tokenizer) + self.pointnet = PointNetEncoder(67, 1024) + self.tokenizer = tgs.find(hort_cfg.tokenizer_cls)(hort_cfg.tokenizer) + self.backbone = tgs.find(hort_cfg.backbone_cls)(hort_cfg.backbone) + self.post_processor = tgs.find(hort_cfg.post_processor_cls)(hort_cfg.post_processor) + self.post_processor_trans = nn.Sequential(nn.Linear(512, 256), nn.ReLU(), nn.Linear(256, 128), nn.ReLU(), nn.Linear(128, 3)) + self.pointcloud_upsampler = tgs.find(hort_cfg.pointcloud_upsampler_cls)(hort_cfg.pointcloud_upsampler) + + def forward(self, input_img, metas): + with torch.no_grad(): + batch_size = input_img.shape[0] + + encoder_hidden_states = self.image_tokenizer(input_img, None) # B * C * Nt + encoder_hidden_states = encoder_hidden_states.transpose(2, 1) # B * Nt * C + + palm_norm_hand_verts_3d = metas['right_hand_verts_3d'] - metas['right_hand_palm'].unsqueeze(1) + point_idx = torch.arange(778).view(1, 778, 1).expand(batch_size, -1, -1).to(input_img.device) / 778. + palm_norm_hand_verts_3d = torch.cat([palm_norm_hand_verts_3d, point_idx], -1) + tip_norm_hand_verts_3d = (metas['right_hand_verts_3d'].unsqueeze(2) - metas['right_hand_joints_3d'].unsqueeze(1)).reshape((batch_size, 778, -1)) + norm_hand_verts_3d = torch.cat([palm_norm_hand_verts_3d, tip_norm_hand_verts_3d], -1) + hand_feats = self.pointnet(norm_hand_verts_3d) + + tokens = self.tokenizer(batch_size) + tokens = self.backbone(tokens, torch.cat([encoder_hidden_states, hand_feats.unsqueeze(1)], 1), modulation_cond=None) + tokens = self.tokenizer.detokenize(tokens) + + pointclouds = self.post_processor(tokens[:, :2048, :]) + pred_obj_trans = self.post_processor_trans(tokens[:, -1, :]) + + upsampling_input = { + "input_image_tokens": encoder_hidden_states.permute(0, 2, 1), + "intrinsic_cond": metas['cam_intr'], + "points": pointclouds, + "hand_points": metas["right_hand_verts_3d"], + "trans": pred_obj_trans + metas['right_hand_palm'], + "scale": 0.3 + } + up_results = self.pointcloud_upsampler(upsampling_input) + pointclouds_up = up_results[-1] + + pc_results = {} + pc_results['pointclouds'] = pointclouds + pc_results['objtrans'] = pred_obj_trans + pc_results['handpalm'] = metas['right_hand_palm'] + pc_results['pointclouds_up'] = pointclouds_up + + return pc_results + +def load_hort(ckpt_path): + hort_model = model() + ckpt = torch.load(ckpt_path, map_location=torch.device('cpu'))["network"] + ckpt = {k.replace('module.', ''): v for k, v in ckpt.items()} + hort_model.load_state_dict(ckpt) + return hort_model diff --git a/hort/models/network/pointnet.py b/hort/models/network/pointnet.py new file mode 100644 index 0000000000000000000000000000000000000000..c81db3db7ef65e80952c674fb4d20ac30c2e25af --- /dev/null +++ b/hort/models/network/pointnet.py @@ -0,0 +1,36 @@ +import torch +import torch.nn as nn + + +class PointNetEncoder(nn.Module): + """Encoder for Pointcloud + """ + def __init__(self, in_channels: int=3, output_channels: int=768): + super().__init__() + + block_channel = [64, 128, 256, 512] + self.mlp = nn.Sequential( + nn.Linear(in_channels, block_channel[0]), + nn.LayerNorm(block_channel[0]), + nn.ReLU(), + nn.Linear(block_channel[0], block_channel[1]), + nn.LayerNorm(block_channel[1]), + nn.ReLU(), + nn.Linear(block_channel[1], block_channel[2]), + nn.LayerNorm(block_channel[2]), + nn.ReLU(), + nn.Linear(block_channel[2], block_channel[3]), + nn.LayerNorm(block_channel[3]), + nn.ReLU(), + ) + + self.final_projection = nn.Sequential( + nn.Linear(block_channel[-1], output_channels), + nn.LayerNorm(output_channels) + ) + + def forward(self, x): + x = self.mlp(x) + x = torch.max(x, 1)[0] + x = self.final_projection(x) + return x \ No newline at end of file diff --git a/hort/models/tgs/__init__.py b/hort/models/tgs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d72c1d1ab4cc2b5481d6028c001c9cab91b35ed9 --- /dev/null +++ b/hort/models/tgs/__init__.py @@ -0,0 +1,9 @@ +import importlib +from tgs.utils.typing import * + +def find(cls_string) -> Type: + module_string = ".".join(cls_string.split(".")[:-1]) + cls_name = cls_string.split(".")[-1] + module = importlib.import_module(module_string, package=None) + cls = getattr(module, cls_name) + return cls \ No newline at end of file diff --git a/hort/models/tgs/data.py b/hort/models/tgs/data.py new file mode 100644 index 0000000000000000000000000000000000000000..4ab123fcc1223232934213d8c29c17b4c1f9df50 --- /dev/null +++ b/hort/models/tgs/data.py @@ -0,0 +1,265 @@ +import json +import math +from dataclasses import dataclass, field + +import os +import imageio +import numpy as np +import torch +import torch.nn.functional as F +from PIL import Image +from torch.utils.data import Dataset + +from tgs.utils.config import parse_structured +from tgs.utils.ops import get_intrinsic_from_fov, get_ray_directions, get_rays +from tgs.utils.typing import * + + +def _parse_scene_list_single(scene_list_path: str): + if scene_list_path.endswith(".json"): + with open(scene_list_path) as f: + all_scenes = json.loads(f.read()) + elif scene_list_path.endswith(".txt"): + with open(scene_list_path) as f: + all_scenes = [p.strip() for p in f.readlines()] + else: + all_scenes = [scene_list_path] + + return all_scenes + + +def _parse_scene_list(scene_list_path: Union[str, List[str]]): + all_scenes = [] + if isinstance(scene_list_path, str): + scene_list_path = [scene_list_path] + for scene_list_path_ in scene_list_path: + all_scenes += _parse_scene_list_single(scene_list_path_) + return all_scenes + +@dataclass +class CustomImageDataModuleConfig: + image_list: Any = "" + background_color: Tuple[float, float, float] = field( + default_factory=lambda: (1.0, 1.0, 1.0) + ) + + relative_pose: bool = False + cond_height: int = 512 + cond_width: int = 512 + cond_camera_distance: float = 1.6 + cond_fovy_deg: float = 40.0 + cond_elevation_deg: float = 0.0 + cond_azimuth_deg: float = 0.0 + num_workers: int = 16 + + eval_height: int = 512 + eval_width: int = 512 + eval_batch_size: int = 1 + eval_elevation_deg: float = 0.0 + eval_camera_distance: float = 1.6 + eval_fovy_deg: float = 40.0 + n_test_views: int = 120 + num_views_output: int = 120 + only_3dgs: bool = False + +class CustomImageOrbitDataset(Dataset): + def __init__(self, cfg: Any) -> None: + super().__init__() + self.cfg: CustomImageDataModuleConfig = parse_structured(CustomImageDataModuleConfig, cfg) + + self.n_views = self.cfg.n_test_views + assert self.n_views % self.cfg.num_views_output == 0 + + self.all_scenes = _parse_scene_list(self.cfg.image_list) + + azimuth_deg: Float[Tensor, "B"] = torch.linspace(0, 360.0, self.n_views + 1)[ + : self.n_views + ] + elevation_deg: Float[Tensor, "B"] = torch.full_like( + azimuth_deg, self.cfg.eval_elevation_deg + ) + camera_distances: Float[Tensor, "B"] = torch.full_like( + elevation_deg, self.cfg.eval_camera_distance + ) + + elevation = elevation_deg * math.pi / 180 + azimuth = azimuth_deg * math.pi / 180 + + # convert spherical coordinates to cartesian coordinates + # right hand coordinate system, x back, y right, z up + # elevation in (-90, 90), azimuth from +x to +y in (-180, 180) + camera_positions: Float[Tensor, "B 3"] = torch.stack( + [ + camera_distances * torch.cos(elevation) * torch.cos(azimuth), + camera_distances * torch.cos(elevation) * torch.sin(azimuth), + camera_distances * torch.sin(elevation), + ], + dim=-1, + ) + + # default scene center at origin + center: Float[Tensor, "B 3"] = torch.zeros_like(camera_positions) + # default camera up direction as +z + up: Float[Tensor, "B 3"] = torch.as_tensor([0, 0, 1], dtype=torch.float32)[ + None, : + ].repeat(self.n_views, 1) + + fovy_deg: Float[Tensor, "B"] = torch.full_like( + elevation_deg, self.cfg.eval_fovy_deg + ) + fovy = fovy_deg * math.pi / 180 + + lookat: Float[Tensor, "B 3"] = F.normalize(center - camera_positions, dim=-1) + right: Float[Tensor, "B 3"] = F.normalize(torch.cross(lookat, up), dim=-1) + up = F.normalize(torch.cross(right, lookat), dim=-1) + c2w3x4: Float[Tensor, "B 3 4"] = torch.cat( + [torch.stack([right, up, -lookat], dim=-1), camera_positions[:, :, None]], + dim=-1, + ) + c2w: Float[Tensor, "B 4 4"] = torch.cat( + [c2w3x4, torch.zeros_like(c2w3x4[:, :1])], dim=1 + ) + c2w[:, 3, 3] = 1.0 + + # get directions by dividing directions_unit_focal by focal length + focal_length: Float[Tensor, "B"] = ( + 0.5 * self.cfg.eval_height / torch.tan(0.5 * fovy) + ) + directions_unit_focal = get_ray_directions( + H=self.cfg.eval_height, + W=self.cfg.eval_width, + focal=1.0, + ) + directions: Float[Tensor, "B H W 3"] = directions_unit_focal[ + None, :, :, : + ].repeat(self.n_views, 1, 1, 1) + directions[:, :, :, :2] = ( + directions[:, :, :, :2] / focal_length[:, None, None, None] + ) + # must use normalize=True to normalize directions here + rays_o, rays_d = get_rays(directions, c2w, keepdim=True) + + intrinsic: Float[Tensor, "B 3 3"] = get_intrinsic_from_fov( + self.cfg.eval_fovy_deg * math.pi / 180, + H=self.cfg.eval_height, + W=self.cfg.eval_width, + bs=self.n_views, + ) + intrinsic_normed: Float[Tensor, "B 3 3"] = intrinsic.clone() + intrinsic_normed[..., 0, 2] /= self.cfg.eval_width + intrinsic_normed[..., 1, 2] /= self.cfg.eval_height + intrinsic_normed[..., 0, 0] /= self.cfg.eval_width + intrinsic_normed[..., 1, 1] /= self.cfg.eval_height + + self.rays_o, self.rays_d = rays_o, rays_d + self.intrinsic = intrinsic + self.intrinsic_normed = intrinsic_normed + self.c2w = c2w + self.camera_positions = camera_positions + + self.background_color = torch.as_tensor(self.cfg.background_color) + + # condition + self.intrinsic_cond = get_intrinsic_from_fov( + np.deg2rad(self.cfg.cond_fovy_deg), + H=self.cfg.cond_height, + W=self.cfg.cond_width, + ) + self.intrinsic_normed_cond = self.intrinsic_cond.clone() + self.intrinsic_normed_cond[..., 0, 2] /= self.cfg.cond_width + self.intrinsic_normed_cond[..., 1, 2] /= self.cfg.cond_height + self.intrinsic_normed_cond[..., 0, 0] /= self.cfg.cond_width + self.intrinsic_normed_cond[..., 1, 1] /= self.cfg.cond_height + + + if self.cfg.relative_pose: + self.c2w_cond = torch.as_tensor( + [ + [0, 0, 1, self.cfg.cond_camera_distance], + [1, 0, 0, 0], + [0, 1, 0, 0], + [0, 0, 0, 1], + ] + ).float() + else: + cond_elevation = self.cfg.cond_elevation_deg * math.pi / 180 + cond_azimuth = self.cfg.cond_azimuth_deg * math.pi / 180 + cond_camera_position: Float[Tensor, "3"] = torch.as_tensor( + [ + self.cfg.cond_camera_distance * np.cos(cond_elevation) * np.cos(cond_azimuth), + self.cfg.cond_camera_distance * np.cos(cond_elevation) * np.sin(cond_azimuth), + self.cfg.cond_camera_distance * np.sin(cond_elevation), + ], dtype=torch.float32 + ) + + cond_center: Float[Tensor, "3"] = torch.zeros_like(cond_camera_position) + cond_up: Float[Tensor, "3"] = torch.as_tensor([0, 0, 1], dtype=torch.float32) + cond_lookat: Float[Tensor, "3"] = F.normalize(cond_center - cond_camera_position, dim=-1) + cond_right: Float[Tensor, "3"] = F.normalize(torch.cross(cond_lookat, cond_up), dim=-1) + cond_up = F.normalize(torch.cross(cond_right, cond_lookat), dim=-1) + cond_c2w3x4: Float[Tensor, "3 4"] = torch.cat( + [torch.stack([cond_right, cond_up, -cond_lookat], dim=-1), cond_camera_position[:, None]], + dim=-1, + ) + cond_c2w: Float[Tensor, "4 4"] = torch.cat( + [cond_c2w3x4, torch.zeros_like(cond_c2w3x4[:1])], dim=0 + ) + cond_c2w[3, 3] = 1.0 + self.c2w_cond = cond_c2w + + def __len__(self): + if self.cfg.only_3dgs: + return len(self.all_scenes) + else: + return len(self.all_scenes) * self.n_views // self.cfg.num_views_output + + def __getitem__(self, index): + if self.cfg.only_3dgs: + scene_index = index + view_index = [0] + else: + scene_index = index * self.cfg.num_views_output // self.n_views + view_start = index % (self.n_views // self.cfg.num_views_output) + view_index = list(range(self.n_views))[view_start * self.cfg.num_views_output : + (view_start + 1) * self.cfg.num_views_output] + + img_path = self.all_scenes[scene_index] + img_cond = torch.from_numpy( + np.asarray( + Image.fromarray(imageio.v2.imread(img_path)) + .convert("RGBA") + .resize((self.cfg.cond_width, self.cfg.cond_height)) + ) + / 255.0 + ).float() + mask_cond: Float[Tensor, "Hc Wc 1"] = img_cond[:, :, -1:] + rgb_cond: Float[Tensor, "Hc Wc 3"] = img_cond[ + :, :, :3 + ] * mask_cond + self.background_color[None, None, :] * (1 - mask_cond) + + out = { + "rgb_cond": rgb_cond.unsqueeze(0), + "c2w_cond": self.c2w_cond.unsqueeze(0), + "mask_cond": mask_cond.unsqueeze(0), + "intrinsic_cond": self.intrinsic_cond.unsqueeze(0), + "intrinsic_normed_cond": self.intrinsic_normed_cond.unsqueeze(0), + "view_index": torch.as_tensor(view_index), + "rays_o": self.rays_o[view_index], + "rays_d": self.rays_d[view_index], + "intrinsic": self.intrinsic[view_index], + "intrinsic_normed": self.intrinsic_normed[view_index], + "c2w": self.c2w[view_index], + "camera_positions": self.camera_positions[view_index], + } + out["c2w"][..., :3, 1:3] *= -1 + out["c2w_cond"][..., :3, 1:3] *= -1 + instance_id = os.path.split(img_path)[-1].split('.')[0] + out["index"] = torch.as_tensor(scene_index) + out["background_color"] = self.background_color + out["instance_id"] = instance_id + return out + + def collate(self, batch): + batch = torch.utils.data.default_collate(batch) + batch.update({"height": self.cfg.eval_height, "width": self.cfg.eval_width}) + return batch \ No newline at end of file diff --git a/hort/models/tgs/models/__init__.py b/hort/models/tgs/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/hort/models/tgs/models/image_feature.py b/hort/models/tgs/models/image_feature.py new file mode 100644 index 0000000000000000000000000000000000000000..e1c6a0aa8fd3fe4b04315bb56e206153c291dee4 --- /dev/null +++ b/hort/models/tgs/models/image_feature.py @@ -0,0 +1,48 @@ +from dataclasses import dataclass +import torch +import torch.nn.functional as F +from einops import rearrange + +from tgs.utils.base import BaseModule +from tgs.utils.ops import compute_distance_transform +from tgs.utils.typing import * + +class ImageFeature(BaseModule): + @dataclass + class Config(BaseModule.Config): + use_rgb: bool = True + use_feature: bool = True + use_mask: bool = True + feature_dim: int = 128 + out_dim: int = 133 + backbone: str = "default" + freeze_backbone_params: bool = True + + cfg: Config + + def forward(self, rgb, mask=None, feature=None): + B, Nv, H, W = rgb.shape[:4] + rgb = rearrange(rgb, "B Nv H W C -> (B Nv) C H W") + if mask is not None: + mask = rearrange(mask, "B Nv H W C -> (B Nv) C H W") + + assert feature is not None + # reshape dino tokens to image-like size + feature = rearrange(feature, "B (Nv Nt) C -> (B Nv) Nt C", Nv=Nv) + feature = feature[:, 1:].reshape(B * Nv, H // 14, W // 14, -1).permute(0, 3, 1, 2).contiguous() + feature = F.interpolate(feature, size=(H, W), mode='bilinear', align_corners=False) + + if mask is not None and mask.is_floating_point(): + mask = mask > 0.5 + + image_features = [] + if self.cfg.use_rgb: + image_features.append(rgb) + if self.cfg.use_feature: + image_features.append(feature) + if self.cfg.use_mask: + image_features += [mask, compute_distance_transform(mask)] + + # detach features, occur error when with grad + image_features = torch.cat(image_features, dim=1)#.detach() + return rearrange(image_features, "(B Nv) C H W -> B Nv C H W", B=B, Nv=Nv).squeeze(1) \ No newline at end of file diff --git a/hort/models/tgs/models/networks.py b/hort/models/tgs/models/networks.py new file mode 100644 index 0000000000000000000000000000000000000000..a5c2bb1f4ec1a5d9e8b824926ba1e4ec51ade25a --- /dev/null +++ b/hort/models/tgs/models/networks.py @@ -0,0 +1,204 @@ +from dataclasses import dataclass + +import torch +import torch.nn as nn +from einops import rearrange +import numpy as np + +from tgs.utils.base import BaseModule +from tgs.utils.ops import get_activation +from tgs.utils.typing import * + +class PointOutLayer(BaseModule): + @dataclass + class Config(BaseModule.Config): + in_channels: int = 1024 + out_channels: int = 3 + cfg: Config + def configure(self) -> None: + super().configure() + self.point_layer = nn.Linear(self.cfg.in_channels, self.cfg.out_channels) + self.initialize_weights() + + def initialize_weights(self): + nn.init.constant_(self.point_layer.weight, 0) + nn.init.constant_(self.point_layer.bias, 0) + + def forward(self, x): + return self.point_layer(x) + +class TriplaneUpsampleNetwork(BaseModule): + @dataclass + class Config(BaseModule.Config): + in_channels: int = 1024 + out_channels: int = 80 + + cfg: Config + + def configure(self) -> None: + super().configure() + self.upsample = nn.ConvTranspose2d( + self.cfg.in_channels, self.cfg.out_channels, kernel_size=2, stride=2 + ) + + def forward( + self, triplanes: Float[Tensor, "B 3 Ci Hp Wp"] + ) -> Float[Tensor, "B 3 Co Hp2 Wp2"]: + triplanes_up = rearrange( + self.upsample( + rearrange(triplanes, "B Np Ci Hp Wp -> (B Np) Ci Hp Wp", Np=3) + ), + "(B Np) Co Hp Wp -> B Np Co Hp Wp", + Np=3, + ) + return triplanes_up + + +class MLP(nn.Module): + def __init__( + self, + dim_in: int, + dim_out: int, + n_neurons: int, + n_hidden_layers: int, + activation: str = "relu", + output_activation: Optional[str] = None, + bias: bool = True, + ): + super().__init__() + layers = [ + self.make_linear( + dim_in, n_neurons, is_first=True, is_last=False, bias=bias + ), + self.make_activation(activation), + ] + for i in range(n_hidden_layers - 1): + layers += [ + self.make_linear( + n_neurons, n_neurons, is_first=False, is_last=False, bias=bias + ), + self.make_activation(activation), + ] + layers += [ + self.make_linear( + n_neurons, dim_out, is_first=False, is_last=True, bias=bias + ) + ] + self.layers = nn.Sequential(*layers) + self.output_activation = get_activation(output_activation) + + def forward(self, x): + x = self.layers(x) + x = self.output_activation(x) + return x + + def make_linear(self, dim_in, dim_out, is_first, is_last, bias=True): + layer = nn.Linear(dim_in, dim_out, bias=bias) + return layer + + def make_activation(self, activation): + if activation == "relu": + return nn.ReLU(inplace=True) + elif activation == "silu": + return nn.SiLU(inplace=True) + else: + raise NotImplementedError + +class GSProjection(nn.Module): + def __init__(self, + in_channels: int = 80, + sh_degree: int = 3, + init_scaling: float = -5.0, + init_density: float = 0.1) -> None: + super().__init__() + + self.out_keys = GS_KEYS + ["shs"] + self.out_channels = GS_CHANNELS + [(sh_degree + 1) ** 2 * 3] + + self.out_layers = nn.ModuleList() + for key, ch in zip(self.out_keys, self.out_channels): + layer = nn.Linear(in_channels, ch) + # initialize + nn.init.constant_(layer.weight, 0) + nn.init.constant_(layer.bias, 0) + + if key == "scaling": + nn.init.constant_(layer.bias, init_scaling) + elif key == "rotation": + nn.init.constant_(layer.bias, 0) + nn.init.constant_(layer.bias[0], 1.0) + elif key == "opacity": + inverse_sigmoid = lambda x: np.log(x / (1 - x)) + nn.init.constant_(layer.bias, inverse_sigmoid(init_density)) + + self.out_layers.append(layer) + + def forward(self, x): + ret = [] + for k, layer in zip(self.out_keys, self.out_layers): + v = layer(x) + if k == "rotation": + v = torch.nn.functional.normalize(v) + elif k == "scaling": + v = torch.exp(v) + # v = v.detach() # FIXME: for DEBUG + elif k == "opacity": + v = torch.sigmoid(v) + # elif k == "shs": + # v = torch.reshape(v, (v.shape[0], -1, 3)) + ret.append(v) + ret = torch.cat(ret, dim=-1) + return ret + +def get_encoding(n_input_dims: int, config) -> nn.Module: + raise NotImplementedError + + +def get_mlp(n_input_dims, n_output_dims, config) -> nn.Module: + raise NotImplementedError + + +# Resnet Blocks for pointnet +class ResnetBlockFC(nn.Module): + ''' Fully connected ResNet Block class. + + Args: + size_in (int): input dimension + size_out (int): output dimension + size_h (int): hidden dimension + ''' + + def __init__(self, size_in, size_out=None, size_h=None): + super().__init__() + # Attributes + if size_out is None: + size_out = size_in + + if size_h is None: + size_h = min(size_in, size_out) + + self.size_in = size_in + self.size_h = size_h + self.size_out = size_out + # Submodules + self.fc_0 = nn.Linear(size_in, size_h) + self.fc_1 = nn.Linear(size_h, size_out) + self.actvn = nn.ReLU() + + if size_in == size_out: + self.shortcut = None + else: + self.shortcut = nn.Linear(size_in, size_out, bias=False) + # Initialization + nn.init.zeros_(self.fc_1.weight) + + def forward(self, x): + net = self.fc_0(self.actvn(x)) + dx = self.fc_1(self.actvn(net)) + + if self.shortcut is not None: + x_s = self.shortcut(x) + else: + x_s = x + + return x_s + dx \ No newline at end of file diff --git a/hort/models/tgs/models/pointclouds/LICENSE_POINTNET b/hort/models/tgs/models/pointclouds/LICENSE_POINTNET new file mode 100644 index 0000000000000000000000000000000000000000..5358ea96300c33c974dafa4c492e2df20b0dbb39 --- /dev/null +++ b/hort/models/tgs/models/pointclouds/LICENSE_POINTNET @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2020 Songyou Peng, Michael Niemeyer, Lars Mescheder, Marc Pollefeys, Andreas Geiger + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/hort/models/tgs/models/pointclouds/pointnet.py b/hort/models/tgs/models/pointclouds/pointnet.py new file mode 100644 index 0000000000000000000000000000000000000000..14b5cbe9ba43ab514369f06fc6db794fb88c9329 --- /dev/null +++ b/hort/models/tgs/models/pointclouds/pointnet.py @@ -0,0 +1,121 @@ +# modified from https://github.com/autonomousvision/convolutional_occupancy_networks/blob/master/src/encoder/pointnet.py +from dataclasses import dataclass +import torch +import torch.nn as nn +from torch_scatter import scatter_mean, scatter_max + +from tgs.utils.base import BaseModule +from tgs.models.networks import ResnetBlockFC +from tgs.utils.ops import scale_tensor + +class LocalPoolPointnet(BaseModule): + ''' PointNet-based encoder network with ResNet blocks for each point. + Number of input points are fixed. + + Args: + c_dim (int): dimension of latent code c + dim (int): input points dimension + hidden_dim (int): hidden dimension of the network + scatter_type (str): feature aggregation when doing local pooling + plane_resolution (int): defined resolution for plane feature + padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55] + n_blocks (int): number of blocks ResNetBlockFC layers + ''' + + @dataclass + class Config(BaseModule.Config): + input_channels: int = 3 + c_dim: int = 128 + hidden_dim: int = 128 + scatter_type: str = "max" + plane_size: int = 32 + n_blocks: int = 5 + radius: float = 1. + + cfg: Config + + def configure(self) -> None: + super().configure() + self.fc_pos = nn.Linear(self.cfg.input_channels, 2 * self.cfg.hidden_dim) + self.blocks = nn.ModuleList([ + ResnetBlockFC(2 * self.cfg.hidden_dim, self.cfg.hidden_dim) for i in range(self.cfg.n_blocks) + ]) + self.fc_c = nn.Linear(self.cfg.hidden_dim, self.cfg.c_dim) + + self.actvn = nn.ReLU() + + if self.cfg.scatter_type == 'max': + self.scatter = scatter_max + elif self.cfg.scatter_type == 'mean': + self.scatter = scatter_mean + else: + raise ValueError('incorrect scatter type') + + + def generate_plane_features(self, index, c): + # acquire indices of features in plane + # xy = normalize_coordinate(p.clone(), plane=plane, padding=self.padding) # normalize to the range of (0, 1) + # index = self.coordinate2index(x, self.cfg.plane_size) + + # scatter plane features from points + fea_plane = c.new_zeros(index.shape[0], self.cfg.c_dim, self.cfg.plane_size ** 2) + c = c.permute(0, 2, 1) # B x 512 x T + fea_plane = scatter_mean(c, index, out=fea_plane) # B x 512 x reso^2 + fea_plane = fea_plane.reshape(index.shape[0], self.cfg.c_dim, self.cfg.plane_size, self.cfg.plane_size) # sparce matrix (B x 512 x reso x reso) + + return fea_plane + + def pool_local(self, xy, index, c): + bs, fea_dim = c.shape[0], c.shape[2] + keys = xy.keys() + + c_out = 0 + for key in keys: + # scatter plane features from points + fea = self.scatter(c.permute(0, 2, 1), index[key], dim_size=self.cfg.plane_size ** 2) + if self.scatter == scatter_max: + fea = fea[0] + # gather feature back to points + fea = fea.gather(dim=2, index=index[key].expand(-1, fea_dim, -1)) + c_out += fea + return c_out.permute(0, 2, 1) + + def coordinate2index(self, x): + x = (x * self.cfg.plane_size).long() + index = x[..., 0] + self.cfg.plane_size * x[..., 1] + assert index.max() < self.cfg.plane_size ** 2 + return index[:, None, :] + + def forward(self, p): + batch_size, T, D = p.shape + + # acquire the index for each point + coord = {} + index = {} + + position = torch.clamp(p[..., :3], -self.cfg.radius + 1e-6, self.cfg.radius - 1e-6) + position_norm = scale_tensor(position, (-self.cfg.radius, self.cfg.radius), (0, 1)) + coord["xy"] = position_norm[..., [0, 1]] + coord["xz"] = position_norm[..., [0, 2]] + coord["yz"] = position_norm[..., [1, 2]] + index["xy"] = self.coordinate2index(coord["xy"]) + index["xz"] = self.coordinate2index(coord["xz"]) + index["yz"] = self.coordinate2index(coord["yz"]) + + net = self.fc_pos(p) + + net = self.blocks[0](net) + for block in self.blocks[1:]: + pooled = self.pool_local(coord, index, net) + net = torch.cat([net, pooled], dim=2) + net = block(net) + + c = self.fc_c(net) + + features = torch.stack([ + self.generate_plane_features(index["xy"], c), + self.generate_plane_features(index["xz"], c), + self.generate_plane_features(index["yz"], c) + ], dim=1) + + return features diff --git a/hort/models/tgs/models/pointclouds/simplepoint.py b/hort/models/tgs/models/pointclouds/simplepoint.py new file mode 100644 index 0000000000000000000000000000000000000000..2fb028d07a7949e93a2116e9d77f4d669636f145 --- /dev/null +++ b/hort/models/tgs/models/pointclouds/simplepoint.py @@ -0,0 +1,110 @@ +from dataclasses import dataclass, field +import torch +from einops import rearrange + +import tgs +from tgs.utils.base import BaseModule +from tgs.utils.typing import * + +class SimplePointGenerator(BaseModule): + @dataclass + class Config(BaseModule.Config): + camera_embedder_cls: str = "" + camera_embedder: dict = field(default_factory=dict) + + image_tokenizer_cls: str = "" + image_tokenizer: dict = field(default_factory=dict) + + tokenizer_cls: str = "" + tokenizer: dict = field(default_factory=dict) + + backbone_cls: str = "" + backbone: dict = field(default_factory=dict) + + post_processor_cls: str = "" + post_processor: dict = field(default_factory=dict) + + pointcloud_upsampling_cls: str = "" + pointcloud_upsampling: dict = field(default_factory=dict) + + flip_c2w_cond: bool = True + + cfg: Config + + def configure(self) -> None: + super().configure() + + self.image_tokenizer = tgs.find(self.cfg.image_tokenizer_cls)( + self.cfg.image_tokenizer + ) + + assert self.cfg.camera_embedder_cls == 'tgs.models.networks.MLP' + weights = self.cfg.camera_embedder.pop("weights") if "weights" in self.cfg.camera_embedder else None + self.camera_embedder = tgs.find(self.cfg.camera_embedder_cls)(**self.cfg.camera_embedder) + if weights: + from tgs.utils.misc import load_module_weights + weights_path, module_name = weights.split(":") + state_dict = load_module_weights( + weights_path, module_name=module_name, map_location="cpu" + ) + self.camera_embedder.load_state_dict(state_dict) + + self.tokenizer = tgs.find(self.cfg.tokenizer_cls)(self.cfg.tokenizer) + + self.backbone = tgs.find(self.cfg.backbone_cls)(self.cfg.backbone) + + self.post_processor = tgs.find(self.cfg.post_processor_cls)( + self.cfg.post_processor + ) + + self.pointcloud_upsampling = tgs.find(self.cfg.pointcloud_upsampling_cls)(self.cfg.pointcloud_upsampling) + + def forward(self, batch, encoder_hidden_states=None, **kwargs): + batch_size, n_input_views = batch["rgb_cond"].shape[:2] + + if encoder_hidden_states is None: + # Camera modulation + c2w_cond = batch["c2w_cond"].clone() + if self.cfg.flip_c2w_cond: + c2w_cond[..., :3, 1:3] *= -1 + camera_extri = c2w_cond.view(*c2w_cond.shape[:-2], -1) + camera_intri = batch["intrinsic_normed_cond"].view( + *batch["intrinsic_normed_cond"].shape[:-2], -1) + camera_feats = torch.cat([camera_intri, camera_extri], dim=-1) + # camera_feats = rearrange(camera_feats, 'B Nv C -> (B Nv) C') + + camera_feats = self.camera_embedder(camera_feats) + + encoder_hidden_states: Float[Tensor, "B Cit Nit"] = self.image_tokenizer( + rearrange(batch["rgb_cond"], 'B Nv H W C -> B Nv C H W'), + modulation_cond=camera_feats, + ) + encoder_hidden_states = rearrange( + encoder_hidden_states, 'B Nv C Nt -> B (Nv Nt) C', Nv=n_input_views) + + tokens: Float[Tensor, "B Ct Nt"] = self.tokenizer(batch_size) + + tokens = self.backbone( + tokens, + encoder_hidden_states=encoder_hidden_states, + modulation_cond=None, + ) + pointclouds = self.post_processor(self.tokenizer.detokenize(tokens)) + + upsampling_input = { + "input_image_tokens": encoder_hidden_states.permute(0, 2, 1), + "input_image_tokens_global": encoder_hidden_states[:, :1], + "c2w_cond": c2w_cond, + "rgb_cond": batch["rgb_cond"], + "intrinsic_cond": batch["intrinsic_cond"], + "intrinsic_normed_cond": batch["intrinsic_normed_cond"], + "points": pointclouds.float() + } + up_results = self.pointcloud_upsampling(upsampling_input) + up_results.insert(0, pointclouds) + pointclouds = up_results[-1] + out = { + "points": pointclouds, + "up_results": up_results + } + return out \ No newline at end of file diff --git a/hort/models/tgs/models/renderer.py b/hort/models/tgs/models/renderer.py new file mode 100644 index 0000000000000000000000000000000000000000..9ff65178eb64f8c308acbf25cbf367a68f58f2f3 --- /dev/null +++ b/hort/models/tgs/models/renderer.py @@ -0,0 +1,427 @@ +from dataclasses import dataclass, field +from collections import defaultdict +from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer +from plyfile import PlyData, PlyElement +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import math + +from tgs.utils.typing import * +from tgs.utils.base import BaseModule +from tgs.utils.ops import trunc_exp +from tgs.models.networks import MLP +from tgs.utils.ops import scale_tensor +from einops import rearrange, reduce + +inverse_sigmoid = lambda x: np.log(x / (1 - x)) + +def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0): + Rt = np.zeros((4, 4)) + Rt[:3, :3] = R.transpose() + Rt[:3, 3] = t + Rt[3, 3] = 1.0 + + C2W = np.linalg.inv(Rt) + cam_center = C2W[:3, 3] + cam_center = (cam_center + translate) * scale + C2W[:3, 3] = cam_center + Rt = np.linalg.inv(C2W) + return np.float32(Rt) + +def getProjectionMatrix(znear, zfar, fovX, fovY): + tanHalfFovY = math.tan((fovY / 2)) + tanHalfFovX = math.tan((fovX / 2)) + + top = tanHalfFovY * znear + bottom = -top + right = tanHalfFovX * znear + left = -right + + P = torch.zeros(4, 4) + + z_sign = 1.0 + + P[0, 0] = 2.0 * znear / (right - left) + P[1, 1] = 2.0 * znear / (top - bottom) + P[0, 2] = (right + left) / (right - left) + P[1, 2] = (top + bottom) / (top - bottom) + P[3, 2] = z_sign + P[2, 2] = z_sign * zfar / (zfar - znear) + P[2, 3] = -(zfar * znear) / (zfar - znear) + return P + +def intrinsic_to_fov(intrinsic, w, h): + fx, fy = intrinsic[0, 0], intrinsic[1, 1] + fov_x = 2 * torch.arctan2(w, 2 * fx) + fov_y = 2 * torch.arctan2(h, 2 * fy) + return fov_x, fov_y + + +class Camera: + def __init__(self, w2c, intrinsic, FoVx, FoVy, height, width, trans=np.array([0.0, 0.0, 0.0]), scale=1.0) -> None: + self.FoVx = FoVx + self.FoVy = FoVy + self.height = height + self.width = width + self.world_view_transform = w2c.transpose(0, 1) + + self.zfar = 100.0 + self.znear = 0.01 + + self.trans = trans + self.scale = scale + + self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy).transpose(0,1).to(w2c.device) + self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0) + self.camera_center = self.world_view_transform.inverse()[3, :3] + + @staticmethod + def from_c2w(c2w, intrinsic, height, width): + w2c = torch.inverse(c2w) + FoVx, FoVy = intrinsic_to_fov(intrinsic, w=torch.tensor(width, device=w2c.device), h=torch.tensor(height, device=w2c.device)) + return Camera(w2c=w2c, intrinsic=intrinsic, FoVx=FoVx, FoVy=FoVy, height=height, width=width) + +class GaussianModel(NamedTuple): + xyz: Tensor + opacity: Tensor + rotation: Tensor + scaling: Tensor + shs: Tensor + + def construct_list_of_attributes(self): + l = ['x', 'y', 'z', 'nx', 'ny', 'nz'] + features_dc = self.shs[:, :1] + features_rest = self.shs[:, 1:] + for i in range(features_dc.shape[1]*features_dc.shape[2]): + l.append('f_dc_{}'.format(i)) + for i in range(features_rest.shape[1]*features_rest.shape[2]): + l.append('f_rest_{}'.format(i)) + l.append('opacity') + for i in range(self.scaling.shape[1]): + l.append('scale_{}'.format(i)) + for i in range(self.rotation.shape[1]): + l.append('rot_{}'.format(i)) + return l + + def save_ply(self, path): + + xyz = self.xyz.detach().cpu().numpy() + normals = np.zeros_like(xyz) + features_dc = self.shs[:, :1] + features_rest = self.shs[:, 1:] + f_dc = features_dc.detach().flatten(start_dim=1).contiguous().cpu().numpy() + f_rest = features_rest.detach().flatten(start_dim=1).contiguous().cpu().numpy() + opacities = inverse_sigmoid(torch.clamp(self.opacity, 1e-3, 1 - 1e-3).detach().cpu().numpy()) + scale = np.log(self.scaling.detach().cpu().numpy()) + rotation = self.rotation.detach().cpu().numpy() + + dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes()] + + elements = np.empty(xyz.shape[0], dtype=dtype_full) + attributes = np.concatenate((xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1) + elements[:] = list(map(tuple, attributes)) + el = PlyElement.describe(elements, 'vertex') + PlyData([el]).write(path) + +class GSLayer(BaseModule): + @dataclass + class Config(BaseModule.Config): + in_channels: int = 128 + feature_channels: dict = field(default_factory=dict) + xyz_offset: bool = True + restrict_offset: bool = False + use_rgb: bool = False + clip_scaling: Optional[float] = None + init_scaling: float = -5.0 + init_density: float = 0.1 + + cfg: Config + + def configure(self, *args, **kwargs) -> None: + self.out_layers = nn.ModuleList() + for key, out_ch in self.cfg.feature_channels.items(): + if key == "shs" and self.cfg.use_rgb: + out_ch = 3 + layer = nn.Linear(self.cfg.in_channels, out_ch) + + # initialize + if not (key == "shs" and self.cfg.use_rgb): + nn.init.constant_(layer.weight, 0) + nn.init.constant_(layer.bias, 0) + if key == "scaling": + nn.init.constant_(layer.bias, self.cfg.init_scaling) + elif key == "rotation": + nn.init.constant_(layer.bias, 0) + nn.init.constant_(layer.bias[0], 1.0) + elif key == "opacity": + nn.init.constant_(layer.bias, inverse_sigmoid(self.cfg.init_density)) + + self.out_layers.append(layer) + + def forward(self, x, pts): + ret = {} + for k, layer in zip(self.cfg.feature_channels.keys(), self.out_layers): + v = layer(x) + if k == "rotation": + v = torch.nn.functional.normalize(v) + elif k == "scaling": + v = trunc_exp(v) + if self.cfg.clip_scaling is not None: + v = torch.clamp(v, min=0, max=self.cfg.clip_scaling) + elif k == "opacity": + v = torch.sigmoid(v) + elif k == "shs": + if self.cfg.use_rgb: + v = torch.sigmoid(v) + v = torch.reshape(v, (v.shape[0], -1, 3)) + elif k == "xyz": + if self.cfg.restrict_offset: + max_step = 1.2 / 32 + v = (torch.sigmoid(v) - 0.5) * max_step + v = v + pts if self.cfg.xyz_offset else pts + ret[k] = v + + return GaussianModel(**ret) + +class GS3DRenderer(BaseModule): + @dataclass + class Config(BaseModule.Config): + mlp_network_config: Optional[dict] = None + gs_out: dict = field(default_factory=dict) + sh_degree: int = 3 + scaling_modifier: float = 1.0 + random_background: bool = False + radius: float = 1.0 + feature_reduction: str = "concat" + projection_feature_dim: int = 773 + background_color: Tuple[float, float, float] = field( + default_factory=lambda: (1.0, 1.0, 1.0) + ) + + cfg: Config + + def configure(self, *args, **kwargs) -> None: + if self.cfg.feature_reduction == "mean": + mlp_in = 80 + elif self.cfg.feature_reduction == "concat": + mlp_in = 80 * 3 + else: + raise NotImplementedError + mlp_in = mlp_in + self.cfg.projection_feature_dim + if self.cfg.mlp_network_config is not None: + self.mlp_net = MLP(mlp_in, self.cfg.gs_out.in_channels, **self.cfg.mlp_network_config) + else: + self.cfg.gs_out.in_channels = mlp_in + self.gs_net = GSLayer(self.cfg.gs_out) + + def forward_gs(self, x, p): + if self.cfg.mlp_network_config is not None: + x = self.mlp_net(x) + return self.gs_net(x, p) + + def forward_single_view(self, + gs: GaussianModel, + viewpoint_camera: Camera, + background_color: Optional[Float[Tensor, "3"]], + ret_mask: bool = True, + ): + # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means + screenspace_points = torch.zeros_like(gs.xyz, dtype=gs.xyz.dtype, requires_grad=True, device=self.device) + 0 + try: + screenspace_points.retain_grad() + except: + pass + + bg_color = background_color + # Set up rasterization configuration + tanfovx = math.tan(viewpoint_camera.FoVx * 0.5) + tanfovy = math.tan(viewpoint_camera.FoVy * 0.5) + + raster_settings = GaussianRasterizationSettings( + image_height=int(viewpoint_camera.height), + image_width=int(viewpoint_camera.width), + tanfovx=tanfovx, + tanfovy=tanfovy, + bg=bg_color, + scale_modifier=self.cfg.scaling_modifier, + viewmatrix=viewpoint_camera.world_view_transform, + projmatrix=viewpoint_camera.full_proj_transform.float(), + sh_degree=self.cfg.sh_degree, + campos=viewpoint_camera.camera_center, + prefiltered=False, + debug=False + ) + + rasterizer = GaussianRasterizer(raster_settings=raster_settings) + + means3D = gs.xyz + means2D = screenspace_points + opacity = gs.opacity + + # If precomputed 3d covariance is provided, use it. If not, then it will be computed from + # scaling / rotation by the rasterizer. + scales = None + rotations = None + cov3D_precomp = None + scales = gs.scaling + rotations = gs.rotation + + # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors + # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer. + shs = None + colors_precomp = None + if self.gs_net.cfg.use_rgb: + colors_precomp = gs.shs.squeeze(1) + else: + shs = gs.shs + + # Rasterize visible Gaussians to image, obtain their radii (on screen). + with torch.autocast(device_type=self.device.type, dtype=torch.float32): + rendered_image, radii = rasterizer( + means3D = means3D, + means2D = means2D, + shs = shs, + colors_precomp = colors_precomp, + opacities = opacity, + scales = scales, + rotations = rotations, + cov3D_precomp = cov3D_precomp) + + ret = { + "comp_rgb": rendered_image.permute(1, 2, 0), + "comp_rgb_bg": bg_color + } + + if ret_mask: + mask_bg_color = torch.zeros(3, dtype=torch.float32, device=self.device) + raster_settings = GaussianRasterizationSettings( + image_height=int(viewpoint_camera.height), + image_width=int(viewpoint_camera.width), + tanfovx=tanfovx, + tanfovy=tanfovy, + bg=mask_bg_color, + scale_modifier=self.cfg.scaling_modifier, + viewmatrix=viewpoint_camera.world_view_transform, + projmatrix=viewpoint_camera.full_proj_transform.float(), + sh_degree=0, + campos=viewpoint_camera.camera_center, + prefiltered=False, + debug=False + ) + rasterizer = GaussianRasterizer(raster_settings=raster_settings) + + with torch.autocast(device_type=self.device.type, dtype=torch.float32): + rendered_mask, radii = rasterizer( + means3D = means3D, + means2D = means2D, + # shs = , + colors_precomp = torch.ones_like(means3D), + opacities = opacity, + scales = scales, + rotations = rotations, + cov3D_precomp = cov3D_precomp) + ret["comp_mask"] = rendered_mask.permute(1, 2, 0) + + return ret + + def query_triplane( + self, + positions: Float[Tensor, "*B N 3"], + triplanes: Float[Tensor, "*B 3 Cp Hp Wp"], + ) -> Dict[str, Tensor]: + batched = positions.ndim == 3 + if not batched: + # no batch dimension + triplanes = triplanes[None, ...] + positions = positions[None, ...] + + positions = scale_tensor(positions, (-self.cfg.radius, self.cfg.radius), (-1, 1)) + indices2D: Float[Tensor, "B 3 N 2"] = torch.stack( + (positions[..., [0, 1]], positions[..., [0, 2]], positions[..., [1, 2]]), + dim=-3, + ) + out: Float[Tensor, "B3 Cp 1 N"] = F.grid_sample( + rearrange(triplanes, "B Np Cp Hp Wp -> (B Np) Cp Hp Wp", Np=3), + rearrange(indices2D, "B Np N Nd -> (B Np) () N Nd", Np=3), + align_corners=False, + mode="bilinear", + ) + if self.cfg.feature_reduction == "concat": + out = rearrange(out, "(B Np) Cp () N -> B N (Np Cp)", Np=3) + elif self.cfg.feature_reduction == "mean": + out = reduce(out, "(B Np) Cp () N -> B N Cp", Np=3, reduction="mean") + else: + raise NotImplementedError + + if not batched: + out = out.squeeze(0) + + return out + + def forward_single_batch( + self, + gs_hidden_features: Float[Tensor, "Np Cp"], + query_points: Float[Tensor, "Np 3"], + c2ws: Float[Tensor, "Nv 4 4"], + intrinsics: Float[Tensor, "Nv 4 4"], + height: int, + width: int, + background_color: Optional[Float[Tensor, "3"]], + ): + gs: GaussianModel = self.forward_gs(gs_hidden_features, query_points) + out_list = [] + + for c2w, intrinsic in zip(c2ws, intrinsics): + out_list.append(self.forward_single_view( + gs, + Camera.from_c2w(c2w, intrinsic, height, width), + background_color + )) + + out = defaultdict(list) + for out_ in out_list: + for k, v in out_.items(): + out[k].append(v) + out = {k: torch.stack(v, dim=0) for k, v in out.items()} + out["3dgs"] = gs + + return out + + def forward(self, + gs_hidden_features: Float[Tensor, "B Np Cp"], + query_points: Float[Tensor, "B Np 3"], + c2w: Float[Tensor, "B Nv 4 4"], + intrinsic: Float[Tensor, "B Nv 4 4"], + height, + width, + additional_features: Optional[Float[Tensor, "B C H W"]] = None, + background_color: Optional[Float[Tensor, "B 3"]] = None, + **kwargs): + batch_size = gs_hidden_features.shape[0] + out_list = [] + gs_hidden_features = self.query_triplane(query_points, gs_hidden_features) + if additional_features is not None: + gs_hidden_features = torch.cat([gs_hidden_features, additional_features], dim=-1) + + for b in range(batch_size): + out_list.append(self.forward_single_batch( + gs_hidden_features[b], + query_points[b], + c2w[b], + intrinsic[b], + height, width, + background_color[b] if background_color is not None else None)) + + out = defaultdict(list) + for out_ in out_list: + for k, v in out_.items(): + out[k].append(v) + for k, v in out.items(): + if isinstance(v[0], torch.Tensor): + out[k] = torch.stack(v, dim=0) + else: + out[k] = v + return out + \ No newline at end of file diff --git a/hort/models/tgs/models/snowflake/LICENSE b/hort/models/tgs/models/snowflake/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..01c2b54a0360c374c480e1c6b606b18be9e73792 --- /dev/null +++ b/hort/models/tgs/models/snowflake/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2021 AllenXiang + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/hort/models/tgs/models/snowflake/SPD.py b/hort/models/tgs/models/snowflake/SPD.py new file mode 100644 index 0000000000000000000000000000000000000000..eae592c8c4792cc4067bb5fa0ea97c69a250a6bb --- /dev/null +++ b/hort/models/tgs/models/snowflake/SPD.py @@ -0,0 +1,68 @@ +# -*- coding: utf-8 -*- +# @Author: Peng Xiang + +import torch +import torch.nn as nn +from .utils import MLP_Res, MLP_CONV +from .skip_transformer import SkipTransformer + + +class SPD(nn.Module): + def __init__(self, dim_feat=512, up_factor=2, i=0, radius=1, bounding=True, global_feat=True): + """Snowflake Point Deconvolution""" + super(SPD, self).__init__() + self.i = i + self.up_factor = up_factor + + self.bounding = bounding + self.radius = radius + + self.global_feat = global_feat + self.ps_dim = 32 if global_feat else 64 + + self.mlp_1 = MLP_CONV(in_channel=3, layer_dims=[64, 128]) + self.mlp_2 = MLP_CONV(in_channel=128 * 2 + dim_feat if self.global_feat else 128, layer_dims=[256, 128]) + + self.skip_transformer = SkipTransformer(in_channel=128, dim=64) + + self.mlp_ps = MLP_CONV(in_channel=128, layer_dims=[64, self.ps_dim]) + self.ps = nn.ConvTranspose1d(self.ps_dim, 128, up_factor, up_factor, bias=False) # point-wise splitting + + self.up_sampler = nn.Upsample(scale_factor=up_factor) + self.mlp_delta_feature = MLP_Res(in_dim=256, hidden_dim=128, out_dim=128) + + self.mlp_delta = MLP_CONV(in_channel=128, layer_dims=[64, 3]) + + def forward(self, pcd_prev, feat_global=None, K_prev=None): + """ + Args: + pcd_prev: Tensor, (B, 3, N_prev) + feat_global: Tensor, (B, dim_feat, 1) + K_prev: Tensor, (B, 128, N_prev) + + Returns: + pcd_child: Tensor, up sampled point cloud, (B, 3, N_prev * up_factor) + K_curr: Tensor, displacement feature of current step, (B, 128, N_prev * up_factor) + """ + b, _, n_prev = pcd_prev.shape + feat_1 = self.mlp_1(pcd_prev) + feat_1 = torch.cat([feat_1, + torch.max(feat_1, 2, keepdim=True)[0].repeat((1, 1, feat_1.size(2))), + feat_global.repeat(1, 1, feat_1.size(2))], 1) if self.global_feat else feat_1 + Q = self.mlp_2(feat_1) + + H = self.skip_transformer(pcd_prev, K_prev if K_prev is not None else Q, Q) + + feat_child = self.mlp_ps(H) + feat_child = self.ps(feat_child) # (B, 128, N_prev * up_factor) + H_up = self.up_sampler(H) + K_curr = self.mlp_delta_feature(torch.cat([feat_child, H_up], 1)) + + delta = self.mlp_delta(torch.relu(K_curr)) + if self.bounding: + delta = torch.tanh(delta) / self.radius**self.i # (B, 3, N_prev * up_factor) + + pcd_child = self.up_sampler(pcd_prev) + pcd_child = pcd_child + delta + + return pcd_child, K_curr \ No newline at end of file diff --git a/hort/models/tgs/models/snowflake/SPD_crossattn.py b/hort/models/tgs/models/snowflake/SPD_crossattn.py new file mode 100644 index 0000000000000000000000000000000000000000..90273e9e4b58df393858caab8bf5110eccccee31 --- /dev/null +++ b/hort/models/tgs/models/snowflake/SPD_crossattn.py @@ -0,0 +1,81 @@ +# -*- coding: utf-8 -*- +# @Author: Peng Xiang + +import torch +import torch.nn as nn +from .utils import MLP_Res, MLP_CONV +from .skip_transformer import SkipTransformer +from .attention import ResidualTransformerBlock + +class SPD_crossattn(nn.Module): + def __init__(self, dim_feat=512, up_factor=2, i=0, radius=1, bounding=True, global_feat=True): + """Snowflake Point Deconvolution""" + super().__init__() + self.i = i + self.up_factor = up_factor + + self.bounding = bounding + self.radius = radius + + self.global_feat = global_feat + self.ps_dim = 32 if global_feat else 64 + + self.mlp_1 = MLP_CONV(in_channel=3, layer_dims=[64, 128]) + self.pcd_image_attn = ResidualTransformerBlock( + device=torch.device('cuda'), + dtype=torch.float32, + n_data=128, + width=128, + heads=8, + init_scale=1.0, + ) + + self.mlp_2 = MLP_CONV(in_channel=128 * 2 + dim_feat if self.global_feat else 128, layer_dims=[256, 128]) + + self.skip_transformer = SkipTransformer(in_channel=128, dim=64) + + self.mlp_ps = MLP_CONV(in_channel=128, layer_dims=[64, self.ps_dim]) + self.ps = nn.ConvTranspose1d(self.ps_dim, 128, up_factor, up_factor, bias=False) # point-wise splitting + + self.up_sampler = nn.Upsample(scale_factor=up_factor) + self.mlp_delta_feature = MLP_Res(in_dim=256, hidden_dim=128, out_dim=128) + + self.mlp_delta = MLP_CONV(in_channel=128, layer_dims=[64, 3]) + + def forward(self, pcd_prev, feat_global=None, K_prev=None): + """ + Args: + pcd_prev: Tensor, (B, 3, N_prev) + feat_global: Tensor, (B, dim_feat, 1) + K_prev: Tensor, (B, 128, N_prev) + + Returns: + pcd_child: Tensor, up sampled point cloud, (B, 3, N_prev * up_factor) + K_curr: Tensor, displacement feature of current step, (B, 128, N_prev * up_factor) + """ + b, _, n_prev = pcd_prev.shape + feat_1 = self.mlp_1(pcd_prev) + # feat_1 = torch.cat([feat_1, + # torch.max(feat_1, 2, keepdim=True)[0].repeat((1, 1, feat_1.size(2))), + # feat_global.repeat(1, 1, feat_1.size(2))], 1) if self.global_feat else feat_1 + feat_1 = torch.permute(feat_1, (0, 2, 1)) + feat_global = torch.permute(feat_global, (0, 2, 1)) + feat_1 = self.pcd_image_attn(feat_1, feat_global) + Q = torch.permute(feat_1, (0, 2, 1)) + # Q = self.mlp_2(feat_1) + + H = self.skip_transformer(pcd_prev, K_prev if K_prev is not None else Q, Q) + + feat_child = self.mlp_ps(H) + feat_child = self.ps(feat_child) # (B, 128, N_prev * up_factor) + H_up = self.up_sampler(H) + K_curr = self.mlp_delta_feature(torch.cat([feat_child, H_up], 1)) + + delta = self.mlp_delta(torch.relu(K_curr)) + if self.bounding: + delta = torch.tanh(delta) / self.radius**self.i # (B, 3, N_prev * up_factor) + + pcd_child = self.up_sampler(pcd_prev) + pcd_child = pcd_child + delta + + return pcd_child, K_curr \ No newline at end of file diff --git a/hort/models/tgs/models/snowflake/SPD_pp.py b/hort/models/tgs/models/snowflake/SPD_pp.py new file mode 100644 index 0000000000000000000000000000000000000000..d7a3d3511337b864d083396ae4236922d6c974e2 --- /dev/null +++ b/hort/models/tgs/models/snowflake/SPD_pp.py @@ -0,0 +1,71 @@ + +import torch +import torch.nn as nn +from .utils import MLP_Res, MLP_CONV +from .skip_transformer import SkipTransformer + +class SPD_pp(nn.Module): + def __init__(self, dim_feat=512, up_factor=2, i=0, radius=1, bounding=True, global_feat=True): + """Snowflake Point Deconvolution""" + super(SPD_pp, self).__init__() + self.i = i + self.up_factor = up_factor + + self.bounding = bounding + self.radius = radius + + self.global_feat = global_feat + self.ps_dim = 32 if global_feat else 64 + + self.mlp_1 = MLP_CONV(in_channel=3, layer_dims=[64, 128]) + self.mlp_2 = MLP_CONV( + in_channel=128 * 2 + dim_feat if self.global_feat else 128, layer_dims=[256, 128]) + + self.skip_transformer = SkipTransformer(in_channel=128, dim=64) + + self.mlp_ps = MLP_CONV(in_channel=128, layer_dims=[64, self.ps_dim]) + self.ps = nn.ConvTranspose1d( + self.ps_dim, 128, up_factor, up_factor, bias=False) # point-wise splitting + + self.up_sampler = nn.Upsample(scale_factor=up_factor) + self.mlp_delta_feature = MLP_Res( + in_dim=256, hidden_dim=128, out_dim=128) + + self.mlp_delta = MLP_CONV(in_channel=128, layer_dims=[64, 3]) + + def forward(self, pcd_prev, feat_cond=None, K_prev=None): + """ + Args: + pcd_prev: Tensor, (B, 3, N_prev) + feat_cond: Tensor, (B, dim_feat, N_prev) + K_prev: Tensor, (B, 128, N_prev) + + Returns: + pcd_child: Tensor, up sampled point cloud, (B, 3, N_prev * up_factor) + K_curr: Tensor, displacement feature of current step, (B, 128, N_prev * up_factor) + """ + b, _, n_prev = pcd_prev.shape + feat_1 = self.mlp_1(pcd_prev) + feat_1 = torch.cat([feat_1, + torch.max(feat_1, 2, keepdim=True)[ + 0].repeat((1, 1, feat_1.size(2))), + feat_cond], 1) if self.global_feat else feat_1 + Q = self.mlp_2(feat_1) + + H = self.skip_transformer( + pcd_prev, K_prev if K_prev is not None else Q, Q) + + feat_child = self.mlp_ps(H) + feat_child = self.ps(feat_child) # (B, 128, N_prev * up_factor) + H_up = self.up_sampler(H) + K_curr = self.mlp_delta_feature(torch.cat([feat_child, H_up], 1)) + + delta = self.mlp_delta(torch.relu(K_curr)) + if self.bounding: + # (B, 3, N_prev * up_factor) + delta = torch.tanh(delta) / self.radius**self.i + + pcd_child = self.up_sampler(pcd_prev) + pcd_child = pcd_child + delta + + return pcd_child, K_curr diff --git a/hort/models/tgs/models/snowflake/attention.py b/hort/models/tgs/models/snowflake/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..51d4d80f25182062f8fee74a708807b6e505e64f --- /dev/null +++ b/hort/models/tgs/models/snowflake/attention.py @@ -0,0 +1,239 @@ +import torch +import torch.nn as nn +import math +import math +from typing import Optional +from typing import Callable, Iterable, Sequence, Union + +import torch + +def checkpoint( + func: Callable[..., Union[torch.Tensor, Sequence[torch.Tensor]]], + inputs: Sequence[torch.Tensor], + params: Iterable[torch.Tensor], + flag: bool, +): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) + for x in ctx.input_tensors] + with torch.enable_grad(): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads + + +def init_linear(l, stddev): + nn.init.normal_(l.weight, std=stddev) + if l.bias is not None: + nn.init.constant_(l.bias, 0.0) + +class MLP(nn.Module): + def __init__(self, *, device: torch.device, dtype: torch.dtype, width: int, init_scale: float): + super().__init__() + self.width = width + self.c_fc = nn.Linear(width, width * 4, device=device, dtype=dtype) + self.c_proj = nn.Linear(width * 4, width, device=device, dtype=dtype) + self.gelu = nn.GELU() + init_linear(self.c_fc, init_scale) + init_linear(self.c_proj, init_scale) + + def forward(self, x): + return self.c_proj(self.gelu(self.c_fc(x))) + +class QKVMultiheadCrossAttention(nn.Module): + def __init__(self, *, device: torch.device, dtype: torch.dtype, heads: int, n_data: int): + super().__init__() + self.device = device + self.dtype = dtype + self.heads = heads + self.n_data = n_data + + def forward(self, q, kv): + _, n_ctx, _ = q.shape + bs, n_data, width = kv.shape + attn_ch = width // self.heads // 2 + scale = 1 / math.sqrt(math.sqrt(attn_ch)) + q = q.view(bs, n_ctx, self.heads, -1) + kv = kv.view(bs, n_data, self.heads, -1) + k, v = torch.split(kv, attn_ch, dim=-1) + weight = torch.einsum( + "bthc,bshc->bhts", q * scale, k * scale + ) # More stable with f16 than dividing afterwards + wdtype = weight.dtype + weight = torch.softmax(weight.float(), dim=-1).type(wdtype) + return torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1) + + + +class QKVMultiheadAttention(nn.Module): + def __init__(self, *, device: torch.device, dtype: torch.dtype, heads: int, n_ctx: int): + super().__init__() + self.device = device + self.dtype = dtype + self.heads = heads + self.n_ctx = n_ctx + + def forward(self, qkv): + bs, n_ctx, width = qkv.shape + attn_ch = width // self.heads // 3 + scale = 1 / math.sqrt(math.sqrt(attn_ch)) + qkv = qkv.view(bs, n_ctx, self.heads, -1) + q, k, v = torch.split(qkv, attn_ch, dim=-1) + weight = torch.einsum( + "bthc,bshc->bhts", q * scale, k * scale + ) # More stable with f16 than dividing afterwards + wdtype = weight.dtype + weight = torch.softmax(weight.float(), dim=-1).type(wdtype) + return torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1) + + + +class MultiheadCrossAttention(nn.Module): + def __init__( + self, + *, + device: torch.device, + dtype: torch.dtype, + n_data: int, + width: int, + heads: int, + init_scale: float, + data_width: Optional[int] = None, + ): + super().__init__() + self.n_data = n_data + self.width = width + self.heads = heads + self.data_width = width if data_width is None else data_width + self.c_q = nn.Linear(width, width, device=device, dtype=dtype) + self.c_kv = nn.Linear(self.data_width, width * 2, + device=device, dtype=dtype) + self.c_proj = nn.Linear(width, width, device=device, dtype=dtype) + self.attention = QKVMultiheadCrossAttention( + device=device, dtype=dtype, heads=heads, n_data=n_data + ) + init_linear(self.c_q, init_scale) + init_linear(self.c_kv, init_scale) + init_linear(self.c_proj, init_scale) + + def forward(self, x, data): + x = self.c_q(x) + data = self.c_kv(data) + x = checkpoint(self.attention, (x, data), (), True) + x = self.c_proj(x) + return x + + +class MultiheadAttention(nn.Module): + def __init__( + self, + *, + device: torch.device, + dtype: torch.dtype, + n_ctx: int, + width: int, + heads: int, + init_scale: float, + ): + super().__init__() + self.n_ctx = n_ctx + self.width = width + self.heads = heads + self.c_qkv = nn.Linear(width, width * 3, device=device, dtype=dtype) + self.c_proj = nn.Linear(width, width, device=device, dtype=dtype) + self.attention = QKVMultiheadAttention(device=device, dtype=dtype, heads=heads, n_ctx=n_ctx) + init_linear(self.c_qkv, init_scale) + init_linear(self.c_proj, init_scale) + + def forward(self, x): + x = self.c_qkv(x) + x = checkpoint(self.attention, (x,), (), True) + x = self.c_proj(x) + return x + + +class ResidualTransformerBlock(nn.Module): + def __init__( + self, + *, + device: torch.device, + dtype: torch.dtype, + n_data: int, + width: int, + heads: int, + data_width: Optional[int] = None, + init_scale: float = 1.0, + ): + super().__init__() + + if data_width is None: + data_width = width + + self.attn_cross = MultiheadCrossAttention( + device=device, + dtype=dtype, + n_data=n_data, + width=width, + heads=heads, + data_width=data_width, + init_scale=init_scale, + ) + self.attn_self = MultiheadAttention( + device=device, + dtype=dtype, + n_ctx=n_data, + width=width, + heads=heads, + init_scale=init_scale, + ) + self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype) + self.ln_2 = nn.LayerNorm(data_width, device=device, dtype=dtype) + self.ln_3 = nn.LayerNorm(width, device=device, dtype=dtype) + self.mlp = MLP(device=device, dtype=dtype, + width=width, init_scale=init_scale) + self.ln_4 = nn.LayerNorm(width, device=device, dtype=dtype) + + def forward(self, x: torch.Tensor, data: torch.Tensor): + x = x + self.attn_cross(self.ln_1(x), self.ln_2(data)) + x = x + self.attn_self(self.ln_3(x)) + x = x + self.mlp(self.ln_4(x)) + return x \ No newline at end of file diff --git a/hort/models/tgs/models/snowflake/model_spdpp.py b/hort/models/tgs/models/snowflake/model_spdpp.py new file mode 100644 index 0000000000000000000000000000000000000000..8787859f784a085b79b8dbdef6cff0de63fa8c0c --- /dev/null +++ b/hort/models/tgs/models/snowflake/model_spdpp.py @@ -0,0 +1,239 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from tgs.utils.base import BaseModule +from tgs.utils.typing import * +from dataclasses import dataclass, field + +from pytorch3d.renderer import ( + AlphaCompositor, + NormWeightedCompositor, + PointsRasterizationSettings, + PointsRasterizer, + PointsRenderer) +from pytorch3d.renderer.cameras import CamerasBase +from pytorch3d.structures import Pointclouds +from pytorch3d.utils.camera_conversions import cameras_from_opencv_projection + +from .utils import fps_subsample +from einops import rearrange + +from .utils import MLP_CONV +from .SPD import SPD +from .SPD_crossattn import SPD_crossattn +from .SPD_pp import SPD_pp + +SPD_BLOCK = { + 'SPD': SPD, + 'SPD_crossattn': SPD_crossattn, + 'SPD_PP': SPD_pp, +} + + +def homoify(points): + """ + Convert a batch of points to homogeneous coordinates. + Args: + points: e.g. (B, N, 3) or (N, 3) + Returns: + homoified points: e.g., (B, N, 4) + """ + points_dim = points.shape[:-1] + (1,) + ones = points.new_ones(points_dim) + + return torch.cat([points, ones], dim=-1) + + +def dehomoify(points): + """ + Convert a batch of homogeneous points to cartesian coordinates. + Args: + homogeneous points: (B, N, 4/3) or (N, 4/3) + Returns: + cartesian points: (B, N, 3/2) + """ + return points[..., :-1] / points[..., -1:] + + +def mask_generation(points: Float[Tensor, "B Np 3"], + intrinsics: Float[Tensor, "B 3 3"], + input_img: Float[Tensor, "B C H W"], + raster_point_radius: float = 0.01, # point size + raster_points_per_pixel: int = 1, # a single point per pixel, for now + bin_size: int = 0): + """ + points: (B, Np, 3) + """ + B, C, H, W = input_img.shape + device = intrinsics.device + + cam_R = torch.eye(3).to(device).unsqueeze(0).repeat(B, 1, 1) + cam_t = torch.zeros(3).to(device).unsqueeze(0).repeat(B, 1) + + raster_settings = PointsRasterizationSettings(image_size=(H, W), radius=raster_point_radius, points_per_pixel=raster_points_per_pixel, bin_size=bin_size) + + image_size = torch.as_tensor([H, W]).view(1, 2).expand(B, -1).to(device) + cameras = cameras_from_opencv_projection(cam_R, cam_t, intrinsics, image_size) + + rasterize = PointsRasterizer(cameras=cameras, raster_settings=raster_settings) + fragments = rasterize(Pointclouds(points)) + + fragments_idx: Tensor = fragments.idx.long() + mask = (fragments_idx[..., 0] > -1) + + return mask.float() + + +def points_projection(points: Float[Tensor, "B Np 3"], + intrinsics: Float[Tensor, "B 3 3"], + local_features: Float[Tensor, "B C H W"], + raster_point_radius: float = 0.0075, # point size + raster_points_per_pixel: int = 1, # a single point per pixel, for now + bin_size: int = 0): + """ + points: (B, Np, 3) + """ + B, C, H, W = local_features.shape + device = local_features.device + cam_R = torch.eye(3).to(device).unsqueeze(0).repeat(B, 1, 1) + cam_t = torch.zeros(3).to(device).unsqueeze(0).repeat(B, 1) + + raster_settings = PointsRasterizationSettings(image_size=(H, W), radius=raster_point_radius, points_per_pixel=raster_points_per_pixel, bin_size=bin_size) + Np = points.shape[1] + R = raster_settings.points_per_pixel + image_size = torch.as_tensor([H, W]).view(1, 2).expand(B, -1).to(device) + cameras = cameras_from_opencv_projection(cam_R, cam_t, intrinsics, image_size) + rasterize = PointsRasterizer(cameras=cameras, raster_settings=raster_settings) + fragments = rasterize(Pointclouds(points)) + fragments_idx: Tensor = fragments.idx.long() + visible_pixels = (fragments_idx > -1) # (B, H, W, R) + points_to_visible_pixels = fragments_idx[visible_pixels] + # Reshape local features to (B, H, W, R, C) + local_features = local_features.permute(0, 2, 3, 1).unsqueeze(-2).expand(-1, -1, -1, R, -1) # (B, H, W, R, C) + # Get local features corresponding to visible points + local_features_proj = torch.zeros(B * Np, C, device=device) + local_features_proj[points_to_visible_pixels] = local_features[visible_pixels] + local_features_proj = local_features_proj.reshape(B, Np, C) + return local_features_proj + + +def points_projection_v2(input_xyz_points, cam_intr, feature_maps): + input_points = input_xyz_points.clone() + batch_size = input_points.shape[0] + xyz = input_points[:, :, :3] + homo_xyz = homoify(xyz) + homo_xyz_2d = torch.matmul(cam_intr, homo_xyz.transpose(1, 2)).transpose(1, 2) + xyz_2d = (homo_xyz_2d[:, :, :2] / homo_xyz_2d[:, :, [2]]).unsqueeze(2) + uv_2d = xyz_2d / 224 * 2 - 1 + sample_feat = torch.nn.functional.grid_sample(feature_maps, uv_2d, align_corners=True)[:, :, :, 0].transpose(1, 2) + uv_2d = uv_2d.squeeze(2).reshape((-1, 2)) + validity = (uv_2d[:, 0] >= -1.0) & (uv_2d[:, 0] <= 1.0) & (uv_2d[:, 1] >= -1.0) & (uv_2d[:, 1] <= 1.0) + validity = validity.unsqueeze(1) + + return sample_feat + + +class Decoder(nn.Module): + def __init__(self, input_channels=1152, dim_feat=512, num_p0=512, + radius=1, bounding=True, up_factors=None, + SPD_type='SPD', + token_type='image_token' + ): + super(Decoder, self).__init__() + # self.decoder_coarse = SeedGenerator(dim_feat=dim_feat, num_pc=num_p0) + if up_factors is None: + up_factors = [1] + else: + up_factors = up_factors + uppers = [] + self.num_p0 = num_p0 + self.mlp_feat_cond = MLP_CONV(in_channel=input_channels, + layer_dims=[dim_feat*2, dim_feat]) + + for i, factor in enumerate(up_factors): + uppers.append( + SPD_BLOCK[SPD_type](dim_feat=dim_feat, up_factor=factor, + i=i, bounding=bounding, radius=radius)) + self.uppers = nn.ModuleList(uppers) + self.token_type = token_type + + def calculate_pcl_token(self, pcl_token, up_factor): + up_token = F.interpolate(pcl_token, scale_factor=up_factor, mode='nearest') + return up_token + + def calculate_image_token(self, pcd, input_image_tokens, batch): + """ + Args: + """ + batch_size = input_image_tokens.shape[0] + h_cond, w_cond = 224, 224 + input_image_tokens = input_image_tokens.permute(0, 2, 1) + local_features = input_image_tokens[:, 1:].reshape(batch_size, h_cond // 14, w_cond // 14, -1).permute(0, 3, 1, 2) + # local_features = F.interpolate(local_features, size=(h_cond, w_cond), mode='bilinear', align_corners=False) + local_features_proj = points_projection_v2(pcd * batch['scale'] + batch['trans'].unsqueeze(1), batch['intrinsic_cond'], local_features) + local_features_proj = local_features_proj.permute(0, 2, 1).contiguous() + + return local_features_proj + + def forward(self, x): + """ + Args: + points: Tensor, (b, num_p0, 3) + feat_cond: Tensor, (b, dim_feat) dinov2: 325x768 + # partial_coarse: Tensor, (b, n_coarse, 3) + """ + points = x['points'] + if self.token_type == 'pcl_token': + feat_cond = x['pcl_token'] + elif self.token_type == 'image_token': + feat_cond = x['input_image_tokens'] + feat_cond = self.mlp_feat_cond(feat_cond) + arr_pcd = [] + feat_prev = None + + pcd = torch.permute(points, (0, 2, 1)).contiguous() + pcl_up_scale = 1 + for upper in self.uppers: + if self.token_type == 'pcl_token': + up_cond = self.calculate_pcl_token( + feat_cond, pcl_up_scale) + pcl_up_scale *= upper.up_factor + elif self.token_type == 'image_token': + up_cond = self.calculate_image_token(points, feat_cond, x) + pcd, feat_prev = upper(pcd, up_cond, feat_prev) + points = torch.permute(pcd, (0, 2, 1)).contiguous() + arr_pcd.append(points) + return arr_pcd + + +class SnowflakeModelSPDPP(BaseModule): + """ + apply PC^2 / PCL token to decoder + """ + @dataclass + class Config(BaseModule.Config): + input_channels: int = 1152 + dim_feat: int = 128 + num_p0: int = 512 + radius: float = 1 + bounding: bool = True + use_fps: bool = True + up_factors: List[int] = field(default_factory=lambda: [2, 2]) + image_full_token_cond: bool = False + SPD_type: str = 'SPD_PP' + token_type: str = 'pcl_token' + cfg: Config + + def configure(self) -> None: + super().configure() + self.decoder = Decoder(input_channels=self.cfg.input_channels, + dim_feat=self.cfg.dim_feat, num_p0=self.cfg.num_p0, + radius=self.cfg.radius, up_factors=self.cfg.up_factors, bounding=self.cfg.bounding, + SPD_type=self.cfg.SPD_type, + token_type=self.cfg.token_type + ) + + def forward(self, x): + results = self.decoder(x) + return results diff --git a/hort/models/tgs/models/snowflake/pointnet2.py b/hort/models/tgs/models/snowflake/pointnet2.py new file mode 100644 index 0000000000000000000000000000000000000000..01aae13aa2fd00a4f4e9a843f310a6698748f135 --- /dev/null +++ b/hort/models/tgs/models/snowflake/pointnet2.py @@ -0,0 +1,126 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from pointnet2_ops.pointnet2_modules import PointnetFPModule, PointnetSAModule + + +class PointNet2ClassificationSSG(nn.Module): + def __init__(self): + super().__init__() + self._build_model() + + def _build_model(self): + self.SA_modules = nn.ModuleList() + self.SA_modules.append( + PointnetSAModule( + npoint=512, + radius=0.2, + nsample=64, + mlp=[3, 64, 64, 128], + use_xyz=True, + ) + ) + self.SA_modules.append( + PointnetSAModule( + npoint=128, + radius=0.4, + nsample=64, + mlp=[128, 128, 128, 256], + use_xyz=True, + ) + ) + self.SA_modules.append( + PointnetSAModule( + mlp=[256, 256, 512, 1024], use_xyz=True, + ) + ) + + self.fc_layer = nn.Sequential( + nn.Linear(1024, 512, bias=False), + nn.BatchNorm1d(512), + nn.ReLU(True), + nn.Linear(512, 256, bias=False), + nn.BatchNorm1d(256), + nn.ReLU(True), + nn.Dropout(0.5), + nn.Linear(256, 40), + ) + + def _break_up_pc(self, pc): + xyz = pc[..., 0:3].contiguous() + features = pc[..., 3:].transpose(1, 2).contiguous() if pc.size(-1) > 3 else None + + return xyz, features + + def forward(self, pointcloud): + r""" + Forward pass of the network + + Parameters + ---------- + pointcloud: Variable(torch.cuda.FloatTensor) + (B, N, 3 + input_channels) tensor + Point cloud to run predicts on + Each point in the point-cloud MUST + be formated as (x, y, z, features...) + """ + xyz, features = self._break_up_pc(pointcloud) + + for module in self.SA_modules: + xyz, features = module(xyz, features) + + return self.fc_layer(features.squeeze(-1)) + + +class PointNet2SemSegSSG(PointNet2ClassificationSSG): + def _build_model(self): + self.SA_modules = nn.ModuleList() + self.SA_modules.append( + PointnetSAModule( + npoint=256, + radius=0.05, + nsample=32, + mlp=[1, 32, 64], + use_xyz=True, + ) + ) + self.SA_modules.append( + PointnetSAModule( + npoint=64, + radius=0.10, + nsample=32, + mlp=[64, 128, 256], + use_xyz=True, + ) + ) + self.SA_modules.append( + PointnetSAModule( + npoint=16, + radius=0.20, + nsample=32, + mlp=[256, 512, 768], + use_xyz=True, + ) + ) + + def forward(self, pointcloud): + r""" + Forward pass of the network + + Parameters + ---------- + pointcloud: Variable(torch.cuda.FloatTensor) + (B, N, 3 + input_channels) tensor + Point cloud to run predicts on + Each point in the point-cloud MUST + be formated as (x, y, z, features...) + """ + xyz, features = self._break_up_pc(pointcloud) + + l_xyz, l_features = [xyz], [features] + for i in range(len(self.SA_modules)): + li_xyz, li_features = self.SA_modules[i](l_xyz[i], l_features[i]) + l_xyz.append(li_xyz) + l_features.append(li_features) + + return l_features[-1].transpose(2, 1) diff --git a/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/__init__.py b/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5fd361f9abbacc218f7699b8c439902b9d1bf745 --- /dev/null +++ b/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/__init__.py @@ -0,0 +1,3 @@ +import pointnet2_ops.pointnet2_modules +import pointnet2_ops.pointnet2_utils +from pointnet2_ops._version import __version__ diff --git a/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/ball_query.h b/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/ball_query.h new file mode 100644 index 0000000000000000000000000000000000000000..1bbc6389c2837ad801e991643dae0c0401ee782c --- /dev/null +++ b/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/ball_query.h @@ -0,0 +1,5 @@ +#pragma once +#include + +at::Tensor ball_query(at::Tensor new_xyz, at::Tensor xyz, const float radius, + const int nsample); diff --git a/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/cuda_utils.h b/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/cuda_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..0fd5b6edcfe3e7f7a03bd75d0ffdc9fc92be25eb --- /dev/null +++ b/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/cuda_utils.h @@ -0,0 +1,41 @@ +#ifndef _CUDA_UTILS_H +#define _CUDA_UTILS_H + +#include +#include +#include + +#include +#include + +#include + +#define TOTAL_THREADS 512 + +inline int opt_n_threads(int work_size) { + const int pow_2 = std::log(static_cast(work_size)) / std::log(2.0); + + return max(min(1 << pow_2, TOTAL_THREADS), 1); +} + +inline dim3 opt_block_config(int x, int y) { + const int x_threads = opt_n_threads(x); + const int y_threads = + max(min(opt_n_threads(y), TOTAL_THREADS / x_threads), 1); + dim3 block_config(x_threads, y_threads, 1); + + return block_config; +} + +#define CUDA_CHECK_ERRORS() \ + do { \ + cudaError_t err = cudaGetLastError(); \ + if (cudaSuccess != err) { \ + fprintf(stderr, "CUDA kernel failed : %s\n%s at L:%d in %s\n", \ + cudaGetErrorString(err), __PRETTY_FUNCTION__, __LINE__, \ + __FILE__); \ + exit(-1); \ + } \ + } while (0) + +#endif diff --git a/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/group_points.h b/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/group_points.h new file mode 100644 index 0000000000000000000000000000000000000000..ad20cda9e68c92fd4e05a319f31fdcf7ef1e6427 --- /dev/null +++ b/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/group_points.h @@ -0,0 +1,5 @@ +#pragma once +#include + +at::Tensor group_points(at::Tensor points, at::Tensor idx); +at::Tensor group_points_grad(at::Tensor grad_out, at::Tensor idx, const int n); diff --git a/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/interpolate.h b/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/interpolate.h new file mode 100644 index 0000000000000000000000000000000000000000..26b34648396783c9858f8f4e10b869a6e74e6a11 --- /dev/null +++ b/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/interpolate.h @@ -0,0 +1,10 @@ +#pragma once + +#include +#include + +std::vector three_nn(at::Tensor unknowns, at::Tensor knows); +at::Tensor three_interpolate(at::Tensor points, at::Tensor idx, + at::Tensor weight); +at::Tensor three_interpolate_grad(at::Tensor grad_out, at::Tensor idx, + at::Tensor weight, const int m); diff --git a/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/sampling.h b/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/sampling.h new file mode 100644 index 0000000000000000000000000000000000000000..d795271200c1a65525a509282ccac058f22f2422 --- /dev/null +++ b/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/sampling.h @@ -0,0 +1,6 @@ +#pragma once +#include + +at::Tensor gather_points(at::Tensor points, at::Tensor idx); +at::Tensor gather_points_grad(at::Tensor grad_out, at::Tensor idx, const int n); +at::Tensor furthest_point_sampling(at::Tensor points, const int nsamples); diff --git a/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/utils.h b/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/utils.h new file mode 100644 index 0000000000000000000000000000000000000000..5f080ed1e455d3bff318e11cc893e3856c132099 --- /dev/null +++ b/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/utils.h @@ -0,0 +1,25 @@ +#pragma once +#include +#include + +#define CHECK_CUDA(x) \ + do { \ + AT_ASSERT(x.is_cuda(), #x " must be a CUDA tensor"); \ + } while (0) + +#define CHECK_CONTIGUOUS(x) \ + do { \ + AT_ASSERT(x.is_contiguous(), #x " must be a contiguous tensor"); \ + } while (0) + +#define CHECK_IS_INT(x) \ + do { \ + AT_ASSERT(x.scalar_type() == at::ScalarType::Int, \ + #x " must be an int tensor"); \ + } while (0) + +#define CHECK_IS_FLOAT(x) \ + do { \ + AT_ASSERT(x.scalar_type() == at::ScalarType::Float, \ + #x " must be a float tensor"); \ + } while (0) diff --git a/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/ball_query.cpp b/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/ball_query.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b1797c1aeecd32ecaf0cb614c4f2a23a1be9b777 --- /dev/null +++ b/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/ball_query.cpp @@ -0,0 +1,32 @@ +#include "ball_query.h" +#include "utils.h" + +void query_ball_point_kernel_wrapper(int b, int n, int m, float radius, + int nsample, const float *new_xyz, + const float *xyz, int *idx); + +at::Tensor ball_query(at::Tensor new_xyz, at::Tensor xyz, const float radius, + const int nsample) { + CHECK_CONTIGUOUS(new_xyz); + CHECK_CONTIGUOUS(xyz); + CHECK_IS_FLOAT(new_xyz); + CHECK_IS_FLOAT(xyz); + + if (new_xyz.is_cuda()) { + CHECK_CUDA(xyz); + } + + at::Tensor idx = + torch::zeros({new_xyz.size(0), new_xyz.size(1), nsample}, + at::device(new_xyz.device()).dtype(at::ScalarType::Int)); + + if (new_xyz.is_cuda()) { + query_ball_point_kernel_wrapper(xyz.size(0), xyz.size(1), new_xyz.size(1), + radius, nsample, new_xyz.data_ptr(), + xyz.data_ptr(), idx.data_ptr()); + } else { + AT_ASSERT(false, "CPU not supported"); + } + + return idx; +} diff --git a/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/ball_query_gpu.cu b/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/ball_query_gpu.cu new file mode 100644 index 0000000000000000000000000000000000000000..8e38e9c1cfb2fd240b1b02d8ef5371b10017aaea --- /dev/null +++ b/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/ball_query_gpu.cu @@ -0,0 +1,54 @@ +#include +#include +#include + +#include "cuda_utils.h" + +// input: new_xyz(b, m, 3) xyz(b, n, 3) +// output: idx(b, m, nsample) +__global__ void query_ball_point_kernel(int b, int n, int m, float radius, + int nsample, + const float *__restrict__ new_xyz, + const float *__restrict__ xyz, + int *__restrict__ idx) { + int batch_index = blockIdx.x; + xyz += batch_index * n * 3; + new_xyz += batch_index * m * 3; + idx += m * nsample * batch_index; + + int index = threadIdx.x; + int stride = blockDim.x; + + float radius2 = radius * radius; + for (int j = index; j < m; j += stride) { + float new_x = new_xyz[j * 3 + 0]; + float new_y = new_xyz[j * 3 + 1]; + float new_z = new_xyz[j * 3 + 2]; + for (int k = 0, cnt = 0; k < n && cnt < nsample; ++k) { + float x = xyz[k * 3 + 0]; + float y = xyz[k * 3 + 1]; + float z = xyz[k * 3 + 2]; + float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + + (new_z - z) * (new_z - z); + if (d2 < radius2) { + if (cnt == 0) { + for (int l = 0; l < nsample; ++l) { + idx[j * nsample + l] = k; + } + } + idx[j * nsample + cnt] = k; + ++cnt; + } + } + } +} + +void query_ball_point_kernel_wrapper(int b, int n, int m, float radius, + int nsample, const float *new_xyz, + const float *xyz, int *idx) { + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + query_ball_point_kernel<<>>( + b, n, m, radius, nsample, new_xyz, xyz, idx); + + //CUDA_CHECK_ERRORS(); +} diff --git a/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/bindings.cpp b/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/bindings.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d1916ce1d57f9728de961eaa65e59c7318c683ae --- /dev/null +++ b/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/bindings.cpp @@ -0,0 +1,19 @@ +#include "ball_query.h" +#include "group_points.h" +#include "interpolate.h" +#include "sampling.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("gather_points", &gather_points); + m.def("gather_points_grad", &gather_points_grad); + m.def("furthest_point_sampling", &furthest_point_sampling); + + m.def("three_nn", &three_nn); + m.def("three_interpolate", &three_interpolate); + m.def("three_interpolate_grad", &three_interpolate_grad); + + m.def("ball_query", &ball_query); + + m.def("group_points", &group_points); + m.def("group_points_grad", &group_points_grad); +} diff --git a/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/group_points.cpp b/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/group_points.cpp new file mode 100644 index 0000000000000000000000000000000000000000..285a4bd42688aabcc829bbeb120a06daaa87a96e --- /dev/null +++ b/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/group_points.cpp @@ -0,0 +1,62 @@ +#include "group_points.h" +#include "utils.h" + +void group_points_kernel_wrapper(int b, int c, int n, int npoints, int nsample, + const float *points, const int *idx, + float *out); + +void group_points_grad_kernel_wrapper(int b, int c, int n, int npoints, + int nsample, const float *grad_out, + const int *idx, float *grad_points); + +at::Tensor group_points(at::Tensor points, at::Tensor idx) { + CHECK_CONTIGUOUS(points); + CHECK_CONTIGUOUS(idx); + CHECK_IS_FLOAT(points); + CHECK_IS_INT(idx); + + if (points.is_cuda()) { + CHECK_CUDA(idx); + } + + at::Tensor output = + torch::zeros({points.size(0), points.size(1), idx.size(1), idx.size(2)}, + at::device(points.device()).dtype(at::ScalarType::Float)); + + if (points.is_cuda()) { + group_points_kernel_wrapper(points.size(0), points.size(1), points.size(2), + idx.size(1), idx.size(2), + points.data_ptr(), idx.data_ptr(), + output.data_ptr()); + } else { + AT_ASSERT(false, "CPU not supported"); + } + + return output; +} + +at::Tensor group_points_grad(at::Tensor grad_out, at::Tensor idx, const int n) { + CHECK_CONTIGUOUS(grad_out); + CHECK_CONTIGUOUS(idx); + CHECK_IS_FLOAT(grad_out); + CHECK_IS_INT(idx); + + if (grad_out.is_cuda()) { + CHECK_CUDA(idx); + } + + at::Tensor output = + torch::zeros({grad_out.size(0), grad_out.size(1), n}, + at::device(grad_out.device()).dtype(at::ScalarType::Float)); + + if (grad_out.is_cuda()) { + group_points_grad_kernel_wrapper( + grad_out.size(0), grad_out.size(1), n, idx.size(1), idx.size(2), + grad_out.data_ptr(), idx.data_ptr(), + output.data_ptr()); + } else { + AT_ASSERT(false, "CPU not supported"); + } + + return output; +} diff --git a/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/group_points_gpu.cu b/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/group_points_gpu.cu new file mode 100644 index 0000000000000000000000000000000000000000..38283b72126e464f1ec89d4e05b40aecb70ede6d --- /dev/null +++ b/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/group_points_gpu.cu @@ -0,0 +1,75 @@ +#include +#include + +#include "cuda_utils.h" + +// input: points(b, c, n) idx(b, npoints, nsample) +// output: out(b, c, npoints, nsample) +__global__ void group_points_kernel(int b, int c, int n, int npoints, + int nsample, + const float *__restrict__ points, + const int *__restrict__ idx, + float *__restrict__ out) { + int batch_index = blockIdx.x; + points += batch_index * n * c; + idx += batch_index * npoints * nsample; + out += batch_index * npoints * nsample * c; + + const int index = threadIdx.y * blockDim.x + threadIdx.x; + const int stride = blockDim.y * blockDim.x; + for (int i = index; i < c * npoints; i += stride) { + const int l = i / npoints; + const int j = i % npoints; + for (int k = 0; k < nsample; ++k) { + int ii = idx[j * nsample + k]; + out[(l * npoints + j) * nsample + k] = points[l * n + ii]; + } + } +} + +void group_points_kernel_wrapper(int b, int c, int n, int npoints, int nsample, + const float *points, const int *idx, + float *out) { + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + group_points_kernel<<>>( + b, c, n, npoints, nsample, points, idx, out); + + //CUDA_CHECK_ERRORS(); +} + +// input: grad_out(b, c, npoints, nsample), idx(b, npoints, nsample) +// output: grad_points(b, c, n) +__global__ void group_points_grad_kernel(int b, int c, int n, int npoints, + int nsample, + const float *__restrict__ grad_out, + const int *__restrict__ idx, + float *__restrict__ grad_points) { + int batch_index = blockIdx.x; + grad_out += batch_index * npoints * nsample * c; + idx += batch_index * npoints * nsample; + grad_points += batch_index * n * c; + + const int index = threadIdx.y * blockDim.x + threadIdx.x; + const int stride = blockDim.y * blockDim.x; + for (int i = index; i < c * npoints; i += stride) { + const int l = i / npoints; + const int j = i % npoints; + for (int k = 0; k < nsample; ++k) { + int ii = idx[j * nsample + k]; + atomicAdd(grad_points + l * n + ii, + grad_out[(l * npoints + j) * nsample + k]); + } + } +} + +void group_points_grad_kernel_wrapper(int b, int c, int n, int npoints, + int nsample, const float *grad_out, + const int *idx, float *grad_points) { + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + group_points_grad_kernel<<>>( + b, c, n, npoints, nsample, grad_out, idx, grad_points); + + //CUDA_CHECK_ERRORS(); +} diff --git a/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/interpolate.cpp b/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/interpolate.cpp new file mode 100644 index 0000000000000000000000000000000000000000..cdee31ca7758278729c0dc2855ce5aba197104e3 --- /dev/null +++ b/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/interpolate.cpp @@ -0,0 +1,99 @@ +#include "interpolate.h" +#include "utils.h" + +void three_nn_kernel_wrapper(int b, int n, int m, const float *unknown, + const float *known, float *dist2, int *idx); +void three_interpolate_kernel_wrapper(int b, int c, int m, int n, + const float *points, const int *idx, + const float *weight, float *out); +void three_interpolate_grad_kernel_wrapper(int b, int c, int n, int m, + const float *grad_out, + const int *idx, const float *weight, + float *grad_points); + +std::vector three_nn(at::Tensor unknowns, at::Tensor knows) { + CHECK_CONTIGUOUS(unknowns); + CHECK_CONTIGUOUS(knows); + CHECK_IS_FLOAT(unknowns); + CHECK_IS_FLOAT(knows); + + if (unknowns.is_cuda()) { + CHECK_CUDA(knows); + } + + at::Tensor idx = + torch::zeros({unknowns.size(0), unknowns.size(1), 3}, + at::device(unknowns.device()).dtype(at::ScalarType::Int)); + at::Tensor dist2 = + torch::zeros({unknowns.size(0), unknowns.size(1), 3}, + at::device(unknowns.device()).dtype(at::ScalarType::Float)); + + if (unknowns.is_cuda()) { + three_nn_kernel_wrapper(unknowns.size(0), unknowns.size(1), knows.size(1), + unknowns.data_ptr(), knows.data_ptr(), + dist2.data_ptr(), idx.data_ptr()); + } else { + AT_ASSERT(false, "CPU not supported"); + } + + return {dist2, idx}; +} + +at::Tensor three_interpolate(at::Tensor points, at::Tensor idx, + at::Tensor weight) { + CHECK_CONTIGUOUS(points); + CHECK_CONTIGUOUS(idx); + CHECK_CONTIGUOUS(weight); + CHECK_IS_FLOAT(points); + CHECK_IS_INT(idx); + CHECK_IS_FLOAT(weight); + + if (points.is_cuda()) { + CHECK_CUDA(idx); + CHECK_CUDA(weight); + } + + at::Tensor output = + torch::zeros({points.size(0), points.size(1), idx.size(1)}, + at::device(points.device()).dtype(at::ScalarType::Float)); + + if (points.is_cuda()) { + three_interpolate_kernel_wrapper( + points.size(0), points.size(1), points.size(2), idx.size(1), + points.data_ptr(), idx.data_ptr(), weight.data_ptr(), + output.data_ptr()); + } else { + AT_ASSERT(false, "CPU not supported"); + } + + return output; +} +at::Tensor three_interpolate_grad(at::Tensor grad_out, at::Tensor idx, + at::Tensor weight, const int m) { + CHECK_CONTIGUOUS(grad_out); + CHECK_CONTIGUOUS(idx); + CHECK_CONTIGUOUS(weight); + CHECK_IS_FLOAT(grad_out); + CHECK_IS_INT(idx); + CHECK_IS_FLOAT(weight); + + if (grad_out.is_cuda()) { + CHECK_CUDA(idx); + CHECK_CUDA(weight); + } + + at::Tensor output = + torch::zeros({grad_out.size(0), grad_out.size(1), m}, + at::device(grad_out.device()).dtype(at::ScalarType::Float)); + + if (grad_out.is_cuda()) { + three_interpolate_grad_kernel_wrapper( + grad_out.size(0), grad_out.size(1), grad_out.size(2), m, + grad_out.data_ptr(), idx.data_ptr(), + weight.data_ptr(), output.data_ptr()); + } else { + AT_ASSERT(false, "CPU not supported"); + } + + return output; +} diff --git a/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/interpolate_gpu.cu b/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/interpolate_gpu.cu new file mode 100644 index 0000000000000000000000000000000000000000..ef245848ac797aef1ae08f8ce49cfb4338e92dc6 --- /dev/null +++ b/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/interpolate_gpu.cu @@ -0,0 +1,154 @@ +#include +#include +#include + +#include "cuda_utils.h" + +// input: unknown(b, n, 3) known(b, m, 3) +// output: dist2(b, n, 3), idx(b, n, 3) +__global__ void three_nn_kernel(int b, int n, int m, + const float *__restrict__ unknown, + const float *__restrict__ known, + float *__restrict__ dist2, + int *__restrict__ idx) { + int batch_index = blockIdx.x; + unknown += batch_index * n * 3; + known += batch_index * m * 3; + dist2 += batch_index * n * 3; + idx += batch_index * n * 3; + + int index = threadIdx.x; + int stride = blockDim.x; + for (int j = index; j < n; j += stride) { + float ux = unknown[j * 3 + 0]; + float uy = unknown[j * 3 + 1]; + float uz = unknown[j * 3 + 2]; + + double best1 = 1e40, best2 = 1e40, best3 = 1e40; + int besti1 = 0, besti2 = 0, besti3 = 0; + for (int k = 0; k < m; ++k) { + float x = known[k * 3 + 0]; + float y = known[k * 3 + 1]; + float z = known[k * 3 + 2]; + float d = (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z); + if (d < best1) { + best3 = best2; + besti3 = besti2; + best2 = best1; + besti2 = besti1; + best1 = d; + besti1 = k; + } else if (d < best2) { + best3 = best2; + besti3 = besti2; + best2 = d; + besti2 = k; + } else if (d < best3) { + best3 = d; + besti3 = k; + } + } + dist2[j * 3 + 0] = best1; + dist2[j * 3 + 1] = best2; + dist2[j * 3 + 2] = best3; + + idx[j * 3 + 0] = besti1; + idx[j * 3 + 1] = besti2; + idx[j * 3 + 2] = besti3; + } +} + +void three_nn_kernel_wrapper(int b, int n, int m, const float *unknown, + const float *known, float *dist2, int *idx) { + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + three_nn_kernel<<>>(b, n, m, unknown, known, + dist2, idx); + + //CUDA_CHECK_ERRORS(); +} + +// input: points(b, c, m), idx(b, n, 3), weight(b, n, 3) +// output: out(b, c, n) +__global__ void three_interpolate_kernel(int b, int c, int m, int n, + const float *__restrict__ points, + const int *__restrict__ idx, + const float *__restrict__ weight, + float *__restrict__ out) { + int batch_index = blockIdx.x; + points += batch_index * m * c; + + idx += batch_index * n * 3; + weight += batch_index * n * 3; + + out += batch_index * n * c; + + const int index = threadIdx.y * blockDim.x + threadIdx.x; + const int stride = blockDim.y * blockDim.x; + for (int i = index; i < c * n; i += stride) { + const int l = i / n; + const int j = i % n; + float w1 = weight[j * 3 + 0]; + float w2 = weight[j * 3 + 1]; + float w3 = weight[j * 3 + 2]; + + int i1 = idx[j * 3 + 0]; + int i2 = idx[j * 3 + 1]; + int i3 = idx[j * 3 + 2]; + + out[i] = points[l * m + i1] * w1 + points[l * m + i2] * w2 + + points[l * m + i3] * w3; + } +} + +void three_interpolate_kernel_wrapper(int b, int c, int m, int n, + const float *points, const int *idx, + const float *weight, float *out) { + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + three_interpolate_kernel<<>>( + b, c, m, n, points, idx, weight, out); + + //CUDA_CHECK_ERRORS(); +} + +// input: grad_out(b, c, n), idx(b, n, 3), weight(b, n, 3) +// output: grad_points(b, c, m) + +__global__ void three_interpolate_grad_kernel( + int b, int c, int n, int m, const float *__restrict__ grad_out, + const int *__restrict__ idx, const float *__restrict__ weight, + float *__restrict__ grad_points) { + int batch_index = blockIdx.x; + grad_out += batch_index * n * c; + idx += batch_index * n * 3; + weight += batch_index * n * 3; + grad_points += batch_index * m * c; + + const int index = threadIdx.y * blockDim.x + threadIdx.x; + const int stride = blockDim.y * blockDim.x; + for (int i = index; i < c * n; i += stride) { + const int l = i / n; + const int j = i % n; + float w1 = weight[j * 3 + 0]; + float w2 = weight[j * 3 + 1]; + float w3 = weight[j * 3 + 2]; + + int i1 = idx[j * 3 + 0]; + int i2 = idx[j * 3 + 1]; + int i3 = idx[j * 3 + 2]; + + atomicAdd(grad_points + l * m + i1, grad_out[i] * w1); + atomicAdd(grad_points + l * m + i2, grad_out[i] * w2); + atomicAdd(grad_points + l * m + i3, grad_out[i] * w3); + } +} + +void three_interpolate_grad_kernel_wrapper(int b, int c, int n, int m, + const float *grad_out, + const int *idx, const float *weight, + float *grad_points) { + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + three_interpolate_grad_kernel<<>>( + b, c, n, m, grad_out, idx, weight, grad_points); + + CUDA_CHECK_ERRORS(); +} diff --git a/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/sampling.cpp b/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/sampling.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ddbdc11b6bbecb3aa1238f0d760bd6a47303b858 --- /dev/null +++ b/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/sampling.cpp @@ -0,0 +1,87 @@ +#include "sampling.h" +#include "utils.h" + +void gather_points_kernel_wrapper(int b, int c, int n, int npoints, + const float *points, const int *idx, + float *out); +void gather_points_grad_kernel_wrapper(int b, int c, int n, int npoints, + const float *grad_out, const int *idx, + float *grad_points); + +void furthest_point_sampling_kernel_wrapper(int b, int n, int m, + const float *dataset, float *temp, + int *idxs); + +at::Tensor gather_points(at::Tensor points, at::Tensor idx) { + CHECK_CONTIGUOUS(points); + CHECK_CONTIGUOUS(idx); + CHECK_IS_FLOAT(points); + CHECK_IS_INT(idx); + + if (points.is_cuda()) { + CHECK_CUDA(idx); + } + + at::Tensor output = + torch::zeros({points.size(0), points.size(1), idx.size(1)}, + at::device(points.device()).dtype(at::ScalarType::Float)); + + if (points.is_cuda()) { + gather_points_kernel_wrapper(points.size(0), points.size(1), points.size(2), + idx.size(1), points.data_ptr(), + idx.data_ptr(), output.data_ptr()); + } else { + AT_ASSERT(false, "CPU not supported"); + } + + return output; +} + +at::Tensor gather_points_grad(at::Tensor grad_out, at::Tensor idx, + const int n) { + CHECK_CONTIGUOUS(grad_out); + CHECK_CONTIGUOUS(idx); + CHECK_IS_FLOAT(grad_out); + CHECK_IS_INT(idx); + + if (grad_out.is_cuda()) { + CHECK_CUDA(idx); + } + + at::Tensor output = + torch::zeros({grad_out.size(0), grad_out.size(1), n}, + at::device(grad_out.device()).dtype(at::ScalarType::Float)); + + if (grad_out.is_cuda()) { + gather_points_grad_kernel_wrapper(grad_out.size(0), grad_out.size(1), n, + idx.size(1), grad_out.data_ptr(), + idx.data_ptr(), + output.data_ptr()); + } else { + AT_ASSERT(false, "CPU not supported"); + } + + return output; +} +at::Tensor furthest_point_sampling(at::Tensor points, const int nsamples) { + CHECK_CONTIGUOUS(points); + CHECK_IS_FLOAT(points); + + at::Tensor output = + torch::zeros({points.size(0), nsamples}, + at::device(points.device()).dtype(at::ScalarType::Int)); + + at::Tensor tmp = + torch::full({points.size(0), points.size(1)}, 1e10, + at::device(points.device()).dtype(at::ScalarType::Float)); + + if (points.is_cuda()) { + furthest_point_sampling_kernel_wrapper( + points.size(0), points.size(1), nsamples, points.data_ptr(), + tmp.data_ptr(), output.data_ptr()); + } else { + AT_ASSERT(false, "CPU not supported"); + } + + return output; +} diff --git a/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/sampling_gpu.cu b/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/sampling_gpu.cu new file mode 100644 index 0000000000000000000000000000000000000000..d6f83aa6cca7f8acb3b288d1701c7b5507c532a7 --- /dev/null +++ b/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/sampling_gpu.cu @@ -0,0 +1,229 @@ +#include +#include + +#include "cuda_utils.h" + +// input: points(b, c, n) idx(b, m) +// output: out(b, c, m) +__global__ void gather_points_kernel(int b, int c, int n, int m, + const float *__restrict__ points, + const int *__restrict__ idx, + float *__restrict__ out) { + for (int i = blockIdx.x; i < b; i += gridDim.x) { + for (int l = blockIdx.y; l < c; l += gridDim.y) { + for (int j = threadIdx.x; j < m; j += blockDim.x) { + int a = idx[i * m + j]; + out[(i * c + l) * m + j] = points[(i * c + l) * n + a]; + } + } + } +} + +void gather_points_kernel_wrapper(int b, int c, int n, int npoints, + const float *points, const int *idx, + float *out) { + gather_points_kernel<<>>(b, c, n, npoints, + points, idx, out); + + //CUDA_CHECK_ERRORS(); +} + +// input: grad_out(b, c, m) idx(b, m) +// output: grad_points(b, c, n) +__global__ void gather_points_grad_kernel(int b, int c, int n, int m, + const float *__restrict__ grad_out, + const int *__restrict__ idx, + float *__restrict__ grad_points) { + for (int i = blockIdx.x; i < b; i += gridDim.x) { + for (int l = blockIdx.y; l < c; l += gridDim.y) { + for (int j = threadIdx.x; j < m; j += blockDim.x) { + int a = idx[i * m + j]; + atomicAdd(grad_points + (i * c + l) * n + a, + grad_out[(i * c + l) * m + j]); + } + } + } +} + +void gather_points_grad_kernel_wrapper(int b, int c, int n, int npoints, + const float *grad_out, const int *idx, + float *grad_points) { + gather_points_grad_kernel<<>>( + b, c, n, npoints, grad_out, idx, grad_points); + + //CUDA_CHECK_ERRORS(); +} + +__device__ void __update(float *__restrict__ dists, int *__restrict__ dists_i, + int idx1, int idx2) { + const float v1 = dists[idx1], v2 = dists[idx2]; + const int i1 = dists_i[idx1], i2 = dists_i[idx2]; + dists[idx1] = max(v1, v2); + dists_i[idx1] = v2 > v1 ? i2 : i1; +} + +// Input dataset: (b, n, 3), tmp: (b, n) +// Ouput idxs (b, m) +template +__global__ void furthest_point_sampling_kernel( + int b, int n, int m, const float *__restrict__ dataset, + float *__restrict__ temp, int *__restrict__ idxs) { + if (m <= 0) return; + __shared__ float dists[block_size]; + __shared__ int dists_i[block_size]; + + int batch_index = blockIdx.x; + dataset += batch_index * n * 3; + temp += batch_index * n; + idxs += batch_index * m; + + int tid = threadIdx.x; + const int stride = block_size; + + int old = 0; + if (threadIdx.x == 0) idxs[0] = old; + + __syncthreads(); + for (int j = 1; j < m; j++) { + int besti = 0; + float best = -1; + float x1 = dataset[old * 3 + 0]; + float y1 = dataset[old * 3 + 1]; + float z1 = dataset[old * 3 + 2]; + for (int k = tid; k < n; k += stride) { + float x2, y2, z2; + x2 = dataset[k * 3 + 0]; + y2 = dataset[k * 3 + 1]; + z2 = dataset[k * 3 + 2]; + float mag = (x2 * x2) + (y2 * y2) + (z2 * z2); + if (mag <= 1e-3) continue; + + float d = + (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1); + + float d2 = min(d, temp[k]); + temp[k] = d2; + besti = d2 > best ? k : besti; + best = d2 > best ? d2 : best; + } + dists[tid] = best; + dists_i[tid] = besti; + __syncthreads(); + + if (block_size >= 512) { + if (tid < 256) { + __update(dists, dists_i, tid, tid + 256); + } + __syncthreads(); + } + if (block_size >= 256) { + if (tid < 128) { + __update(dists, dists_i, tid, tid + 128); + } + __syncthreads(); + } + if (block_size >= 128) { + if (tid < 64) { + __update(dists, dists_i, tid, tid + 64); + } + __syncthreads(); + } + if (block_size >= 64) { + if (tid < 32) { + __update(dists, dists_i, tid, tid + 32); + } + __syncthreads(); + } + if (block_size >= 32) { + if (tid < 16) { + __update(dists, dists_i, tid, tid + 16); + } + __syncthreads(); + } + if (block_size >= 16) { + if (tid < 8) { + __update(dists, dists_i, tid, tid + 8); + } + __syncthreads(); + } + if (block_size >= 8) { + if (tid < 4) { + __update(dists, dists_i, tid, tid + 4); + } + __syncthreads(); + } + if (block_size >= 4) { + if (tid < 2) { + __update(dists, dists_i, tid, tid + 2); + } + __syncthreads(); + } + if (block_size >= 2) { + if (tid < 1) { + __update(dists, dists_i, tid, tid + 1); + } + __syncthreads(); + } + + old = dists_i[0]; + if (tid == 0) idxs[j] = old; + } +} + +void furthest_point_sampling_kernel_wrapper(int b, int n, int m, + const float *dataset, float *temp, + int *idxs) { + unsigned int n_threads = opt_n_threads(n); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + switch (n_threads) { + case 512: + furthest_point_sampling_kernel<512> + <<>>(b, n, m, dataset, temp, idxs); + break; + case 256: + furthest_point_sampling_kernel<256> + <<>>(b, n, m, dataset, temp, idxs); + break; + case 128: + furthest_point_sampling_kernel<128> + <<>>(b, n, m, dataset, temp, idxs); + break; + case 64: + furthest_point_sampling_kernel<64> + <<>>(b, n, m, dataset, temp, idxs); + break; + case 32: + furthest_point_sampling_kernel<32> + <<>>(b, n, m, dataset, temp, idxs); + break; + case 16: + furthest_point_sampling_kernel<16> + <<>>(b, n, m, dataset, temp, idxs); + break; + case 8: + furthest_point_sampling_kernel<8> + <<>>(b, n, m, dataset, temp, idxs); + break; + case 4: + furthest_point_sampling_kernel<4> + <<>>(b, n, m, dataset, temp, idxs); + break; + case 2: + furthest_point_sampling_kernel<2> + <<>>(b, n, m, dataset, temp, idxs); + break; + case 1: + furthest_point_sampling_kernel<1> + <<>>(b, n, m, dataset, temp, idxs); + break; + default: + furthest_point_sampling_kernel<512> + <<>>(b, n, m, dataset, temp, idxs); + } + + //CUDA_CHECK_ERRORS(); +} diff --git a/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_version.py b/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_version.py new file mode 100644 index 0000000000000000000000000000000000000000..528787cfc8ad81ed41822a8104b60b4896632906 --- /dev/null +++ b/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_version.py @@ -0,0 +1 @@ +__version__ = "3.0.0" diff --git a/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/pointnet2_modules.py b/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/pointnet2_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..a0ad4f6bc23f54ca2d61454e657a6f533e9b875c --- /dev/null +++ b/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/pointnet2_modules.py @@ -0,0 +1,209 @@ +from typing import List, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from pointnet2_ops import pointnet2_utils + + +def build_shared_mlp(mlp_spec: List[int], bn: bool = True): + layers = [] + for i in range(1, len(mlp_spec)): + layers.append( + nn.Conv2d(mlp_spec[i - 1], mlp_spec[i], kernel_size=1, bias=not bn) + ) + if bn: + layers.append(nn.BatchNorm2d(mlp_spec[i])) + layers.append(nn.ReLU(True)) + + return nn.Sequential(*layers) + + +class _PointnetSAModuleBase(nn.Module): + def __init__(self): + super(_PointnetSAModuleBase, self).__init__() + self.npoint = None + self.groupers = None + self.mlps = None + + def forward( + self, xyz: torch.Tensor, features: Optional[torch.Tensor] + ) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Parameters + ---------- + xyz : torch.Tensor + (B, N, 3) tensor of the xyz coordinates of the features + features : torch.Tensor + (B, C, N) tensor of the descriptors of the the features + + Returns + ------- + new_xyz : torch.Tensor + (B, npoint, 3) tensor of the new features' xyz + new_features : torch.Tensor + (B, \sum_k(mlps[k][-1]), npoint) tensor of the new_features descriptors + """ + + new_features_list = [] + + xyz_flipped = xyz.transpose(1, 2).contiguous() + new_xyz = ( + pointnet2_utils.gather_operation( + xyz_flipped, pointnet2_utils.furthest_point_sample(xyz, self.npoint) + ) + .transpose(1, 2) + .contiguous() + if self.npoint is not None + else None + ) + + for i in range(len(self.groupers)): + new_features = self.groupers[i]( + xyz, new_xyz, features + ) # (B, C, npoint, nsample) + + new_features = self.mlps[i](new_features) # (B, mlp[-1], npoint, nsample) + new_features = F.max_pool2d( + new_features, kernel_size=[1, new_features.size(3)] + ) # (B, mlp[-1], npoint, 1) + new_features = new_features.squeeze(-1) # (B, mlp[-1], npoint) + + new_features_list.append(new_features) + + return new_xyz, torch.cat(new_features_list, dim=1) + + +class PointnetSAModuleMSG(_PointnetSAModuleBase): + r"""Pointnet set abstrction layer with multiscale grouping + + Parameters + ---------- + npoint : int + Number of features + radii : list of float32 + list of radii to group with + nsamples : list of int32 + Number of samples in each ball query + mlps : list of list of int32 + Spec of the pointnet before the global max_pool for each scale + bn : bool + Use batchnorm + """ + + def __init__(self, npoint, radii, nsamples, mlps, bn=True, use_xyz=True): + # type: (PointnetSAModuleMSG, int, List[float], List[int], List[List[int]], bool, bool) -> None + super(PointnetSAModuleMSG, self).__init__() + + assert len(radii) == len(nsamples) == len(mlps) + + self.npoint = npoint + self.groupers = nn.ModuleList() + self.mlps = nn.ModuleList() + for i in range(len(radii)): + radius = radii[i] + nsample = nsamples[i] + self.groupers.append( + pointnet2_utils.QueryAndGroup(radius, nsample, use_xyz=use_xyz) + if npoint is not None + else pointnet2_utils.GroupAll(use_xyz) + ) + mlp_spec = mlps[i] + if use_xyz: + mlp_spec[0] += 3 + + self.mlps.append(build_shared_mlp(mlp_spec, bn)) + + +class PointnetSAModule(PointnetSAModuleMSG): + r"""Pointnet set abstrction layer + + Parameters + ---------- + npoint : int + Number of features + radius : float + Radius of ball + nsample : int + Number of samples in the ball query + mlp : list + Spec of the pointnet before the global max_pool + bn : bool + Use batchnorm + """ + + def __init__( + self, mlp, npoint=None, radius=None, nsample=None, bn=True, use_xyz=True + ): + # type: (PointnetSAModule, List[int], int, float, int, bool, bool) -> None + super(PointnetSAModule, self).__init__( + mlps=[mlp], + npoint=npoint, + radii=[radius], + nsamples=[nsample], + bn=bn, + use_xyz=use_xyz, + ) + + +class PointnetFPModule(nn.Module): + r"""Propigates the features of one set to another + + Parameters + ---------- + mlp : list + Pointnet module parameters + bn : bool + Use batchnorm + """ + + def __init__(self, mlp, bn=True): + # type: (PointnetFPModule, List[int], bool) -> None + super(PointnetFPModule, self).__init__() + self.mlp = build_shared_mlp(mlp, bn=bn) + + def forward(self, unknown, known, unknow_feats, known_feats): + # type: (PointnetFPModule, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor) -> torch.Tensor + r""" + Parameters + ---------- + unknown : torch.Tensor + (B, n, 3) tensor of the xyz positions of the unknown features + known : torch.Tensor + (B, m, 3) tensor of the xyz positions of the known features + unknow_feats : torch.Tensor + (B, C1, n) tensor of the features to be propigated to + known_feats : torch.Tensor + (B, C2, m) tensor of features to be propigated + + Returns + ------- + new_features : torch.Tensor + (B, mlp[-1], n) tensor of the features of the unknown features + """ + + if known is not None: + dist, idx = pointnet2_utils.three_nn(unknown, known) + dist_recip = 1.0 / (dist + 1e-8) + norm = torch.sum(dist_recip, dim=2, keepdim=True) + weight = dist_recip / norm + + interpolated_feats = pointnet2_utils.three_interpolate( + known_feats, idx, weight + ) + else: + interpolated_feats = known_feats.expand( + *(known_feats.size()[0:2] + [unknown.size(1)]) + ) + + if unknow_feats is not None: + new_features = torch.cat( + [interpolated_feats, unknow_feats], dim=1 + ) # (B, C2 + C1, n) + else: + new_features = interpolated_feats + + new_features = new_features.unsqueeze(-1) + new_features = self.mlp(new_features) + + return new_features.squeeze(-1) diff --git a/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/pointnet2_utils.py b/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/pointnet2_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2972e2e7793ae73f5f52a533f938f14db8a0341a --- /dev/null +++ b/hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/pointnet2_utils.py @@ -0,0 +1,391 @@ +import torch +import torch.nn as nn +import warnings +from torch.autograd import Function +from typing import * + +try: + import pointnet2_ops._ext as _ext +except ImportError: + from torch.utils.cpp_extension import load + import glob + import os.path as osp + import os + + warnings.warn("Unable to load pointnet2_ops cpp extension. JIT Compiling.") + + _ext_src_root = osp.join(osp.dirname(__file__), "_ext-src") + _ext_sources = glob.glob(osp.join(_ext_src_root, "src", "*.cpp")) + glob.glob( + osp.join(_ext_src_root, "src", "*.cu") + ) + _ext_headers = glob.glob(osp.join(_ext_src_root, "include", "*")) + + os.environ["TORCH_CUDA_ARCH_LIST"] = "3.7+PTX;5.0;6.0;6.1;6.2;7.0;7.5" + _ext = load( + "_ext", + sources=_ext_sources, + extra_include_paths=[osp.join(_ext_src_root, "include")], + extra_cflags=["-O3"], + extra_cuda_cflags=["-O3", "-Xfatbin", "-compress-all"], + with_cuda=True, + ) + + +class FurthestPointSampling(Function): + @staticmethod + @torch.amp.custom_fwd(cast_inputs=torch.float32, device_type="cuda") + def forward(ctx, xyz, npoint): + # type: (Any, torch.Tensor, int) -> torch.Tensor + r""" + Uses iterative furthest point sampling to select a set of npoint features that have the largest + minimum distance + + Parameters + ---------- + xyz : torch.Tensor + (B, N, 3) tensor where N > npoint + npoint : int32 + number of features in the sampled set + + Returns + ------- + torch.Tensor + (B, npoint) tensor containing the set + """ + out = _ext.furthest_point_sampling(xyz, npoint) + + ctx.mark_non_differentiable(out) + + return out + + @staticmethod + @torch.amp.custom_bwd(device_type="cuda") + def backward(ctx, grad_out): + return () + + +furthest_point_sample = FurthestPointSampling.apply + + +class GatherOperation(Function): + @staticmethod + @torch.amp.custom_fwd(cast_inputs=torch.float32, device_type="cuda") + def forward(ctx, features, idx): + # type: (Any, torch.Tensor, torch.Tensor) -> torch.Tensor + r""" + + Parameters + ---------- + features : torch.Tensor + (B, C, N) tensor + + idx : torch.Tensor + (B, npoint) tensor of the features to gather + + Returns + ------- + torch.Tensor + (B, C, npoint) tensor + """ + + ctx.save_for_backward(idx, features) + + return _ext.gather_points(features, idx) + + @staticmethod + @torch.amp.custom_bwd(device_type="cuda") + def backward(ctx, grad_out): + idx, features = ctx.saved_tensors + N = features.size(2) + + grad_features = _ext.gather_points_grad(grad_out.contiguous(), idx, N) + return grad_features, None + + +gather_operation = GatherOperation.apply + + +class ThreeNN(Function): + @staticmethod + @torch.amp.custom_fwd(cast_inputs=torch.float32, device_type="cuda") + def forward(ctx, unknown, known): + # type: (Any, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor] + r""" + Find the three nearest neighbors of unknown in known + Parameters + ---------- + unknown : torch.Tensor + (B, n, 3) tensor of known features + known : torch.Tensor + (B, m, 3) tensor of unknown features + + Returns + ------- + dist : torch.Tensor + (B, n, 3) l2 distance to the three nearest neighbors + idx : torch.Tensor + (B, n, 3) index of 3 nearest neighbors + """ + dist2, idx = _ext.three_nn(unknown, known) + dist = torch.sqrt(dist2) + + ctx.mark_non_differentiable(dist, idx) + + return dist, idx + + @staticmethod + @torch.amp.custom_bwd(device_type="cuda") + def backward(ctx, grad_dist, grad_idx): + return () + + +three_nn = ThreeNN.apply + + +class ThreeInterpolate(Function): + @staticmethod + @torch.amp.custom_fwd(cast_inputs=torch.float32, device_type="cuda") + def forward(ctx, features, idx, weight): + # type(Any, torch.Tensor, torch.Tensor, torch.Tensor) -> Torch.Tensor + r""" + Performs weight linear interpolation on 3 features + Parameters + ---------- + features : torch.Tensor + (B, c, m) Features descriptors to be interpolated from + idx : torch.Tensor + (B, n, 3) three nearest neighbors of the target features in features + weight : torch.Tensor + (B, n, 3) weights + + Returns + ------- + torch.Tensor + (B, c, n) tensor of the interpolated features + """ + ctx.save_for_backward(idx, weight, features) + + return _ext.three_interpolate(features, idx, weight) + + @staticmethod + @torch.amp.custom_bwd(device_type="cuda") + def backward(ctx, grad_out): + # type: (Any, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor] + r""" + Parameters + ---------- + grad_out : torch.Tensor + (B, c, n) tensor with gradients of ouputs + + Returns + ------- + grad_features : torch.Tensor + (B, c, m) tensor with gradients of features + + None + + None + """ + idx, weight, features = ctx.saved_tensors + m = features.size(2) + + grad_features = _ext.three_interpolate_grad( + grad_out.contiguous(), idx, weight, m + ) + + return grad_features, torch.zeros_like(idx), torch.zeros_like(weight) + + +three_interpolate = ThreeInterpolate.apply + + +class GroupingOperation(Function): + @staticmethod + @torch.amp.custom_fwd(cast_inputs=torch.float32, device_type="cuda") + def forward(ctx, features, idx): + # type: (Any, torch.Tensor, torch.Tensor) -> torch.Tensor + r""" + + Parameters + ---------- + features : torch.Tensor + (B, C, N) tensor of features to group + idx : torch.Tensor + (B, npoint, nsample) tensor containing the indicies of features to group with + + Returns + ------- + torch.Tensor + (B, C, npoint, nsample) tensor + """ + ctx.save_for_backward(idx, features) + + return _ext.group_points(features, idx) + + @staticmethod + @torch.amp.custom_bwd(device_type="cuda") + def backward(ctx, grad_out): + # type: (Any, torch.tensor) -> Tuple[torch.Tensor, torch.Tensor] + r""" + + Parameters + ---------- + grad_out : torch.Tensor + (B, C, npoint, nsample) tensor of the gradients of the output from forward + + Returns + ------- + torch.Tensor + (B, C, N) gradient of the features + None + """ + idx, features = ctx.saved_tensors + N = features.size(2) + + grad_features = _ext.group_points_grad(grad_out.contiguous(), idx, N) + + return grad_features, torch.zeros_like(idx) + + +grouping_operation = GroupingOperation.apply + + +class BallQuery(Function): + @staticmethod + @torch.amp.custom_fwd(cast_inputs=torch.float32, device_type="cuda") + def forward(ctx, radius, nsample, xyz, new_xyz): + # type: (Any, float, int, torch.Tensor, torch.Tensor) -> torch.Tensor + r""" + + Parameters + ---------- + radius : float + radius of the balls + nsample : int + maximum number of features in the balls + xyz : torch.Tensor + (B, N, 3) xyz coordinates of the features + new_xyz : torch.Tensor + (B, npoint, 3) centers of the ball query + + Returns + ------- + torch.Tensor + (B, npoint, nsample) tensor with the indicies of the features that form the query balls + """ + output = _ext.ball_query(new_xyz, xyz, radius, nsample) + + ctx.mark_non_differentiable(output) + + return output + + @staticmethod + @torch.amp.custom_bwd(device_type="cuda") + def backward(ctx, grad_out): + return () + + +ball_query = BallQuery.apply + + +class QueryAndGroup(nn.Module): + r""" + Groups with a ball query of radius + + Parameters + --------- + radius : float32 + Radius of ball + nsample : int32 + Maximum number of features to gather in the ball + """ + + def __init__(self, radius, nsample, use_xyz=True): + # type: (QueryAndGroup, float, int, bool) -> None + super(QueryAndGroup, self).__init__() + self.radius, self.nsample, self.use_xyz = radius, nsample, use_xyz + + def forward(self, xyz, new_xyz, features=None): + # type: (QueryAndGroup, torch.Tensor. torch.Tensor, torch.Tensor) -> Tuple[Torch.Tensor] + r""" + Parameters + ---------- + xyz : torch.Tensor + xyz coordinates of the features (B, N, 3) + new_xyz : torch.Tensor + centriods (B, npoint, 3) + features : torch.Tensor + Descriptors of the features (B, C, N) + + Returns + ------- + new_features : torch.Tensor + (B, 3 + C, npoint, nsample) tensor + """ + + idx = ball_query(self.radius, self.nsample, xyz, new_xyz) + xyz_trans = xyz.transpose(1, 2).contiguous() + grouped_xyz = grouping_operation(xyz_trans, idx) # (B, 3, npoint, nsample) + grouped_xyz -= new_xyz.transpose(1, 2).unsqueeze(-1) + + if features is not None: + grouped_features = grouping_operation(features, idx) + if self.use_xyz: + new_features = torch.cat( + [grouped_xyz, grouped_features], dim=1 + ) # (B, C + 3, npoint, nsample) + else: + new_features = grouped_features + else: + assert ( + self.use_xyz + ), "Cannot have not features and not use xyz as a feature!" + new_features = grouped_xyz + + return new_features + + +class GroupAll(nn.Module): + r""" + Groups all features + + Parameters + --------- + """ + + def __init__(self, use_xyz=True): + # type: (GroupAll, bool) -> None + super(GroupAll, self).__init__() + self.use_xyz = use_xyz + + def forward(self, xyz, new_xyz, features=None): + # type: (GroupAll, torch.Tensor, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor] + r""" + Parameters + ---------- + xyz : torch.Tensor + xyz coordinates of the features (B, N, 3) + new_xyz : torch.Tensor + Ignored + features : torch.Tensor + Descriptors of the features (B, C, N) + + Returns + ------- + new_features : torch.Tensor + (B, C + 3, 1, N) tensor + """ + + grouped_xyz = xyz.transpose(1, 2).unsqueeze(2) + if features is not None: + grouped_features = features.unsqueeze(2) + if self.use_xyz: + new_features = torch.cat( + [grouped_xyz, grouped_features], dim=1 + ) # (B, 3 + C, 1, N) + else: + new_features = grouped_features + else: + new_features = grouped_xyz + + return new_features diff --git a/hort/models/tgs/models/snowflake/pointnet2_ops_lib/setup.py b/hort/models/tgs/models/snowflake/pointnet2_ops_lib/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..d5cbe1ca5d45acab7bfcd880a743f1f883cd7002 --- /dev/null +++ b/hort/models/tgs/models/snowflake/pointnet2_ops_lib/setup.py @@ -0,0 +1,41 @@ +import glob +import os +import os.path as osp + +from setuptools import find_packages, setup +import torch +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +this_dir = osp.dirname(osp.abspath(__file__)) +_ext_src_root = osp.join("pointnet2_ops", "_ext-src") +_ext_sources = glob.glob(osp.join(_ext_src_root, "src", "*.cpp")) + glob.glob( + osp.join(_ext_src_root, "src", "*.cu") +) +_ext_headers = glob.glob(osp.join(_ext_src_root, "include", "*")) + +requirements = ["torch>=1.4"] + +exec(open(osp.join("pointnet2_ops", "_version.py")).read()) + +# os.environ["TORCH_CUDA_ARCH_LIST"] = ".".join(map(str, torch.cuda.get_device_capability())) +os.environ["TORCH_CUDA_ARCH_LIST"] = "5.0;6.0;6.1;6.2;7.0;7.5;8.0;8.6;9.0" +setup( + name="pointnet2_ops", + version=__version__, + author="Erik Wijmans", + packages=find_packages(), + install_requires=requirements, + ext_modules=[ + CUDAExtension( + name="pointnet2_ops._ext", + sources=_ext_sources, + extra_compile_args={ + "cxx": ["-O3"], + "nvcc": ["-O3", "-Xfatbin", "-compress-all"], + }, + include_dirs=[osp.join(this_dir, _ext_src_root, "include")], + ) + ], + cmdclass={"build_ext": BuildExtension}, + include_package_data=True, +) diff --git a/hort/models/tgs/models/snowflake/skip_transformer.py b/hort/models/tgs/models/snowflake/skip_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..36462f7242dd3dbe58495c19a61d8cf0cc54857e --- /dev/null +++ b/hort/models/tgs/models/snowflake/skip_transformer.py @@ -0,0 +1,69 @@ +# -*- coding: utf-8 -*- +# @Author: Peng Xiang + +import torch +from torch import nn, einsum +from .utils import MLP_Res, grouping_operation, query_knn + + +class SkipTransformer(nn.Module): + def __init__(self, in_channel, dim=256, n_knn=16, pos_hidden_dim=64, attn_hidden_multiplier=4): + super(SkipTransformer, self).__init__() + self.mlp_v = MLP_Res(in_dim=in_channel*2, hidden_dim=in_channel, out_dim=in_channel) + self.n_knn = n_knn + self.conv_key = nn.Conv1d(in_channel, dim, 1) + self.conv_query = nn.Conv1d(in_channel, dim, 1) + self.conv_value = nn.Conv1d(in_channel, dim, 1) + + self.pos_mlp = nn.Sequential( + nn.Conv2d(3, pos_hidden_dim, 1), + nn.BatchNorm2d(pos_hidden_dim), + nn.ReLU(), + nn.Conv2d(pos_hidden_dim, dim, 1) + ) + + self.attn_mlp = nn.Sequential( + nn.Conv2d(dim, dim * attn_hidden_multiplier, 1), + nn.BatchNorm2d(dim * attn_hidden_multiplier), + nn.ReLU(), + nn.Conv2d(dim * attn_hidden_multiplier, dim, 1) + ) + + self.conv_end = nn.Conv1d(dim, in_channel, 1) + + def forward(self, pos, key, query, include_self=True): + """ + Args: + pos: (B, 3, N) + key: (B, in_channel, N) + query: (B, in_channel, N) + include_self: boolean + + Returns: + Tensor: (B, in_channel, N), shape context feature + """ + value = self.mlp_v(torch.cat([key, query], 1)) + identity = value + key = self.conv_key(key) + query = self.conv_query(query) + value = self.conv_value(value) + b, dim, n = value.shape + + pos_flipped = pos.permute(0, 2, 1).contiguous() + idx_knn = query_knn(self.n_knn, pos_flipped, pos_flipped, include_self=include_self) + + key = grouping_operation(key, idx_knn) # b, dim, n, n_knn + qk_rel = query.reshape((b, -1, n, 1)) - key + + pos_rel = pos.reshape((b, -1, n, 1)) - grouping_operation(pos, idx_knn) # b, 3, n, n_knn + pos_embedding = self.pos_mlp(pos_rel) + + attention = self.attn_mlp(qk_rel + pos_embedding) # b, dim, n, n_knn + attention = torch.softmax(attention, -1) + + value = value.reshape((b, -1, n, 1)) + pos_embedding # + + agg = einsum('b c i j, b c i j -> b c i', attention, value) # b, dim, n + y = self.conv_end(agg) + + return y + identity diff --git a/hort/models/tgs/models/snowflake/utils.py b/hort/models/tgs/models/snowflake/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6661e38da308f5a05714274c8b9ade378be61486 --- /dev/null +++ b/hort/models/tgs/models/snowflake/utils.py @@ -0,0 +1,741 @@ +# -*- coding: utf-8 -*- +# @Author: Peng Xiang + +import types +import torch +import torch.nn.functional as F +import numpy as np +from torch import nn, einsum +from pointnet2_ops.pointnet2_utils import furthest_point_sample, \ + gather_operation, ball_query, three_nn, three_interpolate, grouping_operation + +class Conv1d(nn.Module): + def __init__(self, in_channel, out_channel, kernel_size=1, stride=1, if_bn=True, activation_fn=torch.relu): + super(Conv1d, self).__init__() + self.conv = nn.Conv1d(in_channel, out_channel, kernel_size, stride=stride) + self.if_bn = if_bn + self.bn = nn.BatchNorm1d(out_channel) + self.activation_fn = activation_fn + + def forward(self, input): + out = self.conv(input) + if self.if_bn: + out = self.bn(out) + + if self.activation_fn is not None: + out = self.activation_fn(out) + + return out + +class Conv2d(nn.Module): + def __init__(self, in_channel, out_channel, kernel_size=(1, 1), stride=(1, 1), if_bn=True, activation_fn=torch.relu): + super(Conv2d, self).__init__() + self.conv = nn.Conv2d(in_channel, out_channel, kernel_size, stride=stride) + self.if_bn = if_bn + self.bn = nn.BatchNorm2d(out_channel) + self.activation_fn = activation_fn + + def forward(self, input): + out = self.conv(input) + if self.if_bn: + out = self.bn(out) + + if self.activation_fn is not None: + out = self.activation_fn(out) + + return out + +class MLP(nn.Module): + def __init__(self, in_channel, layer_dims, bn=None): + super(MLP, self).__init__() + layers = [] + last_channel = in_channel + for out_channel in layer_dims[:-1]: + layers.append(nn.Linear(last_channel, out_channel)) + if bn: + layers.append(nn.BatchNorm1d(out_channel)) + layers.append(nn.ReLU()) + last_channel = out_channel + layers.append(nn.Linear(last_channel, layer_dims[-1])) + self.mlp = nn.Sequential(*layers) + + def forward(self, inputs): + return self.mlp(inputs) + +class MLP_CONV(nn.Module): + def __init__(self, in_channel, layer_dims, bn=None): + super(MLP_CONV, self).__init__() + layers = [] + last_channel = in_channel + for out_channel in layer_dims[:-1]: + layers.append(nn.Conv1d(last_channel, out_channel, 1)) + if bn: + layers.append(nn.BatchNorm1d(out_channel)) + layers.append(nn.ReLU()) + last_channel = out_channel + layers.append(nn.Conv1d(last_channel, layer_dims[-1], 1)) + self.mlp = nn.Sequential(*layers) + + def forward(self, inputs): + return self.mlp(inputs) + +class MLP_Res(nn.Module): + def __init__(self, in_dim=128, hidden_dim=None, out_dim=128): + super(MLP_Res, self).__init__() + if hidden_dim is None: + hidden_dim = in_dim + self.conv_1 = nn.Conv1d(in_dim, hidden_dim, 1) + self.conv_2 = nn.Conv1d(hidden_dim, out_dim, 1) + self.conv_shortcut = nn.Conv1d(in_dim, out_dim, 1) + + def forward(self, x): + """ + Args: + x: (B, out_dim, n) + """ + shortcut = self.conv_shortcut(x) + out = self.conv_2(torch.relu(self.conv_1(x))) + shortcut + return out + + +def sample_and_group(xyz, points, npoint, nsample, radius, use_xyz=True): + """ + Args: + xyz: Tensor, (B, 3, N) + points: Tensor, (B, f, N) + npoint: int + nsample: int + radius: float + use_xyz: boolean + + Returns: + new_xyz: Tensor, (B, 3, npoint) + new_points: Tensor, (B, 3 | f+3 | f, npoint, nsample) + idx_local: Tensor, (B, npoint, nsample) + grouped_xyz: Tensor, (B, 3, npoint, nsample) + + """ + xyz_flipped = xyz.permute(0, 2, 1).contiguous() # (B, N, 3) + new_xyz = gather_operation(xyz, furthest_point_sample(xyz_flipped, npoint)) # (B, 3, npoint) + + idx = ball_query(radius, nsample, xyz_flipped, new_xyz.permute(0, 2, 1).contiguous()) # (B, npoint, nsample) + grouped_xyz = grouping_operation(xyz, idx) # (B, 3, npoint, nsample) + grouped_xyz -= new_xyz.unsqueeze(3).repeat(1, 1, 1, nsample) + + if points is not None: + grouped_points = grouping_operation(points, idx) # (B, f, npoint, nsample) + if use_xyz: + new_points = torch.cat([grouped_xyz, grouped_points], 1) + else: + new_points = grouped_points + else: + new_points = grouped_xyz + + return new_xyz, new_points, idx, grouped_xyz + + +def sample_and_group_all(xyz, points, use_xyz=True): + """ + Args: + xyz: Tensor, (B, 3, nsample) + points: Tensor, (B, f, nsample) + use_xyz: boolean + + Returns: + new_xyz: Tensor, (B, 3, 1) + new_points: Tensor, (B, f|f+3|3, 1, nsample) + idx: Tensor, (B, 1, nsample) + grouped_xyz: Tensor, (B, 3, 1, nsample) + """ + b, _, nsample = xyz.shape + device = xyz.device + new_xyz = torch.zeros((1, 3, 1), dtype=torch.float, device=device).repeat(b, 1, 1) + grouped_xyz = xyz.reshape((b, 3, 1, nsample)) + idx = torch.arange(nsample, device=device).reshape(1, 1, nsample).repeat(b, 1, 1) + if points is not None: + if use_xyz: + new_points = torch.cat([xyz, points], 1) + else: + new_points = points + new_points = new_points.unsqueeze(2) + else: + new_points = grouped_xyz + + return new_xyz, new_points, idx, grouped_xyz + + +class PointNet_SA_Module(nn.Module): + def __init__(self, npoint, nsample, radius, in_channel, mlp, if_bn=True, group_all=False, use_xyz=True): + """ + Args: + npoint: int, number of points to sample + nsample: int, number of points in each local region + radius: float + in_channel: int, input channel of features(points) + mlp: list of int, + """ + super(PointNet_SA_Module, self).__init__() + self.npoint = npoint + self.nsample = nsample + self.radius = radius + self.mlp = mlp + self.group_all = group_all + self.use_xyz = use_xyz + if use_xyz: + in_channel += 3 + + last_channel = in_channel + self.mlp_conv = [] + for out_channel in mlp: + self.mlp_conv.append(Conv2d(last_channel, out_channel, if_bn=if_bn)) + last_channel = out_channel + + self.mlp_conv = nn.Sequential(*self.mlp_conv) + + def forward(self, xyz, points): + """ + Args: + xyz: Tensor, (B, 3, N) + points: Tensor, (B, f, N) + + Returns: + new_xyz: Tensor, (B, 3, npoint) + new_points: Tensor, (B, mlp[-1], npoint) + """ + if self.group_all: + new_xyz, new_points, idx, grouped_xyz = sample_and_group_all(xyz, points, self.use_xyz) + else: + new_xyz, new_points, idx, grouped_xyz = sample_and_group(xyz, points, self.npoint, self.nsample, self.radius, self.use_xyz) + + new_points = self.mlp_conv(new_points) + new_points = torch.max(new_points, 3)[0] + + return new_xyz, new_points + + +class PointNet_FP_Module(nn.Module): + def __init__(self, in_channel, mlp, use_points1=False, in_channel_points1=None, if_bn=True): + """ + Args: + in_channel: int, input channel of points2 + mlp: list of int + use_points1: boolean, if use points + in_channel_points1: int, input channel of points1 + """ + super(PointNet_FP_Module, self).__init__() + self.use_points1 = use_points1 + + if use_points1: + in_channel += in_channel_points1 + + last_channel = in_channel + self.mlp_conv = [] + for out_channel in mlp: + self.mlp_conv.append(Conv1d(last_channel, out_channel, if_bn=if_bn)) + last_channel = out_channel + + self.mlp_conv = nn.Sequential(*self.mlp_conv) + + def forward(self, xyz1, xyz2, points1, points2): + """ + Args: + xyz1: Tensor, (B, 3, N) + xyz2: Tensor, (B, 3, M) + points1: Tensor, (B, in_channel, N) + points2: Tensor, (B, in_channel, M) + + Returns:MLP_CONV + new_points: Tensor, (B, mlp[-1], N) + """ + dist, idx = three_nn(xyz1.permute(0, 2, 1).contiguous(), xyz2.permute(0, 2, 1).contiguous()) + dist = torch.clamp_min(dist, 1e-10) # (B, N, 3) + recip_dist = 1.0/dist + norm = torch.sum(recip_dist, 2, keepdim=True).repeat((1, 1, 3)) + weight = recip_dist / norm + interpolated_points = three_interpolate(points2, idx, weight) # B, in_channel, N + + if self.use_points1: + new_points = torch.cat([interpolated_points, points1], 1) + else: + new_points = interpolated_points + + new_points = self.mlp_conv(new_points) + return new_points + + +def square_distance(src, dst): + """ + Calculate Euclid distance between each two points. + + src^T * dst = xn * xm + yn * ym + zn * zm; + sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn; + sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm; + dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2 + = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst + + Input: + src: source points, [B, N, C] + dst: target points, [B, M, C] + Output: + dist: per-point square distance, [B, N, M] + """ + B, N, _ = src.shape + _, M, _ = dst.shape + dist = -2 * torch.matmul(src, dst.permute(0, 2, 1)) # B, N, M + dist += torch.sum(src ** 2, -1).view(B, N, 1) + dist += torch.sum(dst ** 2, -1).view(B, 1, M) + return dist + + +def query_knn(nsample, xyz, new_xyz, include_self=True): + """Find k-NN of new_xyz in xyz""" + pad = 0 if include_self else 1 + sqrdists = square_distance(new_xyz, xyz) # B, S, N + idx = torch.argsort(sqrdists, dim=-1, descending=False)[:, :, pad: nsample+pad] + return idx.int() + + +def sample_and_group_knn(xyz, points, npoint, k, use_xyz=True, idx=None): + """ + Args: + xyz: Tensor, (B, 3, N) + points: Tensor, (B, f, N) + npoint: int + nsample: int + radius: float + use_xyz: boolean + + Returns: + new_xyz: Tensor, (B, 3, npoint) + new_points: Tensor, (B, 3 | f+3 | f, npoint, nsample) + idx_local: Tensor, (B, npoint, nsample) + grouped_xyz: Tensor, (B, 3, npoint, nsample) + + """ + xyz_flipped = xyz.permute(0, 2, 1).contiguous() # (B, N, 3) + new_xyz = gather_operation(xyz, furthest_point_sample(xyz_flipped, npoint)) # (B, 3, npoint) + if idx is None: + idx = query_knn(k, xyz_flipped, new_xyz.permute(0, 2, 1).contiguous()) + grouped_xyz = grouping_operation(xyz, idx) # (B, 3, npoint, nsample) + grouped_xyz -= new_xyz.unsqueeze(3).repeat(1, 1, 1, k) + + if points is not None: + grouped_points = grouping_operation(points, idx) # (B, f, npoint, nsample) + if use_xyz: + new_points = torch.cat([grouped_xyz, grouped_points], 1) + else: + new_points = grouped_points + else: + new_points = grouped_xyz + + return new_xyz, new_points, idx, grouped_xyz + + +class PointNet_SA_Module_KNN(nn.Module): + def __init__(self, npoint, nsample, in_channel, mlp, if_bn=True, group_all=False, use_xyz=True, if_idx=False): + """ + Args: + npoint: int, number of points to sample + nsample: int, number of points in each local region + radius: float + in_channel: int, input channel of features(points) + mlp: list of int, + """ + super(PointNet_SA_Module_KNN, self).__init__() + self.npoint = npoint + self.nsample = nsample + self.mlp = mlp + self.group_all = group_all + self.use_xyz = use_xyz + self.if_idx = if_idx + if use_xyz: + in_channel += 3 + + last_channel = in_channel + self.mlp_conv = [] + for out_channel in mlp[:-1]: + self.mlp_conv.append(Conv2d(last_channel, out_channel, if_bn=if_bn)) + last_channel = out_channel + self.mlp_conv.append(Conv2d(last_channel, mlp[-1], if_bn=False, activation_fn=None)) + self.mlp_conv = nn.Sequential(*self.mlp_conv) + + def forward(self, xyz, points, idx=None): + """ + Args: + xyz: Tensor, (B, 3, N) + points: Tensor, (B, f, N) + + Returns: + new_xyz: Tensor, (B, 3, npoint) + new_points: Tensor, (B, mlp[-1], npoint) + """ + if self.group_all: + new_xyz, new_points, idx, grouped_xyz = sample_and_group_all(xyz, points, self.use_xyz) + else: + new_xyz, new_points, idx, grouped_xyz = sample_and_group_knn(xyz, points, self.npoint, self.nsample, self.use_xyz, idx=idx) + + new_points = self.mlp_conv(new_points) + new_points = torch.max(new_points, 3)[0] + + if self.if_idx: + return new_xyz, new_points, idx + else: + return new_xyz, new_points + + +def fps_subsample(pcd, n_points=2048): + """ + Args + pcd: (b, 16384, 3) + + returns + new_pcd: (b, n_points, 3) + """ + new_pcd = gather_operation(pcd.permute(0, 2, 1).contiguous(), furthest_point_sample(pcd, n_points)) + new_pcd = new_pcd.permute(0, 2, 1).contiguous() + return new_pcd + + +class Transformer(nn.Module): + def __init__(self, in_channel, dim=256, n_knn=16, pos_hidden_dim=64, attn_hidden_multiplier=4): + super(Transformer, self).__init__() + self.n_knn = n_knn + self.conv_key = nn.Conv1d(dim, dim, 1) + self.conv_query = nn.Conv1d(dim, dim, 1) + self.conv_value = nn.Conv1d(dim, dim, 1) + + self.pos_mlp = nn.Sequential( + nn.Conv2d(3, pos_hidden_dim, 1), + nn.BatchNorm2d(pos_hidden_dim), + nn.ReLU(), + nn.Conv2d(pos_hidden_dim, dim, 1) + ) + + self.attn_mlp = nn.Sequential( + nn.Conv2d(dim, dim * attn_hidden_multiplier, 1), + nn.BatchNorm2d(dim * attn_hidden_multiplier), + nn.ReLU(), + nn.Conv2d(dim * attn_hidden_multiplier, dim, 1) + ) + + self.linear_start = nn.Conv1d(in_channel, dim, 1) + self.linear_end = nn.Conv1d(dim, in_channel, 1) + + def forward(self, x, pos): + """feed forward of transformer + Args: + x: Tensor of features, (B, in_channel, n) + pos: Tensor of positions, (B, 3, n) + + Returns: + y: Tensor of features with attention, (B, in_channel, n) + """ + + identity = x + + x = self.linear_start(x) + b, dim, n = x.shape + + pos_flipped = pos.permute(0, 2, 1).contiguous() + idx_knn = query_knn(self.n_knn, pos_flipped, pos_flipped) + key = self.conv_key(x) + value = self.conv_value(x) + query = self.conv_query(x) + + key = grouping_operation(key, idx_knn) # b, dim, n, n_knn + qk_rel = query.reshape((b, -1, n, 1)) - key + + pos_rel = pos.reshape((b, -1, n, 1)) - grouping_operation(pos, idx_knn) # b, 3, n, n_knn + pos_embedding = self.pos_mlp(pos_rel) # b, dim, n, n_knn + + attention = self.attn_mlp(qk_rel + pos_embedding) + attention = torch.softmax(attention, -1) + + value = value.reshape((b, -1, n, 1)) + pos_embedding + + agg = einsum('b c i j, b c i j -> b c i', attention, value) # b, dim, n + y = self.linear_end(agg) + + return y+identity + + +class CouplingLayer(nn.Module): + + def __init__(self, d, intermediate_dim, swap=False): + nn.Module.__init__(self) + self.d = d - (d // 2) + self.swap = swap + self.net_s_t = nn.Sequential( + nn.Linear(self.d, intermediate_dim), + nn.ReLU(inplace=True), + nn.Linear(intermediate_dim, intermediate_dim), + nn.ReLU(inplace=True), + nn.Linear(intermediate_dim, (d - self.d) * 2), + ) + + def forward(self, x, logpx=None, reverse=False): + + if self.swap: + x = torch.cat([x[:, self.d:], x[:, :self.d]], 1) + + in_dim = self.d + out_dim = x.shape[1] - self.d + + s_t = self.net_s_t(x[:, :in_dim]) + scale = torch.sigmoid(s_t[:, :out_dim] + 2.) + shift = s_t[:, out_dim:] + + logdetjac = torch.sum(torch.log(scale).view(scale.shape[0], -1), 1, keepdim=True) + + if not reverse: + y1 = x[:, self.d:] * scale + shift + delta_logp = -logdetjac + else: + y1 = (x[:, self.d:] - shift) / scale + delta_logp = logdetjac + + y = torch.cat([x[:, :self.d], y1], 1) if not self.swap else torch.cat([y1, x[:, :self.d]], 1) + + if logpx is None: + return y + else: + return y, logpx + delta_logp + + +class SequentialFlow(nn.Module): + """A generalized nn.Sequential container for normalizing flows. + """ + + def __init__(self, layersList): + super(SequentialFlow, self).__init__() + self.chain = nn.ModuleList(layersList) + + def forward(self, x, logpx=None, reverse=False, inds=None): + if inds is None: + if reverse: + inds = range(len(self.chain) - 1, -1, -1) + else: + inds = range(len(self.chain)) + + if logpx is None: + for i in inds: + x = self.chain[i](x, reverse=reverse) + return x + else: + for i in inds: + x, logpx = self.chain[i](x, logpx, reverse=reverse) + return x, logpx + + +def build_latent_flow(args): + chain = [] + for i in range(args.latent_flow_depth): + chain.append(CouplingLayer(args.latent_dim, args.latent_flow_hidden_dim, swap=(i % 2 == 0))) + return SequentialFlow(chain) + + +################## +## SpectralNorm ## +################## + +POWER_ITERATION_FN = "spectral_norm_power_iteration" + + +class SpectralNorm(object): + def __init__(self, name='weight', dim=0, eps=1e-12): + self.name = name + self.dim = dim + self.eps = eps + + def compute_weight(self, module, n_power_iterations): + if n_power_iterations < 0: + raise ValueError( + 'Expected n_power_iterations to be non-negative, but ' + 'got n_power_iterations={}'.format(n_power_iterations) + ) + + weight = getattr(module, self.name + '_orig') + u = getattr(module, self.name + '_u') + v = getattr(module, self.name + '_v') + weight_mat = weight + if self.dim != 0: + # permute dim to front + weight_mat = weight_mat.permute(self.dim, * [d for d in range(weight_mat.dim()) if d != self.dim]) + height = weight_mat.size(0) + weight_mat = weight_mat.reshape(height, -1) + with torch.no_grad(): + for _ in range(n_power_iterations): + # Spectral norm of weight equals to `u^T W v`, where `u` and `v` + # are the first left and right singular vectors. + # This power iteration produces approximations of `u` and `v`. + v = F.normalize(torch.matmul(weight_mat.t(), u), dim=0, eps=self.eps) + u = F.normalize(torch.matmul(weight_mat, v), dim=0, eps=self.eps) + setattr(module, self.name + '_u', u) + setattr(module, self.name + '_v', v) + + sigma = torch.dot(u, torch.matmul(weight_mat, v)) + weight = weight / sigma + setattr(module, self.name, weight) + + def remove(self, module): + weight = getattr(module, self.name) + delattr(module, self.name) + delattr(module, self.name + '_u') + delattr(module, self.name + '_orig') + module.register_parameter(self.name, torch.nn.Parameter(weight)) + + def get_update_method(self, module): + def update_fn(module, n_power_iterations): + self.compute_weight(module, n_power_iterations) + + return update_fn + + def __call__(self, module, unused_inputs): + del unused_inputs + self.compute_weight(module, n_power_iterations=0) + + # requires_grad might be either True or False during inference. + if not module.training: + r_g = getattr(module, self.name + '_orig').requires_grad + setattr(module, self.name, getattr(module, self.name).detach().requires_grad_(r_g)) + + @staticmethod + def apply(module, name, dim, eps): + fn = SpectralNorm(name, dim, eps) + weight = module._parameters[name] + height = weight.size(dim) + + u = F.normalize(weight.new_empty(height).normal_(0, 1), dim=0, eps=fn.eps) + v = F.normalize(weight.new_empty(int(weight.numel() / height)).normal_(0, 1), dim=0, eps=fn.eps) + delattr(module, fn.name) + module.register_parameter(fn.name + "_orig", weight) + # We still need to assign weight back as fn.name because all sorts of + # things may assume that it exists, e.g., when initializing weights. + # However, we can't directly assign as it could be an nn.Parameter and + # gets added as a parameter. Instead, we register weight.data as a + # buffer, which will cause weight to be included in the state dict + # and also supports nn.init due to shared storage. + module.register_buffer(fn.name, weight.data) + module.register_buffer(fn.name + "_u", u) + module.register_buffer(fn.name + "_v", v) + + setattr(module, POWER_ITERATION_FN, types.MethodType(fn.get_update_method(module), module)) + + module.register_forward_pre_hook(fn) + return fn + + +def inplace_spectral_norm(module, name='weight', dim=None, eps=1e-12): + r"""Applies spectral normalization to a parameter in the given module. + .. math:: + \mathbf{W} = \dfrac{\mathbf{W}}{\sigma(\mathbf{W})} \\ + \sigma(\mathbf{W}) = \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2} + Spectral normalization stabilizes the training of discriminators (critics) + in Generaive Adversarial Networks (GANs) by rescaling the weight tensor + with spectral norm :math:`\sigma` of the weight matrix calculated using + power iteration method. If the dimension of the weight tensor is greater + than 2, it is reshaped to 2D in power iteration method to get spectral + norm. This is implemented via a hook that calculates spectral norm and + rescales weight before every :meth:`~Module.forward` call. + See `Spectral Normalization for Generative Adversarial Networks`_ . + .. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957 + Args: + module (nn.Module): containing module + name (str, optional): name of weight parameter + n_power_iterations (int, optional): number of power iterations to + calculate spectal norm + dim (int, optional): dimension corresponding to number of outputs, + the default is 0, except for modules that are instances of + ConvTranspose1/2/3d, when it is 1 + eps (float, optional): epsilon for numerical stability in + calculating norms + Returns: + The original module with the spectal norm hook + Example:: + >>> m = spectral_norm(nn.Linear(20, 40)) + Linear (20 -> 40) + >>> m.weight_u.size() + torch.Size([20]) + """ + if dim is None: + if isinstance(module, (torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d)): + dim = 1 + else: + dim = 0 + SpectralNorm.apply(module, name, dim=dim, eps=eps) + return module + + +def remove_spectral_norm(module, name='weight'): + r"""Removes the spectral normalization reparameterization from a module. + Args: + module (nn.Module): containing module + name (str, optional): name of weight parameter + Example: + >>> m = spectral_norm(nn.Linear(40, 10)) + >>> remove_spectral_norm(m) + """ + for k, hook in module._forward_pre_hooks.items(): + if isinstance(hook, SpectralNorm) and hook.name == name: + hook.remove(module) + del module._forward_pre_hooks[k] + return module + + raise ValueError("spectral_norm of '{}' not found in {}".format(name, module)) + + +def add_spectral_norm(model, logger=None): + """Applies spectral norm to all modules within the scope of a CNF.""" + + def apply_spectral_norm(module): + if 'weight' in module._parameters: + if logger: logger.info("Adding spectral norm to {}".format(module)) + inplace_spectral_norm(module, 'weight') + + def find_coupling_layer(module): + if isinstance(module, CouplingLayer): + module.apply(apply_spectral_norm) + else: + for child in module.children(): + find_coupling_layer(child) + + find_coupling_layer(model) + + +def spectral_norm_power_iteration(model, n_power_iterations=1): + + def recursive_power_iteration(module): + if hasattr(module, POWER_ITERATION_FN): + getattr(module, POWER_ITERATION_FN)(n_power_iterations) + + model.apply(recursive_power_iteration) + +def reparameterize_gaussian(mean, logvar): + std = torch.exp(0.5 * logvar) + eps = torch.randn(std.size()).to(mean) + return mean + std * eps + + + +def gaussian_entropy(logvar): + const = 0.5 * float(logvar.size(1)) * (1. + np.log(np.pi * 2)) + ent = 0.5 * logvar.sum(dim=1, keepdim=False) + const + return ent + + +def standard_normal_logprob(z): + dim = z.size(-1) + log_z = -0.5 * dim * np.log(2 * np.pi) + return log_z - z.pow(2) / 2 + +def truncated_normal_(tensor, mean=0, std=1, trunc_std=2): + """ + Taken from https://discuss.pytorch.org/t/implementing-truncated-normal-initializer/4778/15 + """ + size = tensor.shape + tmp = tensor.new_empty(size + (4,)).normal_() + valid = (tmp < trunc_std) & (tmp > -trunc_std) + ind = valid.max(-1, keepdim=True)[1] + tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1)) + tensor.data.mul_(std).add_(mean) + return tensor \ No newline at end of file diff --git a/hort/models/tgs/models/tokenizers/dinov2.py b/hort/models/tgs/models/tokenizers/dinov2.py new file mode 100644 index 0000000000000000000000000000000000000000..64f939b2d1f0d31c64ba21328fdecf72c280c6fa --- /dev/null +++ b/hort/models/tgs/models/tokenizers/dinov2.py @@ -0,0 +1,1179 @@ +# coding=utf-8 +# Copyright 2023 Meta AI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch DINOv2 model.""" + + +import collections.abc +import math +from typing import Dict, List, Optional, Set, Tuple, Union +from dataclasses import dataclass + +import torch +import torch.utils.checkpoint +import torch.nn.functional as F +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.activations import ACT2FN +from transformers.modeling_outputs import ( + BackboneOutput, + BaseModelOutput, + BaseModelOutputWithPooling, + ImageClassifierOutput, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.pytorch_utils import ( + find_pruneable_heads_and_indices, + prune_linear_layer, +) +from transformers.utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from transformers.utils.backbone_utils import BackboneMixin +from transformers.models.dinov2.configuration_dinov2 import Dinov2Config + +from tgs.models.transformers import MemoryEfficientAttentionMixin +from tgs.utils.typing import * + + +logger = logging.get_logger(__name__) + +# General docstring +_CONFIG_FOR_DOC = "Dinov2Config" + +# Base docstring +_CHECKPOINT_FOR_DOC = "facebook/dinov2-base" +_EXPECTED_OUTPUT_SHAPE = [1, 257, 768] + +# Image classification docstring +_IMAGE_CLASS_CHECKPOINT = "facebook/dinov2-base" + + +DINOV2_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "facebook/dinov2-base", + # See all DINOv2 models at https://huggingface.co/models?filter=dinov2 +] + + +class Dinov2Embeddings(nn.Module): + """ + Construct the CLS token, mask token, position and patch embeddings. + """ + + def __init__(self, config: Dinov2Config) -> None: + super().__init__() + + self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size)) + # register as mask token as it's not used in optimization + # to avoid the use of find_unused_parameters_true + # self.mask_token = nn.Parameter(torch.zeros(1, config.hidden_size)) + self.register_buffer("mask_token", torch.zeros(1, config.hidden_size)) + self.patch_embeddings = Dinov2PatchEmbeddings(config) + num_patches = self.patch_embeddings.num_patches + self.position_embeddings = nn.Parameter( + torch.randn(1, num_patches + 1, config.hidden_size) + ) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.config = config + + def interpolate_pos_encoding( + self, embeddings: torch.Tensor, height: int, width: int + ) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher + resolution images. + + Source: + https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + """ + + num_patches = embeddings.shape[1] - 1 + num_positions = self.position_embeddings.shape[1] - 1 + if num_patches == num_positions and height == width: + return self.position_embeddings + class_pos_embed = self.position_embeddings[:, 0] + patch_pos_embed = self.position_embeddings[:, 1:] + dim = embeddings.shape[-1] + height = height // self.config.patch_size + width = width // self.config.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + height, width = height + 0.1, width + 0.1 + patch_pos_embed = patch_pos_embed.reshape( + 1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim + ) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + scale_factor=( + height / math.sqrt(num_positions), + width / math.sqrt(num_positions), + ), + mode="bicubic", + align_corners=False, + ) + if ( + int(height) != patch_pos_embed.shape[-2] + or int(width) != patch_pos_embed.shape[-1] + ): + raise ValueError( + "Width or height does not match with the interpolated position embeddings" + ) + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + def forward( + self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + batch_size, _, height, width = pixel_values.shape + patch_embeddings = self.patch_embeddings(pixel_values) + embeddings = patch_embeddings + + if bool_masked_pos is not None: + embeddings = torch.where( + bool_masked_pos.unsqueeze(-1), + self.mask_token.to(embeddings.dtype).unsqueeze(0), + embeddings, + ) + + # add the [CLS] token to the embedded patch tokens + cls_tokens = self.cls_token.expand(batch_size, -1, -1) + embeddings = torch.cat((cls_tokens, embeddings), dim=1) + + # add positional encoding to each token + embeddings = embeddings + self.interpolate_pos_encoding( + embeddings, height, width + ) + + embeddings = self.dropout(embeddings) + + return embeddings + + +class Dinov2PatchEmbeddings(nn.Module): + """ + This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial + `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a + Transformer. + """ + + def __init__(self, config): + super().__init__() + image_size, patch_size = config.image_size, config.patch_size + num_channels, hidden_size = config.num_channels, config.hidden_size + + image_size = ( + image_size + if isinstance(image_size, collections.abc.Iterable) + else (image_size, image_size) + ) + patch_size = ( + patch_size + if isinstance(patch_size, collections.abc.Iterable) + else (patch_size, patch_size) + ) + num_patches = (image_size[1] // patch_size[1]) * ( + image_size[0] // patch_size[0] + ) + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_patches = num_patches + + self.projection = nn.Conv2d( + num_channels, hidden_size, kernel_size=patch_size, stride=patch_size + ) + + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + num_channels = pixel_values.shape[1] + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + f" Expected {self.num_channels} but got {num_channels}." + ) + embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2) + return embeddings + + +# Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->Dinov2 +class Dinov2SelfAttention(nn.Module): + def __init__(self, config: Dinov2Config) -> None: + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr( + config, "embedding_size" + ): + raise ValueError( + f"The hidden size {config.hidden_size,} is not a multiple of the number of attention " + f"heads {config.num_attention_heads}." + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.attention_probs_dropout_prob = config.attention_probs_dropout_prob + + self.query = nn.Linear( + config.hidden_size, self.all_head_size, bias=config.qkv_bias + ) + self.key = nn.Linear( + config.hidden_size, self.all_head_size, bias=config.qkv_bias + ) + self.value = nn.Linear( + config.hidden_size, self.all_head_size, bias=config.qkv_bias + ) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + self.use_memory_efficient_attention_xformers: bool = False + + def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: + new_x_shape = x.size()[:-1] + ( + self.num_attention_heads, + self.attention_head_size, + ) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + mixed_query_layer = self.query(hidden_states) + + if self.use_memory_efficient_attention_xformers: + import xformers + assert head_mask is None and not output_attentions + new_size = hidden_states.size()[:-1] + ( + self.num_attention_heads, + self.attention_head_size, + ) + key_layer = self.key(hidden_states).view(new_size) + value_layer = self.value(hidden_states).view(new_size) + query_layer = mixed_query_layer.view(new_size) + context_layer = xformers.ops.memory_efficient_attention( + query_layer, key_layer, value_layer, p=self.attention_probs_dropout_prob + ) + context_layer = context_layer.view(*hidden_states.size()[:-1], -1) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + query_layer = self.transpose_for_scores(mixed_query_layer) + + try: + context_layer = F.scaled_dot_product_attention(query_layer, key_layer, value_layer, attn_mask=head_mask, dropout_p=(self.dropout.p if self.training else 0.0), scale=1/math.sqrt(self.attention_head_size)) + except: + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + outputs = ( + (context_layer, attention_probs) if output_attentions else (context_layer,) + ) + + return outputs + + def set_use_memory_efficient_attention_xformers( + self, valid: bool, attention_op: Optional[Callable] = None + ): + self.use_memory_efficient_attention_xformers = valid + + +# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->Dinov2 +class Dinov2SelfOutput(nn.Module): + """ + The residual connection is defined in Dinov2Layer instead of here (as is the case with other models), due to the + layernorm applied before each block. + """ + + def __init__(self, config: Dinov2Config) -> None: + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward( + self, hidden_states: torch.Tensor, input_tensor: torch.Tensor + ) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + + return hidden_states + + +# Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->Dinov2 +class Dinov2Attention(nn.Module): + def __init__(self, config: Dinov2Config) -> None: + super().__init__() + self.attention = Dinov2SelfAttention(config) + self.output = Dinov2SelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads: Set[int]) -> None: + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, + self.attention.num_attention_heads, + self.attention.attention_head_size, + self.pruned_heads, + ) + + # Prune linear layers + self.attention.query = prune_linear_layer(self.attention.query, index) + self.attention.key = prune_linear_layer(self.attention.key, index) + self.attention.value = prune_linear_layer(self.attention.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.attention.num_attention_heads = self.attention.num_attention_heads - len( + heads + ) + self.attention.all_head_size = ( + self.attention.attention_head_size * self.attention.num_attention_heads + ) + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + self_outputs = self.attention(hidden_states, head_mask, output_attentions) + + attention_output = self.output(self_outputs[0], hidden_states) + + outputs = (attention_output,) + self_outputs[ + 1: + ] # add attentions if we output them + return outputs + + +class Dinov2LayerScale(nn.Module): + def __init__(self, config) -> None: + super().__init__() + self.lambda1 = nn.Parameter( + config.layerscale_value * torch.ones(config.hidden_size) + ) + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + return hidden_state * self.lambda1 + + +# Copied from transformers.models.beit.modeling_beit.drop_path +def drop_path( + input: torch.Tensor, drop_prob: float = 0.0, training: bool = False +) -> torch.Tensor: + """ + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, + however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the + layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the + argument. + """ + if drop_prob == 0.0 or not training: + return input + keep_prob = 1 - drop_prob + shape = (input.shape[0],) + (1,) * ( + input.ndim - 1 + ) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand( + shape, dtype=input.dtype, device=input.device + ) + random_tensor.floor_() # binarize + output = input.div(keep_prob) * random_tensor + return output + + +# Copied from transformers.models.beit.modeling_beit.BeitDropPath +class Dinov2DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob: Optional[float] = None) -> None: + super().__init__() + self.drop_prob = drop_prob + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return drop_path(hidden_states, self.drop_prob, self.training) + + def extra_repr(self) -> str: + return "p={}".format(self.drop_prob) + + +class Dinov2MLP(nn.Module): + def __init__(self, config) -> None: + super().__init__() + in_features = out_features = config.hidden_size + hidden_features = int(config.hidden_size * config.mlp_ratio) + self.fc1 = nn.Linear(in_features, hidden_features, bias=True) + if isinstance(config.hidden_act, str): + self.activation = ACT2FN[config.hidden_act] + else: + self.activation = config.hidden_act + self.fc2 = nn.Linear(hidden_features, out_features, bias=True) + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + hidden_state = self.fc1(hidden_state) + hidden_state = self.activation(hidden_state) + hidden_state = self.fc2(hidden_state) + return hidden_state + + +class Dinov2SwiGLUFFN(nn.Module): + def __init__(self, config) -> None: + super().__init__() + in_features = out_features = config.hidden_size + hidden_features = int(config.hidden_size * config.mlp_ratio) + hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 + + self.weights_in = nn.Linear(in_features, 2 * hidden_features, bias=True) + self.weights_out = nn.Linear(hidden_features, out_features, bias=True) + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + hidden_state = self.weights_in(hidden_state) + x1, x2 = hidden_state.chunk(2, dim=-1) + hidden = nn.functional.silu(x1) * x2 + return self.weights_out(hidden) + + +class Dinov2Layer(nn.Module, MemoryEfficientAttentionMixin): + """This corresponds to the Block class in the original implementation.""" + + def __init__(self, config: Dinov2Config) -> None: + super().__init__() + + self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.norm1_modulation = None + self.attention = Dinov2Attention(config) + self.layer_scale1 = Dinov2LayerScale(config) + self.drop_path1 = ( + Dinov2DropPath(config.drop_path_rate) + if config.drop_path_rate > 0.0 + else nn.Identity() + ) + + self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.norm2_modulation = None + + if config.use_swiglu_ffn: + self.mlp = Dinov2SwiGLUFFN(config) + else: + self.mlp = Dinov2MLP(config) + self.layer_scale2 = Dinov2LayerScale(config) + self.drop_path2 = ( + Dinov2DropPath(config.drop_path_rate) + if config.drop_path_rate > 0.0 + else nn.Identity() + ) + + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + modulation_cond: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + hidden_states_norm = self.norm1(hidden_states) + if self.norm1_modulation is not None: + assert modulation_cond is not None + hidden_states_norm = self.norm1_modulation( + hidden_states_norm, modulation_cond + ) + self_attention_outputs = self.attention( + hidden_states_norm, # in Dinov2, layernorm is applied before self-attention + head_mask, + output_attentions=output_attentions, + ) + attention_output = self_attention_outputs[0] + + attention_output = self.layer_scale1(attention_output) + outputs = self_attention_outputs[ + 1: + ] # add self attentions if we output attention weights + + # first residual connection + hidden_states = attention_output + hidden_states + + # in Dinov2, layernorm is also applied after self-attention + layer_output = self.norm2(hidden_states) + if self.norm2_modulation is not None: + assert modulation_cond is not None + layer_output = self.norm2_modulation(layer_output, modulation_cond) + layer_output = self.mlp(layer_output) + layer_output = self.layer_scale2(layer_output) + + # second residual connection + layer_output = layer_output + hidden_states + + outputs = (layer_output,) + outputs + + return outputs + + def register_ada_norm_modulation(self, norm1_mod: nn.Module, norm2_mod: nn.Module): + self.norm1_modulation = norm1_mod + self.norm2_modulation = norm2_mod + + +# Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->Dinov2 +class Dinov2Encoder(nn.Module, MemoryEfficientAttentionMixin): + def __init__(self, config: Dinov2Config) -> None: + super().__init__() + self.config = config + self.layer = nn.ModuleList( + [Dinov2Layer(config) for _ in range(config.num_hidden_layers)] + ) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + modulation_cond: Optional[torch.Tensor] = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ) -> Union[tuple, BaseModelOutput]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + layer_head_mask, + modulation_cond, + ) + else: + layer_outputs = layer_module( + hidden_states, layer_head_mask, modulation_cond, output_attentions + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [hidden_states, all_hidden_states, all_self_attentions] + if v is not None + ) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class Dinov2PreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = Dinov2Config + base_model_prefix = "dinov2" + main_input_name = "pixel_values" + supports_gradient_checkpointing = True + + def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d)): + # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid + # `trunc_normal_cpu` not implemented in `half` issues + module.weight.data = nn.init.trunc_normal_( + module.weight.data.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.weight.dtype) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, Dinov2Embeddings): + module.position_embeddings.data = nn.init.trunc_normal_( + module.position_embeddings.data.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.position_embeddings.dtype) + + module.cls_token.data = nn.init.trunc_normal_( + module.cls_token.data.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.cls_token.dtype) + + def _set_gradient_checkpointing( + self, module: Dinov2Encoder, value: bool = False + ) -> None: + if isinstance(module, Dinov2Encoder): + module.gradient_checkpointing = value + + +DINOV2_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it + as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`Dinov2Config`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +DINOV2_BASE_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See + [`BitImageProcessor.preprocess`] for details. + + bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, sequence_length)`): + Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). Only relevant for + pre-training. + + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +DINOV2_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See + [`BitImageProcessor.preprocess`] for details. + + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +@dataclass +class CustomBaseModelOutputWithPooling(BaseModelOutputWithPooling): + patch_embeddings: Optional[torch.FloatTensor] = None + + +@add_start_docstrings( + "The bare DINOv2 Model transformer outputting raw hidden-states without any specific head on top.", + DINOV2_START_DOCSTRING, +) +class Dinov2Model(Dinov2PreTrainedModel, MemoryEfficientAttentionMixin): + def __init__(self, config: Dinov2Config): + super().__init__(config) + self.config = config + + self.embeddings = Dinov2Embeddings(config) + self.encoder = Dinov2Encoder(config) + + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> Dinov2PatchEmbeddings: + return self.embeddings.patch_embeddings + + def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None: + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(DINOV2_BASE_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPooling, + config_class=_CONFIG_FOR_DOC, + modality="vision", + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + bool_masked_pos: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + modulation_cond: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + embedding_output = self.embeddings( + pixel_values, bool_masked_pos=bool_masked_pos + ) + + encoder_outputs = self.encoder( + embedding_output, + head_mask=head_mask, + modulation_cond=modulation_cond, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + sequence_output = self.layernorm(sequence_output) + pooled_output = sequence_output[:, 0, :] + + if not return_dict: + head_outputs = (sequence_output, pooled_output) + return head_outputs + encoder_outputs[1:] + + return CustomBaseModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + patch_embeddings=embedding_output + ) + + def set_gradient_checkpointing(self, value: bool = False) -> None: + self._set_gradient_checkpointing(self.encoder, value) + + +@add_start_docstrings( + """ + Dinov2 Model transformer with an image classification head on top (a linear layer on top of the final hidden state + of the [CLS] token) e.g. for ImageNet. + """, + DINOV2_START_DOCSTRING, +) +class Dinov2ForImageClassification(Dinov2PreTrainedModel): + def __init__(self, config: Dinov2Config) -> None: + super().__init__(config) + + self.num_labels = config.num_labels + self.dinov2 = Dinov2Model(config) + + # Classifier head + self.classifier = ( + nn.Linear(config.hidden_size * 2, config.num_labels) + if config.num_labels > 0 + else nn.Identity() + ) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(DINOV2_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_IMAGE_CLASS_CHECKPOINT, + output_type=ImageClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, ImageClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + outputs = self.dinov2( + pixel_values, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] # batch_size, sequence_length, hidden_size + + cls_token = sequence_output[:, 0] + patch_tokens = sequence_output[:, 1:] + + linear_input = torch.cat([cls_token, patch_tokens.mean(dim=1)], dim=1) + + logits = self.classifier(linear_input) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and ( + labels.dtype == torch.long or labels.dtype == torch.int + ): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return ImageClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Dinov2 backbone, to be used with frameworks like DETR and MaskFormer. + """, + DINOV2_START_DOCSTRING, +) +class Dinov2Backbone(Dinov2PreTrainedModel, BackboneMixin): + def __init__(self, config): + super().__init__(config) + super()._init_backbone(config) + + self.num_features = [ + config.hidden_size for _ in range(config.num_hidden_layers + 1) + ] + self.embeddings = Dinov2Embeddings(config) + self.encoder = Dinov2Encoder(config) + + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> Dinov2PatchEmbeddings: + return self.embeddings.patch_embeddings + + @add_start_docstrings_to_model_forward(DINOV2_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values: torch.Tensor, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> BackboneOutput: + """ + Returns: + + Examples: + + ```python + >>> from transformers import AutoImageProcessor, AutoBackbone + >>> import torch + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base") + >>> model = AutoBackbone.from_pretrained( + ... "facebook/dinov2-base", out_features=["stage2", "stage5", "stage8", "stage11"] + ... ) + + >>> inputs = processor(image, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> feature_maps = outputs.feature_maps + >>> list(feature_maps[-1].shape) + [1, 768, 16, 16] + ```""" + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + + embedding_output = self.embeddings(pixel_values) + + outputs = self.encoder( + embedding_output, + output_hidden_states=True, + output_attentions=output_attentions, + return_dict=return_dict, + ) + + hidden_states = outputs.hidden_states if return_dict else outputs[1] + + feature_maps = () + for stage, hidden_state in zip(self.stage_names, hidden_states): + if stage in self.out_features: + if self.config.apply_layernorm: + hidden_state = self.layernorm(hidden_state) + if self.config.reshape_hidden_states: + batch_size, _, height, width = pixel_values.shape + patch_size = self.config.patch_size + hidden_state = hidden_state[:, 1:, :].reshape( + batch_size, width // patch_size, height // patch_size, -1 + ) + hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous() + feature_maps += (hidden_state,) + + if not return_dict: + if output_hidden_states: + output = (feature_maps,) + outputs[1:] + else: + output = (feature_maps,) + outputs[2:] + return output + + return BackboneOutput( + feature_maps=feature_maps, + hidden_states=outputs.hidden_states if output_hidden_states else None, + attentions=outputs.attentions if output_attentions else None, + ) + + + +class CustomPatchEmbeddings(nn.Module): + """ + This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial + `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a + Transformer. + """ + + def __init__(self, image_size: int, patch_size: int, num_channels: int, hidden_size: int): + super().__init__() + + image_size = ( + image_size + if isinstance(image_size, collections.abc.Iterable) + else (image_size, image_size) + ) + patch_size = ( + patch_size + if isinstance(patch_size, collections.abc.Iterable) + else (patch_size, patch_size) + ) + num_patches = (image_size[1] // patch_size[1]) * ( + image_size[0] // patch_size[0] + ) + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_patches = num_patches + + self.projection = nn.Conv2d( + num_channels, hidden_size, kernel_size=patch_size, stride=patch_size + ) + + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + num_channels = pixel_values.shape[1] + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + f" Expected {self.num_channels} but got {num_channels}." + ) + embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2) + return embeddings + + +class CustomEmbeddings(nn.Module): + """ + Construct the CLS token, mask token, position and patch embeddings. + """ + + def __init__(self, image_size: int, patch_size: int, num_channels: int, hidden_size: int) -> None: + super().__init__() + + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.hidden_size = hidden_size + + self.cls_token = nn.Parameter(torch.randn(1, 1, self.hidden_size)) + + self.patch_embeddings = CustomPatchEmbeddings(image_size, patch_size, num_channels, hidden_size) + num_patches = self.patch_embeddings.num_patches + self.position_embeddings = nn.Parameter( + torch.randn(1, num_patches + 1, self.hidden_size) + ) + + def interpolate_pos_encoding( + self, embeddings: torch.Tensor, height: int, width: int + ) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher + resolution images. + + Source: + https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + """ + + num_patches = embeddings.shape[1] - 1 + num_positions = self.position_embeddings.shape[1] - 1 + if num_patches == num_positions and height == width: + return self.position_embeddings + class_pos_embed = self.position_embeddings[:, 0] + patch_pos_embed = self.position_embeddings[:, 1:] + dim = embeddings.shape[-1] + height = height // self.patch_size + width = width // self.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + height, width = height + 0.1, width + 0.1 + patch_pos_embed = patch_pos_embed.reshape( + 1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim + ) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + scale_factor=( + height / math.sqrt(num_positions), + width / math.sqrt(num_positions), + ), + mode="bicubic", + align_corners=False, + ) + if ( + int(height) != patch_pos_embed.shape[-2] + or int(width) != patch_pos_embed.shape[-1] + ): + raise ValueError( + "Width or height does not match with the interpolated position embeddings" + ) + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + def forward( + self, pixel_values: torch.Tensor, + ) -> torch.Tensor: + batch_size, _, height, width = pixel_values.shape + patch_embeddings = self.patch_embeddings(pixel_values) + embeddings = patch_embeddings + + # add the [CLS] token to the embedded patch tokens + cls_tokens = self.cls_token.expand(batch_size, -1, -1) + embeddings = torch.cat((cls_tokens, embeddings), dim=1) + + # add positional encoding to each token + embeddings = embeddings + self.interpolate_pos_encoding( + embeddings, height, width + ) + + return embeddings diff --git a/hort/models/tgs/models/tokenizers/image.py b/hort/models/tgs/models/tokenizers/image.py new file mode 100644 index 0000000000000000000000000000000000000000..548aee829200cc37e8e44ac363556de776bc9e74 --- /dev/null +++ b/hort/models/tgs/models/tokenizers/image.py @@ -0,0 +1,123 @@ +from dataclasses import dataclass + +import torch +import torch.nn as nn +from einops import rearrange + +from tgs.utils.base import BaseModule +from tgs.models.tokenizers.dinov2 import Dinov2Model +from tgs.models.transformers import Modulation +from tgs.utils.typing import * + +class DINOV2SingleImageTokenizer(BaseModule): + @dataclass + class Config(BaseModule.Config): + pretrained_model_name_or_path: str = "facebook/dinov2-base" + width: int = 224 + height: int = 224 + modulation: bool = False + modulation_zero_init: bool = False + modulation_single_layer: bool = False + modulation_cond_dim: int = 16 + freeze_backbone_params: bool = True + enable_memory_efficient_attention: bool = False + enable_gradient_checkpointing: bool = False + use_patch_embeddings: bool = False + patch_embeddings_aggr_method: str = 'concat' + + cfg: Config + + def configure(self) -> None: + super().configure() + model: Dinov2Model + + if self.cfg.freeze_backbone_params: + # freeze dino backbone parameters + self.register_non_module( + "model", + Dinov2Model.from_pretrained(self.cfg.pretrained_model_name_or_path).to( + self.device + ), + ) + + model = self.non_module("model") + for p in model.parameters(): + p.requires_grad_(False) + model.eval() + else: + self.model = Dinov2Model.from_pretrained( + self.cfg.pretrained_model_name_or_path + ).to(self.device) + model = self.model + + model.set_use_memory_efficient_attention_xformers( + self.cfg.enable_memory_efficient_attention + ) + model.set_gradient_checkpointing(self.cfg.enable_gradient_checkpointing) + + # add modulation + if self.cfg.modulation: + modulations = [] + for layer in model.encoder.layer: + norm1_modulation = Modulation( + model.config.hidden_size, + self.cfg.modulation_cond_dim, + zero_init=self.cfg.modulation_zero_init, + single_layer=self.cfg.modulation_single_layer, + ) + norm2_modulation = Modulation( + model.config.hidden_size, + self.cfg.modulation_cond_dim, + zero_init=self.cfg.modulation_zero_init, + single_layer=self.cfg.modulation_single_layer, + ) + layer.register_ada_norm_modulation(norm1_modulation, norm2_modulation) + modulations += [norm1_modulation, norm2_modulation] + self.modulations = nn.ModuleList(modulations) + + def forward( + self, + images: Float[Tensor, "B *N C H W"], + modulation_cond: Optional[Float[Tensor, "B *N Cc"]], + ) -> Float[Tensor, "B *N Ct Nt"]: + model: Dinov2Model + if self.cfg.freeze_backbone_params: + model = self.non_module("model") + else: + model = self.model + + packed = False + if images.ndim == 4: + packed = True + images = images.unsqueeze(1) + if modulation_cond is not None: + assert modulation_cond.ndim == 2 + modulation_cond = modulation_cond.unsqueeze(1) + + batch_size, n_input_views = images.shape[:2] + out = model( + rearrange(images, "B N C H W -> (B N) C H W"), + modulation_cond=rearrange(modulation_cond, "B N Cc -> (B N) Cc") + if modulation_cond is not None + else None, + ) + local_features, global_features = out.last_hidden_state, out.pooler_output + if self.cfg.use_patch_embeddings: + patch_embeddings = out.patch_embeddings + if self.cfg.patch_embeddings_aggr_method == 'concat': + local_features = torch.cat([local_features, patch_embeddings], dim=1) + elif self.cfg.patch_embeddings_aggr_method == 'add': + local_features = local_features + patch_embeddings + else: + raise NotImplementedError + local_features = local_features.permute(0, 2, 1) + local_features = rearrange( + local_features, "(B N) Ct Nt -> B N Ct Nt", B=batch_size + ) + if packed: + local_features = local_features.squeeze(1) + + return local_features + + def detokenize(self, *args, **kwargs): + raise NotImplementedError diff --git a/hort/models/tgs/models/tokenizers/point.py b/hort/models/tgs/models/tokenizers/point.py new file mode 100644 index 0000000000000000000000000000000000000000..99afb6cc6f2736dd5a24b3c7cdf337ea860bc5ec --- /dev/null +++ b/hort/models/tgs/models/tokenizers/point.py @@ -0,0 +1,29 @@ +from dataclasses import dataclass +import torch.nn as nn +from tgs.utils.base import BaseModule +from tgs.utils.typing import * +import torch + +class PointLearnablePositionalEmbedding(BaseModule): + @dataclass + class Config(BaseModule.Config): + num_pcl: int = 2048 + num_channels: int = 512 + + cfg: Config + + def configure(self) -> None: + super().configure() + self.pcl_embeddings = nn.Embedding( + self.cfg.num_pcl , self.cfg.num_channels + ) + + def forward(self, batch_size: int) -> Float[Tensor, "B Ct Nt"]: + range_ = torch.arange(self.cfg.num_pcl, device=self.device) + embeddings = self.pcl_embeddings(range_).unsqueeze(0).repeat((batch_size,1,1)) + return torch.permute(embeddings, (0,2,1)) + + def detokenize( + self, tokens: Float[Tensor, "B Ct Nt"] + ) -> Float[Tensor, "B 3 Ct Hp Wp"]: + return torch.permute(tokens, (0,2,1)) \ No newline at end of file diff --git a/hort/models/tgs/models/tokenizers/triplane.py b/hort/models/tgs/models/tokenizers/triplane.py new file mode 100644 index 0000000000000000000000000000000000000000..8bd6774c5028d25308878c5218ee4664cede6e63 --- /dev/null +++ b/hort/models/tgs/models/tokenizers/triplane.py @@ -0,0 +1,52 @@ +from dataclasses import dataclass +import math + +import torch +import torch.nn as nn +from einops import rearrange, repeat + +from tgs.utils.base import BaseModule +from tgs.utils.typing import * + + +class TriplaneLearnablePositionalEmbedding(BaseModule): + @dataclass + class Config(BaseModule.Config): + plane_size: int = 32 + num_channels: int = 1024 + + cfg: Config + + def configure(self) -> None: + super().configure() + self.embeddings = nn.Parameter( + torch.randn( + (3, self.cfg.num_channels, self.cfg.plane_size, self.cfg.plane_size), + dtype=torch.float32, + ) + * 1 + / math.sqrt(self.cfg.num_channels) + ) + + def forward(self, batch_size: int, cond_embeddings: Float[Tensor, "B Ct"] = None) -> Float[Tensor, "B Ct Nt"]: + embeddings = repeat(self.embeddings, "Np Ct Hp Wp -> B Np Ct Hp Wp", B=batch_size) + if cond_embeddings is not None: + embeddings = embeddings + cond_embeddings + return rearrange( + embeddings, + "B Np Ct Hp Wp -> B Ct (Np Hp Wp)", + ) + + def detokenize( + self, tokens: Float[Tensor, "B Ct Nt"] + ) -> Float[Tensor, "B 3 Ct Hp Wp"]: + batch_size, Ct, Nt = tokens.shape + assert Nt == self.cfg.plane_size**2 * 3 + assert Ct == self.cfg.num_channels + return rearrange( + tokens, + "B Ct (Np Hp Wp) -> B Np Ct Hp Wp", + Np=3, + Hp=self.cfg.plane_size, + Wp=self.cfg.plane_size, + ) diff --git a/hort/models/tgs/models/transformers.py b/hort/models/tgs/models/transformers.py new file mode 100644 index 0000000000000000000000000000000000000000..6c7ce3b82f3c06a44ba97986307ddc30284163d0 --- /dev/null +++ b/hort/models/tgs/models/transformers.py @@ -0,0 +1,908 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn.functional as F +from torch import nn + +from diffusers.utils.torch_utils import maybe_allow_in_graph +from diffusers.models.activations import get_activation +from diffusers.models.attention_processor import Attention +from diffusers.models.embeddings import CombinedTimestepLabelEmbeddings + +from dataclasses import dataclass +from tgs.utils.base import BaseModule +from tgs.utils.typing import * + + +class MemoryEfficientAttentionMixin: + def enable_xformers_memory_efficient_attention( + self, attention_op: Optional[Callable] = None + ): + r""" + Enable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/). When this + option is enabled, you should observe lower GPU memory usage and a potential speed up during inference. Speed + up during training is not guaranteed. + + + + ⚠️ When memory efficient attention and sliced attention are both enabled, memory efficient attention takes + precedent. + + + + Parameters: + attention_op (`Callable`, *optional*): + Override the default `None` operator for use as `op` argument to the + [`memory_efficient_attention()`](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.memory_efficient_attention) + function of xFormers. + + Examples: + + ```py + >>> import torch + >>> from diffusers import DiffusionPipeline + >>> from xformers.ops import MemoryEfficientAttentionFlashAttentionOp + + >>> pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16) + >>> pipe = pipe.to("cuda") + >>> pipe.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp) + >>> # Workaround for not accepting attention shape using VAE for Flash Attention + >>> pipe.vae.enable_xformers_memory_efficient_attention(attention_op=None) + ``` + """ + self.set_use_memory_efficient_attention_xformers(True, attention_op) + + def disable_xformers_memory_efficient_attention(self): + r""" + Disable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/). + """ + self.set_use_memory_efficient_attention_xformers(False) + + def set_use_memory_efficient_attention_xformers( + self, valid: bool, attention_op: Optional[Callable] = None + ) -> None: + # Recursively walk through all the children. + # Any children which exposes the set_use_memory_efficient_attention_xformers method + # gets the message + def fn_recursive_set_mem_eff(module: torch.nn.Module): + if hasattr(module, "set_use_memory_efficient_attention_xformers"): + module.set_use_memory_efficient_attention_xformers(valid, attention_op) + + for child in module.children(): + fn_recursive_set_mem_eff(child) + + for module in self.children(): + if isinstance(module, torch.nn.Module): + fn_recursive_set_mem_eff(module) + + +@maybe_allow_in_graph +class GatedSelfAttentionDense(nn.Module): + r""" + A gated self-attention dense layer that combines visual features and object features. + + Parameters: + query_dim (`int`): The number of channels in the query. + context_dim (`int`): The number of channels in the context. + n_heads (`int`): The number of heads to use for attention. + d_head (`int`): The number of channels in each head. + """ + + def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int): + super().__init__() + + # we need a linear projection since we need cat visual feature and obj feature + self.linear = nn.Linear(context_dim, query_dim) + + self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head) + self.ff = FeedForward(query_dim, activation_fn="geglu") + + self.norm1 = nn.LayerNorm(query_dim) + self.norm2 = nn.LayerNorm(query_dim) + + self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0))) + self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0))) + + self.enabled = True + + def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor: + if not self.enabled: + return x + + n_visual = x.shape[1] + objs = self.linear(objs) + + x = ( + x + + self.alpha_attn.tanh() + * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :] + ) + x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x)) + + return x + + +@maybe_allow_in_graph +class BasicTransformerBlock(nn.Module, MemoryEfficientAttentionMixin): + r""" + A basic Transformer block. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm (: + obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used. + double_self_attention (`bool`, *optional*): + Whether to use two self-attention layers. In this case no cross attention layers are used. + upcast_attention (`bool`, *optional*): + Whether to upcast the attention computation to float32. This is useful for mixed precision training. + norm_elementwise_affine (`bool`, *optional*, defaults to `True`): + Whether to use learnable elementwise affine parameters for normalization. + norm_type (`str`, *optional*, defaults to `"layer_norm"`): + The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`. + final_dropout (`bool` *optional*, defaults to False): + Whether to apply a final dropout after the last feed-forward layer. + attention_type (`str`, *optional*, defaults to `"default"`): + The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + cond_dim_ada_norm_continuous: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + norm_type: str = "layer_norm", + final_dropout: bool = False, + attention_type: str = "default", + ): + super().__init__() + self.only_cross_attention = only_cross_attention + + self.use_ada_layer_norm_zero = ( + num_embeds_ada_norm is not None + ) and norm_type == "ada_norm_zero" + self.use_ada_layer_norm = ( + num_embeds_ada_norm is not None + ) and norm_type == "ada_norm" + self.use_ada_layer_norm_continuous = ( + cond_dim_ada_norm_continuous is not None + ) and norm_type == "ada_norm_continuous" + + assert ( + int(self.use_ada_layer_norm) + + int(self.use_ada_layer_norm_continuous) + + int(self.use_ada_layer_norm_zero) + <= 1 + ) + + if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: + raise ValueError( + f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" + f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." + ) + + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + if self.use_ada_layer_norm: + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif self.use_ada_layer_norm_continuous: + self.norm1 = AdaLayerNormContinuous(dim, cond_dim_ada_norm_continuous) + elif self.use_ada_layer_norm_zero: + self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) + else: + self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + ) + + # 2. Cross-Attn + if cross_attention_dim is not None or double_self_attention: + # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. + # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during + # the second cross attention block. + if self.use_ada_layer_norm: + self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif self.use_ada_layer_norm_continuous: + self.norm2 = AdaLayerNormContinuous(dim, cond_dim_ada_norm_continuous) + else: + self.norm2 = nn.LayerNorm( + dim, elementwise_affine=norm_elementwise_affine + ) + + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim + if not double_self_attention + else None, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + ) # is self-attn if encoder_hidden_states is none + else: + self.norm2 = None + self.attn2 = None + + # 3. Feed-forward + if self.use_ada_layer_norm_continuous: + self.norm3 = AdaLayerNormContinuous(dim, cond_dim_ada_norm_continuous) + else: + self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + self.ff = FeedForward( + dim, + dropout=dropout, + activation_fn=activation_fn, + final_dropout=final_dropout, + ) + + # 4. Fuser + if attention_type == "gated" or attention_type == "gated-text-image": + self.fuser = GatedSelfAttentionDense( + dim, cross_attention_dim, num_attention_heads, attention_head_dim + ) + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int): + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim + + def forward( + self, + hidden_states: torch.FloatTensor, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + timestep: Optional[torch.LongTensor] = None, + modulation_cond: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[torch.LongTensor] = None, + ) -> torch.FloatTensor: + # Notice that normalization is always applied before the real computation in the following blocks. + # 0. Self-Attention + if self.use_ada_layer_norm: + norm_hidden_states = self.norm1(hidden_states, timestep) + elif self.use_ada_layer_norm_continuous: + norm_hidden_states = self.norm1(hidden_states, modulation_cond) + elif self.use_ada_layer_norm_zero: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + else: + norm_hidden_states = self.norm1(hidden_states) + + # 1. Retrieve lora scale. + lora_scale = ( + cross_attention_kwargs.get("scale", 1.0) + if cross_attention_kwargs is not None + else 1.0 + ) + + # 2. Prepare GLIGEN inputs + cross_attention_kwargs = ( + cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} + ) + gligen_kwargs = cross_attention_kwargs.pop("gligen", None) + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states + if self.only_cross_attention + else None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + if self.use_ada_layer_norm_zero: + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = attn_output + hidden_states + + # 2.5 GLIGEN Control + if gligen_kwargs is not None: + hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) + # 2.5 ends + + # 3. Cross-Attention + if self.attn2 is not None: + if self.use_ada_layer_norm: + norm_hidden_states = self.norm2(hidden_states, timestep) + elif self.use_ada_layer_norm_continuous: + norm_hidden_states = self.norm2(hidden_states, modulation_cond) + else: + norm_hidden_states = self.norm2(hidden_states) + + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + hidden_states = attn_output + hidden_states + + # 4. Feed-forward + if self.use_ada_layer_norm_continuous: + norm_hidden_states = self.norm3(hidden_states, modulation_cond) + else: + norm_hidden_states = self.norm3(hidden_states) + + if self.use_ada_layer_norm_zero: + norm_hidden_states = ( + norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + ) + + if self._chunk_size is not None: + # "feed_forward_chunk_size" can be used to save memory + if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0: + raise ValueError( + f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." + ) + + num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size + ff_output = torch.cat( + [ + self.ff(hid_slice, scale=lora_scale) + for hid_slice in norm_hidden_states.chunk( + num_chunks, dim=self._chunk_dim + ) + ], + dim=self._chunk_dim, + ) + else: + ff_output = self.ff(norm_hidden_states, scale=lora_scale) + + if self.use_ada_layer_norm_zero: + ff_output = gate_mlp.unsqueeze(1) * ff_output + + hidden_states = ff_output + hidden_states + + return hidden_states + + +class FeedForward(nn.Module): + r""" + A feed-forward layer. + + Parameters: + dim (`int`): The number of channels in the input. + dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. + mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. + """ + + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + mult: int = 4, + dropout: float = 0.0, + activation_fn: str = "geglu", + final_dropout: bool = False, + ): + super().__init__() + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + linear_cls = nn.Linear + + if activation_fn == "gelu": + act_fn = GELU(dim, inner_dim) + if activation_fn == "gelu-approximate": + act_fn = GELU(dim, inner_dim, approximate="tanh") + elif activation_fn == "geglu": + act_fn = GEGLU(dim, inner_dim) + elif activation_fn == "geglu-approximate": + act_fn = ApproximateGELU(dim, inner_dim) + + self.net = nn.ModuleList([]) + # project in + self.net.append(act_fn) + # project dropout + self.net.append(nn.Dropout(dropout)) + # project out + self.net.append(linear_cls(inner_dim, dim_out)) + # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout + if final_dropout: + self.net.append(nn.Dropout(dropout)) + + def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor: + for module in self.net: + hidden_states = module(hidden_states) + return hidden_states + + +class GELU(nn.Module): + r""" + GELU activation function with tanh approximation support with `approximate="tanh"`. + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation. + """ + + def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out) + self.approximate = approximate + + def gelu(self, gate: torch.Tensor) -> torch.Tensor: + if gate.device.type != "mps": + return F.gelu(gate, approximate=self.approximate) + # mps: gelu is not implemented for float16 + return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to( + dtype=gate.dtype + ) + + def forward(self, hidden_states): + hidden_states = self.proj(hidden_states) + hidden_states = self.gelu(hidden_states) + return hidden_states + + +class GEGLU(nn.Module): + r""" + A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202. + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + """ + + def __init__(self, dim_in: int, dim_out: int): + super().__init__() + linear_cls = nn.Linear + + self.proj = linear_cls(dim_in, dim_out * 2) + + def gelu(self, gate: torch.Tensor) -> torch.Tensor: + if gate.device.type != "mps": + return F.gelu(gate) + # mps: gelu is not implemented for float16 + return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) + + def forward(self, hidden_states, scale: float = 1.0): + args = () + hidden_states, gate = self.proj(hidden_states, *args).chunk(2, dim=-1) + return hidden_states * self.gelu(gate) + + +class ApproximateGELU(nn.Module): + r""" + The approximate form of Gaussian Error Linear Unit (GELU). For more details, see section 2: + https://arxiv.org/abs/1606.08415. + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + """ + + def __init__(self, dim_in: int, dim_out: int): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + return x * torch.sigmoid(1.702 * x) + + +class AdaLayerNorm(nn.Module): + r""" + Norm layer modified to incorporate timestep embeddings. + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + num_embeddings (`int`): The size of the dictionary of embeddings. + """ + + def __init__(self, embedding_dim: int, num_embeddings: int): + super().__init__() + self.emb = nn.Embedding(num_embeddings, embedding_dim) + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, embedding_dim * 2) + self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False) + + def forward(self, x: torch.Tensor, timestep: torch.Tensor) -> torch.Tensor: + emb = self.linear(self.silu(self.emb(timestep))) + scale, shift = torch.chunk(emb, 2, dim=1) + x = self.norm(x) * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + return x + + +class AdaLayerNormContinuous(nn.Module): + r""" + Norm layer modified to incorporate arbitrary continuous embeddings. + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + """ + + def __init__(self, embedding_dim: int, condition_dim: int): + super().__init__() + self.silu = nn.SiLU() + self.linear1 = nn.Linear(condition_dim, condition_dim) + self.linear2 = nn.Linear(condition_dim, embedding_dim * 2) + self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False) + + def forward(self, x: torch.Tensor, condition: torch.Tensor) -> torch.Tensor: + emb = self.linear2(self.silu(self.linear1(condition))) + scale, shift = torch.chunk(emb, 2, dim=1) + x = self.norm(x) * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + return x + + +class Modulation(nn.Module): + def __init__(self, embedding_dim: int, condition_dim: int, zero_init: bool = False, single_layer: bool = False): + super().__init__() + self.silu = nn.SiLU() + if single_layer: + self.linear1 = nn.Identity() + else: + self.linear1 = nn.Linear(condition_dim, condition_dim) + + self.linear2 = nn.Linear(condition_dim, embedding_dim * 2) + + # Only zero init the last linear layer + if zero_init: + nn.init.zeros_(self.linear2.weight) + nn.init.zeros_(self.linear2.bias) + + def forward(self, x: torch.Tensor, condition: torch.Tensor) -> torch.Tensor: + emb = self.linear2(self.silu(self.linear1(condition))) + scale, shift = torch.chunk(emb, 2, dim=1) + x = x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + return x + + +class AdaLayerNormZero(nn.Module): + r""" + Norm layer adaptive layer norm zero (adaLN-Zero). + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + num_embeddings (`int`): The size of the dictionary of embeddings. + """ + + def __init__(self, embedding_dim: int, num_embeddings: int): + super().__init__() + + self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim) + + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True) + self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6) + + def forward( + self, + x: torch.Tensor, + timestep: torch.Tensor, + class_labels: torch.LongTensor, + hidden_dtype: Optional[torch.dtype] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + emb = self.linear( + self.silu(self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)) + ) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk( + 6, dim=1 + ) + x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None] + return x, gate_msa, shift_mlp, scale_mlp, gate_mlp + + +class AdaGroupNorm(nn.Module): + r""" + GroupNorm layer modified to incorporate timestep embeddings. + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + num_embeddings (`int`): The size of the dictionary of embeddings. + num_groups (`int`): The number of groups to separate the channels into. + act_fn (`str`, *optional*, defaults to `None`): The activation function to use. + eps (`float`, *optional*, defaults to `1e-5`): The epsilon value to use for numerical stability. + """ + + def __init__( + self, + embedding_dim: int, + out_dim: int, + num_groups: int, + act_fn: Optional[str] = None, + eps: float = 1e-5, + ): + super().__init__() + self.num_groups = num_groups + self.eps = eps + + if act_fn is None: + self.act = None + else: + self.act = get_activation(act_fn) + + self.linear = nn.Linear(embedding_dim, out_dim * 2) + + def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: + if self.act: + emb = self.act(emb) + emb = self.linear(emb) + emb = emb[:, :, None, None] + scale, shift = emb.chunk(2, dim=1) + + x = F.group_norm(x, self.num_groups, eps=self.eps) + x = x * (1 + scale) + shift + return x + +class Transformer1D(BaseModule, MemoryEfficientAttentionMixin): + """ + A 1D Transformer model for sequence data. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + The number of channels in the input and output (specify if the input is **continuous**). + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. + num_embeds_ada_norm ( `int`, *optional*): + The number of diffusion steps used during training. Pass if at least one of the norm_layers is + `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are + added to the hidden states. + + During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`. + attention_bias (`bool`, *optional*): + Configure if the `TransformerBlocks` attention should contain a bias parameter. + """ + + @dataclass + class Config(BaseModule.Config): + num_attention_heads: int = 16 + attention_head_dim: int = 88 + in_channels: Optional[int] = None + out_channels: Optional[int] = None + num_layers: int = 1 + dropout: float = 0.0 + norm_num_groups: int = 32 + cross_attention_dim: Optional[int] = None + attention_bias: bool = False + activation_fn: str = "geglu" + num_embeds_ada_norm: Optional[int] = None + cond_dim_ada_norm_continuous: Optional[int] = None + only_cross_attention: bool = False + double_self_attention: bool = False + upcast_attention: bool = False + norm_type: str = "layer_norm" + norm_elementwise_affine: bool = True + attention_type: str = "default" + enable_memory_efficient_attention: bool = False + gradient_checkpointing: bool = False + + cfg: Config + + def configure(self) -> None: + super().configure() + + self.num_attention_heads = self.cfg.num_attention_heads + self.attention_head_dim = self.cfg.attention_head_dim + inner_dim = self.num_attention_heads * self.attention_head_dim + + linear_cls = nn.Linear + + if self.cfg.norm_type == "layer_norm" and ( + self.cfg.num_embeds_ada_norm is not None + or self.cfg.cond_dim_ada_norm_continuous is not None + ): + raise ValueError("Incorrect norm_type.") + + # 2. Define input layers + self.in_channels = self.cfg.in_channels + + self.norm = torch.nn.GroupNorm( + num_groups=self.cfg.norm_num_groups, + num_channels=self.cfg.in_channels, + eps=1e-6, + affine=True, + ) + self.proj_in = linear_cls(self.cfg.in_channels, inner_dim) + + # 3. Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + self.num_attention_heads, + self.attention_head_dim, + dropout=self.cfg.dropout, + cross_attention_dim=self.cfg.cross_attention_dim, + activation_fn=self.cfg.activation_fn, + num_embeds_ada_norm=self.cfg.num_embeds_ada_norm, + cond_dim_ada_norm_continuous=self.cfg.cond_dim_ada_norm_continuous, + attention_bias=self.cfg.attention_bias, + only_cross_attention=self.cfg.only_cross_attention, + double_self_attention=self.cfg.double_self_attention, + upcast_attention=self.cfg.upcast_attention, + norm_type=self.cfg.norm_type, + norm_elementwise_affine=self.cfg.norm_elementwise_affine, + attention_type=self.cfg.attention_type, + ) + for d in range(self.cfg.num_layers) + ] + ) + + # 4. Define output layers + self.out_channels = ( + self.cfg.in_channels + if self.cfg.out_channels is None + else self.cfg.out_channels + ) + + self.proj_out = linear_cls(inner_dim, self.cfg.in_channels) + + self.gradient_checkpointing = self.cfg.gradient_checkpointing + + self.set_use_memory_efficient_attention_xformers( + self.cfg.enable_memory_efficient_attention + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + modulation_cond: Optional[torch.FloatTensor] = None, + class_labels: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + ): + """ + The [`Transformer1DModel`] forward method. + + Args: + hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous): + Input `hidden_states`. + encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.LongTensor`, *optional*): + Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. + class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): + Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in + `AdaLayerZeroNorm`. + cross_attention_kwargs ( `Dict[str, Any]`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + attention_mask ( `torch.Tensor`, *optional*): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + encoder_attention_mask ( `torch.Tensor`, *optional*): + Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: + + * Mask `(batch, sequence_length)` True = keep, False = discard. + * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. + + If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format + above. This bias will be added to the cross-attention scores. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. + # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. + # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None and attention_mask.ndim == 2: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: + encoder_attention_mask = ( + 1 - encoder_attention_mask.to(hidden_states.dtype) + ) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # 1. Input + batch, _, seq_len = hidden_states.shape + residual = hidden_states + + hidden_states = self.norm(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 1).reshape( + batch, seq_len, inner_dim + ) + hidden_states = self.proj_in(hidden_states) + + # 2. Blocks + for block in self.transformer_blocks: + if self.training and self.gradient_checkpointing: + hidden_states = torch.utils.checkpoint.checkpoint( + block, + hidden_states, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + timestep, + modulation_cond, + cross_attention_kwargs, + class_labels, + use_reentrant=False, + ) + else: + hidden_states = block( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + timestep=timestep, + modulation_cond=modulation_cond, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + ) + + # 3. Output + hidden_states = self.proj_out(hidden_states) + hidden_states = ( + hidden_states.reshape(batch, seq_len, inner_dim) + .permute(0, 2, 1) + .contiguous() + ) + + output = hidden_states + residual + + return output diff --git a/hort/models/tgs/utils/__init__.py b/hort/models/tgs/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/hort/models/tgs/utils/base.py b/hort/models/tgs/utils/base.py new file mode 100644 index 0000000000000000000000000000000000000000..36cee1fad8741cb8e0c06f9126a6aef51ceb206e --- /dev/null +++ b/hort/models/tgs/utils/base.py @@ -0,0 +1,129 @@ +from dataclasses import dataclass + +import torch.nn as nn + +from tgs.utils.config import parse_structured +from tgs.utils.misc import get_device, load_module_weights +from tgs.utils.typing import * + + +class Configurable: + @dataclass + class Config: + pass + + def __init__(self, cfg: Optional[dict] = None) -> None: + super().__init__() + self.cfg = parse_structured(self.Config, cfg) + + +class Updateable: + def do_update_step( + self, epoch: int, global_step: int, on_load_weights: bool = False + ): + for attr in self.__dir__(): + if attr.startswith("_"): + continue + try: + module = getattr(self, attr) + except: + continue # ignore attributes like property, which can't be retrived using getattr? + if isinstance(module, Updateable): + module.do_update_step( + epoch, global_step, on_load_weights=on_load_weights + ) + self.update_step(epoch, global_step, on_load_weights=on_load_weights) + + def do_update_step_end(self, epoch: int, global_step: int): + for attr in self.__dir__(): + if attr.startswith("_"): + continue + try: + module = getattr(self, attr) + except: + continue # ignore attributes like property, which can't be retrived using getattr? + if isinstance(module, Updateable): + module.do_update_step_end(epoch, global_step) + self.update_step_end(epoch, global_step) + + def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False): + # override this method to implement custom update logic + # if on_load_weights is True, you should be careful doing things related to model evaluations, + # as the models and tensors are not guarenteed to be on the same device + pass + + def update_step_end(self, epoch: int, global_step: int): + pass + + +def update_if_possible(module: Any, epoch: int, global_step: int) -> None: + if isinstance(module, Updateable): + module.do_update_step(epoch, global_step) + + +def update_end_if_possible(module: Any, epoch: int, global_step: int) -> None: + if isinstance(module, Updateable): + module.do_update_step_end(epoch, global_step) + + +class BaseObject(Updateable): + @dataclass + class Config: + pass + + cfg: Config # add this to every subclass of BaseObject to enable static type checking + + def __init__( + self, cfg: Optional[Union[dict, DictConfig]] = None, *args, **kwargs + ) -> None: + super().__init__() + self.cfg = parse_structured(self.Config, cfg) + self.device = get_device() + self.configure(*args, **kwargs) + + def configure(self, *args, **kwargs) -> None: + pass + + +class BaseModule(nn.Module, Updateable): + @dataclass + class Config: + weights: Optional[str] = None + freeze: Optional[bool] = False + + cfg: Config # add this to every subclass of BaseModule to enable static type checking + + def __init__( + self, cfg: Optional[Union[dict, DictConfig]] = None, *args, **kwargs + ) -> None: + super().__init__() + self.cfg = parse_structured(self.Config, cfg) + self.device = get_device() + self._non_modules = {} + self.configure(*args, **kwargs) + if self.cfg.weights is not None: + # format: path/to/weights:module_name + weights_path, module_name = self.cfg.weights.split(":") + state_dict = load_module_weights( + weights_path, module_name=module_name, map_location="cpu" + ) + self.load_state_dict(state_dict, strict=False) + # self.do_update_step( + # epoch, global_step, on_load_weights=True + # ) # restore states + + if self.cfg.freeze: + for params in self.parameters(): + params.requires_grad = False + + def configure(self, *args, **kwargs) -> None: + pass + + def register_non_module(self, name: str, module: nn.Module) -> None: + # non-modules won't be treated as model parameters + if name in self._non_modules: + raise ValueError(f"Non-module {name} already exists!") + self._non_modules[name] = module + + def non_module(self, name: str): + return self._non_modules.get(name, None) diff --git a/hort/models/tgs/utils/config.py b/hort/models/tgs/utils/config.py new file mode 100644 index 0000000000000000000000000000000000000000..8fe2c623f5a0458896d958e00e8e8327bfafd38b --- /dev/null +++ b/hort/models/tgs/utils/config.py @@ -0,0 +1,74 @@ +import os +from dataclasses import dataclass, field + +from omegaconf import OmegaConf + +from tgs.utils.typing import * + +# ============ Register OmegaConf Recolvers ============= # +OmegaConf.register_new_resolver( + "calc_exp_lr_decay_rate", lambda factor, n: factor ** (1.0 / n) +) +OmegaConf.register_new_resolver("add", lambda a, b: a + b) +OmegaConf.register_new_resolver("sub", lambda a, b: a - b) +OmegaConf.register_new_resolver("mul", lambda a, b: a * b) +OmegaConf.register_new_resolver("div", lambda a, b: a / b) +OmegaConf.register_new_resolver("idiv", lambda a, b: a // b) +OmegaConf.register_new_resolver("basename", lambda p: os.path.basename(p)) +OmegaConf.register_new_resolver("rmspace", lambda s, sub: s.replace(" ", sub)) +OmegaConf.register_new_resolver("tuple2", lambda s: [float(s), float(s)]) +OmegaConf.register_new_resolver("gt0", lambda s: s > 0) +OmegaConf.register_new_resolver("not", lambda s: not s) +OmegaConf.register_new_resolver("shsdim", lambda sh_degree: (sh_degree + 1) ** 2 * 3) +# ======================================================= # + +# ============== Automatic Name Resolvers =============== # +def get_naming_convention(cfg): + # TODO + name = f"tgs_{cfg.system.backbone.num_layers}" + return name + +# ======================================================= # + +@dataclass +class ExperimentConfig: + n_gpus: int = 1 + data: dict = field(default_factory=dict) + system: dict = field(default_factory=dict) + +def load_config( + *yamls: str, cli_args: list = [], from_string=False, makedirs=True, **kwargs +) -> Any: + if from_string: + parse_func = OmegaConf.create + else: + parse_func = OmegaConf.load + yaml_confs = [] + for y in yamls: + conf = parse_func(y) + extends = conf.pop("extends", None) + if extends: + assert os.path.exists(extends), f"File {extends} does not exist." + yaml_confs.append(OmegaConf.load(extends)) + yaml_confs.append(conf) + cli_conf = OmegaConf.from_cli(cli_args) + cfg = OmegaConf.merge(*yaml_confs, cli_conf, kwargs) + OmegaConf.resolve(cfg) + assert isinstance(cfg, DictConfig) + scfg: ExperimentConfig = parse_structured(ExperimentConfig, cfg) + + return scfg + + +def config_to_primitive(config, resolve: bool = True) -> Any: + return OmegaConf.to_container(config, resolve=resolve) + + +def dump_config(path: str, config) -> None: + with open(path, "w") as fp: + OmegaConf.save(config=config, f=fp) + + +def parse_structured(fields: Any, cfg: Optional[Union[dict, DictConfig]] = None) -> Any: + scfg = OmegaConf.structured(fields(**cfg)) + return scfg diff --git a/hort/models/tgs/utils/misc.py b/hort/models/tgs/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..85442996c7324d86aa94fe3661cee44d8b404a25 --- /dev/null +++ b/hort/models/tgs/utils/misc.py @@ -0,0 +1,88 @@ +import os +import re + +import torch +from packaging import version + +from tgs.utils.typing import * + + +def parse_version(ver: str): + return version.parse(ver) + + +def get_rank(): + # SLURM_PROCID can be set even if SLURM is not managing the multiprocessing, + # therefore LOCAL_RANK needs to be checked first + rank_keys = ("RANK", "LOCAL_RANK", "SLURM_PROCID", "JSM_NAMESPACE_RANK") + for key in rank_keys: + rank = os.environ.get(key) + if rank is not None: + return int(rank) + return 0 + + +def get_device(): + return torch.device(f"cuda:{get_rank()}") + + +def load_module_weights( + path, module_name=None, ignore_modules=None, map_location=None +) -> Tuple[dict, int, int]: + if module_name is not None and ignore_modules is not None: + raise ValueError("module_name and ignore_modules cannot be both set") + if map_location is None: + map_location = get_device() + + ckpt = torch.load(path, map_location=map_location) + state_dict = ckpt["state_dict"] + state_dict_to_load = state_dict + + if ignore_modules is not None: + state_dict_to_load = {} + for k, v in state_dict.items(): + ignore = any( + [k.startswith(ignore_module + ".") for ignore_module in ignore_modules] + ) + if ignore: + continue + state_dict_to_load[k] = v + + if module_name is not None: + state_dict_to_load = {} + for k, v in state_dict.items(): + m = re.match(rf"^{module_name}\.(.*)$", k) + if m is None: + continue + state_dict_to_load[m.group(1)] = v + + return state_dict_to_load + +# convert a function into recursive style to handle nested dict/list/tuple variables +def make_recursive_func(func): + def wrapper(vars, *args, **kwargs): + if isinstance(vars, list): + return [wrapper(x, *args, **kwargs) for x in vars] + elif isinstance(vars, tuple): + return tuple([wrapper(x, *args, **kwargs) for x in vars]) + elif isinstance(vars, dict): + return {k: wrapper(v, *args, **kwargs) for k, v in vars.items()} + else: + return func(vars, *args, **kwargs) + + return wrapper + +@make_recursive_func +def todevice(vars, device="cuda"): + if isinstance(vars, torch.Tensor): + return vars.to(device) + elif isinstance(vars, str): + return vars + elif isinstance(vars, bool): + return vars + elif isinstance(vars, float): + return vars + elif isinstance(vars, int): + return vars + else: + raise NotImplementedError("invalid input type {} for tensor2numpy".format(type(vars))) diff --git a/hort/models/tgs/utils/ops.py b/hort/models/tgs/utils/ops.py new file mode 100644 index 0000000000000000000000000000000000000000..531474faf282a42db0a06e1440dd8b58d441c8bb --- /dev/null +++ b/hort/models/tgs/utils/ops.py @@ -0,0 +1,279 @@ +import math + +import numpy as np +import torch +import torch.nn.functional as F +from torch.autograd import Function +from torch.amp import custom_bwd, custom_fwd +from pytorch3d import io +from pytorch3d.renderer import ( + PointsRasterizationSettings, + PointsRasterizer) +from pytorch3d.structures import Pointclouds +from pytorch3d.utils.camera_conversions import cameras_from_opencv_projection +import cv2 + +from tgs.utils.typing import * + +ValidScale = Union[Tuple[float, float], Num[Tensor, "2 D"]] + +def scale_tensor( + dat: Num[Tensor, "... D"], inp_scale: ValidScale, tgt_scale: ValidScale +): + if inp_scale is None: + inp_scale = (0, 1) + if tgt_scale is None: + tgt_scale = (0, 1) + if isinstance(tgt_scale, Tensor): + assert dat.shape[-1] == tgt_scale.shape[-1] + dat = (dat - inp_scale[0]) / (inp_scale[1] - inp_scale[0]) + dat = dat * (tgt_scale[1] - tgt_scale[0]) + tgt_scale[0] + return dat + + +class _TruncExp(Function): # pylint: disable=abstract-method + # Implementation from torch-ngp: + # https://github.com/ashawkey/torch-ngp/blob/93b08a0d4ec1cc6e69d85df7f0acdfb99603b628/activation.py + @staticmethod + @custom_fwd(cast_inputs=torch.float32, device_type="cuda") + def forward(ctx, x): # pylint: disable=arguments-differ + ctx.save_for_backward(x) + return torch.exp(x) + + @staticmethod + @custom_bwd(device_type="cuda") + def backward(ctx, g): # pylint: disable=arguments-differ + x = ctx.saved_tensors[0] + return g * torch.exp(torch.clamp(x, max=15)) + + +trunc_exp = _TruncExp.apply + + +def get_activation(name) -> Callable: + if name is None: + return lambda x: x + name = name.lower() + if name == "none": + return lambda x: x + elif name == "lin2srgb": + return lambda x: torch.where( + x > 0.0031308, + torch.pow(torch.clamp(x, min=0.0031308), 1.0 / 2.4) * 1.055 - 0.055, + 12.92 * x, + ).clamp(0.0, 1.0) + elif name == "exp": + return lambda x: torch.exp(x) + elif name == "shifted_exp": + return lambda x: torch.exp(x - 1.0) + elif name == "trunc_exp": + return trunc_exp + elif name == "shifted_trunc_exp": + return lambda x: trunc_exp(x - 1.0) + elif name == "sigmoid": + return lambda x: torch.sigmoid(x) + elif name == "tanh": + return lambda x: torch.tanh(x) + elif name == "shifted_softplus": + return lambda x: F.softplus(x - 1.0) + elif name == "scale_-11_01": + return lambda x: x * 0.5 + 0.5 + else: + try: + return getattr(F, name) + except AttributeError: + raise ValueError(f"Unknown activation function: {name}") + +def get_ray_directions( + H: int, + W: int, + focal: Union[float, Tuple[float, float]], + principal: Optional[Tuple[float, float]] = None, + use_pixel_centers: bool = True, +) -> Float[Tensor, "H W 3"]: + """ + Get ray directions for all pixels in camera coordinate. + Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ + ray-tracing-generating-camera-rays/standard-coordinate-systems + + Inputs: + H, W, focal, principal, use_pixel_centers: image height, width, focal length, principal point and whether use pixel centers + Outputs: + directions: (H, W, 3), the direction of the rays in camera coordinate + """ + pixel_center = 0.5 if use_pixel_centers else 0 + + if isinstance(focal, float): + fx, fy = focal, focal + cx, cy = W / 2, H / 2 + else: + fx, fy = focal + assert principal is not None + cx, cy = principal + + i, j = torch.meshgrid( + torch.arange(W, dtype=torch.float32) + pixel_center, + torch.arange(H, dtype=torch.float32) + pixel_center, + indexing="xy", + ) + + directions: Float[Tensor, "H W 3"] = torch.stack( + [(i - cx) / fx, -(j - cy) / fy, -torch.ones_like(i)], -1 + ) + + return directions + + +def get_rays( + directions: Float[Tensor, "... 3"], + c2w: Float[Tensor, "... 4 4"], + keepdim=False, + noise_scale=0.0, +) -> Tuple[Float[Tensor, "... 3"], Float[Tensor, "... 3"]]: + # Rotate ray directions from camera coordinate to the world coordinate + assert directions.shape[-1] == 3 + + if directions.ndim == 2: # (N_rays, 3) + if c2w.ndim == 2: # (4, 4) + c2w = c2w[None, :, :] + assert c2w.ndim == 3 # (N_rays, 4, 4) or (1, 4, 4) + rays_d = (directions[:, None, :] * c2w[:, :3, :3]).sum(-1) # (N_rays, 3) + rays_o = c2w[:, :3, 3].expand(rays_d.shape) + elif directions.ndim == 3: # (H, W, 3) + assert c2w.ndim in [2, 3] + if c2w.ndim == 2: # (4, 4) + rays_d = (directions[:, :, None, :] * c2w[None, None, :3, :3]).sum( + -1 + ) # (H, W, 3) + rays_o = c2w[None, None, :3, 3].expand(rays_d.shape) + elif c2w.ndim == 3: # (B, 4, 4) + rays_d = (directions[None, :, :, None, :] * c2w[:, None, None, :3, :3]).sum( + -1 + ) # (B, H, W, 3) + rays_o = c2w[:, None, None, :3, 3].expand(rays_d.shape) + elif directions.ndim == 4: # (B, H, W, 3) + assert c2w.ndim == 3 # (B, 4, 4) + rays_d = (directions[:, :, :, None, :] * c2w[:, None, None, :3, :3]).sum( + -1 + ) # (B, H, W, 3) + rays_o = c2w[:, None, None, :3, 3].expand(rays_d.shape) + + # add camera noise to avoid grid-like artifect + # https://github.com/ashawkey/stable-dreamfusion/blob/49c3d4fa01d68a4f027755acf94e1ff6020458cc/nerf/utils.py#L373 + if noise_scale > 0: + rays_o = rays_o + torch.randn(3, device=rays_o.device) * noise_scale + rays_d = rays_d + torch.randn(3, device=rays_d.device) * noise_scale + + rays_d = F.normalize(rays_d, dim=-1) + if not keepdim: + rays_o, rays_d = rays_o.reshape(-1, 3), rays_d.reshape(-1, 3) + + return rays_o, rays_d + + +def get_projection_matrix( + fovy: Union[float, Float[Tensor, "B"]], aspect_wh: float, near: float, far: float +) -> Float[Tensor, "*B 4 4"]: + if isinstance(fovy, float): + proj_mtx = torch.zeros(4, 4, dtype=torch.float32) + proj_mtx[0, 0] = 1.0 / (math.tan(fovy / 2.0) * aspect_wh) + proj_mtx[1, 1] = -1.0 / math.tan( + fovy / 2.0 + ) # add a negative sign here as the y axis is flipped in nvdiffrast output + proj_mtx[2, 2] = -(far + near) / (far - near) + proj_mtx[2, 3] = -2.0 * far * near / (far - near) + proj_mtx[3, 2] = -1.0 + else: + batch_size = fovy.shape[0] + proj_mtx = torch.zeros(batch_size, 4, 4, dtype=torch.float32) + proj_mtx[:, 0, 0] = 1.0 / (torch.tan(fovy / 2.0) * aspect_wh) + proj_mtx[:, 1, 1] = -1.0 / torch.tan( + fovy / 2.0 + ) # add a negative sign here as the y axis is flipped in nvdiffrast output + proj_mtx[:, 2, 2] = -(far + near) / (far - near) + proj_mtx[:, 2, 3] = -2.0 * far * near / (far - near) + proj_mtx[:, 3, 2] = -1.0 + return proj_mtx + + +def get_mvp_matrix( + c2w: Float[Tensor, "*B 4 4"], proj_mtx: Float[Tensor, "*B 4 4"] +) -> Float[Tensor, "*B 4 4"]: + # calculate w2c from c2w: R' = Rt, t' = -Rt * t + # mathematically equivalent to (c2w)^-1 + if c2w.ndim == 2: + assert proj_mtx.ndim == 2 + w2c: Float[Tensor, "4 4"] = torch.zeros(4, 4).to(c2w) + w2c[:3, :3] = c2w[:3, :3].permute(1, 0) + w2c[:3, 3:] = -c2w[:3, :3].permute(1, 0) @ c2w[:3, 3:] + w2c[3, 3] = 1.0 + else: + w2c: Float[Tensor, "B 4 4"] = torch.zeros(c2w.shape[0], 4, 4).to(c2w) + w2c[:, :3, :3] = c2w[:, :3, :3].permute(0, 2, 1) + w2c[:, :3, 3:] = -c2w[:, :3, :3].permute(0, 2, 1) @ c2w[:, :3, 3:] + w2c[:, 3, 3] = 1.0 + # calculate mvp matrix by proj_mtx @ w2c (mv_mtx) + mvp_mtx = proj_mtx @ w2c + return mvp_mtx + +def get_intrinsic_from_fov(fov, H, W, bs=-1): + focal_length = 0.5 * H / np.tan(0.5 * fov) + intrinsic = np.identity(3, dtype=np.float32) + intrinsic[0, 0] = focal_length + intrinsic[1, 1] = focal_length + intrinsic[0, 2] = W / 2.0 + intrinsic[1, 2] = H / 2.0 + + if bs > 0: + intrinsic = intrinsic[None].repeat(bs, axis=0) + + return torch.from_numpy(intrinsic) + +def points_projection(points: Float[Tensor, "B Np 3"], + c2ws: Float[Tensor, "B 4 4"], + intrinsics: Float[Tensor, "B 3 3"], + local_features: Float[Tensor, "B C H W"], + # Rasterization settings + raster_point_radius: float = 0.0075, # point size + raster_points_per_pixel: int = 1, # a single point per pixel, for now + bin_size: int = 0): + B, C, H, W = local_features.shape + device = local_features.device + raster_settings = PointsRasterizationSettings( + image_size=(H, W), + radius=raster_point_radius, + points_per_pixel=raster_points_per_pixel, + bin_size=bin_size, + ) + Np = points.shape[1] + R = raster_settings.points_per_pixel + + w2cs = torch.inverse(c2ws) + image_size = torch.as_tensor([H, W]).view(1, 2).expand(w2cs.shape[0], -1).to(device) + cameras = cameras_from_opencv_projection(w2cs[:, :3, :3], w2cs[:, :3, 3], intrinsics, image_size) + + rasterize = PointsRasterizer(cameras=cameras, raster_settings=raster_settings) + fragments = rasterize(Pointclouds(points)) + fragments_idx: Tensor = fragments.idx.long() + visible_pixels = (fragments_idx > -1) # (B, H, W, R) + points_to_visible_pixels = fragments_idx[visible_pixels] + + # Reshape local features to (B, H, W, R, C) + local_features = local_features.permute(0, 2, 3, 1).unsqueeze(-2).expand(-1, -1, -1, R, -1) # (B, H, W, R, C) + + # Get local features corresponding to visible points + local_features_proj = torch.zeros(B * Np, C, device=device) + local_features_proj[points_to_visible_pixels] = local_features[visible_pixels] + local_features_proj = local_features_proj.reshape(B, Np, C) + + return local_features_proj + +def compute_distance_transform(mask: torch.Tensor): + image_size = mask.shape[-1] + distance_transform = torch.stack([ + torch.from_numpy(cv2.distanceTransform( + (1 - m), distanceType=cv2.DIST_L2, maskSize=cv2.DIST_MASK_3 + ) / (image_size / 2)) + for m in mask.squeeze(1).detach().cpu().numpy().astype(np.uint8) + ]).unsqueeze(1).clip(0, 1).to(mask.device) + return distance_transform diff --git a/hort/models/tgs/utils/saving.py b/hort/models/tgs/utils/saving.py new file mode 100644 index 0000000000000000000000000000000000000000..a497823b74a6cf9bc852f2223e95b3064a4a64ef --- /dev/null +++ b/hort/models/tgs/utils/saving.py @@ -0,0 +1,315 @@ +import os +import re +import shutil + +import cv2 +import imageio +import matplotlib.pyplot as plt +import numpy as np +import torch +from matplotlib import cm +from matplotlib.colors import LinearSegmentedColormap +from PIL import Image, ImageDraw + +import tgs +from tgs.utils.typing import * + +class SaverMixin: + _save_dir: Optional[str] = None + + def set_save_dir(self, save_dir: str): + self._save_dir = save_dir + + def get_save_dir(self): + if self._save_dir is None: + raise ValueError("Save dir is not set") + return self._save_dir + + def convert_data(self, data): + if data is None: + return None + elif isinstance(data, np.ndarray): + return data + elif isinstance(data, torch.Tensor): + return data.detach().cpu().numpy() + elif isinstance(data, list): + return [self.convert_data(d) for d in data] + elif isinstance(data, dict): + return {k: self.convert_data(v) for k, v in data.items()} + else: + raise TypeError( + "Data must be in type numpy.ndarray, torch.Tensor, list or dict, getting", + type(data), + ) + + def get_save_path(self, filename): + save_path = os.path.join(self.get_save_dir(), filename) + os.makedirs(os.path.dirname(save_path), exist_ok=True) + return save_path + + DEFAULT_RGB_KWARGS = {"data_format": "HWC", "data_range": (0, 1)} + DEFAULT_UV_KWARGS = { + "data_format": "HWC", + "data_range": (0, 1), + "cmap": "checkerboard", + } + DEFAULT_GRAYSCALE_KWARGS = {"data_range": None, "cmap": "jet"} + DEFAULT_GRID_KWARGS = {"align": "max"} + + def get_rgb_image_(self, img, data_format, data_range, rgba=False): + img = self.convert_data(img) + assert data_format in ["CHW", "HWC"] + if data_format == "CHW": + img = img.transpose(1, 2, 0) + if img.dtype != np.uint8: + img = img.clip(min=data_range[0], max=data_range[1]) + img = ( + (img - data_range[0]) / (data_range[1] - data_range[0]) * 255.0 + ).astype(np.uint8) + nc = 4 if rgba else 3 + imgs = [img[..., start : start + nc] for start in range(0, img.shape[-1], nc)] + imgs = [ + img_ + if img_.shape[-1] == nc + else np.concatenate( + [ + img_, + np.zeros( + (img_.shape[0], img_.shape[1], nc - img_.shape[2]), + dtype=img_.dtype, + ), + ], + axis=-1, + ) + for img_ in imgs + ] + img = np.concatenate(imgs, axis=1) + if rgba: + img = cv2.cvtColor(img, cv2.COLOR_RGBA2BGRA) + else: + img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + return img + + def _save_rgb_image( + self, + filename, + img, + data_format, + data_range + ): + img = self.get_rgb_image_(img, data_format, data_range) + cv2.imwrite(filename, img) + + def save_rgb_image( + self, + filename, + img, + data_format=DEFAULT_RGB_KWARGS["data_format"], + data_range=DEFAULT_RGB_KWARGS["data_range"], + ) -> str: + save_path = self.get_save_path(filename) + self._save_rgb_image(save_path, img, data_format, data_range) + return save_path + + def get_grayscale_image_(self, img, data_range, cmap): + img = self.convert_data(img) + img = np.nan_to_num(img) + if data_range is None: + img = (img - img.min()) / (img.max() - img.min()) + else: + img = img.clip(data_range[0], data_range[1]) + img = (img - data_range[0]) / (data_range[1] - data_range[0]) + assert cmap in [None, "jet", "magma", "spectral"] + if cmap == None: + img = (img * 255.0).astype(np.uint8) + img = np.repeat(img[..., None], 3, axis=2) + elif cmap == "jet": + img = (img * 255.0).astype(np.uint8) + img = cv2.applyColorMap(img, cv2.COLORMAP_JET) + elif cmap == "magma": + img = 1.0 - img + base = cm.get_cmap("magma") + num_bins = 256 + colormap = LinearSegmentedColormap.from_list( + f"{base.name}{num_bins}", base(np.linspace(0, 1, num_bins)), num_bins + )(np.linspace(0, 1, num_bins))[:, :3] + a = np.floor(img * 255.0) + b = (a + 1).clip(max=255.0) + f = img * 255.0 - a + a = a.astype(np.uint16).clip(0, 255) + b = b.astype(np.uint16).clip(0, 255) + img = colormap[a] + (colormap[b] - colormap[a]) * f[..., None] + img = (img * 255.0).astype(np.uint8) + elif cmap == "spectral": + colormap = plt.get_cmap("Spectral") + + def blend_rgba(image): + image = image[..., :3] * image[..., -1:] + ( + 1.0 - image[..., -1:] + ) # blend A to RGB + return image + + img = colormap(img) + img = blend_rgba(img) + img = (img * 255).astype(np.uint8) + img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + return img + + def _save_grayscale_image( + self, + filename, + img, + data_range, + cmap, + ): + img = self.get_grayscale_image_(img, data_range, cmap) + cv2.imwrite(filename, img) + + def save_grayscale_image( + self, + filename, + img, + data_range=DEFAULT_GRAYSCALE_KWARGS["data_range"], + cmap=DEFAULT_GRAYSCALE_KWARGS["cmap"], + ) -> str: + save_path = self.get_save_path(filename) + self._save_grayscale_image(save_path, img, data_range, cmap) + return save_path + + def get_image_grid_(self, imgs, align): + if isinstance(imgs[0], list): + return np.concatenate( + [self.get_image_grid_(row, align) for row in imgs], axis=0 + ) + cols = [] + for col in imgs: + assert col["type"] in ["rgb", "uv", "grayscale"] + if col["type"] == "rgb": + rgb_kwargs = self.DEFAULT_RGB_KWARGS.copy() + rgb_kwargs.update(col["kwargs"]) + cols.append(self.get_rgb_image_(col["img"], **rgb_kwargs)) + elif col["type"] == "uv": + uv_kwargs = self.DEFAULT_UV_KWARGS.copy() + uv_kwargs.update(col["kwargs"]) + cols.append(self.get_uv_image_(col["img"], **uv_kwargs)) + elif col["type"] == "grayscale": + grayscale_kwargs = self.DEFAULT_GRAYSCALE_KWARGS.copy() + grayscale_kwargs.update(col["kwargs"]) + cols.append(self.get_grayscale_image_(col["img"], **grayscale_kwargs)) + + if align == "max": + h = max([col.shape[0] for col in cols]) + w = max([col.shape[1] for col in cols]) + elif align == "min": + h = min([col.shape[0] for col in cols]) + w = min([col.shape[1] for col in cols]) + elif isinstance(align, int): + h = align + w = align + elif ( + isinstance(align, tuple) + and isinstance(align[0], int) + and isinstance(align[1], int) + ): + h, w = align + else: + raise ValueError( + f"Unsupported image grid align: {align}, should be min, max, int or (int, int)" + ) + + for i in range(len(cols)): + if cols[i].shape[0] != h or cols[i].shape[1] != w: + cols[i] = cv2.resize(cols[i], (w, h), interpolation=cv2.INTER_LINEAR) + return np.concatenate(cols, axis=1) + + def save_image_grid( + self, + filename, + imgs, + align=DEFAULT_GRID_KWARGS["align"], + texts: Optional[List[float]] = None, + ): + save_path = self.get_save_path(filename) + img = self.get_image_grid_(imgs, align=align) + + if texts is not None: + img = Image.fromarray(img) + draw = ImageDraw.Draw(img) + black, white = (0, 0, 0), (255, 255, 255) + for i, text in enumerate(texts): + draw.text((2, (img.size[1] // len(texts)) * i + 1), f"{text}", white) + draw.text((0, (img.size[1] // len(texts)) * i + 1), f"{text}", white) + draw.text((2, (img.size[1] // len(texts)) * i - 1), f"{text}", white) + draw.text((0, (img.size[1] // len(texts)) * i - 1), f"{text}", white) + draw.text((1, (img.size[1] // len(texts)) * i), f"{text}", black) + img = np.asarray(img) + + cv2.imwrite(save_path, img) + return save_path + + def save_image(self, filename, img) -> str: + save_path = self.get_save_path(filename) + img = self.convert_data(img) + assert img.dtype == np.uint8 or img.dtype == np.uint16 + if img.ndim == 3 and img.shape[-1] == 3: + img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + elif img.ndim == 3 and img.shape[-1] == 4: + img = cv2.cvtColor(img, cv2.COLOR_RGBA2BGRA) + cv2.imwrite(save_path, img) + return save_path + + def save_img_sequence( + self, + filename, + img_dir, + matcher, + save_format="mp4", + fps=30, + ) -> str: + assert save_format in ["gif", "mp4"] + if not filename.endswith(save_format): + filename += f".{save_format}" + save_path = self.get_save_path(filename) + matcher = re.compile(matcher) + img_dir = os.path.join(self.get_save_dir(), img_dir) + imgs = [] + for f in os.listdir(img_dir): + if matcher.search(f): + imgs.append(f) + imgs = sorted(imgs, key=lambda f: int(matcher.search(f).groups()[0])) + imgs = [cv2.imread(os.path.join(img_dir, f)) for f in imgs] + + if save_format == "gif": + imgs = [cv2.cvtColor(i, cv2.COLOR_BGR2RGB) for i in imgs] + imageio.mimsave(save_path, imgs, fps=fps, palettesize=256) + elif save_format == "mp4": + imgs = [cv2.cvtColor(i, cv2.COLOR_BGR2RGB) for i in imgs] + imageio.mimsave(save_path, imgs, fps=fps) + return save_path + + def save_img_sequences( + self, + seq_dir, + matcher, + save_format="mp4", + fps=30, + delete=True + ): + seq_dir_ = os.path.join(self.get_save_dir(), seq_dir) + for f in os.listdir(seq_dir_): + img_dir_ = os.path.join(seq_dir_, f) + if not os.path.isdir(img_dir_): + continue + try: + self.save_img_sequence( + os.path.join(seq_dir, f), + os.path.join(seq_dir, f), + matcher, + save_format=save_format, + fps=fps + ) + except: + tgs.warn(f"Video saving for directory {seq_dir_} failed!") + + if delete: + shutil.rmtree(img_dir_) diff --git a/hort/models/tgs/utils/typing.py b/hort/models/tgs/utils/typing.py new file mode 100644 index 0000000000000000000000000000000000000000..dee9f967c21f94db1ad939d7dead156d86748752 --- /dev/null +++ b/hort/models/tgs/utils/typing.py @@ -0,0 +1,40 @@ +""" +This module contains type annotations for the project, using +1. Python type hints (https://docs.python.org/3/library/typing.html) for Python objects +2. jaxtyping (https://github.com/google/jaxtyping/blob/main/API.md) for PyTorch tensors + +Two types of typing checking can be used: +1. Static type checking with mypy (install with pip and enabled as the default linter in VSCode) +2. Runtime type checking with typeguard (install with pip and triggered at runtime, mainly for tensor dtype and shape checking) +""" + +# Basic types +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + Literal, + NamedTuple, + NewType, + Optional, + Sized, + Tuple, + Type, + TypeVar, + Union, +) + +# Tensor dtype +# for jaxtyping usage, see https://github.com/google/jaxtyping/blob/main/API.md +from jaxtyping import Bool, Complex, Float, Inexact, Int, Integer, Num, Shaped, UInt + +# Config type +from omegaconf import DictConfig + +# PyTorch Tensor type +from torch import Tensor + +# Runtime type checking decorator +from typeguard import typechecked as typechecker diff --git a/hort/utils/__init__.py b/hort/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3fa5d45defd84f4b5acb11e17f4b440b6a780168 --- /dev/null +++ b/hort/utils/__init__.py @@ -0,0 +1,3 @@ +import torch +from typing import Any +from .renderer import Renderer \ No newline at end of file diff --git a/hort/utils/img_utils.py b/hort/utils/img_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..71c1dad27a7d52566aef20203fea97e4c187ad34 --- /dev/null +++ b/hort/utils/img_utils.py @@ -0,0 +1,181 @@ +import numpy as np +import cv2 + + +def process_bbox(bbox, factor=1.25): + # aspect ratio preserving bbox + w = bbox[2] + h = bbox[3] + c_x = bbox[0] + w / 2. + c_y = bbox[1] + h / 2. + aspect_ratio = 1. + if w > aspect_ratio * h: + h = w / aspect_ratio + elif w < aspect_ratio * h: + w = h * aspect_ratio + bbox[2] = w * factor + bbox[3] = h * factor + bbox[0] = c_x - bbox[2] / 2. + bbox[1] = c_y - bbox[3] / 2. + + return bbox + + +def generate_patch_image(cvimg, bbox, input_shape, do_flip, scale, rot): + """ + @description: Modified from https://github.com/mks0601/3DMPPE_ROOTNET_RELEASE/blob/master/data/dataset.py. + generate the patch image from the bounding box and other parameters. + --------- + @param: input image, bbox(x1, y1, w, h), dest image shape, do_flip, scale factor, rotation degrees. + ------- + @Returns: processed image, affine_transform matrix to get the processed image. + ------- + """ + + img = cvimg.copy() + img_height, img_width, _ = img.shape + + bb_c_x = float(bbox[0] + 0.5 * bbox[2]) + bb_c_y = float(bbox[1] + 0.5 * bbox[3]) + bb_width = float(bbox[2]) + bb_height = float(bbox[3]) + + if do_flip: + img = img[:, ::-1, :] + bb_c_x = img_width - bb_c_x - 1 + + trans = gen_trans_from_patch_cv(bb_c_x, bb_c_y, bb_width, bb_height, input_shape[1], input_shape[0], scale, rot, inv=False) + img_patch = cv2.warpAffine(img, trans, (int(input_shape[1]), int(input_shape[0])), flags=cv2.INTER_LINEAR) + new_trans = np.zeros((3, 3), dtype=np.float32) + new_trans[:2, :] = trans + new_trans[2, 2] = 1 + + return img_patch, new_trans + + +def gen_trans_from_patch_cv(c_x, c_y, src_width, src_height, dst_width, dst_height, scale, rot, inv=False): + """ + @description: Modified from https://github.com/mks0601/3DMPPE_ROOTNET_RELEASE/blob/master/data/dataset.py. + get affine transform matrix + --------- + @param: image center, original image size, desired image size, scale factor, rotation degree, whether to get inverse transformation. + ------- + @Returns: affine transformation matrix + ------- + """ + + def rotate_2d(pt_2d, rot_rad): + x = pt_2d[0] + y = pt_2d[1] + sn, cs = np.sin(rot_rad), np.cos(rot_rad) + xx = x * cs - y * sn + yy = x * sn + y * cs + return np.array([xx, yy], dtype=np.float32) + + # augment size with scale + src_w = src_width * scale + src_h = src_height * scale + src_center = np.array([c_x, c_y], dtype=np.float32) + + # augment rotation + rot_rad = np.pi * rot / 180 + src_downdir = rotate_2d(np.array([0, src_h * 0.5], dtype=np.float32), rot_rad) + src_rightdir = rotate_2d(np.array([src_w * 0.5, 0], dtype=np.float32), rot_rad) + + dst_w = dst_width + dst_h = dst_height + dst_center = np.array([dst_w * 0.5, dst_h * 0.5], dtype=np.float32) + dst_downdir = np.array([0, dst_h * 0.5], dtype=np.float32) + dst_rightdir = np.array([dst_w * 0.5, 0], dtype=np.float32) + + src = np.zeros((3, 2), dtype=np.float32) + src[0, :] = src_center + src[1, :] = src_center + src_downdir + src[2, :] = src_center + src_rightdir + + dst = np.zeros((3, 2), dtype=np.float32) + dst[0, :] = dst_center + dst[1, :] = dst_center + dst_downdir + dst[2, :] = dst_center + dst_rightdir + + if inv: + trans = cv2.getAffineTransform(np.float32(dst), np.float32(src)) + else: + trans = cv2.getAffineTransform(np.float32(src), np.float32(dst)) + + return trans + + +class PerspectiveCamera: + def __init__(self, fx, fy, cx, cy, R=np.eye(3), t=np.zeros(3)): + self.K = np.array([[fx, 0, cx, 0], [0, fy, cy, 0], [0, 0, 1, 0]], dtype=np.float32) + + self.R = np.array(R, dtype=np.float32).copy() + assert self.R.shape == (3, 3) + + self.t = np.array(t, dtype=np.float32).copy() + assert self.t.size == 3 + self.t = self.t.reshape(3, 1) + + def update_virtual_camera_after_crop(self, bbox, option='same'): + left, upper, width, height = bbox + new_img_center = np.array([left + width / 2, upper + height / 2, 1], dtype=np.float32).reshape(3, 1) + new_cam_center = np.linalg.inv(self.K[:3, :3]).dot(new_img_center) + self.K[0, 2], self.K[1, 2] = width / 2, height / 2 + + x, y, z = new_cam_center[0], new_cam_center[1], new_cam_center[2] + sin_theta = -y / np.sqrt(1 + x ** 2 + y ** 2) + cos_theta = np.sqrt(1 + x ** 2) / np.sqrt(1 + x ** 2 + y ** 2) + R_x = np.array([[1, 0, 0], [0, cos_theta, -sin_theta], [0, sin_theta, cos_theta]], dtype=np.float32) + sin_phi = x / np.sqrt(1 + x ** 2) + cos_phi = 1 / np.sqrt(1 + x ** 2) + R_y = np.array([[cos_phi, 0, sin_phi], [0, 1, 0], [-sin_phi, 0, cos_phi]], dtype=np.float32) + self.R = R_y @ R_x + + # update focal length for virtual camera; please refer to the paper "PCLs: Geometry-aware Neural Reconstruction of 3D Pose with Perspective Crop Layers" for more details. + if option == 'length': + self.K[0, 0] = self.K[0, 0] * np.sqrt(1 + x ** 2 + y ** 2) + self.K[1, 1] = self.K[1, 1] * np.sqrt(1 + x ** 2 + y ** 2) + + if option == 'scale': + self.K[0, 0] = self.K[0, 0] * np.sqrt(1 + x ** 2 + y ** 2) * np.sqrt(1 + x ** 2) + self.K[1, 1] = self.K[1, 1] * (1 + x ** 2 + y ** 2)/ np.sqrt(1 + x ** 2) + + def update_intrinsics_after_crop(self, bbox): + left, upper, _, _ = bbox + + cx, cy = self.K[0, 2], self.K[1, 2] + + new_cx = cx - left + new_cy = cy - upper + + self.K[0, 2], self.K[1, 2] = new_cx, new_cy + + def update_intrinsics_after_resize(self, image_shape, new_image_shape): + height, width = image_shape + new_height, new_width = new_image_shape + + fx, fy, cx, cy = self.K[0, 0], self.K[1, 1], self.K[0, 2], self.K[1, 2] + + new_fx = fx * (new_width / width) + new_fy = fy * (new_height / height) + new_cx = cx * (new_width / width) + new_cy = cy * (new_height / height) + + self.K[0, 0], self.K[1, 1], self.K[0, 2], self.K[1, 2] = new_fx, new_fy, new_cx, new_cy + + def update_intrinsics_after_scale(self, scale_factor): + self.K[0, 0] /= scale_factor + self.K[1, 1] /= scale_factor + + @property + def projection(self): + return self.K.dot(self.extrinsics) + + @property + def intrinsics(self): + return self.K + + @property + def extrinsics(self): + return np.hstack([self.R, self.t]) \ No newline at end of file diff --git a/hort/utils/renderer.py b/hort/utils/renderer.py new file mode 100644 index 0000000000000000000000000000000000000000..4f6ff18ac90900d936fb75eb7bfbb56e815e8c98 --- /dev/null +++ b/hort/utils/renderer.py @@ -0,0 +1,375 @@ +import os +if 'PYOPENGL_PLATFORM' not in os.environ: + os.environ['PYOPENGL_PLATFORM'] = 'egl' +import torch +import numpy as np +import pyrender +import trimesh +import cv2 +from yacs.config import CfgNode +from typing import List, Optional + +def cam_crop_to_full(cam_bbox, box_center, box_size, img_size, focal_length=5000.): + # Convert cam_bbox to full image + img_w, img_h = img_size[:, 0], img_size[:, 1] + cx, cy, b = box_center[:, 0], box_center[:, 1], box_size + w_2, h_2 = img_w / 2., img_h / 2. + bs = b * cam_bbox[:, 0] + 1e-9 + tz = 2 * focal_length / bs + tx = (2 * (cx - w_2) / bs) + cam_bbox[:, 1] + ty = (2 * (cy - h_2) / bs) + cam_bbox[:, 2] + full_cam = torch.stack([tx, ty, tz], dim=-1) + return full_cam + +def cam_crop_to_new(cam_bbox, box_center, box_size, img_size, new_bbox, focal_length=5000.): + img_w, img_h = img_size[:, 0], img_size[:, 1] + cx, cy, b = box_center[:, 0], box_center[:, 1], box_size + w_2, h_2 = img_w / 2., img_h / 2. + bs = b * cam_bbox[:, 0] + 1e-9 + tz = 2 * focal_length / bs + + x0, y0, new_w, new_h = new_bbox[:, 0], new_bbox[:, 1], new_bbox[:, 2], new_bbox[:, 3] + tz = tz * new_w / 224 + new_cx, new_cy = x0 + new_w / 2., y0 + new_h / 2. + tx = (2 * (cx - new_cx) / bs) + cam_bbox[:, 1] + ty = (2 * (cy - new_cy) / bs) + cam_bbox[:, 2] + + full_cam_resized = torch.stack([tx, ty, tz], dim=-1) + return full_cam_resized + +def get_light_poses(n_lights=5, elevation=np.pi / 3, dist=12): + # get lights in a circle around origin at elevation + thetas = elevation * np.ones(n_lights) + phis = 2 * np.pi * np.arange(n_lights) / n_lights + poses = [] + trans = make_translation(torch.tensor([0, 0, dist])) + for phi, theta in zip(phis, thetas): + rot = make_rotation(rx=-theta, ry=phi, order="xyz") + poses.append((rot @ trans).numpy()) + return poses + +def make_translation(t): + return make_4x4_pose(torch.eye(3), t) + +def make_rotation(rx=0, ry=0, rz=0, order="xyz"): + Rx = rotx(rx) + Ry = roty(ry) + Rz = rotz(rz) + if order == "xyz": + R = Rz @ Ry @ Rx + elif order == "xzy": + R = Ry @ Rz @ Rx + elif order == "yxz": + R = Rz @ Rx @ Ry + elif order == "yzx": + R = Rx @ Rz @ Ry + elif order == "zyx": + R = Rx @ Ry @ Rz + elif order == "zxy": + R = Ry @ Rx @ Rz + return make_4x4_pose(R, torch.zeros(3)) + +def make_4x4_pose(R, t): + """ + :param R (*, 3, 3) + :param t (*, 3) + return (*, 4, 4) + """ + dims = R.shape[:-2] + pose_3x4 = torch.cat([R, t.view(*dims, 3, 1)], dim=-1) + bottom = ( + torch.tensor([0, 0, 0, 1], device=R.device) + .reshape(*(1,) * len(dims), 1, 4) + .expand(*dims, 1, 4) + ) + return torch.cat([pose_3x4, bottom], dim=-2) + + +def rotx(theta): + return torch.tensor( + [ + [1, 0, 0], + [0, np.cos(theta), -np.sin(theta)], + [0, np.sin(theta), np.cos(theta)], + ], + dtype=torch.float32, + ) + + +def roty(theta): + return torch.tensor( + [ + [np.cos(theta), 0, np.sin(theta)], + [0, 1, 0], + [-np.sin(theta), 0, np.cos(theta)], + ], + dtype=torch.float32, + ) + + +def rotz(theta): + return torch.tensor( + [ + [np.cos(theta), -np.sin(theta), 0], + [np.sin(theta), np.cos(theta), 0], + [0, 0, 1], + ], + dtype=torch.float32, + ) + + +def create_raymond_lights() -> List[pyrender.Node]: + """ + Return raymond light nodes for the scene. + """ + thetas = np.pi * np.array([1.0 / 6.0, 1.0 / 6.0, 1.0 / 6.0]) + phis = np.pi * np.array([0.0, 2.0 / 3.0, 4.0 / 3.0]) + + nodes = [] + + for phi, theta in zip(phis, thetas): + xp = np.sin(theta) * np.cos(phi) + yp = np.sin(theta) * np.sin(phi) + zp = np.cos(theta) + + z = np.array([xp, yp, zp]) + z = z / np.linalg.norm(z) + x = np.array([-z[1], z[0], 0.0]) + if np.linalg.norm(x) == 0: + x = np.array([1.0, 0.0, 0.0]) + x = x / np.linalg.norm(x) + y = np.cross(z, x) + + matrix = np.eye(4) + matrix[:3,:3] = np.c_[x,y,z] + nodes.append(pyrender.Node( + light=pyrender.DirectionalLight(color=np.ones(3), intensity=1.0), + matrix=matrix + )) + + return nodes + +class Renderer: + + def __init__(self, cfg: CfgNode, faces: np.array): + """ + Wrapper around the pyrender renderer to render MANO meshes. + Args: + cfg (CfgNode): Model config file. + faces (np.array): Array of shape (F, 3) containing the mesh faces. + """ + self.cfg = cfg + self.focal_length = cfg.EXTRA.FOCAL_LENGTH + self.img_res = cfg.MODEL.IMAGE_SIZE + + # add faces that make the hand mesh watertight + faces_new = np.array([[92, 38, 234], + [234, 38, 239], + [38, 122, 239], + [239, 122, 279], + [122, 118, 279], + [279, 118, 215], + [118, 117, 215], + [215, 117, 214], + [117, 119, 214], + [214, 119, 121], + [119, 120, 121], + [121, 120, 78], + [120, 108, 78], + [78, 108, 79]]) + faces = np.concatenate([faces, faces_new], axis=0) + + self.camera_center = [self.img_res // 2, self.img_res // 2] + self.faces = faces + self.faces_left = self.faces[:,[0,2,1]] + + def __call__(self, + vertices: np.array, + camera_translation: np.array, + image: torch.Tensor, + full_frame: bool = False, + imgname: Optional[str] = None, + side_view=False, rot_angle=90, + mesh_base_color=(1.0, 1.0, 0.9), + scene_bg_color=(0,0,0), + return_rgba=False, + ) -> np.array: + """ + Render meshes on input image + Args: + vertices (np.array): Array of shape (V, 3) containing the mesh vertices. + camera_translation (np.array): Array of shape (3,) with the camera translation. + image (torch.Tensor): Tensor of shape (3, H, W) containing the image crop with normalized pixel values. + full_frame (bool): If True, then render on the full image. + imgname (Optional[str]): Contains the original image filenamee. Used only if full_frame == True. + """ + + if full_frame: + image = cv2.imread(imgname).astype(np.float32)[:, :, ::-1] / 255. + else: + image = image.clone() * torch.tensor(self.cfg.MODEL.IMAGE_STD, device=image.device).reshape(3,1,1) + image = image + torch.tensor(self.cfg.MODEL.IMAGE_MEAN, device=image.device).reshape(3,1,1) + image = image.permute(1, 2, 0).cpu().numpy() + + renderer = pyrender.OffscreenRenderer(viewport_width=image.shape[1], + viewport_height=image.shape[0], + point_size=1.0) + material = pyrender.MetallicRoughnessMaterial( + metallicFactor=0.0, + alphaMode='OPAQUE', + baseColorFactor=(*mesh_base_color, 1.0)) + + camera_translation[0] *= -1. + + mesh = trimesh.Trimesh(vertices.copy(), self.faces.copy()) + if side_view: + rot = trimesh.transformations.rotation_matrix( + np.radians(rot_angle), [0, 1, 0]) + mesh.apply_transform(rot) + rot = trimesh.transformations.rotation_matrix( + np.radians(180), [1, 0, 0]) + mesh.apply_transform(rot) + mesh = pyrender.Mesh.from_trimesh(mesh, material=material) + + scene = pyrender.Scene(bg_color=[*scene_bg_color, 0.0], + ambient_light=(0.3, 0.3, 0.3)) + scene.add(mesh, 'mesh') + + camera_pose = np.eye(4) + camera_pose[:3, 3] = camera_translation + camera_center = [image.shape[1] / 2., image.shape[0] / 2.] + camera = pyrender.IntrinsicsCamera(fx=self.focal_length, fy=self.focal_length, + cx=camera_center[0], cy=camera_center[1], zfar=1e12) + scene.add(camera, pose=camera_pose) + + + light_nodes = create_raymond_lights() + for node in light_nodes: + scene.add_node(node) + + color, rend_depth = renderer.render(scene, flags=pyrender.RenderFlags.RGBA) + color = color.astype(np.float32) / 255.0 + renderer.delete() + + if return_rgba: + return color + + valid_mask = (color[:, :, -1])[:, :, np.newaxis] + if not side_view: + output_img = (color[:, :, :3] * valid_mask + (1 - valid_mask) * image) + else: + output_img = color[:, :, :3] + + output_img = output_img.astype(np.float32) + return output_img + + def vertices_to_trimesh(self, vertices, camera_translation, mesh_base_color=(1.0, 1.0, 0.9), + rot_axis=[1,0,0], rot_angle=0, is_right=1): + vertex_colors = np.array([(*mesh_base_color, 1.0)] * vertices.shape[0]) + if is_right: + mesh = trimesh.Trimesh(vertices.copy() + camera_translation, self.faces.copy(), vertex_colors=vertex_colors) + else: + mesh = trimesh.Trimesh(vertices.copy() + camera_translation, self.faces_left.copy(), vertex_colors=vertex_colors) + + rot = trimesh.transformations.rotation_matrix( + np.radians(rot_angle), rot_axis) + mesh.apply_transform(rot) + + rot = trimesh.transformations.rotation_matrix( + np.radians(180), [1, 0, 0]) + mesh.apply_transform(rot) + return mesh + + def render_rgba( + self, + vertices: np.array, + points: np.array, + cam_t = None, + rot=None, + rot_axis=[1,0,0], + rot_angle=0, + camera_z=3, + mesh_base_color=(1.0, 1.0, 0.9), + point_base_color=(1.0, 1.0, 0.9), + scene_bg_color=(0,0,0), + render_res=[256, 256], + focal_length=None, + is_right=None, + ): + + renderer = pyrender.OffscreenRenderer(viewport_width=render_res[0], viewport_height=render_res[1], point_size=3) + focal_length = focal_length if focal_length is not None else self.focal_length + + if cam_t is not None: + camera_translation = cam_t.copy() + camera_translation[0] *= -1. + else: + camera_translation = np.array([0, 0, camera_z * focal_length/render_res[1]]) + + mesh = self.vertices_to_trimesh(vertices, np.array([0, 0, 0]), mesh_base_color, rot_axis, rot_angle, is_right=is_right) + mesh = pyrender.Mesh.from_trimesh(mesh) + rot = trimesh.transformations.rotation_matrix(np.radians(180), [1, 0, 0]) + points = (rot[:3, :3] @ points.transpose(1, 0)).transpose(1, 0) + point_cloud = pyrender.Mesh.from_points(points=points, colors=point_base_color) + + scene = pyrender.Scene(bg_color=[*scene_bg_color, 0.0], ambient_light=(0.3, 0.3, 0.3)) + scene.add(mesh, 'mesh') + scene.add(point_cloud, 'pc') + + camera_pose = np.eye(4) + camera_pose[:3, 3] = camera_translation + camera_center = [render_res[0] / 2., render_res[1] / 2.] + camera = pyrender.IntrinsicsCamera(fx=focal_length, fy=focal_length, cx=camera_center[0], cy=camera_center[1], zfar=1e12) + + # Create camera node and add it to pyRender scene + camera_node = pyrender.Node(camera=camera, matrix=camera_pose) + scene.add_node(camera_node) + self.add_point_lighting(scene, camera_node) + self.add_lighting(scene, camera_node) + + light_nodes = create_raymond_lights() + for node in light_nodes: + scene.add_node(node) + + color, rend_depth = renderer.render(scene, flags=pyrender.RenderFlags.RGBA) + color = color.astype(np.float32) / 255.0 + renderer.delete() + + return color + + def add_lighting(self, scene, cam_node, color=np.ones(3), intensity=1.0): + # from phalp.visualize.py_renderer import get_light_poses + light_poses = get_light_poses() + light_poses.append(np.eye(4)) + cam_pose = scene.get_pose(cam_node) + for i, pose in enumerate(light_poses): + matrix = cam_pose @ pose + node = pyrender.Node( + name=f"light-{i:02d}", + light=pyrender.DirectionalLight(color=color, intensity=intensity), + matrix=matrix, + ) + if scene.has_node(node): + continue + scene.add_node(node) + + def add_point_lighting(self, scene, cam_node, color=np.ones(3), intensity=1.0): + # from phalp.visualize.py_renderer import get_light_poses + light_poses = get_light_poses(dist=0.5) + light_poses.append(np.eye(4)) + cam_pose = scene.get_pose(cam_node) + for i, pose in enumerate(light_poses): + matrix = cam_pose @ pose + # node = pyrender.Node( + # name=f"light-{i:02d}", + # light=pyrender.DirectionalLight(color=color, intensity=intensity), + # matrix=matrix, + # ) + node = pyrender.Node( + name=f"plight-{i:02d}", + light=pyrender.PointLight(color=color, intensity=intensity), + matrix=matrix, + ) + if scene.has_node(node): + continue + scene.add_node(node) diff --git a/mano_data/closed_fmano.npy b/mano_data/closed_fmano.npy new file mode 100644 index 0000000000000000000000000000000000000000..ee3ee84dc7ffc1f13da6a772633f24d3d4ff7b8c --- /dev/null +++ b/mano_data/closed_fmano.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3e6842048a9a1bce51b07d9f148b676699d4ff607508fe6b104268a15f3560bf +size 37376 diff --git a/mano_data/mano/MANO_RIGHT.pkl b/mano_data/mano/MANO_RIGHT.pkl new file mode 100755 index 0000000000000000000000000000000000000000..8e7ac7faf64ad51096ec1da626ea13757ed7f665 --- /dev/null +++ b/mano_data/mano/MANO_RIGHT.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:45d60aa3b27ef9107a7afd4e00808f307fd91111e1cfa35afd5c4a62de264767 +size 3821356 diff --git a/mano_data/mano_mean_params.npz b/mano_data/mano_mean_params.npz new file mode 100644 index 0000000000000000000000000000000000000000..dc294b01fb78a9cd6636c87a69b59cf82d28d15b --- /dev/null +++ b/mano_data/mano_mean_params.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:efc0ec58e4a5cef78f3abfb4e8f91623b8950be9eff8b8e0dbb0d036ebc63988 +size 1178 diff --git a/packages.txt b/packages.txt new file mode 100644 index 0000000000000000000000000000000000000000..e7dbe446102d6b63616219acf114e844b6c266ee --- /dev/null +++ b/packages.txt @@ -0,0 +1,12 @@ +libglfw3-dev +libgles2-mesa-dev +libgl1 +freeglut3-dev +unzip +ffmpeg +libsm6 +libxext6 +libgl1-mesa-dri +libegl1-mesa +libgbm1 +build-essential diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..25b2e53fafafb660ff38f893370cb12f88a68f1a --- /dev/null +++ b/requirements.txt @@ -0,0 +1,46 @@ +opencv-python +numpy +trimesh +plyfile +pyyaml +scikit-image +scikit-learn +chumpy +tensorboard +kornia +loguru +pycocotools +yacs +lmdb +fire +setuptools +einops +tqdm +ipython +gym +transformers==4.44.2 +OmegaConf +matplotlib +gradio==5.0.2 +diffusers +rembg +segment_anything +jaxtyping +imageio +iopath +timm +open3d +pyrender +pytorch-lightning +smplx==0.1.28 +chumpy @ git+https://github.com/zerchen/chumpy +xtcocotools +pandas +hydra-core +hydra-submitit-launcher +hydra-colorlog +pyrootutils +rich +webdataset +ultralytics +dill diff --git a/wilor/configs/__init__.py b/wilor/configs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3bd9dcf90daa9627cb482d2c2e602288e09b1f1b --- /dev/null +++ b/wilor/configs/__init__.py @@ -0,0 +1,114 @@ +import os +from typing import Dict +from yacs.config import CfgNode as CN + +CACHE_DIR_PRETRAINED = "./pretrained_models/" + +def to_lower(x: Dict) -> Dict: + """ + Convert all dictionary keys to lowercase + Args: + x (dict): Input dictionary + Returns: + dict: Output dictionary with all keys converted to lowercase + """ + return {k.lower(): v for k, v in x.items()} + +_C = CN(new_allowed=True) + +_C.GENERAL = CN(new_allowed=True) +_C.GENERAL.RESUME = True +_C.GENERAL.TIME_TO_RUN = 3300 +_C.GENERAL.VAL_STEPS = 100 +_C.GENERAL.LOG_STEPS = 100 +_C.GENERAL.CHECKPOINT_STEPS = 20000 +_C.GENERAL.CHECKPOINT_DIR = "checkpoints" +_C.GENERAL.SUMMARY_DIR = "tensorboard" +_C.GENERAL.NUM_GPUS = 1 +_C.GENERAL.NUM_WORKERS = 4 +_C.GENERAL.MIXED_PRECISION = True +_C.GENERAL.ALLOW_CUDA = True +_C.GENERAL.PIN_MEMORY = False +_C.GENERAL.DISTRIBUTED = False +_C.GENERAL.LOCAL_RANK = 0 +_C.GENERAL.USE_SYNCBN = False +_C.GENERAL.WORLD_SIZE = 1 + +_C.TRAIN = CN(new_allowed=True) +_C.TRAIN.NUM_EPOCHS = 100 +_C.TRAIN.BATCH_SIZE = 32 +_C.TRAIN.SHUFFLE = True +_C.TRAIN.WARMUP = False +_C.TRAIN.NORMALIZE_PER_IMAGE = False +_C.TRAIN.CLIP_GRAD = False +_C.TRAIN.CLIP_GRAD_VALUE = 1.0 +_C.LOSS_WEIGHTS = CN(new_allowed=True) + +_C.DATASETS = CN(new_allowed=True) + +_C.MODEL = CN(new_allowed=True) +_C.MODEL.IMAGE_SIZE = 224 + +_C.EXTRA = CN(new_allowed=True) +_C.EXTRA.FOCAL_LENGTH = 5000 + +_C.DATASETS.CONFIG = CN(new_allowed=True) +_C.DATASETS.CONFIG.SCALE_FACTOR = 0.3 +_C.DATASETS.CONFIG.ROT_FACTOR = 30 +_C.DATASETS.CONFIG.TRANS_FACTOR = 0.02 +_C.DATASETS.CONFIG.COLOR_SCALE = 0.2 +_C.DATASETS.CONFIG.ROT_AUG_RATE = 0.6 +_C.DATASETS.CONFIG.TRANS_AUG_RATE = 0.5 +_C.DATASETS.CONFIG.DO_FLIP = False +_C.DATASETS.CONFIG.FLIP_AUG_RATE = 0.5 +_C.DATASETS.CONFIG.EXTREME_CROP_AUG_RATE = 0.10 + +def default_config() -> CN: + """ + Get a yacs CfgNode object with the default config values. + """ + # Return a clone so that the defaults will not be altered + # This is for the "local variable" use pattern + return _C.clone() + +def dataset_config(name='datasets_tar.yaml') -> CN: + """ + Get dataset config file + Returns: + CfgNode: Dataset config as a yacs CfgNode object. + """ + cfg = CN(new_allowed=True) + config_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), name) + cfg.merge_from_file(config_file) + cfg.freeze() + return cfg + +def dataset_eval_config() -> CN: + return dataset_config('datasets_eval.yaml') + +def get_config(config_file: str, merge: bool = True, update_cachedir: bool = False) -> CN: + """ + Read a config file and optionally merge it with the default config file. + Args: + config_file (str): Path to config file. + merge (bool): Whether to merge with the default config or not. + Returns: + CfgNode: Config as a yacs CfgNode object. + """ + if merge: + cfg = default_config() + else: + cfg = CN(new_allowed=True) + cfg.merge_from_file(config_file) + + if update_cachedir: + def update_path(path: str) -> str: + if os.path.isabs(path): + return path + return os.path.join(CACHE_DIR_PRETRAINED, path) + + cfg.MANO.MODEL_PATH = update_path(cfg.MANO.MODEL_PATH) + cfg.MANO.MEAN_PARAMS = update_path(cfg.MANO.MEAN_PARAMS) + + cfg.freeze() + return cfg \ No newline at end of file diff --git a/wilor/datasets/utils.py b/wilor/datasets/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2b09722ef538f25f736ab387e15df259966ce7b7 --- /dev/null +++ b/wilor/datasets/utils.py @@ -0,0 +1,994 @@ +""" +Parts of the code are taken or adapted from +https://github.com/mkocabas/EpipolarPose/blob/master/lib/utils/img_utils.py +""" +import torch +import numpy as np +from skimage.transform import rotate, resize +from skimage.filters import gaussian +import random +import cv2 +from typing import List, Dict, Tuple +from yacs.config import CfgNode + +def expand_to_aspect_ratio(input_shape, target_aspect_ratio=None): + """Increase the size of the bounding box to match the target shape.""" + if target_aspect_ratio is None: + return input_shape + + try: + w , h = input_shape + except (ValueError, TypeError): + return input_shape + + w_t, h_t = target_aspect_ratio + if h / w < h_t / w_t: + h_new = max(w * h_t / w_t, h) + w_new = w + else: + h_new = h + w_new = max(h * w_t / h_t, w) + if h_new < h or w_new < w: + breakpoint() + return np.array([w_new, h_new]) + +def do_augmentation(aug_config: CfgNode) -> Tuple: + """ + Compute random augmentation parameters. + Args: + aug_config (CfgNode): Config containing augmentation parameters. + Returns: + scale (float): Box rescaling factor. + rot (float): Random image rotation. + do_flip (bool): Whether to flip image or not. + do_extreme_crop (bool): Whether to apply extreme cropping (as proposed in EFT). + color_scale (List): Color rescaling factor + tx (float): Random translation along the x axis. + ty (float): Random translation along the y axis. + """ + + tx = np.clip(np.random.randn(), -1.0, 1.0) * aug_config.TRANS_FACTOR + ty = np.clip(np.random.randn(), -1.0, 1.0) * aug_config.TRANS_FACTOR + scale = np.clip(np.random.randn(), -1.0, 1.0) * aug_config.SCALE_FACTOR + 1.0 + rot = np.clip(np.random.randn(), -2.0, + 2.0) * aug_config.ROT_FACTOR if random.random() <= aug_config.ROT_AUG_RATE else 0 + do_flip = aug_config.DO_FLIP and random.random() <= aug_config.FLIP_AUG_RATE + do_extreme_crop = random.random() <= aug_config.EXTREME_CROP_AUG_RATE + extreme_crop_lvl = aug_config.get('EXTREME_CROP_AUG_LEVEL', 0) + # extreme_crop_lvl = 0 + c_up = 1.0 + aug_config.COLOR_SCALE + c_low = 1.0 - aug_config.COLOR_SCALE + color_scale = [random.uniform(c_low, c_up), random.uniform(c_low, c_up), random.uniform(c_low, c_up)] + return scale, rot, do_flip, do_extreme_crop, extreme_crop_lvl, color_scale, tx, ty + +def rotate_2d(pt_2d: np.array, rot_rad: float) -> np.array: + """ + Rotate a 2D point on the x-y plane. + Args: + pt_2d (np.array): Input 2D point with shape (2,). + rot_rad (float): Rotation angle + Returns: + np.array: Rotated 2D point. + """ + x = pt_2d[0] + y = pt_2d[1] + sn, cs = np.sin(rot_rad), np.cos(rot_rad) + xx = x * cs - y * sn + yy = x * sn + y * cs + return np.array([xx, yy], dtype=np.float32) + + +def gen_trans_from_patch_cv(c_x: float, c_y: float, + src_width: float, src_height: float, + dst_width: float, dst_height: float, + scale: float, rot: float) -> np.array: + """ + Create transformation matrix for the bounding box crop. + Args: + c_x (float): Bounding box center x coordinate in the original image. + c_y (float): Bounding box center y coordinate in the original image. + src_width (float): Bounding box width. + src_height (float): Bounding box height. + dst_width (float): Output box width. + dst_height (float): Output box height. + scale (float): Rescaling factor for the bounding box (augmentation). + rot (float): Random rotation applied to the box. + Returns: + trans (np.array): Target geometric transformation. + """ + # augment size with scale + src_w = src_width * scale + src_h = src_height * scale + src_center = np.zeros(2) + src_center[0] = c_x + src_center[1] = c_y + # augment rotation + rot_rad = np.pi * rot / 180 + src_downdir = rotate_2d(np.array([0, src_h * 0.5], dtype=np.float32), rot_rad) + src_rightdir = rotate_2d(np.array([src_w * 0.5, 0], dtype=np.float32), rot_rad) + + dst_w = dst_width + dst_h = dst_height + dst_center = np.array([dst_w * 0.5, dst_h * 0.5], dtype=np.float32) + dst_downdir = np.array([0, dst_h * 0.5], dtype=np.float32) + dst_rightdir = np.array([dst_w * 0.5, 0], dtype=np.float32) + + src = np.zeros((3, 2), dtype=np.float32) + src[0, :] = src_center + src[1, :] = src_center + src_downdir + src[2, :] = src_center + src_rightdir + + dst = np.zeros((3, 2), dtype=np.float32) + dst[0, :] = dst_center + dst[1, :] = dst_center + dst_downdir + dst[2, :] = dst_center + dst_rightdir + + trans = cv2.getAffineTransform(np.float32(src), np.float32(dst)) + + return trans + + +def trans_point2d(pt_2d: np.array, trans: np.array): + """ + Transform a 2D point using translation matrix trans. + Args: + pt_2d (np.array): Input 2D point with shape (2,). + trans (np.array): Transformation matrix. + Returns: + np.array: Transformed 2D point. + """ + src_pt = np.array([pt_2d[0], pt_2d[1], 1.]).T + dst_pt = np.dot(trans, src_pt) + return dst_pt[0:2] + +def get_transform(center, scale, res, rot=0): + """Generate transformation matrix.""" + """Taken from PARE: https://github.com/mkocabas/PARE/blob/6e0caca86c6ab49ff80014b661350958e5b72fd8/pare/utils/image_utils.py""" + h = 200 * scale + t = np.zeros((3, 3)) + t[0, 0] = float(res[1]) / h + t[1, 1] = float(res[0]) / h + t[0, 2] = res[1] * (-float(center[0]) / h + .5) + t[1, 2] = res[0] * (-float(center[1]) / h + .5) + t[2, 2] = 1 + if not rot == 0: + rot = -rot # To match direction of rotation from cropping + rot_mat = np.zeros((3, 3)) + rot_rad = rot * np.pi / 180 + sn, cs = np.sin(rot_rad), np.cos(rot_rad) + rot_mat[0, :2] = [cs, -sn] + rot_mat[1, :2] = [sn, cs] + rot_mat[2, 2] = 1 + # Need to rotate around center + t_mat = np.eye(3) + t_mat[0, 2] = -res[1] / 2 + t_mat[1, 2] = -res[0] / 2 + t_inv = t_mat.copy() + t_inv[:2, 2] *= -1 + t = np.dot(t_inv, np.dot(rot_mat, np.dot(t_mat, t))) + return t + + +def transform(pt, center, scale, res, invert=0, rot=0, as_int=True): + """Transform pixel location to different reference.""" + """Taken from PARE: https://github.com/mkocabas/PARE/blob/6e0caca86c6ab49ff80014b661350958e5b72fd8/pare/utils/image_utils.py""" + t = get_transform(center, scale, res, rot=rot) + if invert: + t = np.linalg.inv(t) + new_pt = np.array([pt[0] - 1, pt[1] - 1, 1.]).T + new_pt = np.dot(t, new_pt) + if as_int: + new_pt = new_pt.astype(int) + return new_pt[:2] + 1 + +def crop_img(img, ul, br, border_mode=cv2.BORDER_CONSTANT, border_value=0): + c_x = (ul[0] + br[0])/2 + c_y = (ul[1] + br[1])/2 + bb_width = patch_width = br[0] - ul[0] + bb_height = patch_height = br[1] - ul[1] + trans = gen_trans_from_patch_cv(c_x, c_y, bb_width, bb_height, patch_width, patch_height, 1.0, 0) + img_patch = cv2.warpAffine(img, trans, (int(patch_width), int(patch_height)), + flags=cv2.INTER_LINEAR, + borderMode=border_mode, + borderValue=border_value + ) + + # Force borderValue=cv2.BORDER_CONSTANT for alpha channel + if (img.shape[2] == 4) and (border_mode != cv2.BORDER_CONSTANT): + img_patch[:,:,3] = cv2.warpAffine(img[:,:,3], trans, (int(patch_width), int(patch_height)), + flags=cv2.INTER_LINEAR, + borderMode=cv2.BORDER_CONSTANT, + ) + + return img_patch + +def generate_image_patch_skimage(img: np.array, c_x: float, c_y: float, + bb_width: float, bb_height: float, + patch_width: float, patch_height: float, + do_flip: bool, scale: float, rot: float, + border_mode=cv2.BORDER_CONSTANT, border_value=0) -> Tuple[np.array, np.array]: + """ + Crop image according to the supplied bounding box. + Args: + img (np.array): Input image of shape (H, W, 3) + c_x (float): Bounding box center x coordinate in the original image. + c_y (float): Bounding box center y coordinate in the original image. + bb_width (float): Bounding box width. + bb_height (float): Bounding box height. + patch_width (float): Output box width. + patch_height (float): Output box height. + do_flip (bool): Whether to flip image or not. + scale (float): Rescaling factor for the bounding box (augmentation). + rot (float): Random rotation applied to the box. + Returns: + img_patch (np.array): Cropped image patch of shape (patch_height, patch_height, 3) + trans (np.array): Transformation matrix. + """ + + img_height, img_width, img_channels = img.shape + if do_flip: + img = img[:, ::-1, :] + c_x = img_width - c_x - 1 + + trans = gen_trans_from_patch_cv(c_x, c_y, bb_width, bb_height, patch_width, patch_height, scale, rot) + + #img_patch = cv2.warpAffine(img, trans, (int(patch_width), int(patch_height)), flags=cv2.INTER_LINEAR) + + # skimage + center = np.zeros(2) + center[0] = c_x + center[1] = c_y + res = np.zeros(2) + res[0] = patch_width + res[1] = patch_height + # assumes bb_width = bb_height + # assumes patch_width = patch_height + assert bb_width == bb_height, f'{bb_width=} != {bb_height=}' + assert patch_width == patch_height, f'{patch_width=} != {patch_height=}' + scale1 = scale*bb_width/200. + + # Upper left point + ul = np.array(transform([1, 1], center, scale1, res, invert=1, as_int=False)) - 1 + # Bottom right point + br = np.array(transform([res[0] + 1, + res[1] + 1], center, scale1, res, invert=1, as_int=False)) - 1 + + # Padding so that when rotated proper amount of context is included + try: + pad = int(np.linalg.norm(br - ul) / 2 - float(br[1] - ul[1]) / 2) + 1 + except: + breakpoint() + if not rot == 0: + ul -= pad + br += pad + + + if False: + # Old way of cropping image + ul_int = ul.astype(int) + br_int = br.astype(int) + new_shape = [br_int[1] - ul_int[1], br_int[0] - ul_int[0]] + if len(img.shape) > 2: + new_shape += [img.shape[2]] + new_img = np.zeros(new_shape) + + # Range to fill new array + new_x = max(0, -ul_int[0]), min(br_int[0], len(img[0])) - ul_int[0] + new_y = max(0, -ul_int[1]), min(br_int[1], len(img)) - ul_int[1] + # Range to sample from original image + old_x = max(0, ul_int[0]), min(len(img[0]), br_int[0]) + old_y = max(0, ul_int[1]), min(len(img), br_int[1]) + new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1], + old_x[0]:old_x[1]] + + # New way of cropping image + new_img = crop_img(img, ul, br, border_mode=border_mode, border_value=border_value).astype(np.float32) + + # print(f'{new_img.shape=}') + # print(f'{new_img1.shape=}') + # print(f'{np.allclose(new_img, new_img1)=}') + # print(f'{img.dtype=}') + + + if not rot == 0: + # Remove padding + + new_img = rotate(new_img, rot) # scipy.misc.imrotate(new_img, rot) + new_img = new_img[pad:-pad, pad:-pad] + + if new_img.shape[0] < 1 or new_img.shape[1] < 1: + print(f'{img.shape=}') + print(f'{new_img.shape=}') + print(f'{ul=}') + print(f'{br=}') + print(f'{pad=}') + print(f'{rot=}') + + breakpoint() + + # resize image + new_img = resize(new_img, res) # scipy.misc.imresize(new_img, res) + + new_img = np.clip(new_img, 0, 255).astype(np.uint8) + + return new_img, trans + + +def generate_image_patch_cv2(img: np.array, c_x: float, c_y: float, + bb_width: float, bb_height: float, + patch_width: float, patch_height: float, + do_flip: bool, scale: float, rot: float, + border_mode=cv2.BORDER_CONSTANT, border_value=0) -> Tuple[np.array, np.array]: + """ + Crop the input image and return the crop and the corresponding transformation matrix. + Args: + img (np.array): Input image of shape (H, W, 3) + c_x (float): Bounding box center x coordinate in the original image. + c_y (float): Bounding box center y coordinate in the original image. + bb_width (float): Bounding box width. + bb_height (float): Bounding box height. + patch_width (float): Output box width. + patch_height (float): Output box height. + do_flip (bool): Whether to flip image or not. + scale (float): Rescaling factor for the bounding box (augmentation). + rot (float): Random rotation applied to the box. + Returns: + img_patch (np.array): Cropped image patch of shape (patch_height, patch_height, 3) + trans (np.array): Transformation matrix. + """ + + img_height, img_width, img_channels = img.shape + if do_flip: + img = img[:, ::-1, :] + c_x = img_width - c_x - 1 + + + trans = gen_trans_from_patch_cv(c_x, c_y, bb_width, bb_height, patch_width, patch_height, scale, rot) + + img_patch = cv2.warpAffine(img, trans, (int(patch_width), int(patch_height)), + flags=cv2.INTER_LINEAR, + borderMode=border_mode, + borderValue=border_value, + ) + # Force borderValue=cv2.BORDER_CONSTANT for alpha channel + if (img.shape[2] == 4) and (border_mode != cv2.BORDER_CONSTANT): + img_patch[:,:,3] = cv2.warpAffine(img[:,:,3], trans, (int(patch_width), int(patch_height)), + flags=cv2.INTER_LINEAR, + borderMode=cv2.BORDER_CONSTANT, + ) + + return img_patch, trans + + +def convert_cvimg_to_tensor(cvimg: np.array): + """ + Convert image from HWC to CHW format. + Args: + cvimg (np.array): Image of shape (H, W, 3) as loaded by OpenCV. + Returns: + np.array: Output image of shape (3, H, W). + """ + # from h,w,c(OpenCV) to c,h,w + img = cvimg.copy() + img = np.transpose(img, (2, 0, 1)) + # from int to float + img = img.astype(np.float32) + return img + +def fliplr_params(mano_params: Dict, has_mano_params: Dict) -> Tuple[Dict, Dict]: + """ + Flip MANO parameters when flipping the image. + Args: + mano_params (Dict): MANO parameter annotations. + has_mano_params (Dict): Whether MANO annotations are valid. + Returns: + Dict, Dict: Flipped MANO parameters and valid flags. + """ + global_orient = mano_params['global_orient'].copy() + hand_pose = mano_params['hand_pose'].copy() + betas = mano_params['betas'].copy() + has_global_orient = has_mano_params['global_orient'].copy() + has_hand_pose = has_mano_params['hand_pose'].copy() + has_betas = has_mano_params['betas'].copy() + + global_orient[1::3] *= -1 + global_orient[2::3] *= -1 + hand_pose[1::3] *= -1 + hand_pose[2::3] *= -1 + + mano_params = {'global_orient': global_orient.astype(np.float32), + 'hand_pose': hand_pose.astype(np.float32), + 'betas': betas.astype(np.float32) + } + + has_mano_params = {'global_orient': has_global_orient, + 'hand_pose': has_hand_pose, + 'betas': has_betas + } + + return mano_params, has_mano_params + + +def fliplr_keypoints(joints: np.array, width: float, flip_permutation: List[int]) -> np.array: + """ + Flip 2D or 3D keypoints. + Args: + joints (np.array): Array of shape (N, 3) or (N, 4) containing 2D or 3D keypoint locations and confidence. + flip_permutation (List): Permutation to apply after flipping. + Returns: + np.array: Flipped 2D or 3D keypoints with shape (N, 3) or (N, 4) respectively. + """ + joints = joints.copy() + # Flip horizontal + joints[:, 0] = width - joints[:, 0] - 1 + joints = joints[flip_permutation, :] + + return joints + +def keypoint_3d_processing(keypoints_3d: np.array, flip_permutation: List[int], rot: float, do_flip: float) -> np.array: + """ + Process 3D keypoints (rotation/flipping). + Args: + keypoints_3d (np.array): Input array of shape (N, 4) containing the 3D keypoints and confidence. + flip_permutation (List): Permutation to apply after flipping. + rot (float): Random rotation applied to the keypoints. + do_flip (bool): Whether to flip keypoints or not. + Returns: + np.array: Transformed 3D keypoints with shape (N, 4). + """ + if do_flip: + keypoints_3d = fliplr_keypoints(keypoints_3d, 1, flip_permutation) + # in-plane rotation + rot_mat = np.eye(3) + if not rot == 0: + rot_rad = -rot * np.pi / 180 + sn,cs = np.sin(rot_rad), np.cos(rot_rad) + rot_mat[0,:2] = [cs, -sn] + rot_mat[1,:2] = [sn, cs] + keypoints_3d[:, :-1] = np.einsum('ij,kj->ki', rot_mat, keypoints_3d[:, :-1]) + # flip the x coordinates + keypoints_3d = keypoints_3d.astype('float32') + return keypoints_3d + +def rot_aa(aa: np.array, rot: float) -> np.array: + """ + Rotate axis angle parameters. + Args: + aa (np.array): Axis-angle vector of shape (3,). + rot (np.array): Rotation angle in degrees. + Returns: + np.array: Rotated axis-angle vector. + """ + # pose parameters + R = np.array([[np.cos(np.deg2rad(-rot)), -np.sin(np.deg2rad(-rot)), 0], + [np.sin(np.deg2rad(-rot)), np.cos(np.deg2rad(-rot)), 0], + [0, 0, 1]]) + # find the rotation of the hand in camera frame + per_rdg, _ = cv2.Rodrigues(aa) + # apply the global rotation to the global orientation + resrot, _ = cv2.Rodrigues(np.dot(R,per_rdg)) + aa = (resrot.T)[0] + return aa.astype(np.float32) + +def mano_param_processing(mano_params: Dict, has_mano_params: Dict, rot: float, do_flip: bool) -> Tuple[Dict, Dict]: + """ + Apply random augmentations to the MANO parameters. + Args: + mano_params (Dict): MANO parameter annotations. + has_mano_params (Dict): Whether mano annotations are valid. + rot (float): Random rotation applied to the keypoints. + do_flip (bool): Whether to flip keypoints or not. + Returns: + Dict, Dict: Transformed MANO parameters and valid flags. + """ + if do_flip: + mano_params, has_mano_params = fliplr_params(mano_params, has_mano_params) + mano_params['global_orient'] = rot_aa(mano_params['global_orient'], rot) + return mano_params, has_mano_params + + + +def get_example(img_path: str|np.ndarray, center_x: float, center_y: float, + width: float, height: float, + keypoints_2d: np.array, keypoints_3d: np.array, + mano_params: Dict, has_mano_params: Dict, + flip_kp_permutation: List[int], + patch_width: int, patch_height: int, + mean: np.array, std: np.array, + do_augment: bool, is_right: bool, augm_config: CfgNode, + is_bgr: bool = True, + use_skimage_antialias: bool = False, + border_mode: int = cv2.BORDER_CONSTANT, + return_trans: bool = False) -> Tuple: + """ + Get an example from the dataset and (possibly) apply random augmentations. + Args: + img_path (str): Image filename + center_x (float): Bounding box center x coordinate in the original image. + center_y (float): Bounding box center y coordinate in the original image. + width (float): Bounding box width. + height (float): Bounding box height. + keypoints_2d (np.array): Array with shape (N,3) containing the 2D keypoints in the original image coordinates. + keypoints_3d (np.array): Array with shape (N,4) containing the 3D keypoints. + mano_params (Dict): MANO parameter annotations. + has_mano_params (Dict): Whether MANO annotations are valid. + flip_kp_permutation (List): Permutation to apply to the keypoints after flipping. + patch_width (float): Output box width. + patch_height (float): Output box height. + mean (np.array): Array of shape (3,) containing the mean for normalizing the input image. + std (np.array): Array of shape (3,) containing the std for normalizing the input image. + do_augment (bool): Whether to apply data augmentation or not. + aug_config (CfgNode): Config containing augmentation parameters. + Returns: + return img_patch, keypoints_2d, keypoints_3d, mano_params, has_mano_params, img_size + img_patch (np.array): Cropped image patch of shape (3, patch_height, patch_height) + keypoints_2d (np.array): Array with shape (N,3) containing the transformed 2D keypoints. + keypoints_3d (np.array): Array with shape (N,4) containing the transformed 3D keypoints. + mano_params (Dict): Transformed MANO parameters. + has_mano_params (Dict): Valid flag for transformed MANO parameters. + img_size (np.array): Image size of the original image. + """ + if isinstance(img_path, str): + # 1. load image + cvimg = cv2.imread(img_path, cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION) + if not isinstance(cvimg, np.ndarray): + raise IOError("Fail to read %s" % img_path) + elif isinstance(img_path, np.ndarray): + cvimg = img_path + else: + raise TypeError('img_path must be either a string or a numpy array') + img_height, img_width, img_channels = cvimg.shape + + img_size = np.array([img_height, img_width]) + + # 2. get augmentation params + if do_augment: + scale, rot, do_flip, do_extreme_crop, extreme_crop_lvl, color_scale, tx, ty = do_augmentation(augm_config) + else: + scale, rot, do_flip, do_extreme_crop, extreme_crop_lvl, color_scale, tx, ty = 1.0, 0, False, False, 0, [1.0, 1.0, 1.0], 0., 0. + + # if it's a left hand, we flip + if not is_right: + do_flip = True + + if width < 1 or height < 1: + breakpoint() + + if do_extreme_crop: + if extreme_crop_lvl == 0: + center_x1, center_y1, width1, height1 = extreme_cropping(center_x, center_y, width, height, keypoints_2d) + elif extreme_crop_lvl == 1: + center_x1, center_y1, width1, height1 = extreme_cropping_aggressive(center_x, center_y, width, height, keypoints_2d) + + THRESH = 4 + if width1 < THRESH or height1 < THRESH: + # print(f'{do_extreme_crop=}') + # print(f'width: {width}, height: {height}') + # print(f'width1: {width1}, height1: {height1}') + # print(f'center_x: {center_x}, center_y: {center_y}') + # print(f'center_x1: {center_x1}, center_y1: {center_y1}') + # print(f'keypoints_2d: {keypoints_2d}') + # print(f'\n\n', flush=True) + # breakpoint() + pass + # print(f'skip ==> width1: {width1}, height1: {height1}, width: {width}, height: {height}') + else: + center_x, center_y, width, height = center_x1, center_y1, width1, height1 + + center_x += width * tx + center_y += height * ty + + # Process 3D keypoints + keypoints_3d = keypoint_3d_processing(keypoints_3d, flip_kp_permutation, rot, do_flip) + + # 3. generate image patch + if use_skimage_antialias: + # Blur image to avoid aliasing artifacts + downsampling_factor = (patch_width / (width*scale)) + if downsampling_factor > 1.1: + cvimg = gaussian(cvimg, sigma=(downsampling_factor-1)/2, channel_axis=2, preserve_range=True, truncate=3.0) + + img_patch_cv, trans = generate_image_patch_cv2(cvimg, + center_x, center_y, + width, height, + patch_width, patch_height, + do_flip, scale, rot, + border_mode=border_mode) + + # img_patch_cv, trans = generate_image_patch_skimage(cvimg, + # center_x, center_y, + # width, height, + # patch_width, patch_height, + # do_flip, scale, rot, + # border_mode=border_mode) + + image = img_patch_cv.copy() + if is_bgr: + image = image[:, :, ::-1] + img_patch_cv = image.copy() + img_patch = convert_cvimg_to_tensor(image) + + + mano_params, has_mano_params = mano_param_processing(mano_params, has_mano_params, rot, do_flip) + + # apply normalization + for n_c in range(min(img_channels, 3)): + img_patch[n_c, :, :] = np.clip(img_patch[n_c, :, :] * color_scale[n_c], 0, 255) + if mean is not None and std is not None: + img_patch[n_c, :, :] = (img_patch[n_c, :, :] - mean[n_c]) / std[n_c] + if do_flip: + keypoints_2d = fliplr_keypoints(keypoints_2d, img_width, flip_kp_permutation) + + + for n_jt in range(len(keypoints_2d)): + keypoints_2d[n_jt, 0:2] = trans_point2d(keypoints_2d[n_jt, 0:2], trans) + keypoints_2d[:, :-1] = keypoints_2d[:, :-1] / patch_width - 0.5 + + if not return_trans: + return img_patch, keypoints_2d, keypoints_3d, mano_params, has_mano_params, img_size + else: + return img_patch, keypoints_2d, keypoints_3d, mano_params, has_mano_params, img_size, trans + +def crop_to_hips(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array) -> Tuple: + """ + Extreme cropping: Crop the box up to the hip locations. + Args: + center_x (float): x coordinate of the bounding box center. + center_y (float): y coordinate of the bounding box center. + width (float): Bounding box width. + height (float): Bounding box height. + keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations. + Returns: + center_x (float): x coordinate of the new bounding box center. + center_y (float): y coordinate of the new bounding box center. + width (float): New bounding box width. + height (float): New bounding box height. + """ + keypoints_2d = keypoints_2d.copy() + lower_body_keypoints = [10, 11, 13, 14, 19, 20, 21, 22, 23, 24, 25+0, 25+1, 25+4, 25+5] + keypoints_2d[lower_body_keypoints, :] = 0 + if keypoints_2d[:, -1].sum() > 1: + center, scale = get_bbox(keypoints_2d) + center_x = center[0] + center_y = center[1] + width = 1.1 * scale[0] + height = 1.1 * scale[1] + return center_x, center_y, width, height + + +def crop_to_shoulders(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array): + """ + Extreme cropping: Crop the box up to the shoulder locations. + Args: + center_x (float): x coordinate of the bounding box center. + center_y (float): y coordinate of the bounding box center. + width (float): Bounding box width. + height (float): Bounding box height. + keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations. + Returns: + center_x (float): x coordinate of the new bounding box center. + center_y (float): y coordinate of the new bounding box center. + width (float): New bounding box width. + height (float): New bounding box height. + """ + keypoints_2d = keypoints_2d.copy() + lower_body_keypoints = [3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 19, 20, 21, 22, 23, 24] + [25 + i for i in [0, 1, 2, 3, 4, 5, 6, 7, 10, 11, 14, 15, 16]] + keypoints_2d[lower_body_keypoints, :] = 0 + center, scale = get_bbox(keypoints_2d) + if keypoints_2d[:, -1].sum() > 1: + center, scale = get_bbox(keypoints_2d) + center_x = center[0] + center_y = center[1] + width = 1.2 * scale[0] + height = 1.2 * scale[1] + return center_x, center_y, width, height + +def crop_to_head(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array): + """ + Extreme cropping: Crop the box and keep on only the head. + Args: + center_x (float): x coordinate of the bounding box center. + center_y (float): y coordinate of the bounding box center. + width (float): Bounding box width. + height (float): Bounding box height. + keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations. + Returns: + center_x (float): x coordinate of the new bounding box center. + center_y (float): y coordinate of the new bounding box center. + width (float): New bounding box width. + height (float): New bounding box height. + """ + keypoints_2d = keypoints_2d.copy() + lower_body_keypoints = [3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 19, 20, 21, 22, 23, 24] + [25 + i for i in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 14, 15, 16]] + keypoints_2d[lower_body_keypoints, :] = 0 + if keypoints_2d[:, -1].sum() > 1: + center, scale = get_bbox(keypoints_2d) + center_x = center[0] + center_y = center[1] + width = 1.3 * scale[0] + height = 1.3 * scale[1] + return center_x, center_y, width, height + +def crop_torso_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array): + """ + Extreme cropping: Crop the box and keep on only the torso. + Args: + center_x (float): x coordinate of the bounding box center. + center_y (float): y coordinate of the bounding box center. + width (float): Bounding box width. + height (float): Bounding box height. + keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations. + Returns: + center_x (float): x coordinate of the new bounding box center. + center_y (float): y coordinate of the new bounding box center. + width (float): New bounding box width. + height (float): New bounding box height. + """ + keypoints_2d = keypoints_2d.copy() + nontorso_body_keypoints = [0, 3, 4, 6, 7, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24] + [25 + i for i in [0, 1, 4, 5, 6, 7, 10, 11, 13, 17, 18]] + keypoints_2d[nontorso_body_keypoints, :] = 0 + if keypoints_2d[:, -1].sum() > 1: + center, scale = get_bbox(keypoints_2d) + center_x = center[0] + center_y = center[1] + width = 1.1 * scale[0] + height = 1.1 * scale[1] + return center_x, center_y, width, height + +def crop_rightarm_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array): + """ + Extreme cropping: Crop the box and keep on only the right arm. + Args: + center_x (float): x coordinate of the bounding box center. + center_y (float): y coordinate of the bounding box center. + width (float): Bounding box width. + height (float): Bounding box height. + keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations. + Returns: + center_x (float): x coordinate of the new bounding box center. + center_y (float): y coordinate of the new bounding box center. + width (float): New bounding box width. + height (float): New bounding box height. + """ + keypoints_2d = keypoints_2d.copy() + nonrightarm_body_keypoints = [0, 1, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24] + [25 + i for i in [0, 1, 2, 3, 4, 5, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18]] + keypoints_2d[nonrightarm_body_keypoints, :] = 0 + if keypoints_2d[:, -1].sum() > 1: + center, scale = get_bbox(keypoints_2d) + center_x = center[0] + center_y = center[1] + width = 1.1 * scale[0] + height = 1.1 * scale[1] + return center_x, center_y, width, height + +def crop_leftarm_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array): + """ + Extreme cropping: Crop the box and keep on only the left arm. + Args: + center_x (float): x coordinate of the bounding box center. + center_y (float): y coordinate of the bounding box center. + width (float): Bounding box width. + height (float): Bounding box height. + keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations. + Returns: + center_x (float): x coordinate of the new bounding box center. + center_y (float): y coordinate of the new bounding box center. + width (float): New bounding box width. + height (float): New bounding box height. + """ + keypoints_2d = keypoints_2d.copy() + nonleftarm_body_keypoints = [0, 1, 2, 3, 4, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24] + [25 + i for i in [0, 1, 2, 3, 4, 5, 6, 7, 8, 12, 13, 14, 15, 16, 17, 18]] + keypoints_2d[nonleftarm_body_keypoints, :] = 0 + if keypoints_2d[:, -1].sum() > 1: + center, scale = get_bbox(keypoints_2d) + center_x = center[0] + center_y = center[1] + width = 1.1 * scale[0] + height = 1.1 * scale[1] + return center_x, center_y, width, height + +def crop_legs_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array): + """ + Extreme cropping: Crop the box and keep on only the legs. + Args: + center_x (float): x coordinate of the bounding box center. + center_y (float): y coordinate of the bounding box center. + width (float): Bounding box width. + height (float): Bounding box height. + keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations. + Returns: + center_x (float): x coordinate of the new bounding box center. + center_y (float): y coordinate of the new bounding box center. + width (float): New bounding box width. + height (float): New bounding box height. + """ + keypoints_2d = keypoints_2d.copy() + nonlegs_body_keypoints = [0, 1, 2, 3, 4, 5, 6, 7, 15, 16, 17, 18] + [25 + i for i in [6, 7, 8, 9, 10, 11, 12, 13, 15, 16, 17, 18]] + keypoints_2d[nonlegs_body_keypoints, :] = 0 + if keypoints_2d[:, -1].sum() > 1: + center, scale = get_bbox(keypoints_2d) + center_x = center[0] + center_y = center[1] + width = 1.1 * scale[0] + height = 1.1 * scale[1] + return center_x, center_y, width, height + +def crop_rightleg_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array): + """ + Extreme cropping: Crop the box and keep on only the right leg. + Args: + center_x (float): x coordinate of the bounding box center. + center_y (float): y coordinate of the bounding box center. + width (float): Bounding box width. + height (float): Bounding box height. + keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations. + Returns: + center_x (float): x coordinate of the new bounding box center. + center_y (float): y coordinate of the new bounding box center. + width (float): New bounding box width. + height (float): New bounding box height. + """ + keypoints_2d = keypoints_2d.copy() + nonrightleg_body_keypoints = [0, 1, 2, 3, 4, 5, 6, 7, 8, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21] + [25 + i for i in [3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18]] + keypoints_2d[nonrightleg_body_keypoints, :] = 0 + if keypoints_2d[:, -1].sum() > 1: + center, scale = get_bbox(keypoints_2d) + center_x = center[0] + center_y = center[1] + width = 1.1 * scale[0] + height = 1.1 * scale[1] + return center_x, center_y, width, height + +def crop_leftleg_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array): + """ + Extreme cropping: Crop the box and keep on only the left leg. + Args: + center_x (float): x coordinate of the bounding box center. + center_y (float): y coordinate of the bounding box center. + width (float): Bounding box width. + height (float): Bounding box height. + keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations. + Returns: + center_x (float): x coordinate of the new bounding box center. + center_y (float): y coordinate of the new bounding box center. + width (float): New bounding box width. + height (float): New bounding box height. + """ + keypoints_2d = keypoints_2d.copy() + nonleftleg_body_keypoints = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 15, 16, 17, 18, 22, 23, 24] + [25 + i for i in [0, 1, 2, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18]] + keypoints_2d[nonleftleg_body_keypoints, :] = 0 + if keypoints_2d[:, -1].sum() > 1: + center, scale = get_bbox(keypoints_2d) + center_x = center[0] + center_y = center[1] + width = 1.1 * scale[0] + height = 1.1 * scale[1] + return center_x, center_y, width, height + +def full_body(keypoints_2d: np.array) -> bool: + """ + Check if all main body joints are visible. + Args: + keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations. + Returns: + bool: True if all main body joints are visible. + """ + + body_keypoints_openpose = [2, 3, 4, 5, 6, 7, 10, 11, 13, 14] + body_keypoints = [25 + i for i in [8, 7, 6, 9, 10, 11, 1, 0, 4, 5]] + return (np.maximum(keypoints_2d[body_keypoints, -1], keypoints_2d[body_keypoints_openpose, -1]) > 0).sum() == len(body_keypoints) + +def upper_body(keypoints_2d: np.array): + """ + Check if all upper body joints are visible. + Args: + keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations. + Returns: + bool: True if all main body joints are visible. + """ + lower_body_keypoints_openpose = [10, 11, 13, 14] + lower_body_keypoints = [25 + i for i in [1, 0, 4, 5]] + upper_body_keypoints_openpose = [0, 1, 15, 16, 17, 18] + upper_body_keypoints = [25+8, 25+9, 25+12, 25+13, 25+17, 25+18] + return ((keypoints_2d[lower_body_keypoints + lower_body_keypoints_openpose, -1] > 0).sum() == 0)\ + and ((keypoints_2d[upper_body_keypoints + upper_body_keypoints_openpose, -1] > 0).sum() >= 2) + +def get_bbox(keypoints_2d: np.array, rescale: float = 1.2) -> Tuple: + """ + Get center and scale for bounding box from openpose detections. + Args: + keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations. + rescale (float): Scale factor to rescale bounding boxes computed from the keypoints. + Returns: + center (np.array): Array of shape (2,) containing the new bounding box center. + scale (float): New bounding box scale. + """ + valid = keypoints_2d[:,-1] > 0 + valid_keypoints = keypoints_2d[valid][:,:-1] + center = 0.5 * (valid_keypoints.max(axis=0) + valid_keypoints.min(axis=0)) + bbox_size = (valid_keypoints.max(axis=0) - valid_keypoints.min(axis=0)) + # adjust bounding box tightness + scale = bbox_size + scale *= rescale + return center, scale + +def extreme_cropping(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array) -> Tuple: + """ + Perform extreme cropping + Args: + center_x (float): x coordinate of bounding box center. + center_y (float): y coordinate of bounding box center. + width (float): bounding box width. + height (float): bounding box height. + keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations. + rescale (float): Scale factor to rescale bounding boxes computed from the keypoints. + Returns: + center_x (float): x coordinate of bounding box center. + center_y (float): y coordinate of bounding box center. + width (float): bounding box width. + height (float): bounding box height. + """ + p = torch.rand(1).item() + if full_body(keypoints_2d): + if p < 0.7: + center_x, center_y, width, height = crop_to_hips(center_x, center_y, width, height, keypoints_2d) + elif p < 0.9: + center_x, center_y, width, height = crop_to_shoulders(center_x, center_y, width, height, keypoints_2d) + else: + center_x, center_y, width, height = crop_to_head(center_x, center_y, width, height, keypoints_2d) + elif upper_body(keypoints_2d): + if p < 0.9: + center_x, center_y, width, height = crop_to_shoulders(center_x, center_y, width, height, keypoints_2d) + else: + center_x, center_y, width, height = crop_to_head(center_x, center_y, width, height, keypoints_2d) + + return center_x, center_y, max(width, height), max(width, height) + +def extreme_cropping_aggressive(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array) -> Tuple: + """ + Perform aggressive extreme cropping + Args: + center_x (float): x coordinate of bounding box center. + center_y (float): y coordinate of bounding box center. + width (float): bounding box width. + height (float): bounding box height. + keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations. + rescale (float): Scale factor to rescale bounding boxes computed from the keypoints. + Returns: + center_x (float): x coordinate of bounding box center. + center_y (float): y coordinate of bounding box center. + width (float): bounding box width. + height (float): bounding box height. + """ + p = torch.rand(1).item() + if full_body(keypoints_2d): + if p < 0.2: + center_x, center_y, width, height = crop_to_hips(center_x, center_y, width, height, keypoints_2d) + elif p < 0.3: + center_x, center_y, width, height = crop_to_shoulders(center_x, center_y, width, height, keypoints_2d) + elif p < 0.4: + center_x, center_y, width, height = crop_to_head(center_x, center_y, width, height, keypoints_2d) + elif p < 0.5: + center_x, center_y, width, height = crop_torso_only(center_x, center_y, width, height, keypoints_2d) + elif p < 0.6: + center_x, center_y, width, height = crop_rightarm_only(center_x, center_y, width, height, keypoints_2d) + elif p < 0.7: + center_x, center_y, width, height = crop_leftarm_only(center_x, center_y, width, height, keypoints_2d) + elif p < 0.8: + center_x, center_y, width, height = crop_legs_only(center_x, center_y, width, height, keypoints_2d) + elif p < 0.9: + center_x, center_y, width, height = crop_rightleg_only(center_x, center_y, width, height, keypoints_2d) + else: + center_x, center_y, width, height = crop_leftleg_only(center_x, center_y, width, height, keypoints_2d) + elif upper_body(keypoints_2d): + if p < 0.2: + center_x, center_y, width, height = crop_to_shoulders(center_x, center_y, width, height, keypoints_2d) + elif p < 0.4: + center_x, center_y, width, height = crop_to_head(center_x, center_y, width, height, keypoints_2d) + elif p < 0.6: + center_x, center_y, width, height = crop_torso_only(center_x, center_y, width, height, keypoints_2d) + elif p < 0.8: + center_x, center_y, width, height = crop_rightarm_only(center_x, center_y, width, height, keypoints_2d) + else: + center_x, center_y, width, height = crop_leftarm_only(center_x, center_y, width, height, keypoints_2d) + return center_x, center_y, max(width, height), max(width, height) diff --git a/wilor/datasets/vitdet_dataset.py b/wilor/datasets/vitdet_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..1fd6e8b9bc52d927c98959174bc556bf796f7162 --- /dev/null +++ b/wilor/datasets/vitdet_dataset.py @@ -0,0 +1,95 @@ +from typing import Dict + +import cv2 +import numpy as np +from skimage.filters import gaussian +from yacs.config import CfgNode +import torch + +from .utils import (convert_cvimg_to_tensor, + expand_to_aspect_ratio, + generate_image_patch_cv2) + +DEFAULT_MEAN = 255. * np.array([0.485, 0.456, 0.406]) +DEFAULT_STD = 255. * np.array([0.229, 0.224, 0.225]) + +class ViTDetDataset(torch.utils.data.Dataset): + + def __init__(self, + cfg: CfgNode, + img_cv2: np.array, + boxes: np.array, + right: np.array, + rescale_factor=2.5, + train: bool = False, + **kwargs): + super().__init__() + self.cfg = cfg + self.img_cv2 = img_cv2 + # self.boxes = boxes + + assert train == False, "ViTDetDataset is only for inference" + self.train = train + self.img_size = cfg.MODEL.IMAGE_SIZE + self.mean = 255. * np.array(self.cfg.MODEL.IMAGE_MEAN) + self.std = 255. * np.array(self.cfg.MODEL.IMAGE_STD) + + # Preprocess annotations + boxes = boxes.astype(np.float32) + self.center = (boxes[:, 2:4] + boxes[:, 0:2]) / 2.0 + self.scale = rescale_factor * (boxes[:, 2:4] - boxes[:, 0:2]) / 200.0 + self.personid = np.arange(len(boxes), dtype=np.int32) + self.right = right.astype(np.float32) + + def __len__(self) -> int: + return len(self.personid) + + def __getitem__(self, idx: int) -> Dict[str, np.array]: + + center = self.center[idx].copy() + center_x = center[0] + center_y = center[1] + + scale = self.scale[idx] + BBOX_SHAPE = self.cfg.MODEL.get('BBOX_SHAPE', None) + bbox_size = expand_to_aspect_ratio(scale*200, target_aspect_ratio=BBOX_SHAPE).max() + + patch_width = patch_height = self.img_size + + right = self.right[idx].copy() + flip = right == 0 + + # 3. generate image patch + # if use_skimage_antialias: + cvimg = self.img_cv2.copy() + if True: + # Blur image to avoid aliasing artifacts + downsampling_factor = ((bbox_size*1.0) / patch_width) + #print(f'{downsampling_factor=}') + downsampling_factor = downsampling_factor / 2.0 + if downsampling_factor > 1.1: + cvimg = gaussian(cvimg, sigma=(downsampling_factor-1)/2, channel_axis=2, preserve_range=True) + + + img_patch_cv, trans = generate_image_patch_cv2(cvimg, + center_x, center_y, + bbox_size, bbox_size, + patch_width, patch_height, + flip, 1.0, 0, + border_mode=cv2.BORDER_CONSTANT) + img_patch_cv = img_patch_cv[:, :, ::-1] + img_patch = convert_cvimg_to_tensor(img_patch_cv) + + # apply normalization + for n_c in range(min(self.img_cv2.shape[2], 3)): + img_patch[n_c, :, :] = (img_patch[n_c, :, :] - self.mean[n_c]) / self.std[n_c] + + item = { + 'img': img_patch, + 'personid': int(self.personid[idx]), + } + item['box_center'] = self.center[idx].copy() + item['box_size'] = bbox_size + item['img_size'] = 1.0 * np.array([cvimg.shape[1], cvimg.shape[0]]) + item['right'] = self.right[idx].copy() + return item diff --git a/wilor/models/__init__.py b/wilor/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..94187b99b6c63ce2113e72d5b2ffa906ab673cc6 --- /dev/null +++ b/wilor/models/__init__.py @@ -0,0 +1,36 @@ +from .mano_wrapper import MANO +from .wilor import WiLoR + +from .discriminator import Discriminator + +def load_wilor(checkpoint_path, cfg_path): + from pathlib import Path + from wilor.configs import get_config + print('Loading ', checkpoint_path) + model_cfg = get_config(cfg_path, update_cachedir=True) + + # Override some config values, to crop bbox correctly + if ('vit' in model_cfg.MODEL.BACKBONE.TYPE) and ('BBOX_SHAPE' not in model_cfg.MODEL): + + model_cfg.defrost() + assert model_cfg.MODEL.IMAGE_SIZE == 256, f"MODEL.IMAGE_SIZE ({model_cfg.MODEL.IMAGE_SIZE}) should be 256 for ViT backbone" + model_cfg.MODEL.BBOX_SHAPE = [192,256] + model_cfg.freeze() + + # Update config to be compatible with demo + if ('PRETRAINED_WEIGHTS' in model_cfg.MODEL.BACKBONE): + model_cfg.defrost() + model_cfg.MODEL.BACKBONE.pop('PRETRAINED_WEIGHTS') + model_cfg.freeze() + + # Update config to be compatible with demo + + if ('DATA_DIR' in model_cfg.MANO): + model_cfg.defrost() + model_cfg.MANO.DATA_DIR = './mano_data/' + model_cfg.MANO.MODEL_PATH = './mano_data/mano/' + model_cfg.MANO.MEAN_PARAMS = './mano_data/mano_mean_params.npz' + model_cfg.freeze() + + model = WiLoR.load_from_checkpoint(checkpoint_path, strict=False, cfg=model_cfg) + return model, model_cfg \ No newline at end of file diff --git a/wilor/models/backbones/__init__.py b/wilor/models/backbones/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0f4305a01fadf44319fd331cf7fd987f33a23510 --- /dev/null +++ b/wilor/models/backbones/__init__.py @@ -0,0 +1,17 @@ +from .vit import vit + +def create_backbone(cfg): + if cfg.MODEL.BACKBONE.TYPE == 'vit': + return vit(cfg) + elif cfg.MODEL.BACKBONE.TYPE == 'fast_vit': + import torch + import sys + from timm.models import create_model + #from models.modules.mobileone import reparameterize_model + fast_vit = create_model("fastvit_ma36", drop_path_rate=0.2) + checkpoint = torch.load('./pretrained_models/fastvit_ma36.pt') + fast_vit.load_state_dict(checkpoint['state_dict']) + return fast_vit + + else: + raise NotImplementedError('Backbone type is not implemented') diff --git a/wilor/models/backbones/vit.py b/wilor/models/backbones/vit.py new file mode 100644 index 0000000000000000000000000000000000000000..704e4922f6d9688d9434fd8598acbcda4ca6c0e7 --- /dev/null +++ b/wilor/models/backbones/vit.py @@ -0,0 +1,410 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +import numpy as np +import torch +from functools import partial +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from ...utils.geometry import rot6d_to_rotmat, aa_to_rotmat +from timm.models.layers import drop_path, to_2tuple, trunc_normal_ + +def vit(cfg): + return ViT( + img_size=(256, 192), + patch_size=16, + embed_dim=1280, + depth=32, + num_heads=16, + ratio=1, + use_checkpoint=False, + mlp_ratio=4, + qkv_bias=True, + drop_path_rate=0.55, + cfg = cfg + ) + +def get_abs_pos(abs_pos, h, w, ori_h, ori_w, has_cls_token=True): + """ + Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token + dimension for the original embeddings. + Args: + abs_pos (Tensor): absolute positional embeddings with (1, num_position, C). + has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token. + hw (Tuple): size of input image tokens. + + Returns: + Absolute positional embeddings after processing with shape (1, H, W, C) + """ + cls_token = None + B, L, C = abs_pos.shape + if has_cls_token: + cls_token = abs_pos[:, 0:1] + abs_pos = abs_pos[:, 1:] + + if ori_h != h or ori_w != w: + new_abs_pos = F.interpolate( + abs_pos.reshape(1, ori_h, ori_w, -1).permute(0, 3, 1, 2), + size=(h, w), + mode="bicubic", + align_corners=False, + ).permute(0, 2, 3, 1).reshape(B, -1, C) + + else: + new_abs_pos = abs_pos + + if cls_token is not None: + new_abs_pos = torch.cat([cls_token, new_abs_pos], dim=1) + return new_abs_pos + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + def extra_repr(self): + return 'p={}'.format(self.drop_prob) + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.fc2(x) + x = self.drop(x) + return x + +class Attention(nn.Module): + def __init__( + self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., + proj_drop=0., attn_head_dim=None,): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.dim = dim + + if attn_head_dim is not None: + head_dim = attn_head_dim + all_head_dim = head_dim * self.num_heads + + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, all_head_dim * 3, bias=qkv_bias) + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(all_head_dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + qkv = self.qkv(x) + qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, -1) + x = self.proj(x) + x = self.proj_drop(x) + + return x + +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, + drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, + norm_layer=nn.LayerNorm, attn_head_dim=None + ): + super().__init__() + + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, + attn_drop=attn_drop, proj_drop=drop, attn_head_dim=attn_head_dim + ) + + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x): + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding + """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, ratio=1): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) * (ratio ** 2) + self.patch_shape = (int(img_size[0] // patch_size[0] * ratio), int(img_size[1] // patch_size[1] * ratio)) + self.origin_patch_shape = (int(img_size[0] // patch_size[0]), int(img_size[1] // patch_size[1])) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=(patch_size[0] // ratio), padding=4 + 2 * (ratio//2-1)) + + def forward(self, x, **kwargs): + B, C, H, W = x.shape + x = self.proj(x) + Hp, Wp = x.shape[2], x.shape[3] + + x = x.flatten(2).transpose(1, 2) + return x, (Hp, Wp) + + +class HybridEmbed(nn.Module): + """ CNN Feature Map Embedding + Extract feature map from CNN, flatten, project to embedding dim. + """ + def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768): + super().__init__() + assert isinstance(backbone, nn.Module) + img_size = to_2tuple(img_size) + self.img_size = img_size + self.backbone = backbone + if feature_size is None: + with torch.no_grad(): + training = backbone.training + if training: + backbone.eval() + o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1] + feature_size = o.shape[-2:] + feature_dim = o.shape[1] + backbone.train(training) + else: + feature_size = to_2tuple(feature_size) + feature_dim = self.backbone.feature_info.channels()[-1] + self.num_patches = feature_size[0] * feature_size[1] + self.proj = nn.Linear(feature_dim, embed_dim) + + def forward(self, x): + x = self.backbone(x)[-1] + x = x.flatten(2).transpose(1, 2) + x = self.proj(x) + return x + + +class ViT(nn.Module): + + def __init__(self, + img_size=224, patch_size=16, in_chans=3, num_classes=80, embed_dim=768, depth=12, + num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., + drop_path_rate=0., hybrid_backbone=None, norm_layer=None, use_checkpoint=False, + frozen_stages=-1, ratio=1, last_norm=True, + patch_padding='pad', freeze_attn=False, freeze_ffn=False,cfg=None, + ): + # Protect mutable default arguments + super(ViT, self).__init__() + norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) + self.num_classes = num_classes + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.frozen_stages = frozen_stages + self.use_checkpoint = use_checkpoint + self.patch_padding = patch_padding + self.freeze_attn = freeze_attn + self.freeze_ffn = freeze_ffn + self.depth = depth + + if hybrid_backbone is not None: + self.patch_embed = HybridEmbed( + hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim) + else: + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, ratio=ratio) + num_patches = self.patch_embed.num_patches + + ########################################## + self.cfg = cfg + self.joint_rep_type = cfg.MODEL.MANO_HEAD.get('JOINT_REP', '6d') + self.joint_rep_dim = {'6d': 6, 'aa': 3}[self.joint_rep_type] + npose = self.joint_rep_dim * (cfg.MANO.NUM_HAND_JOINTS + 1) + self.npose = npose + mean_params = np.load(cfg.MANO.MEAN_PARAMS) + init_cam = torch.from_numpy(mean_params['cam'].astype(np.float32)).unsqueeze(0) + self.register_buffer('init_cam', init_cam) + init_hand_pose = torch.from_numpy(mean_params['pose'].astype(np.float32)).unsqueeze(0) + init_betas = torch.from_numpy(mean_params['shape'].astype('float32')).unsqueeze(0) + self.register_buffer('init_hand_pose', init_hand_pose) + self.register_buffer('init_betas', init_betas) + + self.pose_emb = nn.Linear(self.joint_rep_dim , embed_dim) + self.shape_emb = nn.Linear(10 , embed_dim) + self.cam_emb = nn.Linear(3 , embed_dim) + + self.decpose = nn.Linear(self.num_features, 6) + self.decshape = nn.Linear(self.num_features, 10) + self.deccam = nn.Linear(self.num_features, 3) + if cfg.MODEL.MANO_HEAD.get('INIT_DECODER_XAVIER', False): + # True by default in MLP. False by default in Transformer + nn.init.xavier_uniform_(self.decpose.weight, gain=0.01) + nn.init.xavier_uniform_(self.decshape.weight, gain=0.01) + nn.init.xavier_uniform_(self.deccam.weight, gain=0.01) + + + ########################################## + + # since the pretraining model has class token + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + + self.blocks = nn.ModuleList([ + Block( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, + ) + for i in range(depth)]) + + self.last_norm = norm_layer(embed_dim) if last_norm else nn.Identity() + + if self.pos_embed is not None: + trunc_normal_(self.pos_embed, std=.02) + + self._freeze_stages() + + def _freeze_stages(self): + """Freeze parameters.""" + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + for i in range(1, self.frozen_stages + 1): + m = self.blocks[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + if self.freeze_attn: + for i in range(0, self.depth): + m = self.blocks[i] + m.attn.eval() + m.norm1.eval() + for param in m.attn.parameters(): + param.requires_grad = False + for param in m.norm1.parameters(): + param.requires_grad = False + + if self.freeze_ffn: + self.pos_embed.requires_grad = False + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + for i in range(0, self.depth): + m = self.blocks[i] + m.mlp.eval() + m.norm2.eval() + for param in m.mlp.parameters(): + param.requires_grad = False + for param in m.norm2.parameters(): + param.requires_grad = False + + def init_weights(self): + """Initialize the weights in backbone. + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + def _init_weights(m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + self.apply(_init_weights) + + def get_num_layers(self): + return len(self.blocks) + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed', 'cls_token'} + + def forward_features(self, x): + B, C, H, W = x.shape + x, (Hp, Wp) = self.patch_embed(x) + + if self.pos_embed is not None: + # fit for multiple GPU training + # since the first element for pos embed (sin-cos manner) is zero, it will cause no difference + x = x + self.pos_embed[:, 1:] + self.pos_embed[:, :1] + # X [B, 192, 1280] + # x cat [ mean_pose, mean_shape, mean_cam] tokens + pose_tokens = self.pose_emb(self.init_hand_pose.reshape(1, self.cfg.MANO.NUM_HAND_JOINTS + 1, self.joint_rep_dim)).repeat(B, 1, 1) + shape_tokens = self.shape_emb(self.init_betas).unsqueeze(1).repeat(B, 1, 1) + cam_tokens = self.cam_emb(self.init_cam).unsqueeze(1).repeat(B, 1, 1) + + x = torch.cat([pose_tokens, shape_tokens, cam_tokens, x], 1) + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + + x = self.last_norm(x) + + + pose_feat = x[:, :(self.cfg.MANO.NUM_HAND_JOINTS + 1)] + shape_feat = x[:, (self.cfg.MANO.NUM_HAND_JOINTS + 1):1+(self.cfg.MANO.NUM_HAND_JOINTS + 1)] + cam_feat = x[:, 1+(self.cfg.MANO.NUM_HAND_JOINTS + 1):2+(self.cfg.MANO.NUM_HAND_JOINTS + 1)] + + #print(pose_feat.shape, shape_feat.shape, cam_feat.shape) + pred_hand_pose = self.decpose(pose_feat).reshape(B, -1) + self.init_hand_pose #B , 96 + pred_betas = self.decshape(shape_feat).reshape(B, -1) + self.init_betas #B , 10 + pred_cam = self.deccam(cam_feat).reshape(B, -1) + self.init_cam #B , 3 + + pred_mano_feats = {} + pred_mano_feats['hand_pose'] = pred_hand_pose + pred_mano_feats['betas'] = pred_betas + pred_mano_feats['cam'] = pred_cam + + + joint_conversion_fn = { + '6d': rot6d_to_rotmat, + 'aa': lambda x: aa_to_rotmat(x.view(-1, 3).contiguous()) + }[self.joint_rep_type] + + pred_hand_pose = joint_conversion_fn(pred_hand_pose).view(B, self.cfg.MANO.NUM_HAND_JOINTS+1, 3, 3) + pred_mano_params = {'global_orient': pred_hand_pose[:, [0]], + 'hand_pose': pred_hand_pose[:, 1:], + 'betas': pred_betas} + + img_feat = x[:, 2+(self.cfg.MANO.NUM_HAND_JOINTS + 1):].reshape(B, Hp, Wp, -1).permute(0, 3, 1, 2) + return pred_mano_params, pred_cam, pred_mano_feats, img_feat + + def forward(self, x): + x = self.forward_features(x) + return x + + def train(self, mode=True): + """Convert the model into training mode.""" + super().train(mode) + self._freeze_stages() diff --git a/wilor/models/discriminator.py b/wilor/models/discriminator.py new file mode 100644 index 0000000000000000000000000000000000000000..f1cb2d1a21fbab47e8fa10dcc603b3d2012686a7 --- /dev/null +++ b/wilor/models/discriminator.py @@ -0,0 +1,98 @@ +import torch +import torch.nn as nn + +class Discriminator(nn.Module): + + def __init__(self): + """ + Pose + Shape discriminator proposed in HMR + """ + super(Discriminator, self).__init__() + + self.num_joints = 15 + # poses_alone + self.D_conv1 = nn.Conv2d(9, 32, kernel_size=1) + nn.init.xavier_uniform_(self.D_conv1.weight) + nn.init.zeros_(self.D_conv1.bias) + self.relu = nn.ReLU(inplace=True) + self.D_conv2 = nn.Conv2d(32, 32, kernel_size=1) + nn.init.xavier_uniform_(self.D_conv2.weight) + nn.init.zeros_(self.D_conv2.bias) + pose_out = [] + for i in range(self.num_joints): + pose_out_temp = nn.Linear(32, 1) + nn.init.xavier_uniform_(pose_out_temp.weight) + nn.init.zeros_(pose_out_temp.bias) + pose_out.append(pose_out_temp) + self.pose_out = nn.ModuleList(pose_out) + + # betas + self.betas_fc1 = nn.Linear(10, 10) + nn.init.xavier_uniform_(self.betas_fc1.weight) + nn.init.zeros_(self.betas_fc1.bias) + self.betas_fc2 = nn.Linear(10, 5) + nn.init.xavier_uniform_(self.betas_fc2.weight) + nn.init.zeros_(self.betas_fc2.bias) + self.betas_out = nn.Linear(5, 1) + nn.init.xavier_uniform_(self.betas_out.weight) + nn.init.zeros_(self.betas_out.bias) + + # poses_joint + self.D_alljoints_fc1 = nn.Linear(32*self.num_joints, 1024) + nn.init.xavier_uniform_(self.D_alljoints_fc1.weight) + nn.init.zeros_(self.D_alljoints_fc1.bias) + self.D_alljoints_fc2 = nn.Linear(1024, 1024) + nn.init.xavier_uniform_(self.D_alljoints_fc2.weight) + nn.init.zeros_(self.D_alljoints_fc2.bias) + self.D_alljoints_out = nn.Linear(1024, 1) + nn.init.xavier_uniform_(self.D_alljoints_out.weight) + nn.init.zeros_(self.D_alljoints_out.bias) + + + def forward(self, poses: torch.Tensor, betas: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the discriminator. + Args: + poses (torch.Tensor): Tensor of shape (B, 23, 3, 3) containing a batch of MANO hand poses (excluding the global orientation). + betas (torch.Tensor): Tensor of shape (B, 10) containign a batch of MANO beta coefficients. + Returns: + torch.Tensor: Discriminator output with shape (B, 25) + """ + #bn = poses.shape[0] + # poses B x 207 + #poses = poses.reshape(bn, -1) + # poses B x num_joints x 1 x 9 + poses = poses.reshape(-1, self.num_joints, 1, 9) + bn = poses.shape[0] + # poses B x 9 x num_joints x 1 + poses = poses.permute(0, 3, 1, 2).contiguous() + + # poses_alone + poses = self.D_conv1(poses) + poses = self.relu(poses) + poses = self.D_conv2(poses) + poses = self.relu(poses) + + poses_out = [] + for i in range(self.num_joints): + poses_out_ = self.pose_out[i](poses[:, :, i, 0]) + poses_out.append(poses_out_) + poses_out = torch.cat(poses_out, dim=1) + + # betas + betas = self.betas_fc1(betas) + betas = self.relu(betas) + betas = self.betas_fc2(betas) + betas = self.relu(betas) + betas_out = self.betas_out(betas) + + # poses_joint + poses = poses.reshape(bn,-1) + poses_all = self.D_alljoints_fc1(poses) + poses_all = self.relu(poses_all) + poses_all = self.D_alljoints_fc2(poses_all) + poses_all = self.relu(poses_all) + poses_all_out = self.D_alljoints_out(poses_all) + + disc_out = torch.cat((poses_out, betas_out, poses_all_out), 1) + return disc_out diff --git a/wilor/models/heads/__init__.py b/wilor/models/heads/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..40279d65814d344cd6a6f452356c4ac4e6a633b3 --- /dev/null +++ b/wilor/models/heads/__init__.py @@ -0,0 +1 @@ +from .refinement_net import RefineNet \ No newline at end of file diff --git a/wilor/models/heads/refinement_net.py b/wilor/models/heads/refinement_net.py new file mode 100644 index 0000000000000000000000000000000000000000..98cd8ef17617b2f5b5efbd0731aab103be77a361 --- /dev/null +++ b/wilor/models/heads/refinement_net.py @@ -0,0 +1,204 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math +from ...utils.geometry import rot6d_to_rotmat, aa_to_rotmat +from typing import Optional + +def make_linear_layers(feat_dims, relu_final=True, use_bn=False): + layers = [] + for i in range(len(feat_dims)-1): + layers.append(nn.Linear(feat_dims[i], feat_dims[i+1])) + + # Do not use ReLU for final estimation + if i < len(feat_dims)-2 or (i == len(feat_dims)-2 and relu_final): + if use_bn: + layers.append(nn.BatchNorm1d(feat_dims[i+1])) + layers.append(nn.ReLU(inplace=True)) + + return nn.Sequential(*layers) + +def make_conv_layers(feat_dims, kernel=3, stride=1, padding=1, bnrelu_final=True): + layers = [] + for i in range(len(feat_dims)-1): + layers.append( + nn.Conv2d( + in_channels=feat_dims[i], + out_channels=feat_dims[i+1], + kernel_size=kernel, + stride=stride, + padding=padding + )) + # Do not use BN and ReLU for final estimation + if i < len(feat_dims)-2 or (i == len(feat_dims)-2 and bnrelu_final): + layers.append(nn.BatchNorm2d(feat_dims[i+1])) + layers.append(nn.ReLU(inplace=True)) + + return nn.Sequential(*layers) + +def make_deconv_layers(feat_dims, bnrelu_final=True): + layers = [] + for i in range(len(feat_dims)-1): + layers.append( + nn.ConvTranspose2d( + in_channels=feat_dims[i], + out_channels=feat_dims[i+1], + kernel_size=4, + stride=2, + padding=1, + output_padding=0, + bias=False)) + + # Do not use BN and ReLU for final estimation + if i < len(feat_dims)-2 or (i == len(feat_dims)-2 and bnrelu_final): + layers.append(nn.BatchNorm2d(feat_dims[i+1])) + layers.append(nn.ReLU(inplace=True)) + + return nn.Sequential(*layers) + +def sample_joint_features(img_feat, joint_xy): + height, width = img_feat.shape[2:] + x = joint_xy[:, :, 0] / (width - 1) * 2 - 1 + y = joint_xy[:, :, 1] / (height - 1) * 2 - 1 + grid = torch.stack((x, y), 2)[:, :, None, :] + img_feat = F.grid_sample(img_feat, grid, align_corners=True)[:, :, :, 0] # batch_size, channel_dim, joint_num + img_feat = img_feat.permute(0, 2, 1).contiguous() # batch_size, joint_num, channel_dim + return img_feat + +def perspective_projection(points: torch.Tensor, + translation: torch.Tensor, + focal_length: torch.Tensor, + camera_center: Optional[torch.Tensor] = None, + rotation: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + Computes the perspective projection of a set of 3D points. + Args: + points (torch.Tensor): Tensor of shape (B, N, 3) containing the input 3D points. + translation (torch.Tensor): Tensor of shape (B, 3) containing the 3D camera translation. + focal_length (torch.Tensor): Tensor of shape (B, 2) containing the focal length in pixels. + camera_center (torch.Tensor): Tensor of shape (B, 2) containing the camera center in pixels. + rotation (torch.Tensor): Tensor of shape (B, 3, 3) containing the camera rotation. + Returns: + torch.Tensor: Tensor of shape (B, N, 2) containing the projection of the input points. + """ + batch_size = points.shape[0] + if rotation is None: + rotation = torch.eye(3, device=points.device, dtype=points.dtype).unsqueeze(0).expand(batch_size, -1, -1) + if camera_center is None: + camera_center = torch.zeros(batch_size, 2, device=points.device, dtype=points.dtype) + # Populate intrinsic camera matrix K. + K = torch.zeros([batch_size, 3, 3], device=points.device, dtype=points.dtype) + K[:,0,0] = focal_length[:,0] + K[:,1,1] = focal_length[:,1] + K[:,2,2] = 1. + K[:,:-1, -1] = camera_center + # Transform points + points = torch.einsum('bij,bkj->bki', rotation, points) + points = points + translation.unsqueeze(1) + + # Apply perspective distortion + projected_points = points / points[:,:,-1].unsqueeze(-1) + + # Apply camera intrinsics + projected_points = torch.einsum('bij,bkj->bki', K, projected_points) + + return projected_points[:, :, :-1] + +class DeConvNet(nn.Module): + def __init__(self, feat_dim=768, upscale=4): + super(DeConvNet, self).__init__() + self.first_conv = make_conv_layers([feat_dim, feat_dim//2], kernel=1, stride=1, padding=0, bnrelu_final=False) + self.deconv = nn.ModuleList([]) + for i in range(int(math.log2(upscale))+1): + if i==0: + self.deconv.append(make_deconv_layers([feat_dim//2, feat_dim//4])) + elif i==1: + self.deconv.append(make_deconv_layers([feat_dim//2, feat_dim//4, feat_dim//8])) + elif i==2: + self.deconv.append(make_deconv_layers([feat_dim//2, feat_dim//4, feat_dim//8, feat_dim//8])) + + def forward(self, img_feat): + + face_img_feats = [] + img_feat = self.first_conv(img_feat) + face_img_feats.append(img_feat) + for i, deconv in enumerate(self.deconv): + scale = 2**i + img_feat_i = deconv(img_feat) + face_img_feat = img_feat_i + face_img_feats.append(face_img_feat) + return face_img_feats[::-1] # high resolution -> low resolution + +class DeConvNet_v2(nn.Module): + def __init__(self, feat_dim=768): + super(DeConvNet_v2, self).__init__() + self.first_conv = make_conv_layers([feat_dim, feat_dim//2], kernel=1, stride=1, padding=0, bnrelu_final=False) + self.deconv = nn.Sequential(*[nn.ConvTranspose2d(in_channels=feat_dim//2, out_channels=feat_dim//4, kernel_size=4, stride=4, padding=0, output_padding=0, bias=False), + nn.BatchNorm2d(feat_dim//4), + nn.ReLU(inplace=True)]) + + def forward(self, img_feat): + + face_img_feats = [] + img_feat = self.first_conv(img_feat) + img_feat = self.deconv(img_feat) + + return [img_feat] + +class RefineNet(nn.Module): + def __init__(self, cfg, feat_dim=1280, upscale=3): + super(RefineNet, self).__init__() + #self.deconv = DeConvNet_v2(feat_dim=feat_dim) + #self.out_dim = feat_dim//4 + + self.deconv = DeConvNet(feat_dim=feat_dim, upscale=upscale) + self.out_dim = feat_dim//8 + feat_dim//4 + feat_dim//2 + self.dec_pose = nn.Linear(self.out_dim, 96) + self.dec_cam = nn.Linear(self.out_dim, 3) + self.dec_shape = nn.Linear(self.out_dim, 10) + + self.cfg = cfg + self.joint_rep_type = cfg.MODEL.MANO_HEAD.get('JOINT_REP', '6d') + self.joint_rep_dim = {'6d': 6, 'aa': 3}[self.joint_rep_type] + + def forward(self, img_feat, verts_3d, pred_cam, pred_mano_feats, focal_length): + B = img_feat.shape[0] + + img_feats = self.deconv(img_feat) + + img_feat_sizes = [img_feat.shape[2] for img_feat in img_feats] + + temp_cams = [torch.stack([pred_cam[:, 1], pred_cam[:, 2], + 2*focal_length[:, 0]/(img_feat_size * pred_cam[:, 0] +1e-9)],dim=-1) for img_feat_size in img_feat_sizes] + + verts_2d = [perspective_projection(verts_3d, + translation=temp_cams[i], + focal_length=focal_length / img_feat_sizes[i]) for i in range(len(img_feat_sizes))] + + vert_feats = [sample_joint_features(img_feats[i], verts_2d[i]).max(1).values for i in range(len(img_feat_sizes))] + + vert_feats = torch.cat(vert_feats, dim=-1) + + delta_pose = self.dec_pose(vert_feats) + delta_betas = self.dec_shape(vert_feats) + delta_cam = self.dec_cam(vert_feats) + + + pred_hand_pose = pred_mano_feats['hand_pose'] + delta_pose + pred_betas = pred_mano_feats['betas'] + delta_betas + pred_cam = pred_mano_feats['cam'] + delta_cam + + joint_conversion_fn = { + '6d': rot6d_to_rotmat, + 'aa': lambda x: aa_to_rotmat(x.view(-1, 3).contiguous()) + }[self.joint_rep_type] + + pred_hand_pose = joint_conversion_fn(pred_hand_pose).view(B, self.cfg.MANO.NUM_HAND_JOINTS+1, 3, 3) + + pred_mano_params = {'global_orient': pred_hand_pose[:, [0]], + 'hand_pose': pred_hand_pose[:, 1:], + 'betas': pred_betas} + + return pred_mano_params, pred_cam + + \ No newline at end of file diff --git a/wilor/models/losses.py b/wilor/models/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..d6e493c081a4d99b97b5641e85152c4d56072a58 --- /dev/null +++ b/wilor/models/losses.py @@ -0,0 +1,92 @@ +import torch +import torch.nn as nn + +class Keypoint2DLoss(nn.Module): + + def __init__(self, loss_type: str = 'l1'): + """ + 2D keypoint loss module. + Args: + loss_type (str): Choose between l1 and l2 losses. + """ + super(Keypoint2DLoss, self).__init__() + if loss_type == 'l1': + self.loss_fn = nn.L1Loss(reduction='none') + elif loss_type == 'l2': + self.loss_fn = nn.MSELoss(reduction='none') + else: + raise NotImplementedError('Unsupported loss function') + + def forward(self, pred_keypoints_2d: torch.Tensor, gt_keypoints_2d: torch.Tensor) -> torch.Tensor: + """ + Compute 2D reprojection loss on the keypoints. + Args: + pred_keypoints_2d (torch.Tensor): Tensor of shape [B, S, N, 2] containing projected 2D keypoints (B: batch_size, S: num_samples, N: num_keypoints) + gt_keypoints_2d (torch.Tensor): Tensor of shape [B, S, N, 3] containing the ground truth 2D keypoints and confidence. + Returns: + torch.Tensor: 2D keypoint loss. + """ + conf = gt_keypoints_2d[:, :, -1].unsqueeze(-1).clone() + batch_size = conf.shape[0] + loss = (conf * self.loss_fn(pred_keypoints_2d, gt_keypoints_2d[:, :, :-1])).sum(dim=(1,2)) + return loss.sum() + + +class Keypoint3DLoss(nn.Module): + + def __init__(self, loss_type: str = 'l1'): + """ + 3D keypoint loss module. + Args: + loss_type (str): Choose between l1 and l2 losses. + """ + super(Keypoint3DLoss, self).__init__() + if loss_type == 'l1': + self.loss_fn = nn.L1Loss(reduction='none') + elif loss_type == 'l2': + self.loss_fn = nn.MSELoss(reduction='none') + else: + raise NotImplementedError('Unsupported loss function') + + def forward(self, pred_keypoints_3d: torch.Tensor, gt_keypoints_3d: torch.Tensor, pelvis_id: int = 0): + """ + Compute 3D keypoint loss. + Args: + pred_keypoints_3d (torch.Tensor): Tensor of shape [B, S, N, 3] containing the predicted 3D keypoints (B: batch_size, S: num_samples, N: num_keypoints) + gt_keypoints_3d (torch.Tensor): Tensor of shape [B, S, N, 4] containing the ground truth 3D keypoints and confidence. + Returns: + torch.Tensor: 3D keypoint loss. + """ + batch_size = pred_keypoints_3d.shape[0] + gt_keypoints_3d = gt_keypoints_3d.clone() + pred_keypoints_3d = pred_keypoints_3d - pred_keypoints_3d[:, pelvis_id, :].unsqueeze(dim=1) + gt_keypoints_3d[:, :, :-1] = gt_keypoints_3d[:, :, :-1] - gt_keypoints_3d[:, pelvis_id, :-1].unsqueeze(dim=1) + conf = gt_keypoints_3d[:, :, -1].unsqueeze(-1).clone() + gt_keypoints_3d = gt_keypoints_3d[:, :, :-1] + loss = (conf * self.loss_fn(pred_keypoints_3d, gt_keypoints_3d)).sum(dim=(1,2)) + return loss.sum() + +class ParameterLoss(nn.Module): + + def __init__(self): + """ + MANO parameter loss module. + """ + super(ParameterLoss, self).__init__() + self.loss_fn = nn.MSELoss(reduction='none') + + def forward(self, pred_param: torch.Tensor, gt_param: torch.Tensor, has_param: torch.Tensor): + """ + Compute MANO parameter loss. + Args: + pred_param (torch.Tensor): Tensor of shape [B, S, ...] containing the predicted parameters (body pose / global orientation / betas) + gt_param (torch.Tensor): Tensor of shape [B, S, ...] containing the ground truth MANO parameters. + Returns: + torch.Tensor: L2 parameter loss loss. + """ + batch_size = pred_param.shape[0] + num_dims = len(pred_param.shape) + mask_dimension = [batch_size] + [1] * (num_dims-1) + has_param = has_param.type(pred_param.type()).view(*mask_dimension) + loss_param = (has_param * self.loss_fn(pred_param, gt_param)) + return loss_param.sum() diff --git a/wilor/models/mano_wrapper.py b/wilor/models/mano_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..f6f0cc336098e9303d2514c571307c56baf3bc86 --- /dev/null +++ b/wilor/models/mano_wrapper.py @@ -0,0 +1,40 @@ +import torch +import numpy as np +import pickle +from typing import Optional +import smplx +from smplx.lbs import vertices2joints +from smplx.utils import MANOOutput, to_tensor +from smplx.vertex_ids import vertex_ids + + +class MANO(smplx.MANOLayer): + def __init__(self, *args, joint_regressor_extra: Optional[str] = None, **kwargs): + """ + Extension of the official MANO implementation to support more joints. + Args: + Same as MANOLayer. + joint_regressor_extra (str): Path to extra joint regressor. + """ + super(MANO, self).__init__(*args, **kwargs) + mano_to_openpose = [0, 13, 14, 15, 16, 1, 2, 3, 17, 4, 5, 6, 18, 10, 11, 12, 19, 7, 8, 9, 20] + + #2, 3, 5, 4, 1 + if joint_regressor_extra is not None: + self.register_buffer('joint_regressor_extra', torch.tensor(pickle.load(open(joint_regressor_extra, 'rb'), encoding='latin1'), dtype=torch.float32)) + self.register_buffer('extra_joints_idxs', to_tensor(list(vertex_ids['mano'].values()), dtype=torch.long)) + self.register_buffer('joint_map', torch.tensor(mano_to_openpose, dtype=torch.long)) + + def forward(self, *args, **kwargs) -> MANOOutput: + """ + Run forward pass. Same as MANO and also append an extra set of joints if joint_regressor_extra is specified. + """ + mano_output = super(MANO, self).forward(*args, **kwargs) + extra_joints = torch.index_select(mano_output.vertices, 1, self.extra_joints_idxs) + joints = torch.cat([mano_output.joints, extra_joints], dim=1) + joints = joints[:, self.joint_map, :] + if hasattr(self, 'joint_regressor_extra'): + extra_joints = vertices2joints(self.joint_regressor_extra, mano_output.vertices) + joints = torch.cat([joints, extra_joints], dim=1) + mano_output.joints = joints + return mano_output diff --git a/wilor/models/wilor.py b/wilor/models/wilor.py new file mode 100644 index 0000000000000000000000000000000000000000..e5306376229a56931a444e693a6b0d070cc75bfd --- /dev/null +++ b/wilor/models/wilor.py @@ -0,0 +1,376 @@ +import torch +import pytorch_lightning as pl +from typing import Any, Dict, Mapping, Tuple + +from yacs.config import CfgNode + +from ..utils import SkeletonRenderer, MeshRenderer +from ..utils.geometry import aa_to_rotmat, perspective_projection +from ..utils.pylogger import get_pylogger +from .backbones import create_backbone +from .heads import RefineNet +from .discriminator import Discriminator +from .losses import Keypoint3DLoss, Keypoint2DLoss, ParameterLoss +from . import MANO + +log = get_pylogger(__name__) + +class WiLoR(pl.LightningModule): + + def __init__(self, cfg: CfgNode, init_renderer: bool = True): + """ + Setup WiLoR model + Args: + cfg (CfgNode): Config file as a yacs CfgNode + """ + super().__init__() + + # Save hyperparameters + self.save_hyperparameters(logger=False, ignore=['init_renderer']) + + self.cfg = cfg + # Create backbone feature extractor + self.backbone = create_backbone(cfg) + if cfg.MODEL.BACKBONE.get('PRETRAINED_WEIGHTS', None): + log.info(f'Loading backbone weights from {cfg.MODEL.BACKBONE.PRETRAINED_WEIGHTS}') + self.backbone.load_state_dict(torch.load(cfg.MODEL.BACKBONE.PRETRAINED_WEIGHTS, map_location='cpu')['state_dict'], strict = False) + + # Create RefineNet head + self.refine_net = RefineNet(cfg, feat_dim=1280, upscale=3) + + # Create discriminator + if self.cfg.LOSS_WEIGHTS.ADVERSARIAL > 0: + self.discriminator = Discriminator() + + # Define loss functions + self.keypoint_3d_loss = Keypoint3DLoss(loss_type='l1') + self.keypoint_2d_loss = Keypoint2DLoss(loss_type='l1') + self.mano_parameter_loss = ParameterLoss() + + # Instantiate MANO model + mano_cfg = {k.lower(): v for k,v in dict(cfg.MANO).items()} + self.mano = MANO(**mano_cfg) + + # Buffer that shows whetheer we need to initialize ActNorm layers + self.register_buffer('initialized', torch.tensor(False)) + # Setup renderer for visualization + if init_renderer: + self.renderer = SkeletonRenderer(self.cfg) + self.mesh_renderer = MeshRenderer(self.cfg, faces=self.mano.faces) + else: + self.renderer = None + self.mesh_renderer = None + + + # Disable automatic optimization since we use adversarial training + self.automatic_optimization = False + + def on_after_backward(self): + for name, param in self.named_parameters(): + if param.grad is None: + print(param.shape) + print(name) + + + def get_parameters(self): + #all_params = list(self.mano_head.parameters()) + all_params = list(self.backbone.parameters()) + return all_params + + def configure_optimizers(self) -> Tuple[torch.optim.Optimizer, torch.optim.Optimizer]: + """ + Setup model and distriminator Optimizers + Returns: + Tuple[torch.optim.Optimizer, torch.optim.Optimizer]: Model and discriminator optimizers + """ + param_groups = [{'params': filter(lambda p: p.requires_grad, self.get_parameters()), 'lr': self.cfg.TRAIN.LR}] + + optimizer = torch.optim.AdamW(params=param_groups, + # lr=self.cfg.TRAIN.LR, + weight_decay=self.cfg.TRAIN.WEIGHT_DECAY) + optimizer_disc = torch.optim.AdamW(params=self.discriminator.parameters(), + lr=self.cfg.TRAIN.LR, + weight_decay=self.cfg.TRAIN.WEIGHT_DECAY) + + return optimizer, optimizer_disc + + def forward_step(self, batch: Dict, train: bool = False) -> Dict: + """ + Run a forward step of the network + Args: + batch (Dict): Dictionary containing batch data + train (bool): Flag indicating whether it is training or validation mode + Returns: + Dict: Dictionary containing the regression output + """ + # Use RGB image as input + x = batch['img'] + batch_size = x.shape[0] + # Compute conditioning features using the backbone + # if using ViT backbone, we need to use a different aspect ratio + temp_mano_params, pred_cam, pred_mano_feats, vit_out = self.backbone(x[:,:,:,32:-32]) # B, 1280, 16, 12 + + + # Compute camera translation + device = temp_mano_params['hand_pose'].device + dtype = temp_mano_params['hand_pose'].dtype + focal_length = self.cfg.EXTRA.FOCAL_LENGTH * torch.ones(batch_size, 2, device=device, dtype=dtype) + + + ## Temp MANO + temp_mano_params['global_orient'] = temp_mano_params['global_orient'].reshape(batch_size, -1, 3, 3) + temp_mano_params['hand_pose'] = temp_mano_params['hand_pose'].reshape(batch_size, -1, 3, 3) + temp_mano_params['betas'] = temp_mano_params['betas'].reshape(batch_size, -1) + temp_mano_output = self.mano(**{k: v.float() for k,v in temp_mano_params.items()}, pose2rot=False) + #temp_keypoints_3d = temp_mano_output.joints + temp_vertices = temp_mano_output.vertices + + pred_mano_params, pred_cam = self.refine_net(vit_out, temp_vertices, pred_cam, pred_mano_feats, focal_length) + # Store useful regression outputs to the output dict + + + output = {} + output['pred_cam'] = pred_cam + output['pred_mano_params'] = {k: v.clone() for k,v in pred_mano_params.items()} + + pred_cam_t = torch.stack([pred_cam[:, 1], + pred_cam[:, 2], + 2*focal_length[:, 0]/(self.cfg.MODEL.IMAGE_SIZE * pred_cam[:, 0] +1e-9)],dim=-1) + output['pred_cam_t'] = pred_cam_t + output['focal_length'] = focal_length + + # Compute model vertices, joints and the projected joints + pred_mano_params['global_orient'] = pred_mano_params['global_orient'].reshape(batch_size, -1, 3, 3) + pred_mano_params['hand_pose'] = pred_mano_params['hand_pose'].reshape(batch_size, -1, 3, 3) + pred_mano_params['betas'] = pred_mano_params['betas'].reshape(batch_size, -1) + mano_output = self.mano(**{k: v.float() for k,v in pred_mano_params.items()}, pose2rot=False) + pred_keypoints_3d = mano_output.joints + pred_vertices = mano_output.vertices + + output['pred_keypoints_3d'] = pred_keypoints_3d.reshape(batch_size, -1, 3) + output['pred_vertices'] = pred_vertices.reshape(batch_size, -1, 3) + pred_cam_t = pred_cam_t.reshape(-1, 3) + focal_length = focal_length.reshape(-1, 2) + + pred_keypoints_2d = perspective_projection(pred_keypoints_3d, + translation=pred_cam_t, + focal_length=focal_length / self.cfg.MODEL.IMAGE_SIZE) + output['pred_keypoints_2d'] = pred_keypoints_2d.reshape(batch_size, -1, 2) + + return output + + def compute_loss(self, batch: Dict, output: Dict, train: bool = True) -> torch.Tensor: + """ + Compute losses given the input batch and the regression output + Args: + batch (Dict): Dictionary containing batch data + output (Dict): Dictionary containing the regression output + train (bool): Flag indicating whether it is training or validation mode + Returns: + torch.Tensor : Total loss for current batch + """ + + pred_mano_params = output['pred_mano_params'] + pred_keypoints_2d = output['pred_keypoints_2d'] + pred_keypoints_3d = output['pred_keypoints_3d'] + + + batch_size = pred_mano_params['hand_pose'].shape[0] + device = pred_mano_params['hand_pose'].device + dtype = pred_mano_params['hand_pose'].dtype + + # Get annotations + gt_keypoints_2d = batch['keypoints_2d'] + gt_keypoints_3d = batch['keypoints_3d'] + gt_mano_params = batch['mano_params'] + has_mano_params = batch['has_mano_params'] + is_axis_angle = batch['mano_params_is_axis_angle'] + + # Compute 3D keypoint loss + loss_keypoints_2d = self.keypoint_2d_loss(pred_keypoints_2d, gt_keypoints_2d) + loss_keypoints_3d = self.keypoint_3d_loss(pred_keypoints_3d, gt_keypoints_3d, pelvis_id=0) + + # Compute loss on MANO parameters + loss_mano_params = {} + for k, pred in pred_mano_params.items(): + gt = gt_mano_params[k].view(batch_size, -1) + if is_axis_angle[k].all(): + gt = aa_to_rotmat(gt.reshape(-1, 3)).view(batch_size, -1, 3, 3) + has_gt = has_mano_params[k] + loss_mano_params[k] = self.mano_parameter_loss(pred.reshape(batch_size, -1), gt.reshape(batch_size, -1), has_gt) + + loss = self.cfg.LOSS_WEIGHTS['KEYPOINTS_3D'] * loss_keypoints_3d+\ + self.cfg.LOSS_WEIGHTS['KEYPOINTS_2D'] * loss_keypoints_2d+\ + sum([loss_mano_params[k] * self.cfg.LOSS_WEIGHTS[k.upper()] for k in loss_mano_params]) + + + losses = dict(loss=loss.detach(), + loss_keypoints_2d=loss_keypoints_2d.detach(), + loss_keypoints_3d=loss_keypoints_3d.detach()) + + for k, v in loss_mano_params.items(): + losses['loss_' + k] = v.detach() + + output['losses'] = losses + + return loss + + # Tensoroboard logging should run from first rank only + @pl.utilities.rank_zero.rank_zero_only + def tensorboard_logging(self, batch: Dict, output: Dict, step_count: int, train: bool = True, write_to_summary_writer: bool = True) -> None: + """ + Log results to Tensorboard + Args: + batch (Dict): Dictionary containing batch data + output (Dict): Dictionary containing the regression output + step_count (int): Global training step count + train (bool): Flag indicating whether it is training or validation mode + """ + + mode = 'train' if train else 'val' + batch_size = batch['keypoints_2d'].shape[0] + images = batch['img'] + images = images * torch.tensor([0.229, 0.224, 0.225], device=images.device).reshape(1,3,1,1) + images = images + torch.tensor([0.485, 0.456, 0.406], device=images.device).reshape(1,3,1,1) + #images = 255*images.permute(0, 2, 3, 1).cpu().numpy() + + pred_keypoints_3d = output['pred_keypoints_3d'].detach().reshape(batch_size, -1, 3) + pred_vertices = output['pred_vertices'].detach().reshape(batch_size, -1, 3) + focal_length = output['focal_length'].detach().reshape(batch_size, 2) + gt_keypoints_3d = batch['keypoints_3d'] + gt_keypoints_2d = batch['keypoints_2d'] + + losses = output['losses'] + pred_cam_t = output['pred_cam_t'].detach().reshape(batch_size, 3) + pred_keypoints_2d = output['pred_keypoints_2d'].detach().reshape(batch_size, -1, 2) + if write_to_summary_writer: + summary_writer = self.logger.experiment + for loss_name, val in losses.items(): + summary_writer.add_scalar(mode +'/' + loss_name, val.detach().item(), step_count) + num_images = min(batch_size, self.cfg.EXTRA.NUM_LOG_IMAGES) + + gt_keypoints_3d = batch['keypoints_3d'] + pred_keypoints_3d = output['pred_keypoints_3d'].detach().reshape(batch_size, -1, 3) + + # We render the skeletons instead of the full mesh because rendering a lot of meshes will make the training slow. + #predictions = self.renderer(pred_keypoints_3d[:num_images], + # gt_keypoints_3d[:num_images], + # 2 * gt_keypoints_2d[:num_images], + # images=images[:num_images], + # camera_translation=pred_cam_t[:num_images]) + predictions = self.mesh_renderer.visualize_tensorboard(pred_vertices[:num_images].cpu().numpy(), + pred_cam_t[:num_images].cpu().numpy(), + images[:num_images].cpu().numpy(), + pred_keypoints_2d[:num_images].cpu().numpy(), + gt_keypoints_2d[:num_images].cpu().numpy(), + focal_length=focal_length[:num_images].cpu().numpy()) + if write_to_summary_writer: + summary_writer.add_image('%s/predictions' % mode, predictions, step_count) + + return predictions + + def forward(self, batch: Dict) -> Dict: + """ + Run a forward step of the network in val mode + Args: + batch (Dict): Dictionary containing batch data + Returns: + Dict: Dictionary containing the regression output + """ + return self.forward_step(batch, train=False) + + def training_step_discriminator(self, batch: Dict, + hand_pose: torch.Tensor, + betas: torch.Tensor, + optimizer: torch.optim.Optimizer) -> torch.Tensor: + """ + Run a discriminator training step + Args: + batch (Dict): Dictionary containing mocap batch data + hand_pose (torch.Tensor): Regressed hand pose from current step + betas (torch.Tensor): Regressed betas from current step + optimizer (torch.optim.Optimizer): Discriminator optimizer + Returns: + torch.Tensor: Discriminator loss + """ + batch_size = hand_pose.shape[0] + gt_hand_pose = batch['hand_pose'] + gt_betas = batch['betas'] + gt_rotmat = aa_to_rotmat(gt_hand_pose.view(-1,3)).view(batch_size, -1, 3, 3) + disc_fake_out = self.discriminator(hand_pose.detach(), betas.detach()) + loss_fake = ((disc_fake_out - 0.0) ** 2).sum() / batch_size + disc_real_out = self.discriminator(gt_rotmat, gt_betas) + loss_real = ((disc_real_out - 1.0) ** 2).sum() / batch_size + loss_disc = loss_fake + loss_real + loss = self.cfg.LOSS_WEIGHTS.ADVERSARIAL * loss_disc + optimizer.zero_grad() + self.manual_backward(loss) + optimizer.step() + return loss_disc.detach() + + def training_step(self, joint_batch: Dict, batch_idx: int) -> Dict: + """ + Run a full training step + Args: + joint_batch (Dict): Dictionary containing image and mocap batch data + batch_idx (int): Unused. + batch_idx (torch.Tensor): Unused. + Returns: + Dict: Dictionary containing regression output. + """ + batch = joint_batch['img'] + mocap_batch = joint_batch['mocap'] + optimizer = self.optimizers(use_pl_optimizer=True) + if self.cfg.LOSS_WEIGHTS.ADVERSARIAL > 0: + optimizer, optimizer_disc = optimizer + + batch_size = batch['img'].shape[0] + output = self.forward_step(batch, train=True) + pred_mano_params = output['pred_mano_params'] + if self.cfg.get('UPDATE_GT_SPIN', False): + self.update_batch_gt_spin(batch, output) + loss = self.compute_loss(batch, output, train=True) + if self.cfg.LOSS_WEIGHTS.ADVERSARIAL > 0: + disc_out = self.discriminator(pred_mano_params['hand_pose'].reshape(batch_size, -1), pred_mano_params['betas'].reshape(batch_size, -1)) + loss_adv = ((disc_out - 1.0) ** 2).sum() / batch_size + loss = loss + self.cfg.LOSS_WEIGHTS.ADVERSARIAL * loss_adv + + # Error if Nan + if torch.isnan(loss): + raise ValueError('Loss is NaN') + + optimizer.zero_grad() + self.manual_backward(loss) + # Clip gradient + if self.cfg.TRAIN.get('GRAD_CLIP_VAL', 0) > 0: + gn = torch.nn.utils.clip_grad_norm_(self.get_parameters(), self.cfg.TRAIN.GRAD_CLIP_VAL, error_if_nonfinite=True) + self.log('train/grad_norm', gn, on_step=True, on_epoch=True, prog_bar=True, logger=True) + optimizer.step() + if self.cfg.LOSS_WEIGHTS.ADVERSARIAL > 0: + loss_disc = self.training_step_discriminator(mocap_batch, pred_mano_params['hand_pose'].reshape(batch_size, -1), pred_mano_params['betas'].reshape(batch_size, -1), optimizer_disc) + output['losses']['loss_gen'] = loss_adv + output['losses']['loss_disc'] = loss_disc + + if self.global_step > 0 and self.global_step % self.cfg.GENERAL.LOG_STEPS == 0: + self.tensorboard_logging(batch, output, self.global_step, train=True) + + self.log('train/loss', output['losses']['loss'], on_step=True, on_epoch=True, prog_bar=True, logger=False) + + return output + + def validation_step(self, batch: Dict, batch_idx: int, dataloader_idx=0) -> Dict: + """ + Run a validation step and log to Tensorboard + Args: + batch (Dict): Dictionary containing batch data + batch_idx (int): Unused. + Returns: + Dict: Dictionary containing regression output. + """ + # batch_size = batch['img'].shape[0] + output = self.forward_step(batch, train=False) + loss = self.compute_loss(batch, output, train=False) + output['loss'] = loss + self.tensorboard_logging(batch, output, self.global_step, train=False) + + return output diff --git a/wilor/utils/__init__.py b/wilor/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..09e47cdf8cdb303432d64902fbe58b256273f88a --- /dev/null +++ b/wilor/utils/__init__.py @@ -0,0 +1,25 @@ +import torch +from typing import Any + +from .renderer import Renderer +from .mesh_renderer import MeshRenderer +from .skeleton_renderer import SkeletonRenderer +from .pose_utils import eval_pose, Evaluator + +def recursive_to(x: Any, target: torch.device): + """ + Recursively transfer a batch of data to the target device + Args: + x (Any): Batch of data. + target (torch.device): Target device. + Returns: + Batch of data where all tensors are transfered to the target device. + """ + if isinstance(x, dict): + return {k: recursive_to(v, target) for k, v in x.items()} + elif isinstance(x, torch.Tensor): + return x.to(target) + elif isinstance(x, list): + return [recursive_to(i, target) for i in x] + else: + return x diff --git a/wilor/utils/geometry.py b/wilor/utils/geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..7929ef52608618a4682788487008e73c5736101b --- /dev/null +++ b/wilor/utils/geometry.py @@ -0,0 +1,102 @@ +from typing import Optional +import torch +from torch.nn import functional as F + +def aa_to_rotmat(theta: torch.Tensor): + """ + Convert axis-angle representation to rotation matrix. + Works by first converting it to a quaternion. + Args: + theta (torch.Tensor): Tensor of shape (B, 3) containing axis-angle representations. + Returns: + torch.Tensor: Corresponding rotation matrices with shape (B, 3, 3). + """ + norm = torch.norm(theta + 1e-8, p = 2, dim = 1) + angle = torch.unsqueeze(norm, -1) + normalized = torch.div(theta, angle) + angle = angle * 0.5 + v_cos = torch.cos(angle) + v_sin = torch.sin(angle) + quat = torch.cat([v_cos, v_sin * normalized], dim = 1) + return quat_to_rotmat(quat) + +def quat_to_rotmat(quat: torch.Tensor) -> torch.Tensor: + """ + Convert quaternion representation to rotation matrix. + Args: + quat (torch.Tensor) of shape (B, 4); 4 <===> (w, x, y, z). + Returns: + torch.Tensor: Corresponding rotation matrices with shape (B, 3, 3). + """ + norm_quat = quat + norm_quat = norm_quat/norm_quat.norm(p=2, dim=1, keepdim=True) + w, x, y, z = norm_quat[:,0], norm_quat[:,1], norm_quat[:,2], norm_quat[:,3] + + B = quat.size(0) + + w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2) + wx, wy, wz = w*x, w*y, w*z + xy, xz, yz = x*y, x*z, y*z + + rotMat = torch.stack([w2 + x2 - y2 - z2, 2*xy - 2*wz, 2*wy + 2*xz, + 2*wz + 2*xy, w2 - x2 + y2 - z2, 2*yz - 2*wx, + 2*xz - 2*wy, 2*wx + 2*yz, w2 - x2 - y2 + z2], dim=1).view(B, 3, 3) + return rotMat + + +def rot6d_to_rotmat(x: torch.Tensor) -> torch.Tensor: + """ + Convert 6D rotation representation to 3x3 rotation matrix. + Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019 + Args: + x (torch.Tensor): (B,6) Batch of 6-D rotation representations. + Returns: + torch.Tensor: Batch of corresponding rotation matrices with shape (B,3,3). + """ + x = x.reshape(-1,2,3).permute(0, 2, 1).contiguous() + a1 = x[:, :, 0] + a2 = x[:, :, 1] + b1 = F.normalize(a1) + b2 = F.normalize(a2 - torch.einsum('bi,bi->b', b1, a2).unsqueeze(-1) * b1) + b3 = torch.cross(b1, b2) + return torch.stack((b1, b2, b3), dim=-1) + +def perspective_projection(points: torch.Tensor, + translation: torch.Tensor, + focal_length: torch.Tensor, + camera_center: Optional[torch.Tensor] = None, + rotation: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + Computes the perspective projection of a set of 3D points. + Args: + points (torch.Tensor): Tensor of shape (B, N, 3) containing the input 3D points. + translation (torch.Tensor): Tensor of shape (B, 3) containing the 3D camera translation. + focal_length (torch.Tensor): Tensor of shape (B, 2) containing the focal length in pixels. + camera_center (torch.Tensor): Tensor of shape (B, 2) containing the camera center in pixels. + rotation (torch.Tensor): Tensor of shape (B, 3, 3) containing the camera rotation. + Returns: + torch.Tensor: Tensor of shape (B, N, 2) containing the projection of the input points. + """ + batch_size = points.shape[0] + if rotation is None: + rotation = torch.eye(3, device=points.device, dtype=points.dtype).unsqueeze(0).expand(batch_size, -1, -1) + if camera_center is None: + camera_center = torch.zeros(batch_size, 2, device=points.device, dtype=points.dtype) + # Populate intrinsic camera matrix K. + K = torch.zeros([batch_size, 3, 3], device=points.device, dtype=points.dtype) + K[:,0,0] = focal_length[:,0] + K[:,1,1] = focal_length[:,1] + K[:,2,2] = 1. + K[:,:-1, -1] = camera_center + + # Transform points + points = torch.einsum('bij,bkj->bki', rotation, points) + points = points + translation.unsqueeze(1) + + # Apply perspective distortion + projected_points = points / points[:,:,-1].unsqueeze(-1) + + # Apply camera intrinsics + projected_points = torch.einsum('bij,bkj->bki', K, projected_points) + + return projected_points[:, :, :-1] \ No newline at end of file diff --git a/wilor/utils/mesh_renderer.py b/wilor/utils/mesh_renderer.py new file mode 100644 index 0000000000000000000000000000000000000000..bb3e8ed2e9aed8157ec852d06d5f13e8f4ff7c54 --- /dev/null +++ b/wilor/utils/mesh_renderer.py @@ -0,0 +1,149 @@ +import os +if 'PYOPENGL_PLATFORM' not in os.environ: + os.environ['PYOPENGL_PLATFORM'] = 'egl' +import torch +from torchvision.utils import make_grid +import numpy as np +import pyrender +import trimesh +import cv2 +import torch.nn.functional as F + +from .render_openpose import render_openpose + +def create_raymond_lights(): + import pyrender + thetas = np.pi * np.array([1.0 / 6.0, 1.0 / 6.0, 1.0 / 6.0]) + phis = np.pi * np.array([0.0, 2.0 / 3.0, 4.0 / 3.0]) + + nodes = [] + + for phi, theta in zip(phis, thetas): + xp = np.sin(theta) * np.cos(phi) + yp = np.sin(theta) * np.sin(phi) + zp = np.cos(theta) + + z = np.array([xp, yp, zp]) + z = z / np.linalg.norm(z) + x = np.array([-z[1], z[0], 0.0]) + if np.linalg.norm(x) == 0: + x = np.array([1.0, 0.0, 0.0]) + x = x / np.linalg.norm(x) + y = np.cross(z, x) + + matrix = np.eye(4) + matrix[:3,:3] = np.c_[x,y,z] + nodes.append(pyrender.Node( + light=pyrender.DirectionalLight(color=np.ones(3), intensity=1.0), + matrix=matrix + )) + + return nodes + +class MeshRenderer: + + def __init__(self, cfg, faces=None): + self.cfg = cfg + self.focal_length = cfg.EXTRA.FOCAL_LENGTH + self.img_res = cfg.MODEL.IMAGE_SIZE + self.renderer = pyrender.OffscreenRenderer(viewport_width=self.img_res, + viewport_height=self.img_res, + point_size=1.0) + + self.camera_center = [self.img_res // 2, self.img_res // 2] + self.faces = faces + + def visualize(self, vertices, camera_translation, images, focal_length=None, nrow=3, padding=2): + images_np = np.transpose(images, (0,2,3,1)) + rend_imgs = [] + for i in range(vertices.shape[0]): + fl = self.focal_length + rend_img = torch.from_numpy(np.transpose(self.__call__(vertices[i], camera_translation[i], images_np[i], focal_length=fl, side_view=False), (2,0,1))).float() + rend_img_side = torch.from_numpy(np.transpose(self.__call__(vertices[i], camera_translation[i], images_np[i], focal_length=fl, side_view=True), (2,0,1))).float() + rend_imgs.append(torch.from_numpy(images[i])) + rend_imgs.append(rend_img) + rend_imgs.append(rend_img_side) + rend_imgs = make_grid(rend_imgs, nrow=nrow, padding=padding) + return rend_imgs + + def visualize_tensorboard(self, vertices, camera_translation, images, pred_keypoints, gt_keypoints, focal_length=None, nrow=5, padding=2): + images_np = np.transpose(images, (0,2,3,1)) + rend_imgs = [] + pred_keypoints = np.concatenate((pred_keypoints, np.ones_like(pred_keypoints)[:, :, [0]]), axis=-1) + pred_keypoints = self.img_res * (pred_keypoints + 0.5) + gt_keypoints[:, :, :-1] = self.img_res * (gt_keypoints[:, :, :-1] + 0.5) + #keypoint_matches = [(1, 12), (2, 8), (3, 7), (4, 6), (5, 9), (6, 10), (7, 11), (8, 14), (9, 2), (10, 1), (11, 0), (12, 3), (13, 4), (14, 5)] + for i in range(vertices.shape[0]): + fl = self.focal_length + rend_img = torch.from_numpy(np.transpose(self.__call__(vertices[i], camera_translation[i], images_np[i], focal_length=fl, side_view=False), (2,0,1))).float() + rend_img_side = torch.from_numpy(np.transpose(self.__call__(vertices[i], camera_translation[i], images_np[i], focal_length=fl, side_view=True), (2,0,1))).float() + hand_keypoints = pred_keypoints[i, :21] + #extra_keypoints = pred_keypoints[i, -19:] + #for pair in keypoint_matches: + # hand_keypoints[pair[0], :] = extra_keypoints[pair[1], :] + pred_keypoints_img = render_openpose(255 * images_np[i].copy(), hand_keypoints) / 255 + hand_keypoints = gt_keypoints[i, :21] + #extra_keypoints = gt_keypoints[i, -19:] + #for pair in keypoint_matches: + # if extra_keypoints[pair[1], -1] > 0 and hand_keypoints[pair[0], -1] == 0: + # hand_keypoints[pair[0], :] = extra_keypoints[pair[1], :] + gt_keypoints_img = render_openpose(255*images_np[i].copy(), hand_keypoints) / 255 + rend_imgs.append(torch.from_numpy(images[i])) + rend_imgs.append(rend_img) + rend_imgs.append(rend_img_side) + rend_imgs.append(torch.from_numpy(pred_keypoints_img).permute(2,0,1)) + rend_imgs.append(torch.from_numpy(gt_keypoints_img).permute(2,0,1)) + rend_imgs = make_grid(rend_imgs, nrow=nrow, padding=padding) + return rend_imgs + + def __call__(self, vertices, camera_translation, image, focal_length=5000, text=None, resize=None, side_view=False, baseColorFactor=(1.0, 1.0, 0.9, 1.0), rot_angle=90): + renderer = pyrender.OffscreenRenderer(viewport_width=image.shape[1], + viewport_height=image.shape[0], + point_size=1.0) + material = pyrender.MetallicRoughnessMaterial( + metallicFactor=0.0, + alphaMode='OPAQUE', + baseColorFactor=baseColorFactor) + + camera_translation[0] *= -1. + + mesh = trimesh.Trimesh(vertices.copy(), self.faces.copy()) + if side_view: + rot = trimesh.transformations.rotation_matrix( + np.radians(rot_angle), [0, 1, 0]) + mesh.apply_transform(rot) + rot = trimesh.transformations.rotation_matrix( + np.radians(180), [1, 0, 0]) + mesh.apply_transform(rot) + mesh = pyrender.Mesh.from_trimesh(mesh, material=material) + + scene = pyrender.Scene(bg_color=[0.0, 0.0, 0.0, 0.0], + ambient_light=(0.3, 0.3, 0.3)) + scene.add(mesh, 'mesh') + + camera_pose = np.eye(4) + camera_pose[:3, 3] = camera_translation + camera_center = [image.shape[1] / 2., image.shape[0] / 2.] + camera = pyrender.IntrinsicsCamera(fx=focal_length, fy=focal_length, + cx=camera_center[0], cy=camera_center[1]) + scene.add(camera, pose=camera_pose) + + + light_nodes = create_raymond_lights() + for node in light_nodes: + scene.add_node(node) + + color, rend_depth = renderer.render(scene, flags=pyrender.RenderFlags.RGBA) + color = color.astype(np.float32) / 255.0 + valid_mask = (color[:, :, -1] > 0)[:, :, np.newaxis] + if not side_view: + output_img = (color[:, :, :3] * valid_mask + + (1 - valid_mask) * image) + else: + output_img = color[:, :, :3] + if resize is not None: + output_img = cv2.resize(output_img, resize) + + output_img = output_img.astype(np.float32) + renderer.delete() + return output_img diff --git a/wilor/utils/misc.py b/wilor/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..ffcfe784872b305c264ce6ef67fd0a9e9ad3390f --- /dev/null +++ b/wilor/utils/misc.py @@ -0,0 +1,203 @@ +import time +import warnings +from importlib.util import find_spec +from pathlib import Path +from typing import Callable, List + +import hydra +from omegaconf import DictConfig, OmegaConf +from pytorch_lightning import Callback +from pytorch_lightning.loggers import Logger +from pytorch_lightning.utilities import rank_zero_only + +from . import pylogger, rich_utils + +log = pylogger.get_pylogger(__name__) + + +def task_wrapper(task_func: Callable) -> Callable: + """Optional decorator that wraps the task function in extra utilities. + + Makes multirun more resistant to failure. + + Utilities: + - Calling the `utils.extras()` before the task is started + - Calling the `utils.close_loggers()` after the task is finished + - Logging the exception if occurs + - Logging the task total execution time + - Logging the output dir + """ + + def wrap(cfg: DictConfig): + + # apply extra utilities + extras(cfg) + + # execute the task + try: + start_time = time.time() + ret = task_func(cfg=cfg) + except Exception as ex: + log.exception("") # save exception to `.log` file + raise ex + finally: + path = Path(cfg.paths.output_dir, "exec_time.log") + content = f"'{cfg.task_name}' execution time: {time.time() - start_time} (s)" + save_file(path, content) # save task execution time (even if exception occurs) + close_loggers() # close loggers (even if exception occurs so multirun won't fail) + + log.info(f"Output dir: {cfg.paths.output_dir}") + + return ret + + return wrap + + +def extras(cfg: DictConfig) -> None: + """Applies optional utilities before the task is started. + + Utilities: + - Ignoring python warnings + - Setting tags from command line + - Rich config printing + """ + + # return if no `extras` config + if not cfg.get("extras"): + log.warning("Extras config not found! ") + return + + # disable python warnings + if cfg.extras.get("ignore_warnings"): + log.info("Disabling python warnings! ") + warnings.filterwarnings("ignore") + + # prompt user to input tags from command line if none are provided in the config + if cfg.extras.get("enforce_tags"): + log.info("Enforcing tags! ") + rich_utils.enforce_tags(cfg, save_to_file=True) + + # pretty print config tree using Rich library + if cfg.extras.get("print_config"): + log.info("Printing config tree with Rich! ") + rich_utils.print_config_tree(cfg, resolve=True, save_to_file=True) + + +@rank_zero_only +def save_file(path: str, content: str) -> None: + """Save file in rank zero mode (only on one process in multi-GPU setup).""" + with open(path, "w+") as file: + file.write(content) + + +def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]: + """Instantiates callbacks from config.""" + callbacks: List[Callback] = [] + + if not callbacks_cfg: + log.warning("Callbacks config is empty.") + return callbacks + + if not isinstance(callbacks_cfg, DictConfig): + raise TypeError("Callbacks config must be a DictConfig!") + + for _, cb_conf in callbacks_cfg.items(): + if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf: + log.info(f"Instantiating callback <{cb_conf._target_}>") + callbacks.append(hydra.utils.instantiate(cb_conf)) + + return callbacks + + +def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]: + """Instantiates loggers from config.""" + logger: List[Logger] = [] + + if not logger_cfg: + log.warning("Logger config is empty.") + return logger + + if not isinstance(logger_cfg, DictConfig): + raise TypeError("Logger config must be a DictConfig!") + + for _, lg_conf in logger_cfg.items(): + if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf: + log.info(f"Instantiating logger <{lg_conf._target_}>") + logger.append(hydra.utils.instantiate(lg_conf)) + + return logger + + +@rank_zero_only +def log_hyperparameters(object_dict: dict) -> None: + """Controls which config parts are saved by lightning loggers. + + Additionally saves: + - Number of model parameters + """ + + hparams = {} + + cfg = object_dict["cfg"] + model = object_dict["model"] + trainer = object_dict["trainer"] + + if not trainer.logger: + log.warning("Logger not found! Skipping hyperparameter logging...") + return + + # save number of model parameters + hparams["model/params/total"] = sum(p.numel() for p in model.parameters()) + hparams["model/params/trainable"] = sum( + p.numel() for p in model.parameters() if p.requires_grad + ) + hparams["model/params/non_trainable"] = sum( + p.numel() for p in model.parameters() if not p.requires_grad + ) + + for k in cfg.keys(): + hparams[k] = cfg.get(k) + + # Resolve all interpolations + def _resolve(_cfg): + if isinstance(_cfg, DictConfig): + _cfg = OmegaConf.to_container(_cfg, resolve=True) + return _cfg + + hparams = {k: _resolve(v) for k, v in hparams.items()} + + # send hparams to all loggers + trainer.logger.log_hyperparams(hparams) + + +def get_metric_value(metric_dict: dict, metric_name: str) -> float: + """Safely retrieves value of the metric logged in LightningModule.""" + + if not metric_name: + log.info("Metric name is None! Skipping metric value retrieval...") + return None + + if metric_name not in metric_dict: + raise Exception( + f"Metric value not found! \n" + "Make sure metric name logged in LightningModule is correct!\n" + "Make sure `optimized_metric` name in `hparams_search` config is correct!" + ) + + metric_value = metric_dict[metric_name].item() + log.info(f"Retrieved metric value! <{metric_name}={metric_value}>") + + return metric_value + + +def close_loggers() -> None: + """Makes sure all loggers closed properly (prevents logging failure during multirun).""" + + log.info("Closing loggers...") + + if find_spec("wandb"): # if wandb is installed + import wandb + + if wandb.run: + log.info("Closing wandb!") + wandb.finish() diff --git a/wilor/utils/pose_utils.py b/wilor/utils/pose_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1b386e39aca64ce73ae10fea9ba2b767ce9e25b2 --- /dev/null +++ b/wilor/utils/pose_utils.py @@ -0,0 +1,352 @@ +""" +Code adapted from: https://github.com/akanazawa/hmr/blob/master/src/benchmark/eval_util.py +""" + +import torch +import numpy as np +from typing import Optional, Dict, List, Tuple + +def compute_similarity_transform(S1: torch.Tensor, S2: torch.Tensor) -> torch.Tensor: + """ + Computes a similarity transform (sR, t) in a batched way that takes + a set of 3D points S1 (B, N, 3) closest to a set of 3D points S2 (B, N, 3), + where R is a 3x3 rotation matrix, t 3x1 translation, s scale. + i.e. solves the orthogonal Procrutes problem. + Args: + S1 (torch.Tensor): First set of points of shape (B, N, 3). + S2 (torch.Tensor): Second set of points of shape (B, N, 3). + Returns: + (torch.Tensor): The first set of points after applying the similarity transformation. + """ + + batch_size = S1.shape[0] + S1 = S1.permute(0, 2, 1) + S2 = S2.permute(0, 2, 1) + # 1. Remove mean. + mu1 = S1.mean(dim=2, keepdim=True) + mu2 = S2.mean(dim=2, keepdim=True) + X1 = S1 - mu1 + X2 = S2 - mu2 + + # 2. Compute variance of X1 used for scale. + var1 = (X1**2).sum(dim=(1,2)) + + # 3. The outer product of X1 and X2. + K = torch.matmul(X1, X2.permute(0, 2, 1)) + + # 4. Solution that Maximizes trace(R'K) is R=U*V', where U, V are singular vectors of K. + U, s, V = torch.svd(K) + Vh = V.permute(0, 2, 1) + + # Construct Z that fixes the orientation of R to get det(R)=1. + Z = torch.eye(U.shape[1], device=U.device).unsqueeze(0).repeat(batch_size, 1, 1) + Z[:, -1, -1] *= torch.sign(torch.linalg.det(torch.matmul(U, Vh))) + + # Construct R. + R = torch.matmul(torch.matmul(V, Z), U.permute(0, 2, 1)) + + # 5. Recover scale. + trace = torch.matmul(R, K).diagonal(offset=0, dim1=-1, dim2=-2).sum(dim=-1) + scale = (trace / var1).unsqueeze(dim=-1).unsqueeze(dim=-1) + + # 6. Recover translation. + t = mu2 - scale*torch.matmul(R, mu1) + + # 7. Error: + S1_hat = scale*torch.matmul(R, S1) + t + + return S1_hat.permute(0, 2, 1) + +def reconstruction_error(S1, S2) -> np.array: + """ + Computes the mean Euclidean distance of 2 set of points S1, S2 after performing Procrustes alignment. + Args: + S1 (torch.Tensor): First set of points of shape (B, N, 3). + S2 (torch.Tensor): Second set of points of shape (B, N, 3). + Returns: + (np.array): Reconstruction error. + """ + S1_hat = compute_similarity_transform(S1, S2) + re = torch.sqrt( ((S1_hat - S2)** 2).sum(dim=-1)).mean(dim=-1) + return re + +def eval_pose(pred_joints, gt_joints) -> Tuple[np.array, np.array]: + """ + Compute joint errors in mm before and after Procrustes alignment. + Args: + pred_joints (torch.Tensor): Predicted 3D joints of shape (B, N, 3). + gt_joints (torch.Tensor): Ground truth 3D joints of shape (B, N, 3). + Returns: + Tuple[np.array, np.array]: Joint errors in mm before and after alignment. + """ + # Absolute error (MPJPE) + mpjpe = torch.sqrt(((pred_joints - gt_joints) ** 2).sum(dim=-1)).mean(dim=-1).cpu().numpy() + + # Reconstruction_error + r_error = reconstruction_error(pred_joints, gt_joints).cpu().numpy() + return 1000 * mpjpe, 1000 * r_error + +class Evaluator: + + def __init__(self, + dataset_length: int, + dataset: str, + keypoint_list: List, + pelvis_ind: int, + metrics: List = ['mode_mpjpe', 'mode_re', 'min_mpjpe', 'min_re'], + preds: List = ['vertices', 'keypoints_3d'], + pck_thresholds: Optional[List] = None): + """ + Class used for evaluating trained models on different 3D pose datasets. + Args: + dataset_length (int): Total dataset length. + keypoint_list [List]: List of keypoints used for evaluation. + pelvis_ind (int): Index of pelvis keypoint; used for aligning the predictions and ground truth. + metrics [List]: List of evaluation metrics to record. + """ + self.dataset_length = dataset_length + self.dataset = dataset + self.keypoint_list = keypoint_list + self.pelvis_ind = pelvis_ind + self.metrics = metrics + self.preds = preds + if self.metrics is not None: + for metric in self.metrics: + setattr(self, metric, np.zeros((dataset_length,))) + if self.preds is not None: + for pred in self.preds: + if pred == 'vertices': + self.vertices = np.zeros((dataset_length, 778, 3)) + if pred == 'keypoints_3d': + self.keypoints_3d = np.zeros((dataset_length, 21, 3)) + self.counter = 0 + if pck_thresholds is None: + self.pck_evaluator = None + else: + self.pck_evaluator = EvaluatorPCK(pck_thresholds) + + def log(self): + """ + Print current evaluation metrics + """ + if self.counter == 0: + print('Evaluation has not started') + return + print(f'{self.counter} / {self.dataset_length} samples') + if self.pck_evaluator is not None: + self.pck_evaluator.log() + if self.metrics is not None: + for metric in self.metrics: + if metric in ['mode_mpjpe', 'mode_re', 'min_mpjpe', 'min_re']: + unit = 'mm' + else: + unit = '' + print(f'{metric}: {getattr(self, metric)[:self.counter].mean()} {unit}') + print('***') + + def get_metrics_dict(self) -> Dict: + """ + Returns: + Dict: Dictionary of evaluation metrics. + """ + d1 = {metric: getattr(self, metric)[:self.counter].mean() for metric in self.metrics} + if self.pck_evaluator is not None: + d2 = self.pck_evaluator.get_metrics_dict() + d1.update(d2) + return d1 + + def get_preds_dict(self) -> Dict: + """ + Returns: + Dict: Dictionary of evaluation preds. + """ + d1 = {pred: getattr(self, pred)[:self.counter] for pred in self.preds} + return d1 + + def __call__(self, output: Dict, batch: Dict, opt_output: Optional[Dict] = None): + """ + Evaluate current batch. + Args: + output (Dict): Regression output. + batch (Dict): Dictionary containing images and their corresponding annotations. + opt_output (Dict): Optimization output. + """ + if self.pck_evaluator is not None: + self.pck_evaluator(output, batch, opt_output) + + pred_keypoints_3d = output['pred_keypoints_3d'].detach() + pred_keypoints_3d = pred_keypoints_3d[:,None,:,:] + batch_size = pred_keypoints_3d.shape[0] + num_samples = pred_keypoints_3d.shape[1] + gt_keypoints_3d = batch['keypoints_3d'][:, :, :-1].unsqueeze(1).repeat(1, num_samples, 1, 1) + pred_vertices = output['pred_vertices'].detach() + + # Align predictions and ground truth such that the pelvis location is at the origin + pred_keypoints_3d -= pred_keypoints_3d[:, :, [self.pelvis_ind]] + gt_keypoints_3d -= gt_keypoints_3d[:, :, [self.pelvis_ind]] + + # Compute joint errors + mpjpe, re = eval_pose(pred_keypoints_3d.reshape(batch_size * num_samples, -1, 3)[:, self.keypoint_list], gt_keypoints_3d.reshape(batch_size * num_samples, -1 ,3)[:, self.keypoint_list]) + mpjpe = mpjpe.reshape(batch_size, num_samples) + re = re.reshape(batch_size, num_samples) + + # Compute 2d keypoint errors + bbox_expand_factor = batch['bbox_expand_factor'][:,None,None,None].detach() + pred_keypoints_2d = output['pred_keypoints_2d'].detach() + pred_keypoints_2d = pred_keypoints_2d[:,None,:,:]*bbox_expand_factor + gt_keypoints_2d = batch['keypoints_2d'][:,None,:,:].repeat(1, num_samples, 1, 1)*bbox_expand_factor + conf = gt_keypoints_2d[:, :, :, -1].clone() + kp_err = torch.nn.functional.mse_loss( + pred_keypoints_2d, + gt_keypoints_2d[:, :, :, :-1], + reduction='none' + ).sum(dim=3) + kp_l2_loss = (conf * kp_err).mean(dim=2) + kp_l2_loss = kp_l2_loss.detach().cpu().numpy() + + # Compute joint errors after optimization, if available. + if opt_output is not None: + opt_keypoints_3d = opt_output['model_joints'] + opt_keypoints_3d -= opt_keypoints_3d[:, [self.pelvis_ind]] + opt_mpjpe, opt_re = eval_pose(opt_keypoints_3d[:, self.keypoint_list], gt_keypoints_3d[:, 0, self.keypoint_list]) + + # The 0-th sample always corresponds to the mode + if hasattr(self, 'mode_mpjpe'): + mode_mpjpe = mpjpe[:, 0] + self.mode_mpjpe[self.counter:self.counter+batch_size] = mode_mpjpe + if hasattr(self, 'mode_re'): + mode_re = re[:, 0] + self.mode_re[self.counter:self.counter+batch_size] = mode_re + if hasattr(self, 'mode_kpl2'): + mode_kpl2 = kp_l2_loss[:, 0] + self.mode_kpl2[self.counter:self.counter+batch_size] = mode_kpl2 + if hasattr(self, 'min_mpjpe'): + min_mpjpe = mpjpe.min(axis=-1) + self.min_mpjpe[self.counter:self.counter+batch_size] = min_mpjpe + if hasattr(self, 'min_re'): + min_re = re.min(axis=-1) + self.min_re[self.counter:self.counter+batch_size] = min_re + if hasattr(self, 'min_kpl2'): + min_kpl2 = kp_l2_loss.min(axis=-1) + self.min_kpl2[self.counter:self.counter+batch_size] = min_kpl2 + if hasattr(self, 'opt_mpjpe'): + self.opt_mpjpe[self.counter:self.counter+batch_size] = opt_mpjpe + if hasattr(self, 'opt_re'): + self.opt_re[self.counter:self.counter+batch_size] = opt_re + if hasattr(self, 'vertices'): + self.vertices[self.counter:self.counter+batch_size] = pred_vertices.cpu().numpy() + if hasattr(self, 'keypoints_3d'): + if self.dataset == 'HO3D-VAL': + pred_keypoints_3d = pred_keypoints_3d[:,:,[0,5,6,7,9,10,11,17,18,19,13,14,15,1,2,3,4,8,12,16,20]] + self.keypoints_3d[self.counter:self.counter+batch_size] = pred_keypoints_3d.squeeze().cpu().numpy() + + self.counter += batch_size + + if hasattr(self, 'mode_mpjpe') and hasattr(self, 'mode_re'): + return { + 'mode_mpjpe': mode_mpjpe, + 'mode_re': mode_re, + } + else: + return {} + + +class EvaluatorPCK: + + def __init__(self, thresholds: List = [0.05, 0.1, 0.2, 0.3, 0.4, 0.5],): + """ + Class used for evaluating trained models on different 3D pose datasets. + Args: + thresholds [List]: List of PCK thresholds to evaluate. + metrics [List]: List of evaluation metrics to record. + """ + self.thresholds = thresholds + self.pred_kp_2d = [] + self.gt_kp_2d = [] + self.gt_conf_2d = [] + self.scale = [] + self.counter = 0 + + def log(self): + """ + Print current evaluation metrics + """ + if self.counter == 0: + print('Evaluation has not started') + return + print(f'{self.counter} samples') + metrics_dict = self.get_metrics_dict() + for metric in metrics_dict: + print(f'{metric}: {metrics_dict[metric]}') + print('***') + + def get_metrics_dict(self) -> Dict: + """ + Returns: + Dict: Dictionary of evaluation metrics. + """ + pcks = self.compute_pcks() + metrics = {} + for thr, (acc,avg_acc,cnt) in zip(self.thresholds, pcks): + metrics.update({f'kp{i}_pck_{thr}': float(a) for i, a in enumerate(acc) if a>=0}) + metrics.update({f'kpAvg_pck_{thr}': float(avg_acc)}) + return metrics + + def compute_pcks(self): + pred_kp_2d = np.concatenate(self.pred_kp_2d, axis=0) + gt_kp_2d = np.concatenate(self.gt_kp_2d, axis=0) + gt_conf_2d = np.concatenate(self.gt_conf_2d, axis=0) + scale = np.concatenate(self.scale, axis=0) + assert pred_kp_2d.shape == gt_kp_2d.shape + assert pred_kp_2d[..., 0].shape == gt_conf_2d.shape + assert pred_kp_2d.shape[1] == 1 # num_samples + assert scale.shape[0] == gt_conf_2d.shape[0] # num_samples + + pcks = [ + self.keypoint_pck_accuracy( + pred_kp_2d[:, 0, :, :], + gt_kp_2d[:, 0, :, :], + gt_conf_2d[:, 0, :]>0.5, + thr=thr, + scale = scale[:,None] + ) + for thr in self.thresholds + ] + return pcks + + def keypoint_pck_accuracy(self, pred, gt, conf, thr, scale): + dist = np.sqrt(np.sum((pred-gt)**2, axis=2)) + all_joints = conf>0.5 + correct_joints = np.logical_and(dist<=scale*thr, all_joints) + pck = correct_joints.sum(axis=0)/all_joints.sum(axis=0) + return pck, pck.mean(), pck.shape[0] + + def __call__(self, output: Dict, batch: Dict, opt_output: Optional[Dict] = None): + """ + Evaluate current batch. + Args: + output (Dict): Regression output. + batch (Dict): Dictionary containing images and their corresponding annotations. + opt_output (Dict): Optimization output. + """ + pred_keypoints_2d = output['pred_keypoints_2d'].detach() + num_samples = 1 + batch_size = pred_keypoints_2d.shape[0] + + right = batch['right'].detach() + pred_keypoints_2d[:,:,0] = (2*right[:,None]-1)*pred_keypoints_2d[:,:,0] + box_size = batch['box_size'].detach() + box_center = batch['box_center'].detach() + bbox_expand_factor = batch['bbox_expand_factor'].detach() + scale = box_size/bbox_expand_factor + bbox_expand_factor = bbox_expand_factor[:,None,None,None] + pred_keypoints_2d = pred_keypoints_2d*box_size[:,None,None]+box_center[:,None] + pred_keypoints_2d = pred_keypoints_2d[:,None,:,:] + gt_keypoints_2d = batch['orig_keypoints_2d'][:,None,:,:].repeat(1, num_samples, 1, 1) + + self.pred_kp_2d.append(pred_keypoints_2d[:, :, :, :2].detach().cpu().numpy()) + self.gt_conf_2d.append(gt_keypoints_2d[:, :, :, -1].detach().cpu().numpy()) + self.gt_kp_2d.append(gt_keypoints_2d[:, :, :, :2].detach().cpu().numpy()) + self.scale.append(scale.detach().cpu().numpy()) + + self.counter += batch_size \ No newline at end of file diff --git a/wilor/utils/pylogger.py b/wilor/utils/pylogger.py new file mode 100644 index 0000000000000000000000000000000000000000..92ffa71893ec20acde65e44d899334a38d8d1333 --- /dev/null +++ b/wilor/utils/pylogger.py @@ -0,0 +1,17 @@ +import logging + +from pytorch_lightning.utilities import rank_zero_only + + +def get_pylogger(name=__name__) -> logging.Logger: + """Initializes multi-GPU-friendly python command line logger.""" + + logger = logging.getLogger(name) + + # this ensures all logging levels get marked with the rank zero decorator + # otherwise logs would get multiplied for each GPU process in multi-GPU setup + logging_levels = ("debug", "info", "warning", "error", "exception", "fatal", "critical") + for level in logging_levels: + setattr(logger, level, rank_zero_only(getattr(logger, level))) + + return logger diff --git a/wilor/utils/render_openpose.py b/wilor/utils/render_openpose.py new file mode 100644 index 0000000000000000000000000000000000000000..8e51ee8e15f40b85e2766e1f9da42183da0d3d46 --- /dev/null +++ b/wilor/utils/render_openpose.py @@ -0,0 +1,191 @@ +""" +Render OpenPose keypoints. +Code was ported to Python from the official C++ implementation https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/src/openpose/utilities/keypoint.cpp +""" +import cv2 +import math +import numpy as np +from typing import List, Tuple + +def get_keypoints_rectangle(keypoints: np.array, threshold: float) -> Tuple[float, float, float]: + """ + Compute rectangle enclosing keypoints above the threshold. + Args: + keypoints (np.array): Keypoint array of shape (N, 3). + threshold (float): Confidence visualization threshold. + Returns: + Tuple[float, float, float]: Rectangle width, height and area. + """ + valid_ind = keypoints[:, -1] > threshold + if valid_ind.sum() > 0: + valid_keypoints = keypoints[valid_ind][:, :-1] + max_x = valid_keypoints[:,0].max() + max_y = valid_keypoints[:,1].max() + min_x = valid_keypoints[:,0].min() + min_y = valid_keypoints[:,1].min() + width = max_x - min_x + height = max_y - min_y + area = width * height + return width, height, area + else: + return 0,0,0 + +def render_keypoints(img: np.array, + keypoints: np.array, + pairs: List, + colors: List, + thickness_circle_ratio: float, + thickness_line_ratio_wrt_circle: float, + pose_scales: List, + threshold: float = 0.1, + alpha: float = 1.0) -> np.array: + """ + Render keypoints on input image. + Args: + img (np.array): Input image of shape (H, W, 3) with pixel values in the [0,255] range. + keypoints (np.array): Keypoint array of shape (N, 3). + pairs (List): List of keypoint pairs per limb. + colors: (List): List of colors per keypoint. + thickness_circle_ratio (float): Circle thickness ratio. + thickness_line_ratio_wrt_circle (float): Line thickness ratio wrt the circle. + pose_scales (List): List of pose scales. + threshold (float): Only visualize keypoints with confidence above the threshold. + Returns: + (np.array): Image of shape (H, W, 3) with keypoints drawn on top of the original image. + """ + img_orig = img.copy() + width, height = img.shape[1], img.shape[2] + area = width * height + + lineType = 8 + shift = 0 + numberColors = len(colors) + thresholdRectangle = 0.1 + + person_width, person_height, person_area = get_keypoints_rectangle(keypoints, thresholdRectangle) + if person_area > 0: + ratioAreas = min(1, max(person_width / width, person_height / height)) + thicknessRatio = np.maximum(np.round(math.sqrt(area) * thickness_circle_ratio * ratioAreas), 2) + thicknessCircle = np.maximum(1, thicknessRatio if ratioAreas > 0.05 else -np.ones_like(thicknessRatio)) + thicknessLine = np.maximum(1, np.round(thicknessRatio * thickness_line_ratio_wrt_circle)) + radius = thicknessRatio / 2 + + img = np.ascontiguousarray(img.copy()) + for i, pair in enumerate(pairs): + index1, index2 = pair + if keypoints[index1, -1] > threshold and keypoints[index2, -1] > threshold: + thicknessLineScaled = int(round(min(thicknessLine[index1], thicknessLine[index2]) * pose_scales[0])) + colorIndex = index2 + color = colors[colorIndex % numberColors] + keypoint1 = keypoints[index1, :-1].astype(np.int_) + keypoint2 = keypoints[index2, :-1].astype(np.int_) + cv2.line(img, tuple(keypoint1.tolist()), tuple(keypoint2.tolist()), tuple(color.tolist()), thicknessLineScaled, lineType, shift) + for part in range(len(keypoints)): + faceIndex = part + if keypoints[faceIndex, -1] > threshold: + radiusScaled = int(round(radius[faceIndex] * pose_scales[0])) + thicknessCircleScaled = int(round(thicknessCircle[faceIndex] * pose_scales[0])) + colorIndex = part + color = colors[colorIndex % numberColors] + center = keypoints[faceIndex, :-1].astype(np.int_) + cv2.circle(img, tuple(center.tolist()), radiusScaled, tuple(color.tolist()), thicknessCircleScaled, lineType, shift) + return img + +def render_hand_keypoints(img, right_hand_keypoints, threshold=0.1, use_confidence=False, map_fn=lambda x: np.ones_like(x), alpha=1.0): + if use_confidence and map_fn is not None: + #thicknessCircleRatioLeft = 1./50 * map_fn(left_hand_keypoints[:, -1]) + thicknessCircleRatioRight = 1./50 * map_fn(right_hand_keypoints[:, -1]) + else: + #thicknessCircleRatioLeft = 1./50 * np.ones(left_hand_keypoints.shape[0]) + thicknessCircleRatioRight = 1./50 * np.ones(right_hand_keypoints.shape[0]) + thicknessLineRatioWRTCircle = 0.75 + pairs = [0,1, 1,2, 2,3, 3,4, 0,5, 5,6, 6,7, 7,8, 0,9, 9,10, 10,11, 11,12, 0,13, 13,14, 14,15, 15,16, 0,17, 17,18, 18,19, 19,20] + pairs = np.array(pairs).reshape(-1,2) + + colors = [100., 100., 100., + 100., 0., 0., + 150., 0., 0., + 200., 0., 0., + 255., 0., 0., + 100., 100., 0., + 150., 150., 0., + 200., 200., 0., + 255., 255., 0., + 0., 100., 50., + 0., 150., 75., + 0., 200., 100., + 0., 255., 125., + 0., 50., 100., + 0., 75., 150., + 0., 100., 200., + 0., 125., 255., + 100., 0., 100., + 150., 0., 150., + 200., 0., 200., + 255., 0., 255.] + colors = np.array(colors).reshape(-1,3) + #colors = np.zeros_like(colors) + poseScales = [1] + #img = render_keypoints(img, left_hand_keypoints, pairs, colors, thicknessCircleRatioLeft, thicknessLineRatioWRTCircle, poseScales, threshold, alpha=alpha) + img = render_keypoints(img, right_hand_keypoints, pairs, colors, thicknessCircleRatioRight, thicknessLineRatioWRTCircle, poseScales, threshold, alpha=alpha) + #img = render_keypoints(img, right_hand_keypoints, pairs, colors, thickness_circle_ratio, thickness_line_ratio_wrt_circle, pose_scales, 0.1) + return img + +def render_body_keypoints(img: np.array, + body_keypoints: np.array) -> np.array: + """ + Render OpenPose body keypoints on input image. + Args: + img (np.array): Input image of shape (H, W, 3) with pixel values in the [0,255] range. + body_keypoints (np.array): Keypoint array of shape (N, 3); 3 <====> (x, y, confidence). + Returns: + (np.array): Image of shape (H, W, 3) with keypoints drawn on top of the original image. + """ + + thickness_circle_ratio = 1./75. * np.ones(body_keypoints.shape[0]) + thickness_line_ratio_wrt_circle = 0.75 + pairs = [] + pairs = [1,8,1,2,1,5,2,3,3,4,5,6,6,7,8,9,9,10,10,11,8,12,12,13,13,14,1,0,0,15,15,17,0,16,16,18,14,19,19,20,14,21,11,22,22,23,11,24] + pairs = np.array(pairs).reshape(-1,2) + colors = [255., 0., 85., + 255., 0., 0., + 255., 85., 0., + 255., 170., 0., + 255., 255., 0., + 170., 255., 0., + 85., 255., 0., + 0., 255., 0., + 255., 0., 0., + 0., 255., 85., + 0., 255., 170., + 0., 255., 255., + 0., 170., 255., + 0., 85., 255., + 0., 0., 255., + 255., 0., 170., + 170., 0., 255., + 255., 0., 255., + 85., 0., 255., + 0., 0., 255., + 0., 0., 255., + 0., 0., 255., + 0., 255., 255., + 0., 255., 255., + 0., 255., 255.] + colors = np.array(colors).reshape(-1,3) + pose_scales = [1] + return render_keypoints(img, body_keypoints, pairs, colors, thickness_circle_ratio, thickness_line_ratio_wrt_circle, pose_scales, 0.1) + +def render_openpose(img: np.array, + hand_keypoints: np.array) -> np.array: + """ + Render keypoints in the OpenPose format on input image. + Args: + img (np.array): Input image of shape (H, W, 3) with pixel values in the [0,255] range. + body_keypoints (np.array): Keypoint array of shape (N, 3); 3 <====> (x, y, confidence). + Returns: + (np.array): Image of shape (H, W, 3) with keypoints drawn on top of the original image. + """ + #img = render_body_keypoints(img, body_keypoints) + img = render_hand_keypoints(img, hand_keypoints) + return img diff --git a/wilor/utils/renderer.py b/wilor/utils/renderer.py new file mode 100644 index 0000000000000000000000000000000000000000..0e161bb05921e52a684427e3eb87c4f8739a5d89 --- /dev/null +++ b/wilor/utils/renderer.py @@ -0,0 +1,423 @@ +import os +if 'PYOPENGL_PLATFORM' not in os.environ: + os.environ['PYOPENGL_PLATFORM'] = 'egl' +import torch +import numpy as np +import pyrender +import trimesh +import cv2 +from yacs.config import CfgNode +from typing import List, Optional + +def cam_crop_to_full(cam_bbox, box_center, box_size, img_size, focal_length=5000.): + # Convert cam_bbox to full image + img_w, img_h = img_size[:, 0], img_size[:, 1] + cx, cy, b = box_center[:, 0], box_center[:, 1], box_size + w_2, h_2 = img_w / 2., img_h / 2. + bs = b * cam_bbox[:, 0] + 1e-9 + tz = 2 * focal_length / bs + tx = (2 * (cx - w_2) / bs) + cam_bbox[:, 1] + ty = (2 * (cy - h_2) / bs) + cam_bbox[:, 2] + full_cam = torch.stack([tx, ty, tz], dim=-1) + return full_cam + +def get_light_poses(n_lights=5, elevation=np.pi / 3, dist=12): + # get lights in a circle around origin at elevation + thetas = elevation * np.ones(n_lights) + phis = 2 * np.pi * np.arange(n_lights) / n_lights + poses = [] + trans = make_translation(torch.tensor([0, 0, dist])) + for phi, theta in zip(phis, thetas): + rot = make_rotation(rx=-theta, ry=phi, order="xyz") + poses.append((rot @ trans).numpy()) + return poses + +def make_translation(t): + return make_4x4_pose(torch.eye(3), t) + +def make_rotation(rx=0, ry=0, rz=0, order="xyz"): + Rx = rotx(rx) + Ry = roty(ry) + Rz = rotz(rz) + if order == "xyz": + R = Rz @ Ry @ Rx + elif order == "xzy": + R = Ry @ Rz @ Rx + elif order == "yxz": + R = Rz @ Rx @ Ry + elif order == "yzx": + R = Rx @ Rz @ Ry + elif order == "zyx": + R = Rx @ Ry @ Rz + elif order == "zxy": + R = Ry @ Rx @ Rz + return make_4x4_pose(R, torch.zeros(3)) + +def make_4x4_pose(R, t): + """ + :param R (*, 3, 3) + :param t (*, 3) + return (*, 4, 4) + """ + dims = R.shape[:-2] + pose_3x4 = torch.cat([R, t.view(*dims, 3, 1)], dim=-1) + bottom = ( + torch.tensor([0, 0, 0, 1], device=R.device) + .reshape(*(1,) * len(dims), 1, 4) + .expand(*dims, 1, 4) + ) + return torch.cat([pose_3x4, bottom], dim=-2) + + +def rotx(theta): + return torch.tensor( + [ + [1, 0, 0], + [0, np.cos(theta), -np.sin(theta)], + [0, np.sin(theta), np.cos(theta)], + ], + dtype=torch.float32, + ) + + +def roty(theta): + return torch.tensor( + [ + [np.cos(theta), 0, np.sin(theta)], + [0, 1, 0], + [-np.sin(theta), 0, np.cos(theta)], + ], + dtype=torch.float32, + ) + + +def rotz(theta): + return torch.tensor( + [ + [np.cos(theta), -np.sin(theta), 0], + [np.sin(theta), np.cos(theta), 0], + [0, 0, 1], + ], + dtype=torch.float32, + ) + + +def create_raymond_lights() -> List[pyrender.Node]: + """ + Return raymond light nodes for the scene. + """ + thetas = np.pi * np.array([1.0 / 6.0, 1.0 / 6.0, 1.0 / 6.0]) + phis = np.pi * np.array([0.0, 2.0 / 3.0, 4.0 / 3.0]) + + nodes = [] + + for phi, theta in zip(phis, thetas): + xp = np.sin(theta) * np.cos(phi) + yp = np.sin(theta) * np.sin(phi) + zp = np.cos(theta) + + z = np.array([xp, yp, zp]) + z = z / np.linalg.norm(z) + x = np.array([-z[1], z[0], 0.0]) + if np.linalg.norm(x) == 0: + x = np.array([1.0, 0.0, 0.0]) + x = x / np.linalg.norm(x) + y = np.cross(z, x) + + matrix = np.eye(4) + matrix[:3,:3] = np.c_[x,y,z] + nodes.append(pyrender.Node( + light=pyrender.DirectionalLight(color=np.ones(3), intensity=1.0), + matrix=matrix + )) + + return nodes + +class Renderer: + + def __init__(self, cfg: CfgNode, faces: np.array): + """ + Wrapper around the pyrender renderer to render MANO meshes. + Args: + cfg (CfgNode): Model config file. + faces (np.array): Array of shape (F, 3) containing the mesh faces. + """ + self.cfg = cfg + self.focal_length = cfg.EXTRA.FOCAL_LENGTH + self.img_res = cfg.MODEL.IMAGE_SIZE + + # add faces that make the hand mesh watertight + faces_new = np.array([[92, 38, 234], + [234, 38, 239], + [38, 122, 239], + [239, 122, 279], + [122, 118, 279], + [279, 118, 215], + [118, 117, 215], + [215, 117, 214], + [117, 119, 214], + [214, 119, 121], + [119, 120, 121], + [121, 120, 78], + [120, 108, 78], + [78, 108, 79]]) + faces = np.concatenate([faces, faces_new], axis=0) + + self.camera_center = [self.img_res // 2, self.img_res // 2] + self.faces = faces + self.faces_left = self.faces[:,[0,2,1]] + + def __call__(self, + vertices: np.array, + camera_translation: np.array, + image: torch.Tensor, + full_frame: bool = False, + imgname: Optional[str] = None, + side_view=False, rot_angle=90, + mesh_base_color=(1.0, 1.0, 0.9), + scene_bg_color=(0,0,0), + return_rgba=False, + ) -> np.array: + """ + Render meshes on input image + Args: + vertices (np.array): Array of shape (V, 3) containing the mesh vertices. + camera_translation (np.array): Array of shape (3,) with the camera translation. + image (torch.Tensor): Tensor of shape (3, H, W) containing the image crop with normalized pixel values. + full_frame (bool): If True, then render on the full image. + imgname (Optional[str]): Contains the original image filenamee. Used only if full_frame == True. + """ + + if full_frame: + image = cv2.imread(imgname).astype(np.float32)[:, :, ::-1] / 255. + else: + image = image.clone() * torch.tensor(self.cfg.MODEL.IMAGE_STD, device=image.device).reshape(3,1,1) + image = image + torch.tensor(self.cfg.MODEL.IMAGE_MEAN, device=image.device).reshape(3,1,1) + image = image.permute(1, 2, 0).cpu().numpy() + + renderer = pyrender.OffscreenRenderer(viewport_width=image.shape[1], + viewport_height=image.shape[0], + point_size=1.0) + material = pyrender.MetallicRoughnessMaterial( + metallicFactor=0.0, + alphaMode='OPAQUE', + baseColorFactor=(*mesh_base_color, 1.0)) + + camera_translation[0] *= -1. + + mesh = trimesh.Trimesh(vertices.copy(), self.faces.copy()) + if side_view: + rot = trimesh.transformations.rotation_matrix( + np.radians(rot_angle), [0, 1, 0]) + mesh.apply_transform(rot) + rot = trimesh.transformations.rotation_matrix( + np.radians(180), [1, 0, 0]) + mesh.apply_transform(rot) + mesh = pyrender.Mesh.from_trimesh(mesh, material=material) + + scene = pyrender.Scene(bg_color=[*scene_bg_color, 0.0], + ambient_light=(0.3, 0.3, 0.3)) + scene.add(mesh, 'mesh') + + camera_pose = np.eye(4) + camera_pose[:3, 3] = camera_translation + camera_center = [image.shape[1] / 2., image.shape[0] / 2.] + camera = pyrender.IntrinsicsCamera(fx=self.focal_length, fy=self.focal_length, + cx=camera_center[0], cy=camera_center[1], zfar=1e12) + scene.add(camera, pose=camera_pose) + + + light_nodes = create_raymond_lights() + for node in light_nodes: + scene.add_node(node) + + color, rend_depth = renderer.render(scene, flags=pyrender.RenderFlags.RGBA) + color = color.astype(np.float32) / 255.0 + renderer.delete() + + if return_rgba: + return color + + valid_mask = (color[:, :, -1])[:, :, np.newaxis] + if not side_view: + output_img = (color[:, :, :3] * valid_mask + (1 - valid_mask) * image) + else: + output_img = color[:, :, :3] + + output_img = output_img.astype(np.float32) + return output_img + + def vertices_to_trimesh(self, vertices, camera_translation, mesh_base_color=(1.0, 1.0, 0.9), + rot_axis=[1,0,0], rot_angle=0, is_right=1): + # material = pyrender.MetallicRoughnessMaterial( + # metallicFactor=0.0, + # alphaMode='OPAQUE', + # baseColorFactor=(*mesh_base_color, 1.0)) + vertex_colors = np.array([(*mesh_base_color, 1.0)] * vertices.shape[0]) + if is_right: + mesh = trimesh.Trimesh(vertices.copy() + camera_translation, self.faces.copy(), vertex_colors=vertex_colors) + else: + mesh = trimesh.Trimesh(vertices.copy() + camera_translation, self.faces_left.copy(), vertex_colors=vertex_colors) + # mesh = trimesh.Trimesh(vertices.copy(), self.faces.copy()) + + rot = trimesh.transformations.rotation_matrix( + np.radians(rot_angle), rot_axis) + mesh.apply_transform(rot) + + rot = trimesh.transformations.rotation_matrix( + np.radians(180), [1, 0, 0]) + mesh.apply_transform(rot) + return mesh + + def render_rgba( + self, + vertices: np.array, + cam_t = None, + rot=None, + rot_axis=[1,0,0], + rot_angle=0, + camera_z=3, + # camera_translation: np.array, + mesh_base_color=(1.0, 1.0, 0.9), + scene_bg_color=(0,0,0), + render_res=[256, 256], + focal_length=None, + is_right=None, + ): + + renderer = pyrender.OffscreenRenderer(viewport_width=render_res[0], + viewport_height=render_res[1], + point_size=1.0) + # material = pyrender.MetallicRoughnessMaterial( + # metallicFactor=0.0, + # alphaMode='OPAQUE', + # baseColorFactor=(*mesh_base_color, 1.0)) + + focal_length = focal_length if focal_length is not None else self.focal_length + + if cam_t is not None: + camera_translation = cam_t.copy() + camera_translation[0] *= -1. + else: + camera_translation = np.array([0, 0, camera_z * focal_length/render_res[1]]) + + mesh = self.vertices_to_trimesh(vertices, np.array([0, 0, 0]), mesh_base_color, rot_axis, rot_angle, is_right=is_right) + mesh = pyrender.Mesh.from_trimesh(mesh) + # mesh = pyrender.Mesh.from_trimesh(mesh, material=material) + + scene = pyrender.Scene(bg_color=[*scene_bg_color, 0.0], + ambient_light=(0.3, 0.3, 0.3)) + scene.add(mesh, 'mesh') + + camera_pose = np.eye(4) + camera_pose[:3, 3] = camera_translation + camera_center = [render_res[0] / 2., render_res[1] / 2.] + camera = pyrender.IntrinsicsCamera(fx=focal_length, fy=focal_length, + cx=camera_center[0], cy=camera_center[1], zfar=1e12) + + # Create camera node and add it to pyRender scene + camera_node = pyrender.Node(camera=camera, matrix=camera_pose) + scene.add_node(camera_node) + self.add_point_lighting(scene, camera_node) + self.add_lighting(scene, camera_node) + + light_nodes = create_raymond_lights() + for node in light_nodes: + scene.add_node(node) + + color, rend_depth = renderer.render(scene, flags=pyrender.RenderFlags.RGBA) + color = color.astype(np.float32) / 255.0 + renderer.delete() + + return color + + def render_rgba_multiple( + self, + vertices: List[np.array], + cam_t: List[np.array], + rot_axis=[1,0,0], + rot_angle=0, + mesh_base_color=(1.0, 1.0, 0.9), + scene_bg_color=(0,0,0), + render_res=[256, 256], + focal_length=None, + is_right=None, + ): + + renderer = pyrender.OffscreenRenderer(viewport_width=render_res[0], + viewport_height=render_res[1], + point_size=1.0) + # material = pyrender.MetallicRoughnessMaterial( + # metallicFactor=0.0, + # alphaMode='OPAQUE', + # baseColorFactor=(*mesh_base_color, 1.0)) + + if is_right is None: + is_right = [1 for _ in range(len(vertices))] + + mesh_list = [pyrender.Mesh.from_trimesh(self.vertices_to_trimesh(vvv, ttt.copy(), mesh_base_color, rot_axis, rot_angle, is_right=sss)) for vvv,ttt,sss in zip(vertices, cam_t, is_right)] + + scene = pyrender.Scene(bg_color=[*scene_bg_color, 0.0], + ambient_light=(0.3, 0.3, 0.3)) + for i,mesh in enumerate(mesh_list): + scene.add(mesh, f'mesh_{i}') + + camera_pose = np.eye(4) + # camera_pose[:3, 3] = camera_translation + camera_center = [render_res[0] / 2., render_res[1] / 2.] + focal_length = focal_length if focal_length is not None else self.focal_length + camera = pyrender.IntrinsicsCamera(fx=focal_length, fy=focal_length, + cx=camera_center[0], cy=camera_center[1], zfar=1e12) + + # Create camera node and add it to pyRender scene + camera_node = pyrender.Node(camera=camera, matrix=camera_pose) + scene.add_node(camera_node) + self.add_point_lighting(scene, camera_node) + self.add_lighting(scene, camera_node) + + light_nodes = create_raymond_lights() + for node in light_nodes: + scene.add_node(node) + + color, rend_depth = renderer.render(scene, flags=pyrender.RenderFlags.RGBA) + color = color.astype(np.float32) / 255.0 + renderer.delete() + + return color + + def add_lighting(self, scene, cam_node, color=np.ones(3), intensity=1.0): + # from phalp.visualize.py_renderer import get_light_poses + light_poses = get_light_poses() + light_poses.append(np.eye(4)) + cam_pose = scene.get_pose(cam_node) + for i, pose in enumerate(light_poses): + matrix = cam_pose @ pose + node = pyrender.Node( + name=f"light-{i:02d}", + light=pyrender.DirectionalLight(color=color, intensity=intensity), + matrix=matrix, + ) + if scene.has_node(node): + continue + scene.add_node(node) + + def add_point_lighting(self, scene, cam_node, color=np.ones(3), intensity=1.0): + # from phalp.visualize.py_renderer import get_light_poses + light_poses = get_light_poses(dist=0.5) + light_poses.append(np.eye(4)) + cam_pose = scene.get_pose(cam_node) + for i, pose in enumerate(light_poses): + matrix = cam_pose @ pose + # node = pyrender.Node( + # name=f"light-{i:02d}", + # light=pyrender.DirectionalLight(color=color, intensity=intensity), + # matrix=matrix, + # ) + node = pyrender.Node( + name=f"plight-{i:02d}", + light=pyrender.PointLight(color=color, intensity=intensity), + matrix=matrix, + ) + if scene.has_node(node): + continue + scene.add_node(node) diff --git a/wilor/utils/rich_utils.py b/wilor/utils/rich_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..19f97494ed2958ec2c3d75c772360b5367f2dc7b --- /dev/null +++ b/wilor/utils/rich_utils.py @@ -0,0 +1,105 @@ +from pathlib import Path +from typing import Sequence + +import rich +import rich.syntax +import rich.tree +from hydra.core.hydra_config import HydraConfig +from omegaconf import DictConfig, OmegaConf, open_dict +from pytorch_lightning.utilities import rank_zero_only +from rich.prompt import Prompt + +from . import pylogger + +log = pylogger.get_pylogger(__name__) + + +@rank_zero_only +def print_config_tree( + cfg: DictConfig, + print_order: Sequence[str] = ( + "datamodule", + "model", + "callbacks", + "logger", + "trainer", + "paths", + "extras", + ), + resolve: bool = False, + save_to_file: bool = False, +) -> None: + """Prints content of DictConfig using Rich library and its tree structure. + + Args: + cfg (DictConfig): Configuration composed by Hydra. + print_order (Sequence[str], optional): Determines in what order config components are printed. + resolve (bool, optional): Whether to resolve reference fields of DictConfig. + save_to_file (bool, optional): Whether to export config to the hydra output folder. + """ + + style = "dim" + tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) + + queue = [] + + # add fields from `print_order` to queue + for field in print_order: + queue.append(field) if field in cfg else log.warning( + f"Field '{field}' not found in config. Skipping '{field}' config printing..." + ) + + # add all the other fields to queue (not specified in `print_order`) + for field in cfg: + if field not in queue: + queue.append(field) + + # generate config tree from queue + for field in queue: + branch = tree.add(field, style=style, guide_style=style) + + config_group = cfg[field] + if isinstance(config_group, DictConfig): + branch_content = OmegaConf.to_yaml(config_group, resolve=resolve) + else: + branch_content = str(config_group) + + branch.add(rich.syntax.Syntax(branch_content, "yaml")) + + # print config tree + rich.print(tree) + + # save config tree to file + if save_to_file: + with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file: + rich.print(tree, file=file) + + +@rank_zero_only +def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None: + """Prompts user to input tags from command line if no tags are provided in config.""" + + if not cfg.get("tags"): + if "id" in HydraConfig().cfg.hydra.job: + raise ValueError("Specify tags before launching a multirun!") + + log.warning("No tags provided in config. Prompting user to input tags...") + tags = Prompt.ask("Enter a list of comma separated tags", default="dev") + tags = [t.strip() for t in tags.split(",") if t != ""] + + with open_dict(cfg): + cfg.tags = tags + + log.info(f"Tags: {cfg.tags}") + + if save_to_file: + with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file: + rich.print(cfg.tags, file=file) + + +if __name__ == "__main__": + from hydra import compose, initialize + + with initialize(version_base="1.2", config_path="../../configs"): + cfg = compose(config_name="train.yaml", return_hydra_config=False, overrides=[]) + print_config_tree(cfg, resolve=False, save_to_file=False) diff --git a/wilor/utils/skeleton_renderer.py b/wilor/utils/skeleton_renderer.py new file mode 100644 index 0000000000000000000000000000000000000000..46a5df75bff887eab00984eeb5be3c1f6e752960 --- /dev/null +++ b/wilor/utils/skeleton_renderer.py @@ -0,0 +1,124 @@ +import torch +import numpy as np +import trimesh +from typing import Optional +from yacs.config import CfgNode + +from .geometry import perspective_projection +from .render_openpose import render_openpose + +class SkeletonRenderer: + + def __init__(self, cfg: CfgNode): + """ + Object used to render 3D keypoints. Faster for use during training. + Args: + cfg (CfgNode): Model config file. + """ + self.cfg = cfg + + def __call__(self, + pred_keypoints_3d: torch.Tensor, + gt_keypoints_3d: torch.Tensor, + gt_keypoints_2d: torch.Tensor, + images: Optional[np.array] = None, + camera_translation: Optional[torch.Tensor] = None) -> np.array: + """ + Render batch of 3D keypoints. + Args: + pred_keypoints_3d (torch.Tensor): Tensor of shape (B, S, N, 3) containing a batch of predicted 3D keypoints, with S samples per image. + gt_keypoints_3d (torch.Tensor): Tensor of shape (B, N, 4) containing corresponding ground truth 3D keypoints; last value is the confidence. + gt_keypoints_2d (torch.Tensor): Tensor of shape (B, N, 3) containing corresponding ground truth 2D keypoints. + images (torch.Tensor): Tensor of shape (B, H, W, 3) containing images with values in the [0,255] range. + camera_translation (torch.Tensor): Tensor of shape (B, 3) containing the camera translation. + Returns: + np.array : Image with the following layout. Each row contains the a) input image, + b) image with gt 2D keypoints, + c) image with projected gt 3D keypoints, + d_1, ... , d_S) image with projected predicted 3D keypoints, + e) gt 3D keypoints rendered from a side view, + f_1, ... , f_S) predicted 3D keypoints frorm a side view + """ + batch_size = pred_keypoints_3d.shape[0] +# num_samples = pred_keypoints_3d.shape[1] + pred_keypoints_3d = pred_keypoints_3d.clone().cpu().float() + gt_keypoints_3d = gt_keypoints_3d.clone().cpu().float() + gt_keypoints_3d[:, :, :-1] = gt_keypoints_3d[:, :, :-1] - gt_keypoints_3d[:, [0], :-1] + pred_keypoints_3d[:, [0]] + gt_keypoints_2d = gt_keypoints_2d.clone().cpu().float().numpy() + gt_keypoints_2d[:, :, :-1] = self.cfg.MODEL.IMAGE_SIZE * (gt_keypoints_2d[:, :, :-1] + 1.0) / 2.0 + + #openpose_indices = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14] + #gt_indices = [12, 8, 7, 6, 9, 10, 11, 14, 2, 1, 0, 3, 4, 5] + #gt_indices = [25 + i for i in gt_indices] + openpose_indices = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20] + gt_indices = openpose_indices + keypoints_to_render = torch.ones(batch_size, gt_keypoints_3d.shape[1], 1) + rotation = torch.eye(3).unsqueeze(0) + if camera_translation is None: + camera_translation = torch.tensor([0.0, 0.0, 2 * self.cfg.EXTRA.FOCAL_LENGTH / (0.8 * self.cfg.MODEL.IMAGE_SIZE)]).unsqueeze(0).repeat(batch_size, 1) + else: + camera_translation = camera_translation.cpu() + + if images is None: + images = np.zeros((batch_size, self.cfg.MODEL.IMAGE_SIZE, self.cfg.MODEL.IMAGE_SIZE, 3)) + focal_length = torch.tensor([self.cfg.EXTRA.FOCAL_LENGTH, self.cfg.EXTRA.FOCAL_LENGTH]).reshape(1, 2) + camera_center = torch.tensor([self.cfg.MODEL.IMAGE_SIZE, self.cfg.MODEL.IMAGE_SIZE], dtype=torch.float).reshape(1, 2) / 2. + gt_keypoints_3d_proj = perspective_projection(gt_keypoints_3d[:, :, :-1], rotation=rotation.repeat(batch_size, 1, 1), translation=camera_translation[:, :], focal_length=focal_length.repeat(batch_size, 1), camera_center=camera_center.repeat(batch_size, 1)) + pred_keypoints_3d_proj = perspective_projection(pred_keypoints_3d.reshape(batch_size, -1, 3), rotation=rotation.repeat(batch_size, 1, 1), translation=camera_translation.reshape(batch_size, -1), focal_length=focal_length.repeat(batch_size, 1), camera_center=camera_center.repeat(batch_size, 1)).reshape(batch_size, -1, 2) + gt_keypoints_3d_proj = torch.cat([gt_keypoints_3d_proj, gt_keypoints_3d[:, :, [-1]]], dim=-1).cpu().numpy() + pred_keypoints_3d_proj = torch.cat([pred_keypoints_3d_proj, keypoints_to_render.reshape(batch_size, -1, 1)], dim=-1).cpu().numpy() + rows = [] + # Rotate keypoints to visualize side view + R = torch.tensor(trimesh.transformations.rotation_matrix(np.radians(90), [0, 1, 0])[:3, :3]).float() + gt_keypoints_3d_side = gt_keypoints_3d.clone() + gt_keypoints_3d_side[:, :, :-1] = torch.einsum('bni,ij->bnj', gt_keypoints_3d_side[:, :, :-1], R) + pred_keypoints_3d_side = pred_keypoints_3d.clone() + pred_keypoints_3d_side = torch.einsum('bni,ij->bnj', pred_keypoints_3d_side, R) + gt_keypoints_3d_proj_side = perspective_projection(gt_keypoints_3d_side[:, :, :-1], rotation=rotation.repeat(batch_size, 1, 1), translation=camera_translation[:, :], focal_length=focal_length.repeat(batch_size, 1), camera_center=camera_center.repeat(batch_size, 1)) + pred_keypoints_3d_proj_side = perspective_projection(pred_keypoints_3d_side.reshape(batch_size, -1, 3), rotation=rotation.repeat(batch_size, 1, 1), translation=camera_translation.reshape(batch_size, -1), focal_length=focal_length.repeat(batch_size, 1), camera_center=camera_center.repeat(batch_size, 1)).reshape(batch_size, -1, 2) + gt_keypoints_3d_proj_side = torch.cat([gt_keypoints_3d_proj_side, gt_keypoints_3d_side[:, :, [-1]]], dim=-1).cpu().numpy() + pred_keypoints_3d_proj_side = torch.cat([pred_keypoints_3d_proj_side, keypoints_to_render.reshape(batch_size, -1, 1)], dim=-1).cpu().numpy() + for i in range(batch_size): + img = images[i] + side_img = np.zeros((self.cfg.MODEL.IMAGE_SIZE, self.cfg.MODEL.IMAGE_SIZE, 3)) + # gt 2D keypoints + body_keypoints_2d = gt_keypoints_2d[i, :21].copy() + for op, gt in zip(openpose_indices, gt_indices): + if gt_keypoints_2d[i, gt, -1] > body_keypoints_2d[op, -1]: + body_keypoints_2d[op] = gt_keypoints_2d[i, gt] + gt_keypoints_img = render_openpose(img, body_keypoints_2d) / 255. + # gt 3D keypoints + body_keypoints_3d_proj = gt_keypoints_3d_proj[i, :21].copy() + for op, gt in zip(openpose_indices, gt_indices): + if gt_keypoints_3d_proj[i, gt, -1] > body_keypoints_3d_proj[op, -1]: + body_keypoints_3d_proj[op] = gt_keypoints_3d_proj[i, gt] + gt_keypoints_3d_proj_img = render_openpose(img, body_keypoints_3d_proj) / 255. + # gt 3D keypoints from the side + body_keypoints_3d_proj = gt_keypoints_3d_proj_side[i, :21].copy() + for op, gt in zip(openpose_indices, gt_indices): + if gt_keypoints_3d_proj_side[i, gt, -1] > body_keypoints_3d_proj[op, -1]: + body_keypoints_3d_proj[op] = gt_keypoints_3d_proj_side[i, gt] + gt_keypoints_3d_proj_img_side = render_openpose(side_img, body_keypoints_3d_proj) / 255. + # pred 3D keypoints + pred_keypoints_3d_proj_imgs = [] + body_keypoints_3d_proj = pred_keypoints_3d_proj[i, :21].copy() + for op, gt in zip(openpose_indices, gt_indices): + if pred_keypoints_3d_proj[i, gt, -1] >= body_keypoints_3d_proj[op, -1]: + body_keypoints_3d_proj[op] = pred_keypoints_3d_proj[i, gt] + pred_keypoints_3d_proj_imgs.append(render_openpose(img, body_keypoints_3d_proj) / 255.) + pred_keypoints_3d_proj_img = np.concatenate(pred_keypoints_3d_proj_imgs, axis=1) + # gt 3D keypoints from the side + pred_keypoints_3d_proj_imgs_side = [] + body_keypoints_3d_proj = pred_keypoints_3d_proj_side[i, :21].copy() + for op, gt in zip(openpose_indices, gt_indices): + if pred_keypoints_3d_proj_side[i, gt, -1] >= body_keypoints_3d_proj[op, -1]: + body_keypoints_3d_proj[op] = pred_keypoints_3d_proj_side[i, gt] + pred_keypoints_3d_proj_imgs_side.append(render_openpose(side_img, body_keypoints_3d_proj) / 255.) + pred_keypoints_3d_proj_img_side = np.concatenate(pred_keypoints_3d_proj_imgs_side, axis=1) + rows.append(np.concatenate((gt_keypoints_img, gt_keypoints_3d_proj_img, pred_keypoints_3d_proj_img, gt_keypoints_3d_proj_img_side, pred_keypoints_3d_proj_img_side), axis=1)) + # Concatenate images + img = np.concatenate(rows, axis=0) + img[:, ::self.cfg.MODEL.IMAGE_SIZE, :] = 1.0 + img[::self.cfg.MODEL.IMAGE_SIZE, :, :] = 1.0 + img[:, (1+1+1)*self.cfg.MODEL.IMAGE_SIZE, :] = 0.5 + return img