junbiao.chen commited on
Commit
cc0c59d
·
1 Parent(s): f29eac5

Trellis update

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. trellis/datasets/__init__.py +58 -0
  2. trellis/datasets/components.py +137 -0
  3. trellis/datasets/sparse_feat2render.py +134 -0
  4. trellis/datasets/sparse_structure.py +107 -0
  5. trellis/datasets/sparse_structure_latent.py +189 -0
  6. trellis/datasets/structured_latent.py +218 -0
  7. trellis/datasets/structured_latent2render.py +160 -0
  8. trellis/models/__init__.py +29 -3
  9. trellis/models/sparse_elastic_mixin.py +24 -0
  10. trellis/models/structured_latent_flow.py +50 -36
  11. trellis/models/structured_latent_vae/__init__.py +4 -4
  12. trellis/models/structured_latent_vae/decoder_gs.py +9 -0
  13. trellis/models/structured_latent_vae/decoder_mesh.py +9 -0
  14. trellis/models/structured_latent_vae/decoder_rf.py +9 -0
  15. trellis/models/structured_latent_vae/encoder.py +8 -0
  16. trellis/pipelines/__init__.py +1 -0
  17. trellis/pipelines/base.py +6 -4
  18. trellis/pipelines/samplers/flow_euler.py +2 -0
  19. trellis/pipelines/trellis_image_to_3d.py +2 -3
  20. trellis/pipelines/trellis_text_to_3d.py +278 -0
  21. trellis/representations/mesh/cube2mesh.py +1 -8
  22. trellis/representations/mesh/flexicubes/LICENSE.txt +90 -0
  23. trellis/representations/mesh/flexicubes/README.md +110 -0
  24. trellis/representations/mesh/flexicubes/examples/data/inputmodels/block.obj +0 -0
  25. trellis/representations/mesh/flexicubes/examples/download_data.py +41 -0
  26. trellis/representations/mesh/flexicubes/examples/extraction.ipynb +0 -0
  27. trellis/representations/mesh/flexicubes/examples/loss.py +95 -0
  28. trellis/representations/mesh/flexicubes/examples/optimization.ipynb +0 -0
  29. trellis/representations/mesh/flexicubes/examples/optimize.py +150 -0
  30. trellis/representations/mesh/flexicubes/examples/render.py +267 -0
  31. trellis/representations/mesh/flexicubes/examples/util.py +122 -0
  32. trellis/representations/mesh/flexicubes/flexicubes.py +384 -0
  33. trellis/representations/mesh/flexicubes/images/ablate_L_dev.jpg +0 -0
  34. trellis/representations/mesh/flexicubes/images/block_final.png +3 -0
  35. trellis/representations/mesh/flexicubes/images/block_init.png +3 -0
  36. trellis/representations/mesh/flexicubes/images/teaser_top.png +3 -0
  37. trellis/representations/mesh/flexicubes/tables.py +791 -0
  38. trellis/representations/octree/octree_dfs.py +3 -18
  39. trellis/trainers/__init__.py +63 -0
  40. trellis/trainers/base.py +451 -0
  41. trellis/trainers/basic.py +438 -0
  42. trellis/trainers/flow_matching/flow_matching.py +353 -0
  43. trellis/trainers/flow_matching/mixins/classifier_free_guidance.py +59 -0
  44. trellis/trainers/flow_matching/mixins/image_conditioned.py +93 -0
  45. trellis/trainers/flow_matching/mixins/text_conditioned.py +68 -0
  46. trellis/trainers/flow_matching/sparse_flow_matching.py +286 -0
  47. trellis/trainers/utils.py +77 -0
  48. trellis/trainers/vae/sparse_structure_vae.py +130 -0
  49. trellis/trainers/vae/structured_latent_vae_gaussian.py +275 -0
  50. trellis/trainers/vae/structured_latent_vae_mesh_dec.py +382 -0
trellis/datasets/__init__.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+
3
+ __attributes = {
4
+ 'SparseStructure': 'sparse_structure',
5
+
6
+ 'SparseFeat2Render': 'sparse_feat2render',
7
+ 'SLat2Render':'structured_latent2render',
8
+ 'Slat2RenderGeo':'structured_latent2render',
9
+
10
+ 'SparseStructureLatent': 'sparse_structure_latent',
11
+ 'TextConditionedSparseStructureLatent': 'sparse_structure_latent',
12
+ 'ImageConditionedSparseStructureLatent': 'sparse_structure_latent',
13
+
14
+ 'SLat': 'structured_latent',
15
+ 'TextConditionedSLat': 'structured_latent',
16
+ 'ImageConditionedSLat': 'structured_latent',
17
+ }
18
+
19
+ __submodules = []
20
+
21
+ __all__ = list(__attributes.keys()) + __submodules
22
+
23
+ def __getattr__(name):
24
+ if name not in globals():
25
+ if name in __attributes:
26
+ module_name = __attributes[name]
27
+ module = importlib.import_module(f".{module_name}", __name__)
28
+ globals()[name] = getattr(module, name)
29
+ elif name in __submodules:
30
+ module = importlib.import_module(f".{name}", __name__)
31
+ globals()[name] = module
32
+ else:
33
+ raise AttributeError(f"module {__name__} has no attribute {name}")
34
+ return globals()[name]
35
+
36
+
37
+ # For Pylance
38
+ if __name__ == '__main__':
39
+ from .sparse_structure import SparseStructure
40
+
41
+ from .sparse_feat2render import SparseFeat2Render
42
+ from .structured_latent2render import (
43
+ SLat2Render,
44
+ Slat2RenderGeo,
45
+ )
46
+
47
+ from .sparse_structure_latent import (
48
+ SparseStructureLatent,
49
+ TextConditionedSparseStructureLatent,
50
+ ImageConditionedSparseStructureLatent,
51
+ )
52
+
53
+ from .structured_latent import (
54
+ SLat,
55
+ TextConditionedSLat,
56
+ ImageConditionedSLat,
57
+ )
58
+
trellis/datasets/components.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ from abc import abstractmethod
3
+ import os
4
+ import json
5
+ import torch
6
+ import numpy as np
7
+ import pandas as pd
8
+ from PIL import Image
9
+ from torch.utils.data import Dataset
10
+
11
+
12
+ class StandardDatasetBase(Dataset):
13
+ """
14
+ Base class for standard datasets.
15
+
16
+ Args:
17
+ roots (str): paths to the dataset
18
+ """
19
+
20
+ def __init__(self,
21
+ roots: str,
22
+ ):
23
+ super().__init__()
24
+ self.roots = roots.split(',')
25
+ self.instances = []
26
+ self.metadata = pd.DataFrame()
27
+
28
+ self._stats = {}
29
+ for root in self.roots:
30
+ key = os.path.basename(root)
31
+ self._stats[key] = {}
32
+ metadata = pd.read_csv(os.path.join(root, 'metadata.csv'))
33
+ self._stats[key]['Total'] = len(metadata)
34
+ metadata, stats = self.filter_metadata(metadata)
35
+ self._stats[key].update(stats)
36
+ self.instances.extend([(root, sha256) for sha256 in metadata['sha256'].values])
37
+ metadata.set_index('sha256', inplace=True)
38
+ self.metadata = pd.concat([self.metadata, metadata])
39
+
40
+ @abstractmethod
41
+ def filter_metadata(self, metadata: pd.DataFrame) -> Tuple[pd.DataFrame, Dict[str, int]]:
42
+ pass
43
+
44
+ @abstractmethod
45
+ def get_instance(self, root: str, instance: str) -> Dict[str, Any]:
46
+ pass
47
+
48
+ def __len__(self):
49
+ return len(self.instances)
50
+
51
+ def __getitem__(self, index) -> Dict[str, Any]:
52
+ try:
53
+ root, instance = self.instances[index]
54
+ return self.get_instance(root, instance)
55
+ except Exception as e:
56
+ print(e)
57
+ return self.__getitem__(np.random.randint(0, len(self)))
58
+
59
+ def __str__(self):
60
+ lines = []
61
+ lines.append(self.__class__.__name__)
62
+ lines.append(f' - Total instances: {len(self)}')
63
+ lines.append(f' - Sources:')
64
+ for key, stats in self._stats.items():
65
+ lines.append(f' - {key}:')
66
+ for k, v in stats.items():
67
+ lines.append(f' - {k}: {v}')
68
+ return '\n'.join(lines)
69
+
70
+
71
+ class TextConditionedMixin:
72
+ def __init__(self, roots, **kwargs):
73
+ super().__init__(roots, **kwargs)
74
+ self.captions = {}
75
+ for instance in self.instances:
76
+ sha256 = instance[1]
77
+ self.captions[sha256] = json.loads(self.metadata.loc[sha256]['captions'])
78
+
79
+ def filter_metadata(self, metadata):
80
+ metadata, stats = super().filter_metadata(metadata)
81
+ metadata = metadata[metadata['captions'].notna()]
82
+ stats['With captions'] = len(metadata)
83
+ return metadata, stats
84
+
85
+ def get_instance(self, root, instance):
86
+ pack = super().get_instance(root, instance)
87
+ text = np.random.choice(self.captions[instance])
88
+ pack['cond'] = text
89
+ return pack
90
+
91
+
92
+ class ImageConditionedMixin:
93
+ def __init__(self, roots, *, image_size=518, **kwargs):
94
+ self.image_size = image_size
95
+ super().__init__(roots, **kwargs)
96
+
97
+ def filter_metadata(self, metadata):
98
+ metadata, stats = super().filter_metadata(metadata)
99
+ metadata = metadata[metadata[f'cond_rendered']]
100
+ stats['Cond rendered'] = len(metadata)
101
+ return metadata, stats
102
+
103
+ def get_instance(self, root, instance):
104
+ pack = super().get_instance(root, instance)
105
+
106
+ image_root = os.path.join(root, 'renders_cond', instance)
107
+ with open(os.path.join(image_root, 'transforms.json')) as f:
108
+ metadata = json.load(f)
109
+ n_views = len(metadata['frames'])
110
+ view = np.random.randint(n_views)
111
+ metadata = metadata['frames'][view]
112
+
113
+ image_path = os.path.join(image_root, metadata['file_path'])
114
+ image = Image.open(image_path)
115
+
116
+ alpha = np.array(image.getchannel(3))
117
+ bbox = np.array(alpha).nonzero()
118
+ bbox = [bbox[1].min(), bbox[0].min(), bbox[1].max(), bbox[0].max()]
119
+ center = [(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2]
120
+ hsize = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) / 2
121
+ aug_size_ratio = 1.2
122
+ aug_hsize = hsize * aug_size_ratio
123
+ aug_center_offset = [0, 0]
124
+ aug_center = [center[0] + aug_center_offset[0], center[1] + aug_center_offset[1]]
125
+ aug_bbox = [int(aug_center[0] - aug_hsize), int(aug_center[1] - aug_hsize), int(aug_center[0] + aug_hsize), int(aug_center[1] + aug_hsize)]
126
+ image = image.crop(aug_bbox)
127
+
128
+ image = image.resize((self.image_size, self.image_size), Image.Resampling.LANCZOS)
129
+ alpha = image.getchannel(3)
130
+ image = image.convert('RGB')
131
+ image = torch.tensor(np.array(image)).permute(2, 0, 1).float() / 255.0
132
+ alpha = torch.tensor(np.array(alpha)).float() / 255.0
133
+ image = image * alpha.unsqueeze(0)
134
+ pack['cond'] = image
135
+
136
+ return pack
137
+
trellis/datasets/sparse_feat2render.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from PIL import Image
3
+ import json
4
+ import numpy as np
5
+ import pandas as pd
6
+ import torch
7
+ import utils3d.torch
8
+ from ..modules.sparse.basic import SparseTensor
9
+ from .components import StandardDatasetBase
10
+
11
+
12
+ class SparseFeat2Render(StandardDatasetBase):
13
+ """
14
+ SparseFeat2Render dataset.
15
+
16
+ Args:
17
+ roots (str): paths to the dataset
18
+ image_size (int): size of the image
19
+ model (str): model name
20
+ resolution (int): resolution of the data
21
+ min_aesthetic_score (float): minimum aesthetic score
22
+ max_num_voxels (int): maximum number of voxels
23
+ """
24
+ def __init__(
25
+ self,
26
+ roots: str,
27
+ image_size: int,
28
+ model: str = 'dinov2_vitl14_reg',
29
+ resolution: int = 64,
30
+ min_aesthetic_score: float = 5.0,
31
+ max_num_voxels: int = 32768,
32
+ ):
33
+ self.image_size = image_size
34
+ self.model = model
35
+ self.resolution = resolution
36
+ self.min_aesthetic_score = min_aesthetic_score
37
+ self.max_num_voxels = max_num_voxels
38
+ self.value_range = (0, 1)
39
+
40
+ super().__init__(roots)
41
+
42
+ def filter_metadata(self, metadata):
43
+ stats = {}
44
+ metadata = metadata[metadata[f'feature_{self.model}']]
45
+ stats['With features'] = len(metadata)
46
+ metadata = metadata[metadata['aesthetic_score'] >= self.min_aesthetic_score]
47
+ stats[f'Aesthetic score >= {self.min_aesthetic_score}'] = len(metadata)
48
+ metadata = metadata[metadata['num_voxels'] <= self.max_num_voxels]
49
+ stats[f'Num voxels <= {self.max_num_voxels}'] = len(metadata)
50
+ return metadata, stats
51
+
52
+ def _get_image(self, root, instance):
53
+ with open(os.path.join(root, 'renders', instance, 'transforms.json')) as f:
54
+ metadata = json.load(f)
55
+ n_views = len(metadata['frames'])
56
+ view = np.random.randint(n_views)
57
+ metadata = metadata['frames'][view]
58
+ fov = metadata['camera_angle_x']
59
+ intrinsics = utils3d.torch.intrinsics_from_fov_xy(torch.tensor(fov), torch.tensor(fov))
60
+ c2w = torch.tensor(metadata['transform_matrix'])
61
+ c2w[:3, 1:3] *= -1
62
+ extrinsics = torch.inverse(c2w)
63
+
64
+ image_path = os.path.join(root, 'renders', instance, metadata['file_path'])
65
+ image = Image.open(image_path)
66
+ alpha = image.getchannel(3)
67
+ image = image.convert('RGB')
68
+ image = image.resize((self.image_size, self.image_size), Image.Resampling.LANCZOS)
69
+ alpha = alpha.resize((self.image_size, self.image_size), Image.Resampling.LANCZOS)
70
+ image = torch.tensor(np.array(image)).permute(2, 0, 1).float() / 255.0
71
+ alpha = torch.tensor(np.array(alpha)).float() / 255.0
72
+
73
+ return {
74
+ 'image': image,
75
+ 'alpha': alpha,
76
+ 'extrinsics': extrinsics,
77
+ 'intrinsics': intrinsics,
78
+ }
79
+
80
+ def _get_feat(self, root, instance):
81
+ DATA_RESOLUTION = 64
82
+ feats_path = os.path.join(root, 'features', self.model, f'{instance}.npz')
83
+ feats = np.load(feats_path, allow_pickle=True)
84
+ coords = torch.tensor(feats['indices']).int()
85
+ feats = torch.tensor(feats['patchtokens']).float()
86
+
87
+ if self.resolution != DATA_RESOLUTION:
88
+ factor = DATA_RESOLUTION // self.resolution
89
+ coords = coords // factor
90
+ coords, idx = coords.unique(return_inverse=True, dim=0)
91
+ feats = torch.scatter_reduce(
92
+ torch.zeros(coords.shape[0], feats.shape[1], device=feats.device),
93
+ dim=0,
94
+ index=idx.unsqueeze(-1).expand(-1, feats.shape[1]),
95
+ src=feats,
96
+ reduce='mean'
97
+ )
98
+
99
+ return {
100
+ 'coords': coords,
101
+ 'feats': feats,
102
+ }
103
+
104
+ @torch.no_grad()
105
+ def visualize_sample(self, sample: dict):
106
+ return sample['image']
107
+
108
+ @staticmethod
109
+ def collate_fn(batch):
110
+ pack = {}
111
+ coords = []
112
+ for i, b in enumerate(batch):
113
+ coords.append(torch.cat([torch.full((b['coords'].shape[0], 1), i, dtype=torch.int32), b['coords']], dim=-1))
114
+ coords = torch.cat(coords)
115
+ feats = torch.cat([b['feats'] for b in batch])
116
+ pack['feats'] = SparseTensor(
117
+ coords=coords,
118
+ feats=feats,
119
+ )
120
+
121
+ pack['image'] = torch.stack([b['image'] for b in batch])
122
+ pack['alpha'] = torch.stack([b['alpha'] for b in batch])
123
+ pack['extrinsics'] = torch.stack([b['extrinsics'] for b in batch])
124
+ pack['intrinsics'] = torch.stack([b['intrinsics'] for b in batch])
125
+
126
+ return pack
127
+
128
+ def get_instance(self, root, instance):
129
+ image = self._get_image(root, instance)
130
+ feat = self._get_feat(root, instance)
131
+ return {
132
+ **image,
133
+ **feat,
134
+ }
trellis/datasets/sparse_structure.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ from typing import Union
4
+ import numpy as np
5
+ import pandas as pd
6
+ import torch
7
+ from torch.utils.data import Dataset
8
+ import utils3d
9
+ from .components import StandardDatasetBase
10
+ from ..representations.octree import DfsOctree as Octree
11
+ from ..renderers import OctreeRenderer
12
+
13
+
14
+ class SparseStructure(StandardDatasetBase):
15
+ """
16
+ Sparse structure dataset
17
+
18
+ Args:
19
+ roots (str): path to the dataset
20
+ resolution (int): resolution of the voxel grid
21
+ min_aesthetic_score (float): minimum aesthetic score of the instances to be included in the dataset
22
+ """
23
+
24
+ def __init__(self,
25
+ roots,
26
+ resolution: int = 64,
27
+ min_aesthetic_score: float = 5.0,
28
+ ):
29
+ self.resolution = resolution
30
+ self.min_aesthetic_score = min_aesthetic_score
31
+ self.value_range = (0, 1)
32
+
33
+ super().__init__(roots)
34
+
35
+ def filter_metadata(self, metadata):
36
+ stats = {}
37
+ metadata = metadata[metadata[f'voxelized']]
38
+ stats['Voxelized'] = len(metadata)
39
+ metadata = metadata[metadata['aesthetic_score'] >= self.min_aesthetic_score]
40
+ stats[f'Aesthetic score >= {self.min_aesthetic_score}'] = len(metadata)
41
+ return metadata, stats
42
+
43
+ def get_instance(self, root, instance):
44
+ position = utils3d.io.read_ply(os.path.join(root, 'voxels', f'{instance}.ply'))[0]
45
+ coords = ((torch.tensor(position) + 0.5) * self.resolution).int().contiguous()
46
+ ss = torch.zeros(1, self.resolution, self.resolution, self.resolution, dtype=torch.long)
47
+ ss[:, coords[:, 0], coords[:, 1], coords[:, 2]] = 1
48
+ return {'ss': ss}
49
+
50
+ @torch.no_grad()
51
+ def visualize_sample(self, ss: Union[torch.Tensor, dict]):
52
+ ss = ss if isinstance(ss, torch.Tensor) else ss['ss']
53
+
54
+ renderer = OctreeRenderer()
55
+ renderer.rendering_options.resolution = 512
56
+ renderer.rendering_options.near = 0.8
57
+ renderer.rendering_options.far = 1.6
58
+ renderer.rendering_options.bg_color = (0, 0, 0)
59
+ renderer.rendering_options.ssaa = 4
60
+ renderer.pipe.primitive = 'voxel'
61
+
62
+ # Build camera
63
+ yaws = [0, np.pi / 2, np.pi, 3 * np.pi / 2]
64
+ yaws_offset = np.random.uniform(-np.pi / 4, np.pi / 4)
65
+ yaws = [y + yaws_offset for y in yaws]
66
+ pitch = [np.random.uniform(-np.pi / 4, np.pi / 4) for _ in range(4)]
67
+
68
+ exts = []
69
+ ints = []
70
+ for yaw, pitch in zip(yaws, pitch):
71
+ orig = torch.tensor([
72
+ np.sin(yaw) * np.cos(pitch),
73
+ np.cos(yaw) * np.cos(pitch),
74
+ np.sin(pitch),
75
+ ]).float().cuda() * 2
76
+ fov = torch.deg2rad(torch.tensor(30)).cuda()
77
+ extrinsics = utils3d.torch.extrinsics_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda())
78
+ intrinsics = utils3d.torch.intrinsics_from_fov_xy(fov, fov)
79
+ exts.append(extrinsics)
80
+ ints.append(intrinsics)
81
+
82
+ images = []
83
+
84
+ # Build each representation
85
+ ss = ss.cuda()
86
+ for i in range(ss.shape[0]):
87
+ representation = Octree(
88
+ depth=10,
89
+ aabb=[-0.5, -0.5, -0.5, 1, 1, 1],
90
+ device='cuda',
91
+ primitive='voxel',
92
+ sh_degree=0,
93
+ primitive_config={'solid': True},
94
+ )
95
+ coords = torch.nonzero(ss[i, 0], as_tuple=False)
96
+ representation.position = coords.float() / self.resolution
97
+ representation.depth = torch.full((representation.position.shape[0], 1), int(np.log2(self.resolution)), dtype=torch.uint8, device='cuda')
98
+
99
+ image = torch.zeros(3, 1024, 1024).cuda()
100
+ tile = [2, 2]
101
+ for j, (ext, intr) in enumerate(zip(exts, ints)):
102
+ res = renderer.render(representation, ext, intr, colors_overwrite=representation.position)
103
+ image[:, 512 * (j // tile[1]):512 * (j // tile[1] + 1), 512 * (j % tile[1]):512 * (j % tile[1] + 1)] = res['color']
104
+ images.append(image)
105
+
106
+ return torch.stack(images)
107
+
trellis/datasets/sparse_structure_latent.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ from typing import *
4
+ import numpy as np
5
+ import torch
6
+ import utils3d
7
+ from ..representations.octree import DfsOctree as Octree
8
+ from ..renderers import OctreeRenderer
9
+ from .components import StandardDatasetBase, TextConditionedMixin, ImageConditionedMixin
10
+ from .. import models
11
+ from ..utils.dist_utils import read_file_dist
12
+
13
+
14
+ class SparseStructureLatentVisMixin:
15
+ def __init__(
16
+ self,
17
+ *args,
18
+ pretrained_ss_dec: str = 'JeffreyXiang/TRELLIS-image-large/ckpts/ss_dec_conv3d_16l8_fp16',
19
+ ss_dec_path: Optional[str] = None,
20
+ ss_dec_ckpt: Optional[str] = None,
21
+ **kwargs
22
+ ):
23
+ super().__init__(*args, **kwargs)
24
+ self.ss_dec = None
25
+ self.pretrained_ss_dec = pretrained_ss_dec
26
+ self.ss_dec_path = ss_dec_path
27
+ self.ss_dec_ckpt = ss_dec_ckpt
28
+
29
+ def _loading_ss_dec(self):
30
+ if self.ss_dec is not None:
31
+ return
32
+ if self.ss_dec_path is not None:
33
+ cfg = json.load(open(os.path.join(self.ss_dec_path, 'config.json'), 'r'))
34
+ decoder = getattr(models, cfg['models']['decoder']['name'])(**cfg['models']['decoder']['args'])
35
+ ckpt_path = os.path.join(self.ss_dec_path, 'ckpts', f'decoder_{self.ss_dec_ckpt}.pt')
36
+ decoder.load_state_dict(torch.load(read_file_dist(ckpt_path), map_location='cpu', weights_only=True))
37
+ else:
38
+ decoder = models.from_pretrained(self.pretrained_ss_dec)
39
+ self.ss_dec = decoder.cuda().eval()
40
+
41
+ def _delete_ss_dec(self):
42
+ del self.ss_dec
43
+ self.ss_dec = None
44
+
45
+ @torch.no_grad()
46
+ def decode_latent(self, z, batch_size=4):
47
+ self._loading_ss_dec()
48
+ ss = []
49
+ if self.normalization is not None:
50
+ z = z * self.std.to(z.device) + self.mean.to(z.device)
51
+ for i in range(0, z.shape[0], batch_size):
52
+ ss.append(self.ss_dec(z[i:i+batch_size]))
53
+ ss = torch.cat(ss, dim=0)
54
+ self._delete_ss_dec()
55
+ return ss
56
+
57
+ @torch.no_grad()
58
+ def visualize_sample(self, x_0: Union[torch.Tensor, dict]):
59
+ x_0 = x_0 if isinstance(x_0, torch.Tensor) else x_0['x_0']
60
+ x_0 = self.decode_latent(x_0.cuda())
61
+
62
+ renderer = OctreeRenderer()
63
+ renderer.rendering_options.resolution = 512
64
+ renderer.rendering_options.near = 0.8
65
+ renderer.rendering_options.far = 1.6
66
+ renderer.rendering_options.bg_color = (0, 0, 0)
67
+ renderer.rendering_options.ssaa = 4
68
+ renderer.pipe.primitive = 'voxel'
69
+
70
+ # Build camera
71
+ yaws = [0, np.pi / 2, np.pi, 3 * np.pi / 2]
72
+ yaws_offset = np.random.uniform(-np.pi / 4, np.pi / 4)
73
+ yaws = [y + yaws_offset for y in yaws]
74
+ pitch = [np.random.uniform(-np.pi / 4, np.pi / 4) for _ in range(4)]
75
+
76
+ exts = []
77
+ ints = []
78
+ for yaw, pitch in zip(yaws, pitch):
79
+ orig = torch.tensor([
80
+ np.sin(yaw) * np.cos(pitch),
81
+ np.cos(yaw) * np.cos(pitch),
82
+ np.sin(pitch),
83
+ ]).float().cuda() * 2
84
+ fov = torch.deg2rad(torch.tensor(30)).cuda()
85
+ extrinsics = utils3d.torch.extrinsics_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda())
86
+ intrinsics = utils3d.torch.intrinsics_from_fov_xy(fov, fov)
87
+ exts.append(extrinsics)
88
+ ints.append(intrinsics)
89
+
90
+ images = []
91
+
92
+ # Build each representation
93
+ x_0 = x_0.cuda()
94
+ for i in range(x_0.shape[0]):
95
+ representation = Octree(
96
+ depth=10,
97
+ aabb=[-0.5, -0.5, -0.5, 1, 1, 1],
98
+ device='cuda',
99
+ primitive='voxel',
100
+ sh_degree=0,
101
+ primitive_config={'solid': True},
102
+ )
103
+ coords = torch.nonzero(x_0[i, 0] > 0, as_tuple=False)
104
+ resolution = x_0.shape[-1]
105
+ representation.position = coords.float() / resolution
106
+ representation.depth = torch.full((representation.position.shape[0], 1), int(np.log2(resolution)), dtype=torch.uint8, device='cuda')
107
+
108
+ image = torch.zeros(3, 1024, 1024).cuda()
109
+ tile = [2, 2]
110
+ for j, (ext, intr) in enumerate(zip(exts, ints)):
111
+ res = renderer.render(representation, ext, intr, colors_overwrite=representation.position)
112
+ image[:, 512 * (j // tile[1]):512 * (j // tile[1] + 1), 512 * (j % tile[1]):512 * (j % tile[1] + 1)] = res['color']
113
+ images.append(image)
114
+
115
+ return torch.stack(images)
116
+
117
+
118
+ class SparseStructureLatent(SparseStructureLatentVisMixin, StandardDatasetBase):
119
+ """
120
+ Sparse structure latent dataset
121
+
122
+ Args:
123
+ roots (str): path to the dataset
124
+ latent_model (str): name of the latent model
125
+ min_aesthetic_score (float): minimum aesthetic score
126
+ normalization (dict): normalization stats
127
+ pretrained_ss_dec (str): name of the pretrained sparse structure decoder
128
+ ss_dec_path (str): path to the sparse structure decoder, if given, will override the pretrained_ss_dec
129
+ ss_dec_ckpt (str): name of the sparse structure decoder checkpoint
130
+ """
131
+ def __init__(self,
132
+ roots: str,
133
+ *,
134
+ latent_model: str,
135
+ min_aesthetic_score: float = 5.0,
136
+ normalization: Optional[dict] = None,
137
+ pretrained_ss_dec: str = 'JeffreyXiang/TRELLIS-image-large/ckpts/ss_dec_conv3d_16l8_fp16',
138
+ ss_dec_path: Optional[str] = None,
139
+ ss_dec_ckpt: Optional[str] = None,
140
+ ):
141
+ self.latent_model = latent_model
142
+ self.min_aesthetic_score = min_aesthetic_score
143
+ self.normalization = normalization
144
+ self.value_range = (0, 1)
145
+
146
+ super().__init__(
147
+ roots,
148
+ pretrained_ss_dec=pretrained_ss_dec,
149
+ ss_dec_path=ss_dec_path,
150
+ ss_dec_ckpt=ss_dec_ckpt,
151
+ )
152
+
153
+ if self.normalization is not None:
154
+ self.mean = torch.tensor(self.normalization['mean']).reshape(-1, 1, 1, 1)
155
+ self.std = torch.tensor(self.normalization['std']).reshape(-1, 1, 1, 1)
156
+
157
+ def filter_metadata(self, metadata):
158
+ stats = {}
159
+ metadata = metadata[metadata[f'ss_latent_{self.latent_model}']]
160
+ stats['With sparse structure latents'] = len(metadata)
161
+ metadata = metadata[metadata['aesthetic_score'] >= self.min_aesthetic_score]
162
+ stats[f'Aesthetic score >= {self.min_aesthetic_score}'] = len(metadata)
163
+ return metadata, stats
164
+
165
+ def get_instance(self, root, instance):
166
+ latent = np.load(os.path.join(root, 'ss_latents', self.latent_model, f'{instance}.npz'))
167
+ z = torch.tensor(latent['mean']).float()
168
+ if self.normalization is not None:
169
+ z = (z - self.mean) / self.std
170
+
171
+ pack = {
172
+ 'x_0': z,
173
+ }
174
+ return pack
175
+
176
+
177
+ class TextConditionedSparseStructureLatent(TextConditionedMixin, SparseStructureLatent):
178
+ """
179
+ Text-conditioned sparse structure dataset
180
+ """
181
+ pass
182
+
183
+
184
+ class ImageConditionedSparseStructureLatent(ImageConditionedMixin, SparseStructureLatent):
185
+ """
186
+ Image-conditioned sparse structure dataset
187
+ """
188
+ pass
189
+
trellis/datasets/structured_latent.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from typing import *
4
+ import numpy as np
5
+ import torch
6
+ import utils3d.torch
7
+ from .components import StandardDatasetBase, TextConditionedMixin, ImageConditionedMixin
8
+ from ..modules.sparse.basic import SparseTensor
9
+ from .. import models
10
+ from ..utils.render_utils import get_renderer
11
+ from ..utils.dist_utils import read_file_dist
12
+ from ..utils.data_utils import load_balanced_group_indices
13
+
14
+
15
+ class SLatVisMixin:
16
+ def __init__(
17
+ self,
18
+ *args,
19
+ pretrained_slat_dec: str = 'JeffreyXiang/TRELLIS-image-large/ckpts/slat_dec_gs_swin8_B_64l8gs32_fp16',
20
+ slat_dec_path: Optional[str] = None,
21
+ slat_dec_ckpt: Optional[str] = None,
22
+ **kwargs
23
+ ):
24
+ super().__init__(*args, **kwargs)
25
+ self.slat_dec = None
26
+ self.pretrained_slat_dec = pretrained_slat_dec
27
+ self.slat_dec_path = slat_dec_path
28
+ self.slat_dec_ckpt = slat_dec_ckpt
29
+
30
+ def _loading_slat_dec(self):
31
+ if self.slat_dec is not None:
32
+ return
33
+ if self.slat_dec_path is not None:
34
+ cfg = json.load(open(os.path.join(self.slat_dec_path, 'config.json'), 'r'))
35
+ decoder = getattr(models, cfg['models']['decoder']['name'])(**cfg['models']['decoder']['args'])
36
+ ckpt_path = os.path.join(self.slat_dec_path, 'ckpts', f'decoder_{self.slat_dec_ckpt}.pt')
37
+ decoder.load_state_dict(torch.load(read_file_dist(ckpt_path), map_location='cpu', weights_only=True))
38
+ else:
39
+ decoder = models.from_pretrained(self.pretrained_slat_dec)
40
+ self.slat_dec = decoder.cuda().eval()
41
+
42
+ def _delete_slat_dec(self):
43
+ del self.slat_dec
44
+ self.slat_dec = None
45
+
46
+ @torch.no_grad()
47
+ def decode_latent(self, z, batch_size=4):
48
+ self._loading_slat_dec()
49
+ reps = []
50
+ if self.normalization is not None:
51
+ z = z * self.std.to(z.device) + self.mean.to(z.device)
52
+ for i in range(0, z.shape[0], batch_size):
53
+ reps.append(self.slat_dec(z[i:i+batch_size]))
54
+ reps = sum(reps, [])
55
+ self._delete_slat_dec()
56
+ return reps
57
+
58
+ @torch.no_grad()
59
+ def visualize_sample(self, x_0: Union[SparseTensor, dict]):
60
+ x_0 = x_0 if isinstance(x_0, SparseTensor) else x_0['x_0']
61
+ reps = self.decode_latent(x_0.cuda())
62
+
63
+ # Build camera
64
+ yaws = [0, np.pi / 2, np.pi, 3 * np.pi / 2]
65
+ yaws_offset = np.random.uniform(-np.pi / 4, np.pi / 4)
66
+ yaws = [y + yaws_offset for y in yaws]
67
+ pitch = [np.random.uniform(-np.pi / 4, np.pi / 4) for _ in range(4)]
68
+
69
+ exts = []
70
+ ints = []
71
+ for yaw, pitch in zip(yaws, pitch):
72
+ orig = torch.tensor([
73
+ np.sin(yaw) * np.cos(pitch),
74
+ np.cos(yaw) * np.cos(pitch),
75
+ np.sin(pitch),
76
+ ]).float().cuda() * 2
77
+ fov = torch.deg2rad(torch.tensor(40)).cuda()
78
+ extrinsics = utils3d.torch.extrinsics_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda())
79
+ intrinsics = utils3d.torch.intrinsics_from_fov_xy(fov, fov)
80
+ exts.append(extrinsics)
81
+ ints.append(intrinsics)
82
+
83
+ renderer = get_renderer(reps[0])
84
+ images = []
85
+ for representation in reps:
86
+ image = torch.zeros(3, 1024, 1024).cuda()
87
+ tile = [2, 2]
88
+ for j, (ext, intr) in enumerate(zip(exts, ints)):
89
+ res = renderer.render(representation, ext, intr)
90
+ image[:, 512 * (j // tile[1]):512 * (j // tile[1] + 1), 512 * (j % tile[1]):512 * (j % tile[1] + 1)] = res['color']
91
+ images.append(image)
92
+ images = torch.stack(images)
93
+
94
+ return images
95
+
96
+
97
+ class SLat(SLatVisMixin, StandardDatasetBase):
98
+ """
99
+ structured latent dataset
100
+
101
+ Args:
102
+ roots (str): path to the dataset
103
+ latent_model (str): name of the latent model
104
+ min_aesthetic_score (float): minimum aesthetic score
105
+ max_num_voxels (int): maximum number of voxels
106
+ normalization (dict): normalization stats
107
+ pretrained_slat_dec (str): name of the pretrained slat decoder
108
+ slat_dec_path (str): path to the slat decoder, if given, will override the pretrained_slat_dec
109
+ slat_dec_ckpt (str): name of the slat decoder checkpoint
110
+ """
111
+ def __init__(self,
112
+ roots: str,
113
+ *,
114
+ latent_model: str,
115
+ min_aesthetic_score: float = 5.0,
116
+ max_num_voxels: int = 32768,
117
+ normalization: Optional[dict] = None,
118
+ pretrained_slat_dec: str = 'JeffreyXiang/TRELLIS-image-large/ckpts/slat_dec_gs_swin8_B_64l8gs32_fp16',
119
+ slat_dec_path: Optional[str] = None,
120
+ slat_dec_ckpt: Optional[str] = None,
121
+ ):
122
+ self.normalization = normalization
123
+ self.latent_model = latent_model
124
+ self.min_aesthetic_score = min_aesthetic_score
125
+ self.max_num_voxels = max_num_voxels
126
+ self.value_range = (0, 1)
127
+
128
+ super().__init__(
129
+ roots,
130
+ pretrained_slat_dec=pretrained_slat_dec,
131
+ slat_dec_path=slat_dec_path,
132
+ slat_dec_ckpt=slat_dec_ckpt,
133
+ )
134
+
135
+ self.loads = [self.metadata.loc[sha256, 'num_voxels'] for _, sha256 in self.instances]
136
+
137
+ if self.normalization is not None:
138
+ self.mean = torch.tensor(self.normalization['mean']).reshape(1, -1)
139
+ self.std = torch.tensor(self.normalization['std']).reshape(1, -1)
140
+
141
+ def filter_metadata(self, metadata):
142
+ stats = {}
143
+ metadata = metadata[metadata[f'latent_{self.latent_model}']]
144
+ stats['With latent'] = len(metadata)
145
+ metadata = metadata[metadata['aesthetic_score'] >= self.min_aesthetic_score]
146
+ stats[f'Aesthetic score >= {self.min_aesthetic_score}'] = len(metadata)
147
+ metadata = metadata[metadata['num_voxels'] <= self.max_num_voxels]
148
+ stats[f'Num voxels <= {self.max_num_voxels}'] = len(metadata)
149
+ return metadata, stats
150
+
151
+ def get_instance(self, root, instance):
152
+ data = np.load(os.path.join(root, 'latents', self.latent_model, f'{instance}.npz'))
153
+ coords = torch.tensor(data['coords']).int()
154
+ feats = torch.tensor(data['feats']).float()
155
+ if self.normalization is not None:
156
+ feats = (feats - self.mean) / self.std
157
+ return {
158
+ 'coords': coords,
159
+ 'feats': feats,
160
+ }
161
+
162
+ @staticmethod
163
+ def collate_fn(batch, split_size=None):
164
+ if split_size is None:
165
+ group_idx = [list(range(len(batch)))]
166
+ else:
167
+ group_idx = load_balanced_group_indices([b['coords'].shape[0] for b in batch], split_size)
168
+ packs = []
169
+ for group in group_idx:
170
+ sub_batch = [batch[i] for i in group]
171
+ pack = {}
172
+ coords = []
173
+ feats = []
174
+ layout = []
175
+ start = 0
176
+ for i, b in enumerate(sub_batch):
177
+ coords.append(torch.cat([torch.full((b['coords'].shape[0], 1), i, dtype=torch.int32), b['coords']], dim=-1))
178
+ feats.append(b['feats'])
179
+ layout.append(slice(start, start + b['coords'].shape[0]))
180
+ start += b['coords'].shape[0]
181
+ coords = torch.cat(coords)
182
+ feats = torch.cat(feats)
183
+ pack['x_0'] = SparseTensor(
184
+ coords=coords,
185
+ feats=feats,
186
+ )
187
+ pack['x_0']._shape = torch.Size([len(group), *sub_batch[0]['feats'].shape[1:]])
188
+ pack['x_0'].register_spatial_cache('layout', layout)
189
+
190
+ # collate other data
191
+ keys = [k for k in sub_batch[0].keys() if k not in ['coords', 'feats']]
192
+ for k in keys:
193
+ if isinstance(sub_batch[0][k], torch.Tensor):
194
+ pack[k] = torch.stack([b[k] for b in sub_batch])
195
+ elif isinstance(sub_batch[0][k], list):
196
+ pack[k] = sum([b[k] for b in sub_batch], [])
197
+ else:
198
+ pack[k] = [b[k] for b in sub_batch]
199
+
200
+ packs.append(pack)
201
+
202
+ if split_size is None:
203
+ return packs[0]
204
+ return packs
205
+
206
+
207
+ class TextConditionedSLat(TextConditionedMixin, SLat):
208
+ """
209
+ Text conditioned structured latent dataset
210
+ """
211
+ pass
212
+
213
+
214
+ class ImageConditionedSLat(ImageConditionedMixin, SLat):
215
+ """
216
+ Image conditioned structured latent dataset
217
+ """
218
+ pass
trellis/datasets/structured_latent2render.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from PIL import Image
3
+ import json
4
+ import numpy as np
5
+ import torch
6
+ import utils3d.torch
7
+ from ..modules.sparse.basic import SparseTensor
8
+ from .components import StandardDatasetBase
9
+
10
+
11
+ class SLat2Render(StandardDatasetBase):
12
+ """
13
+ Dataset for Structured Latent and rendered images.
14
+
15
+ Args:
16
+ roots (str): paths to the dataset
17
+ image_size (int): size of the image
18
+ latent_model (str): latent model name
19
+ min_aesthetic_score (float): minimum aesthetic score
20
+ max_num_voxels (int): maximum number of voxels
21
+ """
22
+ def __init__(
23
+ self,
24
+ roots: str,
25
+ image_size: int,
26
+ latent_model: str,
27
+ min_aesthetic_score: float = 5.0,
28
+ max_num_voxels: int = 32768,
29
+ ):
30
+ self.image_size = image_size
31
+ self.latent_model = latent_model
32
+ self.min_aesthetic_score = min_aesthetic_score
33
+ self.max_num_voxels = max_num_voxels
34
+ self.value_range = (0, 1)
35
+
36
+ super().__init__(roots)
37
+
38
+ def filter_metadata(self, metadata):
39
+ stats = {}
40
+ metadata = metadata[metadata[f'latent_{self.latent_model}']]
41
+ stats['With latent'] = len(metadata)
42
+ metadata = metadata[metadata['aesthetic_score'] >= self.min_aesthetic_score]
43
+ stats[f'Aesthetic score >= {self.min_aesthetic_score}'] = len(metadata)
44
+ metadata = metadata[metadata['num_voxels'] <= self.max_num_voxels]
45
+ stats[f'Num voxels <= {self.max_num_voxels}'] = len(metadata)
46
+ return metadata, stats
47
+
48
+ def _get_image(self, root, instance):
49
+ with open(os.path.join(root, 'renders', instance, 'transforms.json')) as f:
50
+ metadata = json.load(f)
51
+ n_views = len(metadata['frames'])
52
+ view = np.random.randint(n_views)
53
+ metadata = metadata['frames'][view]
54
+ fov = metadata['camera_angle_x']
55
+ intrinsics = utils3d.torch.intrinsics_from_fov_xy(torch.tensor(fov), torch.tensor(fov))
56
+ c2w = torch.tensor(metadata['transform_matrix'])
57
+ c2w[:3, 1:3] *= -1
58
+ extrinsics = torch.inverse(c2w)
59
+
60
+ image_path = os.path.join(root, 'renders', instance, metadata['file_path'])
61
+ image = Image.open(image_path)
62
+ alpha = image.getchannel(3)
63
+ image = image.convert('RGB')
64
+ image = image.resize((self.image_size, self.image_size), Image.Resampling.LANCZOS)
65
+ alpha = alpha.resize((self.image_size, self.image_size), Image.Resampling.LANCZOS)
66
+ image = torch.tensor(np.array(image)).permute(2, 0, 1).float() / 255.0
67
+ alpha = torch.tensor(np.array(alpha)).float() / 255.0
68
+
69
+ return {
70
+ 'image': image,
71
+ 'alpha': alpha,
72
+ 'extrinsics': extrinsics,
73
+ 'intrinsics': intrinsics,
74
+ }
75
+
76
+ def _get_latent(self, root, instance):
77
+ data = np.load(os.path.join(root, 'latents', self.latent_model, f'{instance}.npz'))
78
+ coords = torch.tensor(data['coords']).int()
79
+ feats = torch.tensor(data['feats']).float()
80
+ return {
81
+ 'coords': coords,
82
+ 'feats': feats,
83
+ }
84
+
85
+ @torch.no_grad()
86
+ def visualize_sample(self, sample: dict):
87
+ return sample['image']
88
+
89
+ @staticmethod
90
+ def collate_fn(batch):
91
+ pack = {}
92
+ coords = []
93
+ for i, b in enumerate(batch):
94
+ coords.append(torch.cat([torch.full((b['coords'].shape[0], 1), i, dtype=torch.int32), b['coords']], dim=-1))
95
+ coords = torch.cat(coords)
96
+ feats = torch.cat([b['feats'] for b in batch])
97
+ pack['latents'] = SparseTensor(
98
+ coords=coords,
99
+ feats=feats,
100
+ )
101
+
102
+ # collate other data
103
+ keys = [k for k in batch[0].keys() if k not in ['coords', 'feats']]
104
+ for k in keys:
105
+ if isinstance(batch[0][k], torch.Tensor):
106
+ pack[k] = torch.stack([b[k] for b in batch])
107
+ elif isinstance(batch[0][k], list):
108
+ pack[k] = sum([b[k] for b in batch], [])
109
+ else:
110
+ pack[k] = [b[k] for b in batch]
111
+
112
+ return pack
113
+
114
+ def get_instance(self, root, instance):
115
+ image = self._get_image(root, instance)
116
+ latent = self._get_latent(root, instance)
117
+ return {
118
+ **image,
119
+ **latent,
120
+ }
121
+
122
+
123
+ class Slat2RenderGeo(SLat2Render):
124
+ def __init__(
125
+ self,
126
+ roots: str,
127
+ image_size: int,
128
+ latent_model: str,
129
+ min_aesthetic_score: float = 5.0,
130
+ max_num_voxels: int = 32768,
131
+ ):
132
+ super().__init__(
133
+ roots,
134
+ image_size,
135
+ latent_model,
136
+ min_aesthetic_score,
137
+ max_num_voxels,
138
+ )
139
+
140
+ def _get_geo(self, root, instance):
141
+ verts, face = utils3d.io.read_ply(os.path.join(root, 'renders', instance, 'mesh.ply'))
142
+ mesh = {
143
+ "vertices" : torch.from_numpy(verts),
144
+ "faces" : torch.from_numpy(face),
145
+ }
146
+ return {
147
+ "mesh" : mesh,
148
+ }
149
+
150
+ def get_instance(self, root, instance):
151
+ image = self._get_image(root, instance)
152
+ latent = self._get_latent(root, instance)
153
+ geo = self._get_geo(root, instance)
154
+ return {
155
+ **image,
156
+ **latent,
157
+ **geo,
158
+ }
159
+
160
+
trellis/models/__init__.py CHANGED
@@ -3,12 +3,20 @@ import importlib
3
  __attributes = {
4
  'SparseStructureEncoder': 'sparse_structure_vae',
5
  'SparseStructureDecoder': 'sparse_structure_vae',
 
6
  'SparseStructureFlowModel': 'sparse_structure_flow',
 
7
  'SLatEncoder': 'structured_latent_vae',
8
  'SLatGaussianDecoder': 'structured_latent_vae',
9
  'SLatRadianceFieldDecoder': 'structured_latent_vae',
10
  'SLatMeshDecoder': 'structured_latent_vae',
 
 
 
 
 
11
  'SLatFlowModel': 'structured_latent_flow',
 
12
  }
13
 
14
  __submodules = []
@@ -64,7 +72,25 @@ def from_pretrained(path: str, **kwargs):
64
 
65
  # For Pylance
66
  if __name__ == '__main__':
67
- from .sparse_structure_vae import SparseStructureEncoder, SparseStructureDecoder
 
 
 
 
68
  from .sparse_structure_flow import SparseStructureFlowModel
69
- from .structured_latent_vae import SLatEncoder, SLatGaussianDecoder, SLatRadianceFieldDecoder, SLatMeshDecoder
70
- from .structured_latent_flow import SLatFlowModel
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  __attributes = {
4
  'SparseStructureEncoder': 'sparse_structure_vae',
5
  'SparseStructureDecoder': 'sparse_structure_vae',
6
+
7
  'SparseStructureFlowModel': 'sparse_structure_flow',
8
+
9
  'SLatEncoder': 'structured_latent_vae',
10
  'SLatGaussianDecoder': 'structured_latent_vae',
11
  'SLatRadianceFieldDecoder': 'structured_latent_vae',
12
  'SLatMeshDecoder': 'structured_latent_vae',
13
+ 'ElasticSLatEncoder': 'structured_latent_vae',
14
+ 'ElasticSLatGaussianDecoder': 'structured_latent_vae',
15
+ 'ElasticSLatRadianceFieldDecoder': 'structured_latent_vae',
16
+ 'ElasticSLatMeshDecoder': 'structured_latent_vae',
17
+
18
  'SLatFlowModel': 'structured_latent_flow',
19
+ 'ElasticSLatFlowModel': 'structured_latent_flow',
20
  }
21
 
22
  __submodules = []
 
72
 
73
  # For Pylance
74
  if __name__ == '__main__':
75
+ from .sparse_structure_vae import (
76
+ SparseStructureEncoder,
77
+ SparseStructureDecoder,
78
+ )
79
+
80
  from .sparse_structure_flow import SparseStructureFlowModel
81
+
82
+ from .structured_latent_vae import (
83
+ SLatEncoder,
84
+ SLatGaussianDecoder,
85
+ SLatRadianceFieldDecoder,
86
+ SLatMeshDecoder,
87
+ ElasticSLatEncoder,
88
+ ElasticSLatGaussianDecoder,
89
+ ElasticSLatRadianceFieldDecoder,
90
+ ElasticSLatMeshDecoder,
91
+ )
92
+
93
+ from .structured_latent_flow import (
94
+ SLatFlowModel,
95
+ ElasticSLatFlowModel,
96
+ )
trellis/models/sparse_elastic_mixin.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from contextlib import contextmanager
2
+ from typing import *
3
+ import math
4
+ from ..modules import sparse as sp
5
+ from ..utils.elastic_utils import ElasticModuleMixin
6
+
7
+
8
+ class SparseTransformerElasticMixin(ElasticModuleMixin):
9
+ def _get_input_size(self, x: sp.SparseTensor, *args, **kwargs):
10
+ return x.feats.shape[0]
11
+
12
+ @contextmanager
13
+ def with_mem_ratio(self, mem_ratio=1.0):
14
+ if mem_ratio == 1.0:
15
+ yield 1.0
16
+ return
17
+ num_blocks = len(self.blocks)
18
+ num_checkpoint_blocks = min(math.ceil((1 - mem_ratio) * num_blocks) + 1, num_blocks)
19
+ exact_mem_ratio = 1 - (num_checkpoint_blocks - 1) / num_blocks
20
+ for i in range(num_blocks):
21
+ self.blocks[i].use_checkpoint = i < num_checkpoint_blocks
22
+ yield exact_mem_ratio
23
+ for i in range(num_blocks):
24
+ self.blocks[i].use_checkpoint = False
trellis/models/structured_latent_flow.py CHANGED
@@ -9,6 +9,7 @@ from ..modules.norm import LayerNorm32
9
  from ..modules import sparse as sp
10
  from ..modules.sparse.transformer import ModulatedSparseTransformerCrossBlock
11
  from .sparse_structure_flow import TimestepEmbedder
 
12
 
13
 
14
  class SparseResBlock3d(nn.Module):
@@ -109,8 +110,9 @@ class SLatFlowModel(nn.Module):
109
  self.qk_rms_norm_cross = qk_rms_norm_cross
110
  self.dtype = torch.float16 if use_fp16 else torch.float32
111
 
112
- assert int(np.log2(patch_size)) == np.log2(patch_size), "Patch size must be a power of 2"
113
- assert np.log2(patch_size) == len(io_block_channels), "Number of IO ResBlocks must match the number of stages"
 
114
 
115
  self.t_embedder = TimestepEmbedder(model_channels)
116
  if share_mod:
@@ -122,25 +124,27 @@ class SLatFlowModel(nn.Module):
122
  if pe_mode == "ape":
123
  self.pos_embedder = AbsolutePositionEmbedder(model_channels)
124
 
125
- self.input_layer = sp.SparseLinear(in_channels, io_block_channels[0])
 
126
  self.input_blocks = nn.ModuleList([])
127
- for chs, next_chs in zip(io_block_channels, io_block_channels[1:] + [model_channels]):
128
- self.input_blocks.extend([
129
- SparseResBlock3d(
130
- chs,
131
- model_channels,
132
- out_channels=chs,
133
- )
134
- for _ in range(num_io_res_blocks-1)
135
- ])
136
- self.input_blocks.append(
137
- SparseResBlock3d(
138
- chs,
139
- model_channels,
140
- out_channels=next_chs,
141
- downsample=True,
 
 
142
  )
143
- )
144
 
145
  self.blocks = nn.ModuleList([
146
  ModulatedSparseTransformerCrossBlock(
@@ -159,24 +163,26 @@ class SLatFlowModel(nn.Module):
159
  ])
160
 
161
  self.out_blocks = nn.ModuleList([])
162
- for chs, prev_chs in zip(reversed(io_block_channels), [model_channels] + list(reversed(io_block_channels[1:]))):
163
- self.out_blocks.append(
164
- SparseResBlock3d(
165
- prev_chs * 2 if self.use_skip_connection else prev_chs,
166
- model_channels,
167
- out_channels=chs,
168
- upsample=True,
169
- )
170
- )
171
- self.out_blocks.extend([
172
- SparseResBlock3d(
173
- chs * 2 if self.use_skip_connection else chs,
174
- model_channels,
175
- out_channels=chs,
176
  )
177
- for _ in range(num_io_res_blocks-1)
178
- ])
179
- self.out_layer = sp.SparseLinear(io_block_channels[0], out_channels)
 
 
 
 
 
 
 
180
 
181
  self.initialize_weights()
182
  if use_fp16:
@@ -260,3 +266,11 @@ class SLatFlowModel(nn.Module):
260
  h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
261
  h = self.out_layer(h.type(x.dtype))
262
  return h
 
 
 
 
 
 
 
 
 
9
  from ..modules import sparse as sp
10
  from ..modules.sparse.transformer import ModulatedSparseTransformerCrossBlock
11
  from .sparse_structure_flow import TimestepEmbedder
12
+ from .sparse_elastic_mixin import SparseTransformerElasticMixin
13
 
14
 
15
  class SparseResBlock3d(nn.Module):
 
110
  self.qk_rms_norm_cross = qk_rms_norm_cross
111
  self.dtype = torch.float16 if use_fp16 else torch.float32
112
 
113
+ if self.io_block_channels is not None:
114
+ assert int(np.log2(patch_size)) == np.log2(patch_size), "Patch size must be a power of 2"
115
+ assert np.log2(patch_size) == len(io_block_channels), "Number of IO ResBlocks must match the number of stages"
116
 
117
  self.t_embedder = TimestepEmbedder(model_channels)
118
  if share_mod:
 
124
  if pe_mode == "ape":
125
  self.pos_embedder = AbsolutePositionEmbedder(model_channels)
126
 
127
+ self.input_layer = sp.SparseLinear(in_channels, model_channels if io_block_channels is None else io_block_channels[0])
128
+
129
  self.input_blocks = nn.ModuleList([])
130
+ if io_block_channels is not None:
131
+ for chs, next_chs in zip(io_block_channels, io_block_channels[1:] + [model_channels]):
132
+ self.input_blocks.extend([
133
+ SparseResBlock3d(
134
+ chs,
135
+ model_channels,
136
+ out_channels=chs,
137
+ )
138
+ for _ in range(num_io_res_blocks-1)
139
+ ])
140
+ self.input_blocks.append(
141
+ SparseResBlock3d(
142
+ chs,
143
+ model_channels,
144
+ out_channels=next_chs,
145
+ downsample=True,
146
+ )
147
  )
 
148
 
149
  self.blocks = nn.ModuleList([
150
  ModulatedSparseTransformerCrossBlock(
 
163
  ])
164
 
165
  self.out_blocks = nn.ModuleList([])
166
+ if io_block_channels is not None:
167
+ for chs, prev_chs in zip(reversed(io_block_channels), [model_channels] + list(reversed(io_block_channels[1:]))):
168
+ self.out_blocks.append(
169
+ SparseResBlock3d(
170
+ prev_chs * 2 if self.use_skip_connection else prev_chs,
171
+ model_channels,
172
+ out_channels=chs,
173
+ upsample=True,
174
+ )
 
 
 
 
 
175
  )
176
+ self.out_blocks.extend([
177
+ SparseResBlock3d(
178
+ chs * 2 if self.use_skip_connection else chs,
179
+ model_channels,
180
+ out_channels=chs,
181
+ )
182
+ for _ in range(num_io_res_blocks-1)
183
+ ])
184
+
185
+ self.out_layer = sp.SparseLinear(model_channels if io_block_channels is None else io_block_channels[0], out_channels)
186
 
187
  self.initialize_weights()
188
  if use_fp16:
 
266
  h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
267
  h = self.out_layer(h.type(x.dtype))
268
  return h
269
+
270
+
271
+ class ElasticSLatFlowModel(SparseTransformerElasticMixin, SLatFlowModel):
272
+ """
273
+ SLat Flow Model with elastic memory management.
274
+ Used for training with low VRAM.
275
+ """
276
+ pass
trellis/models/structured_latent_vae/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- from .encoder import SLatEncoder
2
- from .decoder_gs import SLatGaussianDecoder
3
- from .decoder_rf import SLatRadianceFieldDecoder
4
- from .decoder_mesh import SLatMeshDecoder
 
1
+ from .encoder import SLatEncoder, ElasticSLatEncoder
2
+ from .decoder_gs import SLatGaussianDecoder, ElasticSLatGaussianDecoder
3
+ from .decoder_rf import SLatRadianceFieldDecoder, ElasticSLatRadianceFieldDecoder
4
+ from .decoder_mesh import SLatMeshDecoder, ElasticSLatMeshDecoder
trellis/models/structured_latent_vae/decoder_gs.py CHANGED
@@ -6,6 +6,7 @@ from ...modules import sparse as sp
6
  from ...utils.random_utils import hammersley_sequence
7
  from .base import SparseTransformerBase
8
  from ...representations import Gaussian
 
9
 
10
 
11
  class SLatGaussianDecoder(SparseTransformerBase):
@@ -120,3 +121,11 @@ class SLatGaussianDecoder(SparseTransformerBase):
120
  h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
121
  h = self.out_layer(h)
122
  return self.to_representation(h)
 
 
 
 
 
 
 
 
 
6
  from ...utils.random_utils import hammersley_sequence
7
  from .base import SparseTransformerBase
8
  from ...representations import Gaussian
9
+ from ..sparse_elastic_mixin import SparseTransformerElasticMixin
10
 
11
 
12
  class SLatGaussianDecoder(SparseTransformerBase):
 
121
  h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
122
  h = self.out_layer(h)
123
  return self.to_representation(h)
124
+
125
+
126
+ class ElasticSLatGaussianDecoder(SparseTransformerElasticMixin, SLatGaussianDecoder):
127
+ """
128
+ Slat VAE Gaussian decoder with elastic memory management.
129
+ Used for training with low VRAM.
130
+ """
131
+ pass
trellis/models/structured_latent_vae/decoder_mesh.py CHANGED
@@ -8,6 +8,7 @@ from ...modules import sparse as sp
8
  from .base import SparseTransformerBase
9
  from ...representations import MeshExtractResult
10
  from ...representations.mesh import SparseFeatures2Mesh
 
11
 
12
 
13
  class SparseSubdivideBlock3d(nn.Module):
@@ -165,3 +166,11 @@ class SLatMeshDecoder(SparseTransformerBase):
165
  h = h.type(x.dtype)
166
  h = self.out_layer(h)
167
  return self.to_representation(h)
 
 
 
 
 
 
 
 
 
8
  from .base import SparseTransformerBase
9
  from ...representations import MeshExtractResult
10
  from ...representations.mesh import SparseFeatures2Mesh
11
+ from ..sparse_elastic_mixin import SparseTransformerElasticMixin
12
 
13
 
14
  class SparseSubdivideBlock3d(nn.Module):
 
166
  h = h.type(x.dtype)
167
  h = self.out_layer(h)
168
  return self.to_representation(h)
169
+
170
+
171
+ class ElasticSLatMeshDecoder(SparseTransformerElasticMixin, SLatMeshDecoder):
172
+ """
173
+ Slat VAE Mesh decoder with elastic memory management.
174
+ Used for training with low VRAM.
175
+ """
176
+ pass
trellis/models/structured_latent_vae/decoder_rf.py CHANGED
@@ -6,6 +6,7 @@ import numpy as np
6
  from ...modules import sparse as sp
7
  from .base import SparseTransformerBase
8
  from ...representations import Strivec
 
9
 
10
 
11
  class SLatRadianceFieldDecoder(SparseTransformerBase):
@@ -102,3 +103,11 @@ class SLatRadianceFieldDecoder(SparseTransformerBase):
102
  h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
103
  h = self.out_layer(h)
104
  return self.to_representation(h)
 
 
 
 
 
 
 
 
 
6
  from ...modules import sparse as sp
7
  from .base import SparseTransformerBase
8
  from ...representations import Strivec
9
+ from ..sparse_elastic_mixin import SparseTransformerElasticMixin
10
 
11
 
12
  class SLatRadianceFieldDecoder(SparseTransformerBase):
 
103
  h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
104
  h = self.out_layer(h)
105
  return self.to_representation(h)
106
+
107
+
108
+ class ElasticSLatRadianceFieldDecoder(SparseTransformerElasticMixin, SLatRadianceFieldDecoder):
109
+ """
110
+ Slat VAE Radiance Field Decoder with elastic memory management.
111
+ Used for training with low VRAM.
112
+ """
113
+ pass
trellis/models/structured_latent_vae/encoder.py CHANGED
@@ -4,6 +4,7 @@ import torch.nn as nn
4
  import torch.nn.functional as F
5
  from ...modules import sparse as sp
6
  from .base import SparseTransformerBase
 
7
 
8
 
9
  class SLatEncoder(SparseTransformerBase):
@@ -70,3 +71,10 @@ class SLatEncoder(SparseTransformerBase):
70
  return z, mean, logvar
71
  else:
72
  return z
 
 
 
 
 
 
 
 
4
  import torch.nn.functional as F
5
  from ...modules import sparse as sp
6
  from .base import SparseTransformerBase
7
+ from ..sparse_elastic_mixin import SparseTransformerElasticMixin
8
 
9
 
10
  class SLatEncoder(SparseTransformerBase):
 
71
  return z, mean, logvar
72
  else:
73
  return z
74
+
75
+
76
+ class ElasticSLatEncoder(SparseTransformerElasticMixin, SLatEncoder):
77
+ """
78
+ SLat VAE encoder with elastic memory management.
79
+ Used for training with low VRAM.
80
+ """
trellis/pipelines/__init__.py CHANGED
@@ -1,5 +1,6 @@
1
  from . import samplers
2
  from .trellis_image_to_3d import TrellisImageTo3DPipeline
 
3
 
4
 
5
  def from_pretrained(path: str):
 
1
  from . import samplers
2
  from .trellis_image_to_3d import TrellisImageTo3DPipeline
3
+ from .trellis_text_to_3d import TrellisTextTo3DPipeline
4
 
5
 
6
  def from_pretrained(path: str):
trellis/pipelines/base.py CHANGED
@@ -36,10 +36,12 @@ class Pipeline:
36
  with open(config_file, 'r') as f:
37
  args = json.load(f)['args']
38
 
39
- _models = {
40
- k: models.from_pretrained(f"{path}/{v}")
41
- for k, v in args['models'].items()
42
- }
 
 
43
 
44
  new_pipeline = Pipeline(_models)
45
  new_pipeline._pretrained_args = args
 
36
  with open(config_file, 'r') as f:
37
  args = json.load(f)['args']
38
 
39
+ _models = {}
40
+ for k, v in args['models'].items():
41
+ try:
42
+ _models[k] = models.from_pretrained(f"{path}/{v}")
43
+ except:
44
+ _models[k] = models.from_pretrained(v)
45
 
46
  new_pipeline = Pipeline(_models)
47
  new_pipeline._pretrained_args = args
trellis/pipelines/samplers/flow_euler.py CHANGED
@@ -37,6 +37,8 @@ class FlowEulerSampler(Sampler):
37
 
38
  def _inference_model(self, model, x_t, t, cond=None, **kwargs):
39
  t = torch.tensor([1000 * t] * x_t.shape[0], device=x_t.device, dtype=torch.float32)
 
 
40
  return model(x_t, t, cond, **kwargs)
41
 
42
  def _get_model_prediction(self, model, x_t, t, cond=None, **kwargs):
 
37
 
38
  def _inference_model(self, model, x_t, t, cond=None, **kwargs):
39
  t = torch.tensor([1000 * t] * x_t.shape[0], device=x_t.device, dtype=torch.float32)
40
+ if cond is not None and cond.shape[0] == 1 and x_t.shape[0] > 1:
41
+ cond = cond.repeat(x_t.shape[0], *([1] * (len(cond.shape) - 1)))
42
  return model(x_t, t, cond, **kwargs)
43
 
44
  def _get_model_prediction(self, model, x_t, t, cond=None, **kwargs):
trellis/pipelines/trellis_image_to_3d.py CHANGED
@@ -4,15 +4,12 @@ import torch
4
  import torch.nn as nn
5
  import torch.nn.functional as F
6
  import numpy as np
7
- from tqdm import tqdm
8
- from easydict import EasyDict as edict
9
  from torchvision import transforms
10
  from PIL import Image
11
  import rembg
12
  from .base import Pipeline
13
  from . import samplers
14
  from ..modules import sparse as sp
15
- from ..representations import Gaussian, Strivec, MeshExtractResult
16
 
17
 
18
  class TrellisImageTo3DPipeline(Pipeline):
@@ -271,8 +268,10 @@ class TrellisImageTo3DPipeline(Pipeline):
271
  Args:
272
  image (Image.Image): The image prompt.
273
  num_samples (int): The number of samples to generate.
 
274
  sparse_structure_sampler_params (dict): Additional parameters for the sparse structure sampler.
275
  slat_sampler_params (dict): Additional parameters for the structured latent sampler.
 
276
  preprocess_image (bool): Whether to preprocess the image.
277
  """
278
  if preprocess_image:
 
4
  import torch.nn as nn
5
  import torch.nn.functional as F
6
  import numpy as np
 
 
7
  from torchvision import transforms
8
  from PIL import Image
9
  import rembg
10
  from .base import Pipeline
11
  from . import samplers
12
  from ..modules import sparse as sp
 
13
 
14
 
15
  class TrellisImageTo3DPipeline(Pipeline):
 
268
  Args:
269
  image (Image.Image): The image prompt.
270
  num_samples (int): The number of samples to generate.
271
+ seed (int): The random seed.
272
  sparse_structure_sampler_params (dict): Additional parameters for the sparse structure sampler.
273
  slat_sampler_params (dict): Additional parameters for the structured latent sampler.
274
+ formats (List[str]): The formats to decode the structured latent to.
275
  preprocess_image (bool): Whether to preprocess the image.
276
  """
277
  if preprocess_image:
trellis/pipelines/trellis_text_to_3d.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ import torch.nn as nn
4
+ import numpy as np
5
+ from transformers import CLIPTextModel, AutoTokenizer
6
+ import open3d as o3d
7
+ from .base import Pipeline
8
+ from . import samplers
9
+ from ..modules import sparse as sp
10
+
11
+
12
+ class TrellisTextTo3DPipeline(Pipeline):
13
+ """
14
+ Pipeline for inferring Trellis text-to-3D models.
15
+
16
+ Args:
17
+ models (dict[str, nn.Module]): The models to use in the pipeline.
18
+ sparse_structure_sampler (samplers.Sampler): The sampler for the sparse structure.
19
+ slat_sampler (samplers.Sampler): The sampler for the structured latent.
20
+ slat_normalization (dict): The normalization parameters for the structured latent.
21
+ text_cond_model (str): The name of the text conditioning model.
22
+ """
23
+ def __init__(
24
+ self,
25
+ models: dict[str, nn.Module] = None,
26
+ sparse_structure_sampler: samplers.Sampler = None,
27
+ slat_sampler: samplers.Sampler = None,
28
+ slat_normalization: dict = None,
29
+ text_cond_model: str = None,
30
+ ):
31
+ if models is None:
32
+ return
33
+ super().__init__(models)
34
+ self.sparse_structure_sampler = sparse_structure_sampler
35
+ self.slat_sampler = slat_sampler
36
+ self.sparse_structure_sampler_params = {}
37
+ self.slat_sampler_params = {}
38
+ self.slat_normalization = slat_normalization
39
+ self._init_text_cond_model(text_cond_model)
40
+
41
+ @staticmethod
42
+ def from_pretrained(path: str) -> "TrellisTextTo3DPipeline":
43
+ """
44
+ Load a pretrained model.
45
+
46
+ Args:
47
+ path (str): The path to the model. Can be either local path or a Hugging Face repository.
48
+ """
49
+ pipeline = super(TrellisTextTo3DPipeline, TrellisTextTo3DPipeline).from_pretrained(path)
50
+ new_pipeline = TrellisTextTo3DPipeline()
51
+ new_pipeline.__dict__ = pipeline.__dict__
52
+ args = pipeline._pretrained_args
53
+
54
+ new_pipeline.sparse_structure_sampler = getattr(samplers, args['sparse_structure_sampler']['name'])(**args['sparse_structure_sampler']['args'])
55
+ new_pipeline.sparse_structure_sampler_params = args['sparse_structure_sampler']['params']
56
+
57
+ new_pipeline.slat_sampler = getattr(samplers, args['slat_sampler']['name'])(**args['slat_sampler']['args'])
58
+ new_pipeline.slat_sampler_params = args['slat_sampler']['params']
59
+
60
+ new_pipeline.slat_normalization = args['slat_normalization']
61
+
62
+ new_pipeline._init_text_cond_model(args['text_cond_model'])
63
+
64
+ return new_pipeline
65
+
66
+ def _init_text_cond_model(self, name: str):
67
+ """
68
+ Initialize the text conditioning model.
69
+ """
70
+ # load model
71
+ model = CLIPTextModel.from_pretrained(name)
72
+ tokenizer = AutoTokenizer.from_pretrained(name)
73
+ model.eval()
74
+ model = model.cuda()
75
+ self.text_cond_model = {
76
+ 'model': model,
77
+ 'tokenizer': tokenizer,
78
+ }
79
+ self.text_cond_model['null_cond'] = self.encode_text([''])
80
+
81
+ @torch.no_grad()
82
+ def encode_text(self, text: List[str]) -> torch.Tensor:
83
+ """
84
+ Encode the text.
85
+ """
86
+ assert isinstance(text, list) and all(isinstance(t, str) for t in text), "text must be a list of strings"
87
+ encoding = self.text_cond_model['tokenizer'](text, max_length=77, padding='max_length', truncation=True, return_tensors='pt')
88
+ tokens = encoding['input_ids'].cuda()
89
+ embeddings = self.text_cond_model['model'](input_ids=tokens).last_hidden_state
90
+
91
+ return embeddings
92
+
93
+ def get_cond(self, prompt: List[str]) -> dict:
94
+ """
95
+ Get the conditioning information for the model.
96
+
97
+ Args:
98
+ prompt (List[str]): The text prompt.
99
+
100
+ Returns:
101
+ dict: The conditioning information
102
+ """
103
+ cond = self.encode_text(prompt)
104
+ neg_cond = self.text_cond_model['null_cond']
105
+ return {
106
+ 'cond': cond,
107
+ 'neg_cond': neg_cond,
108
+ }
109
+
110
+ def sample_sparse_structure(
111
+ self,
112
+ cond: dict,
113
+ num_samples: int = 1,
114
+ sampler_params: dict = {},
115
+ ) -> torch.Tensor:
116
+ """
117
+ Sample sparse structures with the given conditioning.
118
+
119
+ Args:
120
+ cond (dict): The conditioning information.
121
+ num_samples (int): The number of samples to generate.
122
+ sampler_params (dict): Additional parameters for the sampler.
123
+ """
124
+ # Sample occupancy latent
125
+ flow_model = self.models['sparse_structure_flow_model']
126
+ reso = flow_model.resolution
127
+ noise = torch.randn(num_samples, flow_model.in_channels, reso, reso, reso).to(self.device)
128
+ sampler_params = {**self.sparse_structure_sampler_params, **sampler_params}
129
+ z_s = self.sparse_structure_sampler.sample(
130
+ flow_model,
131
+ noise,
132
+ **cond,
133
+ **sampler_params,
134
+ verbose=True
135
+ ).samples
136
+
137
+ # Decode occupancy latent
138
+ decoder = self.models['sparse_structure_decoder']
139
+ coords = torch.argwhere(decoder(z_s)>0)[:, [0, 2, 3, 4]].int()
140
+
141
+ return coords
142
+
143
+ def decode_slat(
144
+ self,
145
+ slat: sp.SparseTensor,
146
+ formats: List[str] = ['mesh', 'gaussian', 'radiance_field'],
147
+ ) -> dict:
148
+ """
149
+ Decode the structured latent.
150
+
151
+ Args:
152
+ slat (sp.SparseTensor): The structured latent.
153
+ formats (List[str]): The formats to decode the structured latent to.
154
+
155
+ Returns:
156
+ dict: The decoded structured latent.
157
+ """
158
+ ret = {}
159
+ if 'mesh' in formats:
160
+ ret['mesh'] = self.models['slat_decoder_mesh'](slat)
161
+ if 'gaussian' in formats:
162
+ ret['gaussian'] = self.models['slat_decoder_gs'](slat)
163
+ if 'radiance_field' in formats:
164
+ ret['radiance_field'] = self.models['slat_decoder_rf'](slat)
165
+ return ret
166
+
167
+ def sample_slat(
168
+ self,
169
+ cond: dict,
170
+ coords: torch.Tensor,
171
+ sampler_params: dict = {},
172
+ ) -> sp.SparseTensor:
173
+ """
174
+ Sample structured latent with the given conditioning.
175
+
176
+ Args:
177
+ cond (dict): The conditioning information.
178
+ coords (torch.Tensor): The coordinates of the sparse structure.
179
+ sampler_params (dict): Additional parameters for the sampler.
180
+ """
181
+ # Sample structured latent
182
+ flow_model = self.models['slat_flow_model']
183
+ noise = sp.SparseTensor(
184
+ feats=torch.randn(coords.shape[0], flow_model.in_channels).to(self.device),
185
+ coords=coords,
186
+ )
187
+ sampler_params = {**self.slat_sampler_params, **sampler_params}
188
+ slat = self.slat_sampler.sample(
189
+ flow_model,
190
+ noise,
191
+ **cond,
192
+ **sampler_params,
193
+ verbose=True
194
+ ).samples
195
+
196
+ std = torch.tensor(self.slat_normalization['std'])[None].to(slat.device)
197
+ mean = torch.tensor(self.slat_normalization['mean'])[None].to(slat.device)
198
+ slat = slat * std + mean
199
+
200
+ return slat
201
+
202
+ @torch.no_grad()
203
+ def run(
204
+ self,
205
+ prompt: str,
206
+ num_samples: int = 1,
207
+ seed: int = 42,
208
+ sparse_structure_sampler_params: dict = {},
209
+ slat_sampler_params: dict = {},
210
+ formats: List[str] = ['mesh', 'gaussian', 'radiance_field'],
211
+ ) -> dict:
212
+ """
213
+ Run the pipeline.
214
+
215
+ Args:
216
+ prompt (str): The text prompt.
217
+ num_samples (int): The number of samples to generate.
218
+ seed (int): The random seed.
219
+ sparse_structure_sampler_params (dict): Additional parameters for the sparse structure sampler.
220
+ slat_sampler_params (dict): Additional parameters for the structured latent sampler.
221
+ formats (List[str]): The formats to decode the structured latent to.
222
+ """
223
+ cond = self.get_cond([prompt])
224
+ torch.manual_seed(seed)
225
+ coords = self.sample_sparse_structure(cond, num_samples, sparse_structure_sampler_params)
226
+ slat = self.sample_slat(cond, coords, slat_sampler_params)
227
+ return self.decode_slat(slat, formats)
228
+
229
+ def voxelize(self, mesh: o3d.geometry.TriangleMesh) -> torch.Tensor:
230
+ """
231
+ Voxelize a mesh.
232
+
233
+ Args:
234
+ mesh (o3d.geometry.TriangleMesh): The mesh to voxelize.
235
+ sha256 (str): The SHA256 hash of the mesh.
236
+ output_dir (str): The output directory.
237
+ """
238
+ vertices = np.asarray(mesh.vertices)
239
+ aabb = np.stack([vertices.min(0), vertices.max(0)])
240
+ center = (aabb[0] + aabb[1]) / 2
241
+ scale = (aabb[1] - aabb[0]).max()
242
+ vertices = (vertices - center) / scale
243
+ vertices = np.clip(vertices, -0.5 + 1e-6, 0.5 - 1e-6)
244
+ mesh.vertices = o3d.utility.Vector3dVector(vertices)
245
+ voxel_grid = o3d.geometry.VoxelGrid.create_from_triangle_mesh_within_bounds(mesh, voxel_size=1/64, min_bound=(-0.5, -0.5, -0.5), max_bound=(0.5, 0.5, 0.5))
246
+ vertices = np.array([voxel.grid_index for voxel in voxel_grid.get_voxels()])
247
+ return torch.tensor(vertices).int().cuda()
248
+
249
+ @torch.no_grad()
250
+ def run_variant(
251
+ self,
252
+ mesh: o3d.geometry.TriangleMesh,
253
+ prompt: str,
254
+ num_samples: int = 1,
255
+ seed: int = 42,
256
+ slat_sampler_params: dict = {},
257
+ formats: List[str] = ['mesh', 'gaussian', 'radiance_field'],
258
+ ) -> dict:
259
+ """
260
+ Run the pipeline for making variants of an asset.
261
+
262
+ Args:
263
+ mesh (o3d.geometry.TriangleMesh): The base mesh.
264
+ prompt (str): The text prompt.
265
+ num_samples (int): The number of samples to generate.
266
+ seed (int): The random seed
267
+ slat_sampler_params (dict): Additional parameters for the structured latent sampler.
268
+ formats (List[str]): The formats to decode the structured latent to.
269
+ """
270
+ cond = self.get_cond([prompt])
271
+ coords = self.voxelize(mesh)
272
+ coords = torch.cat([
273
+ torch.arange(num_samples).repeat_interleave(coords.shape[0], 0)[:, None].int().cuda(),
274
+ coords.repeat(num_samples, 1)
275
+ ], 1)
276
+ torch.manual_seed(seed)
277
+ slat = self.sample_slat(cond, coords, slat_sampler_params)
278
+ return self.decode_slat(slat, formats)
trellis/representations/mesh/cube2mesh.py CHANGED
@@ -1,15 +1,8 @@
1
- # Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- #
3
- # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
4
- # and proprietary rights in and to this software, related documentation
5
- # and any modifications thereto. Any use, reproduction, disclosure or
6
- # distribution of this software and related documentation without an express
7
- # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
8
  import torch
9
  from ...modules.sparse import SparseTensor
10
  from easydict import EasyDict as edict
11
  from .utils_cube import *
12
- from .flexicube import FlexiCubes
13
 
14
 
15
  class MeshExtractResult:
 
 
 
 
 
 
 
 
1
  import torch
2
  from ...modules.sparse import SparseTensor
3
  from easydict import EasyDict as edict
4
  from .utils_cube import *
5
+ from .flexicubes.flexicubes import FlexiCubes
6
 
7
 
8
  class MeshExtractResult:
trellis/representations/mesh/flexicubes/LICENSE.txt ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+
3
+
4
+ NVIDIA Source Code License for FlexiCubes
5
+
6
+
7
+ =======================================================================
8
+
9
+ 1. Definitions
10
+
11
+ “Licensor” means any person or entity that distributes its Work.
12
+
13
+ “Work” means (a) the original work of authorship made available under
14
+ this license, which may include software, documentation, or other files,
15
+ and (b) any additions to or derivative works thereof that are made
16
+ available under this license.
17
+
18
+ The terms “reproduce,” “reproduction,” “derivative works,” and
19
+ “distribution” have the meaning as provided under U.S. copyright law;
20
+ provided, however, that for the purposes of this license, derivative works
21
+ shall not include works that remain separable from, or merely link
22
+ (or bind by name) to the interfaces of, the Work.
23
+
24
+ Works are “made available” under this license by including in or with
25
+ the Work either (a) a copyright notice referencing the applicability of
26
+ this license to the Work, or (b) a copy of this license.
27
+
28
+ 2. License Grant
29
+
30
+ 2.1 Copyright Grant. Subject to the terms and conditions of this license,
31
+ each Licensor grants to you a perpetual, worldwide, non-exclusive,
32
+ royalty-free, copyright license to use, reproduce, prepare derivative
33
+ works of, publicly display, publicly perform, sublicense and distribute
34
+ its Work and any resulting derivative works in any form.
35
+
36
+ 3. Limitations
37
+
38
+ 3.1 Redistribution. You may reproduce or distribute the Work only if
39
+ (a) you do so under this license, (b) you include a complete copy of
40
+ this license with your distribution, and (c) you retain without
41
+ modification any copyright, patent, trademark, or attribution notices
42
+ that are present in the Work.
43
+
44
+ 3.2 Derivative Works. You may specify that additional or different terms
45
+ apply to the use, reproduction, and distribution of your derivative
46
+ works of the Work (“Your Terms”) only if (a) Your Terms provide that the
47
+ use limitation in Section 3.3 applies to your derivative works, and (b)
48
+ you identify the specific derivative works that are subject to Your Terms.
49
+ Notwithstanding Your Terms, this license (including the redistribution
50
+ requirements in Section 3.1) will continue to apply to the Work itself.
51
+
52
+ 3.3 Use Limitation. The Work and any derivative works thereof only may be
53
+ used or intended for use non-commercially. Notwithstanding the foregoing,
54
+ NVIDIA Corporation and its affiliates may use the Work and any derivative
55
+ works commercially. As used herein, “non-commercially” means for research
56
+ or evaluation purposes only.
57
+
58
+ 3.4 Patent Claims. If you bring or threaten to bring a patent claim against
59
+ any Licensor (including any claim, cross-claim or counterclaim in a lawsuit)
60
+ to enforce any patents that you allege are infringed by any Work, then your
61
+ rights under this license from such Licensor (including the grant in
62
+ Section 2.1) will terminate immediately.
63
+
64
+ 3.5 Trademarks. This license does not grant any rights to use any Licensor’s
65
+ or its affiliates’ names, logos, or trademarks, except as necessary to
66
+ reproduce the notices described in this license.
67
+
68
+ 3.6 Termination. If you violate any term of this license, then your rights
69
+ under this license (including the grant in Section 2.1) will terminate
70
+ immediately.
71
+
72
+ 4. Disclaimer of Warranty.
73
+
74
+ THE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
75
+ EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
76
+ MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR NON-INFRINGEMENT.
77
+ YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER THIS LICENSE.
78
+
79
+ 5. Limitation of Liability.
80
+
81
+ EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY,
82
+ WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE SHALL ANY
83
+ LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL,
84
+ INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF OR RELATED TO THIS LICENSE,
85
+ THE USE OR INABILITY TO USE THE WORK (INCLUDING BUT NOT LIMITED TO LOSS OF
86
+ GOODWILL, BUSINESS INTERRUPTION, LOST PROFITS OR DATA, COMPUTER FAILURE OR
87
+ MALFUNCTION, OR ANY OTHER DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN
88
+ ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
89
+
90
+ =======================================================================
trellis/representations/mesh/flexicubes/README.md ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Flexible Isosurface Extraction for Gradient-Based Mesh Optimization (FlexiCubes)<br><sub>Official PyTorch implementation </sub>
2
+
3
+ ![Teaser image](<images/teaser_top.png>)
4
+
5
+ FlexiCubes is a high-quality isosurface representation specifically designed for gradient-based mesh optimization with respect to geometric, visual, or even physical objectives. For more details, please refer to our [paper](https://arxiv.org/abs/2308.05371) and [project page](https://research.nvidia.com/labs/toronto-ai/flexicubes/).
6
+
7
+ ## Highlights
8
+ * [Getting started](https://github.com/nv-tlabs/FlexiCubes#getting-started)
9
+ * [Basic workflow](https://github.com/nv-tlabs/FlexiCubes#example-usage)
10
+ * [nvdiffrec: image-based reconstruction example](https://github.com/NVlabs/nvdiffrec#news)
11
+ * [GET3D: generative AI example](https://github.com/nv-tlabs/GET3D#employing-flexicubes)
12
+ * [Bibtex](https://github.com/nv-tlabs/FlexiCubes#citation)
13
+
14
+ ## Getting Started
15
+
16
+ The core functions of FlexiCubes are now in [Kaolin](https://github.com/NVIDIAGameWorks/kaolin/) starting from v0.15.0. See installation instructions [here](https://kaolin.readthedocs.io/en/latest/notes/installation.html) and API documentations [here](https://kaolin.readthedocs.io/en/latest/modules/kaolin.non_commercial.html#kaolin.non_commercial.FlexiCubes)
17
+
18
+ The original code of the paper is still visible in `flexicube.py`.
19
+
20
+ ## Example Usage
21
+
22
+ ### Gradient-Based Mesh Optimization
23
+ We provide examples demonstrating how to use FlexiCubes for reconstructing unknown meshes through gradient-based optimization. Specifically, starting from randomly initialized SDF, we optimize the shape towards the reference mesh by minimizing their geometric difference, measured by multiview mask and depth losses. This workflow is a simplified version of `nvdiffrec` with code largely borrowed from the [nvdiffrec GitHub](https://github.com/NVlabs/nvdiffrec). We use the same pipeline to conduct the analysis in Section 3 and the main experiments described in Section 5 of our paper. We provide a detailed tutorial in `examples/optimization.ipynb`, along with an optimization script in `examples/optimize.py` which accepts command-line arguments.
24
+
25
+
26
+ To run the examples, it is suggested to install the Conda environment as detailed below:
27
+ ```sh
28
+ conda create -n flexicubes python=3.9
29
+ conda activate flexicubes
30
+ conda install pytorch==1.12.0 torchvision==0.13.0 torchaudio==0.12.0 cudatoolkit=11.3 -c pytorch
31
+ pip install imageio trimesh tqdm matplotlib torch_scatter ninja
32
+ pip install git+https://github.com/NVlabs/nvdiffrast/
33
+ pip install kaolin==0.15.0 -f https://nvidia-kaolin.s3.us-east-2.amazonaws.com/torch-1.12.0_cu113.html
34
+ ```
35
+
36
+ Then download the dataset collected by [Myles et al.](https://vcg.isti.cnr.it/Publications/2014/MPZ14/) as follows. We include one shape in 'examples/data/inputmodels/block.obj' if you want to test without downloading the full dataset.
37
+
38
+ ```sh
39
+ cd examples
40
+ python download_data.py
41
+ ```
42
+
43
+ After downloading the data, run shape optimization with the following example command:
44
+ ```sh
45
+ python optimize.py --ref_mesh data/inputmodels/block.obj --out_dir out/block
46
+ ```
47
+ You can find visualization and output meshes in the `out/block`. Below, we show the initial and final shapes during optimization, with the reference shape on the right.
48
+
49
+ <img src="images/block_init.png" alt="block_init" width="80%" height="80%">
50
+
51
+ <img src="images/block_final.png" alt="block_final" width="80%" height="80%">
52
+
53
+
54
+ To further demonstrate the flexibility of our FlexiCubes representation, which can accommodates both reconstruction objectives and regularizers defined on the extracted mesh, you can add a developability regularizer (proposed by [Stein et al.](https://www.cs.cmu.edu/~kmcrane/Projects/DiscreteDevelopable/)) to the previous reconstruction pipeline to encourage fabricability from panels:
55
+ ```sh
56
+ python optimize.py --ref_mesh data/inputmodels/david.obj --out_dir out/david_dev --develop_reg True --iter=1250
57
+ ```
58
+
59
+ ### Extract mesh from known signed distance field
60
+ While not its designated use case, our function can extract a mesh from a known Signed Distance Field (SDF) without optimization. Please refer to the tutorial found in `examples/extraction.ipynb` for details.
61
+
62
+ ## Tips for using FlexiCubes
63
+ ### Regularization losses:
64
+ We commonly use three regularizers in our mesh optimization pipelines, referenced in lines `L104-L106` in `examples/optimize.py`. The weights of these regularizers should be scaled according to the your application objectives. Initially, it is suggested to employ low weights because strong regularization can hinder convergence. You can incrementally increase the weights if you notice artifacts appearing in the optimized meshes. Specifically:
65
+
66
+ * The loss function at `L104` helps to remove floaters in areas of the shape that are not supervised by the application objective, such as internal faces when using image supervision only.
67
+ * The L_dev loss at `L105` can be increased if you observe artifacts in flat areas, as illustrated in the image below.
68
+ * Generally, the L1 regularizer on flexible weights at `L106` does not have a significant impact during the optimization of a single shape. However, we found it to be effective in stabilizing training in generative pipelines such as GET3D.
69
+ <img src="images/ablate_L_dev.jpg" alt="Ablating L_dev" width="80%" height="80%">
70
+
71
+ ### Resolution of voxel grid vs. tetrahedral grid:
72
+ If you are switching from our previous work, DMTet, it's important to note the difference in grid resolution when compared to FlexiCubes. In both implementations, the resolution is defined by the edge length: a grid resolution of `n` means the grid edge length is 1/n for both the voxel and tetrahedral grids. However, a tetrahedral grid with a resolution of `n` contains only `(n/2+1)³` grid vertices, in contrast to the `(n+1)³` vertices in a voxel grid. Consequently, if you are switching from DMTet to FlexiCubes while maintaining the same resolution, you will notice not only a denser output mesh but also a substantial increase in computational cost. To align the triangle count in the output meshes more closely, we recommend adopting a 4:5 resolution ratio between the voxel grid and the tetrahedral grid. For instance, in our paper, `64³` FlexiCubes generate approximately the same number of triangles as `80³` DMTet.
73
+
74
+ ## Applications
75
+ FlexiCubes is now integrated into NVIDIA applications as a drop-in replacement for DMTet. You can visit their GitHub pages to see how FlexiCubes is used in advanced photogrammetry and 3D generative pipelines.
76
+
77
+ [Extracting Triangular 3D Models, Materials, and Lighting From Images (nvdiffrec)](https://github.com/NVlabs/nvdiffrec#news)
78
+
79
+ [GET3D: A Generative Model of High Quality 3D Textured Shapes Learned from Images](https://github.com/nv-tlabs/GET3D#employing-flexicubes)
80
+
81
+
82
+
83
+ ## License
84
+ Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
85
+
86
+ This work is made available under the [Nvidia Source Code License](LICENSE.txt).
87
+
88
+ For business inquiries, please visit our website and submit the form: [NVIDIA Research Licensing](https://www.nvidia.com/en-us/research/inquiries/).
89
+
90
+ ## Citation
91
+ ```bibtex
92
+ @article{shen2023flexicubes,
93
+ author = {Shen, Tianchang and Munkberg, Jacob and Hasselgren, Jon and Yin, Kangxue and Wang, Zian
94
+ and Chen, Wenzheng and Gojcic, Zan and Fidler, Sanja and Sharp, Nicholas and Gao, Jun},
95
+ title = {Flexible Isosurface Extraction for Gradient-Based Mesh Optimization},
96
+ year = {2023},
97
+ issue_date = {August 2023},
98
+ publisher = {Association for Computing Machinery},
99
+ address = {New York, NY, USA},
100
+ volume = {42},
101
+ number = {4},
102
+ issn = {0730-0301},
103
+ url = {https://doi.org/10.1145/3592430},
104
+ doi = {10.1145/3592430},
105
+ journal = {ACM Trans. Graph.},
106
+ month = {jul},
107
+ articleno = {37},
108
+ numpages = {16}
109
+ }
110
+ ```
trellis/representations/mesh/flexicubes/examples/data/inputmodels/block.obj ADDED
The diff for this file is too large to render. See raw diff
 
trellis/representations/mesh/flexicubes/examples/download_data.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
8
+ import requests
9
+ from zipfile import ZipFile
10
+ from tqdm import tqdm
11
+ import os
12
+
13
+ def download_file(url, output_path):
14
+ response = requests.get(url, stream=True)
15
+ response.raise_for_status()
16
+ total_size_in_bytes = int(response.headers.get('content-length', 0))
17
+ block_size = 1024 #1 Kibibyte
18
+ progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
19
+
20
+ with open(output_path, 'wb') as file:
21
+ for data in response.iter_content(block_size):
22
+ progress_bar.update(len(data))
23
+ file.write(data)
24
+ progress_bar.close()
25
+ if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
26
+ raise Exception("ERROR, something went wrong")
27
+
28
+
29
+ url = "https://vcg.isti.cnr.it/Publications/2014/MPZ14/inputmodels.zip"
30
+ zip_file_path = './data/inputmodels.zip'
31
+
32
+ os.makedirs('./data', exist_ok=True)
33
+
34
+ download_file(url, zip_file_path)
35
+
36
+ with ZipFile(zip_file_path, 'r') as zip_ref:
37
+ zip_ref.extractall('./data')
38
+
39
+ os.remove(zip_file_path)
40
+
41
+ print("Download and extraction complete.")
trellis/representations/mesh/flexicubes/examples/extraction.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
trellis/representations/mesh/flexicubes/examples/loss.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
8
+ import torch
9
+ import torch_scatter
10
+
11
+ ###############################################################################
12
+ # Pytorch implementation of the developability regularizer introduced in paper
13
+ # "Developability of Triangle Meshes" by Stein et al.
14
+ ###############################################################################
15
+ def mesh_developable_reg(mesh):
16
+
17
+ verts = mesh.vertices
18
+ tris = mesh.faces
19
+
20
+ device = verts.device
21
+ V = verts.shape[0]
22
+ F = tris.shape[0]
23
+
24
+ POS_EPS = 1e-6
25
+ REL_EPS = 1e-6
26
+
27
+ def normalize(vecs):
28
+ return vecs / (torch.linalg.norm(vecs, dim=-1, keepdim=True) + POS_EPS)
29
+
30
+ tri_pos = verts[tris]
31
+
32
+ vert_normal_covariance_sum = torch.zeros((V, 9), device=device)
33
+ vert_area = torch.zeros(V, device=device)
34
+ vert_degree = torch.zeros(V, dtype=torch.int32, device=device)
35
+
36
+ for iC in range(3): # loop over three corners of each triangle
37
+
38
+ # gather tri verts
39
+ pRoot = tri_pos[:, iC, :]
40
+ pA = tri_pos[:, (iC + 1) % 3, :]
41
+ pB = tri_pos[:, (iC + 2) % 3, :]
42
+
43
+ # compute the corner angle & normal
44
+ vA = pA - pRoot
45
+ vAn = normalize(vA)
46
+ vB = pB - pRoot
47
+ vBn = normalize(vB)
48
+ area_normal = torch.linalg.cross(vA, vB, dim=-1)
49
+ face_area = 0.5 * torch.linalg.norm(area_normal, dim=-1)
50
+ normal = normalize(area_normal)
51
+ corner_angle = torch.acos(torch.clamp(torch.sum(vAn * vBn, dim=-1), min=-1., max=1.))
52
+
53
+ # add up the contribution to the covariance matrix
54
+ outer = normal[:, :, None] @ normal[:, None, :]
55
+ contrib = corner_angle[:, None] * outer.reshape(-1, 9)
56
+
57
+ # scatter the result to the appropriate matrices
58
+ vert_normal_covariance_sum = torch_scatter.scatter_add(src=contrib,
59
+ index=tris[:, iC],
60
+ dim=-2,
61
+ out=vert_normal_covariance_sum)
62
+
63
+ vert_area = torch_scatter.scatter_add(src=face_area / 3.,
64
+ index=tris[:, iC],
65
+ dim=-1,
66
+ out=vert_area)
67
+
68
+ vert_degree = torch_scatter.scatter_add(src=torch.ones(F, dtype=torch.int32, device=device),
69
+ index=tris[:, iC],
70
+ dim=-1,
71
+ out=vert_degree)
72
+
73
+ # The energy is the smallest eigenvalue of the outer-product matrix
74
+ vert_normal_covariance_sum = vert_normal_covariance_sum.reshape(
75
+ -1, 3, 3) # reshape to a batch of matrices
76
+ vert_normal_covariance_sum = vert_normal_covariance_sum + torch.eye(
77
+ 3, device=device)[None, :, :] * REL_EPS
78
+
79
+ min_eigvals = torch.min(torch.linalg.eigvals(vert_normal_covariance_sum).abs(), dim=-1).values
80
+
81
+ # Mask out degree-3 vertices
82
+ vert_area = torch.where(vert_degree == 3, torch.tensor(0, dtype=vert_area.dtype,device=vert_area.device), vert_area)
83
+
84
+ # Adjust the vertex area weighting so it is unit-less, and 1 on average
85
+ vert_area = vert_area * (V / torch.sum(vert_area, dim=-1, keepdim=True))
86
+
87
+ return vert_area * min_eigvals
88
+
89
+ def sdf_reg_loss(sdf, all_edges):
90
+ sdf_f1x6x2 = sdf[all_edges.reshape(-1)].reshape(-1,2)
91
+ mask = torch.sign(sdf_f1x6x2[...,0]) != torch.sign(sdf_f1x6x2[...,1])
92
+ sdf_f1x6x2 = sdf_f1x6x2[mask]
93
+ sdf_diff = torch.nn.functional.binary_cross_entropy_with_logits(sdf_f1x6x2[...,0], (sdf_f1x6x2[...,1] > 0).float()) + \
94
+ torch.nn.functional.binary_cross_entropy_with_logits(sdf_f1x6x2[...,1], (sdf_f1x6x2[...,0] > 0).float())
95
+ return sdf_diff
trellis/representations/mesh/flexicubes/examples/optimization.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
trellis/representations/mesh/flexicubes/examples/optimize.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
8
+ import argparse
9
+ import numpy as np
10
+ import torch
11
+ import nvdiffrast.torch as dr
12
+ import trimesh
13
+ import os
14
+ from util import *
15
+ import render
16
+ import loss
17
+ import imageio
18
+
19
+ import sys
20
+ sys.path.append('..')
21
+ from flexicubes import FlexiCubes
22
+
23
+ ###############################################################################
24
+ # Functions adapted from https://github.com/NVlabs/nvdiffrec
25
+ ###############################################################################
26
+
27
+ def lr_schedule(iter):
28
+ return max(0.0, 10**(-(iter)*0.0002)) # Exponential falloff from [1.0, 0.1] over 5k epochs.
29
+
30
+ if __name__ == "__main__":
31
+ parser = argparse.ArgumentParser(description='flexicubes optimization')
32
+ parser.add_argument('-o', '--out_dir', type=str, default=None)
33
+ parser.add_argument('-rm', '--ref_mesh', type=str)
34
+
35
+ parser.add_argument('-i', '--iter', type=int, default=1000)
36
+ parser.add_argument('-b', '--batch', type=int, default=8)
37
+ parser.add_argument('-r', '--train_res', nargs=2, type=int, default=[2048, 2048])
38
+ parser.add_argument('-lr', '--learning_rate', type=float, default=0.01)
39
+ parser.add_argument('--voxel_grid_res', type=int, default=64)
40
+
41
+ parser.add_argument('--sdf_loss', type=bool, default=True)
42
+ parser.add_argument('--develop_reg', type=bool, default=False)
43
+ parser.add_argument('--sdf_regularizer', type=float, default=0.2)
44
+
45
+ parser.add_argument('-dr', '--display_res', nargs=2, type=int, default=[512, 512])
46
+ parser.add_argument('-si', '--save_interval', type=int, default=20)
47
+ FLAGS = parser.parse_args()
48
+ device = 'cuda'
49
+
50
+ os.makedirs(FLAGS.out_dir, exist_ok=True)
51
+ glctx = dr.RasterizeGLContext()
52
+
53
+ # Load GT mesh
54
+ gt_mesh = load_mesh(FLAGS.ref_mesh, device)
55
+ gt_mesh.auto_normals() # compute face normals for visualization
56
+
57
+ # ==============================================================================================
58
+ # Create and initialize FlexiCubes
59
+ # ==============================================================================================
60
+ fc = FlexiCubes(device)
61
+ x_nx3, cube_fx8 = fc.construct_voxel_grid(FLAGS.voxel_grid_res)
62
+ x_nx3 *= 2 # scale up the grid so that it's larger than the target object
63
+
64
+ sdf = torch.rand_like(x_nx3[:,0]) - 0.1 # randomly init SDF
65
+ sdf = torch.nn.Parameter(sdf.clone().detach(), requires_grad=True)
66
+ # set per-cube learnable weights to zeros
67
+ weight = torch.zeros((cube_fx8.shape[0], 21), dtype=torch.float, device='cuda')
68
+ weight = torch.nn.Parameter(weight.clone().detach(), requires_grad=True)
69
+ deform = torch.nn.Parameter(torch.zeros_like(x_nx3), requires_grad=True)
70
+
71
+ # Retrieve all the edges of the voxel grid; these edges will be utilized to
72
+ # compute the regularization loss in subsequent steps of the process.
73
+ all_edges = cube_fx8[:, fc.cube_edges].reshape(-1, 2)
74
+ grid_edges = torch.unique(all_edges, dim=0)
75
+
76
+ # ==============================================================================================
77
+ # Setup optimizer
78
+ # ==============================================================================================
79
+ optimizer = torch.optim.Adam([sdf, weight,deform], lr=FLAGS.learning_rate)
80
+ scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda x: lr_schedule(x))
81
+
82
+ # ==============================================================================================
83
+ # Train loop
84
+ # ==============================================================================================
85
+ for it in range(FLAGS.iter):
86
+ optimizer.zero_grad()
87
+ # sample random camera poses
88
+ mv, mvp = render.get_random_camera_batch(FLAGS.batch, iter_res=FLAGS.train_res, device=device, use_kaolin=False)
89
+ # render gt mesh
90
+ target = render.render_mesh_paper(gt_mesh, mv, mvp, FLAGS.train_res)
91
+ # extract and render FlexiCubes mesh
92
+ grid_verts = x_nx3 + (2-1e-8) / (FLAGS.voxel_grid_res * 2) * torch.tanh(deform)
93
+ vertices, faces, L_dev = fc(grid_verts, sdf, cube_fx8, FLAGS.voxel_grid_res, beta_fx12=weight[:,:12], alpha_fx8=weight[:,12:20],
94
+ gamma_f=weight[:,20], training=True)
95
+ flexicubes_mesh = Mesh(vertices, faces)
96
+ buffers = render.render_mesh_paper(flexicubes_mesh, mv, mvp, FLAGS.train_res)
97
+
98
+ # evaluate reconstruction loss
99
+ mask_loss = (buffers['mask'] - target['mask']).abs().mean()
100
+ depth_loss = (((((buffers['depth'] - (target['depth']))* target['mask'])**2).sum(-1)+1e-8)).sqrt().mean() * 10
101
+
102
+ t_iter = it / FLAGS.iter
103
+ sdf_weight = FLAGS.sdf_regularizer - (FLAGS.sdf_regularizer - FLAGS.sdf_regularizer/20)*min(1.0, 4.0 * t_iter)
104
+ reg_loss = loss.sdf_reg_loss(sdf, grid_edges).mean() * sdf_weight # Loss to eliminate internal floaters that are not visible
105
+ reg_loss += L_dev.mean() * 0.5
106
+ reg_loss += (weight[:,:20]).abs().mean() * 0.1
107
+ total_loss = mask_loss + depth_loss + reg_loss
108
+
109
+ if FLAGS.sdf_loss: # optionally add SDF loss to eliminate internal structures
110
+ with torch.no_grad():
111
+ pts = sample_random_points(1000, gt_mesh)
112
+ gt_sdf = compute_sdf(pts, gt_mesh.vertices, gt_mesh.faces)
113
+ pred_sdf = compute_sdf(pts, flexicubes_mesh.vertices, flexicubes_mesh.faces)
114
+ total_loss += torch.nn.functional.mse_loss(pred_sdf, gt_sdf) * 2e3
115
+
116
+ # optionally add developability regularizer, as described in paper section 5.2
117
+ if FLAGS.develop_reg:
118
+ reg_weight = max(0, t_iter - 0.8) * 5
119
+ if reg_weight > 0: # only applied after shape converges
120
+ reg_loss = loss.mesh_developable_reg(flexicubes_mesh).mean() * 10
121
+ reg_loss += (deform).abs().mean()
122
+ reg_loss += (weight[:,:20]).abs().mean()
123
+ total_loss = mask_loss + depth_loss + reg_loss
124
+
125
+ total_loss.backward()
126
+ optimizer.step()
127
+ scheduler.step()
128
+
129
+ if (it % FLAGS.save_interval == 0 or it == (FLAGS.iter-1)): # save normal image for visualization
130
+ with torch.no_grad():
131
+ # extract mesh with training=False
132
+ vertices, faces, L_dev = fc(grid_verts, sdf, cube_fx8, FLAGS.voxel_grid_res, beta_fx12=weight[:,:12], alpha_fx8=weight[:,12:20],
133
+ gamma_f=weight[:,20], training=False)
134
+ flexicubes_mesh = Mesh(vertices, faces)
135
+
136
+ flexicubes_mesh.auto_normals() # compute face normals for visualization
137
+ mv, mvp = render.get_rotate_camera(it//FLAGS.save_interval, iter_res=FLAGS.display_res, device=device,use_kaolin=False)
138
+ val_buffers = render.render_mesh_paper(flexicubes_mesh, mv.unsqueeze(0), mvp.unsqueeze(0), FLAGS.display_res, return_types=["normal"], white_bg=True)
139
+ val_image = ((val_buffers["normal"][0].detach().cpu().numpy()+1)/2*255).astype(np.uint8)
140
+
141
+ gt_buffers = render.render_mesh_paper(gt_mesh, mv.unsqueeze(0), mvp.unsqueeze(0), FLAGS.display_res, return_types=["normal"], white_bg=True)
142
+ gt_image = ((gt_buffers["normal"][0].detach().cpu().numpy()+1)/2*255).astype(np.uint8)
143
+ imageio.imwrite(os.path.join(FLAGS.out_dir, '{:04d}.png'.format(it)), np.concatenate([val_image, gt_image], 1))
144
+ print(f"Optimization Step [{it}/{FLAGS.iter}], Loss: {total_loss.item():.4f}")
145
+
146
+ # ==============================================================================================
147
+ # Save ouput
148
+ # ==============================================================================================
149
+ mesh_np = trimesh.Trimesh(vertices = vertices.detach().cpu().numpy(), faces=faces.detach().cpu().numpy(), process=False)
150
+ mesh_np.export(os.path.join(FLAGS.out_dir, 'output_mesh.obj'))
trellis/representations/mesh/flexicubes/examples/render.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
8
+ import numpy as np
9
+ import copy
10
+ import math
11
+ from ipywidgets import interactive, HBox, VBox, FloatLogSlider, IntSlider
12
+
13
+ import torch
14
+ import nvdiffrast.torch as dr
15
+ import kaolin as kal
16
+ import util
17
+
18
+ ###############################################################################
19
+ # Functions adapted from https://github.com/NVlabs/nvdiffrec
20
+ ###############################################################################
21
+
22
+ def get_random_camera_batch(batch_size, fovy = np.deg2rad(45), iter_res=[512,512], cam_near_far=[0.1, 1000.0], cam_radius=3.0, device="cuda", use_kaolin=True):
23
+ if use_kaolin:
24
+ camera_pos = torch.stack(kal.ops.coords.spherical2cartesian(
25
+ *kal.ops.random.sample_spherical_coords((batch_size,), azimuth_low=0., azimuth_high=math.pi * 2,
26
+ elevation_low=-math.pi / 2., elevation_high=math.pi / 2., device='cuda'),
27
+ cam_radius
28
+ ), dim=-1)
29
+ return kal.render.camera.Camera.from_args(
30
+ eye=camera_pos + torch.rand((batch_size, 1), device='cuda') * 0.5 - 0.25,
31
+ at=torch.zeros(batch_size, 3),
32
+ up=torch.tensor([[0., 1., 0.]]),
33
+ fov=fovy,
34
+ near=cam_near_far[0], far=cam_near_far[1],
35
+ height=iter_res[0], width=iter_res[1],
36
+ device='cuda'
37
+ )
38
+ else:
39
+ def get_random_camera():
40
+ proj_mtx = util.perspective(fovy, iter_res[1] / iter_res[0], cam_near_far[0], cam_near_far[1])
41
+ mv = util.translate(0, 0, -cam_radius) @ util.random_rotation_translation(0.25)
42
+ mvp = proj_mtx @ mv
43
+ return mv, mvp
44
+ mv_batch = []
45
+ mvp_batch = []
46
+ for i in range(batch_size):
47
+ mv, mvp = get_random_camera()
48
+ mv_batch.append(mv)
49
+ mvp_batch.append(mvp)
50
+ return torch.stack(mv_batch).to(device), torch.stack(mvp_batch).to(device)
51
+
52
+ def get_rotate_camera(itr, fovy = np.deg2rad(45), iter_res=[512,512], cam_near_far=[0.1, 1000.0], cam_radius=3.0, device="cuda", use_kaolin=True):
53
+ if use_kaolin:
54
+ ang = (itr / 10) * np.pi * 2
55
+ camera_pos = torch.stack(kal.ops.coords.spherical2cartesian(torch.tensor(ang), torch.tensor(0.4), -torch.tensor(cam_radius)))
56
+ return kal.render.camera.Camera.from_args(
57
+ eye=camera_pos,
58
+ at=torch.zeros(3),
59
+ up=torch.tensor([0., 1., 0.]),
60
+ fov=fovy,
61
+ near=cam_near_far[0], far=cam_near_far[1],
62
+ height=iter_res[0], width=iter_res[1],
63
+ device='cuda'
64
+ )
65
+ else:
66
+ proj_mtx = util.perspective(fovy, iter_res[1] / iter_res[0], cam_near_far[0], cam_near_far[1])
67
+
68
+ # Smooth rotation for display.
69
+ ang = (itr / 10) * np.pi * 2
70
+ mv = util.translate(0, 0, -cam_radius) @ (util.rotate_x(-0.4) @ util.rotate_y(ang))
71
+ mvp = proj_mtx @ mv
72
+ return mv.to(device), mvp.to(device)
73
+
74
+ glctx = dr.RasterizeGLContext()
75
+ def render_mesh(mesh, camera, iter_res, return_types = ["mask", "depth"], white_bg=False, wireframe_thickness=0.4):
76
+ vertices_camera = camera.extrinsics.transform(mesh.vertices)
77
+ face_vertices_camera = kal.ops.mesh.index_vertices_by_faces(
78
+ vertices_camera, mesh.faces
79
+ )
80
+
81
+ # Projection: nvdiffrast take clip coordinates as input to apply barycentric perspective correction.
82
+ # Using `camera.intrinsics.transform(vertices_camera) would return the normalized device coordinates.
83
+ proj = camera.projection_matrix().unsqueeze(1)
84
+ proj[:, :, 1, 1] = -proj[:, :, 1, 1]
85
+ homogeneous_vecs = kal.render.camera.up_to_homogeneous(
86
+ vertices_camera
87
+ )
88
+ vertices_clip = (proj @ homogeneous_vecs.unsqueeze(-1)).squeeze(-1)
89
+ faces_int = mesh.faces.int()
90
+
91
+ rast, _ = dr.rasterize(
92
+ glctx, vertices_clip, faces_int, iter_res)
93
+
94
+ out_dict = {}
95
+ for type in return_types:
96
+ if type == "mask" :
97
+ img = dr.antialias((rast[..., -1:] > 0).float(), rast, vertices_clip, faces_int)
98
+ elif type == "depth":
99
+ img = dr.interpolate(homogeneous_vecs, rast, faces_int)[0]
100
+ elif type == "wireframe":
101
+ img = torch.logical_or(
102
+ torch.logical_or(rast[..., 0] < wireframe_thickness, rast[..., 1] < wireframe_thickness),
103
+ (rast[..., 0] + rast[..., 1]) > (1. - wireframe_thickness)
104
+ ).unsqueeze(-1)
105
+ elif type == "normals" :
106
+ img = dr.interpolate(
107
+ mesh.face_normals.reshape(len(mesh), -1, 3), rast,
108
+ torch.arange(mesh.faces.shape[0] * 3, device='cuda', dtype=torch.int).reshape(-1, 3)
109
+ )[0]
110
+ if white_bg:
111
+ bg = torch.ones_like(img)
112
+ alpha = (rast[..., -1:] > 0).float()
113
+ img = torch.lerp(bg, img, alpha)
114
+ out_dict[type] = img
115
+
116
+
117
+ return out_dict
118
+
119
+ def render_mesh_paper(mesh, mv, mvp, iter_res, return_types = ["mask", "depth"], white_bg=False):
120
+ '''
121
+ The rendering function used to produce the results in the paper.
122
+ '''
123
+ v_pos_clip = util.xfm_points(mesh.vertices.unsqueeze(0), mvp) # Rotate it to camera coordinates
124
+ rast, db = dr.rasterize(
125
+ dr.RasterizeGLContext(), v_pos_clip, mesh.faces.int(), iter_res)
126
+
127
+ out_dict = {}
128
+ for type in return_types:
129
+ if type == "mask" :
130
+ img = dr.antialias((rast[..., -1:] > 0).float(), rast, v_pos_clip, mesh.faces.int())
131
+ elif type == "depth":
132
+ v_pos_cam = util.xfm_points(mesh.vertices.unsqueeze(0), mv)
133
+ img, _ = util.interpolate(v_pos_cam, rast, mesh.faces.int())
134
+ elif type == "normal" :
135
+ normal_indices = (torch.arange(0, mesh.nrm.shape[0], dtype=torch.int64, device='cuda')[:, None]).repeat(1, 3)
136
+ img, _ = util.interpolate(mesh.nrm.unsqueeze(0).contiguous(), rast, normal_indices.int())
137
+ elif type == "vertex_normal":
138
+ img, _ = util.interpolate(mesh.v_nrm.unsqueeze(0).contiguous(), rast, mesh.faces.int())
139
+ img = dr.antialias((img + 1) * 0.5, rast, v_pos_clip, mesh.faces.int())
140
+ if white_bg:
141
+ bg = torch.ones_like(img)
142
+ alpha = (rast[..., -1:] > 0).float()
143
+ img = torch.lerp(bg, img, alpha)
144
+ out_dict[type] = img
145
+ return out_dict
146
+
147
+ class SplitVisualizer():
148
+ def __init__(self, lh_mesh, rh_mesh, height, width):
149
+ self.lh_mesh = lh_mesh
150
+ self.rh_mesh = rh_mesh
151
+ self.height = height
152
+ self.width = width
153
+ self.wireframe_thickness = 0.4
154
+
155
+
156
+ def render(self, camera):
157
+ lh_outputs = render_mesh(
158
+ self.lh_mesh, camera, (self.height, self.width),
159
+ return_types=["normals", "wireframe"], wireframe_thickness=self.wireframe_thickness
160
+ )
161
+ rh_outputs = render_mesh(
162
+ self.rh_mesh, camera, (self.height, self.width),
163
+ return_types=["normals", "wireframe"], wireframe_thickness=self.wireframe_thickness
164
+ )
165
+ outputs = {
166
+ k: torch.cat(
167
+ [lh_outputs[k][0].permute(1, 0, 2), rh_outputs[k][0].permute(1, 0, 2)],
168
+ dim=0
169
+ ).permute(1, 0, 2) for k in ["normals", "wireframe"]
170
+ }
171
+ return {
172
+ 'img': (outputs['wireframe'] * ((outputs['normals'] + 1.) / 2.) * 255).to(torch.uint8),
173
+ 'normals': outputs['normals']
174
+ }
175
+
176
+ def show(self, init_camera):
177
+ visualizer = kal.visualize.IpyTurntableVisualizer(
178
+ self.height, self.width * 2, copy.deepcopy(init_camera), self.render,
179
+ max_fps=24, world_up_axis=1)
180
+
181
+ def slider_callback(new_wireframe_thickness):
182
+ """ipywidgets sliders callback"""
183
+ with visualizer.out: # This is in case of bug
184
+ self.wireframe_thickness = new_wireframe_thickness
185
+ # this is how we request a new update
186
+ visualizer.render_update()
187
+
188
+ wireframe_thickness_slider = FloatLogSlider(
189
+ value=self.wireframe_thickness,
190
+ base=10,
191
+ min=-3,
192
+ max=-0.4,
193
+ step=0.1,
194
+ description='wireframe_thickness',
195
+ continuous_update=True,
196
+ readout=True,
197
+ readout_format='.3f',
198
+ )
199
+
200
+ interactive_slider = interactive(
201
+ slider_callback,
202
+ new_wireframe_thickness=wireframe_thickness_slider,
203
+ )
204
+
205
+ full_output = VBox([visualizer.canvas, interactive_slider])
206
+ display(full_output, visualizer.out)
207
+
208
+ class TimelineVisualizer():
209
+ def __init__(self, meshes, height, width):
210
+ self.meshes = meshes
211
+ self.height = height
212
+ self.width = width
213
+ self.wireframe_thickness = 0.4
214
+ self.idx = len(meshes) - 1
215
+
216
+ def render(self, camera):
217
+ outputs = render_mesh(
218
+ self.meshes[self.idx], camera, (self.height, self.width),
219
+ return_types=["normals", "wireframe"], wireframe_thickness=self.wireframe_thickness
220
+ )
221
+
222
+ return {
223
+ 'img': (outputs['wireframe'] * ((outputs['normals'] + 1.) / 2.) * 255).to(torch.uint8)[0],
224
+ 'normals': outputs['normals'][0]
225
+ }
226
+
227
+ def show(self, init_camera):
228
+ visualizer = kal.visualize.IpyTurntableVisualizer(
229
+ self.height, self.width, copy.deepcopy(init_camera), self.render,
230
+ max_fps=24, world_up_axis=1)
231
+
232
+ def slider_callback(new_wireframe_thickness, new_idx):
233
+ """ipywidgets sliders callback"""
234
+ with visualizer.out: # This is in case of bug
235
+ self.wireframe_thickness = new_wireframe_thickness
236
+ self.idx = new_idx
237
+ # this is how we request a new update
238
+ visualizer.render_update()
239
+
240
+ wireframe_thickness_slider = FloatLogSlider(
241
+ value=self.wireframe_thickness,
242
+ base=10,
243
+ min=-3,
244
+ max=-0.4,
245
+ step=0.1,
246
+ description='wireframe_thickness',
247
+ continuous_update=True,
248
+ readout=True,
249
+ readout_format='.3f',
250
+ )
251
+
252
+ idx_slider = IntSlider(
253
+ value=self.idx,
254
+ min=0,
255
+ max=len(self.meshes) - 1,
256
+ description='idx',
257
+ continuous_update=True,
258
+ readout=True
259
+ )
260
+
261
+ interactive_slider = interactive(
262
+ slider_callback,
263
+ new_wireframe_thickness=wireframe_thickness_slider,
264
+ new_idx=idx_slider
265
+ )
266
+ full_output = HBox([visualizer.canvas, interactive_slider])
267
+ display(full_output, visualizer.out)
trellis/representations/mesh/flexicubes/examples/util.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
8
+ import numpy as np
9
+ import torch
10
+ import trimesh
11
+ import kaolin
12
+ import nvdiffrast.torch as dr
13
+
14
+ ###############################################################################
15
+ # Functions adapted from https://github.com/NVlabs/nvdiffrec
16
+ ###############################################################################
17
+
18
+ def dot(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
19
+ return torch.sum(x*y, -1, keepdim=True)
20
+
21
+ def length(x: torch.Tensor, eps: float =1e-8) -> torch.Tensor:
22
+ return torch.sqrt(torch.clamp(dot(x,x), min=eps)) # Clamp to avoid nan gradients because grad(sqrt(0)) = NaN
23
+
24
+ def safe_normalize(x: torch.Tensor, eps: float =1e-8) -> torch.Tensor:
25
+ return x / length(x, eps)
26
+
27
+ def perspective(fovy=0.7854, aspect=1.0, n=0.1, f=1000.0, device=None):
28
+ y = np.tan(fovy / 2)
29
+ return torch.tensor([[1/(y*aspect), 0, 0, 0],
30
+ [ 0, 1/-y, 0, 0],
31
+ [ 0, 0, -(f+n)/(f-n), -(2*f*n)/(f-n)],
32
+ [ 0, 0, -1, 0]], dtype=torch.float32, device=device)
33
+
34
+ def translate(x, y, z, device=None):
35
+ return torch.tensor([[1, 0, 0, x],
36
+ [0, 1, 0, y],
37
+ [0, 0, 1, z],
38
+ [0, 0, 0, 1]], dtype=torch.float32, device=device)
39
+
40
+ @torch.no_grad()
41
+ def random_rotation_translation(t, device=None):
42
+ m = np.random.normal(size=[3, 3])
43
+ m[1] = np.cross(m[0], m[2])
44
+ m[2] = np.cross(m[0], m[1])
45
+ m = m / np.linalg.norm(m, axis=1, keepdims=True)
46
+ m = np.pad(m, [[0, 1], [0, 1]], mode='constant')
47
+ m[3, 3] = 1.0
48
+ m[:3, 3] = np.random.uniform(-t, t, size=[3])
49
+ return torch.tensor(m, dtype=torch.float32, device=device)
50
+
51
+ def rotate_x(a, device=None):
52
+ s, c = np.sin(a), np.cos(a)
53
+ return torch.tensor([[1, 0, 0, 0],
54
+ [0, c, s, 0],
55
+ [0, -s, c, 0],
56
+ [0, 0, 0, 1]], dtype=torch.float32, device=device)
57
+
58
+ def rotate_y(a, device=None):
59
+ s, c = np.sin(a), np.cos(a)
60
+ return torch.tensor([[ c, 0, s, 0],
61
+ [ 0, 1, 0, 0],
62
+ [-s, 0, c, 0],
63
+ [ 0, 0, 0, 1]], dtype=torch.float32, device=device)
64
+
65
+ class Mesh:
66
+ def __init__(self, vertices, faces):
67
+ self.vertices = vertices
68
+ self.faces = faces
69
+
70
+ def auto_normals(self):
71
+ v0 = self.vertices[self.faces[:, 0], :]
72
+ v1 = self.vertices[self.faces[:, 1], :]
73
+ v2 = self.vertices[self.faces[:, 2], :]
74
+ nrm = safe_normalize(torch.cross(v1 - v0, v2 - v0))
75
+ self.nrm = nrm
76
+
77
+ def load_mesh(path, device):
78
+ mesh_np = trimesh.load(path)
79
+ vertices = torch.tensor(mesh_np.vertices, device=device, dtype=torch.float)
80
+ faces = torch.tensor(mesh_np.faces, device=device, dtype=torch.long)
81
+
82
+ # Normalize
83
+ vmin, vmax = vertices.min(dim=0)[0], vertices.max(dim=0)[0]
84
+ scale = 1.8 / torch.max(vmax - vmin).item()
85
+ vertices = vertices - (vmax + vmin) / 2 # Center mesh on origin
86
+ vertices = vertices * scale # Rescale to [-0.9, 0.9]
87
+ return Mesh(vertices, faces)
88
+
89
+ def compute_sdf(points, vertices, faces):
90
+ face_vertices = kaolin.ops.mesh.index_vertices_by_faces(vertices.clone().unsqueeze(0), faces)
91
+ distance = kaolin.metrics.trianglemesh.point_to_mesh_distance(points.unsqueeze(0), face_vertices)[0]
92
+ with torch.no_grad():
93
+ sign = (kaolin.ops.mesh.check_sign(vertices.unsqueeze(0), faces, points.unsqueeze(0))<1).float() * 2 - 1
94
+ sdf = (sign*distance).squeeze(0)
95
+ return sdf
96
+
97
+ def sample_random_points(n, mesh):
98
+ pts_random = (torch.rand((n//2,3),device='cuda') - 0.5) * 2
99
+ pts_surface = kaolin.ops.mesh.sample_points(mesh.vertices.unsqueeze(0), mesh.faces, 500)[0].squeeze(0)
100
+ pts_surface += torch.randn_like(pts_surface) * 0.05
101
+ pts = torch.cat([pts_random, pts_surface])
102
+ return pts
103
+
104
+ def xfm_points(points, matrix):
105
+ '''Transform points.
106
+ Args:
107
+ points: Tensor containing 3D points with shape [minibatch_size, num_vertices, 3] or [1, num_vertices, 3]
108
+ matrix: A 4x4 transform matrix with shape [minibatch_size, 4, 4]
109
+ use_python: Use PyTorch's torch.matmul (for validation)
110
+ Returns:
111
+ Transformed points in homogeneous 4D with shape [minibatch_size, num_vertices, 4].
112
+ '''
113
+ out = torch.matmul(
114
+ torch.nn.functional.pad(points, pad=(0, 1), mode='constant', value=1.0), torch.transpose(matrix, 1, 2))
115
+ if torch.is_anomaly_enabled():
116
+ assert torch.all(torch.isfinite(out)), "Output of xfm_points contains inf or NaN"
117
+ return out
118
+
119
+ def interpolate(attr, rast, attr_idx, rast_db=None):
120
+ return dr.interpolate(
121
+ attr, rast, attr_idx, rast_db=rast_db,
122
+ diff_attrs=None if rast_db is None else 'all')
trellis/representations/mesh/flexicubes/flexicubes.py ADDED
@@ -0,0 +1,384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
8
+
9
+ import torch
10
+ from .tables import *
11
+ from kaolin.utils.testing import check_tensor
12
+
13
+ __all__ = [
14
+ 'FlexiCubes'
15
+ ]
16
+
17
+
18
+ class FlexiCubes:
19
+ def __init__(self, device="cuda"):
20
+
21
+ self.device = device
22
+ self.dmc_table = torch.tensor(dmc_table, dtype=torch.long, device=device, requires_grad=False)
23
+ self.num_vd_table = torch.tensor(num_vd_table,
24
+ dtype=torch.long, device=device, requires_grad=False)
25
+ self.check_table = torch.tensor(
26
+ check_table,
27
+ dtype=torch.long, device=device, requires_grad=False)
28
+
29
+ self.tet_table = torch.tensor(tet_table, dtype=torch.long, device=device, requires_grad=False)
30
+ self.quad_split_1 = torch.tensor([0, 1, 2, 0, 2, 3], dtype=torch.long, device=device, requires_grad=False)
31
+ self.quad_split_2 = torch.tensor([0, 1, 3, 3, 1, 2], dtype=torch.long, device=device, requires_grad=False)
32
+ self.quad_split_train = torch.tensor(
33
+ [0, 1, 1, 2, 2, 3, 3, 0], dtype=torch.long, device=device, requires_grad=False)
34
+
35
+ self.cube_corners = torch.tensor([[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0], [0, 0, 1], [
36
+ 1, 0, 1], [0, 1, 1], [1, 1, 1]], dtype=torch.float, device=device)
37
+ self.cube_corners_idx = torch.pow(2, torch.arange(8, requires_grad=False))
38
+ self.cube_edges = torch.tensor([0, 1, 1, 5, 4, 5, 0, 4, 2, 3, 3, 7, 6, 7, 2, 6,
39
+ 2, 0, 3, 1, 7, 5, 6, 4], dtype=torch.long, device=device, requires_grad=False)
40
+
41
+ self.edge_dir_table = torch.tensor([0, 2, 0, 2, 0, 2, 0, 2, 1, 1, 1, 1],
42
+ dtype=torch.long, device=device)
43
+ self.dir_faces_table = torch.tensor([
44
+ [[5, 4], [3, 2], [4, 5], [2, 3]],
45
+ [[5, 4], [1, 0], [4, 5], [0, 1]],
46
+ [[3, 2], [1, 0], [2, 3], [0, 1]]
47
+ ], dtype=torch.long, device=device)
48
+ self.adj_pairs = torch.tensor([0, 1, 1, 3, 3, 2, 2, 0], dtype=torch.long, device=device)
49
+
50
+ def __call__(self, voxelgrid_vertices, scalar_field, cube_idx, resolution, qef_reg_scale=1e-3,
51
+ weight_scale=0.99, beta=None, alpha=None, gamma_f=None, voxelgrid_colors=None, training=False):
52
+ assert torch.is_tensor(voxelgrid_vertices) and \
53
+ check_tensor(voxelgrid_vertices, (None, 3), throw=False), \
54
+ "'voxelgrid_vertices' should be a tensor of shape (num_vertices, 3)"
55
+ num_vertices = voxelgrid_vertices.shape[0]
56
+ assert torch.is_tensor(scalar_field) and \
57
+ check_tensor(scalar_field, (num_vertices,), throw=False), \
58
+ "'scalar_field' should be a tensor of shape (num_vertices,)"
59
+ assert torch.is_tensor(cube_idx) and \
60
+ check_tensor(cube_idx, (None, 8), throw=False), \
61
+ "'cube_idx' should be a tensor of shape (num_cubes, 8)"
62
+ num_cubes = cube_idx.shape[0]
63
+ assert beta is None or (
64
+ torch.is_tensor(beta) and
65
+ check_tensor(beta, (num_cubes, 12), throw=False)
66
+ ), "'beta' should be a tensor of shape (num_cubes, 12)"
67
+ assert alpha is None or (
68
+ torch.is_tensor(alpha) and
69
+ check_tensor(alpha, (num_cubes, 8), throw=False)
70
+ ), "'alpha' should be a tensor of shape (num_cubes, 8)"
71
+ assert gamma_f is None or (
72
+ torch.is_tensor(gamma_f) and
73
+ check_tensor(gamma_f, (num_cubes,), throw=False)
74
+ ), "'gamma_f' should be a tensor of shape (num_cubes,)"
75
+
76
+ surf_cubes, occ_fx8 = self._identify_surf_cubes(scalar_field, cube_idx)
77
+ if surf_cubes.sum() == 0:
78
+ return (
79
+ torch.zeros((0, 3), device=self.device),
80
+ torch.zeros((0, 3), dtype=torch.long, device=self.device),
81
+ torch.zeros((0), device=self.device),
82
+ torch.zeros((0, voxelgrid_colors.shape[-1]), device=self.device) if voxelgrid_colors is not None else None
83
+ )
84
+ beta, alpha, gamma_f = self._normalize_weights(
85
+ beta, alpha, gamma_f, surf_cubes, weight_scale)
86
+
87
+ if voxelgrid_colors is not None:
88
+ voxelgrid_colors = torch.sigmoid(voxelgrid_colors)
89
+
90
+ case_ids = self._get_case_id(occ_fx8, surf_cubes, resolution)
91
+
92
+ surf_edges, idx_map, edge_counts, surf_edges_mask = self._identify_surf_edges(
93
+ scalar_field, cube_idx, surf_cubes
94
+ )
95
+
96
+ vd, L_dev, vd_gamma, vd_idx_map, vd_color = self._compute_vd(
97
+ voxelgrid_vertices, cube_idx[surf_cubes], surf_edges, scalar_field,
98
+ case_ids, beta, alpha, gamma_f, idx_map, qef_reg_scale, voxelgrid_colors)
99
+ vertices, faces, s_edges, edge_indices, vertices_color = self._triangulate(
100
+ scalar_field, surf_edges, vd, vd_gamma, edge_counts, idx_map,
101
+ vd_idx_map, surf_edges_mask, training, vd_color)
102
+ return vertices, faces, L_dev, vertices_color
103
+
104
+ def _compute_reg_loss(self, vd, ue, edge_group_to_vd, vd_num_edges):
105
+ """
106
+ Regularizer L_dev as in Equation 8
107
+ """
108
+ dist = torch.norm(ue - torch.index_select(input=vd, index=edge_group_to_vd, dim=0), dim=-1)
109
+ mean_l2 = torch.zeros_like(vd[:, 0])
110
+ mean_l2 = (mean_l2).index_add_(0, edge_group_to_vd, dist) / vd_num_edges.squeeze(1).float()
111
+ mad = (dist - torch.index_select(input=mean_l2, index=edge_group_to_vd, dim=0)).abs()
112
+ return mad
113
+
114
+ def _normalize_weights(self, beta, alpha, gamma_f, surf_cubes, weight_scale):
115
+ """
116
+ Normalizes the given weights to be non-negative. If input weights are None, it creates and returns a set of weights of ones.
117
+ """
118
+ n_cubes = surf_cubes.shape[0]
119
+
120
+ if beta is not None:
121
+ beta = (torch.tanh(beta) * weight_scale + 1)
122
+ else:
123
+ beta = torch.ones((n_cubes, 12), dtype=torch.float, device=self.device)
124
+
125
+ if alpha is not None:
126
+ alpha = (torch.tanh(alpha) * weight_scale + 1)
127
+ else:
128
+ alpha = torch.ones((n_cubes, 8), dtype=torch.float, device=self.device)
129
+
130
+ if gamma_f is not None:
131
+ gamma_f = torch.sigmoid(gamma_f) * weight_scale + (1 - weight_scale) / 2
132
+ else:
133
+ gamma_f = torch.ones((n_cubes), dtype=torch.float, device=self.device)
134
+
135
+ return beta[surf_cubes], alpha[surf_cubes], gamma_f[surf_cubes]
136
+
137
+ @torch.no_grad()
138
+ def _get_case_id(self, occ_fx8, surf_cubes, res):
139
+ """
140
+ Obtains the ID of topology cases based on cell corner occupancy. This function resolves the
141
+ ambiguity in the Dual Marching Cubes (DMC) configurations as described in Section 1.3 of the
142
+ supplementary material. It should be noted that this function assumes a regular grid.
143
+ """
144
+ case_ids = (occ_fx8[surf_cubes] * self.cube_corners_idx.to(self.device).unsqueeze(0)).sum(-1)
145
+
146
+ problem_config = self.check_table.to(self.device)[case_ids]
147
+ to_check = problem_config[..., 0] == 1
148
+ problem_config = problem_config[to_check]
149
+ if not isinstance(res, (list, tuple)):
150
+ res = [res, res, res]
151
+
152
+ # The 'problematic_configs' only contain configurations for surface cubes. Next, we construct a 3D array,
153
+ # 'problem_config_full', to store configurations for all cubes (with default config for non-surface cubes).
154
+ # This allows efficient checking on adjacent cubes.
155
+ problem_config_full = torch.zeros(list(res) + [5], device=self.device, dtype=torch.long)
156
+ vol_idx = torch.nonzero(problem_config_full[..., 0] == 0) # N, 3
157
+ vol_idx_problem = vol_idx[surf_cubes][to_check]
158
+ problem_config_full[vol_idx_problem[..., 0], vol_idx_problem[..., 1], vol_idx_problem[..., 2]] = problem_config
159
+ vol_idx_problem_adj = vol_idx_problem + problem_config[..., 1:4]
160
+
161
+ within_range = (
162
+ vol_idx_problem_adj[..., 0] >= 0) & (
163
+ vol_idx_problem_adj[..., 0] < res[0]) & (
164
+ vol_idx_problem_adj[..., 1] >= 0) & (
165
+ vol_idx_problem_adj[..., 1] < res[1]) & (
166
+ vol_idx_problem_adj[..., 2] >= 0) & (
167
+ vol_idx_problem_adj[..., 2] < res[2])
168
+
169
+ vol_idx_problem = vol_idx_problem[within_range]
170
+ vol_idx_problem_adj = vol_idx_problem_adj[within_range]
171
+ problem_config = problem_config[within_range]
172
+ problem_config_adj = problem_config_full[vol_idx_problem_adj[..., 0],
173
+ vol_idx_problem_adj[..., 1], vol_idx_problem_adj[..., 2]]
174
+ # If two cubes with cases C16 and C19 share an ambiguous face, both cases are inverted.
175
+ to_invert = (problem_config_adj[..., 0] == 1)
176
+ idx = torch.arange(case_ids.shape[0], device=self.device)[to_check][within_range][to_invert]
177
+ case_ids.index_put_((idx,), problem_config[to_invert][..., -1])
178
+ return case_ids
179
+
180
+ @torch.no_grad()
181
+ def _identify_surf_edges(self, scalar_field, cube_idx, surf_cubes):
182
+ """
183
+ Identifies grid edges that intersect with the underlying surface by checking for opposite signs. As each edge
184
+ can be shared by multiple cubes, this function also assigns a unique index to each surface-intersecting edge
185
+ and marks the cube edges with this index.
186
+ """
187
+ occ_n = scalar_field < 0
188
+ all_edges = cube_idx[surf_cubes][:, self.cube_edges].reshape(-1, 2)
189
+ unique_edges, _idx_map, counts = torch.unique(all_edges, dim=0, return_inverse=True, return_counts=True)
190
+
191
+ unique_edges = unique_edges.long()
192
+ mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1
193
+
194
+ surf_edges_mask = mask_edges[_idx_map]
195
+ counts = counts[_idx_map]
196
+
197
+ mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=cube_idx.device) * -1
198
+ mapping[mask_edges] = torch.arange(mask_edges.sum(), device=cube_idx.device)
199
+ # Shaped as [number of cubes x 12 edges per cube]. This is later used to map a cube edge to the unique index
200
+ # for a surface-intersecting edge. Non-surface-intersecting edges are marked with -1.
201
+ idx_map = mapping[_idx_map]
202
+ surf_edges = unique_edges[mask_edges]
203
+ return surf_edges, idx_map, counts, surf_edges_mask
204
+
205
+ @torch.no_grad()
206
+ def _identify_surf_cubes(self, scalar_field, cube_idx):
207
+ """
208
+ Identifies grid cubes that intersect with the underlying surface by checking if the signs at
209
+ all corners are not identical.
210
+ """
211
+ occ_n = scalar_field < 0
212
+ occ_fx8 = occ_n[cube_idx.reshape(-1)].reshape(-1, 8)
213
+ _occ_sum = torch.sum(occ_fx8, -1)
214
+ surf_cubes = (_occ_sum > 0) & (_occ_sum < 8)
215
+ return surf_cubes, occ_fx8
216
+
217
+ def _linear_interp(self, edges_weight, edges_x):
218
+ """
219
+ Computes the location of zero-crossings on 'edges_x' using linear interpolation with 'edges_weight'.
220
+ """
221
+ edge_dim = edges_weight.dim() - 2
222
+ assert edges_weight.shape[edge_dim] == 2
223
+ edges_weight = torch.cat([torch.index_select(input=edges_weight, index=torch.tensor(1, device=self.device), dim=edge_dim), -
224
+ torch.index_select(input=edges_weight, index=torch.tensor(0, device=self.device), dim=edge_dim)]
225
+ , edge_dim)
226
+ denominator = edges_weight.sum(edge_dim)
227
+ ue = (edges_x * edges_weight).sum(edge_dim) / denominator
228
+ return ue
229
+
230
+ def _solve_vd_QEF(self, p_bxnx3, norm_bxnx3, c_bx3, qef_reg_scale):
231
+ p_bxnx3 = p_bxnx3.reshape(-1, 7, 3)
232
+ norm_bxnx3 = norm_bxnx3.reshape(-1, 7, 3)
233
+ c_bx3 = c_bx3.reshape(-1, 3)
234
+ A = norm_bxnx3
235
+ B = ((p_bxnx3) * norm_bxnx3).sum(-1, keepdims=True)
236
+
237
+ A_reg = (torch.eye(3, device=p_bxnx3.device) * qef_reg_scale).unsqueeze(0).repeat(p_bxnx3.shape[0], 1, 1)
238
+ B_reg = (qef_reg_scale * c_bx3).unsqueeze(-1)
239
+ A = torch.cat([A, A_reg], 1)
240
+ B = torch.cat([B, B_reg], 1)
241
+ dual_verts = torch.linalg.lstsq(A, B).solution.squeeze(-1)
242
+ return dual_verts
243
+
244
+ def _compute_vd(self, voxelgrid_vertices, surf_cubes_fx8, surf_edges, scalar_field,
245
+ case_ids, beta, alpha, gamma_f, idx_map, qef_reg_scale, voxelgrid_colors):
246
+ """
247
+ Computes the location of dual vertices as described in Section 4.2
248
+ """
249
+ alpha_nx12x2 = torch.index_select(input=alpha, index=self.cube_edges, dim=1).reshape(-1, 12, 2)
250
+ surf_edges_x = torch.index_select(input=voxelgrid_vertices, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, 3)
251
+ surf_edges_s = torch.index_select(input=scalar_field, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, 1)
252
+ zero_crossing = self._linear_interp(surf_edges_s, surf_edges_x)
253
+
254
+ if voxelgrid_colors is not None:
255
+ C = voxelgrid_colors.shape[-1]
256
+ surf_edges_c = torch.index_select(input=voxelgrid_colors, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, C)
257
+
258
+ idx_map = idx_map.reshape(-1, 12)
259
+ num_vd = torch.index_select(input=self.num_vd_table, index=case_ids, dim=0)
260
+ edge_group, edge_group_to_vd, edge_group_to_cube, vd_num_edges, vd_gamma = [], [], [], [], []
261
+
262
+ # if color is not None:
263
+ # vd_color = []
264
+
265
+ total_num_vd = 0
266
+ vd_idx_map = torch.zeros((case_ids.shape[0], 12), dtype=torch.long, device=self.device, requires_grad=False)
267
+
268
+ for num in torch.unique(num_vd):
269
+ cur_cubes = (num_vd == num) # consider cubes with the same numbers of vd emitted (for batching)
270
+ curr_num_vd = cur_cubes.sum() * num
271
+ curr_edge_group = self.dmc_table[case_ids[cur_cubes], :num].reshape(-1, num * 7)
272
+ curr_edge_group_to_vd = torch.arange(
273
+ curr_num_vd, device=self.device).unsqueeze(-1).repeat(1, 7) + total_num_vd
274
+ total_num_vd += curr_num_vd
275
+ curr_edge_group_to_cube = torch.arange(idx_map.shape[0], device=self.device)[
276
+ cur_cubes].unsqueeze(-1).repeat(1, num * 7).reshape_as(curr_edge_group)
277
+
278
+ curr_mask = (curr_edge_group != -1)
279
+ edge_group.append(torch.masked_select(curr_edge_group, curr_mask))
280
+ edge_group_to_vd.append(torch.masked_select(curr_edge_group_to_vd.reshape_as(curr_edge_group), curr_mask))
281
+ edge_group_to_cube.append(torch.masked_select(curr_edge_group_to_cube, curr_mask))
282
+ vd_num_edges.append(curr_mask.reshape(-1, 7).sum(-1, keepdims=True))
283
+ vd_gamma.append(torch.masked_select(gamma_f, cur_cubes).unsqueeze(-1).repeat(1, num).reshape(-1))
284
+ # if color is not None:
285
+ # vd_color.append(color[cur_cubes].unsqueeze(1).repeat(1, num, 1).reshape(-1, 3))
286
+
287
+ edge_group = torch.cat(edge_group)
288
+ edge_group_to_vd = torch.cat(edge_group_to_vd)
289
+ edge_group_to_cube = torch.cat(edge_group_to_cube)
290
+ vd_num_edges = torch.cat(vd_num_edges)
291
+ vd_gamma = torch.cat(vd_gamma)
292
+ # if color is not None:
293
+ # vd_color = torch.cat(vd_color)
294
+ # else:
295
+ # vd_color = None
296
+
297
+ vd = torch.zeros((total_num_vd, 3), device=self.device)
298
+ beta_sum = torch.zeros((total_num_vd, 1), device=self.device)
299
+
300
+ idx_group = torch.gather(input=idx_map.reshape(-1), dim=0, index=edge_group_to_cube * 12 + edge_group)
301
+
302
+ x_group = torch.index_select(input=surf_edges_x, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, 3)
303
+ s_group = torch.index_select(input=surf_edges_s, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, 1)
304
+
305
+
306
+ zero_crossing_group = torch.index_select(
307
+ input=zero_crossing, index=idx_group.reshape(-1), dim=0).reshape(-1, 3)
308
+
309
+ alpha_group = torch.index_select(input=alpha_nx12x2.reshape(-1, 2), dim=0,
310
+ index=edge_group_to_cube * 12 + edge_group).reshape(-1, 2, 1)
311
+ ue_group = self._linear_interp(s_group * alpha_group, x_group)
312
+
313
+ beta_group = torch.gather(input=beta.reshape(-1), dim=0,
314
+ index=edge_group_to_cube * 12 + edge_group).reshape(-1, 1)
315
+ beta_sum = beta_sum.index_add_(0, index=edge_group_to_vd, source=beta_group)
316
+ vd = vd.index_add_(0, index=edge_group_to_vd, source=ue_group * beta_group) / beta_sum
317
+
318
+ '''
319
+ interpolate colors use the same method as dual vertices
320
+ '''
321
+ if voxelgrid_colors is not None:
322
+ vd_color = torch.zeros((total_num_vd, C), device=self.device)
323
+ c_group = torch.index_select(input=surf_edges_c, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, C)
324
+ uc_group = self._linear_interp(s_group * alpha_group, c_group)
325
+ vd_color = vd_color.index_add_(0, index=edge_group_to_vd, source=uc_group * beta_group) / beta_sum
326
+ else:
327
+ vd_color = None
328
+
329
+ L_dev = self._compute_reg_loss(vd, zero_crossing_group, edge_group_to_vd, vd_num_edges)
330
+
331
+ v_idx = torch.arange(vd.shape[0], device=self.device) # + total_num_vd
332
+
333
+ vd_idx_map = (vd_idx_map.reshape(-1)).scatter(dim=0, index=edge_group_to_cube *
334
+ 12 + edge_group, src=v_idx[edge_group_to_vd])
335
+
336
+ return vd, L_dev, vd_gamma, vd_idx_map, vd_color
337
+
338
+ def _triangulate(self, scalar_field, surf_edges, vd, vd_gamma, edge_counts, idx_map, vd_idx_map, surf_edges_mask, training, vd_color):
339
+ """
340
+ Connects four neighboring dual vertices to form a quadrilateral. The quadrilaterals are then split into
341
+ triangles based on the gamma parameter, as described in Section 4.3.
342
+ """
343
+ with torch.no_grad():
344
+ group_mask = (edge_counts == 4) & surf_edges_mask # surface edges shared by 4 cubes.
345
+ group = idx_map.reshape(-1)[group_mask]
346
+ vd_idx = vd_idx_map[group_mask]
347
+ edge_indices, indices = torch.sort(group, stable=True)
348
+ quad_vd_idx = vd_idx[indices].reshape(-1, 4)
349
+
350
+ # Ensure all face directions point towards the positive SDF to maintain consistent winding.
351
+ s_edges = scalar_field[surf_edges[edge_indices.reshape(-1, 4)[:, 0]].reshape(-1)].reshape(-1, 2)
352
+ flip_mask = s_edges[:, 0] > 0
353
+ quad_vd_idx = torch.cat((quad_vd_idx[flip_mask][:, [0, 1, 3, 2]],
354
+ quad_vd_idx[~flip_mask][:, [2, 3, 1, 0]]))
355
+
356
+ quad_gamma = torch.index_select(input=vd_gamma, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4)
357
+ gamma_02 = quad_gamma[:, 0] * quad_gamma[:, 2]
358
+ gamma_13 = quad_gamma[:, 1] * quad_gamma[:, 3]
359
+ if not training:
360
+ mask = (gamma_02 > gamma_13)
361
+ faces = torch.zeros((quad_gamma.shape[0], 6), dtype=torch.long, device=quad_vd_idx.device)
362
+ faces[mask] = quad_vd_idx[mask][:, self.quad_split_1]
363
+ faces[~mask] = quad_vd_idx[~mask][:, self.quad_split_2]
364
+ faces = faces.reshape(-1, 3)
365
+ else:
366
+ vd_quad = torch.index_select(input=vd, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4, 3)
367
+ vd_02 = (vd_quad[:, 0] + vd_quad[:, 2]) / 2
368
+ vd_13 = (vd_quad[:, 1] + vd_quad[:, 3]) / 2
369
+ weight_sum = (gamma_02 + gamma_13) + 1e-8
370
+ vd_center = (vd_02 * gamma_02.unsqueeze(-1) + vd_13 * gamma_13.unsqueeze(-1)) / weight_sum.unsqueeze(-1)
371
+
372
+ if vd_color is not None:
373
+ color_quad = torch.index_select(input=vd_color, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4, vd_color.shape[-1])
374
+ color_02 = (color_quad[:, 0] + color_quad[:, 2]) / 2
375
+ color_13 = (color_quad[:, 1] + color_quad[:, 3]) / 2
376
+ color_center = (color_02 * gamma_02.unsqueeze(-1) + color_13 * gamma_13.unsqueeze(-1)) / weight_sum.unsqueeze(-1)
377
+ vd_color = torch.cat([vd_color, color_center])
378
+
379
+
380
+ vd_center_idx = torch.arange(vd_center.shape[0], device=self.device) + vd.shape[0]
381
+ vd = torch.cat([vd, vd_center])
382
+ faces = quad_vd_idx[:, self.quad_split_train].reshape(-1, 4, 2)
383
+ faces = torch.cat([faces, vd_center_idx.reshape(-1, 1, 1).repeat(1, 4, 1)], -1).reshape(-1, 3)
384
+ return vd, faces, s_edges, edge_indices, vd_color
trellis/representations/mesh/flexicubes/images/ablate_L_dev.jpg ADDED
trellis/representations/mesh/flexicubes/images/block_final.png ADDED

Git LFS Details

  • SHA256: d030fee195d332f63ef80805486ed0b4074b1afc34efcba621e385aca9ae9135
  • Pointer size: 130 Bytes
  • Size of remote file: 56 kB
trellis/representations/mesh/flexicubes/images/block_init.png ADDED

Git LFS Details

  • SHA256: 699ba21d95cce9d1504d31fca3694ba339f21703ac0bc3240c87df6ac2d2db3e
  • Pointer size: 131 Bytes
  • Size of remote file: 199 kB
trellis/representations/mesh/flexicubes/images/teaser_top.png ADDED

Git LFS Details

  • SHA256: 71c27efaeeb7fc3357440607b34805495fc34acf39be00bb70dd315b5b25a71d
  • Pointer size: 132 Bytes
  • Size of remote file: 3.56 MB
trellis/representations/mesh/flexicubes/tables.py ADDED
@@ -0,0 +1,791 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
8
+ dmc_table = [
9
+ [[-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
10
+ [[0, 3, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
11
+ [[0, 1, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
12
+ [[1, 3, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
13
+ [[4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
14
+ [[0, 3, 4, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
15
+ [[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
16
+ [[1, 3, 4, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
17
+ [[4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
18
+ [[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
19
+ [[0, 1, 4, 5, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
20
+ [[1, 3, 4, 5, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
21
+ [[5, 7, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
22
+ [[0, 3, 5, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
23
+ [[0, 1, 5, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
24
+ [[1, 3, 5, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
25
+ [[2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
26
+ [[0, 2, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
27
+ [[0, 1, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
28
+ [[1, 2, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
29
+ [[4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
30
+ [[0, 2, 4, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
31
+ [[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
32
+ [[1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
33
+ [[4, 5, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
34
+ [[0, 2, 8, 11, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
35
+ [[0, 1, 4, 5, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
36
+ [[1, 2, 4, 5, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
37
+ [[5, 7, 8, 9, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
38
+ [[0, 2, 5, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
39
+ [[0, 1, 5, 7, 8, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
40
+ [[1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
41
+ [[1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
42
+ [[0, 3, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
43
+ [[0, 2, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
44
+ [[2, 3, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
45
+ [[4, 7, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
46
+ [[0, 3, 4, 7, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
47
+ [[0, 2, 9, 10, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
48
+ [[2, 3, 4, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
49
+ [[4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
50
+ [[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
51
+ [[0, 2, 4, 5, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
52
+ [[2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
53
+ [[5, 7, 8, 9, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
54
+ [[0, 3, 5, 7, 9, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
55
+ [[0, 2, 5, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
56
+ [[2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
57
+ [[1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
58
+ [[0, 1, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
59
+ [[0, 3, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
60
+ [[8, 9, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
61
+ [[4, 7, 8, -1, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
62
+ [[0, 1, 4, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
63
+ [[0, 3, 9, 10, 11, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
64
+ [[4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
65
+ [[4, 5, 9, -1, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
66
+ [[0, 1, 8, 10, 11, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
67
+ [[0, 3, 4, 5, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
68
+ [[4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
69
+ [[5, 7, 8, 9, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
70
+ [[0, 1, 5, 7, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
71
+ [[0, 3, 5, 7, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
72
+ [[5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
73
+ [[6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
74
+ [[0, 3, 8, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
75
+ [[0, 1, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
76
+ [[1, 3, 8, 9, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
77
+ [[4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
78
+ [[0, 3, 4, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
79
+ [[0, 1, 9, -1, -1, -1, -1], [4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
80
+ [[1, 3, 4, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
81
+ [[4, 5, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
82
+ [[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
83
+ [[0, 1, 4, 5, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
84
+ [[1, 3, 4, 5, 8, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
85
+ [[5, 6, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
86
+ [[0, 3, 5, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
87
+ [[0, 1, 5, 6, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
88
+ [[1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
89
+ [[2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
90
+ [[0, 2, 6, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
91
+ [[0, 1, 9, -1, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
92
+ [[1, 2, 6, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
93
+ [[2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
94
+ [[0, 2, 4, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
95
+ [[0, 1, 9, -1, -1, -1, -1], [2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
96
+ [[1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
97
+ [[4, 5, 9, -1, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
98
+ [[0, 2, 6, 7, 8, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
99
+ [[0, 1, 4, 5, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
100
+ [[1, 2, 4, 5, 6, 7, 8], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
101
+ [[2, 3, 5, 6, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
102
+ [[0, 2, 5, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
103
+ [[0, 1, 2, 3, 5, 6, 8], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
104
+ [[1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
105
+ [[1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
106
+ [[0, 3, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
107
+ [[0, 2, 9, 10, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
108
+ [[2, 3, 8, 9, 10, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
109
+ [[4, 6, 8, 11, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
110
+ [[0, 3, 4, 6, 11, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
111
+ [[0, 2, 9, 10, -1, -1, -1], [4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
112
+ [[2, 3, 4, 6, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
113
+ [[4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
114
+ [[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1]],
115
+ [[0, 2, 4, 5, 10, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
116
+ [[2, 3, 4, 5, 8, 10, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
117
+ [[5, 6, 8, 9, 11, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
118
+ [[0, 3, 5, 6, 9, 11, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
119
+ [[0, 2, 5, 6, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
120
+ [[2, 3, 5, 6, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
121
+ [[1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
122
+ [[0, 1, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
123
+ [[0, 3, 6, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
124
+ [[6, 7, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
125
+ [[1, 3, 4, 6, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
126
+ [[0, 1, 4, 6, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
127
+ [[0, 3, 4, 6, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
128
+ [[4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
129
+ [[4, 5, 9, -1, -1, -1, -1], [1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
130
+ [[0, 1, 6, 7, 8, 10, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
131
+ [[0, 3, 4, 5, 6, 7, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
132
+ [[4, 5, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
133
+ [[1, 3, 5, 6, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
134
+ [[0, 1, 5, 6, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
135
+ [[0, 3, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
136
+ [[5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
137
+ [[5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
138
+ [[0, 3, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
139
+ [[0, 1, 9, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
140
+ [[1, 3, 8, 9, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
141
+ [[4, 7, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
142
+ [[0, 3, 4, 7, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
143
+ [[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
144
+ [[1, 3, 4, 7, 9, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
145
+ [[4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
146
+ [[0, 3, 8, -1, -1, -1, -1], [4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
147
+ [[0, 1, 4, 6, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
148
+ [[1, 3, 4, 6, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
149
+ [[6, 7, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
150
+ [[0, 3, 6, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
151
+ [[0, 1, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
152
+ [[1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
153
+ [[2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
154
+ [[0, 2, 8, 11, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
155
+ [[0, 1, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
156
+ [[1, 2, 8, 9, 11, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
157
+ [[4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
158
+ [[0, 2, 4, 7, 11, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
159
+ [[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1]],
160
+ [[1, 2, 4, 7, 9, 11, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
161
+ [[4, 6, 9, 10, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
162
+ [[0, 2, 8, 11, -1, -1, -1], [4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
163
+ [[0, 1, 4, 6, 10, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
164
+ [[1, 2, 4, 6, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
165
+ [[6, 7, 8, 9, 10, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
166
+ [[0, 2, 6, 7, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
167
+ [[0, 1, 6, 7, 8, 10, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
168
+ [[1, 2, 6, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
169
+ [[1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
170
+ [[0, 3, 8, -1, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
171
+ [[0, 2, 5, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
172
+ [[2, 3, 5, 6, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
173
+ [[4, 7, 8, -1, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
174
+ [[0, 3, 4, 7, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
175
+ [[0, 2, 5, 6, 9, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
176
+ [[2, 3, 4, 5, 6, 7, 9], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
177
+ [[1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
178
+ [[0, 3, 8, -1, -1, -1, -1], [1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
179
+ [[0, 2, 4, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
180
+ [[2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
181
+ [[1, 2, 6, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
182
+ [[0, 1, 2, 3, 6, 7, 9], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
183
+ [[0, 2, 6, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
184
+ [[2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
185
+ [[1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
186
+ [[0, 1, 5, 6, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
187
+ [[0, 3, 5, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
188
+ [[5, 6, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
189
+ [[4, 7, 8, -1, -1, -1, -1], [1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
190
+ [[0, 1, 4, 5, 6, 7, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
191
+ [[0, 3, 5, 6, 9, 11, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
192
+ [[4, 5, 6, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
193
+ [[1, 3, 4, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
194
+ [[0, 1, 4, 6, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
195
+ [[0, 3, 4, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
196
+ [[4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
197
+ [[1, 3, 6, 7, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
198
+ [[0, 1, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
199
+ [[0, 3, 6, 7, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
200
+ [[6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
201
+ [[5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
202
+ [[0, 3, 8, -1, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
203
+ [[0, 1, 9, -1, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
204
+ [[1, 3, 8, 9, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
205
+ [[4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
206
+ [[0, 3, 4, 5, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
207
+ [[0, 1, 9, -1, -1, -1, -1], [4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
208
+ [[1, 3, 4, 5, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
209
+ [[4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
210
+ [[0, 3, 8, -1, -1, -1, -1], [4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
211
+ [[0, 1, 4, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
212
+ [[1, 3, 4, 7, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
213
+ [[8, 9, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
214
+ [[0, 3, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
215
+ [[0, 1, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
216
+ [[1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
217
+ [[2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
218
+ [[0, 2, 5, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
219
+ [[0, 1, 9, -1, -1, -1, -1], [2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
220
+ [[1, 2, 5, 7, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
221
+ [[2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
222
+ [[0, 2, 4, 5, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
223
+ [[0, 1, 9, -1, -1, -1, -1], [2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
224
+ [[1, 2, 4, 5, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
225
+ [[2, 3, 4, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
226
+ [[0, 2, 4, 7, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
227
+ [[0, 1, 2, 3, 4, 7, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
228
+ [[4, 7, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
229
+ [[2, 3, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
230
+ [[0, 2, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
231
+ [[0, 1, 2, 3, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
232
+ [[1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
233
+ [[1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
234
+ [[0, 3, 8, -1, -1, -1, -1], [1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
235
+ [[0, 2, 5, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
236
+ [[2, 3, 5, 7, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
237
+ [[1, 2, 4, 5, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
238
+ [[0, 1, 2, 3, 4, 5, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
239
+ [[0, 2, 4, 5, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
240
+ [[4, 5, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
241
+ [[1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
242
+ [[0, 3, 8, -1, -1, -1, -1], [1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
243
+ [[0, 2, 4, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
244
+ [[2, 3, 4, 7, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
245
+ [[1, 2, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
246
+ [[0, 1, 2, 3, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
247
+ [[0, 2, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
248
+ [[2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
249
+ [[1, 3, 5, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
250
+ [[0, 1, 5, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
251
+ [[0, 3, 5, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
252
+ [[5, 7, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
253
+ [[1, 3, 4, 5, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
254
+ [[0, 1, 4, 5, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
255
+ [[0, 3, 4, 5, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
256
+ [[4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
257
+ [[1, 3, 4, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
258
+ [[0, 1, 4, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
259
+ [[0, 3, 4, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
260
+ [[4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
261
+ [[1, 3, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
262
+ [[0, 1, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
263
+ [[0, 3, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
264
+ [[-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]]
265
+ ]
266
+ num_vd_table = [0, 1, 1, 1, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 3, 1, 2, 2,
267
+ 2, 1, 2, 1, 2, 1, 1, 2, 1, 1, 2, 2, 2, 1, 2, 3, 1, 1, 2, 2, 1, 1, 1, 1, 1, 1, 2,
268
+ 1, 2, 1, 2, 2, 1, 1, 2, 1, 1, 1, 1, 2, 2, 2, 1, 1, 2, 1, 2, 3, 2, 2, 1, 1, 1, 1,
269
+ 1, 1, 2, 1, 1, 1, 2, 1, 2, 2, 2, 1, 1, 1, 1, 1, 2, 3, 2, 2, 2, 2, 2, 1, 3, 4, 2,
270
+ 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 2, 1, 1, 2, 2, 2, 2, 2,
271
+ 3, 2, 1, 2, 1, 1, 1, 1, 1, 1, 2, 2, 3, 2, 3, 2, 4, 2, 2, 2, 2, 1, 2, 1, 2, 1, 1,
272
+ 2, 1, 1, 2, 2, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1,
273
+ 1, 2, 1, 1, 1, 2, 2, 2, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 2,
274
+ 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1,
275
+ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]
276
+ check_table = [
277
+ [0, 0, 0, 0, 0],
278
+ [0, 0, 0, 0, 0],
279
+ [0, 0, 0, 0, 0],
280
+ [0, 0, 0, 0, 0],
281
+ [0, 0, 0, 0, 0],
282
+ [0, 0, 0, 0, 0],
283
+ [0, 0, 0, 0, 0],
284
+ [0, 0, 0, 0, 0],
285
+ [0, 0, 0, 0, 0],
286
+ [0, 0, 0, 0, 0],
287
+ [0, 0, 0, 0, 0],
288
+ [0, 0, 0, 0, 0],
289
+ [0, 0, 0, 0, 0],
290
+ [0, 0, 0, 0, 0],
291
+ [0, 0, 0, 0, 0],
292
+ [0, 0, 0, 0, 0],
293
+ [0, 0, 0, 0, 0],
294
+ [0, 0, 0, 0, 0],
295
+ [0, 0, 0, 0, 0],
296
+ [0, 0, 0, 0, 0],
297
+ [0, 0, 0, 0, 0],
298
+ [0, 0, 0, 0, 0],
299
+ [0, 0, 0, 0, 0],
300
+ [0, 0, 0, 0, 0],
301
+ [0, 0, 0, 0, 0],
302
+ [0, 0, 0, 0, 0],
303
+ [0, 0, 0, 0, 0],
304
+ [0, 0, 0, 0, 0],
305
+ [0, 0, 0, 0, 0],
306
+ [0, 0, 0, 0, 0],
307
+ [0, 0, 0, 0, 0],
308
+ [0, 0, 0, 0, 0],
309
+ [0, 0, 0, 0, 0],
310
+ [0, 0, 0, 0, 0],
311
+ [0, 0, 0, 0, 0],
312
+ [0, 0, 0, 0, 0],
313
+ [0, 0, 0, 0, 0],
314
+ [0, 0, 0, 0, 0],
315
+ [0, 0, 0, 0, 0],
316
+ [0, 0, 0, 0, 0],
317
+ [0, 0, 0, 0, 0],
318
+ [0, 0, 0, 0, 0],
319
+ [0, 0, 0, 0, 0],
320
+ [0, 0, 0, 0, 0],
321
+ [0, 0, 0, 0, 0],
322
+ [0, 0, 0, 0, 0],
323
+ [0, 0, 0, 0, 0],
324
+ [0, 0, 0, 0, 0],
325
+ [0, 0, 0, 0, 0],
326
+ [0, 0, 0, 0, 0],
327
+ [0, 0, 0, 0, 0],
328
+ [0, 0, 0, 0, 0],
329
+ [0, 0, 0, 0, 0],
330
+ [0, 0, 0, 0, 0],
331
+ [0, 0, 0, 0, 0],
332
+ [0, 0, 0, 0, 0],
333
+ [0, 0, 0, 0, 0],
334
+ [0, 0, 0, 0, 0],
335
+ [0, 0, 0, 0, 0],
336
+ [0, 0, 0, 0, 0],
337
+ [0, 0, 0, 0, 0],
338
+ [1, 1, 0, 0, 194],
339
+ [1, -1, 0, 0, 193],
340
+ [0, 0, 0, 0, 0],
341
+ [0, 0, 0, 0, 0],
342
+ [0, 0, 0, 0, 0],
343
+ [0, 0, 0, 0, 0],
344
+ [0, 0, 0, 0, 0],
345
+ [0, 0, 0, 0, 0],
346
+ [0, 0, 0, 0, 0],
347
+ [0, 0, 0, 0, 0],
348
+ [0, 0, 0, 0, 0],
349
+ [0, 0, 0, 0, 0],
350
+ [0, 0, 0, 0, 0],
351
+ [0, 0, 0, 0, 0],
352
+ [0, 0, 0, 0, 0],
353
+ [0, 0, 0, 0, 0],
354
+ [0, 0, 0, 0, 0],
355
+ [0, 0, 0, 0, 0],
356
+ [0, 0, 0, 0, 0],
357
+ [0, 0, 0, 0, 0],
358
+ [0, 0, 0, 0, 0],
359
+ [0, 0, 0, 0, 0],
360
+ [0, 0, 0, 0, 0],
361
+ [0, 0, 0, 0, 0],
362
+ [0, 0, 0, 0, 0],
363
+ [0, 0, 0, 0, 0],
364
+ [0, 0, 0, 0, 0],
365
+ [0, 0, 0, 0, 0],
366
+ [0, 0, 0, 0, 0],
367
+ [0, 0, 0, 0, 0],
368
+ [1, 0, 1, 0, 164],
369
+ [0, 0, 0, 0, 0],
370
+ [0, 0, 0, 0, 0],
371
+ [1, 0, -1, 0, 161],
372
+ [0, 0, 0, 0, 0],
373
+ [0, 0, 0, 0, 0],
374
+ [0, 0, 0, 0, 0],
375
+ [0, 0, 0, 0, 0],
376
+ [0, 0, 0, 0, 0],
377
+ [0, 0, 0, 0, 0],
378
+ [0, 0, 0, 0, 0],
379
+ [0, 0, 0, 0, 0],
380
+ [1, 0, 0, 1, 152],
381
+ [0, 0, 0, 0, 0],
382
+ [0, 0, 0, 0, 0],
383
+ [0, 0, 0, 0, 0],
384
+ [0, 0, 0, 0, 0],
385
+ [0, 0, 0, 0, 0],
386
+ [0, 0, 0, 0, 0],
387
+ [1, 0, 0, 1, 145],
388
+ [1, 0, 0, 1, 144],
389
+ [0, 0, 0, 0, 0],
390
+ [0, 0, 0, 0, 0],
391
+ [0, 0, 0, 0, 0],
392
+ [0, 0, 0, 0, 0],
393
+ [0, 0, 0, 0, 0],
394
+ [0, 0, 0, 0, 0],
395
+ [1, 0, 0, -1, 137],
396
+ [0, 0, 0, 0, 0],
397
+ [0, 0, 0, 0, 0],
398
+ [0, 0, 0, 0, 0],
399
+ [1, 0, 1, 0, 133],
400
+ [1, 0, 1, 0, 132],
401
+ [1, 1, 0, 0, 131],
402
+ [1, 1, 0, 0, 130],
403
+ [0, 0, 0, 0, 0],
404
+ [0, 0, 0, 0, 0],
405
+ [0, 0, 0, 0, 0],
406
+ [0, 0, 0, 0, 0],
407
+ [0, 0, 0, 0, 0],
408
+ [0, 0, 0, 0, 0],
409
+ [0, 0, 0, 0, 0],
410
+ [0, 0, 0, 0, 0],
411
+ [0, 0, 0, 0, 0],
412
+ [0, 0, 0, 0, 0],
413
+ [0, 0, 0, 0, 0],
414
+ [0, 0, 0, 0, 0],
415
+ [0, 0, 0, 0, 0],
416
+ [0, 0, 0, 0, 0],
417
+ [0, 0, 0, 0, 0],
418
+ [0, 0, 0, 0, 0],
419
+ [0, 0, 0, 0, 0],
420
+ [0, 0, 0, 0, 0],
421
+ [0, 0, 0, 0, 0],
422
+ [0, 0, 0, 0, 0],
423
+ [0, 0, 0, 0, 0],
424
+ [0, 0, 0, 0, 0],
425
+ [0, 0, 0, 0, 0],
426
+ [0, 0, 0, 0, 0],
427
+ [0, 0, 0, 0, 0],
428
+ [0, 0, 0, 0, 0],
429
+ [0, 0, 0, 0, 0],
430
+ [0, 0, 0, 0, 0],
431
+ [0, 0, 0, 0, 0],
432
+ [1, 0, 0, 1, 100],
433
+ [0, 0, 0, 0, 0],
434
+ [1, 0, 0, 1, 98],
435
+ [0, 0, 0, 0, 0],
436
+ [1, 0, 0, 1, 96],
437
+ [0, 0, 0, 0, 0],
438
+ [0, 0, 0, 0, 0],
439
+ [0, 0, 0, 0, 0],
440
+ [0, 0, 0, 0, 0],
441
+ [0, 0, 0, 0, 0],
442
+ [0, 0, 0, 0, 0],
443
+ [0, 0, 0, 0, 0],
444
+ [1, 0, 1, 0, 88],
445
+ [0, 0, 0, 0, 0],
446
+ [0, 0, 0, 0, 0],
447
+ [0, 0, 0, 0, 0],
448
+ [0, 0, 0, 0, 0],
449
+ [0, 0, 0, 0, 0],
450
+ [1, 0, -1, 0, 82],
451
+ [0, 0, 0, 0, 0],
452
+ [0, 0, 0, 0, 0],
453
+ [0, 0, 0, 0, 0],
454
+ [0, 0, 0, 0, 0],
455
+ [0, 0, 0, 0, 0],
456
+ [0, 0, 0, 0, 0],
457
+ [0, 0, 0, 0, 0],
458
+ [1, 0, 1, 0, 74],
459
+ [0, 0, 0, 0, 0],
460
+ [1, 0, 1, 0, 72],
461
+ [0, 0, 0, 0, 0],
462
+ [1, 0, 0, -1, 70],
463
+ [0, 0, 0, 0, 0],
464
+ [0, 0, 0, 0, 0],
465
+ [1, -1, 0, 0, 67],
466
+ [0, 0, 0, 0, 0],
467
+ [1, -1, 0, 0, 65],
468
+ [0, 0, 0, 0, 0],
469
+ [0, 0, 0, 0, 0],
470
+ [0, 0, 0, 0, 0],
471
+ [0, 0, 0, 0, 0],
472
+ [0, 0, 0, 0, 0],
473
+ [0, 0, 0, 0, 0],
474
+ [0, 0, 0, 0, 0],
475
+ [0, 0, 0, 0, 0],
476
+ [1, 1, 0, 0, 56],
477
+ [0, 0, 0, 0, 0],
478
+ [0, 0, 0, 0, 0],
479
+ [0, 0, 0, 0, 0],
480
+ [1, -1, 0, 0, 52],
481
+ [0, 0, 0, 0, 0],
482
+ [0, 0, 0, 0, 0],
483
+ [0, 0, 0, 0, 0],
484
+ [0, 0, 0, 0, 0],
485
+ [0, 0, 0, 0, 0],
486
+ [0, 0, 0, 0, 0],
487
+ [0, 0, 0, 0, 0],
488
+ [1, 1, 0, 0, 44],
489
+ [0, 0, 0, 0, 0],
490
+ [0, 0, 0, 0, 0],
491
+ [0, 0, 0, 0, 0],
492
+ [1, 1, 0, 0, 40],
493
+ [0, 0, 0, 0, 0],
494
+ [1, 0, 0, -1, 38],
495
+ [1, 0, -1, 0, 37],
496
+ [0, 0, 0, 0, 0],
497
+ [0, 0, 0, 0, 0],
498
+ [0, 0, 0, 0, 0],
499
+ [1, 0, -1, 0, 33],
500
+ [0, 0, 0, 0, 0],
501
+ [0, 0, 0, 0, 0],
502
+ [0, 0, 0, 0, 0],
503
+ [0, 0, 0, 0, 0],
504
+ [1, -1, 0, 0, 28],
505
+ [0, 0, 0, 0, 0],
506
+ [1, 0, -1, 0, 26],
507
+ [1, 0, 0, -1, 25],
508
+ [0, 0, 0, 0, 0],
509
+ [0, 0, 0, 0, 0],
510
+ [0, 0, 0, 0, 0],
511
+ [0, 0, 0, 0, 0],
512
+ [1, -1, 0, 0, 20],
513
+ [0, 0, 0, 0, 0],
514
+ [1, 0, -1, 0, 18],
515
+ [0, 0, 0, 0, 0],
516
+ [0, 0, 0, 0, 0],
517
+ [0, 0, 0, 0, 0],
518
+ [0, 0, 0, 0, 0],
519
+ [0, 0, 0, 0, 0],
520
+ [0, 0, 0, 0, 0],
521
+ [0, 0, 0, 0, 0],
522
+ [0, 0, 0, 0, 0],
523
+ [1, 0, 0, -1, 9],
524
+ [0, 0, 0, 0, 0],
525
+ [0, 0, 0, 0, 0],
526
+ [1, 0, 0, -1, 6],
527
+ [0, 0, 0, 0, 0],
528
+ [0, 0, 0, 0, 0],
529
+ [0, 0, 0, 0, 0],
530
+ [0, 0, 0, 0, 0],
531
+ [0, 0, 0, 0, 0],
532
+ [0, 0, 0, 0, 0]
533
+ ]
534
+ tet_table = [
535
+ [-1, -1, -1, -1, -1, -1],
536
+ [0, 0, 0, 0, 0, 0],
537
+ [0, 0, 0, 0, 0, 0],
538
+ [1, 1, 1, 1, 1, 1],
539
+ [4, 4, 4, 4, 4, 4],
540
+ [0, 0, 0, 0, 0, 0],
541
+ [4, 0, 0, 4, 4, -1],
542
+ [1, 1, 1, 1, 1, 1],
543
+ [4, 4, 4, 4, 4, 4],
544
+ [0, 4, 0, 4, 4, -1],
545
+ [0, 0, 0, 0, 0, 0],
546
+ [1, 1, 1, 1, 1, 1],
547
+ [5, 5, 5, 5, 5, 5],
548
+ [0, 0, 0, 0, 0, 0],
549
+ [0, 0, 0, 0, 0, 0],
550
+ [1, 1, 1, 1, 1, 1],
551
+ [2, 2, 2, 2, 2, 2],
552
+ [0, 0, 0, 0, 0, 0],
553
+ [2, 0, 2, -1, 0, 2],
554
+ [1, 1, 1, 1, 1, 1],
555
+ [2, -1, 2, 4, 4, 2],
556
+ [0, 0, 0, 0, 0, 0],
557
+ [2, 0, 2, 4, 4, 2],
558
+ [1, 1, 1, 1, 1, 1],
559
+ [2, 4, 2, 4, 4, 2],
560
+ [0, 4, 0, 4, 4, 0],
561
+ [2, 0, 2, 0, 0, 2],
562
+ [1, 1, 1, 1, 1, 1],
563
+ [2, 5, 2, 5, 5, 2],
564
+ [0, 0, 0, 0, 0, 0],
565
+ [2, 0, 2, 0, 0, 2],
566
+ [1, 1, 1, 1, 1, 1],
567
+ [1, 1, 1, 1, 1, 1],
568
+ [0, 1, 1, -1, 0, 1],
569
+ [0, 0, 0, 0, 0, 0],
570
+ [2, 2, 2, 2, 2, 2],
571
+ [4, 1, 1, 4, 4, 1],
572
+ [0, 1, 1, 0, 0, 1],
573
+ [4, 0, 0, 4, 4, 0],
574
+ [2, 2, 2, 2, 2, 2],
575
+ [-1, 1, 1, 4, 4, 1],
576
+ [0, 1, 1, 4, 4, 1],
577
+ [0, 0, 0, 0, 0, 0],
578
+ [2, 2, 2, 2, 2, 2],
579
+ [5, 1, 1, 5, 5, 1],
580
+ [0, 1, 1, 0, 0, 1],
581
+ [0, 0, 0, 0, 0, 0],
582
+ [2, 2, 2, 2, 2, 2],
583
+ [1, 1, 1, 1, 1, 1],
584
+ [0, 0, 0, 0, 0, 0],
585
+ [0, 0, 0, 0, 0, 0],
586
+ [8, 8, 8, 8, 8, 8],
587
+ [1, 1, 1, 4, 4, 1],
588
+ [0, 0, 0, 0, 0, 0],
589
+ [4, 0, 0, 4, 4, 0],
590
+ [4, 4, 4, 4, 4, 4],
591
+ [1, 1, 1, 4, 4, 1],
592
+ [0, 4, 0, 4, 4, 0],
593
+ [0, 0, 0, 0, 0, 0],
594
+ [4, 4, 4, 4, 4, 4],
595
+ [1, 1, 1, 5, 5, 1],
596
+ [0, 0, 0, 0, 0, 0],
597
+ [0, 0, 0, 0, 0, 0],
598
+ [5, 5, 5, 5, 5, 5],
599
+ [6, 6, 6, 6, 6, 6],
600
+ [6, -1, 0, 6, 0, 6],
601
+ [6, 0, 0, 6, 0, 6],
602
+ [6, 1, 1, 6, 1, 6],
603
+ [4, 4, 4, 4, 4, 4],
604
+ [0, 0, 0, 0, 0, 0],
605
+ [4, 0, 0, 4, 4, 4],
606
+ [1, 1, 1, 1, 1, 1],
607
+ [6, 4, -1, 6, 4, 6],
608
+ [6, 4, 0, 6, 4, 6],
609
+ [6, 0, 0, 6, 0, 6],
610
+ [6, 1, 1, 6, 1, 6],
611
+ [5, 5, 5, 5, 5, 5],
612
+ [0, 0, 0, 0, 0, 0],
613
+ [0, 0, 0, 0, 0, 0],
614
+ [1, 1, 1, 1, 1, 1],
615
+ [2, 2, 2, 2, 2, 2],
616
+ [0, 0, 0, 0, 0, 0],
617
+ [2, 0, 2, 2, 0, 2],
618
+ [1, 1, 1, 1, 1, 1],
619
+ [2, 2, 2, 2, 2, 2],
620
+ [0, 0, 0, 0, 0, 0],
621
+ [2, 0, 2, 2, 2, 2],
622
+ [1, 1, 1, 1, 1, 1],
623
+ [2, 4, 2, 2, 4, 2],
624
+ [0, 4, 0, 4, 4, 0],
625
+ [2, 0, 2, 2, 0, 2],
626
+ [1, 1, 1, 1, 1, 1],
627
+ [2, 2, 2, 2, 2, 2],
628
+ [0, 0, 0, 0, 0, 0],
629
+ [0, 0, 0, 0, 0, 0],
630
+ [1, 1, 1, 1, 1, 1],
631
+ [6, 1, 1, 6, -1, 6],
632
+ [6, 1, 1, 6, 0, 6],
633
+ [6, 0, 0, 6, 0, 6],
634
+ [6, 2, 2, 6, 2, 6],
635
+ [4, 1, 1, 4, 4, 1],
636
+ [0, 1, 1, 0, 0, 1],
637
+ [4, 0, 0, 4, 4, 4],
638
+ [2, 2, 2, 2, 2, 2],
639
+ [6, 1, 1, 6, 4, 6],
640
+ [6, 1, 1, 6, 4, 6],
641
+ [6, 0, 0, 6, 0, 6],
642
+ [6, 2, 2, 6, 2, 6],
643
+ [5, 1, 1, 5, 5, 1],
644
+ [0, 1, 1, 0, 0, 1],
645
+ [0, 0, 0, 0, 0, 0],
646
+ [2, 2, 2, 2, 2, 2],
647
+ [1, 1, 1, 1, 1, 1],
648
+ [0, 0, 0, 0, 0, 0],
649
+ [0, 0, 0, 0, 0, 0],
650
+ [6, 6, 6, 6, 6, 6],
651
+ [1, 1, 1, 1, 1, 1],
652
+ [0, 0, 0, 0, 0, 0],
653
+ [0, 0, 0, 0, 0, 0],
654
+ [4, 4, 4, 4, 4, 4],
655
+ [1, 1, 1, 1, 4, 1],
656
+ [0, 4, 0, 4, 4, 0],
657
+ [0, 0, 0, 0, 0, 0],
658
+ [4, 4, 4, 4, 4, 4],
659
+ [1, 1, 1, 1, 1, 1],
660
+ [0, 0, 0, 0, 0, 0],
661
+ [0, 5, 0, 5, 0, 5],
662
+ [5, 5, 5, 5, 5, 5],
663
+ [5, 5, 5, 5, 5, 5],
664
+ [0, 5, 0, 5, 0, 5],
665
+ [-1, 5, 0, 5, 0, 5],
666
+ [1, 5, 1, 5, 1, 5],
667
+ [4, 5, -1, 5, 4, 5],
668
+ [0, 5, 0, 5, 0, 5],
669
+ [4, 5, 0, 5, 4, 5],
670
+ [1, 5, 1, 5, 1, 5],
671
+ [4, 4, 4, 4, 4, 4],
672
+ [0, 4, 0, 4, 4, 4],
673
+ [0, 0, 0, 0, 0, 0],
674
+ [1, 1, 1, 1, 1, 1],
675
+ [6, 6, 6, 6, 6, 6],
676
+ [0, 0, 0, 0, 0, 0],
677
+ [0, 0, 0, 0, 0, 0],
678
+ [1, 1, 1, 1, 1, 1],
679
+ [2, 5, 2, 5, -1, 5],
680
+ [0, 5, 0, 5, 0, 5],
681
+ [2, 5, 2, 5, 0, 5],
682
+ [1, 5, 1, 5, 1, 5],
683
+ [2, 5, 2, 5, 4, 5],
684
+ [0, 5, 0, 5, 0, 5],
685
+ [2, 5, 2, 5, 4, 5],
686
+ [1, 5, 1, 5, 1, 5],
687
+ [2, 4, 2, 4, 4, 2],
688
+ [0, 4, 0, 4, 4, 4],
689
+ [2, 0, 2, 0, 0, 2],
690
+ [1, 1, 1, 1, 1, 1],
691
+ [2, 6, 2, 6, 6, 2],
692
+ [0, 0, 0, 0, 0, 0],
693
+ [2, 0, 2, 0, 0, 2],
694
+ [1, 1, 1, 1, 1, 1],
695
+ [1, 1, 1, 1, 1, 1],
696
+ [0, 1, 1, 1, 0, 1],
697
+ [0, 0, 0, 0, 0, 0],
698
+ [2, 2, 2, 2, 2, 2],
699
+ [4, 1, 1, 1, 4, 1],
700
+ [0, 1, 1, 1, 0, 1],
701
+ [4, 0, 0, 4, 4, 0],
702
+ [2, 2, 2, 2, 2, 2],
703
+ [1, 1, 1, 1, 1, 1],
704
+ [0, 1, 1, 1, 1, 1],
705
+ [0, 0, 0, 0, 0, 0],
706
+ [2, 2, 2, 2, 2, 2],
707
+ [1, 1, 1, 1, 1, 1],
708
+ [0, 0, 0, 0, 0, 0],
709
+ [0, 0, 0, 0, 0, 0],
710
+ [2, 2, 2, 2, 2, 2],
711
+ [1, 1, 1, 1, 1, 1],
712
+ [0, 0, 0, 0, 0, 0],
713
+ [0, 0, 0, 0, 0, 0],
714
+ [5, 5, 5, 5, 5, 5],
715
+ [1, 1, 1, 1, 4, 1],
716
+ [0, 0, 0, 0, 0, 0],
717
+ [4, 0, 0, 4, 4, 0],
718
+ [4, 4, 4, 4, 4, 4],
719
+ [1, 1, 1, 1, 1, 1],
720
+ [0, 0, 0, 0, 0, 0],
721
+ [0, 0, 0, 0, 0, 0],
722
+ [4, 4, 4, 4, 4, 4],
723
+ [1, 1, 1, 1, 1, 1],
724
+ [6, 0, 0, 6, 0, 6],
725
+ [0, 0, 0, 0, 0, 0],
726
+ [6, 6, 6, 6, 6, 6],
727
+ [5, 5, 5, 5, 5, 5],
728
+ [5, 5, 0, 5, 0, 5],
729
+ [5, 5, 0, 5, 0, 5],
730
+ [5, 5, 1, 5, 1, 5],
731
+ [4, 4, 4, 4, 4, 4],
732
+ [0, 0, 0, 0, 0, 0],
733
+ [4, 4, 0, 4, 4, 4],
734
+ [1, 1, 1, 1, 1, 1],
735
+ [4, 4, 4, 4, 4, 4],
736
+ [4, 4, 0, 4, 4, 4],
737
+ [0, 0, 0, 0, 0, 0],
738
+ [1, 1, 1, 1, 1, 1],
739
+ [8, 8, 8, 8, 8, 8],
740
+ [0, 0, 0, 0, 0, 0],
741
+ [0, 0, 0, 0, 0, 0],
742
+ [1, 1, 1, 1, 1, 1],
743
+ [2, 2, 2, 2, 2, 2],
744
+ [0, 0, 0, 0, 0, 0],
745
+ [2, 2, 2, 2, 0, 2],
746
+ [1, 1, 1, 1, 1, 1],
747
+ [2, 2, 2, 2, 2, 2],
748
+ [0, 0, 0, 0, 0, 0],
749
+ [2, 2, 2, 2, 2, 2],
750
+ [1, 1, 1, 1, 1, 1],
751
+ [2, 2, 2, 2, 2, 2],
752
+ [0, 0, 0, 0, 0, 0],
753
+ [0, 0, 0, 0, 0, 0],
754
+ [4, 1, 1, 4, 4, 1],
755
+ [2, 2, 2, 2, 2, 2],
756
+ [0, 0, 0, 0, 0, 0],
757
+ [0, 0, 0, 0, 0, 0],
758
+ [1, 1, 1, 1, 1, 1],
759
+ [1, 1, 1, 1, 1, 1],
760
+ [1, 1, 1, 1, 0, 1],
761
+ [0, 0, 0, 0, 0, 0],
762
+ [2, 2, 2, 2, 2, 2],
763
+ [1, 1, 1, 1, 1, 1],
764
+ [0, 0, 0, 0, 0, 0],
765
+ [0, 0, 0, 0, 0, 0],
766
+ [2, 4, 2, 4, 4, 2],
767
+ [1, 1, 1, 1, 1, 1],
768
+ [1, 1, 1, 1, 1, 1],
769
+ [0, 0, 0, 0, 0, 0],
770
+ [2, 2, 2, 2, 2, 2],
771
+ [1, 1, 1, 1, 1, 1],
772
+ [0, 0, 0, 0, 0, 0],
773
+ [0, 0, 0, 0, 0, 0],
774
+ [2, 2, 2, 2, 2, 2],
775
+ [1, 1, 1, 1, 1, 1],
776
+ [0, 0, 0, 0, 0, 0],
777
+ [0, 0, 0, 0, 0, 0],
778
+ [5, 5, 5, 5, 5, 5],
779
+ [1, 1, 1, 1, 1, 1],
780
+ [0, 0, 0, 0, 0, 0],
781
+ [0, 0, 0, 0, 0, 0],
782
+ [4, 4, 4, 4, 4, 4],
783
+ [1, 1, 1, 1, 1, 1],
784
+ [0, 0, 0, 0, 0, 0],
785
+ [0, 0, 0, 0, 0, 0],
786
+ [4, 4, 4, 4, 4, 4],
787
+ [1, 1, 1, 1, 1, 1],
788
+ [0, 0, 0, 0, 0, 0],
789
+ [0, 0, 0, 0, 0, 0],
790
+ [12, 12, 12, 12, 12, 12]
791
+ ]
trellis/representations/octree/octree_dfs.py CHANGED
@@ -3,21 +3,6 @@ import torch.nn as nn
3
  import torch.nn.functional as F
4
 
5
 
6
- DEFAULT_TRIVEC_CONFIG = {
7
- 'dim': 8,
8
- 'rank': 8,
9
- }
10
-
11
- DEFAULT_VOXEL_CONFIG = {
12
- 'solid': False,
13
- }
14
-
15
- DEFAULT_DECOPOLY_CONFIG = {
16
- 'degree': 8,
17
- 'rank': 16,
18
- }
19
-
20
-
21
  class DfsOctree:
22
  """
23
  Sparse Voxel Octree (SVO) implementation for PyTorch.
@@ -145,8 +130,8 @@ class DfsOctree:
145
 
146
  @property
147
  def get_density(self):
148
- if self.primitive == 'voxel' and self.voxel_config['solid']:
149
- return torch.full((self.position.shape[0], 1), 1000, dtype=torch.float32, device=self.device)
150
  return self.density_activation(self.density)
151
 
152
  @property
@@ -172,7 +157,7 @@ class DfsOctree:
172
  return torch.cat([self.features_dc, self.features_ac], dim=-2)
173
 
174
  def state_dict(self):
175
- ret = {'structure': self.structure, 'position': self.position, 'depth': self.depth, 'sh_degree': self.sh_degree, 'active_sh_degree': self.active_sh_degree, 'trivec_config': self.trivec_config, 'voxel_config': self.voxel_config, 'primitive': self.primitive}
176
  if hasattr(self, 'density_shift'):
177
  ret['density_shift'] = self.density_shift
178
  for data in set(self.data + self.param_names):
 
3
  import torch.nn.functional as F
4
 
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  class DfsOctree:
7
  """
8
  Sparse Voxel Octree (SVO) implementation for PyTorch.
 
130
 
131
  @property
132
  def get_density(self):
133
+ if self.primitive == 'voxel' and self.primitive_config.get('solid', False):
134
+ return torch.full((self.position.shape[0], 1), torch.finfo(torch.float32).max, dtype=torch.float32, device=self.device)
135
  return self.density_activation(self.density)
136
 
137
  @property
 
157
  return torch.cat([self.features_dc, self.features_ac], dim=-2)
158
 
159
  def state_dict(self):
160
+ ret = {'structure': self.structure, 'position': self.position, 'depth': self.depth, 'sh_degree': self.sh_degree, 'active_sh_degree': self.active_sh_degree, 'primitive_config': self.primitive_config, 'primitive': self.primitive}
161
  if hasattr(self, 'density_shift'):
162
  ret['density_shift'] = self.density_shift
163
  for data in set(self.data + self.param_names):
trellis/trainers/__init__.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+
3
+ __attributes = {
4
+ 'BasicTrainer': 'basic',
5
+
6
+ 'SparseStructureVaeTrainer': 'vae.sparse_structure_vae',
7
+
8
+ 'SLatVaeGaussianTrainer': 'vae.structured_latent_vae_gaussian',
9
+ 'SLatVaeRadianceFieldDecoderTrainer': 'vae.structured_latent_vae_rf_dec',
10
+ 'SLatVaeMeshDecoderTrainer': 'vae.structured_latent_vae_mesh_dec',
11
+
12
+ 'FlowMatchingTrainer': 'flow_matching.flow_matching',
13
+ 'FlowMatchingCFGTrainer': 'flow_matching.flow_matching',
14
+ 'TextConditionedFlowMatchingCFGTrainer': 'flow_matching.flow_matching',
15
+ 'ImageConditionedFlowMatchingCFGTrainer': 'flow_matching.flow_matching',
16
+
17
+ 'SparseFlowMatchingTrainer': 'flow_matching.sparse_flow_matching',
18
+ 'SparseFlowMatchingCFGTrainer': 'flow_matching.sparse_flow_matching',
19
+ 'TextConditionedSparseFlowMatchingCFGTrainer': 'flow_matching.sparse_flow_matching',
20
+ 'ImageConditionedSparseFlowMatchingCFGTrainer': 'flow_matching.sparse_flow_matching',
21
+ }
22
+
23
+ __submodules = []
24
+
25
+ __all__ = list(__attributes.keys()) + __submodules
26
+
27
+ def __getattr__(name):
28
+ if name not in globals():
29
+ if name in __attributes:
30
+ module_name = __attributes[name]
31
+ module = importlib.import_module(f".{module_name}", __name__)
32
+ globals()[name] = getattr(module, name)
33
+ elif name in __submodules:
34
+ module = importlib.import_module(f".{name}", __name__)
35
+ globals()[name] = module
36
+ else:
37
+ raise AttributeError(f"module {__name__} has no attribute {name}")
38
+ return globals()[name]
39
+
40
+
41
+ # For Pylance
42
+ if __name__ == '__main__':
43
+ from .basic import BasicTrainer
44
+
45
+ from .vae.sparse_structure_vae import SparseStructureVaeTrainer
46
+
47
+ from .vae.structured_latent_vae_gaussian import SLatVaeGaussianTrainer
48
+ from .vae.structured_latent_vae_rf_dec import SLatVaeRadianceFieldDecoderTrainer
49
+ from .vae.structured_latent_vae_mesh_dec import SLatVaeMeshDecoderTrainer
50
+
51
+ from .flow_matching.flow_matching import (
52
+ FlowMatchingTrainer,
53
+ FlowMatchingCFGTrainer,
54
+ TextConditionedFlowMatchingCFGTrainer,
55
+ ImageConditionedFlowMatchingCFGTrainer,
56
+ )
57
+
58
+ from .flow_matching.sparse_flow_matching import (
59
+ SparseFlowMatchingTrainer,
60
+ SparseFlowMatchingCFGTrainer,
61
+ TextConditionedSparseFlowMatchingCFGTrainer,
62
+ ImageConditionedSparseFlowMatchingCFGTrainer,
63
+ )
trellis/trainers/base.py ADDED
@@ -0,0 +1,451 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+ import os
3
+ import time
4
+ import json
5
+
6
+ import torch
7
+ import torch.distributed as dist
8
+ from torch.utils.data import DataLoader
9
+ import numpy as np
10
+
11
+ from torchvision import utils
12
+ from torch.utils.tensorboard import SummaryWriter
13
+
14
+ from .utils import *
15
+ from ..utils.general_utils import *
16
+ from ..utils.data_utils import recursive_to_device, cycle, ResumableSampler
17
+
18
+
19
+ class Trainer:
20
+ """
21
+ Base class for training.
22
+ """
23
+ def __init__(self,
24
+ models,
25
+ dataset,
26
+ *,
27
+ output_dir,
28
+ load_dir,
29
+ step,
30
+ max_steps,
31
+ batch_size=None,
32
+ batch_size_per_gpu=None,
33
+ batch_split=None,
34
+ optimizer={},
35
+ lr_scheduler=None,
36
+ elastic=None,
37
+ grad_clip=None,
38
+ ema_rate=0.9999,
39
+ fp16_mode='inflat_all',
40
+ fp16_scale_growth=1e-3,
41
+ finetune_ckpt=None,
42
+ log_param_stats=False,
43
+ prefetch_data=True,
44
+ i_print=1000,
45
+ i_log=500,
46
+ i_sample=10000,
47
+ i_save=10000,
48
+ i_ddpcheck=10000,
49
+ **kwargs
50
+ ):
51
+ assert batch_size is not None or batch_size_per_gpu is not None, 'Either batch_size or batch_size_per_gpu must be specified.'
52
+
53
+ self.models = models
54
+ self.dataset = dataset
55
+ self.batch_split = batch_split if batch_split is not None else 1
56
+ self.max_steps = max_steps
57
+ self.optimizer_config = optimizer
58
+ self.lr_scheduler_config = lr_scheduler
59
+ self.elastic_controller_config = elastic
60
+ self.grad_clip = grad_clip
61
+ self.ema_rate = [ema_rate] if isinstance(ema_rate, float) else ema_rate
62
+ self.fp16_mode = fp16_mode
63
+ self.fp16_scale_growth = fp16_scale_growth
64
+ self.log_param_stats = log_param_stats
65
+ self.prefetch_data = prefetch_data
66
+ if self.prefetch_data:
67
+ self._data_prefetched = None
68
+
69
+ self.output_dir = output_dir
70
+ self.i_print = i_print
71
+ self.i_log = i_log
72
+ self.i_sample = i_sample
73
+ self.i_save = i_save
74
+ self.i_ddpcheck = i_ddpcheck
75
+
76
+ if dist.is_initialized():
77
+ # Multi-GPU params
78
+ self.world_size = dist.get_world_size()
79
+ self.rank = dist.get_rank()
80
+ self.local_rank = dist.get_rank() % torch.cuda.device_count()
81
+ self.is_master = self.rank == 0
82
+ else:
83
+ # Single-GPU params
84
+ self.world_size = 1
85
+ self.rank = 0
86
+ self.local_rank = 0
87
+ self.is_master = True
88
+
89
+ self.batch_size = batch_size if batch_size_per_gpu is None else batch_size_per_gpu * self.world_size
90
+ self.batch_size_per_gpu = batch_size_per_gpu if batch_size_per_gpu is not None else batch_size // self.world_size
91
+ assert self.batch_size % self.world_size == 0, 'Batch size must be divisible by the number of GPUs.'
92
+ assert self.batch_size_per_gpu % self.batch_split == 0, 'Batch size per GPU must be divisible by batch split.'
93
+
94
+ self.init_models_and_more(**kwargs)
95
+ self.prepare_dataloader(**kwargs)
96
+
97
+ # Load checkpoint
98
+ self.step = 0
99
+ if load_dir is not None and step is not None:
100
+ self.load(load_dir, step)
101
+ elif finetune_ckpt is not None:
102
+ self.finetune_from(finetune_ckpt)
103
+
104
+ if self.is_master:
105
+ os.makedirs(os.path.join(self.output_dir, 'ckpts'), exist_ok=True)
106
+ os.makedirs(os.path.join(self.output_dir, 'samples'), exist_ok=True)
107
+ self.writer = SummaryWriter(os.path.join(self.output_dir, 'tb_logs'))
108
+
109
+ if self.world_size > 1:
110
+ self.check_ddp()
111
+
112
+ if self.is_master:
113
+ print('\n\nTrainer initialized.')
114
+ print(self)
115
+
116
+ @property
117
+ def device(self):
118
+ for _, model in self.models.items():
119
+ if hasattr(model, 'device'):
120
+ return model.device
121
+ return next(list(self.models.values())[0].parameters()).device
122
+
123
+ @abstractmethod
124
+ def init_models_and_more(self, **kwargs):
125
+ """
126
+ Initialize models and more.
127
+ """
128
+ pass
129
+
130
+ def prepare_dataloader(self, **kwargs):
131
+ """
132
+ Prepare dataloader.
133
+ """
134
+ self.data_sampler = ResumableSampler(
135
+ self.dataset,
136
+ shuffle=True,
137
+ )
138
+ self.dataloader = DataLoader(
139
+ self.dataset,
140
+ batch_size=self.batch_size_per_gpu,
141
+ num_workers=int(np.ceil(os.cpu_count() / torch.cuda.device_count())),
142
+ pin_memory=True,
143
+ drop_last=True,
144
+ persistent_workers=True,
145
+ collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None,
146
+ sampler=self.data_sampler,
147
+ )
148
+ self.data_iterator = cycle(self.dataloader)
149
+
150
+ @abstractmethod
151
+ def load(self, load_dir, step=0):
152
+ """
153
+ Load a checkpoint.
154
+ Should be called by all processes.
155
+ """
156
+ pass
157
+
158
+ @abstractmethod
159
+ def save(self):
160
+ """
161
+ Save a checkpoint.
162
+ Should be called only by the rank 0 process.
163
+ """
164
+ pass
165
+
166
+ @abstractmethod
167
+ def finetune_from(self, finetune_ckpt):
168
+ """
169
+ Finetune from a checkpoint.
170
+ Should be called by all processes.
171
+ """
172
+ pass
173
+
174
+ @abstractmethod
175
+ def run_snapshot(self, num_samples, batch_size=4, verbose=False, **kwargs):
176
+ """
177
+ Run a snapshot of the model.
178
+ """
179
+ pass
180
+
181
+ @torch.no_grad()
182
+ def visualize_sample(self, sample):
183
+ """
184
+ Convert a sample to an image.
185
+ """
186
+ if hasattr(self.dataset, 'visualize_sample'):
187
+ return self.dataset.visualize_sample(sample)
188
+ else:
189
+ return sample
190
+
191
+ @torch.no_grad()
192
+ def snapshot_dataset(self, num_samples=100):
193
+ """
194
+ Sample images from the dataset.
195
+ """
196
+ dataloader = torch.utils.data.DataLoader(
197
+ self.dataset,
198
+ batch_size=num_samples,
199
+ num_workers=0,
200
+ shuffle=True,
201
+ collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None,
202
+ )
203
+ data = next(iter(dataloader))
204
+ data = recursive_to_device(data, self.device)
205
+ vis = self.visualize_sample(data)
206
+ if isinstance(vis, dict):
207
+ save_cfg = [(f'dataset_{k}', v) for k, v in vis.items()]
208
+ else:
209
+ save_cfg = [('dataset', vis)]
210
+ for name, image in save_cfg:
211
+ utils.save_image(
212
+ image,
213
+ os.path.join(self.output_dir, 'samples', f'{name}.jpg'),
214
+ nrow=int(np.sqrt(num_samples)),
215
+ normalize=True,
216
+ value_range=self.dataset.value_range,
217
+ )
218
+
219
+ @torch.no_grad()
220
+ def snapshot(self, suffix=None, num_samples=64, batch_size=4, verbose=False):
221
+ """
222
+ Sample images from the model.
223
+ NOTE: This function should be called by all processes.
224
+ """
225
+ if self.is_master:
226
+ print(f'\nSampling {num_samples} images...', end='')
227
+
228
+ if suffix is None:
229
+ suffix = f'step{self.step:07d}'
230
+
231
+ # Assign tasks
232
+ num_samples_per_process = int(np.ceil(num_samples / self.world_size))
233
+ samples = self.run_snapshot(num_samples_per_process, batch_size=batch_size, verbose=verbose)
234
+
235
+ # Preprocess images
236
+ for key in list(samples.keys()):
237
+ if samples[key]['type'] == 'sample':
238
+ vis = self.visualize_sample(samples[key]['value'])
239
+ if isinstance(vis, dict):
240
+ for k, v in vis.items():
241
+ samples[f'{key}_{k}'] = {'value': v, 'type': 'image'}
242
+ del samples[key]
243
+ else:
244
+ samples[key] = {'value': vis, 'type': 'image'}
245
+
246
+ # Gather results
247
+ if self.world_size > 1:
248
+ for key in samples.keys():
249
+ samples[key]['value'] = samples[key]['value'].contiguous()
250
+ if self.is_master:
251
+ all_images = [torch.empty_like(samples[key]['value']) for _ in range(self.world_size)]
252
+ else:
253
+ all_images = []
254
+ dist.gather(samples[key]['value'], all_images, dst=0)
255
+ if self.is_master:
256
+ samples[key]['value'] = torch.cat(all_images, dim=0)[:num_samples]
257
+
258
+ # Save images
259
+ if self.is_master:
260
+ os.makedirs(os.path.join(self.output_dir, 'samples', suffix), exist_ok=True)
261
+ for key in samples.keys():
262
+ if samples[key]['type'] == 'image':
263
+ utils.save_image(
264
+ samples[key]['value'],
265
+ os.path.join(self.output_dir, 'samples', suffix, f'{key}_{suffix}.jpg'),
266
+ nrow=int(np.sqrt(num_samples)),
267
+ normalize=True,
268
+ value_range=self.dataset.value_range,
269
+ )
270
+ elif samples[key]['type'] == 'number':
271
+ min = samples[key]['value'].min()
272
+ max = samples[key]['value'].max()
273
+ images = (samples[key]['value'] - min) / (max - min)
274
+ images = utils.make_grid(
275
+ images,
276
+ nrow=int(np.sqrt(num_samples)),
277
+ normalize=False,
278
+ )
279
+ save_image_with_notes(
280
+ images,
281
+ os.path.join(self.output_dir, 'samples', suffix, f'{key}_{suffix}.jpg'),
282
+ notes=f'{key} min: {min}, max: {max}',
283
+ )
284
+
285
+ if self.is_master:
286
+ print(' Done.')
287
+
288
+ @abstractmethod
289
+ def update_ema(self):
290
+ """
291
+ Update exponential moving average.
292
+ Should only be called by the rank 0 process.
293
+ """
294
+ pass
295
+
296
+ @abstractmethod
297
+ def check_ddp(self):
298
+ """
299
+ Check if DDP is working properly.
300
+ Should be called by all process.
301
+ """
302
+ pass
303
+
304
+ @abstractmethod
305
+ def training_losses(**mb_data):
306
+ """
307
+ Compute training losses.
308
+ """
309
+ pass
310
+
311
+ def load_data(self):
312
+ """
313
+ Load data.
314
+ """
315
+ if self.prefetch_data:
316
+ if self._data_prefetched is None:
317
+ self._data_prefetched = recursive_to_device(next(self.data_iterator), self.device, non_blocking=True)
318
+ data = self._data_prefetched
319
+ self._data_prefetched = recursive_to_device(next(self.data_iterator), self.device, non_blocking=True)
320
+ else:
321
+ data = recursive_to_device(next(self.data_iterator), self.device, non_blocking=True)
322
+
323
+ # if the data is a dict, we need to split it into multiple dicts with batch_size_per_gpu
324
+ if isinstance(data, dict):
325
+ if self.batch_split == 1:
326
+ data_list = [data]
327
+ else:
328
+ batch_size = list(data.values())[0].shape[0]
329
+ data_list = [
330
+ {k: v[i * batch_size // self.batch_split:(i + 1) * batch_size // self.batch_split] for k, v in data.items()}
331
+ for i in range(self.batch_split)
332
+ ]
333
+ elif isinstance(data, list):
334
+ data_list = data
335
+ else:
336
+ raise ValueError('Data must be a dict or a list of dicts.')
337
+
338
+ return data_list
339
+
340
+ @abstractmethod
341
+ def run_step(self, data_list):
342
+ """
343
+ Run a training step.
344
+ """
345
+ pass
346
+
347
+ def run(self):
348
+ """
349
+ Run training.
350
+ """
351
+ if self.is_master:
352
+ print('\nStarting training...')
353
+ self.snapshot_dataset()
354
+ if self.step == 0:
355
+ self.snapshot(suffix='init')
356
+ else: # resume
357
+ self.snapshot(suffix=f'resume_step{self.step:07d}')
358
+
359
+ log = []
360
+ time_last_print = 0.0
361
+ time_elapsed = 0.0
362
+ while self.step < self.max_steps:
363
+ time_start = time.time()
364
+
365
+ data_list = self.load_data()
366
+ step_log = self.run_step(data_list)
367
+
368
+ time_end = time.time()
369
+ time_elapsed += time_end - time_start
370
+
371
+ self.step += 1
372
+
373
+ # Print progress
374
+ if self.is_master and self.step % self.i_print == 0:
375
+ speed = self.i_print / (time_elapsed - time_last_print) * 3600
376
+ columns = [
377
+ f'Step: {self.step}/{self.max_steps} ({self.step / self.max_steps * 100:.2f}%)',
378
+ f'Elapsed: {time_elapsed / 3600:.2f} h',
379
+ f'Speed: {speed:.2f} steps/h',
380
+ f'ETA: {(self.max_steps - self.step) / speed:.2f} h',
381
+ ]
382
+ print(' | '.join([c.ljust(25) for c in columns]), flush=True)
383
+ time_last_print = time_elapsed
384
+
385
+ # Check ddp
386
+ if self.world_size > 1 and self.i_ddpcheck is not None and self.step % self.i_ddpcheck == 0:
387
+ self.check_ddp()
388
+
389
+ # Sample images
390
+ if self.step % self.i_sample == 0:
391
+ self.snapshot()
392
+
393
+ if self.is_master:
394
+ log.append((self.step, {}))
395
+
396
+ # Log time
397
+ log[-1][1]['time'] = {
398
+ 'step': time_end - time_start,
399
+ 'elapsed': time_elapsed,
400
+ }
401
+
402
+ # Log losses
403
+ if step_log is not None:
404
+ log[-1][1].update(step_log)
405
+
406
+ # Log scale
407
+ if self.fp16_mode == 'amp':
408
+ log[-1][1]['scale'] = self.scaler.get_scale()
409
+ elif self.fp16_mode == 'inflat_all':
410
+ log[-1][1]['log_scale'] = self.log_scale
411
+
412
+ # Save log
413
+ if self.step % self.i_log == 0:
414
+ ## save to log file
415
+ log_str = '\n'.join([
416
+ f'{step}: {json.dumps(log)}' for step, log in log
417
+ ])
418
+ with open(os.path.join(self.output_dir, 'log.txt'), 'a') as log_file:
419
+ log_file.write(log_str + '\n')
420
+
421
+ # show with mlflow
422
+ log_show = [l for _, l in log if not dict_any(l, lambda x: np.isnan(x))]
423
+ log_show = dict_reduce(log_show, lambda x: np.mean(x))
424
+ log_show = dict_flatten(log_show, sep='/')
425
+ for key, value in log_show.items():
426
+ self.writer.add_scalar(key, value, self.step)
427
+ log = []
428
+
429
+ # Save checkpoint
430
+ if self.step % self.i_save == 0:
431
+ self.save()
432
+
433
+ if self.is_master:
434
+ self.snapshot(suffix='final')
435
+ self.writer.close()
436
+ print('Training finished.')
437
+
438
+ def profile(self, wait=2, warmup=3, active=5):
439
+ """
440
+ Profile the training loop.
441
+ """
442
+ with torch.profiler.profile(
443
+ schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=1),
444
+ on_trace_ready=torch.profiler.tensorboard_trace_handler(os.path.join(self.output_dir, 'profile')),
445
+ profile_memory=True,
446
+ with_stack=True,
447
+ ) as prof:
448
+ for _ in range(wait + warmup + active):
449
+ self.run_step()
450
+ prof.step()
451
+
trellis/trainers/basic.py ADDED
@@ -0,0 +1,438 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import copy
3
+ from functools import partial
4
+ from contextlib import nullcontext
5
+
6
+ import torch
7
+ import torch.distributed as dist
8
+ from torch.nn.parallel import DistributedDataParallel as DDP
9
+ import numpy as np
10
+
11
+ from .utils import *
12
+ from .base import Trainer
13
+ from ..utils.general_utils import *
14
+ from ..utils.dist_utils import *
15
+ from ..utils import grad_clip_utils, elastic_utils
16
+
17
+
18
+ class BasicTrainer(Trainer):
19
+ """
20
+ Trainer for basic training loop.
21
+
22
+ Args:
23
+ models (dict[str, nn.Module]): Models to train.
24
+ dataset (torch.utils.data.Dataset): Dataset.
25
+ output_dir (str): Output directory.
26
+ load_dir (str): Load directory.
27
+ step (int): Step to load.
28
+ batch_size (int): Batch size.
29
+ batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
30
+ batch_split (int): Split batch with gradient accumulation.
31
+ max_steps (int): Max steps.
32
+ optimizer (dict): Optimizer config.
33
+ lr_scheduler (dict): Learning rate scheduler config.
34
+ elastic (dict): Elastic memory management config.
35
+ grad_clip (float or dict): Gradient clip config.
36
+ ema_rate (float or list): Exponential moving average rates.
37
+ fp16_mode (str): FP16 mode.
38
+ - None: No FP16.
39
+ - 'inflat_all': Hold a inflated fp32 master param for all params.
40
+ - 'amp': Automatic mixed precision.
41
+ fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
42
+ finetune_ckpt (dict): Finetune checkpoint.
43
+ log_param_stats (bool): Log parameter stats.
44
+ i_print (int): Print interval.
45
+ i_log (int): Log interval.
46
+ i_sample (int): Sample interval.
47
+ i_save (int): Save interval.
48
+ i_ddpcheck (int): DDP check interval.
49
+ """
50
+
51
+ def __str__(self):
52
+ lines = []
53
+ lines.append(self.__class__.__name__)
54
+ lines.append(f' - Models:')
55
+ for name, model in self.models.items():
56
+ lines.append(f' - {name}: {model.__class__.__name__}')
57
+ lines.append(f' - Dataset: {indent(str(self.dataset), 2)}')
58
+ lines.append(f' - Dataloader:')
59
+ lines.append(f' - Sampler: {self.dataloader.sampler.__class__.__name__}')
60
+ lines.append(f' - Num workers: {self.dataloader.num_workers}')
61
+ lines.append(f' - Number of steps: {self.max_steps}')
62
+ lines.append(f' - Number of GPUs: {self.world_size}')
63
+ lines.append(f' - Batch size: {self.batch_size}')
64
+ lines.append(f' - Batch size per GPU: {self.batch_size_per_gpu}')
65
+ lines.append(f' - Batch split: {self.batch_split}')
66
+ lines.append(f' - Optimizer: {self.optimizer.__class__.__name__}')
67
+ lines.append(f' - Learning rate: {self.optimizer.param_groups[0]["lr"]}')
68
+ if self.lr_scheduler_config is not None:
69
+ lines.append(f' - LR scheduler: {self.lr_scheduler.__class__.__name__}')
70
+ if self.elastic_controller_config is not None:
71
+ lines.append(f' - Elastic memory: {indent(str(self.elastic_controller), 2)}')
72
+ if self.grad_clip is not None:
73
+ lines.append(f' - Gradient clip: {indent(str(self.grad_clip), 2)}')
74
+ lines.append(f' - EMA rate: {self.ema_rate}')
75
+ lines.append(f' - FP16 mode: {self.fp16_mode}')
76
+ return '\n'.join(lines)
77
+
78
+ def init_models_and_more(self, **kwargs):
79
+ """
80
+ Initialize models and more.
81
+ """
82
+ if self.world_size > 1:
83
+ # Prepare distributed data parallel
84
+ self.training_models = {
85
+ name: DDP(
86
+ model,
87
+ device_ids=[self.local_rank],
88
+ output_device=self.local_rank,
89
+ bucket_cap_mb=128,
90
+ find_unused_parameters=False
91
+ )
92
+ for name, model in self.models.items()
93
+ }
94
+ else:
95
+ self.training_models = self.models
96
+
97
+ # Build master params
98
+ self.model_params = sum(
99
+ [[p for p in model.parameters() if p.requires_grad] for model in self.models.values()]
100
+ , [])
101
+ if self.fp16_mode == 'amp':
102
+ self.master_params = self.model_params
103
+ self.scaler = torch.GradScaler() if self.fp16_mode == 'amp' else None
104
+ elif self.fp16_mode == 'inflat_all':
105
+ self.master_params = make_master_params(self.model_params)
106
+ self.fp16_scale_growth = self.fp16_scale_growth
107
+ self.log_scale = 20.0
108
+ elif self.fp16_mode is None:
109
+ self.master_params = self.model_params
110
+ else:
111
+ raise NotImplementedError(f'FP16 mode {self.fp16_mode} is not implemented.')
112
+
113
+ # Build EMA params
114
+ if self.is_master:
115
+ self.ema_params = [copy.deepcopy(self.master_params) for _ in self.ema_rate]
116
+
117
+ # Initialize optimizer
118
+ if hasattr(torch.optim, self.optimizer_config['name']):
119
+ self.optimizer = getattr(torch.optim, self.optimizer_config['name'])(self.master_params, **self.optimizer_config['args'])
120
+ else:
121
+ self.optimizer = globals()[self.optimizer_config['name']](self.master_params, **self.optimizer_config['args'])
122
+
123
+ # Initalize learning rate scheduler
124
+ if self.lr_scheduler_config is not None:
125
+ if hasattr(torch.optim.lr_scheduler, self.lr_scheduler_config['name']):
126
+ self.lr_scheduler = getattr(torch.optim.lr_scheduler, self.lr_scheduler_config['name'])(self.optimizer, **self.lr_scheduler_config['args'])
127
+ else:
128
+ self.lr_scheduler = globals()[self.lr_scheduler_config['name']](self.optimizer, **self.lr_scheduler_config['args'])
129
+
130
+ # Initialize elastic memory controller
131
+ if self.elastic_controller_config is not None:
132
+ assert any([isinstance(model, (elastic_utils.ElasticModule, elastic_utils.ElasticModuleMixin)) for model in self.models.values()]), \
133
+ 'No elastic module found in models, please inherit from ElasticModule or ElasticModuleMixin'
134
+ self.elastic_controller = getattr(elastic_utils, self.elastic_controller_config['name'])(**self.elastic_controller_config['args'])
135
+ for model in self.models.values():
136
+ if isinstance(model, (elastic_utils.ElasticModule, elastic_utils.ElasticModuleMixin)):
137
+ model.register_memory_controller(self.elastic_controller)
138
+
139
+ # Initialize gradient clipper
140
+ if self.grad_clip is not None:
141
+ if isinstance(self.grad_clip, (float, int)):
142
+ self.grad_clip = float(self.grad_clip)
143
+ else:
144
+ self.grad_clip = getattr(grad_clip_utils, self.grad_clip['name'])(**self.grad_clip['args'])
145
+
146
+ def _master_params_to_state_dicts(self, master_params):
147
+ """
148
+ Convert master params to dict of state_dicts.
149
+ """
150
+ if self.fp16_mode == 'inflat_all':
151
+ master_params = unflatten_master_params(self.model_params, master_params)
152
+ state_dicts = {name: model.state_dict() for name, model in self.models.items()}
153
+ master_params_names = sum(
154
+ [[(name, n) for n, p in model.named_parameters() if p.requires_grad] for name, model in self.models.items()]
155
+ , [])
156
+ for i, (model_name, param_name) in enumerate(master_params_names):
157
+ state_dicts[model_name][param_name] = master_params[i]
158
+ return state_dicts
159
+
160
+ def _state_dicts_to_master_params(self, master_params, state_dicts):
161
+ """
162
+ Convert a state_dict to master params.
163
+ """
164
+ master_params_names = sum(
165
+ [[(name, n) for n, p in model.named_parameters() if p.requires_grad] for name, model in self.models.items()]
166
+ , [])
167
+ params = [state_dicts[name][param_name] for name, param_name in master_params_names]
168
+ if self.fp16_mode == 'inflat_all':
169
+ model_params_to_master_params(params, master_params)
170
+ else:
171
+ for i, param in enumerate(params):
172
+ master_params[i].data.copy_(param.data)
173
+
174
+ def load(self, load_dir, step=0):
175
+ """
176
+ Load a checkpoint.
177
+ Should be called by all processes.
178
+ """
179
+ if self.is_master:
180
+ print(f'\nLoading checkpoint from step {step}...', end='')
181
+
182
+ model_ckpts = {}
183
+ for name, model in self.models.items():
184
+ model_ckpt = torch.load(read_file_dist(os.path.join(load_dir, 'ckpts', f'{name}_step{step:07d}.pt')), map_location=self.device, weights_only=True)
185
+ model_ckpts[name] = model_ckpt
186
+ model.load_state_dict(model_ckpt)
187
+ if self.fp16_mode == 'inflat_all':
188
+ model.convert_to_fp16()
189
+ self._state_dicts_to_master_params(self.master_params, model_ckpts)
190
+ del model_ckpts
191
+
192
+ if self.is_master:
193
+ for i, ema_rate in enumerate(self.ema_rate):
194
+ ema_ckpts = {}
195
+ for name, model in self.models.items():
196
+ ema_ckpt = torch.load(os.path.join(load_dir, 'ckpts', f'{name}_ema{ema_rate}_step{step:07d}.pt'), map_location=self.device, weights_only=True)
197
+ ema_ckpts[name] = ema_ckpt
198
+ self._state_dicts_to_master_params(self.ema_params[i], ema_ckpts)
199
+ del ema_ckpts
200
+
201
+ misc_ckpt = torch.load(read_file_dist(os.path.join(load_dir, 'ckpts', f'misc_step{step:07d}.pt')), map_location=torch.device('cpu'), weights_only=False)
202
+ self.optimizer.load_state_dict(misc_ckpt['optimizer'])
203
+ self.step = misc_ckpt['step']
204
+ self.data_sampler.load_state_dict(misc_ckpt['data_sampler'])
205
+ if self.fp16_mode == 'amp':
206
+ self.scaler.load_state_dict(misc_ckpt['scaler'])
207
+ elif self.fp16_mode == 'inflat_all':
208
+ self.log_scale = misc_ckpt['log_scale']
209
+ if self.lr_scheduler_config is not None:
210
+ self.lr_scheduler.load_state_dict(misc_ckpt['lr_scheduler'])
211
+ if self.elastic_controller_config is not None:
212
+ self.elastic_controller.load_state_dict(misc_ckpt['elastic_controller'])
213
+ if self.grad_clip is not None and not isinstance(self.grad_clip, float):
214
+ self.grad_clip.load_state_dict(misc_ckpt['grad_clip'])
215
+ del misc_ckpt
216
+
217
+ if self.world_size > 1:
218
+ dist.barrier()
219
+ if self.is_master:
220
+ print(' Done.')
221
+
222
+ if self.world_size > 1:
223
+ self.check_ddp()
224
+
225
+ def save(self):
226
+ """
227
+ Save a checkpoint.
228
+ Should be called only by the rank 0 process.
229
+ """
230
+ assert self.is_master, 'save() should be called only by the rank 0 process.'
231
+ print(f'\nSaving checkpoint at step {self.step}...', end='')
232
+
233
+ model_ckpts = self._master_params_to_state_dicts(self.master_params)
234
+ for name, model_ckpt in model_ckpts.items():
235
+ torch.save(model_ckpt, os.path.join(self.output_dir, 'ckpts', f'{name}_step{self.step:07d}.pt'))
236
+
237
+ for i, ema_rate in enumerate(self.ema_rate):
238
+ ema_ckpts = self._master_params_to_state_dicts(self.ema_params[i])
239
+ for name, ema_ckpt in ema_ckpts.items():
240
+ torch.save(ema_ckpt, os.path.join(self.output_dir, 'ckpts', f'{name}_ema{ema_rate}_step{self.step:07d}.pt'))
241
+
242
+ misc_ckpt = {
243
+ 'optimizer': self.optimizer.state_dict(),
244
+ 'step': self.step,
245
+ 'data_sampler': self.data_sampler.state_dict(),
246
+ }
247
+ if self.fp16_mode == 'amp':
248
+ misc_ckpt['scaler'] = self.scaler.state_dict()
249
+ elif self.fp16_mode == 'inflat_all':
250
+ misc_ckpt['log_scale'] = self.log_scale
251
+ if self.lr_scheduler_config is not None:
252
+ misc_ckpt['lr_scheduler'] = self.lr_scheduler.state_dict()
253
+ if self.elastic_controller_config is not None:
254
+ misc_ckpt['elastic_controller'] = self.elastic_controller.state_dict()
255
+ if self.grad_clip is not None and not isinstance(self.grad_clip, float):
256
+ misc_ckpt['grad_clip'] = self.grad_clip.state_dict()
257
+ torch.save(misc_ckpt, os.path.join(self.output_dir, 'ckpts', f'misc_step{self.step:07d}.pt'))
258
+ print(' Done.')
259
+
260
+ def finetune_from(self, finetune_ckpt):
261
+ """
262
+ Finetune from a checkpoint.
263
+ Should be called by all processes.
264
+ """
265
+ if self.is_master:
266
+ print('\nFinetuning from:')
267
+ for name, path in finetune_ckpt.items():
268
+ print(f' - {name}: {path}')
269
+
270
+ model_ckpts = {}
271
+ for name, model in self.models.items():
272
+ model_state_dict = model.state_dict()
273
+ if name in finetune_ckpt:
274
+ model_ckpt = torch.load(read_file_dist(finetune_ckpt[name]), map_location=self.device, weights_only=True)
275
+ for k, v in model_ckpt.items():
276
+ if model_ckpt[k].shape != model_state_dict[k].shape:
277
+ if self.is_master:
278
+ print(f'Warning: {k} shape mismatch, {model_ckpt[k].shape} vs {model_state_dict[k].shape}, skipped.')
279
+ model_ckpt[k] = model_state_dict[k]
280
+ model_ckpts[name] = model_ckpt
281
+ model.load_state_dict(model_ckpt)
282
+ if self.fp16_mode == 'inflat_all':
283
+ model.convert_to_fp16()
284
+ else:
285
+ if self.is_master:
286
+ print(f'Warning: {name} not found in finetune_ckpt, skipped.')
287
+ model_ckpts[name] = model_state_dict
288
+ self._state_dicts_to_master_params(self.master_params, model_ckpts)
289
+ del model_ckpts
290
+
291
+ if self.world_size > 1:
292
+ dist.barrier()
293
+ if self.is_master:
294
+ print('Done.')
295
+
296
+ if self.world_size > 1:
297
+ self.check_ddp()
298
+
299
+ def update_ema(self):
300
+ """
301
+ Update exponential moving average.
302
+ Should only be called by the rank 0 process.
303
+ """
304
+ assert self.is_master, 'update_ema() should be called only by the rank 0 process.'
305
+ for i, ema_rate in enumerate(self.ema_rate):
306
+ for master_param, ema_param in zip(self.master_params, self.ema_params[i]):
307
+ ema_param.detach().mul_(ema_rate).add_(master_param, alpha=1.0 - ema_rate)
308
+
309
+ def check_ddp(self):
310
+ """
311
+ Check if DDP is working properly.
312
+ Should be called by all process.
313
+ """
314
+ if self.is_master:
315
+ print('\nPerforming DDP check...')
316
+
317
+ if self.is_master:
318
+ print('Checking if parameters are consistent across processes...')
319
+ dist.barrier()
320
+ try:
321
+ for p in self.master_params:
322
+ # split to avoid OOM
323
+ for i in range(0, p.numel(), 10000000):
324
+ sub_size = min(10000000, p.numel() - i)
325
+ sub_p = p.detach().view(-1)[i:i+sub_size]
326
+ # gather from all processes
327
+ sub_p_gather = [torch.empty_like(sub_p) for _ in range(self.world_size)]
328
+ dist.all_gather(sub_p_gather, sub_p)
329
+ # check if equal
330
+ assert all([torch.equal(sub_p, sub_p_gather[i]) for i in range(self.world_size)]), 'parameters are not consistent across processes'
331
+ except AssertionError as e:
332
+ if self.is_master:
333
+ print(f'\n\033[91mError: {e}\033[0m')
334
+ print('DDP check failed.')
335
+ raise e
336
+
337
+ dist.barrier()
338
+ if self.is_master:
339
+ print('Done.')
340
+
341
+ def run_step(self, data_list):
342
+ """
343
+ Run a training step.
344
+ """
345
+ step_log = {'loss': {}, 'status': {}}
346
+ amp_context = partial(torch.autocast, device_type='cuda') if self.fp16_mode == 'amp' else nullcontext
347
+ elastic_controller_context = self.elastic_controller.record if self.elastic_controller_config is not None else nullcontext
348
+
349
+ # Train
350
+ losses = []
351
+ statuses = []
352
+ elastic_controller_logs = []
353
+ zero_grad(self.model_params)
354
+ for i, mb_data in enumerate(data_list):
355
+ ## sync at the end of each batch split
356
+ sync_contexts = [self.training_models[name].no_sync for name in self.training_models] if i != len(data_list) - 1 and self.world_size > 1 else [nullcontext]
357
+ with nested_contexts(*sync_contexts), elastic_controller_context():
358
+ with amp_context():
359
+ loss, status = self.training_losses(**mb_data)
360
+ l = loss['loss'] / len(data_list)
361
+ ## backward
362
+ if self.fp16_mode == 'amp':
363
+ self.scaler.scale(l).backward()
364
+ elif self.fp16_mode == 'inflat_all':
365
+ scaled_l = l * (2 ** self.log_scale)
366
+ scaled_l.backward()
367
+ else:
368
+ l.backward()
369
+ ## log
370
+ losses.append(dict_foreach(loss, lambda x: x.item() if isinstance(x, torch.Tensor) else x))
371
+ statuses.append(dict_foreach(status, lambda x: x.item() if isinstance(x, torch.Tensor) else x))
372
+ if self.elastic_controller_config is not None:
373
+ elastic_controller_logs.append(self.elastic_controller.log())
374
+ ## gradient clip
375
+ if self.grad_clip is not None:
376
+ if self.fp16_mode == 'amp':
377
+ self.scaler.unscale_(self.optimizer)
378
+ elif self.fp16_mode == 'inflat_all':
379
+ model_grads_to_master_grads(self.model_params, self.master_params)
380
+ self.master_params[0].grad.mul_(1.0 / (2 ** self.log_scale))
381
+ if isinstance(self.grad_clip, float):
382
+ grad_norm = torch.nn.utils.clip_grad_norm_(self.master_params, self.grad_clip)
383
+ else:
384
+ grad_norm = self.grad_clip(self.master_params)
385
+ if torch.isfinite(grad_norm):
386
+ statuses[-1]['grad_norm'] = grad_norm.item()
387
+ ## step
388
+ if self.fp16_mode == 'amp':
389
+ prev_scale = self.scaler.get_scale()
390
+ self.scaler.step(self.optimizer)
391
+ self.scaler.update()
392
+ elif self.fp16_mode == 'inflat_all':
393
+ prev_scale = 2 ** self.log_scale
394
+ if not any(not p.grad.isfinite().all() for p in self.model_params):
395
+ if self.grad_clip is None:
396
+ model_grads_to_master_grads(self.model_params, self.master_params)
397
+ self.master_params[0].grad.mul_(1.0 / (2 ** self.log_scale))
398
+ self.optimizer.step()
399
+ master_params_to_model_params(self.model_params, self.master_params)
400
+ self.log_scale += self.fp16_scale_growth
401
+ else:
402
+ self.log_scale -= 1
403
+ else:
404
+ prev_scale = 1.0
405
+ if not any(not p.grad.isfinite().all() for p in self.model_params):
406
+ self.optimizer.step()
407
+ else:
408
+ print('\n\033[93mWarning: NaN detected in gradients. Skipping update.\033[0m')
409
+ ## adjust learning rate
410
+ if self.lr_scheduler_config is not None:
411
+ statuses[-1]['lr'] = self.lr_scheduler.get_last_lr()[0]
412
+ self.lr_scheduler.step()
413
+
414
+ # Logs
415
+ step_log['loss'] = dict_reduce(losses, lambda x: np.mean(x))
416
+ step_log['status'] = dict_reduce(statuses, lambda x: np.mean(x), special_func={'min': lambda x: np.min(x), 'max': lambda x: np.max(x)})
417
+ if self.elastic_controller_config is not None:
418
+ step_log['elastic'] = dict_reduce(elastic_controller_logs, lambda x: np.mean(x))
419
+ if self.grad_clip is not None:
420
+ step_log['grad_clip'] = self.grad_clip if isinstance(self.grad_clip, float) else self.grad_clip.log()
421
+
422
+ # Check grad and norm of each param
423
+ if self.log_param_stats:
424
+ param_norms = {}
425
+ param_grads = {}
426
+ for name, param in self.backbone.named_parameters():
427
+ if param.requires_grad:
428
+ param_norms[name] = param.norm().item()
429
+ if param.grad is not None and torch.isfinite(param.grad).all():
430
+ param_grads[name] = param.grad.norm().item() / prev_scale
431
+ step_log['param_norms'] = param_norms
432
+ step_log['param_grads'] = param_grads
433
+
434
+ # Update exponential moving average
435
+ if self.is_master:
436
+ self.update_ema()
437
+
438
+ return step_log
trellis/trainers/flow_matching/flow_matching.py ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import copy
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch.utils.data import DataLoader
6
+ import numpy as np
7
+ from easydict import EasyDict as edict
8
+
9
+ from ..basic import BasicTrainer
10
+ from ...pipelines import samplers
11
+ from ...utils.general_utils import dict_reduce
12
+ from .mixins.classifier_free_guidance import ClassifierFreeGuidanceMixin
13
+ from .mixins.text_conditioned import TextConditionedMixin
14
+ from .mixins.image_conditioned import ImageConditionedMixin
15
+
16
+
17
+ class FlowMatchingTrainer(BasicTrainer):
18
+ """
19
+ Trainer for diffusion model with flow matching objective.
20
+
21
+ Args:
22
+ models (dict[str, nn.Module]): Models to train.
23
+ dataset (torch.utils.data.Dataset): Dataset.
24
+ output_dir (str): Output directory.
25
+ load_dir (str): Load directory.
26
+ step (int): Step to load.
27
+ batch_size (int): Batch size.
28
+ batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
29
+ batch_split (int): Split batch with gradient accumulation.
30
+ max_steps (int): Max steps.
31
+ optimizer (dict): Optimizer config.
32
+ lr_scheduler (dict): Learning rate scheduler config.
33
+ elastic (dict): Elastic memory management config.
34
+ grad_clip (float or dict): Gradient clip config.
35
+ ema_rate (float or list): Exponential moving average rates.
36
+ fp16_mode (str): FP16 mode.
37
+ - None: No FP16.
38
+ - 'inflat_all': Hold a inflated fp32 master param for all params.
39
+ - 'amp': Automatic mixed precision.
40
+ fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
41
+ finetune_ckpt (dict): Finetune checkpoint.
42
+ log_param_stats (bool): Log parameter stats.
43
+ i_print (int): Print interval.
44
+ i_log (int): Log interval.
45
+ i_sample (int): Sample interval.
46
+ i_save (int): Save interval.
47
+ i_ddpcheck (int): DDP check interval.
48
+
49
+ t_schedule (dict): Time schedule for flow matching.
50
+ sigma_min (float): Minimum noise level.
51
+ """
52
+ def __init__(
53
+ self,
54
+ *args,
55
+ t_schedule: dict = {
56
+ 'name': 'logitNormal',
57
+ 'args': {
58
+ 'mean': 0.0,
59
+ 'std': 1.0,
60
+ }
61
+ },
62
+ sigma_min: float = 1e-5,
63
+ **kwargs
64
+ ):
65
+ super().__init__(*args, **kwargs)
66
+ self.t_schedule = t_schedule
67
+ self.sigma_min = sigma_min
68
+
69
+ def diffuse(self, x_0: torch.Tensor, t: torch.Tensor, noise: Optional[torch.Tensor] = None) -> torch.Tensor:
70
+ """
71
+ Diffuse the data for a given number of diffusion steps.
72
+ In other words, sample from q(x_t | x_0).
73
+
74
+ Args:
75
+ x_0: The [N x C x ...] tensor of noiseless inputs.
76
+ t: The [N] tensor of diffusion steps [0-1].
77
+ noise: If specified, use this noise instead of generating new noise.
78
+
79
+ Returns:
80
+ x_t, the noisy version of x_0 under timestep t.
81
+ """
82
+ if noise is None:
83
+ noise = torch.randn_like(x_0)
84
+ assert noise.shape == x_0.shape, "noise must have same shape as x_0"
85
+
86
+ t = t.view(-1, *[1 for _ in range(len(x_0.shape) - 1)])
87
+ x_t = (1 - t) * x_0 + (self.sigma_min + (1 - self.sigma_min) * t) * noise
88
+
89
+ return x_t
90
+
91
+ def reverse_diffuse(self, x_t: torch.Tensor, t: torch.Tensor, noise: torch.Tensor) -> torch.Tensor:
92
+ """
93
+ Get original image from noisy version under timestep t.
94
+ """
95
+ assert noise.shape == x_t.shape, "noise must have same shape as x_t"
96
+ t = t.view(-1, *[1 for _ in range(len(x_t.shape) - 1)])
97
+ x_0 = (x_t - (self.sigma_min + (1 - self.sigma_min) * t) * noise) / (1 - t)
98
+ return x_0
99
+
100
+ def get_v(self, x_0: torch.Tensor, noise: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
101
+ """
102
+ Compute the velocity of the diffusion process at time t.
103
+ """
104
+ return (1 - self.sigma_min) * noise - x_0
105
+
106
+ def get_cond(self, cond, **kwargs):
107
+ """
108
+ Get the conditioning data.
109
+ """
110
+ return cond
111
+
112
+ def get_inference_cond(self, cond, **kwargs):
113
+ """
114
+ Get the conditioning data for inference.
115
+ """
116
+ return {'cond': cond, **kwargs}
117
+
118
+ def get_sampler(self, **kwargs) -> samplers.FlowEulerSampler:
119
+ """
120
+ Get the sampler for the diffusion process.
121
+ """
122
+ return samplers.FlowEulerSampler(self.sigma_min)
123
+
124
+ def vis_cond(self, **kwargs):
125
+ """
126
+ Visualize the conditioning data.
127
+ """
128
+ return {}
129
+
130
+ def sample_t(self, batch_size: int) -> torch.Tensor:
131
+ """
132
+ Sample timesteps.
133
+ """
134
+ if self.t_schedule['name'] == 'uniform':
135
+ t = torch.rand(batch_size)
136
+ elif self.t_schedule['name'] == 'logitNormal':
137
+ mean = self.t_schedule['args']['mean']
138
+ std = self.t_schedule['args']['std']
139
+ t = torch.sigmoid(torch.randn(batch_size) * std + mean)
140
+ else:
141
+ raise ValueError(f"Unknown t_schedule: {self.t_schedule['name']}")
142
+ return t
143
+
144
+ def training_losses(
145
+ self,
146
+ x_0: torch.Tensor,
147
+ cond=None,
148
+ **kwargs
149
+ ) -> Tuple[Dict, Dict]:
150
+ """
151
+ Compute training losses for a single timestep.
152
+
153
+ Args:
154
+ x_0: The [N x C x ...] tensor of noiseless inputs.
155
+ cond: The [N x ...] tensor of additional conditions.
156
+ kwargs: Additional arguments to pass to the backbone.
157
+
158
+ Returns:
159
+ a dict with the key "loss" containing a tensor of shape [N].
160
+ may also contain other keys for different terms.
161
+ """
162
+ noise = torch.randn_like(x_0)
163
+ t = self.sample_t(x_0.shape[0]).to(x_0.device).float()
164
+ x_t = self.diffuse(x_0, t, noise=noise)
165
+ cond = self.get_cond(cond, **kwargs)
166
+
167
+ pred = self.training_models['denoiser'](x_t, t * 1000, cond, **kwargs)
168
+ assert pred.shape == noise.shape == x_0.shape
169
+ target = self.get_v(x_0, noise, t)
170
+ terms = edict()
171
+ terms["mse"] = F.mse_loss(pred, target)
172
+ terms["loss"] = terms["mse"]
173
+
174
+ # log loss with time bins
175
+ mse_per_instance = np.array([
176
+ F.mse_loss(pred[i], target[i]).item()
177
+ for i in range(x_0.shape[0])
178
+ ])
179
+ time_bin = np.digitize(t.cpu().numpy(), np.linspace(0, 1, 11)) - 1
180
+ for i in range(10):
181
+ if (time_bin == i).sum() != 0:
182
+ terms[f"bin_{i}"] = {"mse": mse_per_instance[time_bin == i].mean()}
183
+
184
+ return terms, {}
185
+
186
+ @torch.no_grad()
187
+ def run_snapshot(
188
+ self,
189
+ num_samples: int,
190
+ batch_size: int,
191
+ verbose: bool = False,
192
+ ) -> Dict:
193
+ dataloader = DataLoader(
194
+ copy.deepcopy(self.dataset),
195
+ batch_size=batch_size,
196
+ shuffle=True,
197
+ num_workers=0,
198
+ collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None,
199
+ )
200
+
201
+ # inference
202
+ sampler = self.get_sampler()
203
+ sample_gt = []
204
+ sample = []
205
+ cond_vis = []
206
+ for i in range(0, num_samples, batch_size):
207
+ batch = min(batch_size, num_samples - i)
208
+ data = next(iter(dataloader))
209
+ data = {k: v[:batch].cuda() if isinstance(v, torch.Tensor) else v[:batch] for k, v in data.items()}
210
+ noise = torch.randn_like(data['x_0'])
211
+ sample_gt.append(data['x_0'])
212
+ cond_vis.append(self.vis_cond(**data))
213
+ del data['x_0']
214
+ args = self.get_inference_cond(**data)
215
+ res = sampler.sample(
216
+ self.models['denoiser'],
217
+ noise=noise,
218
+ **args,
219
+ steps=50, cfg_strength=3.0, verbose=verbose,
220
+ )
221
+ sample.append(res.samples)
222
+
223
+ sample_gt = torch.cat(sample_gt, dim=0)
224
+ sample = torch.cat(sample, dim=0)
225
+ sample_dict = {
226
+ 'sample_gt': {'value': sample_gt, 'type': 'sample'},
227
+ 'sample': {'value': sample, 'type': 'sample'},
228
+ }
229
+ sample_dict.update(dict_reduce(cond_vis, None, {
230
+ 'value': lambda x: torch.cat(x, dim=0),
231
+ 'type': lambda x: x[0],
232
+ }))
233
+
234
+ return sample_dict
235
+
236
+
237
+ class FlowMatchingCFGTrainer(ClassifierFreeGuidanceMixin, FlowMatchingTrainer):
238
+ """
239
+ Trainer for diffusion model with flow matching objective and classifier-free guidance.
240
+
241
+ Args:
242
+ models (dict[str, nn.Module]): Models to train.
243
+ dataset (torch.utils.data.Dataset): Dataset.
244
+ output_dir (str): Output directory.
245
+ load_dir (str): Load directory.
246
+ step (int): Step to load.
247
+ batch_size (int): Batch size.
248
+ batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
249
+ batch_split (int): Split batch with gradient accumulation.
250
+ max_steps (int): Max steps.
251
+ optimizer (dict): Optimizer config.
252
+ lr_scheduler (dict): Learning rate scheduler config.
253
+ elastic (dict): Elastic memory management config.
254
+ grad_clip (float or dict): Gradient clip config.
255
+ ema_rate (float or list): Exponential moving average rates.
256
+ fp16_mode (str): FP16 mode.
257
+ - None: No FP16.
258
+ - 'inflat_all': Hold a inflated fp32 master param for all params.
259
+ - 'amp': Automatic mixed precision.
260
+ fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
261
+ finetune_ckpt (dict): Finetune checkpoint.
262
+ log_param_stats (bool): Log parameter stats.
263
+ i_print (int): Print interval.
264
+ i_log (int): Log interval.
265
+ i_sample (int): Sample interval.
266
+ i_save (int): Save interval.
267
+ i_ddpcheck (int): DDP check interval.
268
+
269
+ t_schedule (dict): Time schedule for flow matching.
270
+ sigma_min (float): Minimum noise level.
271
+ p_uncond (float): Probability of dropping conditions.
272
+ """
273
+ pass
274
+
275
+
276
+ class TextConditionedFlowMatchingCFGTrainer(TextConditionedMixin, FlowMatchingCFGTrainer):
277
+ """
278
+ Trainer for text-conditioned diffusion model with flow matching objective and classifier-free guidance.
279
+
280
+ Args:
281
+ models (dict[str, nn.Module]): Models to train.
282
+ dataset (torch.utils.data.Dataset): Dataset.
283
+ output_dir (str): Output directory.
284
+ load_dir (str): Load directory.
285
+ step (int): Step to load.
286
+ batch_size (int): Batch size.
287
+ batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
288
+ batch_split (int): Split batch with gradient accumulation.
289
+ max_steps (int): Max steps.
290
+ optimizer (dict): Optimizer config.
291
+ lr_scheduler (dict): Learning rate scheduler config.
292
+ elastic (dict): Elastic memory management config.
293
+ grad_clip (float or dict): Gradient clip config.
294
+ ema_rate (float or list): Exponential moving average rates.
295
+ fp16_mode (str): FP16 mode.
296
+ - None: No FP16.
297
+ - 'inflat_all': Hold a inflated fp32 master param for all params.
298
+ - 'amp': Automatic mixed precision.
299
+ fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
300
+ finetune_ckpt (dict): Finetune checkpoint.
301
+ log_param_stats (bool): Log parameter stats.
302
+ i_print (int): Print interval.
303
+ i_log (int): Log interval.
304
+ i_sample (int): Sample interval.
305
+ i_save (int): Save interval.
306
+ i_ddpcheck (int): DDP check interval.
307
+
308
+ t_schedule (dict): Time schedule for flow matching.
309
+ sigma_min (float): Minimum noise level.
310
+ p_uncond (float): Probability of dropping conditions.
311
+ text_cond_model(str): Text conditioning model.
312
+ """
313
+ pass
314
+
315
+
316
+ class ImageConditionedFlowMatchingCFGTrainer(ImageConditionedMixin, FlowMatchingCFGTrainer):
317
+ """
318
+ Trainer for image-conditioned diffusion model with flow matching objective and classifier-free guidance.
319
+
320
+ Args:
321
+ models (dict[str, nn.Module]): Models to train.
322
+ dataset (torch.utils.data.Dataset): Dataset.
323
+ output_dir (str): Output directory.
324
+ load_dir (str): Load directory.
325
+ step (int): Step to load.
326
+ batch_size (int): Batch size.
327
+ batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
328
+ batch_split (int): Split batch with gradient accumulation.
329
+ max_steps (int): Max steps.
330
+ optimizer (dict): Optimizer config.
331
+ lr_scheduler (dict): Learning rate scheduler config.
332
+ elastic (dict): Elastic memory management config.
333
+ grad_clip (float or dict): Gradient clip config.
334
+ ema_rate (float or list): Exponential moving average rates.
335
+ fp16_mode (str): FP16 mode.
336
+ - None: No FP16.
337
+ - 'inflat_all': Hold a inflated fp32 master param for all params.
338
+ - 'amp': Automatic mixed precision.
339
+ fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
340
+ finetune_ckpt (dict): Finetune checkpoint.
341
+ log_param_stats (bool): Log parameter stats.
342
+ i_print (int): Print interval.
343
+ i_log (int): Log interval.
344
+ i_sample (int): Sample interval.
345
+ i_save (int): Save interval.
346
+ i_ddpcheck (int): DDP check interval.
347
+
348
+ t_schedule (dict): Time schedule for flow matching.
349
+ sigma_min (float): Minimum noise level.
350
+ p_uncond (float): Probability of dropping conditions.
351
+ image_cond_model (str): Image conditioning model.
352
+ """
353
+ pass
trellis/trainers/flow_matching/mixins/classifier_free_guidance.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from ....utils.general_utils import dict_foreach
4
+ from ....pipelines import samplers
5
+
6
+
7
+ class ClassifierFreeGuidanceMixin:
8
+ def __init__(self, *args, p_uncond: float = 0.1, **kwargs):
9
+ super().__init__(*args, **kwargs)
10
+ self.p_uncond = p_uncond
11
+
12
+ def get_cond(self, cond, neg_cond=None, **kwargs):
13
+ """
14
+ Get the conditioning data.
15
+ """
16
+ assert neg_cond is not None, "neg_cond must be provided for classifier-free guidance"
17
+
18
+ if self.p_uncond > 0:
19
+ # randomly drop the class label
20
+ def get_batch_size(cond):
21
+ if isinstance(cond, torch.Tensor):
22
+ return cond.shape[0]
23
+ elif isinstance(cond, list):
24
+ return len(cond)
25
+ else:
26
+ raise ValueError(f"Unsupported type of cond: {type(cond)}")
27
+
28
+ ref_cond = cond if not isinstance(cond, dict) else cond[list(cond.keys())[0]]
29
+ B = get_batch_size(ref_cond)
30
+
31
+ def select(cond, neg_cond, mask):
32
+ if isinstance(cond, torch.Tensor):
33
+ mask = torch.tensor(mask, device=cond.device).reshape(-1, *[1] * (cond.ndim - 1))
34
+ return torch.where(mask, neg_cond, cond)
35
+ elif isinstance(cond, list):
36
+ return [nc if m else c for c, nc, m in zip(cond, neg_cond, mask)]
37
+ else:
38
+ raise ValueError(f"Unsupported type of cond: {type(cond)}")
39
+
40
+ mask = list(np.random.rand(B) < self.p_uncond)
41
+ if not isinstance(cond, dict):
42
+ cond = select(cond, neg_cond, mask)
43
+ else:
44
+ cond = dict_foreach([cond, neg_cond], lambda x: select(x[0], x[1], mask))
45
+
46
+ return cond
47
+
48
+ def get_inference_cond(self, cond, neg_cond=None, **kwargs):
49
+ """
50
+ Get the conditioning data for inference.
51
+ """
52
+ assert neg_cond is not None, "neg_cond must be provided for classifier-free guidance"
53
+ return {'cond': cond, 'neg_cond': neg_cond, **kwargs}
54
+
55
+ def get_sampler(self, **kwargs) -> samplers.FlowEulerCfgSampler:
56
+ """
57
+ Get the sampler for the diffusion process.
58
+ """
59
+ return samplers.FlowEulerCfgSampler(self.sigma_min)
trellis/trainers/flow_matching/mixins/image_conditioned.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from torchvision import transforms
5
+ import numpy as np
6
+ from PIL import Image
7
+
8
+ from ....utils import dist_utils
9
+
10
+
11
+ class ImageConditionedMixin:
12
+ """
13
+ Mixin for image-conditioned models.
14
+
15
+ Args:
16
+ image_cond_model: The image conditioning model.
17
+ """
18
+ def __init__(self, *args, image_cond_model: str = 'dinov2_vitl14_reg', **kwargs):
19
+ super().__init__(*args, **kwargs)
20
+ self.image_cond_model_name = image_cond_model
21
+ self.image_cond_model = None # the model is init lazily
22
+
23
+ @staticmethod
24
+ def prepare_for_training(image_cond_model: str, **kwargs):
25
+ """
26
+ Prepare for training.
27
+ """
28
+ if hasattr(super(ImageConditionedMixin, ImageConditionedMixin), 'prepare_for_training'):
29
+ super(ImageConditionedMixin, ImageConditionedMixin).prepare_for_training(**kwargs)
30
+ # download the model
31
+ torch.hub.load('facebookresearch/dinov2', image_cond_model, pretrained=True)
32
+
33
+ def _init_image_cond_model(self):
34
+ """
35
+ Initialize the image conditioning model.
36
+ """
37
+ with dist_utils.local_master_first():
38
+ dinov2_model = torch.hub.load('facebookresearch/dinov2', self.image_cond_model_name, pretrained=True)
39
+ dinov2_model.eval().cuda()
40
+ transform = transforms.Compose([
41
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
42
+ ])
43
+ self.image_cond_model = {
44
+ 'model': dinov2_model,
45
+ 'transform': transform,
46
+ }
47
+
48
+ @torch.no_grad()
49
+ def encode_image(self, image: Union[torch.Tensor, List[Image.Image]]) -> torch.Tensor:
50
+ """
51
+ Encode the image.
52
+ """
53
+ if isinstance(image, torch.Tensor):
54
+ assert image.ndim == 4, "Image tensor should be batched (B, C, H, W)"
55
+ elif isinstance(image, list):
56
+ assert all(isinstance(i, Image.Image) for i in image), "Image list should be list of PIL images"
57
+ image = [i.resize((518, 518), Image.LANCZOS) for i in image]
58
+ image = [np.array(i.convert('RGB')).astype(np.float32) / 255 for i in image]
59
+ image = [torch.from_numpy(i).permute(2, 0, 1).float() for i in image]
60
+ image = torch.stack(image).cuda()
61
+ else:
62
+ raise ValueError(f"Unsupported type of image: {type(image)}")
63
+
64
+ if self.image_cond_model is None:
65
+ self._init_image_cond_model()
66
+ image = self.image_cond_model['transform'](image).cuda()
67
+ features = self.image_cond_model['model'](image, is_training=True)['x_prenorm']
68
+ patchtokens = F.layer_norm(features, features.shape[-1:])
69
+ return patchtokens
70
+
71
+ def get_cond(self, cond, **kwargs):
72
+ """
73
+ Get the conditioning data.
74
+ """
75
+ cond = self.encode_image(cond)
76
+ kwargs['neg_cond'] = torch.zeros_like(cond)
77
+ cond = super().get_cond(cond, **kwargs)
78
+ return cond
79
+
80
+ def get_inference_cond(self, cond, **kwargs):
81
+ """
82
+ Get the conditioning data for inference.
83
+ """
84
+ cond = self.encode_image(cond)
85
+ kwargs['neg_cond'] = torch.zeros_like(cond)
86
+ cond = super().get_inference_cond(cond, **kwargs)
87
+ return cond
88
+
89
+ def vis_cond(self, cond, **kwargs):
90
+ """
91
+ Visualize the conditioning data.
92
+ """
93
+ return {'image': {'value': cond, 'type': 'image'}}
trellis/trainers/flow_matching/mixins/text_conditioned.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import os
3
+ os.environ['TOKENIZERS_PARALLELISM'] = 'true'
4
+ import torch
5
+ from transformers import AutoTokenizer, CLIPTextModel
6
+
7
+ from ....utils import dist_utils
8
+
9
+
10
+ class TextConditionedMixin:
11
+ """
12
+ Mixin for text-conditioned models.
13
+
14
+ Args:
15
+ text_cond_model: The text conditioning model.
16
+ """
17
+ def __init__(self, *args, text_cond_model: str = 'openai/clip-vit-large-patch14', **kwargs):
18
+ super().__init__(*args, **kwargs)
19
+ self.text_cond_model_name = text_cond_model
20
+ self.text_cond_model = None # the model is init lazily
21
+
22
+ def _init_text_cond_model(self):
23
+ """
24
+ Initialize the text conditioning model.
25
+ """
26
+ # load model
27
+ with dist_utils.local_master_first():
28
+ model = CLIPTextModel.from_pretrained(self.text_cond_model_name)
29
+ tokenizer = AutoTokenizer.from_pretrained(self.text_cond_model_name)
30
+ model.eval()
31
+ model = model.cuda()
32
+ self.text_cond_model = {
33
+ 'model': model,
34
+ 'tokenizer': tokenizer,
35
+ }
36
+ self.text_cond_model['null_cond'] = self.encode_text([''])
37
+
38
+ @torch.no_grad()
39
+ def encode_text(self, text: List[str]) -> torch.Tensor:
40
+ """
41
+ Encode the text.
42
+ """
43
+ assert isinstance(text, list) and isinstance(text[0], str), "TextConditionedMixin only supports list of strings as cond"
44
+ if self.text_cond_model is None:
45
+ self._init_text_cond_model()
46
+ encoding = self.text_cond_model['tokenizer'](text, max_length=77, padding='max_length', truncation=True, return_tensors='pt')
47
+ tokens = encoding['input_ids'].cuda()
48
+ embeddings = self.text_cond_model['model'](input_ids=tokens).last_hidden_state
49
+
50
+ return embeddings
51
+
52
+ def get_cond(self, cond, **kwargs):
53
+ """
54
+ Get the conditioning data.
55
+ """
56
+ cond = self.encode_text(cond)
57
+ kwargs['neg_cond'] = self.text_cond_model['null_cond'].repeat(cond.shape[0], 1, 1)
58
+ cond = super().get_cond(cond, **kwargs)
59
+ return cond
60
+
61
+ def get_inference_cond(self, cond, **kwargs):
62
+ """
63
+ Get the conditioning data for inference.
64
+ """
65
+ cond = self.encode_text(cond)
66
+ kwargs['neg_cond'] = self.text_cond_model['null_cond'].repeat(cond.shape[0], 1, 1)
67
+ cond = super().get_inference_cond(cond, **kwargs)
68
+ return cond
trellis/trainers/flow_matching/sparse_flow_matching.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import os
3
+ import copy
4
+ import functools
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch.utils.data import DataLoader
8
+ import numpy as np
9
+ from easydict import EasyDict as edict
10
+
11
+ from ...modules import sparse as sp
12
+ from ...utils.general_utils import dict_reduce
13
+ from ...utils.data_utils import cycle, BalancedResumableSampler
14
+ from .flow_matching import FlowMatchingTrainer
15
+ from .mixins.classifier_free_guidance import ClassifierFreeGuidanceMixin
16
+ from .mixins.text_conditioned import TextConditionedMixin
17
+ from .mixins.image_conditioned import ImageConditionedMixin
18
+
19
+
20
+ class SparseFlowMatchingTrainer(FlowMatchingTrainer):
21
+ """
22
+ Trainer for sparse diffusion model with flow matching objective.
23
+
24
+ Args:
25
+ models (dict[str, nn.Module]): Models to train.
26
+ dataset (torch.utils.data.Dataset): Dataset.
27
+ output_dir (str): Output directory.
28
+ load_dir (str): Load directory.
29
+ step (int): Step to load.
30
+ batch_size (int): Batch size.
31
+ batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
32
+ batch_split (int): Split batch with gradient accumulation.
33
+ max_steps (int): Max steps.
34
+ optimizer (dict): Optimizer config.
35
+ lr_scheduler (dict): Learning rate scheduler config.
36
+ elastic (dict): Elastic memory management config.
37
+ grad_clip (float or dict): Gradient clip config.
38
+ ema_rate (float or list): Exponential moving average rates.
39
+ fp16_mode (str): FP16 mode.
40
+ - None: No FP16.
41
+ - 'inflat_all': Hold a inflated fp32 master param for all params.
42
+ - 'amp': Automatic mixed precision.
43
+ fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
44
+ finetune_ckpt (dict): Finetune checkpoint.
45
+ log_param_stats (bool): Log parameter stats.
46
+ i_print (int): Print interval.
47
+ i_log (int): Log interval.
48
+ i_sample (int): Sample interval.
49
+ i_save (int): Save interval.
50
+ i_ddpcheck (int): DDP check interval.
51
+
52
+ t_schedule (dict): Time schedule for flow matching.
53
+ sigma_min (float): Minimum noise level.
54
+ """
55
+
56
+ def prepare_dataloader(self, **kwargs):
57
+ """
58
+ Prepare dataloader.
59
+ """
60
+ self.data_sampler = BalancedResumableSampler(
61
+ self.dataset,
62
+ shuffle=True,
63
+ batch_size=self.batch_size_per_gpu,
64
+ )
65
+ self.dataloader = DataLoader(
66
+ self.dataset,
67
+ batch_size=self.batch_size_per_gpu,
68
+ num_workers=int(np.ceil(os.cpu_count() / torch.cuda.device_count())),
69
+ pin_memory=True,
70
+ drop_last=True,
71
+ persistent_workers=True,
72
+ collate_fn=functools.partial(self.dataset.collate_fn, split_size=self.batch_split),
73
+ sampler=self.data_sampler,
74
+ )
75
+ self.data_iterator = cycle(self.dataloader)
76
+
77
+ def training_losses(
78
+ self,
79
+ x_0: sp.SparseTensor,
80
+ cond=None,
81
+ **kwargs
82
+ ) -> Tuple[Dict, Dict]:
83
+ """
84
+ Compute training losses for a single timestep.
85
+
86
+ Args:
87
+ x_0: The [N x ... x C] sparse tensor of the inputs.
88
+ cond: The [N x ...] tensor of additional conditions.
89
+ kwargs: Additional arguments to pass to the backbone.
90
+
91
+ Returns:
92
+ a dict with the key "loss" containing a tensor of shape [N].
93
+ may also contain other keys for different terms.
94
+ """
95
+ noise = x_0.replace(torch.randn_like(x_0.feats))
96
+ t = self.sample_t(x_0.shape[0]).to(x_0.device).float()
97
+ x_t = self.diffuse(x_0, t, noise=noise)
98
+ cond = self.get_cond(cond, **kwargs)
99
+
100
+ pred = self.training_models['denoiser'](x_t, t * 1000, cond, **kwargs)
101
+ assert pred.shape == noise.shape == x_0.shape
102
+ target = self.get_v(x_0, noise, t)
103
+ terms = edict()
104
+ terms["mse"] = F.mse_loss(pred.feats, target.feats)
105
+ terms["loss"] = terms["mse"]
106
+
107
+ # log loss with time bins
108
+ mse_per_instance = np.array([
109
+ F.mse_loss(pred.feats[x_0.layout[i]], target.feats[x_0.layout[i]]).item()
110
+ for i in range(x_0.shape[0])
111
+ ])
112
+ time_bin = np.digitize(t.cpu().numpy(), np.linspace(0, 1, 11)) - 1
113
+ for i in range(10):
114
+ if (time_bin == i).sum() != 0:
115
+ terms[f"bin_{i}"] = {"mse": mse_per_instance[time_bin == i].mean()}
116
+
117
+ return terms, {}
118
+
119
+ @torch.no_grad()
120
+ def run_snapshot(
121
+ self,
122
+ num_samples: int,
123
+ batch_size: int,
124
+ verbose: bool = False,
125
+ ) -> Dict:
126
+ dataloader = DataLoader(
127
+ copy.deepcopy(self.dataset),
128
+ batch_size=batch_size,
129
+ shuffle=True,
130
+ num_workers=0,
131
+ collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None,
132
+ )
133
+
134
+ # inference
135
+ sampler = self.get_sampler()
136
+ sample_gt = []
137
+ sample = []
138
+ cond_vis = []
139
+ for i in range(0, num_samples, batch_size):
140
+ batch = min(batch_size, num_samples - i)
141
+ data = next(iter(dataloader))
142
+ data = {k: v[:batch].cuda() if not isinstance(v, list) else v[:batch] for k, v in data.items()}
143
+ noise = data['x_0'].replace(torch.randn_like(data['x_0'].feats))
144
+ sample_gt.append(data['x_0'])
145
+ cond_vis.append(self.vis_cond(**data))
146
+ del data['x_0']
147
+ args = self.get_inference_cond(**data)
148
+ res = sampler.sample(
149
+ self.models['denoiser'],
150
+ noise=noise,
151
+ **args,
152
+ steps=50, cfg_strength=3.0, verbose=verbose,
153
+ )
154
+ sample.append(res.samples)
155
+
156
+ sample_gt = sp.sparse_cat(sample_gt)
157
+ sample = sp.sparse_cat(sample)
158
+ sample_dict = {
159
+ 'sample_gt': {'value': sample_gt, 'type': 'sample'},
160
+ 'sample': {'value': sample, 'type': 'sample'},
161
+ }
162
+ sample_dict.update(dict_reduce(cond_vis, None, {
163
+ 'value': lambda x: torch.cat(x, dim=0),
164
+ 'type': lambda x: x[0],
165
+ }))
166
+
167
+ return sample_dict
168
+
169
+
170
+ class SparseFlowMatchingCFGTrainer(ClassifierFreeGuidanceMixin, SparseFlowMatchingTrainer):
171
+ """
172
+ Trainer for sparse diffusion model with flow matching objective and classifier-free guidance.
173
+
174
+ Args:
175
+ models (dict[str, nn.Module]): Models to train.
176
+ dataset (torch.utils.data.Dataset): Dataset.
177
+ output_dir (str): Output directory.
178
+ load_dir (str): Load directory.
179
+ step (int): Step to load.
180
+ batch_size (int): Batch size.
181
+ batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
182
+ batch_split (int): Split batch with gradient accumulation.
183
+ max_steps (int): Max steps.
184
+ optimizer (dict): Optimizer config.
185
+ lr_scheduler (dict): Learning rate scheduler config.
186
+ elastic (dict): Elastic memory management config.
187
+ grad_clip (float or dict): Gradient clip config.
188
+ ema_rate (float or list): Exponential moving average rates.
189
+ fp16_mode (str): FP16 mode.
190
+ - None: No FP16.
191
+ - 'inflat_all': Hold a inflated fp32 master param for all params.
192
+ - 'amp': Automatic mixed precision.
193
+ fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
194
+ finetune_ckpt (dict): Finetune checkpoint.
195
+ log_param_stats (bool): Log parameter stats.
196
+ i_print (int): Print interval.
197
+ i_log (int): Log interval.
198
+ i_sample (int): Sample interval.
199
+ i_save (int): Save interval.
200
+ i_ddpcheck (int): DDP check interval.
201
+
202
+ t_schedule (dict): Time schedule for flow matching.
203
+ sigma_min (float): Minimum noise level.
204
+ p_uncond (float): Probability of dropping conditions.
205
+ """
206
+ pass
207
+
208
+
209
+ class TextConditionedSparseFlowMatchingCFGTrainer(TextConditionedMixin, SparseFlowMatchingCFGTrainer):
210
+ """
211
+ Trainer for sparse text-conditioned diffusion model with flow matching objective and classifier-free guidance.
212
+
213
+ Args:
214
+ models (dict[str, nn.Module]): Models to train.
215
+ dataset (torch.utils.data.Dataset): Dataset.
216
+ output_dir (str): Output directory.
217
+ load_dir (str): Load directory.
218
+ step (int): Step to load.
219
+ batch_size (int): Batch size.
220
+ batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
221
+ batch_split (int): Split batch with gradient accumulation.
222
+ max_steps (int): Max steps.
223
+ optimizer (dict): Optimizer config.
224
+ lr_scheduler (dict): Learning rate scheduler config.
225
+ elastic (dict): Elastic memory management config.
226
+ grad_clip (float or dict): Gradient clip config.
227
+ ema_rate (float or list): Exponential moving average rates.
228
+ fp16_mode (str): FP16 mode.
229
+ - None: No FP16.
230
+ - 'inflat_all': Hold a inflated fp32 master param for all params.
231
+ - 'amp': Automatic mixed precision.
232
+ fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
233
+ finetune_ckpt (dict): Finetune checkpoint.
234
+ log_param_stats (bool): Log parameter stats.
235
+ i_print (int): Print interval.
236
+ i_log (int): Log interval.
237
+ i_sample (int): Sample interval.
238
+ i_save (int): Save interval.
239
+ i_ddpcheck (int): DDP check interval.
240
+
241
+ t_schedule (dict): Time schedule for flow matching.
242
+ sigma_min (float): Minimum noise level.
243
+ p_uncond (float): Probability of dropping conditions.
244
+ text_cond_model(str): Text conditioning model.
245
+ """
246
+ pass
247
+
248
+
249
+ class ImageConditionedSparseFlowMatchingCFGTrainer(ImageConditionedMixin, SparseFlowMatchingCFGTrainer):
250
+ """
251
+ Trainer for sparse image-conditioned diffusion model with flow matching objective and classifier-free guidance.
252
+
253
+ Args:
254
+ models (dict[str, nn.Module]): Models to train.
255
+ dataset (torch.utils.data.Dataset): Dataset.
256
+ output_dir (str): Output directory.
257
+ load_dir (str): Load directory.
258
+ step (int): Step to load.
259
+ batch_size (int): Batch size.
260
+ batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
261
+ batch_split (int): Split batch with gradient accumulation.
262
+ max_steps (int): Max steps.
263
+ optimizer (dict): Optimizer config.
264
+ lr_scheduler (dict): Learning rate scheduler config.
265
+ elastic (dict): Elastic memory management config.
266
+ grad_clip (float or dict): Gradient clip config.
267
+ ema_rate (float or list): Exponential moving average rates.
268
+ fp16_mode (str): FP16 mode.
269
+ - None: No FP16.
270
+ - 'inflat_all': Hold a inflated fp32 master param for all params.
271
+ - 'amp': Automatic mixed precision.
272
+ fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
273
+ finetune_ckpt (dict): Finetune checkpoint.
274
+ log_param_stats (bool): Log parameter stats.
275
+ i_print (int): Print interval.
276
+ i_log (int): Log interval.
277
+ i_sample (int): Sample interval.
278
+ i_save (int): Save interval.
279
+ i_ddpcheck (int): DDP check interval.
280
+
281
+ t_schedule (dict): Time schedule for flow matching.
282
+ sigma_min (float): Minimum noise level.
283
+ p_uncond (float): Probability of dropping conditions.
284
+ image_cond_model (str): Image conditioning model.
285
+ """
286
+ pass
trellis/trainers/utils.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+
4
+ # FP16 utils
5
+ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
6
+
7
+ def make_master_params(model_params):
8
+ """
9
+ Copy model parameters into a inflated tensor of full-precision parameters.
10
+ """
11
+ master_params = _flatten_dense_tensors(
12
+ [param.detach().float() for param in model_params]
13
+ )
14
+ master_params = nn.Parameter(master_params)
15
+ master_params.requires_grad = True
16
+ return [master_params]
17
+
18
+
19
+ def unflatten_master_params(model_params, master_params):
20
+ """
21
+ Unflatten the master parameters to look like model_params.
22
+ """
23
+ return _unflatten_dense_tensors(master_params[0].detach(), model_params)
24
+
25
+
26
+ def model_params_to_master_params(model_params, master_params):
27
+ """
28
+ Copy the model parameter data into the master parameters.
29
+ """
30
+ master_params[0].detach().copy_(
31
+ _flatten_dense_tensors([param.detach().float() for param in model_params])
32
+ )
33
+
34
+
35
+ def master_params_to_model_params(model_params, master_params):
36
+ """
37
+ Copy the master parameter data back into the model parameters.
38
+ """
39
+ for param, master_param in zip(
40
+ model_params, _unflatten_dense_tensors(master_params[0].detach(), model_params)
41
+ ):
42
+ param.detach().copy_(master_param)
43
+
44
+
45
+ def model_grads_to_master_grads(model_params, master_params):
46
+ """
47
+ Copy the gradients from the model parameters into the master parameters
48
+ from make_master_params().
49
+ """
50
+ master_params[0].grad = _flatten_dense_tensors(
51
+ [param.grad.data.detach().float() for param in model_params]
52
+ )
53
+
54
+
55
+ def zero_grad(model_params):
56
+ for param in model_params:
57
+ if param.grad is not None:
58
+ if param.grad.grad_fn is not None:
59
+ param.grad.detach_()
60
+ else:
61
+ param.grad.requires_grad_(False)
62
+ param.grad.zero_()
63
+
64
+
65
+ # LR Schedulers
66
+ from torch.optim.lr_scheduler import LambdaLR
67
+
68
+ class LinearWarmupLRScheduler(LambdaLR):
69
+ def __init__(self, optimizer, warmup_steps, last_epoch=-1):
70
+ self.warmup_steps = warmup_steps
71
+ super(LinearWarmupLRScheduler, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch)
72
+
73
+ def lr_lambda(self, current_step):
74
+ if current_step < self.warmup_steps:
75
+ return float(current_step + 1) / self.warmup_steps
76
+ return 1.0
77
+
trellis/trainers/vae/sparse_structure_vae.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import copy
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch.utils.data import DataLoader
6
+ from easydict import EasyDict as edict
7
+
8
+ from ..basic import BasicTrainer
9
+
10
+
11
+ class SparseStructureVaeTrainer(BasicTrainer):
12
+ """
13
+ Trainer for Sparse Structure VAE.
14
+
15
+ Args:
16
+ models (dict[str, nn.Module]): Models to train.
17
+ dataset (torch.utils.data.Dataset): Dataset.
18
+ output_dir (str): Output directory.
19
+ load_dir (str): Load directory.
20
+ step (int): Step to load.
21
+ batch_size (int): Batch size.
22
+ batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
23
+ batch_split (int): Split batch with gradient accumulation.
24
+ max_steps (int): Max steps.
25
+ optimizer (dict): Optimizer config.
26
+ lr_scheduler (dict): Learning rate scheduler config.
27
+ elastic (dict): Elastic memory management config.
28
+ grad_clip (float or dict): Gradient clip config.
29
+ ema_rate (float or list): Exponential moving average rates.
30
+ fp16_mode (str): FP16 mode.
31
+ - None: No FP16.
32
+ - 'inflat_all': Hold a inflated fp32 master param for all params.
33
+ - 'amp': Automatic mixed precision.
34
+ fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
35
+ finetune_ckpt (dict): Finetune checkpoint.
36
+ log_param_stats (bool): Log parameter stats.
37
+ i_print (int): Print interval.
38
+ i_log (int): Log interval.
39
+ i_sample (int): Sample interval.
40
+ i_save (int): Save interval.
41
+ i_ddpcheck (int): DDP check interval.
42
+
43
+ loss_type (str): Loss type. 'bce' for binary cross entropy, 'l1' for L1 loss, 'dice' for Dice loss.
44
+ lambda_kl (float): KL divergence loss weight.
45
+ """
46
+
47
+ def __init__(
48
+ self,
49
+ *args,
50
+ loss_type='bce',
51
+ lambda_kl=1e-6,
52
+ **kwargs
53
+ ):
54
+ super().__init__(*args, **kwargs)
55
+ self.loss_type = loss_type
56
+ self.lambda_kl = lambda_kl
57
+
58
+ def training_losses(
59
+ self,
60
+ ss: torch.Tensor,
61
+ **kwargs
62
+ ) -> Tuple[Dict, Dict]:
63
+ """
64
+ Compute training losses.
65
+
66
+ Args:
67
+ ss: The [N x 1 x H x W x D] tensor of binary sparse structure.
68
+
69
+ Returns:
70
+ a dict with the key "loss" containing a scalar tensor.
71
+ may also contain other keys for different terms.
72
+ """
73
+ z, mean, logvar = self.training_models['encoder'](ss.float(), sample_posterior=True, return_raw=True)
74
+ logits = self.training_models['decoder'](z)
75
+
76
+ terms = edict(loss = 0.0)
77
+ if self.loss_type == 'bce':
78
+ terms["bce"] = F.binary_cross_entropy_with_logits(logits, ss.float(), reduction='mean')
79
+ terms["loss"] = terms["loss"] + terms["bce"]
80
+ elif self.loss_type == 'l1':
81
+ terms["l1"] = F.l1_loss(F.sigmoid(logits), ss.float(), reduction='mean')
82
+ terms["loss"] = terms["loss"] + terms["l1"]
83
+ elif self.loss_type == 'dice':
84
+ logits = F.sigmoid(logits)
85
+ terms["dice"] = 1 - (2 * (logits * ss.float()).sum() + 1) / (logits.sum() + ss.float().sum() + 1)
86
+ terms["loss"] = terms["loss"] + terms["dice"]
87
+ else:
88
+ raise ValueError(f'Invalid loss type {self.loss_type}')
89
+ terms["kl"] = 0.5 * torch.mean(mean.pow(2) + logvar.exp() - logvar - 1)
90
+ terms["loss"] = terms["loss"] + self.lambda_kl * terms["kl"]
91
+
92
+ return terms, {}
93
+
94
+ @torch.no_grad()
95
+ def snapshot(self, suffix=None, num_samples=64, batch_size=1, verbose=False):
96
+ super().snapshot(suffix=suffix, num_samples=num_samples, batch_size=batch_size, verbose=verbose)
97
+
98
+ @torch.no_grad()
99
+ def run_snapshot(
100
+ self,
101
+ num_samples: int,
102
+ batch_size: int,
103
+ verbose: bool = False,
104
+ ) -> Dict:
105
+ dataloader = DataLoader(
106
+ copy.deepcopy(self.dataset),
107
+ batch_size=batch_size,
108
+ shuffle=True,
109
+ num_workers=0,
110
+ collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None,
111
+ )
112
+
113
+ # inference
114
+ gts = []
115
+ recons = []
116
+ for i in range(0, num_samples, batch_size):
117
+ batch = min(batch_size, num_samples - i)
118
+ data = next(iter(dataloader))
119
+ args = {k: v[:batch].cuda() if isinstance(v, torch.Tensor) else v[:batch] for k, v in data.items()}
120
+ z = self.models['encoder'](args['ss'].float(), sample_posterior=False)
121
+ logits = self.models['decoder'](z)
122
+ recon = (logits > 0).long()
123
+ gts.append(args['ss'])
124
+ recons.append(recon)
125
+
126
+ sample_dict = {
127
+ 'gt': {'value': torch.cat(gts, dim=0), 'type': 'sample'},
128
+ 'recon': {'value': torch.cat(recons, dim=0), 'type': 'sample'},
129
+ }
130
+ return sample_dict
trellis/trainers/vae/structured_latent_vae_gaussian.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import copy
3
+ import torch
4
+ from torch.utils.data import DataLoader
5
+ import numpy as np
6
+ from easydict import EasyDict as edict
7
+ import utils3d.torch
8
+
9
+ from ..basic import BasicTrainer
10
+ from ...representations import Gaussian
11
+ from ...renderers import GaussianRenderer
12
+ from ...modules.sparse import SparseTensor
13
+ from ...utils.loss_utils import l1_loss, l2_loss, ssim, lpips
14
+
15
+
16
+ class SLatVaeGaussianTrainer(BasicTrainer):
17
+ """
18
+ Trainer for structured latent VAE.
19
+
20
+ Args:
21
+ models (dict[str, nn.Module]): Models to train.
22
+ dataset (torch.utils.data.Dataset): Dataset.
23
+ output_dir (str): Output directory.
24
+ load_dir (str): Load directory.
25
+ step (int): Step to load.
26
+ batch_size (int): Batch size.
27
+ batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
28
+ batch_split (int): Split batch with gradient accumulation.
29
+ max_steps (int): Max steps.
30
+ optimizer (dict): Optimizer config.
31
+ lr_scheduler (dict): Learning rate scheduler config.
32
+ elastic (dict): Elastic memory management config.
33
+ grad_clip (float or dict): Gradient clip config.
34
+ ema_rate (float or list): Exponential moving average rates.
35
+ fp16_mode (str): FP16 mode.
36
+ - None: No FP16.
37
+ - 'inflat_all': Hold a inflated fp32 master param for all params.
38
+ - 'amp': Automatic mixed precision.
39
+ fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
40
+ finetune_ckpt (dict): Finetune checkpoint.
41
+ log_param_stats (bool): Log parameter stats.
42
+ i_print (int): Print interval.
43
+ i_log (int): Log interval.
44
+ i_sample (int): Sample interval.
45
+ i_save (int): Save interval.
46
+ i_ddpcheck (int): DDP check interval.
47
+
48
+ loss_type (str): Loss type. Can be 'l1', 'l2'
49
+ lambda_ssim (float): SSIM loss weight.
50
+ lambda_lpips (float): LPIPS loss weight.
51
+ lambda_kl (float): KL loss weight.
52
+ regularizations (dict): Regularization config.
53
+ """
54
+
55
+ def __init__(
56
+ self,
57
+ *args,
58
+ loss_type: str = 'l1',
59
+ lambda_ssim: float = 0.2,
60
+ lambda_lpips: float = 0.2,
61
+ lambda_kl: float = 1e-6,
62
+ regularizations: Dict = {},
63
+ **kwargs
64
+ ):
65
+ super().__init__(*args, **kwargs)
66
+ self.loss_type = loss_type
67
+ self.lambda_ssim = lambda_ssim
68
+ self.lambda_lpips = lambda_lpips
69
+ self.lambda_kl = lambda_kl
70
+ self.regularizations = regularizations
71
+
72
+ self._init_renderer()
73
+
74
+ def _init_renderer(self):
75
+ rendering_options = {"near" : 0.8,
76
+ "far" : 1.6,
77
+ "bg_color" : 'random'}
78
+ self.renderer = GaussianRenderer(rendering_options)
79
+ self.renderer.pipe.kernel_size = self.models['decoder'].rep_config['2d_filter_kernel_size']
80
+
81
+ def _render_batch(self, reps: List[Gaussian], extrinsics: torch.Tensor, intrinsics: torch.Tensor) -> torch.Tensor:
82
+ """
83
+ Render a batch of representations.
84
+
85
+ Args:
86
+ reps: The dictionary of lists of representations.
87
+ extrinsics: The [N x 4 x 4] tensor of extrinsics.
88
+ intrinsics: The [N x 3 x 3] tensor of intrinsics.
89
+ """
90
+ ret = None
91
+ for i, representation in enumerate(reps):
92
+ render_pack = self.renderer.render(representation, extrinsics[i], intrinsics[i])
93
+ if ret is None:
94
+ ret = {k: [] for k in list(render_pack.keys()) + ['bg_color']}
95
+ for k, v in render_pack.items():
96
+ ret[k].append(v)
97
+ ret['bg_color'].append(self.renderer.bg_color)
98
+ for k, v in ret.items():
99
+ ret[k] = torch.stack(v, dim=0)
100
+ return ret
101
+
102
+ @torch.no_grad()
103
+ def _get_status(self, z: SparseTensor, reps: List[Gaussian]) -> Dict:
104
+ xyz = torch.cat([g.get_xyz for g in reps], dim=0)
105
+ xyz_base = (z.coords[:, 1:].float() + 0.5) / self.models['decoder'].resolution - 0.5
106
+ offset = xyz - xyz_base.unsqueeze(1).expand(-1, self.models['decoder'].rep_config['num_gaussians'], -1).reshape(-1, 3)
107
+ status = {
108
+ 'xyz': xyz,
109
+ 'offset': offset,
110
+ 'scale': torch.cat([g.get_scaling for g in reps], dim=0),
111
+ 'opacity': torch.cat([g.get_opacity for g in reps], dim=0),
112
+ }
113
+
114
+ for k in list(status.keys()):
115
+ status[k] = {
116
+ 'mean': status[k].mean().item(),
117
+ 'max': status[k].max().item(),
118
+ 'min': status[k].min().item(),
119
+ }
120
+
121
+ return status
122
+
123
+ def _get_regularization_loss(self, reps: List[Gaussian]) -> Tuple[torch.Tensor, Dict]:
124
+ loss = 0.0
125
+ terms = {}
126
+ if 'lambda_vol' in self.regularizations:
127
+ scales = torch.cat([g.get_scaling for g in reps], dim=0) # [N x 3]
128
+ volume = torch.prod(scales, dim=1) # [N]
129
+ terms[f'reg_vol'] = volume.mean()
130
+ loss = loss + self.regularizations['lambda_vol'] * terms[f'reg_vol']
131
+ if 'lambda_opacity' in self.regularizations:
132
+ opacity = torch.cat([g.get_opacity for g in reps], dim=0)
133
+ terms[f'reg_opacity'] = (opacity - 1).pow(2).mean()
134
+ loss = loss + self.regularizations['lambda_opacity'] * terms[f'reg_opacity']
135
+ return loss, terms
136
+
137
+ def training_losses(
138
+ self,
139
+ feats: SparseTensor,
140
+ image: torch.Tensor,
141
+ alpha: torch.Tensor,
142
+ extrinsics: torch.Tensor,
143
+ intrinsics: torch.Tensor,
144
+ return_aux: bool = False,
145
+ **kwargs
146
+ ) -> Tuple[Dict, Dict]:
147
+ """
148
+ Compute training losses.
149
+
150
+ Args:
151
+ feats: The [N x * x C] sparse tensor of features.
152
+ image: The [N x 3 x H x W] tensor of images.
153
+ alpha: The [N x H x W] tensor of alpha channels.
154
+ extrinsics: The [N x 4 x 4] tensor of extrinsics.
155
+ intrinsics: The [N x 3 x 3] tensor of intrinsics.
156
+ return_aux: Whether to return auxiliary information.
157
+
158
+ Returns:
159
+ a dict with the key "loss" containing a scalar tensor.
160
+ may also contain other keys for different terms.
161
+ """
162
+ z, mean, logvar = self.training_models['encoder'](feats, sample_posterior=True, return_raw=True)
163
+ reps = self.training_models['decoder'](z)
164
+ self.renderer.rendering_options.resolution = image.shape[-1]
165
+ render_results = self._render_batch(reps, extrinsics, intrinsics)
166
+
167
+ terms = edict(loss = 0.0, rec = 0.0)
168
+
169
+ rec_image = render_results['color']
170
+ gt_image = image * alpha[:, None] + (1 - alpha[:, None]) * render_results['bg_color'][..., None, None]
171
+
172
+ if self.loss_type == 'l1':
173
+ terms["l1"] = l1_loss(rec_image, gt_image)
174
+ terms["rec"] = terms["rec"] + terms["l1"]
175
+ elif self.loss_type == 'l2':
176
+ terms["l2"] = l2_loss(rec_image, gt_image)
177
+ terms["rec"] = terms["rec"] + terms["l2"]
178
+ else:
179
+ raise ValueError(f"Invalid loss type: {self.loss_type}")
180
+ if self.lambda_ssim > 0:
181
+ terms["ssim"] = 1 - ssim(rec_image, gt_image)
182
+ terms["rec"] = terms["rec"] + self.lambda_ssim * terms["ssim"]
183
+ if self.lambda_lpips > 0:
184
+ terms["lpips"] = lpips(rec_image, gt_image)
185
+ terms["rec"] = terms["rec"] + self.lambda_lpips * terms["lpips"]
186
+ terms["loss"] = terms["loss"] + terms["rec"]
187
+
188
+ terms["kl"] = 0.5 * torch.mean(mean.pow(2) + logvar.exp() - logvar - 1)
189
+ terms["loss"] = terms["loss"] + self.lambda_kl * terms["kl"]
190
+
191
+ reg_loss, reg_terms = self._get_regularization_loss(reps)
192
+ terms.update(reg_terms)
193
+ terms["loss"] = terms["loss"] + reg_loss
194
+
195
+ status = self._get_status(z, reps)
196
+
197
+ if return_aux:
198
+ return terms, status, {'rec_image': rec_image, 'gt_image': gt_image}
199
+ return terms, status
200
+
201
+ @torch.no_grad()
202
+ def run_snapshot(
203
+ self,
204
+ num_samples: int,
205
+ batch_size: int,
206
+ verbose: bool = False,
207
+ ) -> Dict:
208
+ dataloader = DataLoader(
209
+ copy.deepcopy(self.dataset),
210
+ batch_size=batch_size,
211
+ shuffle=True,
212
+ num_workers=0,
213
+ collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None,
214
+ )
215
+
216
+ # inference
217
+ ret_dict = {}
218
+ gt_images = []
219
+ exts = []
220
+ ints = []
221
+ reps = []
222
+ for i in range(0, num_samples, batch_size):
223
+ batch = min(batch_size, num_samples - i)
224
+ data = next(iter(dataloader))
225
+ args = {k: v[:batch].cuda() for k, v in data.items()}
226
+ gt_images.append(args['image'] * args['alpha'][:, None])
227
+ exts.append(args['extrinsics'])
228
+ ints.append(args['intrinsics'])
229
+ z = self.models['encoder'](args['feats'], sample_posterior=True, return_raw=False)
230
+ reps.extend(self.models['decoder'](z))
231
+ gt_images = torch.cat(gt_images, dim=0)
232
+ ret_dict.update({f'gt_image': {'value': gt_images, 'type': 'image'}})
233
+
234
+ # render single view
235
+ exts = torch.cat(exts, dim=0)
236
+ ints = torch.cat(ints, dim=0)
237
+ self.renderer.rendering_options.bg_color = (0, 0, 0)
238
+ self.renderer.rendering_options.resolution = gt_images.shape[-1]
239
+ render_results = self._render_batch(reps, exts, ints)
240
+ ret_dict.update({f'rec_image': {'value': render_results['color'], 'type': 'image'}})
241
+
242
+ # render multiview
243
+ self.renderer.rendering_options.resolution = 512
244
+ ## Build camera
245
+ yaws = [0, np.pi / 2, np.pi, 3 * np.pi / 2]
246
+ yaws_offset = np.random.uniform(-np.pi / 4, np.pi / 4)
247
+ yaws = [y + yaws_offset for y in yaws]
248
+ pitch = [np.random.uniform(-np.pi / 4, np.pi / 4) for _ in range(4)]
249
+
250
+ ## render each view
251
+ miltiview_images = []
252
+ for yaw, pitch in zip(yaws, pitch):
253
+ orig = torch.tensor([
254
+ np.sin(yaw) * np.cos(pitch),
255
+ np.cos(yaw) * np.cos(pitch),
256
+ np.sin(pitch),
257
+ ]).float().cuda() * 2
258
+ fov = torch.deg2rad(torch.tensor(30)).cuda()
259
+ extrinsics = utils3d.torch.extrinsics_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda())
260
+ intrinsics = utils3d.torch.intrinsics_from_fov_xy(fov, fov)
261
+ extrinsics = extrinsics.unsqueeze(0).expand(num_samples, -1, -1)
262
+ intrinsics = intrinsics.unsqueeze(0).expand(num_samples, -1, -1)
263
+ render_results = self._render_batch(reps, extrinsics, intrinsics)
264
+ miltiview_images.append(render_results['color'])
265
+
266
+ ## Concatenate views
267
+ miltiview_images = torch.cat([
268
+ torch.cat(miltiview_images[:2], dim=-2),
269
+ torch.cat(miltiview_images[2:], dim=-2),
270
+ ], dim=-1)
271
+ ret_dict.update({f'miltiview_image': {'value': miltiview_images, 'type': 'image'}})
272
+
273
+ self.renderer.rendering_options.bg_color = 'random'
274
+
275
+ return ret_dict
trellis/trainers/vae/structured_latent_vae_mesh_dec.py ADDED
@@ -0,0 +1,382 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import copy
3
+ import torch
4
+ from torch.utils.data import DataLoader
5
+ import numpy as np
6
+ from easydict import EasyDict as edict
7
+ import utils3d.torch
8
+
9
+ from ..basic import BasicTrainer
10
+ from ...representations import MeshExtractResult
11
+ from ...renderers import MeshRenderer
12
+ from ...modules.sparse import SparseTensor
13
+ from ...utils.loss_utils import l1_loss, smooth_l1_loss, ssim, lpips
14
+ from ...utils.data_utils import recursive_to_device
15
+
16
+
17
+ class SLatVaeMeshDecoderTrainer(BasicTrainer):
18
+ """
19
+ Trainer for structured latent VAE Mesh Decoder.
20
+
21
+ Args:
22
+ models (dict[str, nn.Module]): Models to train.
23
+ dataset (torch.utils.data.Dataset): Dataset.
24
+ output_dir (str): Output directory.
25
+ load_dir (str): Load directory.
26
+ step (int): Step to load.
27
+ batch_size (int): Batch size.
28
+ batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
29
+ batch_split (int): Split batch with gradient accumulation.
30
+ max_steps (int): Max steps.
31
+ optimizer (dict): Optimizer config.
32
+ lr_scheduler (dict): Learning rate scheduler config.
33
+ elastic (dict): Elastic memory management config.
34
+ grad_clip (float or dict): Gradient clip config.
35
+ ema_rate (float or list): Exponential moving average rates.
36
+ fp16_mode (str): FP16 mode.
37
+ - None: No FP16.
38
+ - 'inflat_all': Hold a inflated fp32 master param for all params.
39
+ - 'amp': Automatic mixed precision.
40
+ fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
41
+ finetune_ckpt (dict): Finetune checkpoint.
42
+ log_param_stats (bool): Log parameter stats.
43
+ i_print (int): Print interval.
44
+ i_log (int): Log interval.
45
+ i_sample (int): Sample interval.
46
+ i_save (int): Save interval.
47
+ i_ddpcheck (int): DDP check interval.
48
+
49
+ loss_type (str): Loss type. Can be 'l1', 'l2'
50
+ lambda_ssim (float): SSIM loss weight.
51
+ lambda_lpips (float): LPIPS loss weight.
52
+ """
53
+
54
+ def __init__(
55
+ self,
56
+ *args,
57
+ depth_loss_type: str = 'l1',
58
+ lambda_depth: int = 1,
59
+ lambda_ssim: float = 0.2,
60
+ lambda_lpips: float = 0.2,
61
+ lambda_tsdf: float = 0.01,
62
+ lambda_color: float = 0.1,
63
+ **kwargs
64
+ ):
65
+ super().__init__(*args, **kwargs)
66
+ self.depth_loss_type = depth_loss_type
67
+ self.lambda_depth = lambda_depth
68
+ self.lambda_ssim = lambda_ssim
69
+ self.lambda_lpips = lambda_lpips
70
+ self.lambda_tsdf = lambda_tsdf
71
+ self.lambda_color = lambda_color
72
+ self.use_color = self.lambda_color > 0
73
+
74
+ self._init_renderer()
75
+
76
+ def _init_renderer(self):
77
+ rendering_options = {"near" : 1,
78
+ "far" : 3}
79
+ self.renderer = MeshRenderer(rendering_options, device=self.device)
80
+
81
+ def _render_batch(self, reps: List[MeshExtractResult], extrinsics: torch.Tensor, intrinsics: torch.Tensor,
82
+ return_types=['mask', 'normal', 'depth']) -> Dict[str, torch.Tensor]:
83
+ """
84
+ Render a batch of representations.
85
+
86
+ Args:
87
+ reps: The dictionary of lists of representations.
88
+ extrinsics: The [N x 4 x 4] tensor of extrinsics.
89
+ intrinsics: The [N x 3 x 3] tensor of intrinsics.
90
+ return_types: vary in ['mask', 'normal', 'depth', 'normal_map', 'color']
91
+
92
+ Returns:
93
+ a dict with
94
+ reg_loss : [N] tensor of regularization losses
95
+ mask : [N x 1 x H x W] tensor of rendered masks
96
+ normal : [N x 3 x H x W] tensor of rendered normals
97
+ depth : [N x 1 x H x W] tensor of rendered depths
98
+ """
99
+ ret = {k : [] for k in return_types}
100
+ for i, rep in enumerate(reps):
101
+ out_dict = self.renderer.render(rep, extrinsics[i], intrinsics[i], return_types=return_types)
102
+ for k in out_dict:
103
+ ret[k].append(out_dict[k][None] if k in ['mask', 'depth'] else out_dict[k])
104
+ for k in ret:
105
+ ret[k] = torch.stack(ret[k])
106
+ return ret
107
+
108
+ @staticmethod
109
+ def _tsdf_reg_loss(rep: MeshExtractResult, depth_map: torch.Tensor, extrinsics: torch.Tensor, intrinsics: torch.Tensor) -> torch.Tensor:
110
+ # Calculate tsdf
111
+ with torch.no_grad():
112
+ # Project points to camera and calculate pseudo-sdf as difference between gt depth and projected depth
113
+ projected_pts, pts_depth = utils3d.torch.project_cv(extrinsics=extrinsics, intrinsics=intrinsics, points=rep.tsdf_v)
114
+ projected_pts = (projected_pts - 0.5) * 2.0
115
+ depth_map_res = depth_map.shape[1]
116
+ gt_depth = torch.nn.functional.grid_sample(depth_map.reshape(1, 1, depth_map_res, depth_map_res),
117
+ projected_pts.reshape(1, 1, -1, 2), mode='bilinear', padding_mode='border', align_corners=True)
118
+ pseudo_sdf = gt_depth.flatten() - pts_depth.flatten()
119
+ # Truncate pseudo-sdf
120
+ delta = 1 / rep.res * 3.0
121
+ trunc_mask = pseudo_sdf > -delta
122
+
123
+ # Loss
124
+ gt_tsdf = pseudo_sdf[trunc_mask]
125
+ tsdf = rep.tsdf_s.flatten()[trunc_mask]
126
+ gt_tsdf = torch.clamp(gt_tsdf, -delta, delta)
127
+ return torch.mean((tsdf - gt_tsdf) ** 2)
128
+
129
+ def _calc_tsdf_loss(self, reps : list[MeshExtractResult], depth_maps, extrinsics, intrinsics) -> torch.Tensor:
130
+ tsdf_loss = 0.0
131
+ for i, rep in enumerate(reps):
132
+ tsdf_loss += self._tsdf_reg_loss(rep, depth_maps[i], extrinsics[i], intrinsics[i])
133
+ return tsdf_loss / len(reps)
134
+
135
+ @torch.no_grad()
136
+ def _flip_normal(self, normal: torch.Tensor, extrinsics: torch.Tensor, intrinsics: torch.Tensor) -> torch.Tensor:
137
+ """
138
+ Flip normal to align with camera.
139
+ """
140
+ normal = normal * 2.0 - 1.0
141
+ R = torch.zeros_like(extrinsics)
142
+ R[:, :3, :3] = extrinsics[:, :3, :3]
143
+ R[:, 3, 3] = 1.0
144
+ view_dir = utils3d.torch.unproject_cv(
145
+ utils3d.torch.image_uv(*normal.shape[-2:], device=self.device).reshape(1, -1, 2),
146
+ torch.ones(*normal.shape[-2:], device=self.device).reshape(1, -1),
147
+ R, intrinsics
148
+ ).reshape(-1, *normal.shape[-2:], 3).permute(0, 3, 1, 2)
149
+ unflip = (normal * view_dir).sum(1, keepdim=True) < 0
150
+ normal *= unflip * 2.0 - 1.0
151
+ return (normal + 1.0) / 2.0
152
+
153
+ def _perceptual_loss(self, gt: torch.Tensor, pred: torch.Tensor, name: str) -> Dict[str, torch.Tensor]:
154
+ """
155
+ Combination of L1, SSIM, and LPIPS loss.
156
+ """
157
+ if gt.shape[1] != 3:
158
+ assert gt.shape[-1] == 3
159
+ gt = gt.permute(0, 3, 1, 2)
160
+ if pred.shape[1] != 3:
161
+ assert pred.shape[-1] == 3
162
+ pred = pred.permute(0, 3, 1, 2)
163
+ terms = {
164
+ f"{name}_loss" : l1_loss(gt, pred),
165
+ f"{name}_loss_ssim" : 1 - ssim(gt, pred),
166
+ f"{name}_loss_lpips" : lpips(gt, pred)
167
+ }
168
+ terms[f"{name}_loss_perceptual"] = terms[f"{name}_loss"] + terms[f"{name}_loss_ssim"] * self.lambda_ssim + terms[f"{name}_loss_lpips"] * self.lambda_lpips
169
+ return terms
170
+
171
+ def geometry_losses(
172
+ self,
173
+ reps: List[MeshExtractResult],
174
+ mesh: List[Dict],
175
+ normal_map: torch.Tensor,
176
+ extrinsics: torch.Tensor,
177
+ intrinsics: torch.Tensor,
178
+ ):
179
+ with torch.no_grad():
180
+ gt_meshes = []
181
+ for i in range(len(reps)):
182
+ gt_mesh = MeshExtractResult(mesh[i]['vertices'].to(self.device), mesh[i]['faces'].to(self.device))
183
+ gt_meshes.append(gt_mesh)
184
+ target = self._render_batch(gt_meshes, extrinsics, intrinsics, return_types=['mask', 'depth', 'normal'])
185
+ target['normal'] = self._flip_normal(target['normal'], extrinsics, intrinsics)
186
+
187
+ terms = edict(geo_loss = 0.0)
188
+ if self.lambda_tsdf > 0:
189
+ tsdf_loss = self._calc_tsdf_loss(reps, target['depth'], extrinsics, intrinsics)
190
+ terms['tsdf_loss'] = tsdf_loss
191
+ terms['geo_loss'] += tsdf_loss * self.lambda_tsdf
192
+
193
+ return_types = ['mask', 'depth', 'normal', 'normal_map'] if self.use_color else ['mask', 'depth', 'normal']
194
+ buffer = self._render_batch(reps, extrinsics, intrinsics, return_types=return_types)
195
+
196
+ success_mask = torch.tensor([rep.success for rep in reps], device=self.device)
197
+ if success_mask.sum() != 0:
198
+ for k, v in buffer.items():
199
+ buffer[k] = v[success_mask]
200
+ for k, v in target.items():
201
+ target[k] = v[success_mask]
202
+
203
+ terms['mask_loss'] = l1_loss(buffer['mask'], target['mask'])
204
+ if self.depth_loss_type == 'l1':
205
+ terms['depth_loss'] = l1_loss(buffer['depth'] * target['mask'], target['depth'] * target['mask'])
206
+ elif self.depth_loss_type == 'smooth_l1':
207
+ terms['depth_loss'] = smooth_l1_loss(buffer['depth'] * target['mask'], target['depth'] * target['mask'], beta=1.0 / (2 * reps[0].res))
208
+ else:
209
+ raise ValueError(f"Unsupported depth loss type: {self.depth_loss_type}")
210
+ terms.update(self._perceptual_loss(buffer['normal'] * target['mask'], target['normal'] * target['mask'], 'normal'))
211
+ terms['geo_loss'] = terms['geo_loss'] + terms['mask_loss'] + terms['depth_loss'] * self.lambda_depth + terms['normal_loss_perceptual']
212
+ if self.use_color and normal_map is not None:
213
+ terms.update(self._perceptual_loss(normal_map[success_mask], buffer['normal_map'], 'normal_map'))
214
+ terms['geo_loss'] = terms['geo_loss'] + terms['normal_map_loss_perceptual'] * self.lambda_color
215
+
216
+ return terms
217
+
218
+ def color_losses(self, reps, image, alpha, extrinsics, intrinsics):
219
+ terms = edict(color_loss = torch.tensor(0.0, device=self.device))
220
+ buffer = self._render_batch(reps, extrinsics, intrinsics, return_types=['color'])
221
+ success_mask = torch.tensor([rep.success for rep in reps], device=self.device)
222
+ if success_mask.sum() != 0:
223
+ terms.update(self._perceptual_loss(image * alpha[:, None][success_mask], buffer['color'][success_mask], 'color'))
224
+ terms['color_loss'] = terms['color_loss'] + terms['color_loss_perceptual'] * self.lambda_color
225
+ return terms
226
+
227
+ def training_losses(
228
+ self,
229
+ latents: SparseTensor,
230
+ image: torch.Tensor,
231
+ alpha: torch.Tensor,
232
+ mesh: List[Dict],
233
+ extrinsics: torch.Tensor,
234
+ intrinsics: torch.Tensor,
235
+ normal_map: torch.Tensor = None,
236
+ ) -> Tuple[Dict, Dict]:
237
+ """
238
+ Compute training losses.
239
+
240
+ Args:
241
+ latents: The [N x * x C] sparse latents
242
+ image: The [N x 3 x H x W] tensor of images.
243
+ alpha: The [N x H x W] tensor of alpha channels.
244
+ mesh: The list of dictionaries of meshes.
245
+ extrinsics: The [N x 4 x 4] tensor of extrinsics.
246
+ intrinsics: The [N x 3 x 3] tensor of intrinsics.
247
+
248
+ Returns:
249
+ a dict with the key "loss" containing a scalar tensor.
250
+ may also contain other keys for different terms.
251
+ """
252
+ reps = self.training_models['decoder'](latents)
253
+ self.renderer.rendering_options.resolution = image.shape[-1]
254
+
255
+ terms = edict(loss = 0.0, rec = 0.0)
256
+
257
+ terms['reg_loss'] = sum([rep.reg_loss for rep in reps]) / len(reps)
258
+ terms['loss'] = terms['loss'] + terms['reg_loss']
259
+
260
+ geo_terms = self.geometry_losses(reps, mesh, normal_map, extrinsics, intrinsics)
261
+ terms.update(geo_terms)
262
+ terms['loss'] = terms['loss'] + terms['geo_loss']
263
+
264
+ if self.use_color:
265
+ color_terms = self.color_losses(reps, image, alpha, extrinsics, intrinsics)
266
+ terms.update(color_terms)
267
+ terms['loss'] = terms['loss'] + terms['color_loss']
268
+
269
+ return terms, {}
270
+
271
+ @torch.no_grad()
272
+ def run_snapshot(
273
+ self,
274
+ num_samples: int,
275
+ batch_size: int,
276
+ verbose: bool = False,
277
+ ) -> Dict:
278
+ dataloader = DataLoader(
279
+ copy.deepcopy(self.dataset),
280
+ batch_size=batch_size,
281
+ shuffle=True,
282
+ num_workers=0,
283
+ collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None,
284
+ )
285
+
286
+ # inference
287
+ ret_dict = {}
288
+ gt_images = []
289
+ gt_normal_maps = []
290
+ gt_meshes = []
291
+ exts = []
292
+ ints = []
293
+ reps = []
294
+ for i in range(0, num_samples, batch_size):
295
+ batch = min(batch_size, num_samples - i)
296
+ data = next(iter(dataloader))
297
+ args = recursive_to_device(data, 'cuda')
298
+ gt_images.append(args['image'] * args['alpha'][:, None])
299
+ if self.use_color and 'normal_map' in data:
300
+ gt_normal_maps.append(args['normal_map'])
301
+ gt_meshes.extend(args['mesh'])
302
+ exts.append(args['extrinsics'])
303
+ ints.append(args['intrinsics'])
304
+ reps.extend(self.models['decoder'](args['latents']))
305
+ gt_images = torch.cat(gt_images, dim=0)
306
+ ret_dict.update({f'gt_image': {'value': gt_images, 'type': 'image'}})
307
+ if self.use_color and gt_normal_maps:
308
+ gt_normal_maps = torch.cat(gt_normal_maps, dim=0)
309
+ ret_dict.update({f'gt_normal_map': {'value': gt_normal_maps, 'type': 'image'}})
310
+
311
+ # render single view
312
+ exts = torch.cat(exts, dim=0)
313
+ ints = torch.cat(ints, dim=0)
314
+ self.renderer.rendering_options.bg_color = (0, 0, 0)
315
+ self.renderer.rendering_options.resolution = gt_images.shape[-1]
316
+ gt_render_results = self._render_batch([
317
+ MeshExtractResult(vertices=mesh['vertices'].to(self.device), faces=mesh['faces'].to(self.device))
318
+ for mesh in gt_meshes
319
+ ], exts, ints, return_types=['normal'])
320
+ ret_dict.update({f'gt_normal': {'value': self._flip_normal(gt_render_results['normal'], exts, ints), 'type': 'image'}})
321
+ return_types = ['normal']
322
+ if self.use_color:
323
+ return_types.append('color')
324
+ if 'normal_map' in data:
325
+ return_types.append('normal_map')
326
+ render_results = self._render_batch(reps, exts, ints, return_types=return_types)
327
+ ret_dict.update({f'rec_normal': {'value': render_results['normal'], 'type': 'image'}})
328
+ if 'color' in return_types:
329
+ ret_dict.update({f'rec_image': {'value': render_results['color'], 'type': 'image'}})
330
+ if 'normal_map' in return_types:
331
+ ret_dict.update({f'rec_normal_map': {'value': render_results['normal_map'], 'type': 'image'}})
332
+
333
+ # render multiview
334
+ self.renderer.rendering_options.resolution = 512
335
+ ## Build camera
336
+ yaws = [0, np.pi / 2, np.pi, 3 * np.pi / 2]
337
+ yaws_offset = np.random.uniform(-np.pi / 4, np.pi / 4)
338
+ yaws = [y + yaws_offset for y in yaws]
339
+ pitch = [np.random.uniform(-np.pi / 4, np.pi / 4) for _ in range(4)]
340
+
341
+ ## render each view
342
+ multiview_normals = []
343
+ multiview_normal_maps = []
344
+ miltiview_images = []
345
+ for yaw, pitch in zip(yaws, pitch):
346
+ orig = torch.tensor([
347
+ np.sin(yaw) * np.cos(pitch),
348
+ np.cos(yaw) * np.cos(pitch),
349
+ np.sin(pitch),
350
+ ]).float().cuda() * 2
351
+ fov = torch.deg2rad(torch.tensor(30)).cuda()
352
+ extrinsics = utils3d.torch.extrinsics_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda())
353
+ intrinsics = utils3d.torch.intrinsics_from_fov_xy(fov, fov)
354
+ extrinsics = extrinsics.unsqueeze(0).expand(num_samples, -1, -1)
355
+ intrinsics = intrinsics.unsqueeze(0).expand(num_samples, -1, -1)
356
+ render_results = self._render_batch(reps, extrinsics, intrinsics, return_types=return_types)
357
+ multiview_normals.append(render_results['normal'])
358
+ if 'color' in return_types:
359
+ miltiview_images.append(render_results['color'])
360
+ if 'normal_map' in return_types:
361
+ multiview_normal_maps.append(render_results['normal_map'])
362
+
363
+ ## Concatenate views
364
+ multiview_normals = torch.cat([
365
+ torch.cat(multiview_normals[:2], dim=-2),
366
+ torch.cat(multiview_normals[2:], dim=-2),
367
+ ], dim=-1)
368
+ ret_dict.update({f'multiview_normal': {'value': multiview_normals, 'type': 'image'}})
369
+ if 'color' in return_types:
370
+ miltiview_images = torch.cat([
371
+ torch.cat(miltiview_images[:2], dim=-2),
372
+ torch.cat(miltiview_images[2:], dim=-2),
373
+ ], dim=-1)
374
+ ret_dict.update({f'multiview_image': {'value': miltiview_images, 'type': 'image'}})
375
+ if 'normal_map' in return_types:
376
+ multiview_normal_maps = torch.cat([
377
+ torch.cat(multiview_normal_maps[:2], dim=-2),
378
+ torch.cat(multiview_normal_maps[2:], dim=-2),
379
+ ], dim=-1)
380
+ ret_dict.update({f'multiview_normal_map': {'value': multiview_normal_maps, 'type': 'image'}})
381
+
382
+ return ret_dict