Spaces:
Running
on
Zero
Running
on
Zero
junbiao.chen
commited on
Commit
·
cc0c59d
1
Parent(s):
f29eac5
Trellis update
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- trellis/datasets/__init__.py +58 -0
- trellis/datasets/components.py +137 -0
- trellis/datasets/sparse_feat2render.py +134 -0
- trellis/datasets/sparse_structure.py +107 -0
- trellis/datasets/sparse_structure_latent.py +189 -0
- trellis/datasets/structured_latent.py +218 -0
- trellis/datasets/structured_latent2render.py +160 -0
- trellis/models/__init__.py +29 -3
- trellis/models/sparse_elastic_mixin.py +24 -0
- trellis/models/structured_latent_flow.py +50 -36
- trellis/models/structured_latent_vae/__init__.py +4 -4
- trellis/models/structured_latent_vae/decoder_gs.py +9 -0
- trellis/models/structured_latent_vae/decoder_mesh.py +9 -0
- trellis/models/structured_latent_vae/decoder_rf.py +9 -0
- trellis/models/structured_latent_vae/encoder.py +8 -0
- trellis/pipelines/__init__.py +1 -0
- trellis/pipelines/base.py +6 -4
- trellis/pipelines/samplers/flow_euler.py +2 -0
- trellis/pipelines/trellis_image_to_3d.py +2 -3
- trellis/pipelines/trellis_text_to_3d.py +278 -0
- trellis/representations/mesh/cube2mesh.py +1 -8
- trellis/representations/mesh/flexicubes/LICENSE.txt +90 -0
- trellis/representations/mesh/flexicubes/README.md +110 -0
- trellis/representations/mesh/flexicubes/examples/data/inputmodels/block.obj +0 -0
- trellis/representations/mesh/flexicubes/examples/download_data.py +41 -0
- trellis/representations/mesh/flexicubes/examples/extraction.ipynb +0 -0
- trellis/representations/mesh/flexicubes/examples/loss.py +95 -0
- trellis/representations/mesh/flexicubes/examples/optimization.ipynb +0 -0
- trellis/representations/mesh/flexicubes/examples/optimize.py +150 -0
- trellis/representations/mesh/flexicubes/examples/render.py +267 -0
- trellis/representations/mesh/flexicubes/examples/util.py +122 -0
- trellis/representations/mesh/flexicubes/flexicubes.py +384 -0
- trellis/representations/mesh/flexicubes/images/ablate_L_dev.jpg +0 -0
- trellis/representations/mesh/flexicubes/images/block_final.png +3 -0
- trellis/representations/mesh/flexicubes/images/block_init.png +3 -0
- trellis/representations/mesh/flexicubes/images/teaser_top.png +3 -0
- trellis/representations/mesh/flexicubes/tables.py +791 -0
- trellis/representations/octree/octree_dfs.py +3 -18
- trellis/trainers/__init__.py +63 -0
- trellis/trainers/base.py +451 -0
- trellis/trainers/basic.py +438 -0
- trellis/trainers/flow_matching/flow_matching.py +353 -0
- trellis/trainers/flow_matching/mixins/classifier_free_guidance.py +59 -0
- trellis/trainers/flow_matching/mixins/image_conditioned.py +93 -0
- trellis/trainers/flow_matching/mixins/text_conditioned.py +68 -0
- trellis/trainers/flow_matching/sparse_flow_matching.py +286 -0
- trellis/trainers/utils.py +77 -0
- trellis/trainers/vae/sparse_structure_vae.py +130 -0
- trellis/trainers/vae/structured_latent_vae_gaussian.py +275 -0
- trellis/trainers/vae/structured_latent_vae_mesh_dec.py +382 -0
trellis/datasets/__init__.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
|
3 |
+
__attributes = {
|
4 |
+
'SparseStructure': 'sparse_structure',
|
5 |
+
|
6 |
+
'SparseFeat2Render': 'sparse_feat2render',
|
7 |
+
'SLat2Render':'structured_latent2render',
|
8 |
+
'Slat2RenderGeo':'structured_latent2render',
|
9 |
+
|
10 |
+
'SparseStructureLatent': 'sparse_structure_latent',
|
11 |
+
'TextConditionedSparseStructureLatent': 'sparse_structure_latent',
|
12 |
+
'ImageConditionedSparseStructureLatent': 'sparse_structure_latent',
|
13 |
+
|
14 |
+
'SLat': 'structured_latent',
|
15 |
+
'TextConditionedSLat': 'structured_latent',
|
16 |
+
'ImageConditionedSLat': 'structured_latent',
|
17 |
+
}
|
18 |
+
|
19 |
+
__submodules = []
|
20 |
+
|
21 |
+
__all__ = list(__attributes.keys()) + __submodules
|
22 |
+
|
23 |
+
def __getattr__(name):
|
24 |
+
if name not in globals():
|
25 |
+
if name in __attributes:
|
26 |
+
module_name = __attributes[name]
|
27 |
+
module = importlib.import_module(f".{module_name}", __name__)
|
28 |
+
globals()[name] = getattr(module, name)
|
29 |
+
elif name in __submodules:
|
30 |
+
module = importlib.import_module(f".{name}", __name__)
|
31 |
+
globals()[name] = module
|
32 |
+
else:
|
33 |
+
raise AttributeError(f"module {__name__} has no attribute {name}")
|
34 |
+
return globals()[name]
|
35 |
+
|
36 |
+
|
37 |
+
# For Pylance
|
38 |
+
if __name__ == '__main__':
|
39 |
+
from .sparse_structure import SparseStructure
|
40 |
+
|
41 |
+
from .sparse_feat2render import SparseFeat2Render
|
42 |
+
from .structured_latent2render import (
|
43 |
+
SLat2Render,
|
44 |
+
Slat2RenderGeo,
|
45 |
+
)
|
46 |
+
|
47 |
+
from .sparse_structure_latent import (
|
48 |
+
SparseStructureLatent,
|
49 |
+
TextConditionedSparseStructureLatent,
|
50 |
+
ImageConditionedSparseStructureLatent,
|
51 |
+
)
|
52 |
+
|
53 |
+
from .structured_latent import (
|
54 |
+
SLat,
|
55 |
+
TextConditionedSLat,
|
56 |
+
ImageConditionedSLat,
|
57 |
+
)
|
58 |
+
|
trellis/datasets/components.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import *
|
2 |
+
from abc import abstractmethod
|
3 |
+
import os
|
4 |
+
import json
|
5 |
+
import torch
|
6 |
+
import numpy as np
|
7 |
+
import pandas as pd
|
8 |
+
from PIL import Image
|
9 |
+
from torch.utils.data import Dataset
|
10 |
+
|
11 |
+
|
12 |
+
class StandardDatasetBase(Dataset):
|
13 |
+
"""
|
14 |
+
Base class for standard datasets.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
roots (str): paths to the dataset
|
18 |
+
"""
|
19 |
+
|
20 |
+
def __init__(self,
|
21 |
+
roots: str,
|
22 |
+
):
|
23 |
+
super().__init__()
|
24 |
+
self.roots = roots.split(',')
|
25 |
+
self.instances = []
|
26 |
+
self.metadata = pd.DataFrame()
|
27 |
+
|
28 |
+
self._stats = {}
|
29 |
+
for root in self.roots:
|
30 |
+
key = os.path.basename(root)
|
31 |
+
self._stats[key] = {}
|
32 |
+
metadata = pd.read_csv(os.path.join(root, 'metadata.csv'))
|
33 |
+
self._stats[key]['Total'] = len(metadata)
|
34 |
+
metadata, stats = self.filter_metadata(metadata)
|
35 |
+
self._stats[key].update(stats)
|
36 |
+
self.instances.extend([(root, sha256) for sha256 in metadata['sha256'].values])
|
37 |
+
metadata.set_index('sha256', inplace=True)
|
38 |
+
self.metadata = pd.concat([self.metadata, metadata])
|
39 |
+
|
40 |
+
@abstractmethod
|
41 |
+
def filter_metadata(self, metadata: pd.DataFrame) -> Tuple[pd.DataFrame, Dict[str, int]]:
|
42 |
+
pass
|
43 |
+
|
44 |
+
@abstractmethod
|
45 |
+
def get_instance(self, root: str, instance: str) -> Dict[str, Any]:
|
46 |
+
pass
|
47 |
+
|
48 |
+
def __len__(self):
|
49 |
+
return len(self.instances)
|
50 |
+
|
51 |
+
def __getitem__(self, index) -> Dict[str, Any]:
|
52 |
+
try:
|
53 |
+
root, instance = self.instances[index]
|
54 |
+
return self.get_instance(root, instance)
|
55 |
+
except Exception as e:
|
56 |
+
print(e)
|
57 |
+
return self.__getitem__(np.random.randint(0, len(self)))
|
58 |
+
|
59 |
+
def __str__(self):
|
60 |
+
lines = []
|
61 |
+
lines.append(self.__class__.__name__)
|
62 |
+
lines.append(f' - Total instances: {len(self)}')
|
63 |
+
lines.append(f' - Sources:')
|
64 |
+
for key, stats in self._stats.items():
|
65 |
+
lines.append(f' - {key}:')
|
66 |
+
for k, v in stats.items():
|
67 |
+
lines.append(f' - {k}: {v}')
|
68 |
+
return '\n'.join(lines)
|
69 |
+
|
70 |
+
|
71 |
+
class TextConditionedMixin:
|
72 |
+
def __init__(self, roots, **kwargs):
|
73 |
+
super().__init__(roots, **kwargs)
|
74 |
+
self.captions = {}
|
75 |
+
for instance in self.instances:
|
76 |
+
sha256 = instance[1]
|
77 |
+
self.captions[sha256] = json.loads(self.metadata.loc[sha256]['captions'])
|
78 |
+
|
79 |
+
def filter_metadata(self, metadata):
|
80 |
+
metadata, stats = super().filter_metadata(metadata)
|
81 |
+
metadata = metadata[metadata['captions'].notna()]
|
82 |
+
stats['With captions'] = len(metadata)
|
83 |
+
return metadata, stats
|
84 |
+
|
85 |
+
def get_instance(self, root, instance):
|
86 |
+
pack = super().get_instance(root, instance)
|
87 |
+
text = np.random.choice(self.captions[instance])
|
88 |
+
pack['cond'] = text
|
89 |
+
return pack
|
90 |
+
|
91 |
+
|
92 |
+
class ImageConditionedMixin:
|
93 |
+
def __init__(self, roots, *, image_size=518, **kwargs):
|
94 |
+
self.image_size = image_size
|
95 |
+
super().__init__(roots, **kwargs)
|
96 |
+
|
97 |
+
def filter_metadata(self, metadata):
|
98 |
+
metadata, stats = super().filter_metadata(metadata)
|
99 |
+
metadata = metadata[metadata[f'cond_rendered']]
|
100 |
+
stats['Cond rendered'] = len(metadata)
|
101 |
+
return metadata, stats
|
102 |
+
|
103 |
+
def get_instance(self, root, instance):
|
104 |
+
pack = super().get_instance(root, instance)
|
105 |
+
|
106 |
+
image_root = os.path.join(root, 'renders_cond', instance)
|
107 |
+
with open(os.path.join(image_root, 'transforms.json')) as f:
|
108 |
+
metadata = json.load(f)
|
109 |
+
n_views = len(metadata['frames'])
|
110 |
+
view = np.random.randint(n_views)
|
111 |
+
metadata = metadata['frames'][view]
|
112 |
+
|
113 |
+
image_path = os.path.join(image_root, metadata['file_path'])
|
114 |
+
image = Image.open(image_path)
|
115 |
+
|
116 |
+
alpha = np.array(image.getchannel(3))
|
117 |
+
bbox = np.array(alpha).nonzero()
|
118 |
+
bbox = [bbox[1].min(), bbox[0].min(), bbox[1].max(), bbox[0].max()]
|
119 |
+
center = [(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2]
|
120 |
+
hsize = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) / 2
|
121 |
+
aug_size_ratio = 1.2
|
122 |
+
aug_hsize = hsize * aug_size_ratio
|
123 |
+
aug_center_offset = [0, 0]
|
124 |
+
aug_center = [center[0] + aug_center_offset[0], center[1] + aug_center_offset[1]]
|
125 |
+
aug_bbox = [int(aug_center[0] - aug_hsize), int(aug_center[1] - aug_hsize), int(aug_center[0] + aug_hsize), int(aug_center[1] + aug_hsize)]
|
126 |
+
image = image.crop(aug_bbox)
|
127 |
+
|
128 |
+
image = image.resize((self.image_size, self.image_size), Image.Resampling.LANCZOS)
|
129 |
+
alpha = image.getchannel(3)
|
130 |
+
image = image.convert('RGB')
|
131 |
+
image = torch.tensor(np.array(image)).permute(2, 0, 1).float() / 255.0
|
132 |
+
alpha = torch.tensor(np.array(alpha)).float() / 255.0
|
133 |
+
image = image * alpha.unsqueeze(0)
|
134 |
+
pack['cond'] = image
|
135 |
+
|
136 |
+
return pack
|
137 |
+
|
trellis/datasets/sparse_feat2render.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from PIL import Image
|
3 |
+
import json
|
4 |
+
import numpy as np
|
5 |
+
import pandas as pd
|
6 |
+
import torch
|
7 |
+
import utils3d.torch
|
8 |
+
from ..modules.sparse.basic import SparseTensor
|
9 |
+
from .components import StandardDatasetBase
|
10 |
+
|
11 |
+
|
12 |
+
class SparseFeat2Render(StandardDatasetBase):
|
13 |
+
"""
|
14 |
+
SparseFeat2Render dataset.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
roots (str): paths to the dataset
|
18 |
+
image_size (int): size of the image
|
19 |
+
model (str): model name
|
20 |
+
resolution (int): resolution of the data
|
21 |
+
min_aesthetic_score (float): minimum aesthetic score
|
22 |
+
max_num_voxels (int): maximum number of voxels
|
23 |
+
"""
|
24 |
+
def __init__(
|
25 |
+
self,
|
26 |
+
roots: str,
|
27 |
+
image_size: int,
|
28 |
+
model: str = 'dinov2_vitl14_reg',
|
29 |
+
resolution: int = 64,
|
30 |
+
min_aesthetic_score: float = 5.0,
|
31 |
+
max_num_voxels: int = 32768,
|
32 |
+
):
|
33 |
+
self.image_size = image_size
|
34 |
+
self.model = model
|
35 |
+
self.resolution = resolution
|
36 |
+
self.min_aesthetic_score = min_aesthetic_score
|
37 |
+
self.max_num_voxels = max_num_voxels
|
38 |
+
self.value_range = (0, 1)
|
39 |
+
|
40 |
+
super().__init__(roots)
|
41 |
+
|
42 |
+
def filter_metadata(self, metadata):
|
43 |
+
stats = {}
|
44 |
+
metadata = metadata[metadata[f'feature_{self.model}']]
|
45 |
+
stats['With features'] = len(metadata)
|
46 |
+
metadata = metadata[metadata['aesthetic_score'] >= self.min_aesthetic_score]
|
47 |
+
stats[f'Aesthetic score >= {self.min_aesthetic_score}'] = len(metadata)
|
48 |
+
metadata = metadata[metadata['num_voxels'] <= self.max_num_voxels]
|
49 |
+
stats[f'Num voxels <= {self.max_num_voxels}'] = len(metadata)
|
50 |
+
return metadata, stats
|
51 |
+
|
52 |
+
def _get_image(self, root, instance):
|
53 |
+
with open(os.path.join(root, 'renders', instance, 'transforms.json')) as f:
|
54 |
+
metadata = json.load(f)
|
55 |
+
n_views = len(metadata['frames'])
|
56 |
+
view = np.random.randint(n_views)
|
57 |
+
metadata = metadata['frames'][view]
|
58 |
+
fov = metadata['camera_angle_x']
|
59 |
+
intrinsics = utils3d.torch.intrinsics_from_fov_xy(torch.tensor(fov), torch.tensor(fov))
|
60 |
+
c2w = torch.tensor(metadata['transform_matrix'])
|
61 |
+
c2w[:3, 1:3] *= -1
|
62 |
+
extrinsics = torch.inverse(c2w)
|
63 |
+
|
64 |
+
image_path = os.path.join(root, 'renders', instance, metadata['file_path'])
|
65 |
+
image = Image.open(image_path)
|
66 |
+
alpha = image.getchannel(3)
|
67 |
+
image = image.convert('RGB')
|
68 |
+
image = image.resize((self.image_size, self.image_size), Image.Resampling.LANCZOS)
|
69 |
+
alpha = alpha.resize((self.image_size, self.image_size), Image.Resampling.LANCZOS)
|
70 |
+
image = torch.tensor(np.array(image)).permute(2, 0, 1).float() / 255.0
|
71 |
+
alpha = torch.tensor(np.array(alpha)).float() / 255.0
|
72 |
+
|
73 |
+
return {
|
74 |
+
'image': image,
|
75 |
+
'alpha': alpha,
|
76 |
+
'extrinsics': extrinsics,
|
77 |
+
'intrinsics': intrinsics,
|
78 |
+
}
|
79 |
+
|
80 |
+
def _get_feat(self, root, instance):
|
81 |
+
DATA_RESOLUTION = 64
|
82 |
+
feats_path = os.path.join(root, 'features', self.model, f'{instance}.npz')
|
83 |
+
feats = np.load(feats_path, allow_pickle=True)
|
84 |
+
coords = torch.tensor(feats['indices']).int()
|
85 |
+
feats = torch.tensor(feats['patchtokens']).float()
|
86 |
+
|
87 |
+
if self.resolution != DATA_RESOLUTION:
|
88 |
+
factor = DATA_RESOLUTION // self.resolution
|
89 |
+
coords = coords // factor
|
90 |
+
coords, idx = coords.unique(return_inverse=True, dim=0)
|
91 |
+
feats = torch.scatter_reduce(
|
92 |
+
torch.zeros(coords.shape[0], feats.shape[1], device=feats.device),
|
93 |
+
dim=0,
|
94 |
+
index=idx.unsqueeze(-1).expand(-1, feats.shape[1]),
|
95 |
+
src=feats,
|
96 |
+
reduce='mean'
|
97 |
+
)
|
98 |
+
|
99 |
+
return {
|
100 |
+
'coords': coords,
|
101 |
+
'feats': feats,
|
102 |
+
}
|
103 |
+
|
104 |
+
@torch.no_grad()
|
105 |
+
def visualize_sample(self, sample: dict):
|
106 |
+
return sample['image']
|
107 |
+
|
108 |
+
@staticmethod
|
109 |
+
def collate_fn(batch):
|
110 |
+
pack = {}
|
111 |
+
coords = []
|
112 |
+
for i, b in enumerate(batch):
|
113 |
+
coords.append(torch.cat([torch.full((b['coords'].shape[0], 1), i, dtype=torch.int32), b['coords']], dim=-1))
|
114 |
+
coords = torch.cat(coords)
|
115 |
+
feats = torch.cat([b['feats'] for b in batch])
|
116 |
+
pack['feats'] = SparseTensor(
|
117 |
+
coords=coords,
|
118 |
+
feats=feats,
|
119 |
+
)
|
120 |
+
|
121 |
+
pack['image'] = torch.stack([b['image'] for b in batch])
|
122 |
+
pack['alpha'] = torch.stack([b['alpha'] for b in batch])
|
123 |
+
pack['extrinsics'] = torch.stack([b['extrinsics'] for b in batch])
|
124 |
+
pack['intrinsics'] = torch.stack([b['intrinsics'] for b in batch])
|
125 |
+
|
126 |
+
return pack
|
127 |
+
|
128 |
+
def get_instance(self, root, instance):
|
129 |
+
image = self._get_image(root, instance)
|
130 |
+
feat = self._get_feat(root, instance)
|
131 |
+
return {
|
132 |
+
**image,
|
133 |
+
**feat,
|
134 |
+
}
|
trellis/datasets/sparse_structure.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
from typing import Union
|
4 |
+
import numpy as np
|
5 |
+
import pandas as pd
|
6 |
+
import torch
|
7 |
+
from torch.utils.data import Dataset
|
8 |
+
import utils3d
|
9 |
+
from .components import StandardDatasetBase
|
10 |
+
from ..representations.octree import DfsOctree as Octree
|
11 |
+
from ..renderers import OctreeRenderer
|
12 |
+
|
13 |
+
|
14 |
+
class SparseStructure(StandardDatasetBase):
|
15 |
+
"""
|
16 |
+
Sparse structure dataset
|
17 |
+
|
18 |
+
Args:
|
19 |
+
roots (str): path to the dataset
|
20 |
+
resolution (int): resolution of the voxel grid
|
21 |
+
min_aesthetic_score (float): minimum aesthetic score of the instances to be included in the dataset
|
22 |
+
"""
|
23 |
+
|
24 |
+
def __init__(self,
|
25 |
+
roots,
|
26 |
+
resolution: int = 64,
|
27 |
+
min_aesthetic_score: float = 5.0,
|
28 |
+
):
|
29 |
+
self.resolution = resolution
|
30 |
+
self.min_aesthetic_score = min_aesthetic_score
|
31 |
+
self.value_range = (0, 1)
|
32 |
+
|
33 |
+
super().__init__(roots)
|
34 |
+
|
35 |
+
def filter_metadata(self, metadata):
|
36 |
+
stats = {}
|
37 |
+
metadata = metadata[metadata[f'voxelized']]
|
38 |
+
stats['Voxelized'] = len(metadata)
|
39 |
+
metadata = metadata[metadata['aesthetic_score'] >= self.min_aesthetic_score]
|
40 |
+
stats[f'Aesthetic score >= {self.min_aesthetic_score}'] = len(metadata)
|
41 |
+
return metadata, stats
|
42 |
+
|
43 |
+
def get_instance(self, root, instance):
|
44 |
+
position = utils3d.io.read_ply(os.path.join(root, 'voxels', f'{instance}.ply'))[0]
|
45 |
+
coords = ((torch.tensor(position) + 0.5) * self.resolution).int().contiguous()
|
46 |
+
ss = torch.zeros(1, self.resolution, self.resolution, self.resolution, dtype=torch.long)
|
47 |
+
ss[:, coords[:, 0], coords[:, 1], coords[:, 2]] = 1
|
48 |
+
return {'ss': ss}
|
49 |
+
|
50 |
+
@torch.no_grad()
|
51 |
+
def visualize_sample(self, ss: Union[torch.Tensor, dict]):
|
52 |
+
ss = ss if isinstance(ss, torch.Tensor) else ss['ss']
|
53 |
+
|
54 |
+
renderer = OctreeRenderer()
|
55 |
+
renderer.rendering_options.resolution = 512
|
56 |
+
renderer.rendering_options.near = 0.8
|
57 |
+
renderer.rendering_options.far = 1.6
|
58 |
+
renderer.rendering_options.bg_color = (0, 0, 0)
|
59 |
+
renderer.rendering_options.ssaa = 4
|
60 |
+
renderer.pipe.primitive = 'voxel'
|
61 |
+
|
62 |
+
# Build camera
|
63 |
+
yaws = [0, np.pi / 2, np.pi, 3 * np.pi / 2]
|
64 |
+
yaws_offset = np.random.uniform(-np.pi / 4, np.pi / 4)
|
65 |
+
yaws = [y + yaws_offset for y in yaws]
|
66 |
+
pitch = [np.random.uniform(-np.pi / 4, np.pi / 4) for _ in range(4)]
|
67 |
+
|
68 |
+
exts = []
|
69 |
+
ints = []
|
70 |
+
for yaw, pitch in zip(yaws, pitch):
|
71 |
+
orig = torch.tensor([
|
72 |
+
np.sin(yaw) * np.cos(pitch),
|
73 |
+
np.cos(yaw) * np.cos(pitch),
|
74 |
+
np.sin(pitch),
|
75 |
+
]).float().cuda() * 2
|
76 |
+
fov = torch.deg2rad(torch.tensor(30)).cuda()
|
77 |
+
extrinsics = utils3d.torch.extrinsics_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda())
|
78 |
+
intrinsics = utils3d.torch.intrinsics_from_fov_xy(fov, fov)
|
79 |
+
exts.append(extrinsics)
|
80 |
+
ints.append(intrinsics)
|
81 |
+
|
82 |
+
images = []
|
83 |
+
|
84 |
+
# Build each representation
|
85 |
+
ss = ss.cuda()
|
86 |
+
for i in range(ss.shape[0]):
|
87 |
+
representation = Octree(
|
88 |
+
depth=10,
|
89 |
+
aabb=[-0.5, -0.5, -0.5, 1, 1, 1],
|
90 |
+
device='cuda',
|
91 |
+
primitive='voxel',
|
92 |
+
sh_degree=0,
|
93 |
+
primitive_config={'solid': True},
|
94 |
+
)
|
95 |
+
coords = torch.nonzero(ss[i, 0], as_tuple=False)
|
96 |
+
representation.position = coords.float() / self.resolution
|
97 |
+
representation.depth = torch.full((representation.position.shape[0], 1), int(np.log2(self.resolution)), dtype=torch.uint8, device='cuda')
|
98 |
+
|
99 |
+
image = torch.zeros(3, 1024, 1024).cuda()
|
100 |
+
tile = [2, 2]
|
101 |
+
for j, (ext, intr) in enumerate(zip(exts, ints)):
|
102 |
+
res = renderer.render(representation, ext, intr, colors_overwrite=representation.position)
|
103 |
+
image[:, 512 * (j // tile[1]):512 * (j // tile[1] + 1), 512 * (j % tile[1]):512 * (j % tile[1] + 1)] = res['color']
|
104 |
+
images.append(image)
|
105 |
+
|
106 |
+
return torch.stack(images)
|
107 |
+
|
trellis/datasets/sparse_structure_latent.py
ADDED
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
from typing import *
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import utils3d
|
7 |
+
from ..representations.octree import DfsOctree as Octree
|
8 |
+
from ..renderers import OctreeRenderer
|
9 |
+
from .components import StandardDatasetBase, TextConditionedMixin, ImageConditionedMixin
|
10 |
+
from .. import models
|
11 |
+
from ..utils.dist_utils import read_file_dist
|
12 |
+
|
13 |
+
|
14 |
+
class SparseStructureLatentVisMixin:
|
15 |
+
def __init__(
|
16 |
+
self,
|
17 |
+
*args,
|
18 |
+
pretrained_ss_dec: str = 'JeffreyXiang/TRELLIS-image-large/ckpts/ss_dec_conv3d_16l8_fp16',
|
19 |
+
ss_dec_path: Optional[str] = None,
|
20 |
+
ss_dec_ckpt: Optional[str] = None,
|
21 |
+
**kwargs
|
22 |
+
):
|
23 |
+
super().__init__(*args, **kwargs)
|
24 |
+
self.ss_dec = None
|
25 |
+
self.pretrained_ss_dec = pretrained_ss_dec
|
26 |
+
self.ss_dec_path = ss_dec_path
|
27 |
+
self.ss_dec_ckpt = ss_dec_ckpt
|
28 |
+
|
29 |
+
def _loading_ss_dec(self):
|
30 |
+
if self.ss_dec is not None:
|
31 |
+
return
|
32 |
+
if self.ss_dec_path is not None:
|
33 |
+
cfg = json.load(open(os.path.join(self.ss_dec_path, 'config.json'), 'r'))
|
34 |
+
decoder = getattr(models, cfg['models']['decoder']['name'])(**cfg['models']['decoder']['args'])
|
35 |
+
ckpt_path = os.path.join(self.ss_dec_path, 'ckpts', f'decoder_{self.ss_dec_ckpt}.pt')
|
36 |
+
decoder.load_state_dict(torch.load(read_file_dist(ckpt_path), map_location='cpu', weights_only=True))
|
37 |
+
else:
|
38 |
+
decoder = models.from_pretrained(self.pretrained_ss_dec)
|
39 |
+
self.ss_dec = decoder.cuda().eval()
|
40 |
+
|
41 |
+
def _delete_ss_dec(self):
|
42 |
+
del self.ss_dec
|
43 |
+
self.ss_dec = None
|
44 |
+
|
45 |
+
@torch.no_grad()
|
46 |
+
def decode_latent(self, z, batch_size=4):
|
47 |
+
self._loading_ss_dec()
|
48 |
+
ss = []
|
49 |
+
if self.normalization is not None:
|
50 |
+
z = z * self.std.to(z.device) + self.mean.to(z.device)
|
51 |
+
for i in range(0, z.shape[0], batch_size):
|
52 |
+
ss.append(self.ss_dec(z[i:i+batch_size]))
|
53 |
+
ss = torch.cat(ss, dim=0)
|
54 |
+
self._delete_ss_dec()
|
55 |
+
return ss
|
56 |
+
|
57 |
+
@torch.no_grad()
|
58 |
+
def visualize_sample(self, x_0: Union[torch.Tensor, dict]):
|
59 |
+
x_0 = x_0 if isinstance(x_0, torch.Tensor) else x_0['x_0']
|
60 |
+
x_0 = self.decode_latent(x_0.cuda())
|
61 |
+
|
62 |
+
renderer = OctreeRenderer()
|
63 |
+
renderer.rendering_options.resolution = 512
|
64 |
+
renderer.rendering_options.near = 0.8
|
65 |
+
renderer.rendering_options.far = 1.6
|
66 |
+
renderer.rendering_options.bg_color = (0, 0, 0)
|
67 |
+
renderer.rendering_options.ssaa = 4
|
68 |
+
renderer.pipe.primitive = 'voxel'
|
69 |
+
|
70 |
+
# Build camera
|
71 |
+
yaws = [0, np.pi / 2, np.pi, 3 * np.pi / 2]
|
72 |
+
yaws_offset = np.random.uniform(-np.pi / 4, np.pi / 4)
|
73 |
+
yaws = [y + yaws_offset for y in yaws]
|
74 |
+
pitch = [np.random.uniform(-np.pi / 4, np.pi / 4) for _ in range(4)]
|
75 |
+
|
76 |
+
exts = []
|
77 |
+
ints = []
|
78 |
+
for yaw, pitch in zip(yaws, pitch):
|
79 |
+
orig = torch.tensor([
|
80 |
+
np.sin(yaw) * np.cos(pitch),
|
81 |
+
np.cos(yaw) * np.cos(pitch),
|
82 |
+
np.sin(pitch),
|
83 |
+
]).float().cuda() * 2
|
84 |
+
fov = torch.deg2rad(torch.tensor(30)).cuda()
|
85 |
+
extrinsics = utils3d.torch.extrinsics_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda())
|
86 |
+
intrinsics = utils3d.torch.intrinsics_from_fov_xy(fov, fov)
|
87 |
+
exts.append(extrinsics)
|
88 |
+
ints.append(intrinsics)
|
89 |
+
|
90 |
+
images = []
|
91 |
+
|
92 |
+
# Build each representation
|
93 |
+
x_0 = x_0.cuda()
|
94 |
+
for i in range(x_0.shape[0]):
|
95 |
+
representation = Octree(
|
96 |
+
depth=10,
|
97 |
+
aabb=[-0.5, -0.5, -0.5, 1, 1, 1],
|
98 |
+
device='cuda',
|
99 |
+
primitive='voxel',
|
100 |
+
sh_degree=0,
|
101 |
+
primitive_config={'solid': True},
|
102 |
+
)
|
103 |
+
coords = torch.nonzero(x_0[i, 0] > 0, as_tuple=False)
|
104 |
+
resolution = x_0.shape[-1]
|
105 |
+
representation.position = coords.float() / resolution
|
106 |
+
representation.depth = torch.full((representation.position.shape[0], 1), int(np.log2(resolution)), dtype=torch.uint8, device='cuda')
|
107 |
+
|
108 |
+
image = torch.zeros(3, 1024, 1024).cuda()
|
109 |
+
tile = [2, 2]
|
110 |
+
for j, (ext, intr) in enumerate(zip(exts, ints)):
|
111 |
+
res = renderer.render(representation, ext, intr, colors_overwrite=representation.position)
|
112 |
+
image[:, 512 * (j // tile[1]):512 * (j // tile[1] + 1), 512 * (j % tile[1]):512 * (j % tile[1] + 1)] = res['color']
|
113 |
+
images.append(image)
|
114 |
+
|
115 |
+
return torch.stack(images)
|
116 |
+
|
117 |
+
|
118 |
+
class SparseStructureLatent(SparseStructureLatentVisMixin, StandardDatasetBase):
|
119 |
+
"""
|
120 |
+
Sparse structure latent dataset
|
121 |
+
|
122 |
+
Args:
|
123 |
+
roots (str): path to the dataset
|
124 |
+
latent_model (str): name of the latent model
|
125 |
+
min_aesthetic_score (float): minimum aesthetic score
|
126 |
+
normalization (dict): normalization stats
|
127 |
+
pretrained_ss_dec (str): name of the pretrained sparse structure decoder
|
128 |
+
ss_dec_path (str): path to the sparse structure decoder, if given, will override the pretrained_ss_dec
|
129 |
+
ss_dec_ckpt (str): name of the sparse structure decoder checkpoint
|
130 |
+
"""
|
131 |
+
def __init__(self,
|
132 |
+
roots: str,
|
133 |
+
*,
|
134 |
+
latent_model: str,
|
135 |
+
min_aesthetic_score: float = 5.0,
|
136 |
+
normalization: Optional[dict] = None,
|
137 |
+
pretrained_ss_dec: str = 'JeffreyXiang/TRELLIS-image-large/ckpts/ss_dec_conv3d_16l8_fp16',
|
138 |
+
ss_dec_path: Optional[str] = None,
|
139 |
+
ss_dec_ckpt: Optional[str] = None,
|
140 |
+
):
|
141 |
+
self.latent_model = latent_model
|
142 |
+
self.min_aesthetic_score = min_aesthetic_score
|
143 |
+
self.normalization = normalization
|
144 |
+
self.value_range = (0, 1)
|
145 |
+
|
146 |
+
super().__init__(
|
147 |
+
roots,
|
148 |
+
pretrained_ss_dec=pretrained_ss_dec,
|
149 |
+
ss_dec_path=ss_dec_path,
|
150 |
+
ss_dec_ckpt=ss_dec_ckpt,
|
151 |
+
)
|
152 |
+
|
153 |
+
if self.normalization is not None:
|
154 |
+
self.mean = torch.tensor(self.normalization['mean']).reshape(-1, 1, 1, 1)
|
155 |
+
self.std = torch.tensor(self.normalization['std']).reshape(-1, 1, 1, 1)
|
156 |
+
|
157 |
+
def filter_metadata(self, metadata):
|
158 |
+
stats = {}
|
159 |
+
metadata = metadata[metadata[f'ss_latent_{self.latent_model}']]
|
160 |
+
stats['With sparse structure latents'] = len(metadata)
|
161 |
+
metadata = metadata[metadata['aesthetic_score'] >= self.min_aesthetic_score]
|
162 |
+
stats[f'Aesthetic score >= {self.min_aesthetic_score}'] = len(metadata)
|
163 |
+
return metadata, stats
|
164 |
+
|
165 |
+
def get_instance(self, root, instance):
|
166 |
+
latent = np.load(os.path.join(root, 'ss_latents', self.latent_model, f'{instance}.npz'))
|
167 |
+
z = torch.tensor(latent['mean']).float()
|
168 |
+
if self.normalization is not None:
|
169 |
+
z = (z - self.mean) / self.std
|
170 |
+
|
171 |
+
pack = {
|
172 |
+
'x_0': z,
|
173 |
+
}
|
174 |
+
return pack
|
175 |
+
|
176 |
+
|
177 |
+
class TextConditionedSparseStructureLatent(TextConditionedMixin, SparseStructureLatent):
|
178 |
+
"""
|
179 |
+
Text-conditioned sparse structure dataset
|
180 |
+
"""
|
181 |
+
pass
|
182 |
+
|
183 |
+
|
184 |
+
class ImageConditionedSparseStructureLatent(ImageConditionedMixin, SparseStructureLatent):
|
185 |
+
"""
|
186 |
+
Image-conditioned sparse structure dataset
|
187 |
+
"""
|
188 |
+
pass
|
189 |
+
|
trellis/datasets/structured_latent.py
ADDED
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
from typing import *
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import utils3d.torch
|
7 |
+
from .components import StandardDatasetBase, TextConditionedMixin, ImageConditionedMixin
|
8 |
+
from ..modules.sparse.basic import SparseTensor
|
9 |
+
from .. import models
|
10 |
+
from ..utils.render_utils import get_renderer
|
11 |
+
from ..utils.dist_utils import read_file_dist
|
12 |
+
from ..utils.data_utils import load_balanced_group_indices
|
13 |
+
|
14 |
+
|
15 |
+
class SLatVisMixin:
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
*args,
|
19 |
+
pretrained_slat_dec: str = 'JeffreyXiang/TRELLIS-image-large/ckpts/slat_dec_gs_swin8_B_64l8gs32_fp16',
|
20 |
+
slat_dec_path: Optional[str] = None,
|
21 |
+
slat_dec_ckpt: Optional[str] = None,
|
22 |
+
**kwargs
|
23 |
+
):
|
24 |
+
super().__init__(*args, **kwargs)
|
25 |
+
self.slat_dec = None
|
26 |
+
self.pretrained_slat_dec = pretrained_slat_dec
|
27 |
+
self.slat_dec_path = slat_dec_path
|
28 |
+
self.slat_dec_ckpt = slat_dec_ckpt
|
29 |
+
|
30 |
+
def _loading_slat_dec(self):
|
31 |
+
if self.slat_dec is not None:
|
32 |
+
return
|
33 |
+
if self.slat_dec_path is not None:
|
34 |
+
cfg = json.load(open(os.path.join(self.slat_dec_path, 'config.json'), 'r'))
|
35 |
+
decoder = getattr(models, cfg['models']['decoder']['name'])(**cfg['models']['decoder']['args'])
|
36 |
+
ckpt_path = os.path.join(self.slat_dec_path, 'ckpts', f'decoder_{self.slat_dec_ckpt}.pt')
|
37 |
+
decoder.load_state_dict(torch.load(read_file_dist(ckpt_path), map_location='cpu', weights_only=True))
|
38 |
+
else:
|
39 |
+
decoder = models.from_pretrained(self.pretrained_slat_dec)
|
40 |
+
self.slat_dec = decoder.cuda().eval()
|
41 |
+
|
42 |
+
def _delete_slat_dec(self):
|
43 |
+
del self.slat_dec
|
44 |
+
self.slat_dec = None
|
45 |
+
|
46 |
+
@torch.no_grad()
|
47 |
+
def decode_latent(self, z, batch_size=4):
|
48 |
+
self._loading_slat_dec()
|
49 |
+
reps = []
|
50 |
+
if self.normalization is not None:
|
51 |
+
z = z * self.std.to(z.device) + self.mean.to(z.device)
|
52 |
+
for i in range(0, z.shape[0], batch_size):
|
53 |
+
reps.append(self.slat_dec(z[i:i+batch_size]))
|
54 |
+
reps = sum(reps, [])
|
55 |
+
self._delete_slat_dec()
|
56 |
+
return reps
|
57 |
+
|
58 |
+
@torch.no_grad()
|
59 |
+
def visualize_sample(self, x_0: Union[SparseTensor, dict]):
|
60 |
+
x_0 = x_0 if isinstance(x_0, SparseTensor) else x_0['x_0']
|
61 |
+
reps = self.decode_latent(x_0.cuda())
|
62 |
+
|
63 |
+
# Build camera
|
64 |
+
yaws = [0, np.pi / 2, np.pi, 3 * np.pi / 2]
|
65 |
+
yaws_offset = np.random.uniform(-np.pi / 4, np.pi / 4)
|
66 |
+
yaws = [y + yaws_offset for y in yaws]
|
67 |
+
pitch = [np.random.uniform(-np.pi / 4, np.pi / 4) for _ in range(4)]
|
68 |
+
|
69 |
+
exts = []
|
70 |
+
ints = []
|
71 |
+
for yaw, pitch in zip(yaws, pitch):
|
72 |
+
orig = torch.tensor([
|
73 |
+
np.sin(yaw) * np.cos(pitch),
|
74 |
+
np.cos(yaw) * np.cos(pitch),
|
75 |
+
np.sin(pitch),
|
76 |
+
]).float().cuda() * 2
|
77 |
+
fov = torch.deg2rad(torch.tensor(40)).cuda()
|
78 |
+
extrinsics = utils3d.torch.extrinsics_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda())
|
79 |
+
intrinsics = utils3d.torch.intrinsics_from_fov_xy(fov, fov)
|
80 |
+
exts.append(extrinsics)
|
81 |
+
ints.append(intrinsics)
|
82 |
+
|
83 |
+
renderer = get_renderer(reps[0])
|
84 |
+
images = []
|
85 |
+
for representation in reps:
|
86 |
+
image = torch.zeros(3, 1024, 1024).cuda()
|
87 |
+
tile = [2, 2]
|
88 |
+
for j, (ext, intr) in enumerate(zip(exts, ints)):
|
89 |
+
res = renderer.render(representation, ext, intr)
|
90 |
+
image[:, 512 * (j // tile[1]):512 * (j // tile[1] + 1), 512 * (j % tile[1]):512 * (j % tile[1] + 1)] = res['color']
|
91 |
+
images.append(image)
|
92 |
+
images = torch.stack(images)
|
93 |
+
|
94 |
+
return images
|
95 |
+
|
96 |
+
|
97 |
+
class SLat(SLatVisMixin, StandardDatasetBase):
|
98 |
+
"""
|
99 |
+
structured latent dataset
|
100 |
+
|
101 |
+
Args:
|
102 |
+
roots (str): path to the dataset
|
103 |
+
latent_model (str): name of the latent model
|
104 |
+
min_aesthetic_score (float): minimum aesthetic score
|
105 |
+
max_num_voxels (int): maximum number of voxels
|
106 |
+
normalization (dict): normalization stats
|
107 |
+
pretrained_slat_dec (str): name of the pretrained slat decoder
|
108 |
+
slat_dec_path (str): path to the slat decoder, if given, will override the pretrained_slat_dec
|
109 |
+
slat_dec_ckpt (str): name of the slat decoder checkpoint
|
110 |
+
"""
|
111 |
+
def __init__(self,
|
112 |
+
roots: str,
|
113 |
+
*,
|
114 |
+
latent_model: str,
|
115 |
+
min_aesthetic_score: float = 5.0,
|
116 |
+
max_num_voxels: int = 32768,
|
117 |
+
normalization: Optional[dict] = None,
|
118 |
+
pretrained_slat_dec: str = 'JeffreyXiang/TRELLIS-image-large/ckpts/slat_dec_gs_swin8_B_64l8gs32_fp16',
|
119 |
+
slat_dec_path: Optional[str] = None,
|
120 |
+
slat_dec_ckpt: Optional[str] = None,
|
121 |
+
):
|
122 |
+
self.normalization = normalization
|
123 |
+
self.latent_model = latent_model
|
124 |
+
self.min_aesthetic_score = min_aesthetic_score
|
125 |
+
self.max_num_voxels = max_num_voxels
|
126 |
+
self.value_range = (0, 1)
|
127 |
+
|
128 |
+
super().__init__(
|
129 |
+
roots,
|
130 |
+
pretrained_slat_dec=pretrained_slat_dec,
|
131 |
+
slat_dec_path=slat_dec_path,
|
132 |
+
slat_dec_ckpt=slat_dec_ckpt,
|
133 |
+
)
|
134 |
+
|
135 |
+
self.loads = [self.metadata.loc[sha256, 'num_voxels'] for _, sha256 in self.instances]
|
136 |
+
|
137 |
+
if self.normalization is not None:
|
138 |
+
self.mean = torch.tensor(self.normalization['mean']).reshape(1, -1)
|
139 |
+
self.std = torch.tensor(self.normalization['std']).reshape(1, -1)
|
140 |
+
|
141 |
+
def filter_metadata(self, metadata):
|
142 |
+
stats = {}
|
143 |
+
metadata = metadata[metadata[f'latent_{self.latent_model}']]
|
144 |
+
stats['With latent'] = len(metadata)
|
145 |
+
metadata = metadata[metadata['aesthetic_score'] >= self.min_aesthetic_score]
|
146 |
+
stats[f'Aesthetic score >= {self.min_aesthetic_score}'] = len(metadata)
|
147 |
+
metadata = metadata[metadata['num_voxels'] <= self.max_num_voxels]
|
148 |
+
stats[f'Num voxels <= {self.max_num_voxels}'] = len(metadata)
|
149 |
+
return metadata, stats
|
150 |
+
|
151 |
+
def get_instance(self, root, instance):
|
152 |
+
data = np.load(os.path.join(root, 'latents', self.latent_model, f'{instance}.npz'))
|
153 |
+
coords = torch.tensor(data['coords']).int()
|
154 |
+
feats = torch.tensor(data['feats']).float()
|
155 |
+
if self.normalization is not None:
|
156 |
+
feats = (feats - self.mean) / self.std
|
157 |
+
return {
|
158 |
+
'coords': coords,
|
159 |
+
'feats': feats,
|
160 |
+
}
|
161 |
+
|
162 |
+
@staticmethod
|
163 |
+
def collate_fn(batch, split_size=None):
|
164 |
+
if split_size is None:
|
165 |
+
group_idx = [list(range(len(batch)))]
|
166 |
+
else:
|
167 |
+
group_idx = load_balanced_group_indices([b['coords'].shape[0] for b in batch], split_size)
|
168 |
+
packs = []
|
169 |
+
for group in group_idx:
|
170 |
+
sub_batch = [batch[i] for i in group]
|
171 |
+
pack = {}
|
172 |
+
coords = []
|
173 |
+
feats = []
|
174 |
+
layout = []
|
175 |
+
start = 0
|
176 |
+
for i, b in enumerate(sub_batch):
|
177 |
+
coords.append(torch.cat([torch.full((b['coords'].shape[0], 1), i, dtype=torch.int32), b['coords']], dim=-1))
|
178 |
+
feats.append(b['feats'])
|
179 |
+
layout.append(slice(start, start + b['coords'].shape[0]))
|
180 |
+
start += b['coords'].shape[0]
|
181 |
+
coords = torch.cat(coords)
|
182 |
+
feats = torch.cat(feats)
|
183 |
+
pack['x_0'] = SparseTensor(
|
184 |
+
coords=coords,
|
185 |
+
feats=feats,
|
186 |
+
)
|
187 |
+
pack['x_0']._shape = torch.Size([len(group), *sub_batch[0]['feats'].shape[1:]])
|
188 |
+
pack['x_0'].register_spatial_cache('layout', layout)
|
189 |
+
|
190 |
+
# collate other data
|
191 |
+
keys = [k for k in sub_batch[0].keys() if k not in ['coords', 'feats']]
|
192 |
+
for k in keys:
|
193 |
+
if isinstance(sub_batch[0][k], torch.Tensor):
|
194 |
+
pack[k] = torch.stack([b[k] for b in sub_batch])
|
195 |
+
elif isinstance(sub_batch[0][k], list):
|
196 |
+
pack[k] = sum([b[k] for b in sub_batch], [])
|
197 |
+
else:
|
198 |
+
pack[k] = [b[k] for b in sub_batch]
|
199 |
+
|
200 |
+
packs.append(pack)
|
201 |
+
|
202 |
+
if split_size is None:
|
203 |
+
return packs[0]
|
204 |
+
return packs
|
205 |
+
|
206 |
+
|
207 |
+
class TextConditionedSLat(TextConditionedMixin, SLat):
|
208 |
+
"""
|
209 |
+
Text conditioned structured latent dataset
|
210 |
+
"""
|
211 |
+
pass
|
212 |
+
|
213 |
+
|
214 |
+
class ImageConditionedSLat(ImageConditionedMixin, SLat):
|
215 |
+
"""
|
216 |
+
Image conditioned structured latent dataset
|
217 |
+
"""
|
218 |
+
pass
|
trellis/datasets/structured_latent2render.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from PIL import Image
|
3 |
+
import json
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import utils3d.torch
|
7 |
+
from ..modules.sparse.basic import SparseTensor
|
8 |
+
from .components import StandardDatasetBase
|
9 |
+
|
10 |
+
|
11 |
+
class SLat2Render(StandardDatasetBase):
|
12 |
+
"""
|
13 |
+
Dataset for Structured Latent and rendered images.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
roots (str): paths to the dataset
|
17 |
+
image_size (int): size of the image
|
18 |
+
latent_model (str): latent model name
|
19 |
+
min_aesthetic_score (float): minimum aesthetic score
|
20 |
+
max_num_voxels (int): maximum number of voxels
|
21 |
+
"""
|
22 |
+
def __init__(
|
23 |
+
self,
|
24 |
+
roots: str,
|
25 |
+
image_size: int,
|
26 |
+
latent_model: str,
|
27 |
+
min_aesthetic_score: float = 5.0,
|
28 |
+
max_num_voxels: int = 32768,
|
29 |
+
):
|
30 |
+
self.image_size = image_size
|
31 |
+
self.latent_model = latent_model
|
32 |
+
self.min_aesthetic_score = min_aesthetic_score
|
33 |
+
self.max_num_voxels = max_num_voxels
|
34 |
+
self.value_range = (0, 1)
|
35 |
+
|
36 |
+
super().__init__(roots)
|
37 |
+
|
38 |
+
def filter_metadata(self, metadata):
|
39 |
+
stats = {}
|
40 |
+
metadata = metadata[metadata[f'latent_{self.latent_model}']]
|
41 |
+
stats['With latent'] = len(metadata)
|
42 |
+
metadata = metadata[metadata['aesthetic_score'] >= self.min_aesthetic_score]
|
43 |
+
stats[f'Aesthetic score >= {self.min_aesthetic_score}'] = len(metadata)
|
44 |
+
metadata = metadata[metadata['num_voxels'] <= self.max_num_voxels]
|
45 |
+
stats[f'Num voxels <= {self.max_num_voxels}'] = len(metadata)
|
46 |
+
return metadata, stats
|
47 |
+
|
48 |
+
def _get_image(self, root, instance):
|
49 |
+
with open(os.path.join(root, 'renders', instance, 'transforms.json')) as f:
|
50 |
+
metadata = json.load(f)
|
51 |
+
n_views = len(metadata['frames'])
|
52 |
+
view = np.random.randint(n_views)
|
53 |
+
metadata = metadata['frames'][view]
|
54 |
+
fov = metadata['camera_angle_x']
|
55 |
+
intrinsics = utils3d.torch.intrinsics_from_fov_xy(torch.tensor(fov), torch.tensor(fov))
|
56 |
+
c2w = torch.tensor(metadata['transform_matrix'])
|
57 |
+
c2w[:3, 1:3] *= -1
|
58 |
+
extrinsics = torch.inverse(c2w)
|
59 |
+
|
60 |
+
image_path = os.path.join(root, 'renders', instance, metadata['file_path'])
|
61 |
+
image = Image.open(image_path)
|
62 |
+
alpha = image.getchannel(3)
|
63 |
+
image = image.convert('RGB')
|
64 |
+
image = image.resize((self.image_size, self.image_size), Image.Resampling.LANCZOS)
|
65 |
+
alpha = alpha.resize((self.image_size, self.image_size), Image.Resampling.LANCZOS)
|
66 |
+
image = torch.tensor(np.array(image)).permute(2, 0, 1).float() / 255.0
|
67 |
+
alpha = torch.tensor(np.array(alpha)).float() / 255.0
|
68 |
+
|
69 |
+
return {
|
70 |
+
'image': image,
|
71 |
+
'alpha': alpha,
|
72 |
+
'extrinsics': extrinsics,
|
73 |
+
'intrinsics': intrinsics,
|
74 |
+
}
|
75 |
+
|
76 |
+
def _get_latent(self, root, instance):
|
77 |
+
data = np.load(os.path.join(root, 'latents', self.latent_model, f'{instance}.npz'))
|
78 |
+
coords = torch.tensor(data['coords']).int()
|
79 |
+
feats = torch.tensor(data['feats']).float()
|
80 |
+
return {
|
81 |
+
'coords': coords,
|
82 |
+
'feats': feats,
|
83 |
+
}
|
84 |
+
|
85 |
+
@torch.no_grad()
|
86 |
+
def visualize_sample(self, sample: dict):
|
87 |
+
return sample['image']
|
88 |
+
|
89 |
+
@staticmethod
|
90 |
+
def collate_fn(batch):
|
91 |
+
pack = {}
|
92 |
+
coords = []
|
93 |
+
for i, b in enumerate(batch):
|
94 |
+
coords.append(torch.cat([torch.full((b['coords'].shape[0], 1), i, dtype=torch.int32), b['coords']], dim=-1))
|
95 |
+
coords = torch.cat(coords)
|
96 |
+
feats = torch.cat([b['feats'] for b in batch])
|
97 |
+
pack['latents'] = SparseTensor(
|
98 |
+
coords=coords,
|
99 |
+
feats=feats,
|
100 |
+
)
|
101 |
+
|
102 |
+
# collate other data
|
103 |
+
keys = [k for k in batch[0].keys() if k not in ['coords', 'feats']]
|
104 |
+
for k in keys:
|
105 |
+
if isinstance(batch[0][k], torch.Tensor):
|
106 |
+
pack[k] = torch.stack([b[k] for b in batch])
|
107 |
+
elif isinstance(batch[0][k], list):
|
108 |
+
pack[k] = sum([b[k] for b in batch], [])
|
109 |
+
else:
|
110 |
+
pack[k] = [b[k] for b in batch]
|
111 |
+
|
112 |
+
return pack
|
113 |
+
|
114 |
+
def get_instance(self, root, instance):
|
115 |
+
image = self._get_image(root, instance)
|
116 |
+
latent = self._get_latent(root, instance)
|
117 |
+
return {
|
118 |
+
**image,
|
119 |
+
**latent,
|
120 |
+
}
|
121 |
+
|
122 |
+
|
123 |
+
class Slat2RenderGeo(SLat2Render):
|
124 |
+
def __init__(
|
125 |
+
self,
|
126 |
+
roots: str,
|
127 |
+
image_size: int,
|
128 |
+
latent_model: str,
|
129 |
+
min_aesthetic_score: float = 5.0,
|
130 |
+
max_num_voxels: int = 32768,
|
131 |
+
):
|
132 |
+
super().__init__(
|
133 |
+
roots,
|
134 |
+
image_size,
|
135 |
+
latent_model,
|
136 |
+
min_aesthetic_score,
|
137 |
+
max_num_voxels,
|
138 |
+
)
|
139 |
+
|
140 |
+
def _get_geo(self, root, instance):
|
141 |
+
verts, face = utils3d.io.read_ply(os.path.join(root, 'renders', instance, 'mesh.ply'))
|
142 |
+
mesh = {
|
143 |
+
"vertices" : torch.from_numpy(verts),
|
144 |
+
"faces" : torch.from_numpy(face),
|
145 |
+
}
|
146 |
+
return {
|
147 |
+
"mesh" : mesh,
|
148 |
+
}
|
149 |
+
|
150 |
+
def get_instance(self, root, instance):
|
151 |
+
image = self._get_image(root, instance)
|
152 |
+
latent = self._get_latent(root, instance)
|
153 |
+
geo = self._get_geo(root, instance)
|
154 |
+
return {
|
155 |
+
**image,
|
156 |
+
**latent,
|
157 |
+
**geo,
|
158 |
+
}
|
159 |
+
|
160 |
+
|
trellis/models/__init__.py
CHANGED
@@ -3,12 +3,20 @@ import importlib
|
|
3 |
__attributes = {
|
4 |
'SparseStructureEncoder': 'sparse_structure_vae',
|
5 |
'SparseStructureDecoder': 'sparse_structure_vae',
|
|
|
6 |
'SparseStructureFlowModel': 'sparse_structure_flow',
|
|
|
7 |
'SLatEncoder': 'structured_latent_vae',
|
8 |
'SLatGaussianDecoder': 'structured_latent_vae',
|
9 |
'SLatRadianceFieldDecoder': 'structured_latent_vae',
|
10 |
'SLatMeshDecoder': 'structured_latent_vae',
|
|
|
|
|
|
|
|
|
|
|
11 |
'SLatFlowModel': 'structured_latent_flow',
|
|
|
12 |
}
|
13 |
|
14 |
__submodules = []
|
@@ -64,7 +72,25 @@ def from_pretrained(path: str, **kwargs):
|
|
64 |
|
65 |
# For Pylance
|
66 |
if __name__ == '__main__':
|
67 |
-
from .sparse_structure_vae import
|
|
|
|
|
|
|
|
|
68 |
from .sparse_structure_flow import SparseStructureFlowModel
|
69 |
-
|
70 |
-
from .
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
__attributes = {
|
4 |
'SparseStructureEncoder': 'sparse_structure_vae',
|
5 |
'SparseStructureDecoder': 'sparse_structure_vae',
|
6 |
+
|
7 |
'SparseStructureFlowModel': 'sparse_structure_flow',
|
8 |
+
|
9 |
'SLatEncoder': 'structured_latent_vae',
|
10 |
'SLatGaussianDecoder': 'structured_latent_vae',
|
11 |
'SLatRadianceFieldDecoder': 'structured_latent_vae',
|
12 |
'SLatMeshDecoder': 'structured_latent_vae',
|
13 |
+
'ElasticSLatEncoder': 'structured_latent_vae',
|
14 |
+
'ElasticSLatGaussianDecoder': 'structured_latent_vae',
|
15 |
+
'ElasticSLatRadianceFieldDecoder': 'structured_latent_vae',
|
16 |
+
'ElasticSLatMeshDecoder': 'structured_latent_vae',
|
17 |
+
|
18 |
'SLatFlowModel': 'structured_latent_flow',
|
19 |
+
'ElasticSLatFlowModel': 'structured_latent_flow',
|
20 |
}
|
21 |
|
22 |
__submodules = []
|
|
|
72 |
|
73 |
# For Pylance
|
74 |
if __name__ == '__main__':
|
75 |
+
from .sparse_structure_vae import (
|
76 |
+
SparseStructureEncoder,
|
77 |
+
SparseStructureDecoder,
|
78 |
+
)
|
79 |
+
|
80 |
from .sparse_structure_flow import SparseStructureFlowModel
|
81 |
+
|
82 |
+
from .structured_latent_vae import (
|
83 |
+
SLatEncoder,
|
84 |
+
SLatGaussianDecoder,
|
85 |
+
SLatRadianceFieldDecoder,
|
86 |
+
SLatMeshDecoder,
|
87 |
+
ElasticSLatEncoder,
|
88 |
+
ElasticSLatGaussianDecoder,
|
89 |
+
ElasticSLatRadianceFieldDecoder,
|
90 |
+
ElasticSLatMeshDecoder,
|
91 |
+
)
|
92 |
+
|
93 |
+
from .structured_latent_flow import (
|
94 |
+
SLatFlowModel,
|
95 |
+
ElasticSLatFlowModel,
|
96 |
+
)
|
trellis/models/sparse_elastic_mixin.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from contextlib import contextmanager
|
2 |
+
from typing import *
|
3 |
+
import math
|
4 |
+
from ..modules import sparse as sp
|
5 |
+
from ..utils.elastic_utils import ElasticModuleMixin
|
6 |
+
|
7 |
+
|
8 |
+
class SparseTransformerElasticMixin(ElasticModuleMixin):
|
9 |
+
def _get_input_size(self, x: sp.SparseTensor, *args, **kwargs):
|
10 |
+
return x.feats.shape[0]
|
11 |
+
|
12 |
+
@contextmanager
|
13 |
+
def with_mem_ratio(self, mem_ratio=1.0):
|
14 |
+
if mem_ratio == 1.0:
|
15 |
+
yield 1.0
|
16 |
+
return
|
17 |
+
num_blocks = len(self.blocks)
|
18 |
+
num_checkpoint_blocks = min(math.ceil((1 - mem_ratio) * num_blocks) + 1, num_blocks)
|
19 |
+
exact_mem_ratio = 1 - (num_checkpoint_blocks - 1) / num_blocks
|
20 |
+
for i in range(num_blocks):
|
21 |
+
self.blocks[i].use_checkpoint = i < num_checkpoint_blocks
|
22 |
+
yield exact_mem_ratio
|
23 |
+
for i in range(num_blocks):
|
24 |
+
self.blocks[i].use_checkpoint = False
|
trellis/models/structured_latent_flow.py
CHANGED
@@ -9,6 +9,7 @@ from ..modules.norm import LayerNorm32
|
|
9 |
from ..modules import sparse as sp
|
10 |
from ..modules.sparse.transformer import ModulatedSparseTransformerCrossBlock
|
11 |
from .sparse_structure_flow import TimestepEmbedder
|
|
|
12 |
|
13 |
|
14 |
class SparseResBlock3d(nn.Module):
|
@@ -109,8 +110,9 @@ class SLatFlowModel(nn.Module):
|
|
109 |
self.qk_rms_norm_cross = qk_rms_norm_cross
|
110 |
self.dtype = torch.float16 if use_fp16 else torch.float32
|
111 |
|
112 |
-
|
113 |
-
|
|
|
114 |
|
115 |
self.t_embedder = TimestepEmbedder(model_channels)
|
116 |
if share_mod:
|
@@ -122,25 +124,27 @@ class SLatFlowModel(nn.Module):
|
|
122 |
if pe_mode == "ape":
|
123 |
self.pos_embedder = AbsolutePositionEmbedder(model_channels)
|
124 |
|
125 |
-
self.input_layer = sp.SparseLinear(in_channels, io_block_channels[0])
|
|
|
126 |
self.input_blocks = nn.ModuleList([])
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
|
|
|
|
142 |
)
|
143 |
-
)
|
144 |
|
145 |
self.blocks = nn.ModuleList([
|
146 |
ModulatedSparseTransformerCrossBlock(
|
@@ -159,24 +163,26 @@ class SLatFlowModel(nn.Module):
|
|
159 |
])
|
160 |
|
161 |
self.out_blocks = nn.ModuleList([])
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
self.out_blocks.extend([
|
172 |
-
SparseResBlock3d(
|
173 |
-
chs * 2 if self.use_skip_connection else chs,
|
174 |
-
model_channels,
|
175 |
-
out_channels=chs,
|
176 |
)
|
177 |
-
|
178 |
-
|
179 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
180 |
|
181 |
self.initialize_weights()
|
182 |
if use_fp16:
|
@@ -260,3 +266,11 @@ class SLatFlowModel(nn.Module):
|
|
260 |
h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
|
261 |
h = self.out_layer(h.type(x.dtype))
|
262 |
return h
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
from ..modules import sparse as sp
|
10 |
from ..modules.sparse.transformer import ModulatedSparseTransformerCrossBlock
|
11 |
from .sparse_structure_flow import TimestepEmbedder
|
12 |
+
from .sparse_elastic_mixin import SparseTransformerElasticMixin
|
13 |
|
14 |
|
15 |
class SparseResBlock3d(nn.Module):
|
|
|
110 |
self.qk_rms_norm_cross = qk_rms_norm_cross
|
111 |
self.dtype = torch.float16 if use_fp16 else torch.float32
|
112 |
|
113 |
+
if self.io_block_channels is not None:
|
114 |
+
assert int(np.log2(patch_size)) == np.log2(patch_size), "Patch size must be a power of 2"
|
115 |
+
assert np.log2(patch_size) == len(io_block_channels), "Number of IO ResBlocks must match the number of stages"
|
116 |
|
117 |
self.t_embedder = TimestepEmbedder(model_channels)
|
118 |
if share_mod:
|
|
|
124 |
if pe_mode == "ape":
|
125 |
self.pos_embedder = AbsolutePositionEmbedder(model_channels)
|
126 |
|
127 |
+
self.input_layer = sp.SparseLinear(in_channels, model_channels if io_block_channels is None else io_block_channels[0])
|
128 |
+
|
129 |
self.input_blocks = nn.ModuleList([])
|
130 |
+
if io_block_channels is not None:
|
131 |
+
for chs, next_chs in zip(io_block_channels, io_block_channels[1:] + [model_channels]):
|
132 |
+
self.input_blocks.extend([
|
133 |
+
SparseResBlock3d(
|
134 |
+
chs,
|
135 |
+
model_channels,
|
136 |
+
out_channels=chs,
|
137 |
+
)
|
138 |
+
for _ in range(num_io_res_blocks-1)
|
139 |
+
])
|
140 |
+
self.input_blocks.append(
|
141 |
+
SparseResBlock3d(
|
142 |
+
chs,
|
143 |
+
model_channels,
|
144 |
+
out_channels=next_chs,
|
145 |
+
downsample=True,
|
146 |
+
)
|
147 |
)
|
|
|
148 |
|
149 |
self.blocks = nn.ModuleList([
|
150 |
ModulatedSparseTransformerCrossBlock(
|
|
|
163 |
])
|
164 |
|
165 |
self.out_blocks = nn.ModuleList([])
|
166 |
+
if io_block_channels is not None:
|
167 |
+
for chs, prev_chs in zip(reversed(io_block_channels), [model_channels] + list(reversed(io_block_channels[1:]))):
|
168 |
+
self.out_blocks.append(
|
169 |
+
SparseResBlock3d(
|
170 |
+
prev_chs * 2 if self.use_skip_connection else prev_chs,
|
171 |
+
model_channels,
|
172 |
+
out_channels=chs,
|
173 |
+
upsample=True,
|
174 |
+
)
|
|
|
|
|
|
|
|
|
|
|
175 |
)
|
176 |
+
self.out_blocks.extend([
|
177 |
+
SparseResBlock3d(
|
178 |
+
chs * 2 if self.use_skip_connection else chs,
|
179 |
+
model_channels,
|
180 |
+
out_channels=chs,
|
181 |
+
)
|
182 |
+
for _ in range(num_io_res_blocks-1)
|
183 |
+
])
|
184 |
+
|
185 |
+
self.out_layer = sp.SparseLinear(model_channels if io_block_channels is None else io_block_channels[0], out_channels)
|
186 |
|
187 |
self.initialize_weights()
|
188 |
if use_fp16:
|
|
|
266 |
h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
|
267 |
h = self.out_layer(h.type(x.dtype))
|
268 |
return h
|
269 |
+
|
270 |
+
|
271 |
+
class ElasticSLatFlowModel(SparseTransformerElasticMixin, SLatFlowModel):
|
272 |
+
"""
|
273 |
+
SLat Flow Model with elastic memory management.
|
274 |
+
Used for training with low VRAM.
|
275 |
+
"""
|
276 |
+
pass
|
trellis/models/structured_latent_vae/__init__.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
from .encoder import SLatEncoder
|
2 |
-
from .decoder_gs import SLatGaussianDecoder
|
3 |
-
from .decoder_rf import SLatRadianceFieldDecoder
|
4 |
-
from .decoder_mesh import SLatMeshDecoder
|
|
|
1 |
+
from .encoder import SLatEncoder, ElasticSLatEncoder
|
2 |
+
from .decoder_gs import SLatGaussianDecoder, ElasticSLatGaussianDecoder
|
3 |
+
from .decoder_rf import SLatRadianceFieldDecoder, ElasticSLatRadianceFieldDecoder
|
4 |
+
from .decoder_mesh import SLatMeshDecoder, ElasticSLatMeshDecoder
|
trellis/models/structured_latent_vae/decoder_gs.py
CHANGED
@@ -6,6 +6,7 @@ from ...modules import sparse as sp
|
|
6 |
from ...utils.random_utils import hammersley_sequence
|
7 |
from .base import SparseTransformerBase
|
8 |
from ...representations import Gaussian
|
|
|
9 |
|
10 |
|
11 |
class SLatGaussianDecoder(SparseTransformerBase):
|
@@ -120,3 +121,11 @@ class SLatGaussianDecoder(SparseTransformerBase):
|
|
120 |
h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
|
121 |
h = self.out_layer(h)
|
122 |
return self.to_representation(h)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
from ...utils.random_utils import hammersley_sequence
|
7 |
from .base import SparseTransformerBase
|
8 |
from ...representations import Gaussian
|
9 |
+
from ..sparse_elastic_mixin import SparseTransformerElasticMixin
|
10 |
|
11 |
|
12 |
class SLatGaussianDecoder(SparseTransformerBase):
|
|
|
121 |
h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
|
122 |
h = self.out_layer(h)
|
123 |
return self.to_representation(h)
|
124 |
+
|
125 |
+
|
126 |
+
class ElasticSLatGaussianDecoder(SparseTransformerElasticMixin, SLatGaussianDecoder):
|
127 |
+
"""
|
128 |
+
Slat VAE Gaussian decoder with elastic memory management.
|
129 |
+
Used for training with low VRAM.
|
130 |
+
"""
|
131 |
+
pass
|
trellis/models/structured_latent_vae/decoder_mesh.py
CHANGED
@@ -8,6 +8,7 @@ from ...modules import sparse as sp
|
|
8 |
from .base import SparseTransformerBase
|
9 |
from ...representations import MeshExtractResult
|
10 |
from ...representations.mesh import SparseFeatures2Mesh
|
|
|
11 |
|
12 |
|
13 |
class SparseSubdivideBlock3d(nn.Module):
|
@@ -165,3 +166,11 @@ class SLatMeshDecoder(SparseTransformerBase):
|
|
165 |
h = h.type(x.dtype)
|
166 |
h = self.out_layer(h)
|
167 |
return self.to_representation(h)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
from .base import SparseTransformerBase
|
9 |
from ...representations import MeshExtractResult
|
10 |
from ...representations.mesh import SparseFeatures2Mesh
|
11 |
+
from ..sparse_elastic_mixin import SparseTransformerElasticMixin
|
12 |
|
13 |
|
14 |
class SparseSubdivideBlock3d(nn.Module):
|
|
|
166 |
h = h.type(x.dtype)
|
167 |
h = self.out_layer(h)
|
168 |
return self.to_representation(h)
|
169 |
+
|
170 |
+
|
171 |
+
class ElasticSLatMeshDecoder(SparseTransformerElasticMixin, SLatMeshDecoder):
|
172 |
+
"""
|
173 |
+
Slat VAE Mesh decoder with elastic memory management.
|
174 |
+
Used for training with low VRAM.
|
175 |
+
"""
|
176 |
+
pass
|
trellis/models/structured_latent_vae/decoder_rf.py
CHANGED
@@ -6,6 +6,7 @@ import numpy as np
|
|
6 |
from ...modules import sparse as sp
|
7 |
from .base import SparseTransformerBase
|
8 |
from ...representations import Strivec
|
|
|
9 |
|
10 |
|
11 |
class SLatRadianceFieldDecoder(SparseTransformerBase):
|
@@ -102,3 +103,11 @@ class SLatRadianceFieldDecoder(SparseTransformerBase):
|
|
102 |
h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
|
103 |
h = self.out_layer(h)
|
104 |
return self.to_representation(h)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
from ...modules import sparse as sp
|
7 |
from .base import SparseTransformerBase
|
8 |
from ...representations import Strivec
|
9 |
+
from ..sparse_elastic_mixin import SparseTransformerElasticMixin
|
10 |
|
11 |
|
12 |
class SLatRadianceFieldDecoder(SparseTransformerBase):
|
|
|
103 |
h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
|
104 |
h = self.out_layer(h)
|
105 |
return self.to_representation(h)
|
106 |
+
|
107 |
+
|
108 |
+
class ElasticSLatRadianceFieldDecoder(SparseTransformerElasticMixin, SLatRadianceFieldDecoder):
|
109 |
+
"""
|
110 |
+
Slat VAE Radiance Field Decoder with elastic memory management.
|
111 |
+
Used for training with low VRAM.
|
112 |
+
"""
|
113 |
+
pass
|
trellis/models/structured_latent_vae/encoder.py
CHANGED
@@ -4,6 +4,7 @@ import torch.nn as nn
|
|
4 |
import torch.nn.functional as F
|
5 |
from ...modules import sparse as sp
|
6 |
from .base import SparseTransformerBase
|
|
|
7 |
|
8 |
|
9 |
class SLatEncoder(SparseTransformerBase):
|
@@ -70,3 +71,10 @@ class SLatEncoder(SparseTransformerBase):
|
|
70 |
return z, mean, logvar
|
71 |
else:
|
72 |
return z
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
import torch.nn.functional as F
|
5 |
from ...modules import sparse as sp
|
6 |
from .base import SparseTransformerBase
|
7 |
+
from ..sparse_elastic_mixin import SparseTransformerElasticMixin
|
8 |
|
9 |
|
10 |
class SLatEncoder(SparseTransformerBase):
|
|
|
71 |
return z, mean, logvar
|
72 |
else:
|
73 |
return z
|
74 |
+
|
75 |
+
|
76 |
+
class ElasticSLatEncoder(SparseTransformerElasticMixin, SLatEncoder):
|
77 |
+
"""
|
78 |
+
SLat VAE encoder with elastic memory management.
|
79 |
+
Used for training with low VRAM.
|
80 |
+
"""
|
trellis/pipelines/__init__.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
from . import samplers
|
2 |
from .trellis_image_to_3d import TrellisImageTo3DPipeline
|
|
|
3 |
|
4 |
|
5 |
def from_pretrained(path: str):
|
|
|
1 |
from . import samplers
|
2 |
from .trellis_image_to_3d import TrellisImageTo3DPipeline
|
3 |
+
from .trellis_text_to_3d import TrellisTextTo3DPipeline
|
4 |
|
5 |
|
6 |
def from_pretrained(path: str):
|
trellis/pipelines/base.py
CHANGED
@@ -36,10 +36,12 @@ class Pipeline:
|
|
36 |
with open(config_file, 'r') as f:
|
37 |
args = json.load(f)['args']
|
38 |
|
39 |
-
_models = {
|
40 |
-
|
41 |
-
|
42 |
-
|
|
|
|
|
43 |
|
44 |
new_pipeline = Pipeline(_models)
|
45 |
new_pipeline._pretrained_args = args
|
|
|
36 |
with open(config_file, 'r') as f:
|
37 |
args = json.load(f)['args']
|
38 |
|
39 |
+
_models = {}
|
40 |
+
for k, v in args['models'].items():
|
41 |
+
try:
|
42 |
+
_models[k] = models.from_pretrained(f"{path}/{v}")
|
43 |
+
except:
|
44 |
+
_models[k] = models.from_pretrained(v)
|
45 |
|
46 |
new_pipeline = Pipeline(_models)
|
47 |
new_pipeline._pretrained_args = args
|
trellis/pipelines/samplers/flow_euler.py
CHANGED
@@ -37,6 +37,8 @@ class FlowEulerSampler(Sampler):
|
|
37 |
|
38 |
def _inference_model(self, model, x_t, t, cond=None, **kwargs):
|
39 |
t = torch.tensor([1000 * t] * x_t.shape[0], device=x_t.device, dtype=torch.float32)
|
|
|
|
|
40 |
return model(x_t, t, cond, **kwargs)
|
41 |
|
42 |
def _get_model_prediction(self, model, x_t, t, cond=None, **kwargs):
|
|
|
37 |
|
38 |
def _inference_model(self, model, x_t, t, cond=None, **kwargs):
|
39 |
t = torch.tensor([1000 * t] * x_t.shape[0], device=x_t.device, dtype=torch.float32)
|
40 |
+
if cond is not None and cond.shape[0] == 1 and x_t.shape[0] > 1:
|
41 |
+
cond = cond.repeat(x_t.shape[0], *([1] * (len(cond.shape) - 1)))
|
42 |
return model(x_t, t, cond, **kwargs)
|
43 |
|
44 |
def _get_model_prediction(self, model, x_t, t, cond=None, **kwargs):
|
trellis/pipelines/trellis_image_to_3d.py
CHANGED
@@ -4,15 +4,12 @@ import torch
|
|
4 |
import torch.nn as nn
|
5 |
import torch.nn.functional as F
|
6 |
import numpy as np
|
7 |
-
from tqdm import tqdm
|
8 |
-
from easydict import EasyDict as edict
|
9 |
from torchvision import transforms
|
10 |
from PIL import Image
|
11 |
import rembg
|
12 |
from .base import Pipeline
|
13 |
from . import samplers
|
14 |
from ..modules import sparse as sp
|
15 |
-
from ..representations import Gaussian, Strivec, MeshExtractResult
|
16 |
|
17 |
|
18 |
class TrellisImageTo3DPipeline(Pipeline):
|
@@ -271,8 +268,10 @@ class TrellisImageTo3DPipeline(Pipeline):
|
|
271 |
Args:
|
272 |
image (Image.Image): The image prompt.
|
273 |
num_samples (int): The number of samples to generate.
|
|
|
274 |
sparse_structure_sampler_params (dict): Additional parameters for the sparse structure sampler.
|
275 |
slat_sampler_params (dict): Additional parameters for the structured latent sampler.
|
|
|
276 |
preprocess_image (bool): Whether to preprocess the image.
|
277 |
"""
|
278 |
if preprocess_image:
|
|
|
4 |
import torch.nn as nn
|
5 |
import torch.nn.functional as F
|
6 |
import numpy as np
|
|
|
|
|
7 |
from torchvision import transforms
|
8 |
from PIL import Image
|
9 |
import rembg
|
10 |
from .base import Pipeline
|
11 |
from . import samplers
|
12 |
from ..modules import sparse as sp
|
|
|
13 |
|
14 |
|
15 |
class TrellisImageTo3DPipeline(Pipeline):
|
|
|
268 |
Args:
|
269 |
image (Image.Image): The image prompt.
|
270 |
num_samples (int): The number of samples to generate.
|
271 |
+
seed (int): The random seed.
|
272 |
sparse_structure_sampler_params (dict): Additional parameters for the sparse structure sampler.
|
273 |
slat_sampler_params (dict): Additional parameters for the structured latent sampler.
|
274 |
+
formats (List[str]): The formats to decode the structured latent to.
|
275 |
preprocess_image (bool): Whether to preprocess the image.
|
276 |
"""
|
277 |
if preprocess_image:
|
trellis/pipelines/trellis_text_to_3d.py
ADDED
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import *
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import numpy as np
|
5 |
+
from transformers import CLIPTextModel, AutoTokenizer
|
6 |
+
import open3d as o3d
|
7 |
+
from .base import Pipeline
|
8 |
+
from . import samplers
|
9 |
+
from ..modules import sparse as sp
|
10 |
+
|
11 |
+
|
12 |
+
class TrellisTextTo3DPipeline(Pipeline):
|
13 |
+
"""
|
14 |
+
Pipeline for inferring Trellis text-to-3D models.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
models (dict[str, nn.Module]): The models to use in the pipeline.
|
18 |
+
sparse_structure_sampler (samplers.Sampler): The sampler for the sparse structure.
|
19 |
+
slat_sampler (samplers.Sampler): The sampler for the structured latent.
|
20 |
+
slat_normalization (dict): The normalization parameters for the structured latent.
|
21 |
+
text_cond_model (str): The name of the text conditioning model.
|
22 |
+
"""
|
23 |
+
def __init__(
|
24 |
+
self,
|
25 |
+
models: dict[str, nn.Module] = None,
|
26 |
+
sparse_structure_sampler: samplers.Sampler = None,
|
27 |
+
slat_sampler: samplers.Sampler = None,
|
28 |
+
slat_normalization: dict = None,
|
29 |
+
text_cond_model: str = None,
|
30 |
+
):
|
31 |
+
if models is None:
|
32 |
+
return
|
33 |
+
super().__init__(models)
|
34 |
+
self.sparse_structure_sampler = sparse_structure_sampler
|
35 |
+
self.slat_sampler = slat_sampler
|
36 |
+
self.sparse_structure_sampler_params = {}
|
37 |
+
self.slat_sampler_params = {}
|
38 |
+
self.slat_normalization = slat_normalization
|
39 |
+
self._init_text_cond_model(text_cond_model)
|
40 |
+
|
41 |
+
@staticmethod
|
42 |
+
def from_pretrained(path: str) -> "TrellisTextTo3DPipeline":
|
43 |
+
"""
|
44 |
+
Load a pretrained model.
|
45 |
+
|
46 |
+
Args:
|
47 |
+
path (str): The path to the model. Can be either local path or a Hugging Face repository.
|
48 |
+
"""
|
49 |
+
pipeline = super(TrellisTextTo3DPipeline, TrellisTextTo3DPipeline).from_pretrained(path)
|
50 |
+
new_pipeline = TrellisTextTo3DPipeline()
|
51 |
+
new_pipeline.__dict__ = pipeline.__dict__
|
52 |
+
args = pipeline._pretrained_args
|
53 |
+
|
54 |
+
new_pipeline.sparse_structure_sampler = getattr(samplers, args['sparse_structure_sampler']['name'])(**args['sparse_structure_sampler']['args'])
|
55 |
+
new_pipeline.sparse_structure_sampler_params = args['sparse_structure_sampler']['params']
|
56 |
+
|
57 |
+
new_pipeline.slat_sampler = getattr(samplers, args['slat_sampler']['name'])(**args['slat_sampler']['args'])
|
58 |
+
new_pipeline.slat_sampler_params = args['slat_sampler']['params']
|
59 |
+
|
60 |
+
new_pipeline.slat_normalization = args['slat_normalization']
|
61 |
+
|
62 |
+
new_pipeline._init_text_cond_model(args['text_cond_model'])
|
63 |
+
|
64 |
+
return new_pipeline
|
65 |
+
|
66 |
+
def _init_text_cond_model(self, name: str):
|
67 |
+
"""
|
68 |
+
Initialize the text conditioning model.
|
69 |
+
"""
|
70 |
+
# load model
|
71 |
+
model = CLIPTextModel.from_pretrained(name)
|
72 |
+
tokenizer = AutoTokenizer.from_pretrained(name)
|
73 |
+
model.eval()
|
74 |
+
model = model.cuda()
|
75 |
+
self.text_cond_model = {
|
76 |
+
'model': model,
|
77 |
+
'tokenizer': tokenizer,
|
78 |
+
}
|
79 |
+
self.text_cond_model['null_cond'] = self.encode_text([''])
|
80 |
+
|
81 |
+
@torch.no_grad()
|
82 |
+
def encode_text(self, text: List[str]) -> torch.Tensor:
|
83 |
+
"""
|
84 |
+
Encode the text.
|
85 |
+
"""
|
86 |
+
assert isinstance(text, list) and all(isinstance(t, str) for t in text), "text must be a list of strings"
|
87 |
+
encoding = self.text_cond_model['tokenizer'](text, max_length=77, padding='max_length', truncation=True, return_tensors='pt')
|
88 |
+
tokens = encoding['input_ids'].cuda()
|
89 |
+
embeddings = self.text_cond_model['model'](input_ids=tokens).last_hidden_state
|
90 |
+
|
91 |
+
return embeddings
|
92 |
+
|
93 |
+
def get_cond(self, prompt: List[str]) -> dict:
|
94 |
+
"""
|
95 |
+
Get the conditioning information for the model.
|
96 |
+
|
97 |
+
Args:
|
98 |
+
prompt (List[str]): The text prompt.
|
99 |
+
|
100 |
+
Returns:
|
101 |
+
dict: The conditioning information
|
102 |
+
"""
|
103 |
+
cond = self.encode_text(prompt)
|
104 |
+
neg_cond = self.text_cond_model['null_cond']
|
105 |
+
return {
|
106 |
+
'cond': cond,
|
107 |
+
'neg_cond': neg_cond,
|
108 |
+
}
|
109 |
+
|
110 |
+
def sample_sparse_structure(
|
111 |
+
self,
|
112 |
+
cond: dict,
|
113 |
+
num_samples: int = 1,
|
114 |
+
sampler_params: dict = {},
|
115 |
+
) -> torch.Tensor:
|
116 |
+
"""
|
117 |
+
Sample sparse structures with the given conditioning.
|
118 |
+
|
119 |
+
Args:
|
120 |
+
cond (dict): The conditioning information.
|
121 |
+
num_samples (int): The number of samples to generate.
|
122 |
+
sampler_params (dict): Additional parameters for the sampler.
|
123 |
+
"""
|
124 |
+
# Sample occupancy latent
|
125 |
+
flow_model = self.models['sparse_structure_flow_model']
|
126 |
+
reso = flow_model.resolution
|
127 |
+
noise = torch.randn(num_samples, flow_model.in_channels, reso, reso, reso).to(self.device)
|
128 |
+
sampler_params = {**self.sparse_structure_sampler_params, **sampler_params}
|
129 |
+
z_s = self.sparse_structure_sampler.sample(
|
130 |
+
flow_model,
|
131 |
+
noise,
|
132 |
+
**cond,
|
133 |
+
**sampler_params,
|
134 |
+
verbose=True
|
135 |
+
).samples
|
136 |
+
|
137 |
+
# Decode occupancy latent
|
138 |
+
decoder = self.models['sparse_structure_decoder']
|
139 |
+
coords = torch.argwhere(decoder(z_s)>0)[:, [0, 2, 3, 4]].int()
|
140 |
+
|
141 |
+
return coords
|
142 |
+
|
143 |
+
def decode_slat(
|
144 |
+
self,
|
145 |
+
slat: sp.SparseTensor,
|
146 |
+
formats: List[str] = ['mesh', 'gaussian', 'radiance_field'],
|
147 |
+
) -> dict:
|
148 |
+
"""
|
149 |
+
Decode the structured latent.
|
150 |
+
|
151 |
+
Args:
|
152 |
+
slat (sp.SparseTensor): The structured latent.
|
153 |
+
formats (List[str]): The formats to decode the structured latent to.
|
154 |
+
|
155 |
+
Returns:
|
156 |
+
dict: The decoded structured latent.
|
157 |
+
"""
|
158 |
+
ret = {}
|
159 |
+
if 'mesh' in formats:
|
160 |
+
ret['mesh'] = self.models['slat_decoder_mesh'](slat)
|
161 |
+
if 'gaussian' in formats:
|
162 |
+
ret['gaussian'] = self.models['slat_decoder_gs'](slat)
|
163 |
+
if 'radiance_field' in formats:
|
164 |
+
ret['radiance_field'] = self.models['slat_decoder_rf'](slat)
|
165 |
+
return ret
|
166 |
+
|
167 |
+
def sample_slat(
|
168 |
+
self,
|
169 |
+
cond: dict,
|
170 |
+
coords: torch.Tensor,
|
171 |
+
sampler_params: dict = {},
|
172 |
+
) -> sp.SparseTensor:
|
173 |
+
"""
|
174 |
+
Sample structured latent with the given conditioning.
|
175 |
+
|
176 |
+
Args:
|
177 |
+
cond (dict): The conditioning information.
|
178 |
+
coords (torch.Tensor): The coordinates of the sparse structure.
|
179 |
+
sampler_params (dict): Additional parameters for the sampler.
|
180 |
+
"""
|
181 |
+
# Sample structured latent
|
182 |
+
flow_model = self.models['slat_flow_model']
|
183 |
+
noise = sp.SparseTensor(
|
184 |
+
feats=torch.randn(coords.shape[0], flow_model.in_channels).to(self.device),
|
185 |
+
coords=coords,
|
186 |
+
)
|
187 |
+
sampler_params = {**self.slat_sampler_params, **sampler_params}
|
188 |
+
slat = self.slat_sampler.sample(
|
189 |
+
flow_model,
|
190 |
+
noise,
|
191 |
+
**cond,
|
192 |
+
**sampler_params,
|
193 |
+
verbose=True
|
194 |
+
).samples
|
195 |
+
|
196 |
+
std = torch.tensor(self.slat_normalization['std'])[None].to(slat.device)
|
197 |
+
mean = torch.tensor(self.slat_normalization['mean'])[None].to(slat.device)
|
198 |
+
slat = slat * std + mean
|
199 |
+
|
200 |
+
return slat
|
201 |
+
|
202 |
+
@torch.no_grad()
|
203 |
+
def run(
|
204 |
+
self,
|
205 |
+
prompt: str,
|
206 |
+
num_samples: int = 1,
|
207 |
+
seed: int = 42,
|
208 |
+
sparse_structure_sampler_params: dict = {},
|
209 |
+
slat_sampler_params: dict = {},
|
210 |
+
formats: List[str] = ['mesh', 'gaussian', 'radiance_field'],
|
211 |
+
) -> dict:
|
212 |
+
"""
|
213 |
+
Run the pipeline.
|
214 |
+
|
215 |
+
Args:
|
216 |
+
prompt (str): The text prompt.
|
217 |
+
num_samples (int): The number of samples to generate.
|
218 |
+
seed (int): The random seed.
|
219 |
+
sparse_structure_sampler_params (dict): Additional parameters for the sparse structure sampler.
|
220 |
+
slat_sampler_params (dict): Additional parameters for the structured latent sampler.
|
221 |
+
formats (List[str]): The formats to decode the structured latent to.
|
222 |
+
"""
|
223 |
+
cond = self.get_cond([prompt])
|
224 |
+
torch.manual_seed(seed)
|
225 |
+
coords = self.sample_sparse_structure(cond, num_samples, sparse_structure_sampler_params)
|
226 |
+
slat = self.sample_slat(cond, coords, slat_sampler_params)
|
227 |
+
return self.decode_slat(slat, formats)
|
228 |
+
|
229 |
+
def voxelize(self, mesh: o3d.geometry.TriangleMesh) -> torch.Tensor:
|
230 |
+
"""
|
231 |
+
Voxelize a mesh.
|
232 |
+
|
233 |
+
Args:
|
234 |
+
mesh (o3d.geometry.TriangleMesh): The mesh to voxelize.
|
235 |
+
sha256 (str): The SHA256 hash of the mesh.
|
236 |
+
output_dir (str): The output directory.
|
237 |
+
"""
|
238 |
+
vertices = np.asarray(mesh.vertices)
|
239 |
+
aabb = np.stack([vertices.min(0), vertices.max(0)])
|
240 |
+
center = (aabb[0] + aabb[1]) / 2
|
241 |
+
scale = (aabb[1] - aabb[0]).max()
|
242 |
+
vertices = (vertices - center) / scale
|
243 |
+
vertices = np.clip(vertices, -0.5 + 1e-6, 0.5 - 1e-6)
|
244 |
+
mesh.vertices = o3d.utility.Vector3dVector(vertices)
|
245 |
+
voxel_grid = o3d.geometry.VoxelGrid.create_from_triangle_mesh_within_bounds(mesh, voxel_size=1/64, min_bound=(-0.5, -0.5, -0.5), max_bound=(0.5, 0.5, 0.5))
|
246 |
+
vertices = np.array([voxel.grid_index for voxel in voxel_grid.get_voxels()])
|
247 |
+
return torch.tensor(vertices).int().cuda()
|
248 |
+
|
249 |
+
@torch.no_grad()
|
250 |
+
def run_variant(
|
251 |
+
self,
|
252 |
+
mesh: o3d.geometry.TriangleMesh,
|
253 |
+
prompt: str,
|
254 |
+
num_samples: int = 1,
|
255 |
+
seed: int = 42,
|
256 |
+
slat_sampler_params: dict = {},
|
257 |
+
formats: List[str] = ['mesh', 'gaussian', 'radiance_field'],
|
258 |
+
) -> dict:
|
259 |
+
"""
|
260 |
+
Run the pipeline for making variants of an asset.
|
261 |
+
|
262 |
+
Args:
|
263 |
+
mesh (o3d.geometry.TriangleMesh): The base mesh.
|
264 |
+
prompt (str): The text prompt.
|
265 |
+
num_samples (int): The number of samples to generate.
|
266 |
+
seed (int): The random seed
|
267 |
+
slat_sampler_params (dict): Additional parameters for the structured latent sampler.
|
268 |
+
formats (List[str]): The formats to decode the structured latent to.
|
269 |
+
"""
|
270 |
+
cond = self.get_cond([prompt])
|
271 |
+
coords = self.voxelize(mesh)
|
272 |
+
coords = torch.cat([
|
273 |
+
torch.arange(num_samples).repeat_interleave(coords.shape[0], 0)[:, None].int().cuda(),
|
274 |
+
coords.repeat(num_samples, 1)
|
275 |
+
], 1)
|
276 |
+
torch.manual_seed(seed)
|
277 |
+
slat = self.sample_slat(cond, coords, slat_sampler_params)
|
278 |
+
return self.decode_slat(slat, formats)
|
trellis/representations/mesh/cube2mesh.py
CHANGED
@@ -1,15 +1,8 @@
|
|
1 |
-
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
-
#
|
3 |
-
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
|
4 |
-
# and proprietary rights in and to this software, related documentation
|
5 |
-
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
-
# distribution of this software and related documentation without an express
|
7 |
-
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
|
8 |
import torch
|
9 |
from ...modules.sparse import SparseTensor
|
10 |
from easydict import EasyDict as edict
|
11 |
from .utils_cube import *
|
12 |
-
from .
|
13 |
|
14 |
|
15 |
class MeshExtractResult:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import torch
|
2 |
from ...modules.sparse import SparseTensor
|
3 |
from easydict import EasyDict as edict
|
4 |
from .utils_cube import *
|
5 |
+
from .flexicubes.flexicubes import FlexiCubes
|
6 |
|
7 |
|
8 |
class MeshExtractResult:
|
trellis/representations/mesh/flexicubes/LICENSE.txt
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
|
3 |
+
|
4 |
+
NVIDIA Source Code License for FlexiCubes
|
5 |
+
|
6 |
+
|
7 |
+
=======================================================================
|
8 |
+
|
9 |
+
1. Definitions
|
10 |
+
|
11 |
+
“Licensor” means any person or entity that distributes its Work.
|
12 |
+
|
13 |
+
“Work” means (a) the original work of authorship made available under
|
14 |
+
this license, which may include software, documentation, or other files,
|
15 |
+
and (b) any additions to or derivative works thereof that are made
|
16 |
+
available under this license.
|
17 |
+
|
18 |
+
The terms “reproduce,” “reproduction,” “derivative works,” and
|
19 |
+
“distribution” have the meaning as provided under U.S. copyright law;
|
20 |
+
provided, however, that for the purposes of this license, derivative works
|
21 |
+
shall not include works that remain separable from, or merely link
|
22 |
+
(or bind by name) to the interfaces of, the Work.
|
23 |
+
|
24 |
+
Works are “made available” under this license by including in or with
|
25 |
+
the Work either (a) a copyright notice referencing the applicability of
|
26 |
+
this license to the Work, or (b) a copy of this license.
|
27 |
+
|
28 |
+
2. License Grant
|
29 |
+
|
30 |
+
2.1 Copyright Grant. Subject to the terms and conditions of this license,
|
31 |
+
each Licensor grants to you a perpetual, worldwide, non-exclusive,
|
32 |
+
royalty-free, copyright license to use, reproduce, prepare derivative
|
33 |
+
works of, publicly display, publicly perform, sublicense and distribute
|
34 |
+
its Work and any resulting derivative works in any form.
|
35 |
+
|
36 |
+
3. Limitations
|
37 |
+
|
38 |
+
3.1 Redistribution. You may reproduce or distribute the Work only if
|
39 |
+
(a) you do so under this license, (b) you include a complete copy of
|
40 |
+
this license with your distribution, and (c) you retain without
|
41 |
+
modification any copyright, patent, trademark, or attribution notices
|
42 |
+
that are present in the Work.
|
43 |
+
|
44 |
+
3.2 Derivative Works. You may specify that additional or different terms
|
45 |
+
apply to the use, reproduction, and distribution of your derivative
|
46 |
+
works of the Work (“Your Terms”) only if (a) Your Terms provide that the
|
47 |
+
use limitation in Section 3.3 applies to your derivative works, and (b)
|
48 |
+
you identify the specific derivative works that are subject to Your Terms.
|
49 |
+
Notwithstanding Your Terms, this license (including the redistribution
|
50 |
+
requirements in Section 3.1) will continue to apply to the Work itself.
|
51 |
+
|
52 |
+
3.3 Use Limitation. The Work and any derivative works thereof only may be
|
53 |
+
used or intended for use non-commercially. Notwithstanding the foregoing,
|
54 |
+
NVIDIA Corporation and its affiliates may use the Work and any derivative
|
55 |
+
works commercially. As used herein, “non-commercially” means for research
|
56 |
+
or evaluation purposes only.
|
57 |
+
|
58 |
+
3.4 Patent Claims. If you bring or threaten to bring a patent claim against
|
59 |
+
any Licensor (including any claim, cross-claim or counterclaim in a lawsuit)
|
60 |
+
to enforce any patents that you allege are infringed by any Work, then your
|
61 |
+
rights under this license from such Licensor (including the grant in
|
62 |
+
Section 2.1) will terminate immediately.
|
63 |
+
|
64 |
+
3.5 Trademarks. This license does not grant any rights to use any Licensor’s
|
65 |
+
or its affiliates’ names, logos, or trademarks, except as necessary to
|
66 |
+
reproduce the notices described in this license.
|
67 |
+
|
68 |
+
3.6 Termination. If you violate any term of this license, then your rights
|
69 |
+
under this license (including the grant in Section 2.1) will terminate
|
70 |
+
immediately.
|
71 |
+
|
72 |
+
4. Disclaimer of Warranty.
|
73 |
+
|
74 |
+
THE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
|
75 |
+
EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
|
76 |
+
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR NON-INFRINGEMENT.
|
77 |
+
YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER THIS LICENSE.
|
78 |
+
|
79 |
+
5. Limitation of Liability.
|
80 |
+
|
81 |
+
EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY,
|
82 |
+
WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE SHALL ANY
|
83 |
+
LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL,
|
84 |
+
INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF OR RELATED TO THIS LICENSE,
|
85 |
+
THE USE OR INABILITY TO USE THE WORK (INCLUDING BUT NOT LIMITED TO LOSS OF
|
86 |
+
GOODWILL, BUSINESS INTERRUPTION, LOST PROFITS OR DATA, COMPUTER FAILURE OR
|
87 |
+
MALFUNCTION, OR ANY OTHER DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN
|
88 |
+
ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
|
89 |
+
|
90 |
+
=======================================================================
|
trellis/representations/mesh/flexicubes/README.md
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Flexible Isosurface Extraction for Gradient-Based Mesh Optimization (FlexiCubes)<br><sub>Official PyTorch implementation </sub>
|
2 |
+
|
3 |
+

|
4 |
+
|
5 |
+
FlexiCubes is a high-quality isosurface representation specifically designed for gradient-based mesh optimization with respect to geometric, visual, or even physical objectives. For more details, please refer to our [paper](https://arxiv.org/abs/2308.05371) and [project page](https://research.nvidia.com/labs/toronto-ai/flexicubes/).
|
6 |
+
|
7 |
+
## Highlights
|
8 |
+
* [Getting started](https://github.com/nv-tlabs/FlexiCubes#getting-started)
|
9 |
+
* [Basic workflow](https://github.com/nv-tlabs/FlexiCubes#example-usage)
|
10 |
+
* [nvdiffrec: image-based reconstruction example](https://github.com/NVlabs/nvdiffrec#news)
|
11 |
+
* [GET3D: generative AI example](https://github.com/nv-tlabs/GET3D#employing-flexicubes)
|
12 |
+
* [Bibtex](https://github.com/nv-tlabs/FlexiCubes#citation)
|
13 |
+
|
14 |
+
## Getting Started
|
15 |
+
|
16 |
+
The core functions of FlexiCubes are now in [Kaolin](https://github.com/NVIDIAGameWorks/kaolin/) starting from v0.15.0. See installation instructions [here](https://kaolin.readthedocs.io/en/latest/notes/installation.html) and API documentations [here](https://kaolin.readthedocs.io/en/latest/modules/kaolin.non_commercial.html#kaolin.non_commercial.FlexiCubes)
|
17 |
+
|
18 |
+
The original code of the paper is still visible in `flexicube.py`.
|
19 |
+
|
20 |
+
## Example Usage
|
21 |
+
|
22 |
+
### Gradient-Based Mesh Optimization
|
23 |
+
We provide examples demonstrating how to use FlexiCubes for reconstructing unknown meshes through gradient-based optimization. Specifically, starting from randomly initialized SDF, we optimize the shape towards the reference mesh by minimizing their geometric difference, measured by multiview mask and depth losses. This workflow is a simplified version of `nvdiffrec` with code largely borrowed from the [nvdiffrec GitHub](https://github.com/NVlabs/nvdiffrec). We use the same pipeline to conduct the analysis in Section 3 and the main experiments described in Section 5 of our paper. We provide a detailed tutorial in `examples/optimization.ipynb`, along with an optimization script in `examples/optimize.py` which accepts command-line arguments.
|
24 |
+
|
25 |
+
|
26 |
+
To run the examples, it is suggested to install the Conda environment as detailed below:
|
27 |
+
```sh
|
28 |
+
conda create -n flexicubes python=3.9
|
29 |
+
conda activate flexicubes
|
30 |
+
conda install pytorch==1.12.0 torchvision==0.13.0 torchaudio==0.12.0 cudatoolkit=11.3 -c pytorch
|
31 |
+
pip install imageio trimesh tqdm matplotlib torch_scatter ninja
|
32 |
+
pip install git+https://github.com/NVlabs/nvdiffrast/
|
33 |
+
pip install kaolin==0.15.0 -f https://nvidia-kaolin.s3.us-east-2.amazonaws.com/torch-1.12.0_cu113.html
|
34 |
+
```
|
35 |
+
|
36 |
+
Then download the dataset collected by [Myles et al.](https://vcg.isti.cnr.it/Publications/2014/MPZ14/) as follows. We include one shape in 'examples/data/inputmodels/block.obj' if you want to test without downloading the full dataset.
|
37 |
+
|
38 |
+
```sh
|
39 |
+
cd examples
|
40 |
+
python download_data.py
|
41 |
+
```
|
42 |
+
|
43 |
+
After downloading the data, run shape optimization with the following example command:
|
44 |
+
```sh
|
45 |
+
python optimize.py --ref_mesh data/inputmodels/block.obj --out_dir out/block
|
46 |
+
```
|
47 |
+
You can find visualization and output meshes in the `out/block`. Below, we show the initial and final shapes during optimization, with the reference shape on the right.
|
48 |
+
|
49 |
+
<img src="images/block_init.png" alt="block_init" width="80%" height="80%">
|
50 |
+
|
51 |
+
<img src="images/block_final.png" alt="block_final" width="80%" height="80%">
|
52 |
+
|
53 |
+
|
54 |
+
To further demonstrate the flexibility of our FlexiCubes representation, which can accommodates both reconstruction objectives and regularizers defined on the extracted mesh, you can add a developability regularizer (proposed by [Stein et al.](https://www.cs.cmu.edu/~kmcrane/Projects/DiscreteDevelopable/)) to the previous reconstruction pipeline to encourage fabricability from panels:
|
55 |
+
```sh
|
56 |
+
python optimize.py --ref_mesh data/inputmodels/david.obj --out_dir out/david_dev --develop_reg True --iter=1250
|
57 |
+
```
|
58 |
+
|
59 |
+
### Extract mesh from known signed distance field
|
60 |
+
While not its designated use case, our function can extract a mesh from a known Signed Distance Field (SDF) without optimization. Please refer to the tutorial found in `examples/extraction.ipynb` for details.
|
61 |
+
|
62 |
+
## Tips for using FlexiCubes
|
63 |
+
### Regularization losses:
|
64 |
+
We commonly use three regularizers in our mesh optimization pipelines, referenced in lines `L104-L106` in `examples/optimize.py`. The weights of these regularizers should be scaled according to the your application objectives. Initially, it is suggested to employ low weights because strong regularization can hinder convergence. You can incrementally increase the weights if you notice artifacts appearing in the optimized meshes. Specifically:
|
65 |
+
|
66 |
+
* The loss function at `L104` helps to remove floaters in areas of the shape that are not supervised by the application objective, such as internal faces when using image supervision only.
|
67 |
+
* The L_dev loss at `L105` can be increased if you observe artifacts in flat areas, as illustrated in the image below.
|
68 |
+
* Generally, the L1 regularizer on flexible weights at `L106` does not have a significant impact during the optimization of a single shape. However, we found it to be effective in stabilizing training in generative pipelines such as GET3D.
|
69 |
+
<img src="images/ablate_L_dev.jpg" alt="Ablating L_dev" width="80%" height="80%">
|
70 |
+
|
71 |
+
### Resolution of voxel grid vs. tetrahedral grid:
|
72 |
+
If you are switching from our previous work, DMTet, it's important to note the difference in grid resolution when compared to FlexiCubes. In both implementations, the resolution is defined by the edge length: a grid resolution of `n` means the grid edge length is 1/n for both the voxel and tetrahedral grids. However, a tetrahedral grid with a resolution of `n` contains only `(n/2+1)³` grid vertices, in contrast to the `(n+1)³` vertices in a voxel grid. Consequently, if you are switching from DMTet to FlexiCubes while maintaining the same resolution, you will notice not only a denser output mesh but also a substantial increase in computational cost. To align the triangle count in the output meshes more closely, we recommend adopting a 4:5 resolution ratio between the voxel grid and the tetrahedral grid. For instance, in our paper, `64³` FlexiCubes generate approximately the same number of triangles as `80³` DMTet.
|
73 |
+
|
74 |
+
## Applications
|
75 |
+
FlexiCubes is now integrated into NVIDIA applications as a drop-in replacement for DMTet. You can visit their GitHub pages to see how FlexiCubes is used in advanced photogrammetry and 3D generative pipelines.
|
76 |
+
|
77 |
+
[Extracting Triangular 3D Models, Materials, and Lighting From Images (nvdiffrec)](https://github.com/NVlabs/nvdiffrec#news)
|
78 |
+
|
79 |
+
[GET3D: A Generative Model of High Quality 3D Textured Shapes Learned from Images](https://github.com/nv-tlabs/GET3D#employing-flexicubes)
|
80 |
+
|
81 |
+
|
82 |
+
|
83 |
+
## License
|
84 |
+
Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
85 |
+
|
86 |
+
This work is made available under the [Nvidia Source Code License](LICENSE.txt).
|
87 |
+
|
88 |
+
For business inquiries, please visit our website and submit the form: [NVIDIA Research Licensing](https://www.nvidia.com/en-us/research/inquiries/).
|
89 |
+
|
90 |
+
## Citation
|
91 |
+
```bibtex
|
92 |
+
@article{shen2023flexicubes,
|
93 |
+
author = {Shen, Tianchang and Munkberg, Jacob and Hasselgren, Jon and Yin, Kangxue and Wang, Zian
|
94 |
+
and Chen, Wenzheng and Gojcic, Zan and Fidler, Sanja and Sharp, Nicholas and Gao, Jun},
|
95 |
+
title = {Flexible Isosurface Extraction for Gradient-Based Mesh Optimization},
|
96 |
+
year = {2023},
|
97 |
+
issue_date = {August 2023},
|
98 |
+
publisher = {Association for Computing Machinery},
|
99 |
+
address = {New York, NY, USA},
|
100 |
+
volume = {42},
|
101 |
+
number = {4},
|
102 |
+
issn = {0730-0301},
|
103 |
+
url = {https://doi.org/10.1145/3592430},
|
104 |
+
doi = {10.1145/3592430},
|
105 |
+
journal = {ACM Trans. Graph.},
|
106 |
+
month = {jul},
|
107 |
+
articleno = {37},
|
108 |
+
numpages = {16}
|
109 |
+
}
|
110 |
+
```
|
trellis/representations/mesh/flexicubes/examples/data/inputmodels/block.obj
ADDED
The diff for this file is too large to render.
See raw diff
|
|
trellis/representations/mesh/flexicubes/examples/download_data.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
|
8 |
+
import requests
|
9 |
+
from zipfile import ZipFile
|
10 |
+
from tqdm import tqdm
|
11 |
+
import os
|
12 |
+
|
13 |
+
def download_file(url, output_path):
|
14 |
+
response = requests.get(url, stream=True)
|
15 |
+
response.raise_for_status()
|
16 |
+
total_size_in_bytes = int(response.headers.get('content-length', 0))
|
17 |
+
block_size = 1024 #1 Kibibyte
|
18 |
+
progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
|
19 |
+
|
20 |
+
with open(output_path, 'wb') as file:
|
21 |
+
for data in response.iter_content(block_size):
|
22 |
+
progress_bar.update(len(data))
|
23 |
+
file.write(data)
|
24 |
+
progress_bar.close()
|
25 |
+
if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
|
26 |
+
raise Exception("ERROR, something went wrong")
|
27 |
+
|
28 |
+
|
29 |
+
url = "https://vcg.isti.cnr.it/Publications/2014/MPZ14/inputmodels.zip"
|
30 |
+
zip_file_path = './data/inputmodels.zip'
|
31 |
+
|
32 |
+
os.makedirs('./data', exist_ok=True)
|
33 |
+
|
34 |
+
download_file(url, zip_file_path)
|
35 |
+
|
36 |
+
with ZipFile(zip_file_path, 'r') as zip_ref:
|
37 |
+
zip_ref.extractall('./data')
|
38 |
+
|
39 |
+
os.remove(zip_file_path)
|
40 |
+
|
41 |
+
print("Download and extraction complete.")
|
trellis/representations/mesh/flexicubes/examples/extraction.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
trellis/representations/mesh/flexicubes/examples/loss.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
|
8 |
+
import torch
|
9 |
+
import torch_scatter
|
10 |
+
|
11 |
+
###############################################################################
|
12 |
+
# Pytorch implementation of the developability regularizer introduced in paper
|
13 |
+
# "Developability of Triangle Meshes" by Stein et al.
|
14 |
+
###############################################################################
|
15 |
+
def mesh_developable_reg(mesh):
|
16 |
+
|
17 |
+
verts = mesh.vertices
|
18 |
+
tris = mesh.faces
|
19 |
+
|
20 |
+
device = verts.device
|
21 |
+
V = verts.shape[0]
|
22 |
+
F = tris.shape[0]
|
23 |
+
|
24 |
+
POS_EPS = 1e-6
|
25 |
+
REL_EPS = 1e-6
|
26 |
+
|
27 |
+
def normalize(vecs):
|
28 |
+
return vecs / (torch.linalg.norm(vecs, dim=-1, keepdim=True) + POS_EPS)
|
29 |
+
|
30 |
+
tri_pos = verts[tris]
|
31 |
+
|
32 |
+
vert_normal_covariance_sum = torch.zeros((V, 9), device=device)
|
33 |
+
vert_area = torch.zeros(V, device=device)
|
34 |
+
vert_degree = torch.zeros(V, dtype=torch.int32, device=device)
|
35 |
+
|
36 |
+
for iC in range(3): # loop over three corners of each triangle
|
37 |
+
|
38 |
+
# gather tri verts
|
39 |
+
pRoot = tri_pos[:, iC, :]
|
40 |
+
pA = tri_pos[:, (iC + 1) % 3, :]
|
41 |
+
pB = tri_pos[:, (iC + 2) % 3, :]
|
42 |
+
|
43 |
+
# compute the corner angle & normal
|
44 |
+
vA = pA - pRoot
|
45 |
+
vAn = normalize(vA)
|
46 |
+
vB = pB - pRoot
|
47 |
+
vBn = normalize(vB)
|
48 |
+
area_normal = torch.linalg.cross(vA, vB, dim=-1)
|
49 |
+
face_area = 0.5 * torch.linalg.norm(area_normal, dim=-1)
|
50 |
+
normal = normalize(area_normal)
|
51 |
+
corner_angle = torch.acos(torch.clamp(torch.sum(vAn * vBn, dim=-1), min=-1., max=1.))
|
52 |
+
|
53 |
+
# add up the contribution to the covariance matrix
|
54 |
+
outer = normal[:, :, None] @ normal[:, None, :]
|
55 |
+
contrib = corner_angle[:, None] * outer.reshape(-1, 9)
|
56 |
+
|
57 |
+
# scatter the result to the appropriate matrices
|
58 |
+
vert_normal_covariance_sum = torch_scatter.scatter_add(src=contrib,
|
59 |
+
index=tris[:, iC],
|
60 |
+
dim=-2,
|
61 |
+
out=vert_normal_covariance_sum)
|
62 |
+
|
63 |
+
vert_area = torch_scatter.scatter_add(src=face_area / 3.,
|
64 |
+
index=tris[:, iC],
|
65 |
+
dim=-1,
|
66 |
+
out=vert_area)
|
67 |
+
|
68 |
+
vert_degree = torch_scatter.scatter_add(src=torch.ones(F, dtype=torch.int32, device=device),
|
69 |
+
index=tris[:, iC],
|
70 |
+
dim=-1,
|
71 |
+
out=vert_degree)
|
72 |
+
|
73 |
+
# The energy is the smallest eigenvalue of the outer-product matrix
|
74 |
+
vert_normal_covariance_sum = vert_normal_covariance_sum.reshape(
|
75 |
+
-1, 3, 3) # reshape to a batch of matrices
|
76 |
+
vert_normal_covariance_sum = vert_normal_covariance_sum + torch.eye(
|
77 |
+
3, device=device)[None, :, :] * REL_EPS
|
78 |
+
|
79 |
+
min_eigvals = torch.min(torch.linalg.eigvals(vert_normal_covariance_sum).abs(), dim=-1).values
|
80 |
+
|
81 |
+
# Mask out degree-3 vertices
|
82 |
+
vert_area = torch.where(vert_degree == 3, torch.tensor(0, dtype=vert_area.dtype,device=vert_area.device), vert_area)
|
83 |
+
|
84 |
+
# Adjust the vertex area weighting so it is unit-less, and 1 on average
|
85 |
+
vert_area = vert_area * (V / torch.sum(vert_area, dim=-1, keepdim=True))
|
86 |
+
|
87 |
+
return vert_area * min_eigvals
|
88 |
+
|
89 |
+
def sdf_reg_loss(sdf, all_edges):
|
90 |
+
sdf_f1x6x2 = sdf[all_edges.reshape(-1)].reshape(-1,2)
|
91 |
+
mask = torch.sign(sdf_f1x6x2[...,0]) != torch.sign(sdf_f1x6x2[...,1])
|
92 |
+
sdf_f1x6x2 = sdf_f1x6x2[mask]
|
93 |
+
sdf_diff = torch.nn.functional.binary_cross_entropy_with_logits(sdf_f1x6x2[...,0], (sdf_f1x6x2[...,1] > 0).float()) + \
|
94 |
+
torch.nn.functional.binary_cross_entropy_with_logits(sdf_f1x6x2[...,1], (sdf_f1x6x2[...,0] > 0).float())
|
95 |
+
return sdf_diff
|
trellis/representations/mesh/flexicubes/examples/optimization.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
trellis/representations/mesh/flexicubes/examples/optimize.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
|
8 |
+
import argparse
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
import nvdiffrast.torch as dr
|
12 |
+
import trimesh
|
13 |
+
import os
|
14 |
+
from util import *
|
15 |
+
import render
|
16 |
+
import loss
|
17 |
+
import imageio
|
18 |
+
|
19 |
+
import sys
|
20 |
+
sys.path.append('..')
|
21 |
+
from flexicubes import FlexiCubes
|
22 |
+
|
23 |
+
###############################################################################
|
24 |
+
# Functions adapted from https://github.com/NVlabs/nvdiffrec
|
25 |
+
###############################################################################
|
26 |
+
|
27 |
+
def lr_schedule(iter):
|
28 |
+
return max(0.0, 10**(-(iter)*0.0002)) # Exponential falloff from [1.0, 0.1] over 5k epochs.
|
29 |
+
|
30 |
+
if __name__ == "__main__":
|
31 |
+
parser = argparse.ArgumentParser(description='flexicubes optimization')
|
32 |
+
parser.add_argument('-o', '--out_dir', type=str, default=None)
|
33 |
+
parser.add_argument('-rm', '--ref_mesh', type=str)
|
34 |
+
|
35 |
+
parser.add_argument('-i', '--iter', type=int, default=1000)
|
36 |
+
parser.add_argument('-b', '--batch', type=int, default=8)
|
37 |
+
parser.add_argument('-r', '--train_res', nargs=2, type=int, default=[2048, 2048])
|
38 |
+
parser.add_argument('-lr', '--learning_rate', type=float, default=0.01)
|
39 |
+
parser.add_argument('--voxel_grid_res', type=int, default=64)
|
40 |
+
|
41 |
+
parser.add_argument('--sdf_loss', type=bool, default=True)
|
42 |
+
parser.add_argument('--develop_reg', type=bool, default=False)
|
43 |
+
parser.add_argument('--sdf_regularizer', type=float, default=0.2)
|
44 |
+
|
45 |
+
parser.add_argument('-dr', '--display_res', nargs=2, type=int, default=[512, 512])
|
46 |
+
parser.add_argument('-si', '--save_interval', type=int, default=20)
|
47 |
+
FLAGS = parser.parse_args()
|
48 |
+
device = 'cuda'
|
49 |
+
|
50 |
+
os.makedirs(FLAGS.out_dir, exist_ok=True)
|
51 |
+
glctx = dr.RasterizeGLContext()
|
52 |
+
|
53 |
+
# Load GT mesh
|
54 |
+
gt_mesh = load_mesh(FLAGS.ref_mesh, device)
|
55 |
+
gt_mesh.auto_normals() # compute face normals for visualization
|
56 |
+
|
57 |
+
# ==============================================================================================
|
58 |
+
# Create and initialize FlexiCubes
|
59 |
+
# ==============================================================================================
|
60 |
+
fc = FlexiCubes(device)
|
61 |
+
x_nx3, cube_fx8 = fc.construct_voxel_grid(FLAGS.voxel_grid_res)
|
62 |
+
x_nx3 *= 2 # scale up the grid so that it's larger than the target object
|
63 |
+
|
64 |
+
sdf = torch.rand_like(x_nx3[:,0]) - 0.1 # randomly init SDF
|
65 |
+
sdf = torch.nn.Parameter(sdf.clone().detach(), requires_grad=True)
|
66 |
+
# set per-cube learnable weights to zeros
|
67 |
+
weight = torch.zeros((cube_fx8.shape[0], 21), dtype=torch.float, device='cuda')
|
68 |
+
weight = torch.nn.Parameter(weight.clone().detach(), requires_grad=True)
|
69 |
+
deform = torch.nn.Parameter(torch.zeros_like(x_nx3), requires_grad=True)
|
70 |
+
|
71 |
+
# Retrieve all the edges of the voxel grid; these edges will be utilized to
|
72 |
+
# compute the regularization loss in subsequent steps of the process.
|
73 |
+
all_edges = cube_fx8[:, fc.cube_edges].reshape(-1, 2)
|
74 |
+
grid_edges = torch.unique(all_edges, dim=0)
|
75 |
+
|
76 |
+
# ==============================================================================================
|
77 |
+
# Setup optimizer
|
78 |
+
# ==============================================================================================
|
79 |
+
optimizer = torch.optim.Adam([sdf, weight,deform], lr=FLAGS.learning_rate)
|
80 |
+
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda x: lr_schedule(x))
|
81 |
+
|
82 |
+
# ==============================================================================================
|
83 |
+
# Train loop
|
84 |
+
# ==============================================================================================
|
85 |
+
for it in range(FLAGS.iter):
|
86 |
+
optimizer.zero_grad()
|
87 |
+
# sample random camera poses
|
88 |
+
mv, mvp = render.get_random_camera_batch(FLAGS.batch, iter_res=FLAGS.train_res, device=device, use_kaolin=False)
|
89 |
+
# render gt mesh
|
90 |
+
target = render.render_mesh_paper(gt_mesh, mv, mvp, FLAGS.train_res)
|
91 |
+
# extract and render FlexiCubes mesh
|
92 |
+
grid_verts = x_nx3 + (2-1e-8) / (FLAGS.voxel_grid_res * 2) * torch.tanh(deform)
|
93 |
+
vertices, faces, L_dev = fc(grid_verts, sdf, cube_fx8, FLAGS.voxel_grid_res, beta_fx12=weight[:,:12], alpha_fx8=weight[:,12:20],
|
94 |
+
gamma_f=weight[:,20], training=True)
|
95 |
+
flexicubes_mesh = Mesh(vertices, faces)
|
96 |
+
buffers = render.render_mesh_paper(flexicubes_mesh, mv, mvp, FLAGS.train_res)
|
97 |
+
|
98 |
+
# evaluate reconstruction loss
|
99 |
+
mask_loss = (buffers['mask'] - target['mask']).abs().mean()
|
100 |
+
depth_loss = (((((buffers['depth'] - (target['depth']))* target['mask'])**2).sum(-1)+1e-8)).sqrt().mean() * 10
|
101 |
+
|
102 |
+
t_iter = it / FLAGS.iter
|
103 |
+
sdf_weight = FLAGS.sdf_regularizer - (FLAGS.sdf_regularizer - FLAGS.sdf_regularizer/20)*min(1.0, 4.0 * t_iter)
|
104 |
+
reg_loss = loss.sdf_reg_loss(sdf, grid_edges).mean() * sdf_weight # Loss to eliminate internal floaters that are not visible
|
105 |
+
reg_loss += L_dev.mean() * 0.5
|
106 |
+
reg_loss += (weight[:,:20]).abs().mean() * 0.1
|
107 |
+
total_loss = mask_loss + depth_loss + reg_loss
|
108 |
+
|
109 |
+
if FLAGS.sdf_loss: # optionally add SDF loss to eliminate internal structures
|
110 |
+
with torch.no_grad():
|
111 |
+
pts = sample_random_points(1000, gt_mesh)
|
112 |
+
gt_sdf = compute_sdf(pts, gt_mesh.vertices, gt_mesh.faces)
|
113 |
+
pred_sdf = compute_sdf(pts, flexicubes_mesh.vertices, flexicubes_mesh.faces)
|
114 |
+
total_loss += torch.nn.functional.mse_loss(pred_sdf, gt_sdf) * 2e3
|
115 |
+
|
116 |
+
# optionally add developability regularizer, as described in paper section 5.2
|
117 |
+
if FLAGS.develop_reg:
|
118 |
+
reg_weight = max(0, t_iter - 0.8) * 5
|
119 |
+
if reg_weight > 0: # only applied after shape converges
|
120 |
+
reg_loss = loss.mesh_developable_reg(flexicubes_mesh).mean() * 10
|
121 |
+
reg_loss += (deform).abs().mean()
|
122 |
+
reg_loss += (weight[:,:20]).abs().mean()
|
123 |
+
total_loss = mask_loss + depth_loss + reg_loss
|
124 |
+
|
125 |
+
total_loss.backward()
|
126 |
+
optimizer.step()
|
127 |
+
scheduler.step()
|
128 |
+
|
129 |
+
if (it % FLAGS.save_interval == 0 or it == (FLAGS.iter-1)): # save normal image for visualization
|
130 |
+
with torch.no_grad():
|
131 |
+
# extract mesh with training=False
|
132 |
+
vertices, faces, L_dev = fc(grid_verts, sdf, cube_fx8, FLAGS.voxel_grid_res, beta_fx12=weight[:,:12], alpha_fx8=weight[:,12:20],
|
133 |
+
gamma_f=weight[:,20], training=False)
|
134 |
+
flexicubes_mesh = Mesh(vertices, faces)
|
135 |
+
|
136 |
+
flexicubes_mesh.auto_normals() # compute face normals for visualization
|
137 |
+
mv, mvp = render.get_rotate_camera(it//FLAGS.save_interval, iter_res=FLAGS.display_res, device=device,use_kaolin=False)
|
138 |
+
val_buffers = render.render_mesh_paper(flexicubes_mesh, mv.unsqueeze(0), mvp.unsqueeze(0), FLAGS.display_res, return_types=["normal"], white_bg=True)
|
139 |
+
val_image = ((val_buffers["normal"][0].detach().cpu().numpy()+1)/2*255).astype(np.uint8)
|
140 |
+
|
141 |
+
gt_buffers = render.render_mesh_paper(gt_mesh, mv.unsqueeze(0), mvp.unsqueeze(0), FLAGS.display_res, return_types=["normal"], white_bg=True)
|
142 |
+
gt_image = ((gt_buffers["normal"][0].detach().cpu().numpy()+1)/2*255).astype(np.uint8)
|
143 |
+
imageio.imwrite(os.path.join(FLAGS.out_dir, '{:04d}.png'.format(it)), np.concatenate([val_image, gt_image], 1))
|
144 |
+
print(f"Optimization Step [{it}/{FLAGS.iter}], Loss: {total_loss.item():.4f}")
|
145 |
+
|
146 |
+
# ==============================================================================================
|
147 |
+
# Save ouput
|
148 |
+
# ==============================================================================================
|
149 |
+
mesh_np = trimesh.Trimesh(vertices = vertices.detach().cpu().numpy(), faces=faces.detach().cpu().numpy(), process=False)
|
150 |
+
mesh_np.export(os.path.join(FLAGS.out_dir, 'output_mesh.obj'))
|
trellis/representations/mesh/flexicubes/examples/render.py
ADDED
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
|
8 |
+
import numpy as np
|
9 |
+
import copy
|
10 |
+
import math
|
11 |
+
from ipywidgets import interactive, HBox, VBox, FloatLogSlider, IntSlider
|
12 |
+
|
13 |
+
import torch
|
14 |
+
import nvdiffrast.torch as dr
|
15 |
+
import kaolin as kal
|
16 |
+
import util
|
17 |
+
|
18 |
+
###############################################################################
|
19 |
+
# Functions adapted from https://github.com/NVlabs/nvdiffrec
|
20 |
+
###############################################################################
|
21 |
+
|
22 |
+
def get_random_camera_batch(batch_size, fovy = np.deg2rad(45), iter_res=[512,512], cam_near_far=[0.1, 1000.0], cam_radius=3.0, device="cuda", use_kaolin=True):
|
23 |
+
if use_kaolin:
|
24 |
+
camera_pos = torch.stack(kal.ops.coords.spherical2cartesian(
|
25 |
+
*kal.ops.random.sample_spherical_coords((batch_size,), azimuth_low=0., azimuth_high=math.pi * 2,
|
26 |
+
elevation_low=-math.pi / 2., elevation_high=math.pi / 2., device='cuda'),
|
27 |
+
cam_radius
|
28 |
+
), dim=-1)
|
29 |
+
return kal.render.camera.Camera.from_args(
|
30 |
+
eye=camera_pos + torch.rand((batch_size, 1), device='cuda') * 0.5 - 0.25,
|
31 |
+
at=torch.zeros(batch_size, 3),
|
32 |
+
up=torch.tensor([[0., 1., 0.]]),
|
33 |
+
fov=fovy,
|
34 |
+
near=cam_near_far[0], far=cam_near_far[1],
|
35 |
+
height=iter_res[0], width=iter_res[1],
|
36 |
+
device='cuda'
|
37 |
+
)
|
38 |
+
else:
|
39 |
+
def get_random_camera():
|
40 |
+
proj_mtx = util.perspective(fovy, iter_res[1] / iter_res[0], cam_near_far[0], cam_near_far[1])
|
41 |
+
mv = util.translate(0, 0, -cam_radius) @ util.random_rotation_translation(0.25)
|
42 |
+
mvp = proj_mtx @ mv
|
43 |
+
return mv, mvp
|
44 |
+
mv_batch = []
|
45 |
+
mvp_batch = []
|
46 |
+
for i in range(batch_size):
|
47 |
+
mv, mvp = get_random_camera()
|
48 |
+
mv_batch.append(mv)
|
49 |
+
mvp_batch.append(mvp)
|
50 |
+
return torch.stack(mv_batch).to(device), torch.stack(mvp_batch).to(device)
|
51 |
+
|
52 |
+
def get_rotate_camera(itr, fovy = np.deg2rad(45), iter_res=[512,512], cam_near_far=[0.1, 1000.0], cam_radius=3.0, device="cuda", use_kaolin=True):
|
53 |
+
if use_kaolin:
|
54 |
+
ang = (itr / 10) * np.pi * 2
|
55 |
+
camera_pos = torch.stack(kal.ops.coords.spherical2cartesian(torch.tensor(ang), torch.tensor(0.4), -torch.tensor(cam_radius)))
|
56 |
+
return kal.render.camera.Camera.from_args(
|
57 |
+
eye=camera_pos,
|
58 |
+
at=torch.zeros(3),
|
59 |
+
up=torch.tensor([0., 1., 0.]),
|
60 |
+
fov=fovy,
|
61 |
+
near=cam_near_far[0], far=cam_near_far[1],
|
62 |
+
height=iter_res[0], width=iter_res[1],
|
63 |
+
device='cuda'
|
64 |
+
)
|
65 |
+
else:
|
66 |
+
proj_mtx = util.perspective(fovy, iter_res[1] / iter_res[0], cam_near_far[0], cam_near_far[1])
|
67 |
+
|
68 |
+
# Smooth rotation for display.
|
69 |
+
ang = (itr / 10) * np.pi * 2
|
70 |
+
mv = util.translate(0, 0, -cam_radius) @ (util.rotate_x(-0.4) @ util.rotate_y(ang))
|
71 |
+
mvp = proj_mtx @ mv
|
72 |
+
return mv.to(device), mvp.to(device)
|
73 |
+
|
74 |
+
glctx = dr.RasterizeGLContext()
|
75 |
+
def render_mesh(mesh, camera, iter_res, return_types = ["mask", "depth"], white_bg=False, wireframe_thickness=0.4):
|
76 |
+
vertices_camera = camera.extrinsics.transform(mesh.vertices)
|
77 |
+
face_vertices_camera = kal.ops.mesh.index_vertices_by_faces(
|
78 |
+
vertices_camera, mesh.faces
|
79 |
+
)
|
80 |
+
|
81 |
+
# Projection: nvdiffrast take clip coordinates as input to apply barycentric perspective correction.
|
82 |
+
# Using `camera.intrinsics.transform(vertices_camera) would return the normalized device coordinates.
|
83 |
+
proj = camera.projection_matrix().unsqueeze(1)
|
84 |
+
proj[:, :, 1, 1] = -proj[:, :, 1, 1]
|
85 |
+
homogeneous_vecs = kal.render.camera.up_to_homogeneous(
|
86 |
+
vertices_camera
|
87 |
+
)
|
88 |
+
vertices_clip = (proj @ homogeneous_vecs.unsqueeze(-1)).squeeze(-1)
|
89 |
+
faces_int = mesh.faces.int()
|
90 |
+
|
91 |
+
rast, _ = dr.rasterize(
|
92 |
+
glctx, vertices_clip, faces_int, iter_res)
|
93 |
+
|
94 |
+
out_dict = {}
|
95 |
+
for type in return_types:
|
96 |
+
if type == "mask" :
|
97 |
+
img = dr.antialias((rast[..., -1:] > 0).float(), rast, vertices_clip, faces_int)
|
98 |
+
elif type == "depth":
|
99 |
+
img = dr.interpolate(homogeneous_vecs, rast, faces_int)[0]
|
100 |
+
elif type == "wireframe":
|
101 |
+
img = torch.logical_or(
|
102 |
+
torch.logical_or(rast[..., 0] < wireframe_thickness, rast[..., 1] < wireframe_thickness),
|
103 |
+
(rast[..., 0] + rast[..., 1]) > (1. - wireframe_thickness)
|
104 |
+
).unsqueeze(-1)
|
105 |
+
elif type == "normals" :
|
106 |
+
img = dr.interpolate(
|
107 |
+
mesh.face_normals.reshape(len(mesh), -1, 3), rast,
|
108 |
+
torch.arange(mesh.faces.shape[0] * 3, device='cuda', dtype=torch.int).reshape(-1, 3)
|
109 |
+
)[0]
|
110 |
+
if white_bg:
|
111 |
+
bg = torch.ones_like(img)
|
112 |
+
alpha = (rast[..., -1:] > 0).float()
|
113 |
+
img = torch.lerp(bg, img, alpha)
|
114 |
+
out_dict[type] = img
|
115 |
+
|
116 |
+
|
117 |
+
return out_dict
|
118 |
+
|
119 |
+
def render_mesh_paper(mesh, mv, mvp, iter_res, return_types = ["mask", "depth"], white_bg=False):
|
120 |
+
'''
|
121 |
+
The rendering function used to produce the results in the paper.
|
122 |
+
'''
|
123 |
+
v_pos_clip = util.xfm_points(mesh.vertices.unsqueeze(0), mvp) # Rotate it to camera coordinates
|
124 |
+
rast, db = dr.rasterize(
|
125 |
+
dr.RasterizeGLContext(), v_pos_clip, mesh.faces.int(), iter_res)
|
126 |
+
|
127 |
+
out_dict = {}
|
128 |
+
for type in return_types:
|
129 |
+
if type == "mask" :
|
130 |
+
img = dr.antialias((rast[..., -1:] > 0).float(), rast, v_pos_clip, mesh.faces.int())
|
131 |
+
elif type == "depth":
|
132 |
+
v_pos_cam = util.xfm_points(mesh.vertices.unsqueeze(0), mv)
|
133 |
+
img, _ = util.interpolate(v_pos_cam, rast, mesh.faces.int())
|
134 |
+
elif type == "normal" :
|
135 |
+
normal_indices = (torch.arange(0, mesh.nrm.shape[0], dtype=torch.int64, device='cuda')[:, None]).repeat(1, 3)
|
136 |
+
img, _ = util.interpolate(mesh.nrm.unsqueeze(0).contiguous(), rast, normal_indices.int())
|
137 |
+
elif type == "vertex_normal":
|
138 |
+
img, _ = util.interpolate(mesh.v_nrm.unsqueeze(0).contiguous(), rast, mesh.faces.int())
|
139 |
+
img = dr.antialias((img + 1) * 0.5, rast, v_pos_clip, mesh.faces.int())
|
140 |
+
if white_bg:
|
141 |
+
bg = torch.ones_like(img)
|
142 |
+
alpha = (rast[..., -1:] > 0).float()
|
143 |
+
img = torch.lerp(bg, img, alpha)
|
144 |
+
out_dict[type] = img
|
145 |
+
return out_dict
|
146 |
+
|
147 |
+
class SplitVisualizer():
|
148 |
+
def __init__(self, lh_mesh, rh_mesh, height, width):
|
149 |
+
self.lh_mesh = lh_mesh
|
150 |
+
self.rh_mesh = rh_mesh
|
151 |
+
self.height = height
|
152 |
+
self.width = width
|
153 |
+
self.wireframe_thickness = 0.4
|
154 |
+
|
155 |
+
|
156 |
+
def render(self, camera):
|
157 |
+
lh_outputs = render_mesh(
|
158 |
+
self.lh_mesh, camera, (self.height, self.width),
|
159 |
+
return_types=["normals", "wireframe"], wireframe_thickness=self.wireframe_thickness
|
160 |
+
)
|
161 |
+
rh_outputs = render_mesh(
|
162 |
+
self.rh_mesh, camera, (self.height, self.width),
|
163 |
+
return_types=["normals", "wireframe"], wireframe_thickness=self.wireframe_thickness
|
164 |
+
)
|
165 |
+
outputs = {
|
166 |
+
k: torch.cat(
|
167 |
+
[lh_outputs[k][0].permute(1, 0, 2), rh_outputs[k][0].permute(1, 0, 2)],
|
168 |
+
dim=0
|
169 |
+
).permute(1, 0, 2) for k in ["normals", "wireframe"]
|
170 |
+
}
|
171 |
+
return {
|
172 |
+
'img': (outputs['wireframe'] * ((outputs['normals'] + 1.) / 2.) * 255).to(torch.uint8),
|
173 |
+
'normals': outputs['normals']
|
174 |
+
}
|
175 |
+
|
176 |
+
def show(self, init_camera):
|
177 |
+
visualizer = kal.visualize.IpyTurntableVisualizer(
|
178 |
+
self.height, self.width * 2, copy.deepcopy(init_camera), self.render,
|
179 |
+
max_fps=24, world_up_axis=1)
|
180 |
+
|
181 |
+
def slider_callback(new_wireframe_thickness):
|
182 |
+
"""ipywidgets sliders callback"""
|
183 |
+
with visualizer.out: # This is in case of bug
|
184 |
+
self.wireframe_thickness = new_wireframe_thickness
|
185 |
+
# this is how we request a new update
|
186 |
+
visualizer.render_update()
|
187 |
+
|
188 |
+
wireframe_thickness_slider = FloatLogSlider(
|
189 |
+
value=self.wireframe_thickness,
|
190 |
+
base=10,
|
191 |
+
min=-3,
|
192 |
+
max=-0.4,
|
193 |
+
step=0.1,
|
194 |
+
description='wireframe_thickness',
|
195 |
+
continuous_update=True,
|
196 |
+
readout=True,
|
197 |
+
readout_format='.3f',
|
198 |
+
)
|
199 |
+
|
200 |
+
interactive_slider = interactive(
|
201 |
+
slider_callback,
|
202 |
+
new_wireframe_thickness=wireframe_thickness_slider,
|
203 |
+
)
|
204 |
+
|
205 |
+
full_output = VBox([visualizer.canvas, interactive_slider])
|
206 |
+
display(full_output, visualizer.out)
|
207 |
+
|
208 |
+
class TimelineVisualizer():
|
209 |
+
def __init__(self, meshes, height, width):
|
210 |
+
self.meshes = meshes
|
211 |
+
self.height = height
|
212 |
+
self.width = width
|
213 |
+
self.wireframe_thickness = 0.4
|
214 |
+
self.idx = len(meshes) - 1
|
215 |
+
|
216 |
+
def render(self, camera):
|
217 |
+
outputs = render_mesh(
|
218 |
+
self.meshes[self.idx], camera, (self.height, self.width),
|
219 |
+
return_types=["normals", "wireframe"], wireframe_thickness=self.wireframe_thickness
|
220 |
+
)
|
221 |
+
|
222 |
+
return {
|
223 |
+
'img': (outputs['wireframe'] * ((outputs['normals'] + 1.) / 2.) * 255).to(torch.uint8)[0],
|
224 |
+
'normals': outputs['normals'][0]
|
225 |
+
}
|
226 |
+
|
227 |
+
def show(self, init_camera):
|
228 |
+
visualizer = kal.visualize.IpyTurntableVisualizer(
|
229 |
+
self.height, self.width, copy.deepcopy(init_camera), self.render,
|
230 |
+
max_fps=24, world_up_axis=1)
|
231 |
+
|
232 |
+
def slider_callback(new_wireframe_thickness, new_idx):
|
233 |
+
"""ipywidgets sliders callback"""
|
234 |
+
with visualizer.out: # This is in case of bug
|
235 |
+
self.wireframe_thickness = new_wireframe_thickness
|
236 |
+
self.idx = new_idx
|
237 |
+
# this is how we request a new update
|
238 |
+
visualizer.render_update()
|
239 |
+
|
240 |
+
wireframe_thickness_slider = FloatLogSlider(
|
241 |
+
value=self.wireframe_thickness,
|
242 |
+
base=10,
|
243 |
+
min=-3,
|
244 |
+
max=-0.4,
|
245 |
+
step=0.1,
|
246 |
+
description='wireframe_thickness',
|
247 |
+
continuous_update=True,
|
248 |
+
readout=True,
|
249 |
+
readout_format='.3f',
|
250 |
+
)
|
251 |
+
|
252 |
+
idx_slider = IntSlider(
|
253 |
+
value=self.idx,
|
254 |
+
min=0,
|
255 |
+
max=len(self.meshes) - 1,
|
256 |
+
description='idx',
|
257 |
+
continuous_update=True,
|
258 |
+
readout=True
|
259 |
+
)
|
260 |
+
|
261 |
+
interactive_slider = interactive(
|
262 |
+
slider_callback,
|
263 |
+
new_wireframe_thickness=wireframe_thickness_slider,
|
264 |
+
new_idx=idx_slider
|
265 |
+
)
|
266 |
+
full_output = HBox([visualizer.canvas, interactive_slider])
|
267 |
+
display(full_output, visualizer.out)
|
trellis/representations/mesh/flexicubes/examples/util.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
import trimesh
|
11 |
+
import kaolin
|
12 |
+
import nvdiffrast.torch as dr
|
13 |
+
|
14 |
+
###############################################################################
|
15 |
+
# Functions adapted from https://github.com/NVlabs/nvdiffrec
|
16 |
+
###############################################################################
|
17 |
+
|
18 |
+
def dot(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
19 |
+
return torch.sum(x*y, -1, keepdim=True)
|
20 |
+
|
21 |
+
def length(x: torch.Tensor, eps: float =1e-8) -> torch.Tensor:
|
22 |
+
return torch.sqrt(torch.clamp(dot(x,x), min=eps)) # Clamp to avoid nan gradients because grad(sqrt(0)) = NaN
|
23 |
+
|
24 |
+
def safe_normalize(x: torch.Tensor, eps: float =1e-8) -> torch.Tensor:
|
25 |
+
return x / length(x, eps)
|
26 |
+
|
27 |
+
def perspective(fovy=0.7854, aspect=1.0, n=0.1, f=1000.0, device=None):
|
28 |
+
y = np.tan(fovy / 2)
|
29 |
+
return torch.tensor([[1/(y*aspect), 0, 0, 0],
|
30 |
+
[ 0, 1/-y, 0, 0],
|
31 |
+
[ 0, 0, -(f+n)/(f-n), -(2*f*n)/(f-n)],
|
32 |
+
[ 0, 0, -1, 0]], dtype=torch.float32, device=device)
|
33 |
+
|
34 |
+
def translate(x, y, z, device=None):
|
35 |
+
return torch.tensor([[1, 0, 0, x],
|
36 |
+
[0, 1, 0, y],
|
37 |
+
[0, 0, 1, z],
|
38 |
+
[0, 0, 0, 1]], dtype=torch.float32, device=device)
|
39 |
+
|
40 |
+
@torch.no_grad()
|
41 |
+
def random_rotation_translation(t, device=None):
|
42 |
+
m = np.random.normal(size=[3, 3])
|
43 |
+
m[1] = np.cross(m[0], m[2])
|
44 |
+
m[2] = np.cross(m[0], m[1])
|
45 |
+
m = m / np.linalg.norm(m, axis=1, keepdims=True)
|
46 |
+
m = np.pad(m, [[0, 1], [0, 1]], mode='constant')
|
47 |
+
m[3, 3] = 1.0
|
48 |
+
m[:3, 3] = np.random.uniform(-t, t, size=[3])
|
49 |
+
return torch.tensor(m, dtype=torch.float32, device=device)
|
50 |
+
|
51 |
+
def rotate_x(a, device=None):
|
52 |
+
s, c = np.sin(a), np.cos(a)
|
53 |
+
return torch.tensor([[1, 0, 0, 0],
|
54 |
+
[0, c, s, 0],
|
55 |
+
[0, -s, c, 0],
|
56 |
+
[0, 0, 0, 1]], dtype=torch.float32, device=device)
|
57 |
+
|
58 |
+
def rotate_y(a, device=None):
|
59 |
+
s, c = np.sin(a), np.cos(a)
|
60 |
+
return torch.tensor([[ c, 0, s, 0],
|
61 |
+
[ 0, 1, 0, 0],
|
62 |
+
[-s, 0, c, 0],
|
63 |
+
[ 0, 0, 0, 1]], dtype=torch.float32, device=device)
|
64 |
+
|
65 |
+
class Mesh:
|
66 |
+
def __init__(self, vertices, faces):
|
67 |
+
self.vertices = vertices
|
68 |
+
self.faces = faces
|
69 |
+
|
70 |
+
def auto_normals(self):
|
71 |
+
v0 = self.vertices[self.faces[:, 0], :]
|
72 |
+
v1 = self.vertices[self.faces[:, 1], :]
|
73 |
+
v2 = self.vertices[self.faces[:, 2], :]
|
74 |
+
nrm = safe_normalize(torch.cross(v1 - v0, v2 - v0))
|
75 |
+
self.nrm = nrm
|
76 |
+
|
77 |
+
def load_mesh(path, device):
|
78 |
+
mesh_np = trimesh.load(path)
|
79 |
+
vertices = torch.tensor(mesh_np.vertices, device=device, dtype=torch.float)
|
80 |
+
faces = torch.tensor(mesh_np.faces, device=device, dtype=torch.long)
|
81 |
+
|
82 |
+
# Normalize
|
83 |
+
vmin, vmax = vertices.min(dim=0)[0], vertices.max(dim=0)[0]
|
84 |
+
scale = 1.8 / torch.max(vmax - vmin).item()
|
85 |
+
vertices = vertices - (vmax + vmin) / 2 # Center mesh on origin
|
86 |
+
vertices = vertices * scale # Rescale to [-0.9, 0.9]
|
87 |
+
return Mesh(vertices, faces)
|
88 |
+
|
89 |
+
def compute_sdf(points, vertices, faces):
|
90 |
+
face_vertices = kaolin.ops.mesh.index_vertices_by_faces(vertices.clone().unsqueeze(0), faces)
|
91 |
+
distance = kaolin.metrics.trianglemesh.point_to_mesh_distance(points.unsqueeze(0), face_vertices)[0]
|
92 |
+
with torch.no_grad():
|
93 |
+
sign = (kaolin.ops.mesh.check_sign(vertices.unsqueeze(0), faces, points.unsqueeze(0))<1).float() * 2 - 1
|
94 |
+
sdf = (sign*distance).squeeze(0)
|
95 |
+
return sdf
|
96 |
+
|
97 |
+
def sample_random_points(n, mesh):
|
98 |
+
pts_random = (torch.rand((n//2,3),device='cuda') - 0.5) * 2
|
99 |
+
pts_surface = kaolin.ops.mesh.sample_points(mesh.vertices.unsqueeze(0), mesh.faces, 500)[0].squeeze(0)
|
100 |
+
pts_surface += torch.randn_like(pts_surface) * 0.05
|
101 |
+
pts = torch.cat([pts_random, pts_surface])
|
102 |
+
return pts
|
103 |
+
|
104 |
+
def xfm_points(points, matrix):
|
105 |
+
'''Transform points.
|
106 |
+
Args:
|
107 |
+
points: Tensor containing 3D points with shape [minibatch_size, num_vertices, 3] or [1, num_vertices, 3]
|
108 |
+
matrix: A 4x4 transform matrix with shape [minibatch_size, 4, 4]
|
109 |
+
use_python: Use PyTorch's torch.matmul (for validation)
|
110 |
+
Returns:
|
111 |
+
Transformed points in homogeneous 4D with shape [minibatch_size, num_vertices, 4].
|
112 |
+
'''
|
113 |
+
out = torch.matmul(
|
114 |
+
torch.nn.functional.pad(points, pad=(0, 1), mode='constant', value=1.0), torch.transpose(matrix, 1, 2))
|
115 |
+
if torch.is_anomaly_enabled():
|
116 |
+
assert torch.all(torch.isfinite(out)), "Output of xfm_points contains inf or NaN"
|
117 |
+
return out
|
118 |
+
|
119 |
+
def interpolate(attr, rast, attr_idx, rast_db=None):
|
120 |
+
return dr.interpolate(
|
121 |
+
attr, rast, attr_idx, rast_db=rast_db,
|
122 |
+
diff_attrs=None if rast_db is None else 'all')
|
trellis/representations/mesh/flexicubes/flexicubes.py
ADDED
@@ -0,0 +1,384 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
|
8 |
+
|
9 |
+
import torch
|
10 |
+
from .tables import *
|
11 |
+
from kaolin.utils.testing import check_tensor
|
12 |
+
|
13 |
+
__all__ = [
|
14 |
+
'FlexiCubes'
|
15 |
+
]
|
16 |
+
|
17 |
+
|
18 |
+
class FlexiCubes:
|
19 |
+
def __init__(self, device="cuda"):
|
20 |
+
|
21 |
+
self.device = device
|
22 |
+
self.dmc_table = torch.tensor(dmc_table, dtype=torch.long, device=device, requires_grad=False)
|
23 |
+
self.num_vd_table = torch.tensor(num_vd_table,
|
24 |
+
dtype=torch.long, device=device, requires_grad=False)
|
25 |
+
self.check_table = torch.tensor(
|
26 |
+
check_table,
|
27 |
+
dtype=torch.long, device=device, requires_grad=False)
|
28 |
+
|
29 |
+
self.tet_table = torch.tensor(tet_table, dtype=torch.long, device=device, requires_grad=False)
|
30 |
+
self.quad_split_1 = torch.tensor([0, 1, 2, 0, 2, 3], dtype=torch.long, device=device, requires_grad=False)
|
31 |
+
self.quad_split_2 = torch.tensor([0, 1, 3, 3, 1, 2], dtype=torch.long, device=device, requires_grad=False)
|
32 |
+
self.quad_split_train = torch.tensor(
|
33 |
+
[0, 1, 1, 2, 2, 3, 3, 0], dtype=torch.long, device=device, requires_grad=False)
|
34 |
+
|
35 |
+
self.cube_corners = torch.tensor([[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0], [0, 0, 1], [
|
36 |
+
1, 0, 1], [0, 1, 1], [1, 1, 1]], dtype=torch.float, device=device)
|
37 |
+
self.cube_corners_idx = torch.pow(2, torch.arange(8, requires_grad=False))
|
38 |
+
self.cube_edges = torch.tensor([0, 1, 1, 5, 4, 5, 0, 4, 2, 3, 3, 7, 6, 7, 2, 6,
|
39 |
+
2, 0, 3, 1, 7, 5, 6, 4], dtype=torch.long, device=device, requires_grad=False)
|
40 |
+
|
41 |
+
self.edge_dir_table = torch.tensor([0, 2, 0, 2, 0, 2, 0, 2, 1, 1, 1, 1],
|
42 |
+
dtype=torch.long, device=device)
|
43 |
+
self.dir_faces_table = torch.tensor([
|
44 |
+
[[5, 4], [3, 2], [4, 5], [2, 3]],
|
45 |
+
[[5, 4], [1, 0], [4, 5], [0, 1]],
|
46 |
+
[[3, 2], [1, 0], [2, 3], [0, 1]]
|
47 |
+
], dtype=torch.long, device=device)
|
48 |
+
self.adj_pairs = torch.tensor([0, 1, 1, 3, 3, 2, 2, 0], dtype=torch.long, device=device)
|
49 |
+
|
50 |
+
def __call__(self, voxelgrid_vertices, scalar_field, cube_idx, resolution, qef_reg_scale=1e-3,
|
51 |
+
weight_scale=0.99, beta=None, alpha=None, gamma_f=None, voxelgrid_colors=None, training=False):
|
52 |
+
assert torch.is_tensor(voxelgrid_vertices) and \
|
53 |
+
check_tensor(voxelgrid_vertices, (None, 3), throw=False), \
|
54 |
+
"'voxelgrid_vertices' should be a tensor of shape (num_vertices, 3)"
|
55 |
+
num_vertices = voxelgrid_vertices.shape[0]
|
56 |
+
assert torch.is_tensor(scalar_field) and \
|
57 |
+
check_tensor(scalar_field, (num_vertices,), throw=False), \
|
58 |
+
"'scalar_field' should be a tensor of shape (num_vertices,)"
|
59 |
+
assert torch.is_tensor(cube_idx) and \
|
60 |
+
check_tensor(cube_idx, (None, 8), throw=False), \
|
61 |
+
"'cube_idx' should be a tensor of shape (num_cubes, 8)"
|
62 |
+
num_cubes = cube_idx.shape[0]
|
63 |
+
assert beta is None or (
|
64 |
+
torch.is_tensor(beta) and
|
65 |
+
check_tensor(beta, (num_cubes, 12), throw=False)
|
66 |
+
), "'beta' should be a tensor of shape (num_cubes, 12)"
|
67 |
+
assert alpha is None or (
|
68 |
+
torch.is_tensor(alpha) and
|
69 |
+
check_tensor(alpha, (num_cubes, 8), throw=False)
|
70 |
+
), "'alpha' should be a tensor of shape (num_cubes, 8)"
|
71 |
+
assert gamma_f is None or (
|
72 |
+
torch.is_tensor(gamma_f) and
|
73 |
+
check_tensor(gamma_f, (num_cubes,), throw=False)
|
74 |
+
), "'gamma_f' should be a tensor of shape (num_cubes,)"
|
75 |
+
|
76 |
+
surf_cubes, occ_fx8 = self._identify_surf_cubes(scalar_field, cube_idx)
|
77 |
+
if surf_cubes.sum() == 0:
|
78 |
+
return (
|
79 |
+
torch.zeros((0, 3), device=self.device),
|
80 |
+
torch.zeros((0, 3), dtype=torch.long, device=self.device),
|
81 |
+
torch.zeros((0), device=self.device),
|
82 |
+
torch.zeros((0, voxelgrid_colors.shape[-1]), device=self.device) if voxelgrid_colors is not None else None
|
83 |
+
)
|
84 |
+
beta, alpha, gamma_f = self._normalize_weights(
|
85 |
+
beta, alpha, gamma_f, surf_cubes, weight_scale)
|
86 |
+
|
87 |
+
if voxelgrid_colors is not None:
|
88 |
+
voxelgrid_colors = torch.sigmoid(voxelgrid_colors)
|
89 |
+
|
90 |
+
case_ids = self._get_case_id(occ_fx8, surf_cubes, resolution)
|
91 |
+
|
92 |
+
surf_edges, idx_map, edge_counts, surf_edges_mask = self._identify_surf_edges(
|
93 |
+
scalar_field, cube_idx, surf_cubes
|
94 |
+
)
|
95 |
+
|
96 |
+
vd, L_dev, vd_gamma, vd_idx_map, vd_color = self._compute_vd(
|
97 |
+
voxelgrid_vertices, cube_idx[surf_cubes], surf_edges, scalar_field,
|
98 |
+
case_ids, beta, alpha, gamma_f, idx_map, qef_reg_scale, voxelgrid_colors)
|
99 |
+
vertices, faces, s_edges, edge_indices, vertices_color = self._triangulate(
|
100 |
+
scalar_field, surf_edges, vd, vd_gamma, edge_counts, idx_map,
|
101 |
+
vd_idx_map, surf_edges_mask, training, vd_color)
|
102 |
+
return vertices, faces, L_dev, vertices_color
|
103 |
+
|
104 |
+
def _compute_reg_loss(self, vd, ue, edge_group_to_vd, vd_num_edges):
|
105 |
+
"""
|
106 |
+
Regularizer L_dev as in Equation 8
|
107 |
+
"""
|
108 |
+
dist = torch.norm(ue - torch.index_select(input=vd, index=edge_group_to_vd, dim=0), dim=-1)
|
109 |
+
mean_l2 = torch.zeros_like(vd[:, 0])
|
110 |
+
mean_l2 = (mean_l2).index_add_(0, edge_group_to_vd, dist) / vd_num_edges.squeeze(1).float()
|
111 |
+
mad = (dist - torch.index_select(input=mean_l2, index=edge_group_to_vd, dim=0)).abs()
|
112 |
+
return mad
|
113 |
+
|
114 |
+
def _normalize_weights(self, beta, alpha, gamma_f, surf_cubes, weight_scale):
|
115 |
+
"""
|
116 |
+
Normalizes the given weights to be non-negative. If input weights are None, it creates and returns a set of weights of ones.
|
117 |
+
"""
|
118 |
+
n_cubes = surf_cubes.shape[0]
|
119 |
+
|
120 |
+
if beta is not None:
|
121 |
+
beta = (torch.tanh(beta) * weight_scale + 1)
|
122 |
+
else:
|
123 |
+
beta = torch.ones((n_cubes, 12), dtype=torch.float, device=self.device)
|
124 |
+
|
125 |
+
if alpha is not None:
|
126 |
+
alpha = (torch.tanh(alpha) * weight_scale + 1)
|
127 |
+
else:
|
128 |
+
alpha = torch.ones((n_cubes, 8), dtype=torch.float, device=self.device)
|
129 |
+
|
130 |
+
if gamma_f is not None:
|
131 |
+
gamma_f = torch.sigmoid(gamma_f) * weight_scale + (1 - weight_scale) / 2
|
132 |
+
else:
|
133 |
+
gamma_f = torch.ones((n_cubes), dtype=torch.float, device=self.device)
|
134 |
+
|
135 |
+
return beta[surf_cubes], alpha[surf_cubes], gamma_f[surf_cubes]
|
136 |
+
|
137 |
+
@torch.no_grad()
|
138 |
+
def _get_case_id(self, occ_fx8, surf_cubes, res):
|
139 |
+
"""
|
140 |
+
Obtains the ID of topology cases based on cell corner occupancy. This function resolves the
|
141 |
+
ambiguity in the Dual Marching Cubes (DMC) configurations as described in Section 1.3 of the
|
142 |
+
supplementary material. It should be noted that this function assumes a regular grid.
|
143 |
+
"""
|
144 |
+
case_ids = (occ_fx8[surf_cubes] * self.cube_corners_idx.to(self.device).unsqueeze(0)).sum(-1)
|
145 |
+
|
146 |
+
problem_config = self.check_table.to(self.device)[case_ids]
|
147 |
+
to_check = problem_config[..., 0] == 1
|
148 |
+
problem_config = problem_config[to_check]
|
149 |
+
if not isinstance(res, (list, tuple)):
|
150 |
+
res = [res, res, res]
|
151 |
+
|
152 |
+
# The 'problematic_configs' only contain configurations for surface cubes. Next, we construct a 3D array,
|
153 |
+
# 'problem_config_full', to store configurations for all cubes (with default config for non-surface cubes).
|
154 |
+
# This allows efficient checking on adjacent cubes.
|
155 |
+
problem_config_full = torch.zeros(list(res) + [5], device=self.device, dtype=torch.long)
|
156 |
+
vol_idx = torch.nonzero(problem_config_full[..., 0] == 0) # N, 3
|
157 |
+
vol_idx_problem = vol_idx[surf_cubes][to_check]
|
158 |
+
problem_config_full[vol_idx_problem[..., 0], vol_idx_problem[..., 1], vol_idx_problem[..., 2]] = problem_config
|
159 |
+
vol_idx_problem_adj = vol_idx_problem + problem_config[..., 1:4]
|
160 |
+
|
161 |
+
within_range = (
|
162 |
+
vol_idx_problem_adj[..., 0] >= 0) & (
|
163 |
+
vol_idx_problem_adj[..., 0] < res[0]) & (
|
164 |
+
vol_idx_problem_adj[..., 1] >= 0) & (
|
165 |
+
vol_idx_problem_adj[..., 1] < res[1]) & (
|
166 |
+
vol_idx_problem_adj[..., 2] >= 0) & (
|
167 |
+
vol_idx_problem_adj[..., 2] < res[2])
|
168 |
+
|
169 |
+
vol_idx_problem = vol_idx_problem[within_range]
|
170 |
+
vol_idx_problem_adj = vol_idx_problem_adj[within_range]
|
171 |
+
problem_config = problem_config[within_range]
|
172 |
+
problem_config_adj = problem_config_full[vol_idx_problem_adj[..., 0],
|
173 |
+
vol_idx_problem_adj[..., 1], vol_idx_problem_adj[..., 2]]
|
174 |
+
# If two cubes with cases C16 and C19 share an ambiguous face, both cases are inverted.
|
175 |
+
to_invert = (problem_config_adj[..., 0] == 1)
|
176 |
+
idx = torch.arange(case_ids.shape[0], device=self.device)[to_check][within_range][to_invert]
|
177 |
+
case_ids.index_put_((idx,), problem_config[to_invert][..., -1])
|
178 |
+
return case_ids
|
179 |
+
|
180 |
+
@torch.no_grad()
|
181 |
+
def _identify_surf_edges(self, scalar_field, cube_idx, surf_cubes):
|
182 |
+
"""
|
183 |
+
Identifies grid edges that intersect with the underlying surface by checking for opposite signs. As each edge
|
184 |
+
can be shared by multiple cubes, this function also assigns a unique index to each surface-intersecting edge
|
185 |
+
and marks the cube edges with this index.
|
186 |
+
"""
|
187 |
+
occ_n = scalar_field < 0
|
188 |
+
all_edges = cube_idx[surf_cubes][:, self.cube_edges].reshape(-1, 2)
|
189 |
+
unique_edges, _idx_map, counts = torch.unique(all_edges, dim=0, return_inverse=True, return_counts=True)
|
190 |
+
|
191 |
+
unique_edges = unique_edges.long()
|
192 |
+
mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1
|
193 |
+
|
194 |
+
surf_edges_mask = mask_edges[_idx_map]
|
195 |
+
counts = counts[_idx_map]
|
196 |
+
|
197 |
+
mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=cube_idx.device) * -1
|
198 |
+
mapping[mask_edges] = torch.arange(mask_edges.sum(), device=cube_idx.device)
|
199 |
+
# Shaped as [number of cubes x 12 edges per cube]. This is later used to map a cube edge to the unique index
|
200 |
+
# for a surface-intersecting edge. Non-surface-intersecting edges are marked with -1.
|
201 |
+
idx_map = mapping[_idx_map]
|
202 |
+
surf_edges = unique_edges[mask_edges]
|
203 |
+
return surf_edges, idx_map, counts, surf_edges_mask
|
204 |
+
|
205 |
+
@torch.no_grad()
|
206 |
+
def _identify_surf_cubes(self, scalar_field, cube_idx):
|
207 |
+
"""
|
208 |
+
Identifies grid cubes that intersect with the underlying surface by checking if the signs at
|
209 |
+
all corners are not identical.
|
210 |
+
"""
|
211 |
+
occ_n = scalar_field < 0
|
212 |
+
occ_fx8 = occ_n[cube_idx.reshape(-1)].reshape(-1, 8)
|
213 |
+
_occ_sum = torch.sum(occ_fx8, -1)
|
214 |
+
surf_cubes = (_occ_sum > 0) & (_occ_sum < 8)
|
215 |
+
return surf_cubes, occ_fx8
|
216 |
+
|
217 |
+
def _linear_interp(self, edges_weight, edges_x):
|
218 |
+
"""
|
219 |
+
Computes the location of zero-crossings on 'edges_x' using linear interpolation with 'edges_weight'.
|
220 |
+
"""
|
221 |
+
edge_dim = edges_weight.dim() - 2
|
222 |
+
assert edges_weight.shape[edge_dim] == 2
|
223 |
+
edges_weight = torch.cat([torch.index_select(input=edges_weight, index=torch.tensor(1, device=self.device), dim=edge_dim), -
|
224 |
+
torch.index_select(input=edges_weight, index=torch.tensor(0, device=self.device), dim=edge_dim)]
|
225 |
+
, edge_dim)
|
226 |
+
denominator = edges_weight.sum(edge_dim)
|
227 |
+
ue = (edges_x * edges_weight).sum(edge_dim) / denominator
|
228 |
+
return ue
|
229 |
+
|
230 |
+
def _solve_vd_QEF(self, p_bxnx3, norm_bxnx3, c_bx3, qef_reg_scale):
|
231 |
+
p_bxnx3 = p_bxnx3.reshape(-1, 7, 3)
|
232 |
+
norm_bxnx3 = norm_bxnx3.reshape(-1, 7, 3)
|
233 |
+
c_bx3 = c_bx3.reshape(-1, 3)
|
234 |
+
A = norm_bxnx3
|
235 |
+
B = ((p_bxnx3) * norm_bxnx3).sum(-1, keepdims=True)
|
236 |
+
|
237 |
+
A_reg = (torch.eye(3, device=p_bxnx3.device) * qef_reg_scale).unsqueeze(0).repeat(p_bxnx3.shape[0], 1, 1)
|
238 |
+
B_reg = (qef_reg_scale * c_bx3).unsqueeze(-1)
|
239 |
+
A = torch.cat([A, A_reg], 1)
|
240 |
+
B = torch.cat([B, B_reg], 1)
|
241 |
+
dual_verts = torch.linalg.lstsq(A, B).solution.squeeze(-1)
|
242 |
+
return dual_verts
|
243 |
+
|
244 |
+
def _compute_vd(self, voxelgrid_vertices, surf_cubes_fx8, surf_edges, scalar_field,
|
245 |
+
case_ids, beta, alpha, gamma_f, idx_map, qef_reg_scale, voxelgrid_colors):
|
246 |
+
"""
|
247 |
+
Computes the location of dual vertices as described in Section 4.2
|
248 |
+
"""
|
249 |
+
alpha_nx12x2 = torch.index_select(input=alpha, index=self.cube_edges, dim=1).reshape(-1, 12, 2)
|
250 |
+
surf_edges_x = torch.index_select(input=voxelgrid_vertices, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, 3)
|
251 |
+
surf_edges_s = torch.index_select(input=scalar_field, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, 1)
|
252 |
+
zero_crossing = self._linear_interp(surf_edges_s, surf_edges_x)
|
253 |
+
|
254 |
+
if voxelgrid_colors is not None:
|
255 |
+
C = voxelgrid_colors.shape[-1]
|
256 |
+
surf_edges_c = torch.index_select(input=voxelgrid_colors, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, C)
|
257 |
+
|
258 |
+
idx_map = idx_map.reshape(-1, 12)
|
259 |
+
num_vd = torch.index_select(input=self.num_vd_table, index=case_ids, dim=0)
|
260 |
+
edge_group, edge_group_to_vd, edge_group_to_cube, vd_num_edges, vd_gamma = [], [], [], [], []
|
261 |
+
|
262 |
+
# if color is not None:
|
263 |
+
# vd_color = []
|
264 |
+
|
265 |
+
total_num_vd = 0
|
266 |
+
vd_idx_map = torch.zeros((case_ids.shape[0], 12), dtype=torch.long, device=self.device, requires_grad=False)
|
267 |
+
|
268 |
+
for num in torch.unique(num_vd):
|
269 |
+
cur_cubes = (num_vd == num) # consider cubes with the same numbers of vd emitted (for batching)
|
270 |
+
curr_num_vd = cur_cubes.sum() * num
|
271 |
+
curr_edge_group = self.dmc_table[case_ids[cur_cubes], :num].reshape(-1, num * 7)
|
272 |
+
curr_edge_group_to_vd = torch.arange(
|
273 |
+
curr_num_vd, device=self.device).unsqueeze(-1).repeat(1, 7) + total_num_vd
|
274 |
+
total_num_vd += curr_num_vd
|
275 |
+
curr_edge_group_to_cube = torch.arange(idx_map.shape[0], device=self.device)[
|
276 |
+
cur_cubes].unsqueeze(-1).repeat(1, num * 7).reshape_as(curr_edge_group)
|
277 |
+
|
278 |
+
curr_mask = (curr_edge_group != -1)
|
279 |
+
edge_group.append(torch.masked_select(curr_edge_group, curr_mask))
|
280 |
+
edge_group_to_vd.append(torch.masked_select(curr_edge_group_to_vd.reshape_as(curr_edge_group), curr_mask))
|
281 |
+
edge_group_to_cube.append(torch.masked_select(curr_edge_group_to_cube, curr_mask))
|
282 |
+
vd_num_edges.append(curr_mask.reshape(-1, 7).sum(-1, keepdims=True))
|
283 |
+
vd_gamma.append(torch.masked_select(gamma_f, cur_cubes).unsqueeze(-1).repeat(1, num).reshape(-1))
|
284 |
+
# if color is not None:
|
285 |
+
# vd_color.append(color[cur_cubes].unsqueeze(1).repeat(1, num, 1).reshape(-1, 3))
|
286 |
+
|
287 |
+
edge_group = torch.cat(edge_group)
|
288 |
+
edge_group_to_vd = torch.cat(edge_group_to_vd)
|
289 |
+
edge_group_to_cube = torch.cat(edge_group_to_cube)
|
290 |
+
vd_num_edges = torch.cat(vd_num_edges)
|
291 |
+
vd_gamma = torch.cat(vd_gamma)
|
292 |
+
# if color is not None:
|
293 |
+
# vd_color = torch.cat(vd_color)
|
294 |
+
# else:
|
295 |
+
# vd_color = None
|
296 |
+
|
297 |
+
vd = torch.zeros((total_num_vd, 3), device=self.device)
|
298 |
+
beta_sum = torch.zeros((total_num_vd, 1), device=self.device)
|
299 |
+
|
300 |
+
idx_group = torch.gather(input=idx_map.reshape(-1), dim=0, index=edge_group_to_cube * 12 + edge_group)
|
301 |
+
|
302 |
+
x_group = torch.index_select(input=surf_edges_x, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, 3)
|
303 |
+
s_group = torch.index_select(input=surf_edges_s, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, 1)
|
304 |
+
|
305 |
+
|
306 |
+
zero_crossing_group = torch.index_select(
|
307 |
+
input=zero_crossing, index=idx_group.reshape(-1), dim=0).reshape(-1, 3)
|
308 |
+
|
309 |
+
alpha_group = torch.index_select(input=alpha_nx12x2.reshape(-1, 2), dim=0,
|
310 |
+
index=edge_group_to_cube * 12 + edge_group).reshape(-1, 2, 1)
|
311 |
+
ue_group = self._linear_interp(s_group * alpha_group, x_group)
|
312 |
+
|
313 |
+
beta_group = torch.gather(input=beta.reshape(-1), dim=0,
|
314 |
+
index=edge_group_to_cube * 12 + edge_group).reshape(-1, 1)
|
315 |
+
beta_sum = beta_sum.index_add_(0, index=edge_group_to_vd, source=beta_group)
|
316 |
+
vd = vd.index_add_(0, index=edge_group_to_vd, source=ue_group * beta_group) / beta_sum
|
317 |
+
|
318 |
+
'''
|
319 |
+
interpolate colors use the same method as dual vertices
|
320 |
+
'''
|
321 |
+
if voxelgrid_colors is not None:
|
322 |
+
vd_color = torch.zeros((total_num_vd, C), device=self.device)
|
323 |
+
c_group = torch.index_select(input=surf_edges_c, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, C)
|
324 |
+
uc_group = self._linear_interp(s_group * alpha_group, c_group)
|
325 |
+
vd_color = vd_color.index_add_(0, index=edge_group_to_vd, source=uc_group * beta_group) / beta_sum
|
326 |
+
else:
|
327 |
+
vd_color = None
|
328 |
+
|
329 |
+
L_dev = self._compute_reg_loss(vd, zero_crossing_group, edge_group_to_vd, vd_num_edges)
|
330 |
+
|
331 |
+
v_idx = torch.arange(vd.shape[0], device=self.device) # + total_num_vd
|
332 |
+
|
333 |
+
vd_idx_map = (vd_idx_map.reshape(-1)).scatter(dim=0, index=edge_group_to_cube *
|
334 |
+
12 + edge_group, src=v_idx[edge_group_to_vd])
|
335 |
+
|
336 |
+
return vd, L_dev, vd_gamma, vd_idx_map, vd_color
|
337 |
+
|
338 |
+
def _triangulate(self, scalar_field, surf_edges, vd, vd_gamma, edge_counts, idx_map, vd_idx_map, surf_edges_mask, training, vd_color):
|
339 |
+
"""
|
340 |
+
Connects four neighboring dual vertices to form a quadrilateral. The quadrilaterals are then split into
|
341 |
+
triangles based on the gamma parameter, as described in Section 4.3.
|
342 |
+
"""
|
343 |
+
with torch.no_grad():
|
344 |
+
group_mask = (edge_counts == 4) & surf_edges_mask # surface edges shared by 4 cubes.
|
345 |
+
group = idx_map.reshape(-1)[group_mask]
|
346 |
+
vd_idx = vd_idx_map[group_mask]
|
347 |
+
edge_indices, indices = torch.sort(group, stable=True)
|
348 |
+
quad_vd_idx = vd_idx[indices].reshape(-1, 4)
|
349 |
+
|
350 |
+
# Ensure all face directions point towards the positive SDF to maintain consistent winding.
|
351 |
+
s_edges = scalar_field[surf_edges[edge_indices.reshape(-1, 4)[:, 0]].reshape(-1)].reshape(-1, 2)
|
352 |
+
flip_mask = s_edges[:, 0] > 0
|
353 |
+
quad_vd_idx = torch.cat((quad_vd_idx[flip_mask][:, [0, 1, 3, 2]],
|
354 |
+
quad_vd_idx[~flip_mask][:, [2, 3, 1, 0]]))
|
355 |
+
|
356 |
+
quad_gamma = torch.index_select(input=vd_gamma, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4)
|
357 |
+
gamma_02 = quad_gamma[:, 0] * quad_gamma[:, 2]
|
358 |
+
gamma_13 = quad_gamma[:, 1] * quad_gamma[:, 3]
|
359 |
+
if not training:
|
360 |
+
mask = (gamma_02 > gamma_13)
|
361 |
+
faces = torch.zeros((quad_gamma.shape[0], 6), dtype=torch.long, device=quad_vd_idx.device)
|
362 |
+
faces[mask] = quad_vd_idx[mask][:, self.quad_split_1]
|
363 |
+
faces[~mask] = quad_vd_idx[~mask][:, self.quad_split_2]
|
364 |
+
faces = faces.reshape(-1, 3)
|
365 |
+
else:
|
366 |
+
vd_quad = torch.index_select(input=vd, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4, 3)
|
367 |
+
vd_02 = (vd_quad[:, 0] + vd_quad[:, 2]) / 2
|
368 |
+
vd_13 = (vd_quad[:, 1] + vd_quad[:, 3]) / 2
|
369 |
+
weight_sum = (gamma_02 + gamma_13) + 1e-8
|
370 |
+
vd_center = (vd_02 * gamma_02.unsqueeze(-1) + vd_13 * gamma_13.unsqueeze(-1)) / weight_sum.unsqueeze(-1)
|
371 |
+
|
372 |
+
if vd_color is not None:
|
373 |
+
color_quad = torch.index_select(input=vd_color, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4, vd_color.shape[-1])
|
374 |
+
color_02 = (color_quad[:, 0] + color_quad[:, 2]) / 2
|
375 |
+
color_13 = (color_quad[:, 1] + color_quad[:, 3]) / 2
|
376 |
+
color_center = (color_02 * gamma_02.unsqueeze(-1) + color_13 * gamma_13.unsqueeze(-1)) / weight_sum.unsqueeze(-1)
|
377 |
+
vd_color = torch.cat([vd_color, color_center])
|
378 |
+
|
379 |
+
|
380 |
+
vd_center_idx = torch.arange(vd_center.shape[0], device=self.device) + vd.shape[0]
|
381 |
+
vd = torch.cat([vd, vd_center])
|
382 |
+
faces = quad_vd_idx[:, self.quad_split_train].reshape(-1, 4, 2)
|
383 |
+
faces = torch.cat([faces, vd_center_idx.reshape(-1, 1, 1).repeat(1, 4, 1)], -1).reshape(-1, 3)
|
384 |
+
return vd, faces, s_edges, edge_indices, vd_color
|
trellis/representations/mesh/flexicubes/images/ablate_L_dev.jpg
ADDED
![]() |
trellis/representations/mesh/flexicubes/images/block_final.png
ADDED
![]() |
Git LFS Details
|
trellis/representations/mesh/flexicubes/images/block_init.png
ADDED
![]() |
Git LFS Details
|
trellis/representations/mesh/flexicubes/images/teaser_top.png
ADDED
![]() |
Git LFS Details
|
trellis/representations/mesh/flexicubes/tables.py
ADDED
@@ -0,0 +1,791 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
|
8 |
+
dmc_table = [
|
9 |
+
[[-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
10 |
+
[[0, 3, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
11 |
+
[[0, 1, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
12 |
+
[[1, 3, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
13 |
+
[[4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
14 |
+
[[0, 3, 4, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
15 |
+
[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
16 |
+
[[1, 3, 4, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
17 |
+
[[4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
18 |
+
[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
19 |
+
[[0, 1, 4, 5, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
20 |
+
[[1, 3, 4, 5, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
21 |
+
[[5, 7, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
22 |
+
[[0, 3, 5, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
23 |
+
[[0, 1, 5, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
24 |
+
[[1, 3, 5, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
25 |
+
[[2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
26 |
+
[[0, 2, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
27 |
+
[[0, 1, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
28 |
+
[[1, 2, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
29 |
+
[[4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
30 |
+
[[0, 2, 4, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
31 |
+
[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
32 |
+
[[1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
33 |
+
[[4, 5, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
34 |
+
[[0, 2, 8, 11, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
35 |
+
[[0, 1, 4, 5, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
36 |
+
[[1, 2, 4, 5, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
37 |
+
[[5, 7, 8, 9, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
38 |
+
[[0, 2, 5, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
39 |
+
[[0, 1, 5, 7, 8, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
40 |
+
[[1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
41 |
+
[[1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
42 |
+
[[0, 3, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
43 |
+
[[0, 2, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
44 |
+
[[2, 3, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
45 |
+
[[4, 7, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
46 |
+
[[0, 3, 4, 7, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
47 |
+
[[0, 2, 9, 10, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
48 |
+
[[2, 3, 4, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
49 |
+
[[4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
50 |
+
[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
51 |
+
[[0, 2, 4, 5, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
52 |
+
[[2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
53 |
+
[[5, 7, 8, 9, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
54 |
+
[[0, 3, 5, 7, 9, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
55 |
+
[[0, 2, 5, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
56 |
+
[[2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
57 |
+
[[1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
58 |
+
[[0, 1, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
59 |
+
[[0, 3, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
60 |
+
[[8, 9, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
61 |
+
[[4, 7, 8, -1, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
62 |
+
[[0, 1, 4, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
63 |
+
[[0, 3, 9, 10, 11, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
64 |
+
[[4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
65 |
+
[[4, 5, 9, -1, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
66 |
+
[[0, 1, 8, 10, 11, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
67 |
+
[[0, 3, 4, 5, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
68 |
+
[[4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
69 |
+
[[5, 7, 8, 9, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
70 |
+
[[0, 1, 5, 7, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
71 |
+
[[0, 3, 5, 7, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
72 |
+
[[5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
73 |
+
[[6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
74 |
+
[[0, 3, 8, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
75 |
+
[[0, 1, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
76 |
+
[[1, 3, 8, 9, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
77 |
+
[[4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
78 |
+
[[0, 3, 4, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
79 |
+
[[0, 1, 9, -1, -1, -1, -1], [4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
80 |
+
[[1, 3, 4, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
81 |
+
[[4, 5, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
82 |
+
[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
83 |
+
[[0, 1, 4, 5, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
84 |
+
[[1, 3, 4, 5, 8, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
85 |
+
[[5, 6, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
86 |
+
[[0, 3, 5, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
87 |
+
[[0, 1, 5, 6, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
88 |
+
[[1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
89 |
+
[[2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
90 |
+
[[0, 2, 6, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
91 |
+
[[0, 1, 9, -1, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
92 |
+
[[1, 2, 6, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
93 |
+
[[2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
94 |
+
[[0, 2, 4, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
95 |
+
[[0, 1, 9, -1, -1, -1, -1], [2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
96 |
+
[[1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
97 |
+
[[4, 5, 9, -1, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
98 |
+
[[0, 2, 6, 7, 8, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
99 |
+
[[0, 1, 4, 5, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
100 |
+
[[1, 2, 4, 5, 6, 7, 8], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
101 |
+
[[2, 3, 5, 6, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
102 |
+
[[0, 2, 5, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
103 |
+
[[0, 1, 2, 3, 5, 6, 8], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
104 |
+
[[1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
105 |
+
[[1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
106 |
+
[[0, 3, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
107 |
+
[[0, 2, 9, 10, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
108 |
+
[[2, 3, 8, 9, 10, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
109 |
+
[[4, 6, 8, 11, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
110 |
+
[[0, 3, 4, 6, 11, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
111 |
+
[[0, 2, 9, 10, -1, -1, -1], [4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
112 |
+
[[2, 3, 4, 6, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
113 |
+
[[4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
114 |
+
[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1]],
|
115 |
+
[[0, 2, 4, 5, 10, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
116 |
+
[[2, 3, 4, 5, 8, 10, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
117 |
+
[[5, 6, 8, 9, 11, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
118 |
+
[[0, 3, 5, 6, 9, 11, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
119 |
+
[[0, 2, 5, 6, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
120 |
+
[[2, 3, 5, 6, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
121 |
+
[[1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
122 |
+
[[0, 1, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
123 |
+
[[0, 3, 6, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
124 |
+
[[6, 7, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
125 |
+
[[1, 3, 4, 6, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
126 |
+
[[0, 1, 4, 6, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
127 |
+
[[0, 3, 4, 6, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
128 |
+
[[4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
129 |
+
[[4, 5, 9, -1, -1, -1, -1], [1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
130 |
+
[[0, 1, 6, 7, 8, 10, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
131 |
+
[[0, 3, 4, 5, 6, 7, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
132 |
+
[[4, 5, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
133 |
+
[[1, 3, 5, 6, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
134 |
+
[[0, 1, 5, 6, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
135 |
+
[[0, 3, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
136 |
+
[[5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
137 |
+
[[5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
138 |
+
[[0, 3, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
139 |
+
[[0, 1, 9, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
140 |
+
[[1, 3, 8, 9, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
141 |
+
[[4, 7, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
142 |
+
[[0, 3, 4, 7, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
143 |
+
[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
144 |
+
[[1, 3, 4, 7, 9, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
145 |
+
[[4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
146 |
+
[[0, 3, 8, -1, -1, -1, -1], [4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
147 |
+
[[0, 1, 4, 6, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
148 |
+
[[1, 3, 4, 6, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
149 |
+
[[6, 7, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
150 |
+
[[0, 3, 6, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
151 |
+
[[0, 1, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
152 |
+
[[1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
153 |
+
[[2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
154 |
+
[[0, 2, 8, 11, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
155 |
+
[[0, 1, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
156 |
+
[[1, 2, 8, 9, 11, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
157 |
+
[[4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
158 |
+
[[0, 2, 4, 7, 11, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
159 |
+
[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1]],
|
160 |
+
[[1, 2, 4, 7, 9, 11, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
161 |
+
[[4, 6, 9, 10, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
162 |
+
[[0, 2, 8, 11, -1, -1, -1], [4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
163 |
+
[[0, 1, 4, 6, 10, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
164 |
+
[[1, 2, 4, 6, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
165 |
+
[[6, 7, 8, 9, 10, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
166 |
+
[[0, 2, 6, 7, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
167 |
+
[[0, 1, 6, 7, 8, 10, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
168 |
+
[[1, 2, 6, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
169 |
+
[[1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
170 |
+
[[0, 3, 8, -1, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
171 |
+
[[0, 2, 5, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
172 |
+
[[2, 3, 5, 6, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
173 |
+
[[4, 7, 8, -1, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
174 |
+
[[0, 3, 4, 7, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
175 |
+
[[0, 2, 5, 6, 9, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
176 |
+
[[2, 3, 4, 5, 6, 7, 9], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
177 |
+
[[1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
178 |
+
[[0, 3, 8, -1, -1, -1, -1], [1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
179 |
+
[[0, 2, 4, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
180 |
+
[[2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
181 |
+
[[1, 2, 6, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
182 |
+
[[0, 1, 2, 3, 6, 7, 9], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
183 |
+
[[0, 2, 6, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
184 |
+
[[2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
185 |
+
[[1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
186 |
+
[[0, 1, 5, 6, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
187 |
+
[[0, 3, 5, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
188 |
+
[[5, 6, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
189 |
+
[[4, 7, 8, -1, -1, -1, -1], [1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
190 |
+
[[0, 1, 4, 5, 6, 7, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
191 |
+
[[0, 3, 5, 6, 9, 11, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
192 |
+
[[4, 5, 6, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
193 |
+
[[1, 3, 4, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
194 |
+
[[0, 1, 4, 6, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
195 |
+
[[0, 3, 4, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
196 |
+
[[4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
197 |
+
[[1, 3, 6, 7, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
198 |
+
[[0, 1, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
199 |
+
[[0, 3, 6, 7, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
200 |
+
[[6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
201 |
+
[[5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
202 |
+
[[0, 3, 8, -1, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
203 |
+
[[0, 1, 9, -1, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
204 |
+
[[1, 3, 8, 9, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
205 |
+
[[4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
206 |
+
[[0, 3, 4, 5, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
207 |
+
[[0, 1, 9, -1, -1, -1, -1], [4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
208 |
+
[[1, 3, 4, 5, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
209 |
+
[[4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
210 |
+
[[0, 3, 8, -1, -1, -1, -1], [4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
211 |
+
[[0, 1, 4, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
212 |
+
[[1, 3, 4, 7, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
213 |
+
[[8, 9, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
214 |
+
[[0, 3, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
215 |
+
[[0, 1, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
216 |
+
[[1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
217 |
+
[[2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
218 |
+
[[0, 2, 5, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
219 |
+
[[0, 1, 9, -1, -1, -1, -1], [2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
220 |
+
[[1, 2, 5, 7, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
221 |
+
[[2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
222 |
+
[[0, 2, 4, 5, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
223 |
+
[[0, 1, 9, -1, -1, -1, -1], [2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
224 |
+
[[1, 2, 4, 5, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
225 |
+
[[2, 3, 4, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
226 |
+
[[0, 2, 4, 7, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
227 |
+
[[0, 1, 2, 3, 4, 7, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
228 |
+
[[4, 7, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
229 |
+
[[2, 3, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
230 |
+
[[0, 2, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
231 |
+
[[0, 1, 2, 3, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
232 |
+
[[1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
233 |
+
[[1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
234 |
+
[[0, 3, 8, -1, -1, -1, -1], [1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
235 |
+
[[0, 2, 5, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
236 |
+
[[2, 3, 5, 7, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
237 |
+
[[1, 2, 4, 5, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
238 |
+
[[0, 1, 2, 3, 4, 5, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
239 |
+
[[0, 2, 4, 5, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
240 |
+
[[4, 5, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
241 |
+
[[1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
242 |
+
[[0, 3, 8, -1, -1, -1, -1], [1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
243 |
+
[[0, 2, 4, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
244 |
+
[[2, 3, 4, 7, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
245 |
+
[[1, 2, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
246 |
+
[[0, 1, 2, 3, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
247 |
+
[[0, 2, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
248 |
+
[[2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
249 |
+
[[1, 3, 5, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
250 |
+
[[0, 1, 5, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
251 |
+
[[0, 3, 5, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
252 |
+
[[5, 7, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
253 |
+
[[1, 3, 4, 5, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
254 |
+
[[0, 1, 4, 5, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
255 |
+
[[0, 3, 4, 5, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
256 |
+
[[4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
257 |
+
[[1, 3, 4, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
258 |
+
[[0, 1, 4, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
259 |
+
[[0, 3, 4, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
260 |
+
[[4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
261 |
+
[[1, 3, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
262 |
+
[[0, 1, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
263 |
+
[[0, 3, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
264 |
+
[[-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]]
|
265 |
+
]
|
266 |
+
num_vd_table = [0, 1, 1, 1, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 3, 1, 2, 2,
|
267 |
+
2, 1, 2, 1, 2, 1, 1, 2, 1, 1, 2, 2, 2, 1, 2, 3, 1, 1, 2, 2, 1, 1, 1, 1, 1, 1, 2,
|
268 |
+
1, 2, 1, 2, 2, 1, 1, 2, 1, 1, 1, 1, 2, 2, 2, 1, 1, 2, 1, 2, 3, 2, 2, 1, 1, 1, 1,
|
269 |
+
1, 1, 2, 1, 1, 1, 2, 1, 2, 2, 2, 1, 1, 1, 1, 1, 2, 3, 2, 2, 2, 2, 2, 1, 3, 4, 2,
|
270 |
+
2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 2, 1, 1, 2, 2, 2, 2, 2,
|
271 |
+
3, 2, 1, 2, 1, 1, 1, 1, 1, 1, 2, 2, 3, 2, 3, 2, 4, 2, 2, 2, 2, 1, 2, 1, 2, 1, 1,
|
272 |
+
2, 1, 1, 2, 2, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1,
|
273 |
+
1, 2, 1, 1, 1, 2, 2, 2, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 2,
|
274 |
+
1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1,
|
275 |
+
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]
|
276 |
+
check_table = [
|
277 |
+
[0, 0, 0, 0, 0],
|
278 |
+
[0, 0, 0, 0, 0],
|
279 |
+
[0, 0, 0, 0, 0],
|
280 |
+
[0, 0, 0, 0, 0],
|
281 |
+
[0, 0, 0, 0, 0],
|
282 |
+
[0, 0, 0, 0, 0],
|
283 |
+
[0, 0, 0, 0, 0],
|
284 |
+
[0, 0, 0, 0, 0],
|
285 |
+
[0, 0, 0, 0, 0],
|
286 |
+
[0, 0, 0, 0, 0],
|
287 |
+
[0, 0, 0, 0, 0],
|
288 |
+
[0, 0, 0, 0, 0],
|
289 |
+
[0, 0, 0, 0, 0],
|
290 |
+
[0, 0, 0, 0, 0],
|
291 |
+
[0, 0, 0, 0, 0],
|
292 |
+
[0, 0, 0, 0, 0],
|
293 |
+
[0, 0, 0, 0, 0],
|
294 |
+
[0, 0, 0, 0, 0],
|
295 |
+
[0, 0, 0, 0, 0],
|
296 |
+
[0, 0, 0, 0, 0],
|
297 |
+
[0, 0, 0, 0, 0],
|
298 |
+
[0, 0, 0, 0, 0],
|
299 |
+
[0, 0, 0, 0, 0],
|
300 |
+
[0, 0, 0, 0, 0],
|
301 |
+
[0, 0, 0, 0, 0],
|
302 |
+
[0, 0, 0, 0, 0],
|
303 |
+
[0, 0, 0, 0, 0],
|
304 |
+
[0, 0, 0, 0, 0],
|
305 |
+
[0, 0, 0, 0, 0],
|
306 |
+
[0, 0, 0, 0, 0],
|
307 |
+
[0, 0, 0, 0, 0],
|
308 |
+
[0, 0, 0, 0, 0],
|
309 |
+
[0, 0, 0, 0, 0],
|
310 |
+
[0, 0, 0, 0, 0],
|
311 |
+
[0, 0, 0, 0, 0],
|
312 |
+
[0, 0, 0, 0, 0],
|
313 |
+
[0, 0, 0, 0, 0],
|
314 |
+
[0, 0, 0, 0, 0],
|
315 |
+
[0, 0, 0, 0, 0],
|
316 |
+
[0, 0, 0, 0, 0],
|
317 |
+
[0, 0, 0, 0, 0],
|
318 |
+
[0, 0, 0, 0, 0],
|
319 |
+
[0, 0, 0, 0, 0],
|
320 |
+
[0, 0, 0, 0, 0],
|
321 |
+
[0, 0, 0, 0, 0],
|
322 |
+
[0, 0, 0, 0, 0],
|
323 |
+
[0, 0, 0, 0, 0],
|
324 |
+
[0, 0, 0, 0, 0],
|
325 |
+
[0, 0, 0, 0, 0],
|
326 |
+
[0, 0, 0, 0, 0],
|
327 |
+
[0, 0, 0, 0, 0],
|
328 |
+
[0, 0, 0, 0, 0],
|
329 |
+
[0, 0, 0, 0, 0],
|
330 |
+
[0, 0, 0, 0, 0],
|
331 |
+
[0, 0, 0, 0, 0],
|
332 |
+
[0, 0, 0, 0, 0],
|
333 |
+
[0, 0, 0, 0, 0],
|
334 |
+
[0, 0, 0, 0, 0],
|
335 |
+
[0, 0, 0, 0, 0],
|
336 |
+
[0, 0, 0, 0, 0],
|
337 |
+
[0, 0, 0, 0, 0],
|
338 |
+
[1, 1, 0, 0, 194],
|
339 |
+
[1, -1, 0, 0, 193],
|
340 |
+
[0, 0, 0, 0, 0],
|
341 |
+
[0, 0, 0, 0, 0],
|
342 |
+
[0, 0, 0, 0, 0],
|
343 |
+
[0, 0, 0, 0, 0],
|
344 |
+
[0, 0, 0, 0, 0],
|
345 |
+
[0, 0, 0, 0, 0],
|
346 |
+
[0, 0, 0, 0, 0],
|
347 |
+
[0, 0, 0, 0, 0],
|
348 |
+
[0, 0, 0, 0, 0],
|
349 |
+
[0, 0, 0, 0, 0],
|
350 |
+
[0, 0, 0, 0, 0],
|
351 |
+
[0, 0, 0, 0, 0],
|
352 |
+
[0, 0, 0, 0, 0],
|
353 |
+
[0, 0, 0, 0, 0],
|
354 |
+
[0, 0, 0, 0, 0],
|
355 |
+
[0, 0, 0, 0, 0],
|
356 |
+
[0, 0, 0, 0, 0],
|
357 |
+
[0, 0, 0, 0, 0],
|
358 |
+
[0, 0, 0, 0, 0],
|
359 |
+
[0, 0, 0, 0, 0],
|
360 |
+
[0, 0, 0, 0, 0],
|
361 |
+
[0, 0, 0, 0, 0],
|
362 |
+
[0, 0, 0, 0, 0],
|
363 |
+
[0, 0, 0, 0, 0],
|
364 |
+
[0, 0, 0, 0, 0],
|
365 |
+
[0, 0, 0, 0, 0],
|
366 |
+
[0, 0, 0, 0, 0],
|
367 |
+
[0, 0, 0, 0, 0],
|
368 |
+
[1, 0, 1, 0, 164],
|
369 |
+
[0, 0, 0, 0, 0],
|
370 |
+
[0, 0, 0, 0, 0],
|
371 |
+
[1, 0, -1, 0, 161],
|
372 |
+
[0, 0, 0, 0, 0],
|
373 |
+
[0, 0, 0, 0, 0],
|
374 |
+
[0, 0, 0, 0, 0],
|
375 |
+
[0, 0, 0, 0, 0],
|
376 |
+
[0, 0, 0, 0, 0],
|
377 |
+
[0, 0, 0, 0, 0],
|
378 |
+
[0, 0, 0, 0, 0],
|
379 |
+
[0, 0, 0, 0, 0],
|
380 |
+
[1, 0, 0, 1, 152],
|
381 |
+
[0, 0, 0, 0, 0],
|
382 |
+
[0, 0, 0, 0, 0],
|
383 |
+
[0, 0, 0, 0, 0],
|
384 |
+
[0, 0, 0, 0, 0],
|
385 |
+
[0, 0, 0, 0, 0],
|
386 |
+
[0, 0, 0, 0, 0],
|
387 |
+
[1, 0, 0, 1, 145],
|
388 |
+
[1, 0, 0, 1, 144],
|
389 |
+
[0, 0, 0, 0, 0],
|
390 |
+
[0, 0, 0, 0, 0],
|
391 |
+
[0, 0, 0, 0, 0],
|
392 |
+
[0, 0, 0, 0, 0],
|
393 |
+
[0, 0, 0, 0, 0],
|
394 |
+
[0, 0, 0, 0, 0],
|
395 |
+
[1, 0, 0, -1, 137],
|
396 |
+
[0, 0, 0, 0, 0],
|
397 |
+
[0, 0, 0, 0, 0],
|
398 |
+
[0, 0, 0, 0, 0],
|
399 |
+
[1, 0, 1, 0, 133],
|
400 |
+
[1, 0, 1, 0, 132],
|
401 |
+
[1, 1, 0, 0, 131],
|
402 |
+
[1, 1, 0, 0, 130],
|
403 |
+
[0, 0, 0, 0, 0],
|
404 |
+
[0, 0, 0, 0, 0],
|
405 |
+
[0, 0, 0, 0, 0],
|
406 |
+
[0, 0, 0, 0, 0],
|
407 |
+
[0, 0, 0, 0, 0],
|
408 |
+
[0, 0, 0, 0, 0],
|
409 |
+
[0, 0, 0, 0, 0],
|
410 |
+
[0, 0, 0, 0, 0],
|
411 |
+
[0, 0, 0, 0, 0],
|
412 |
+
[0, 0, 0, 0, 0],
|
413 |
+
[0, 0, 0, 0, 0],
|
414 |
+
[0, 0, 0, 0, 0],
|
415 |
+
[0, 0, 0, 0, 0],
|
416 |
+
[0, 0, 0, 0, 0],
|
417 |
+
[0, 0, 0, 0, 0],
|
418 |
+
[0, 0, 0, 0, 0],
|
419 |
+
[0, 0, 0, 0, 0],
|
420 |
+
[0, 0, 0, 0, 0],
|
421 |
+
[0, 0, 0, 0, 0],
|
422 |
+
[0, 0, 0, 0, 0],
|
423 |
+
[0, 0, 0, 0, 0],
|
424 |
+
[0, 0, 0, 0, 0],
|
425 |
+
[0, 0, 0, 0, 0],
|
426 |
+
[0, 0, 0, 0, 0],
|
427 |
+
[0, 0, 0, 0, 0],
|
428 |
+
[0, 0, 0, 0, 0],
|
429 |
+
[0, 0, 0, 0, 0],
|
430 |
+
[0, 0, 0, 0, 0],
|
431 |
+
[0, 0, 0, 0, 0],
|
432 |
+
[1, 0, 0, 1, 100],
|
433 |
+
[0, 0, 0, 0, 0],
|
434 |
+
[1, 0, 0, 1, 98],
|
435 |
+
[0, 0, 0, 0, 0],
|
436 |
+
[1, 0, 0, 1, 96],
|
437 |
+
[0, 0, 0, 0, 0],
|
438 |
+
[0, 0, 0, 0, 0],
|
439 |
+
[0, 0, 0, 0, 0],
|
440 |
+
[0, 0, 0, 0, 0],
|
441 |
+
[0, 0, 0, 0, 0],
|
442 |
+
[0, 0, 0, 0, 0],
|
443 |
+
[0, 0, 0, 0, 0],
|
444 |
+
[1, 0, 1, 0, 88],
|
445 |
+
[0, 0, 0, 0, 0],
|
446 |
+
[0, 0, 0, 0, 0],
|
447 |
+
[0, 0, 0, 0, 0],
|
448 |
+
[0, 0, 0, 0, 0],
|
449 |
+
[0, 0, 0, 0, 0],
|
450 |
+
[1, 0, -1, 0, 82],
|
451 |
+
[0, 0, 0, 0, 0],
|
452 |
+
[0, 0, 0, 0, 0],
|
453 |
+
[0, 0, 0, 0, 0],
|
454 |
+
[0, 0, 0, 0, 0],
|
455 |
+
[0, 0, 0, 0, 0],
|
456 |
+
[0, 0, 0, 0, 0],
|
457 |
+
[0, 0, 0, 0, 0],
|
458 |
+
[1, 0, 1, 0, 74],
|
459 |
+
[0, 0, 0, 0, 0],
|
460 |
+
[1, 0, 1, 0, 72],
|
461 |
+
[0, 0, 0, 0, 0],
|
462 |
+
[1, 0, 0, -1, 70],
|
463 |
+
[0, 0, 0, 0, 0],
|
464 |
+
[0, 0, 0, 0, 0],
|
465 |
+
[1, -1, 0, 0, 67],
|
466 |
+
[0, 0, 0, 0, 0],
|
467 |
+
[1, -1, 0, 0, 65],
|
468 |
+
[0, 0, 0, 0, 0],
|
469 |
+
[0, 0, 0, 0, 0],
|
470 |
+
[0, 0, 0, 0, 0],
|
471 |
+
[0, 0, 0, 0, 0],
|
472 |
+
[0, 0, 0, 0, 0],
|
473 |
+
[0, 0, 0, 0, 0],
|
474 |
+
[0, 0, 0, 0, 0],
|
475 |
+
[0, 0, 0, 0, 0],
|
476 |
+
[1, 1, 0, 0, 56],
|
477 |
+
[0, 0, 0, 0, 0],
|
478 |
+
[0, 0, 0, 0, 0],
|
479 |
+
[0, 0, 0, 0, 0],
|
480 |
+
[1, -1, 0, 0, 52],
|
481 |
+
[0, 0, 0, 0, 0],
|
482 |
+
[0, 0, 0, 0, 0],
|
483 |
+
[0, 0, 0, 0, 0],
|
484 |
+
[0, 0, 0, 0, 0],
|
485 |
+
[0, 0, 0, 0, 0],
|
486 |
+
[0, 0, 0, 0, 0],
|
487 |
+
[0, 0, 0, 0, 0],
|
488 |
+
[1, 1, 0, 0, 44],
|
489 |
+
[0, 0, 0, 0, 0],
|
490 |
+
[0, 0, 0, 0, 0],
|
491 |
+
[0, 0, 0, 0, 0],
|
492 |
+
[1, 1, 0, 0, 40],
|
493 |
+
[0, 0, 0, 0, 0],
|
494 |
+
[1, 0, 0, -1, 38],
|
495 |
+
[1, 0, -1, 0, 37],
|
496 |
+
[0, 0, 0, 0, 0],
|
497 |
+
[0, 0, 0, 0, 0],
|
498 |
+
[0, 0, 0, 0, 0],
|
499 |
+
[1, 0, -1, 0, 33],
|
500 |
+
[0, 0, 0, 0, 0],
|
501 |
+
[0, 0, 0, 0, 0],
|
502 |
+
[0, 0, 0, 0, 0],
|
503 |
+
[0, 0, 0, 0, 0],
|
504 |
+
[1, -1, 0, 0, 28],
|
505 |
+
[0, 0, 0, 0, 0],
|
506 |
+
[1, 0, -1, 0, 26],
|
507 |
+
[1, 0, 0, -1, 25],
|
508 |
+
[0, 0, 0, 0, 0],
|
509 |
+
[0, 0, 0, 0, 0],
|
510 |
+
[0, 0, 0, 0, 0],
|
511 |
+
[0, 0, 0, 0, 0],
|
512 |
+
[1, -1, 0, 0, 20],
|
513 |
+
[0, 0, 0, 0, 0],
|
514 |
+
[1, 0, -1, 0, 18],
|
515 |
+
[0, 0, 0, 0, 0],
|
516 |
+
[0, 0, 0, 0, 0],
|
517 |
+
[0, 0, 0, 0, 0],
|
518 |
+
[0, 0, 0, 0, 0],
|
519 |
+
[0, 0, 0, 0, 0],
|
520 |
+
[0, 0, 0, 0, 0],
|
521 |
+
[0, 0, 0, 0, 0],
|
522 |
+
[0, 0, 0, 0, 0],
|
523 |
+
[1, 0, 0, -1, 9],
|
524 |
+
[0, 0, 0, 0, 0],
|
525 |
+
[0, 0, 0, 0, 0],
|
526 |
+
[1, 0, 0, -1, 6],
|
527 |
+
[0, 0, 0, 0, 0],
|
528 |
+
[0, 0, 0, 0, 0],
|
529 |
+
[0, 0, 0, 0, 0],
|
530 |
+
[0, 0, 0, 0, 0],
|
531 |
+
[0, 0, 0, 0, 0],
|
532 |
+
[0, 0, 0, 0, 0]
|
533 |
+
]
|
534 |
+
tet_table = [
|
535 |
+
[-1, -1, -1, -1, -1, -1],
|
536 |
+
[0, 0, 0, 0, 0, 0],
|
537 |
+
[0, 0, 0, 0, 0, 0],
|
538 |
+
[1, 1, 1, 1, 1, 1],
|
539 |
+
[4, 4, 4, 4, 4, 4],
|
540 |
+
[0, 0, 0, 0, 0, 0],
|
541 |
+
[4, 0, 0, 4, 4, -1],
|
542 |
+
[1, 1, 1, 1, 1, 1],
|
543 |
+
[4, 4, 4, 4, 4, 4],
|
544 |
+
[0, 4, 0, 4, 4, -1],
|
545 |
+
[0, 0, 0, 0, 0, 0],
|
546 |
+
[1, 1, 1, 1, 1, 1],
|
547 |
+
[5, 5, 5, 5, 5, 5],
|
548 |
+
[0, 0, 0, 0, 0, 0],
|
549 |
+
[0, 0, 0, 0, 0, 0],
|
550 |
+
[1, 1, 1, 1, 1, 1],
|
551 |
+
[2, 2, 2, 2, 2, 2],
|
552 |
+
[0, 0, 0, 0, 0, 0],
|
553 |
+
[2, 0, 2, -1, 0, 2],
|
554 |
+
[1, 1, 1, 1, 1, 1],
|
555 |
+
[2, -1, 2, 4, 4, 2],
|
556 |
+
[0, 0, 0, 0, 0, 0],
|
557 |
+
[2, 0, 2, 4, 4, 2],
|
558 |
+
[1, 1, 1, 1, 1, 1],
|
559 |
+
[2, 4, 2, 4, 4, 2],
|
560 |
+
[0, 4, 0, 4, 4, 0],
|
561 |
+
[2, 0, 2, 0, 0, 2],
|
562 |
+
[1, 1, 1, 1, 1, 1],
|
563 |
+
[2, 5, 2, 5, 5, 2],
|
564 |
+
[0, 0, 0, 0, 0, 0],
|
565 |
+
[2, 0, 2, 0, 0, 2],
|
566 |
+
[1, 1, 1, 1, 1, 1],
|
567 |
+
[1, 1, 1, 1, 1, 1],
|
568 |
+
[0, 1, 1, -1, 0, 1],
|
569 |
+
[0, 0, 0, 0, 0, 0],
|
570 |
+
[2, 2, 2, 2, 2, 2],
|
571 |
+
[4, 1, 1, 4, 4, 1],
|
572 |
+
[0, 1, 1, 0, 0, 1],
|
573 |
+
[4, 0, 0, 4, 4, 0],
|
574 |
+
[2, 2, 2, 2, 2, 2],
|
575 |
+
[-1, 1, 1, 4, 4, 1],
|
576 |
+
[0, 1, 1, 4, 4, 1],
|
577 |
+
[0, 0, 0, 0, 0, 0],
|
578 |
+
[2, 2, 2, 2, 2, 2],
|
579 |
+
[5, 1, 1, 5, 5, 1],
|
580 |
+
[0, 1, 1, 0, 0, 1],
|
581 |
+
[0, 0, 0, 0, 0, 0],
|
582 |
+
[2, 2, 2, 2, 2, 2],
|
583 |
+
[1, 1, 1, 1, 1, 1],
|
584 |
+
[0, 0, 0, 0, 0, 0],
|
585 |
+
[0, 0, 0, 0, 0, 0],
|
586 |
+
[8, 8, 8, 8, 8, 8],
|
587 |
+
[1, 1, 1, 4, 4, 1],
|
588 |
+
[0, 0, 0, 0, 0, 0],
|
589 |
+
[4, 0, 0, 4, 4, 0],
|
590 |
+
[4, 4, 4, 4, 4, 4],
|
591 |
+
[1, 1, 1, 4, 4, 1],
|
592 |
+
[0, 4, 0, 4, 4, 0],
|
593 |
+
[0, 0, 0, 0, 0, 0],
|
594 |
+
[4, 4, 4, 4, 4, 4],
|
595 |
+
[1, 1, 1, 5, 5, 1],
|
596 |
+
[0, 0, 0, 0, 0, 0],
|
597 |
+
[0, 0, 0, 0, 0, 0],
|
598 |
+
[5, 5, 5, 5, 5, 5],
|
599 |
+
[6, 6, 6, 6, 6, 6],
|
600 |
+
[6, -1, 0, 6, 0, 6],
|
601 |
+
[6, 0, 0, 6, 0, 6],
|
602 |
+
[6, 1, 1, 6, 1, 6],
|
603 |
+
[4, 4, 4, 4, 4, 4],
|
604 |
+
[0, 0, 0, 0, 0, 0],
|
605 |
+
[4, 0, 0, 4, 4, 4],
|
606 |
+
[1, 1, 1, 1, 1, 1],
|
607 |
+
[6, 4, -1, 6, 4, 6],
|
608 |
+
[6, 4, 0, 6, 4, 6],
|
609 |
+
[6, 0, 0, 6, 0, 6],
|
610 |
+
[6, 1, 1, 6, 1, 6],
|
611 |
+
[5, 5, 5, 5, 5, 5],
|
612 |
+
[0, 0, 0, 0, 0, 0],
|
613 |
+
[0, 0, 0, 0, 0, 0],
|
614 |
+
[1, 1, 1, 1, 1, 1],
|
615 |
+
[2, 2, 2, 2, 2, 2],
|
616 |
+
[0, 0, 0, 0, 0, 0],
|
617 |
+
[2, 0, 2, 2, 0, 2],
|
618 |
+
[1, 1, 1, 1, 1, 1],
|
619 |
+
[2, 2, 2, 2, 2, 2],
|
620 |
+
[0, 0, 0, 0, 0, 0],
|
621 |
+
[2, 0, 2, 2, 2, 2],
|
622 |
+
[1, 1, 1, 1, 1, 1],
|
623 |
+
[2, 4, 2, 2, 4, 2],
|
624 |
+
[0, 4, 0, 4, 4, 0],
|
625 |
+
[2, 0, 2, 2, 0, 2],
|
626 |
+
[1, 1, 1, 1, 1, 1],
|
627 |
+
[2, 2, 2, 2, 2, 2],
|
628 |
+
[0, 0, 0, 0, 0, 0],
|
629 |
+
[0, 0, 0, 0, 0, 0],
|
630 |
+
[1, 1, 1, 1, 1, 1],
|
631 |
+
[6, 1, 1, 6, -1, 6],
|
632 |
+
[6, 1, 1, 6, 0, 6],
|
633 |
+
[6, 0, 0, 6, 0, 6],
|
634 |
+
[6, 2, 2, 6, 2, 6],
|
635 |
+
[4, 1, 1, 4, 4, 1],
|
636 |
+
[0, 1, 1, 0, 0, 1],
|
637 |
+
[4, 0, 0, 4, 4, 4],
|
638 |
+
[2, 2, 2, 2, 2, 2],
|
639 |
+
[6, 1, 1, 6, 4, 6],
|
640 |
+
[6, 1, 1, 6, 4, 6],
|
641 |
+
[6, 0, 0, 6, 0, 6],
|
642 |
+
[6, 2, 2, 6, 2, 6],
|
643 |
+
[5, 1, 1, 5, 5, 1],
|
644 |
+
[0, 1, 1, 0, 0, 1],
|
645 |
+
[0, 0, 0, 0, 0, 0],
|
646 |
+
[2, 2, 2, 2, 2, 2],
|
647 |
+
[1, 1, 1, 1, 1, 1],
|
648 |
+
[0, 0, 0, 0, 0, 0],
|
649 |
+
[0, 0, 0, 0, 0, 0],
|
650 |
+
[6, 6, 6, 6, 6, 6],
|
651 |
+
[1, 1, 1, 1, 1, 1],
|
652 |
+
[0, 0, 0, 0, 0, 0],
|
653 |
+
[0, 0, 0, 0, 0, 0],
|
654 |
+
[4, 4, 4, 4, 4, 4],
|
655 |
+
[1, 1, 1, 1, 4, 1],
|
656 |
+
[0, 4, 0, 4, 4, 0],
|
657 |
+
[0, 0, 0, 0, 0, 0],
|
658 |
+
[4, 4, 4, 4, 4, 4],
|
659 |
+
[1, 1, 1, 1, 1, 1],
|
660 |
+
[0, 0, 0, 0, 0, 0],
|
661 |
+
[0, 5, 0, 5, 0, 5],
|
662 |
+
[5, 5, 5, 5, 5, 5],
|
663 |
+
[5, 5, 5, 5, 5, 5],
|
664 |
+
[0, 5, 0, 5, 0, 5],
|
665 |
+
[-1, 5, 0, 5, 0, 5],
|
666 |
+
[1, 5, 1, 5, 1, 5],
|
667 |
+
[4, 5, -1, 5, 4, 5],
|
668 |
+
[0, 5, 0, 5, 0, 5],
|
669 |
+
[4, 5, 0, 5, 4, 5],
|
670 |
+
[1, 5, 1, 5, 1, 5],
|
671 |
+
[4, 4, 4, 4, 4, 4],
|
672 |
+
[0, 4, 0, 4, 4, 4],
|
673 |
+
[0, 0, 0, 0, 0, 0],
|
674 |
+
[1, 1, 1, 1, 1, 1],
|
675 |
+
[6, 6, 6, 6, 6, 6],
|
676 |
+
[0, 0, 0, 0, 0, 0],
|
677 |
+
[0, 0, 0, 0, 0, 0],
|
678 |
+
[1, 1, 1, 1, 1, 1],
|
679 |
+
[2, 5, 2, 5, -1, 5],
|
680 |
+
[0, 5, 0, 5, 0, 5],
|
681 |
+
[2, 5, 2, 5, 0, 5],
|
682 |
+
[1, 5, 1, 5, 1, 5],
|
683 |
+
[2, 5, 2, 5, 4, 5],
|
684 |
+
[0, 5, 0, 5, 0, 5],
|
685 |
+
[2, 5, 2, 5, 4, 5],
|
686 |
+
[1, 5, 1, 5, 1, 5],
|
687 |
+
[2, 4, 2, 4, 4, 2],
|
688 |
+
[0, 4, 0, 4, 4, 4],
|
689 |
+
[2, 0, 2, 0, 0, 2],
|
690 |
+
[1, 1, 1, 1, 1, 1],
|
691 |
+
[2, 6, 2, 6, 6, 2],
|
692 |
+
[0, 0, 0, 0, 0, 0],
|
693 |
+
[2, 0, 2, 0, 0, 2],
|
694 |
+
[1, 1, 1, 1, 1, 1],
|
695 |
+
[1, 1, 1, 1, 1, 1],
|
696 |
+
[0, 1, 1, 1, 0, 1],
|
697 |
+
[0, 0, 0, 0, 0, 0],
|
698 |
+
[2, 2, 2, 2, 2, 2],
|
699 |
+
[4, 1, 1, 1, 4, 1],
|
700 |
+
[0, 1, 1, 1, 0, 1],
|
701 |
+
[4, 0, 0, 4, 4, 0],
|
702 |
+
[2, 2, 2, 2, 2, 2],
|
703 |
+
[1, 1, 1, 1, 1, 1],
|
704 |
+
[0, 1, 1, 1, 1, 1],
|
705 |
+
[0, 0, 0, 0, 0, 0],
|
706 |
+
[2, 2, 2, 2, 2, 2],
|
707 |
+
[1, 1, 1, 1, 1, 1],
|
708 |
+
[0, 0, 0, 0, 0, 0],
|
709 |
+
[0, 0, 0, 0, 0, 0],
|
710 |
+
[2, 2, 2, 2, 2, 2],
|
711 |
+
[1, 1, 1, 1, 1, 1],
|
712 |
+
[0, 0, 0, 0, 0, 0],
|
713 |
+
[0, 0, 0, 0, 0, 0],
|
714 |
+
[5, 5, 5, 5, 5, 5],
|
715 |
+
[1, 1, 1, 1, 4, 1],
|
716 |
+
[0, 0, 0, 0, 0, 0],
|
717 |
+
[4, 0, 0, 4, 4, 0],
|
718 |
+
[4, 4, 4, 4, 4, 4],
|
719 |
+
[1, 1, 1, 1, 1, 1],
|
720 |
+
[0, 0, 0, 0, 0, 0],
|
721 |
+
[0, 0, 0, 0, 0, 0],
|
722 |
+
[4, 4, 4, 4, 4, 4],
|
723 |
+
[1, 1, 1, 1, 1, 1],
|
724 |
+
[6, 0, 0, 6, 0, 6],
|
725 |
+
[0, 0, 0, 0, 0, 0],
|
726 |
+
[6, 6, 6, 6, 6, 6],
|
727 |
+
[5, 5, 5, 5, 5, 5],
|
728 |
+
[5, 5, 0, 5, 0, 5],
|
729 |
+
[5, 5, 0, 5, 0, 5],
|
730 |
+
[5, 5, 1, 5, 1, 5],
|
731 |
+
[4, 4, 4, 4, 4, 4],
|
732 |
+
[0, 0, 0, 0, 0, 0],
|
733 |
+
[4, 4, 0, 4, 4, 4],
|
734 |
+
[1, 1, 1, 1, 1, 1],
|
735 |
+
[4, 4, 4, 4, 4, 4],
|
736 |
+
[4, 4, 0, 4, 4, 4],
|
737 |
+
[0, 0, 0, 0, 0, 0],
|
738 |
+
[1, 1, 1, 1, 1, 1],
|
739 |
+
[8, 8, 8, 8, 8, 8],
|
740 |
+
[0, 0, 0, 0, 0, 0],
|
741 |
+
[0, 0, 0, 0, 0, 0],
|
742 |
+
[1, 1, 1, 1, 1, 1],
|
743 |
+
[2, 2, 2, 2, 2, 2],
|
744 |
+
[0, 0, 0, 0, 0, 0],
|
745 |
+
[2, 2, 2, 2, 0, 2],
|
746 |
+
[1, 1, 1, 1, 1, 1],
|
747 |
+
[2, 2, 2, 2, 2, 2],
|
748 |
+
[0, 0, 0, 0, 0, 0],
|
749 |
+
[2, 2, 2, 2, 2, 2],
|
750 |
+
[1, 1, 1, 1, 1, 1],
|
751 |
+
[2, 2, 2, 2, 2, 2],
|
752 |
+
[0, 0, 0, 0, 0, 0],
|
753 |
+
[0, 0, 0, 0, 0, 0],
|
754 |
+
[4, 1, 1, 4, 4, 1],
|
755 |
+
[2, 2, 2, 2, 2, 2],
|
756 |
+
[0, 0, 0, 0, 0, 0],
|
757 |
+
[0, 0, 0, 0, 0, 0],
|
758 |
+
[1, 1, 1, 1, 1, 1],
|
759 |
+
[1, 1, 1, 1, 1, 1],
|
760 |
+
[1, 1, 1, 1, 0, 1],
|
761 |
+
[0, 0, 0, 0, 0, 0],
|
762 |
+
[2, 2, 2, 2, 2, 2],
|
763 |
+
[1, 1, 1, 1, 1, 1],
|
764 |
+
[0, 0, 0, 0, 0, 0],
|
765 |
+
[0, 0, 0, 0, 0, 0],
|
766 |
+
[2, 4, 2, 4, 4, 2],
|
767 |
+
[1, 1, 1, 1, 1, 1],
|
768 |
+
[1, 1, 1, 1, 1, 1],
|
769 |
+
[0, 0, 0, 0, 0, 0],
|
770 |
+
[2, 2, 2, 2, 2, 2],
|
771 |
+
[1, 1, 1, 1, 1, 1],
|
772 |
+
[0, 0, 0, 0, 0, 0],
|
773 |
+
[0, 0, 0, 0, 0, 0],
|
774 |
+
[2, 2, 2, 2, 2, 2],
|
775 |
+
[1, 1, 1, 1, 1, 1],
|
776 |
+
[0, 0, 0, 0, 0, 0],
|
777 |
+
[0, 0, 0, 0, 0, 0],
|
778 |
+
[5, 5, 5, 5, 5, 5],
|
779 |
+
[1, 1, 1, 1, 1, 1],
|
780 |
+
[0, 0, 0, 0, 0, 0],
|
781 |
+
[0, 0, 0, 0, 0, 0],
|
782 |
+
[4, 4, 4, 4, 4, 4],
|
783 |
+
[1, 1, 1, 1, 1, 1],
|
784 |
+
[0, 0, 0, 0, 0, 0],
|
785 |
+
[0, 0, 0, 0, 0, 0],
|
786 |
+
[4, 4, 4, 4, 4, 4],
|
787 |
+
[1, 1, 1, 1, 1, 1],
|
788 |
+
[0, 0, 0, 0, 0, 0],
|
789 |
+
[0, 0, 0, 0, 0, 0],
|
790 |
+
[12, 12, 12, 12, 12, 12]
|
791 |
+
]
|
trellis/representations/octree/octree_dfs.py
CHANGED
@@ -3,21 +3,6 @@ import torch.nn as nn
|
|
3 |
import torch.nn.functional as F
|
4 |
|
5 |
|
6 |
-
DEFAULT_TRIVEC_CONFIG = {
|
7 |
-
'dim': 8,
|
8 |
-
'rank': 8,
|
9 |
-
}
|
10 |
-
|
11 |
-
DEFAULT_VOXEL_CONFIG = {
|
12 |
-
'solid': False,
|
13 |
-
}
|
14 |
-
|
15 |
-
DEFAULT_DECOPOLY_CONFIG = {
|
16 |
-
'degree': 8,
|
17 |
-
'rank': 16,
|
18 |
-
}
|
19 |
-
|
20 |
-
|
21 |
class DfsOctree:
|
22 |
"""
|
23 |
Sparse Voxel Octree (SVO) implementation for PyTorch.
|
@@ -145,8 +130,8 @@ class DfsOctree:
|
|
145 |
|
146 |
@property
|
147 |
def get_density(self):
|
148 |
-
if self.primitive == 'voxel' and self.
|
149 |
-
return torch.full((self.position.shape[0], 1),
|
150 |
return self.density_activation(self.density)
|
151 |
|
152 |
@property
|
@@ -172,7 +157,7 @@ class DfsOctree:
|
|
172 |
return torch.cat([self.features_dc, self.features_ac], dim=-2)
|
173 |
|
174 |
def state_dict(self):
|
175 |
-
ret = {'structure': self.structure, 'position': self.position, 'depth': self.depth, 'sh_degree': self.sh_degree, 'active_sh_degree': self.active_sh_degree, '
|
176 |
if hasattr(self, 'density_shift'):
|
177 |
ret['density_shift'] = self.density_shift
|
178 |
for data in set(self.data + self.param_names):
|
|
|
3 |
import torch.nn.functional as F
|
4 |
|
5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
class DfsOctree:
|
7 |
"""
|
8 |
Sparse Voxel Octree (SVO) implementation for PyTorch.
|
|
|
130 |
|
131 |
@property
|
132 |
def get_density(self):
|
133 |
+
if self.primitive == 'voxel' and self.primitive_config.get('solid', False):
|
134 |
+
return torch.full((self.position.shape[0], 1), torch.finfo(torch.float32).max, dtype=torch.float32, device=self.device)
|
135 |
return self.density_activation(self.density)
|
136 |
|
137 |
@property
|
|
|
157 |
return torch.cat([self.features_dc, self.features_ac], dim=-2)
|
158 |
|
159 |
def state_dict(self):
|
160 |
+
ret = {'structure': self.structure, 'position': self.position, 'depth': self.depth, 'sh_degree': self.sh_degree, 'active_sh_degree': self.active_sh_degree, 'primitive_config': self.primitive_config, 'primitive': self.primitive}
|
161 |
if hasattr(self, 'density_shift'):
|
162 |
ret['density_shift'] = self.density_shift
|
163 |
for data in set(self.data + self.param_names):
|
trellis/trainers/__init__.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
|
3 |
+
__attributes = {
|
4 |
+
'BasicTrainer': 'basic',
|
5 |
+
|
6 |
+
'SparseStructureVaeTrainer': 'vae.sparse_structure_vae',
|
7 |
+
|
8 |
+
'SLatVaeGaussianTrainer': 'vae.structured_latent_vae_gaussian',
|
9 |
+
'SLatVaeRadianceFieldDecoderTrainer': 'vae.structured_latent_vae_rf_dec',
|
10 |
+
'SLatVaeMeshDecoderTrainer': 'vae.structured_latent_vae_mesh_dec',
|
11 |
+
|
12 |
+
'FlowMatchingTrainer': 'flow_matching.flow_matching',
|
13 |
+
'FlowMatchingCFGTrainer': 'flow_matching.flow_matching',
|
14 |
+
'TextConditionedFlowMatchingCFGTrainer': 'flow_matching.flow_matching',
|
15 |
+
'ImageConditionedFlowMatchingCFGTrainer': 'flow_matching.flow_matching',
|
16 |
+
|
17 |
+
'SparseFlowMatchingTrainer': 'flow_matching.sparse_flow_matching',
|
18 |
+
'SparseFlowMatchingCFGTrainer': 'flow_matching.sparse_flow_matching',
|
19 |
+
'TextConditionedSparseFlowMatchingCFGTrainer': 'flow_matching.sparse_flow_matching',
|
20 |
+
'ImageConditionedSparseFlowMatchingCFGTrainer': 'flow_matching.sparse_flow_matching',
|
21 |
+
}
|
22 |
+
|
23 |
+
__submodules = []
|
24 |
+
|
25 |
+
__all__ = list(__attributes.keys()) + __submodules
|
26 |
+
|
27 |
+
def __getattr__(name):
|
28 |
+
if name not in globals():
|
29 |
+
if name in __attributes:
|
30 |
+
module_name = __attributes[name]
|
31 |
+
module = importlib.import_module(f".{module_name}", __name__)
|
32 |
+
globals()[name] = getattr(module, name)
|
33 |
+
elif name in __submodules:
|
34 |
+
module = importlib.import_module(f".{name}", __name__)
|
35 |
+
globals()[name] = module
|
36 |
+
else:
|
37 |
+
raise AttributeError(f"module {__name__} has no attribute {name}")
|
38 |
+
return globals()[name]
|
39 |
+
|
40 |
+
|
41 |
+
# For Pylance
|
42 |
+
if __name__ == '__main__':
|
43 |
+
from .basic import BasicTrainer
|
44 |
+
|
45 |
+
from .vae.sparse_structure_vae import SparseStructureVaeTrainer
|
46 |
+
|
47 |
+
from .vae.structured_latent_vae_gaussian import SLatVaeGaussianTrainer
|
48 |
+
from .vae.structured_latent_vae_rf_dec import SLatVaeRadianceFieldDecoderTrainer
|
49 |
+
from .vae.structured_latent_vae_mesh_dec import SLatVaeMeshDecoderTrainer
|
50 |
+
|
51 |
+
from .flow_matching.flow_matching import (
|
52 |
+
FlowMatchingTrainer,
|
53 |
+
FlowMatchingCFGTrainer,
|
54 |
+
TextConditionedFlowMatchingCFGTrainer,
|
55 |
+
ImageConditionedFlowMatchingCFGTrainer,
|
56 |
+
)
|
57 |
+
|
58 |
+
from .flow_matching.sparse_flow_matching import (
|
59 |
+
SparseFlowMatchingTrainer,
|
60 |
+
SparseFlowMatchingCFGTrainer,
|
61 |
+
TextConditionedSparseFlowMatchingCFGTrainer,
|
62 |
+
ImageConditionedSparseFlowMatchingCFGTrainer,
|
63 |
+
)
|
trellis/trainers/base.py
ADDED
@@ -0,0 +1,451 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import abstractmethod
|
2 |
+
import os
|
3 |
+
import time
|
4 |
+
import json
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.distributed as dist
|
8 |
+
from torch.utils.data import DataLoader
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
from torchvision import utils
|
12 |
+
from torch.utils.tensorboard import SummaryWriter
|
13 |
+
|
14 |
+
from .utils import *
|
15 |
+
from ..utils.general_utils import *
|
16 |
+
from ..utils.data_utils import recursive_to_device, cycle, ResumableSampler
|
17 |
+
|
18 |
+
|
19 |
+
class Trainer:
|
20 |
+
"""
|
21 |
+
Base class for training.
|
22 |
+
"""
|
23 |
+
def __init__(self,
|
24 |
+
models,
|
25 |
+
dataset,
|
26 |
+
*,
|
27 |
+
output_dir,
|
28 |
+
load_dir,
|
29 |
+
step,
|
30 |
+
max_steps,
|
31 |
+
batch_size=None,
|
32 |
+
batch_size_per_gpu=None,
|
33 |
+
batch_split=None,
|
34 |
+
optimizer={},
|
35 |
+
lr_scheduler=None,
|
36 |
+
elastic=None,
|
37 |
+
grad_clip=None,
|
38 |
+
ema_rate=0.9999,
|
39 |
+
fp16_mode='inflat_all',
|
40 |
+
fp16_scale_growth=1e-3,
|
41 |
+
finetune_ckpt=None,
|
42 |
+
log_param_stats=False,
|
43 |
+
prefetch_data=True,
|
44 |
+
i_print=1000,
|
45 |
+
i_log=500,
|
46 |
+
i_sample=10000,
|
47 |
+
i_save=10000,
|
48 |
+
i_ddpcheck=10000,
|
49 |
+
**kwargs
|
50 |
+
):
|
51 |
+
assert batch_size is not None or batch_size_per_gpu is not None, 'Either batch_size or batch_size_per_gpu must be specified.'
|
52 |
+
|
53 |
+
self.models = models
|
54 |
+
self.dataset = dataset
|
55 |
+
self.batch_split = batch_split if batch_split is not None else 1
|
56 |
+
self.max_steps = max_steps
|
57 |
+
self.optimizer_config = optimizer
|
58 |
+
self.lr_scheduler_config = lr_scheduler
|
59 |
+
self.elastic_controller_config = elastic
|
60 |
+
self.grad_clip = grad_clip
|
61 |
+
self.ema_rate = [ema_rate] if isinstance(ema_rate, float) else ema_rate
|
62 |
+
self.fp16_mode = fp16_mode
|
63 |
+
self.fp16_scale_growth = fp16_scale_growth
|
64 |
+
self.log_param_stats = log_param_stats
|
65 |
+
self.prefetch_data = prefetch_data
|
66 |
+
if self.prefetch_data:
|
67 |
+
self._data_prefetched = None
|
68 |
+
|
69 |
+
self.output_dir = output_dir
|
70 |
+
self.i_print = i_print
|
71 |
+
self.i_log = i_log
|
72 |
+
self.i_sample = i_sample
|
73 |
+
self.i_save = i_save
|
74 |
+
self.i_ddpcheck = i_ddpcheck
|
75 |
+
|
76 |
+
if dist.is_initialized():
|
77 |
+
# Multi-GPU params
|
78 |
+
self.world_size = dist.get_world_size()
|
79 |
+
self.rank = dist.get_rank()
|
80 |
+
self.local_rank = dist.get_rank() % torch.cuda.device_count()
|
81 |
+
self.is_master = self.rank == 0
|
82 |
+
else:
|
83 |
+
# Single-GPU params
|
84 |
+
self.world_size = 1
|
85 |
+
self.rank = 0
|
86 |
+
self.local_rank = 0
|
87 |
+
self.is_master = True
|
88 |
+
|
89 |
+
self.batch_size = batch_size if batch_size_per_gpu is None else batch_size_per_gpu * self.world_size
|
90 |
+
self.batch_size_per_gpu = batch_size_per_gpu if batch_size_per_gpu is not None else batch_size // self.world_size
|
91 |
+
assert self.batch_size % self.world_size == 0, 'Batch size must be divisible by the number of GPUs.'
|
92 |
+
assert self.batch_size_per_gpu % self.batch_split == 0, 'Batch size per GPU must be divisible by batch split.'
|
93 |
+
|
94 |
+
self.init_models_and_more(**kwargs)
|
95 |
+
self.prepare_dataloader(**kwargs)
|
96 |
+
|
97 |
+
# Load checkpoint
|
98 |
+
self.step = 0
|
99 |
+
if load_dir is not None and step is not None:
|
100 |
+
self.load(load_dir, step)
|
101 |
+
elif finetune_ckpt is not None:
|
102 |
+
self.finetune_from(finetune_ckpt)
|
103 |
+
|
104 |
+
if self.is_master:
|
105 |
+
os.makedirs(os.path.join(self.output_dir, 'ckpts'), exist_ok=True)
|
106 |
+
os.makedirs(os.path.join(self.output_dir, 'samples'), exist_ok=True)
|
107 |
+
self.writer = SummaryWriter(os.path.join(self.output_dir, 'tb_logs'))
|
108 |
+
|
109 |
+
if self.world_size > 1:
|
110 |
+
self.check_ddp()
|
111 |
+
|
112 |
+
if self.is_master:
|
113 |
+
print('\n\nTrainer initialized.')
|
114 |
+
print(self)
|
115 |
+
|
116 |
+
@property
|
117 |
+
def device(self):
|
118 |
+
for _, model in self.models.items():
|
119 |
+
if hasattr(model, 'device'):
|
120 |
+
return model.device
|
121 |
+
return next(list(self.models.values())[0].parameters()).device
|
122 |
+
|
123 |
+
@abstractmethod
|
124 |
+
def init_models_and_more(self, **kwargs):
|
125 |
+
"""
|
126 |
+
Initialize models and more.
|
127 |
+
"""
|
128 |
+
pass
|
129 |
+
|
130 |
+
def prepare_dataloader(self, **kwargs):
|
131 |
+
"""
|
132 |
+
Prepare dataloader.
|
133 |
+
"""
|
134 |
+
self.data_sampler = ResumableSampler(
|
135 |
+
self.dataset,
|
136 |
+
shuffle=True,
|
137 |
+
)
|
138 |
+
self.dataloader = DataLoader(
|
139 |
+
self.dataset,
|
140 |
+
batch_size=self.batch_size_per_gpu,
|
141 |
+
num_workers=int(np.ceil(os.cpu_count() / torch.cuda.device_count())),
|
142 |
+
pin_memory=True,
|
143 |
+
drop_last=True,
|
144 |
+
persistent_workers=True,
|
145 |
+
collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None,
|
146 |
+
sampler=self.data_sampler,
|
147 |
+
)
|
148 |
+
self.data_iterator = cycle(self.dataloader)
|
149 |
+
|
150 |
+
@abstractmethod
|
151 |
+
def load(self, load_dir, step=0):
|
152 |
+
"""
|
153 |
+
Load a checkpoint.
|
154 |
+
Should be called by all processes.
|
155 |
+
"""
|
156 |
+
pass
|
157 |
+
|
158 |
+
@abstractmethod
|
159 |
+
def save(self):
|
160 |
+
"""
|
161 |
+
Save a checkpoint.
|
162 |
+
Should be called only by the rank 0 process.
|
163 |
+
"""
|
164 |
+
pass
|
165 |
+
|
166 |
+
@abstractmethod
|
167 |
+
def finetune_from(self, finetune_ckpt):
|
168 |
+
"""
|
169 |
+
Finetune from a checkpoint.
|
170 |
+
Should be called by all processes.
|
171 |
+
"""
|
172 |
+
pass
|
173 |
+
|
174 |
+
@abstractmethod
|
175 |
+
def run_snapshot(self, num_samples, batch_size=4, verbose=False, **kwargs):
|
176 |
+
"""
|
177 |
+
Run a snapshot of the model.
|
178 |
+
"""
|
179 |
+
pass
|
180 |
+
|
181 |
+
@torch.no_grad()
|
182 |
+
def visualize_sample(self, sample):
|
183 |
+
"""
|
184 |
+
Convert a sample to an image.
|
185 |
+
"""
|
186 |
+
if hasattr(self.dataset, 'visualize_sample'):
|
187 |
+
return self.dataset.visualize_sample(sample)
|
188 |
+
else:
|
189 |
+
return sample
|
190 |
+
|
191 |
+
@torch.no_grad()
|
192 |
+
def snapshot_dataset(self, num_samples=100):
|
193 |
+
"""
|
194 |
+
Sample images from the dataset.
|
195 |
+
"""
|
196 |
+
dataloader = torch.utils.data.DataLoader(
|
197 |
+
self.dataset,
|
198 |
+
batch_size=num_samples,
|
199 |
+
num_workers=0,
|
200 |
+
shuffle=True,
|
201 |
+
collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None,
|
202 |
+
)
|
203 |
+
data = next(iter(dataloader))
|
204 |
+
data = recursive_to_device(data, self.device)
|
205 |
+
vis = self.visualize_sample(data)
|
206 |
+
if isinstance(vis, dict):
|
207 |
+
save_cfg = [(f'dataset_{k}', v) for k, v in vis.items()]
|
208 |
+
else:
|
209 |
+
save_cfg = [('dataset', vis)]
|
210 |
+
for name, image in save_cfg:
|
211 |
+
utils.save_image(
|
212 |
+
image,
|
213 |
+
os.path.join(self.output_dir, 'samples', f'{name}.jpg'),
|
214 |
+
nrow=int(np.sqrt(num_samples)),
|
215 |
+
normalize=True,
|
216 |
+
value_range=self.dataset.value_range,
|
217 |
+
)
|
218 |
+
|
219 |
+
@torch.no_grad()
|
220 |
+
def snapshot(self, suffix=None, num_samples=64, batch_size=4, verbose=False):
|
221 |
+
"""
|
222 |
+
Sample images from the model.
|
223 |
+
NOTE: This function should be called by all processes.
|
224 |
+
"""
|
225 |
+
if self.is_master:
|
226 |
+
print(f'\nSampling {num_samples} images...', end='')
|
227 |
+
|
228 |
+
if suffix is None:
|
229 |
+
suffix = f'step{self.step:07d}'
|
230 |
+
|
231 |
+
# Assign tasks
|
232 |
+
num_samples_per_process = int(np.ceil(num_samples / self.world_size))
|
233 |
+
samples = self.run_snapshot(num_samples_per_process, batch_size=batch_size, verbose=verbose)
|
234 |
+
|
235 |
+
# Preprocess images
|
236 |
+
for key in list(samples.keys()):
|
237 |
+
if samples[key]['type'] == 'sample':
|
238 |
+
vis = self.visualize_sample(samples[key]['value'])
|
239 |
+
if isinstance(vis, dict):
|
240 |
+
for k, v in vis.items():
|
241 |
+
samples[f'{key}_{k}'] = {'value': v, 'type': 'image'}
|
242 |
+
del samples[key]
|
243 |
+
else:
|
244 |
+
samples[key] = {'value': vis, 'type': 'image'}
|
245 |
+
|
246 |
+
# Gather results
|
247 |
+
if self.world_size > 1:
|
248 |
+
for key in samples.keys():
|
249 |
+
samples[key]['value'] = samples[key]['value'].contiguous()
|
250 |
+
if self.is_master:
|
251 |
+
all_images = [torch.empty_like(samples[key]['value']) for _ in range(self.world_size)]
|
252 |
+
else:
|
253 |
+
all_images = []
|
254 |
+
dist.gather(samples[key]['value'], all_images, dst=0)
|
255 |
+
if self.is_master:
|
256 |
+
samples[key]['value'] = torch.cat(all_images, dim=0)[:num_samples]
|
257 |
+
|
258 |
+
# Save images
|
259 |
+
if self.is_master:
|
260 |
+
os.makedirs(os.path.join(self.output_dir, 'samples', suffix), exist_ok=True)
|
261 |
+
for key in samples.keys():
|
262 |
+
if samples[key]['type'] == 'image':
|
263 |
+
utils.save_image(
|
264 |
+
samples[key]['value'],
|
265 |
+
os.path.join(self.output_dir, 'samples', suffix, f'{key}_{suffix}.jpg'),
|
266 |
+
nrow=int(np.sqrt(num_samples)),
|
267 |
+
normalize=True,
|
268 |
+
value_range=self.dataset.value_range,
|
269 |
+
)
|
270 |
+
elif samples[key]['type'] == 'number':
|
271 |
+
min = samples[key]['value'].min()
|
272 |
+
max = samples[key]['value'].max()
|
273 |
+
images = (samples[key]['value'] - min) / (max - min)
|
274 |
+
images = utils.make_grid(
|
275 |
+
images,
|
276 |
+
nrow=int(np.sqrt(num_samples)),
|
277 |
+
normalize=False,
|
278 |
+
)
|
279 |
+
save_image_with_notes(
|
280 |
+
images,
|
281 |
+
os.path.join(self.output_dir, 'samples', suffix, f'{key}_{suffix}.jpg'),
|
282 |
+
notes=f'{key} min: {min}, max: {max}',
|
283 |
+
)
|
284 |
+
|
285 |
+
if self.is_master:
|
286 |
+
print(' Done.')
|
287 |
+
|
288 |
+
@abstractmethod
|
289 |
+
def update_ema(self):
|
290 |
+
"""
|
291 |
+
Update exponential moving average.
|
292 |
+
Should only be called by the rank 0 process.
|
293 |
+
"""
|
294 |
+
pass
|
295 |
+
|
296 |
+
@abstractmethod
|
297 |
+
def check_ddp(self):
|
298 |
+
"""
|
299 |
+
Check if DDP is working properly.
|
300 |
+
Should be called by all process.
|
301 |
+
"""
|
302 |
+
pass
|
303 |
+
|
304 |
+
@abstractmethod
|
305 |
+
def training_losses(**mb_data):
|
306 |
+
"""
|
307 |
+
Compute training losses.
|
308 |
+
"""
|
309 |
+
pass
|
310 |
+
|
311 |
+
def load_data(self):
|
312 |
+
"""
|
313 |
+
Load data.
|
314 |
+
"""
|
315 |
+
if self.prefetch_data:
|
316 |
+
if self._data_prefetched is None:
|
317 |
+
self._data_prefetched = recursive_to_device(next(self.data_iterator), self.device, non_blocking=True)
|
318 |
+
data = self._data_prefetched
|
319 |
+
self._data_prefetched = recursive_to_device(next(self.data_iterator), self.device, non_blocking=True)
|
320 |
+
else:
|
321 |
+
data = recursive_to_device(next(self.data_iterator), self.device, non_blocking=True)
|
322 |
+
|
323 |
+
# if the data is a dict, we need to split it into multiple dicts with batch_size_per_gpu
|
324 |
+
if isinstance(data, dict):
|
325 |
+
if self.batch_split == 1:
|
326 |
+
data_list = [data]
|
327 |
+
else:
|
328 |
+
batch_size = list(data.values())[0].shape[0]
|
329 |
+
data_list = [
|
330 |
+
{k: v[i * batch_size // self.batch_split:(i + 1) * batch_size // self.batch_split] for k, v in data.items()}
|
331 |
+
for i in range(self.batch_split)
|
332 |
+
]
|
333 |
+
elif isinstance(data, list):
|
334 |
+
data_list = data
|
335 |
+
else:
|
336 |
+
raise ValueError('Data must be a dict or a list of dicts.')
|
337 |
+
|
338 |
+
return data_list
|
339 |
+
|
340 |
+
@abstractmethod
|
341 |
+
def run_step(self, data_list):
|
342 |
+
"""
|
343 |
+
Run a training step.
|
344 |
+
"""
|
345 |
+
pass
|
346 |
+
|
347 |
+
def run(self):
|
348 |
+
"""
|
349 |
+
Run training.
|
350 |
+
"""
|
351 |
+
if self.is_master:
|
352 |
+
print('\nStarting training...')
|
353 |
+
self.snapshot_dataset()
|
354 |
+
if self.step == 0:
|
355 |
+
self.snapshot(suffix='init')
|
356 |
+
else: # resume
|
357 |
+
self.snapshot(suffix=f'resume_step{self.step:07d}')
|
358 |
+
|
359 |
+
log = []
|
360 |
+
time_last_print = 0.0
|
361 |
+
time_elapsed = 0.0
|
362 |
+
while self.step < self.max_steps:
|
363 |
+
time_start = time.time()
|
364 |
+
|
365 |
+
data_list = self.load_data()
|
366 |
+
step_log = self.run_step(data_list)
|
367 |
+
|
368 |
+
time_end = time.time()
|
369 |
+
time_elapsed += time_end - time_start
|
370 |
+
|
371 |
+
self.step += 1
|
372 |
+
|
373 |
+
# Print progress
|
374 |
+
if self.is_master and self.step % self.i_print == 0:
|
375 |
+
speed = self.i_print / (time_elapsed - time_last_print) * 3600
|
376 |
+
columns = [
|
377 |
+
f'Step: {self.step}/{self.max_steps} ({self.step / self.max_steps * 100:.2f}%)',
|
378 |
+
f'Elapsed: {time_elapsed / 3600:.2f} h',
|
379 |
+
f'Speed: {speed:.2f} steps/h',
|
380 |
+
f'ETA: {(self.max_steps - self.step) / speed:.2f} h',
|
381 |
+
]
|
382 |
+
print(' | '.join([c.ljust(25) for c in columns]), flush=True)
|
383 |
+
time_last_print = time_elapsed
|
384 |
+
|
385 |
+
# Check ddp
|
386 |
+
if self.world_size > 1 and self.i_ddpcheck is not None and self.step % self.i_ddpcheck == 0:
|
387 |
+
self.check_ddp()
|
388 |
+
|
389 |
+
# Sample images
|
390 |
+
if self.step % self.i_sample == 0:
|
391 |
+
self.snapshot()
|
392 |
+
|
393 |
+
if self.is_master:
|
394 |
+
log.append((self.step, {}))
|
395 |
+
|
396 |
+
# Log time
|
397 |
+
log[-1][1]['time'] = {
|
398 |
+
'step': time_end - time_start,
|
399 |
+
'elapsed': time_elapsed,
|
400 |
+
}
|
401 |
+
|
402 |
+
# Log losses
|
403 |
+
if step_log is not None:
|
404 |
+
log[-1][1].update(step_log)
|
405 |
+
|
406 |
+
# Log scale
|
407 |
+
if self.fp16_mode == 'amp':
|
408 |
+
log[-1][1]['scale'] = self.scaler.get_scale()
|
409 |
+
elif self.fp16_mode == 'inflat_all':
|
410 |
+
log[-1][1]['log_scale'] = self.log_scale
|
411 |
+
|
412 |
+
# Save log
|
413 |
+
if self.step % self.i_log == 0:
|
414 |
+
## save to log file
|
415 |
+
log_str = '\n'.join([
|
416 |
+
f'{step}: {json.dumps(log)}' for step, log in log
|
417 |
+
])
|
418 |
+
with open(os.path.join(self.output_dir, 'log.txt'), 'a') as log_file:
|
419 |
+
log_file.write(log_str + '\n')
|
420 |
+
|
421 |
+
# show with mlflow
|
422 |
+
log_show = [l for _, l in log if not dict_any(l, lambda x: np.isnan(x))]
|
423 |
+
log_show = dict_reduce(log_show, lambda x: np.mean(x))
|
424 |
+
log_show = dict_flatten(log_show, sep='/')
|
425 |
+
for key, value in log_show.items():
|
426 |
+
self.writer.add_scalar(key, value, self.step)
|
427 |
+
log = []
|
428 |
+
|
429 |
+
# Save checkpoint
|
430 |
+
if self.step % self.i_save == 0:
|
431 |
+
self.save()
|
432 |
+
|
433 |
+
if self.is_master:
|
434 |
+
self.snapshot(suffix='final')
|
435 |
+
self.writer.close()
|
436 |
+
print('Training finished.')
|
437 |
+
|
438 |
+
def profile(self, wait=2, warmup=3, active=5):
|
439 |
+
"""
|
440 |
+
Profile the training loop.
|
441 |
+
"""
|
442 |
+
with torch.profiler.profile(
|
443 |
+
schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=1),
|
444 |
+
on_trace_ready=torch.profiler.tensorboard_trace_handler(os.path.join(self.output_dir, 'profile')),
|
445 |
+
profile_memory=True,
|
446 |
+
with_stack=True,
|
447 |
+
) as prof:
|
448 |
+
for _ in range(wait + warmup + active):
|
449 |
+
self.run_step()
|
450 |
+
prof.step()
|
451 |
+
|
trellis/trainers/basic.py
ADDED
@@ -0,0 +1,438 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import copy
|
3 |
+
from functools import partial
|
4 |
+
from contextlib import nullcontext
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.distributed as dist
|
8 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
from .utils import *
|
12 |
+
from .base import Trainer
|
13 |
+
from ..utils.general_utils import *
|
14 |
+
from ..utils.dist_utils import *
|
15 |
+
from ..utils import grad_clip_utils, elastic_utils
|
16 |
+
|
17 |
+
|
18 |
+
class BasicTrainer(Trainer):
|
19 |
+
"""
|
20 |
+
Trainer for basic training loop.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
models (dict[str, nn.Module]): Models to train.
|
24 |
+
dataset (torch.utils.data.Dataset): Dataset.
|
25 |
+
output_dir (str): Output directory.
|
26 |
+
load_dir (str): Load directory.
|
27 |
+
step (int): Step to load.
|
28 |
+
batch_size (int): Batch size.
|
29 |
+
batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
|
30 |
+
batch_split (int): Split batch with gradient accumulation.
|
31 |
+
max_steps (int): Max steps.
|
32 |
+
optimizer (dict): Optimizer config.
|
33 |
+
lr_scheduler (dict): Learning rate scheduler config.
|
34 |
+
elastic (dict): Elastic memory management config.
|
35 |
+
grad_clip (float or dict): Gradient clip config.
|
36 |
+
ema_rate (float or list): Exponential moving average rates.
|
37 |
+
fp16_mode (str): FP16 mode.
|
38 |
+
- None: No FP16.
|
39 |
+
- 'inflat_all': Hold a inflated fp32 master param for all params.
|
40 |
+
- 'amp': Automatic mixed precision.
|
41 |
+
fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
|
42 |
+
finetune_ckpt (dict): Finetune checkpoint.
|
43 |
+
log_param_stats (bool): Log parameter stats.
|
44 |
+
i_print (int): Print interval.
|
45 |
+
i_log (int): Log interval.
|
46 |
+
i_sample (int): Sample interval.
|
47 |
+
i_save (int): Save interval.
|
48 |
+
i_ddpcheck (int): DDP check interval.
|
49 |
+
"""
|
50 |
+
|
51 |
+
def __str__(self):
|
52 |
+
lines = []
|
53 |
+
lines.append(self.__class__.__name__)
|
54 |
+
lines.append(f' - Models:')
|
55 |
+
for name, model in self.models.items():
|
56 |
+
lines.append(f' - {name}: {model.__class__.__name__}')
|
57 |
+
lines.append(f' - Dataset: {indent(str(self.dataset), 2)}')
|
58 |
+
lines.append(f' - Dataloader:')
|
59 |
+
lines.append(f' - Sampler: {self.dataloader.sampler.__class__.__name__}')
|
60 |
+
lines.append(f' - Num workers: {self.dataloader.num_workers}')
|
61 |
+
lines.append(f' - Number of steps: {self.max_steps}')
|
62 |
+
lines.append(f' - Number of GPUs: {self.world_size}')
|
63 |
+
lines.append(f' - Batch size: {self.batch_size}')
|
64 |
+
lines.append(f' - Batch size per GPU: {self.batch_size_per_gpu}')
|
65 |
+
lines.append(f' - Batch split: {self.batch_split}')
|
66 |
+
lines.append(f' - Optimizer: {self.optimizer.__class__.__name__}')
|
67 |
+
lines.append(f' - Learning rate: {self.optimizer.param_groups[0]["lr"]}')
|
68 |
+
if self.lr_scheduler_config is not None:
|
69 |
+
lines.append(f' - LR scheduler: {self.lr_scheduler.__class__.__name__}')
|
70 |
+
if self.elastic_controller_config is not None:
|
71 |
+
lines.append(f' - Elastic memory: {indent(str(self.elastic_controller), 2)}')
|
72 |
+
if self.grad_clip is not None:
|
73 |
+
lines.append(f' - Gradient clip: {indent(str(self.grad_clip), 2)}')
|
74 |
+
lines.append(f' - EMA rate: {self.ema_rate}')
|
75 |
+
lines.append(f' - FP16 mode: {self.fp16_mode}')
|
76 |
+
return '\n'.join(lines)
|
77 |
+
|
78 |
+
def init_models_and_more(self, **kwargs):
|
79 |
+
"""
|
80 |
+
Initialize models and more.
|
81 |
+
"""
|
82 |
+
if self.world_size > 1:
|
83 |
+
# Prepare distributed data parallel
|
84 |
+
self.training_models = {
|
85 |
+
name: DDP(
|
86 |
+
model,
|
87 |
+
device_ids=[self.local_rank],
|
88 |
+
output_device=self.local_rank,
|
89 |
+
bucket_cap_mb=128,
|
90 |
+
find_unused_parameters=False
|
91 |
+
)
|
92 |
+
for name, model in self.models.items()
|
93 |
+
}
|
94 |
+
else:
|
95 |
+
self.training_models = self.models
|
96 |
+
|
97 |
+
# Build master params
|
98 |
+
self.model_params = sum(
|
99 |
+
[[p for p in model.parameters() if p.requires_grad] for model in self.models.values()]
|
100 |
+
, [])
|
101 |
+
if self.fp16_mode == 'amp':
|
102 |
+
self.master_params = self.model_params
|
103 |
+
self.scaler = torch.GradScaler() if self.fp16_mode == 'amp' else None
|
104 |
+
elif self.fp16_mode == 'inflat_all':
|
105 |
+
self.master_params = make_master_params(self.model_params)
|
106 |
+
self.fp16_scale_growth = self.fp16_scale_growth
|
107 |
+
self.log_scale = 20.0
|
108 |
+
elif self.fp16_mode is None:
|
109 |
+
self.master_params = self.model_params
|
110 |
+
else:
|
111 |
+
raise NotImplementedError(f'FP16 mode {self.fp16_mode} is not implemented.')
|
112 |
+
|
113 |
+
# Build EMA params
|
114 |
+
if self.is_master:
|
115 |
+
self.ema_params = [copy.deepcopy(self.master_params) for _ in self.ema_rate]
|
116 |
+
|
117 |
+
# Initialize optimizer
|
118 |
+
if hasattr(torch.optim, self.optimizer_config['name']):
|
119 |
+
self.optimizer = getattr(torch.optim, self.optimizer_config['name'])(self.master_params, **self.optimizer_config['args'])
|
120 |
+
else:
|
121 |
+
self.optimizer = globals()[self.optimizer_config['name']](self.master_params, **self.optimizer_config['args'])
|
122 |
+
|
123 |
+
# Initalize learning rate scheduler
|
124 |
+
if self.lr_scheduler_config is not None:
|
125 |
+
if hasattr(torch.optim.lr_scheduler, self.lr_scheduler_config['name']):
|
126 |
+
self.lr_scheduler = getattr(torch.optim.lr_scheduler, self.lr_scheduler_config['name'])(self.optimizer, **self.lr_scheduler_config['args'])
|
127 |
+
else:
|
128 |
+
self.lr_scheduler = globals()[self.lr_scheduler_config['name']](self.optimizer, **self.lr_scheduler_config['args'])
|
129 |
+
|
130 |
+
# Initialize elastic memory controller
|
131 |
+
if self.elastic_controller_config is not None:
|
132 |
+
assert any([isinstance(model, (elastic_utils.ElasticModule, elastic_utils.ElasticModuleMixin)) for model in self.models.values()]), \
|
133 |
+
'No elastic module found in models, please inherit from ElasticModule or ElasticModuleMixin'
|
134 |
+
self.elastic_controller = getattr(elastic_utils, self.elastic_controller_config['name'])(**self.elastic_controller_config['args'])
|
135 |
+
for model in self.models.values():
|
136 |
+
if isinstance(model, (elastic_utils.ElasticModule, elastic_utils.ElasticModuleMixin)):
|
137 |
+
model.register_memory_controller(self.elastic_controller)
|
138 |
+
|
139 |
+
# Initialize gradient clipper
|
140 |
+
if self.grad_clip is not None:
|
141 |
+
if isinstance(self.grad_clip, (float, int)):
|
142 |
+
self.grad_clip = float(self.grad_clip)
|
143 |
+
else:
|
144 |
+
self.grad_clip = getattr(grad_clip_utils, self.grad_clip['name'])(**self.grad_clip['args'])
|
145 |
+
|
146 |
+
def _master_params_to_state_dicts(self, master_params):
|
147 |
+
"""
|
148 |
+
Convert master params to dict of state_dicts.
|
149 |
+
"""
|
150 |
+
if self.fp16_mode == 'inflat_all':
|
151 |
+
master_params = unflatten_master_params(self.model_params, master_params)
|
152 |
+
state_dicts = {name: model.state_dict() for name, model in self.models.items()}
|
153 |
+
master_params_names = sum(
|
154 |
+
[[(name, n) for n, p in model.named_parameters() if p.requires_grad] for name, model in self.models.items()]
|
155 |
+
, [])
|
156 |
+
for i, (model_name, param_name) in enumerate(master_params_names):
|
157 |
+
state_dicts[model_name][param_name] = master_params[i]
|
158 |
+
return state_dicts
|
159 |
+
|
160 |
+
def _state_dicts_to_master_params(self, master_params, state_dicts):
|
161 |
+
"""
|
162 |
+
Convert a state_dict to master params.
|
163 |
+
"""
|
164 |
+
master_params_names = sum(
|
165 |
+
[[(name, n) for n, p in model.named_parameters() if p.requires_grad] for name, model in self.models.items()]
|
166 |
+
, [])
|
167 |
+
params = [state_dicts[name][param_name] for name, param_name in master_params_names]
|
168 |
+
if self.fp16_mode == 'inflat_all':
|
169 |
+
model_params_to_master_params(params, master_params)
|
170 |
+
else:
|
171 |
+
for i, param in enumerate(params):
|
172 |
+
master_params[i].data.copy_(param.data)
|
173 |
+
|
174 |
+
def load(self, load_dir, step=0):
|
175 |
+
"""
|
176 |
+
Load a checkpoint.
|
177 |
+
Should be called by all processes.
|
178 |
+
"""
|
179 |
+
if self.is_master:
|
180 |
+
print(f'\nLoading checkpoint from step {step}...', end='')
|
181 |
+
|
182 |
+
model_ckpts = {}
|
183 |
+
for name, model in self.models.items():
|
184 |
+
model_ckpt = torch.load(read_file_dist(os.path.join(load_dir, 'ckpts', f'{name}_step{step:07d}.pt')), map_location=self.device, weights_only=True)
|
185 |
+
model_ckpts[name] = model_ckpt
|
186 |
+
model.load_state_dict(model_ckpt)
|
187 |
+
if self.fp16_mode == 'inflat_all':
|
188 |
+
model.convert_to_fp16()
|
189 |
+
self._state_dicts_to_master_params(self.master_params, model_ckpts)
|
190 |
+
del model_ckpts
|
191 |
+
|
192 |
+
if self.is_master:
|
193 |
+
for i, ema_rate in enumerate(self.ema_rate):
|
194 |
+
ema_ckpts = {}
|
195 |
+
for name, model in self.models.items():
|
196 |
+
ema_ckpt = torch.load(os.path.join(load_dir, 'ckpts', f'{name}_ema{ema_rate}_step{step:07d}.pt'), map_location=self.device, weights_only=True)
|
197 |
+
ema_ckpts[name] = ema_ckpt
|
198 |
+
self._state_dicts_to_master_params(self.ema_params[i], ema_ckpts)
|
199 |
+
del ema_ckpts
|
200 |
+
|
201 |
+
misc_ckpt = torch.load(read_file_dist(os.path.join(load_dir, 'ckpts', f'misc_step{step:07d}.pt')), map_location=torch.device('cpu'), weights_only=False)
|
202 |
+
self.optimizer.load_state_dict(misc_ckpt['optimizer'])
|
203 |
+
self.step = misc_ckpt['step']
|
204 |
+
self.data_sampler.load_state_dict(misc_ckpt['data_sampler'])
|
205 |
+
if self.fp16_mode == 'amp':
|
206 |
+
self.scaler.load_state_dict(misc_ckpt['scaler'])
|
207 |
+
elif self.fp16_mode == 'inflat_all':
|
208 |
+
self.log_scale = misc_ckpt['log_scale']
|
209 |
+
if self.lr_scheduler_config is not None:
|
210 |
+
self.lr_scheduler.load_state_dict(misc_ckpt['lr_scheduler'])
|
211 |
+
if self.elastic_controller_config is not None:
|
212 |
+
self.elastic_controller.load_state_dict(misc_ckpt['elastic_controller'])
|
213 |
+
if self.grad_clip is not None and not isinstance(self.grad_clip, float):
|
214 |
+
self.grad_clip.load_state_dict(misc_ckpt['grad_clip'])
|
215 |
+
del misc_ckpt
|
216 |
+
|
217 |
+
if self.world_size > 1:
|
218 |
+
dist.barrier()
|
219 |
+
if self.is_master:
|
220 |
+
print(' Done.')
|
221 |
+
|
222 |
+
if self.world_size > 1:
|
223 |
+
self.check_ddp()
|
224 |
+
|
225 |
+
def save(self):
|
226 |
+
"""
|
227 |
+
Save a checkpoint.
|
228 |
+
Should be called only by the rank 0 process.
|
229 |
+
"""
|
230 |
+
assert self.is_master, 'save() should be called only by the rank 0 process.'
|
231 |
+
print(f'\nSaving checkpoint at step {self.step}...', end='')
|
232 |
+
|
233 |
+
model_ckpts = self._master_params_to_state_dicts(self.master_params)
|
234 |
+
for name, model_ckpt in model_ckpts.items():
|
235 |
+
torch.save(model_ckpt, os.path.join(self.output_dir, 'ckpts', f'{name}_step{self.step:07d}.pt'))
|
236 |
+
|
237 |
+
for i, ema_rate in enumerate(self.ema_rate):
|
238 |
+
ema_ckpts = self._master_params_to_state_dicts(self.ema_params[i])
|
239 |
+
for name, ema_ckpt in ema_ckpts.items():
|
240 |
+
torch.save(ema_ckpt, os.path.join(self.output_dir, 'ckpts', f'{name}_ema{ema_rate}_step{self.step:07d}.pt'))
|
241 |
+
|
242 |
+
misc_ckpt = {
|
243 |
+
'optimizer': self.optimizer.state_dict(),
|
244 |
+
'step': self.step,
|
245 |
+
'data_sampler': self.data_sampler.state_dict(),
|
246 |
+
}
|
247 |
+
if self.fp16_mode == 'amp':
|
248 |
+
misc_ckpt['scaler'] = self.scaler.state_dict()
|
249 |
+
elif self.fp16_mode == 'inflat_all':
|
250 |
+
misc_ckpt['log_scale'] = self.log_scale
|
251 |
+
if self.lr_scheduler_config is not None:
|
252 |
+
misc_ckpt['lr_scheduler'] = self.lr_scheduler.state_dict()
|
253 |
+
if self.elastic_controller_config is not None:
|
254 |
+
misc_ckpt['elastic_controller'] = self.elastic_controller.state_dict()
|
255 |
+
if self.grad_clip is not None and not isinstance(self.grad_clip, float):
|
256 |
+
misc_ckpt['grad_clip'] = self.grad_clip.state_dict()
|
257 |
+
torch.save(misc_ckpt, os.path.join(self.output_dir, 'ckpts', f'misc_step{self.step:07d}.pt'))
|
258 |
+
print(' Done.')
|
259 |
+
|
260 |
+
def finetune_from(self, finetune_ckpt):
|
261 |
+
"""
|
262 |
+
Finetune from a checkpoint.
|
263 |
+
Should be called by all processes.
|
264 |
+
"""
|
265 |
+
if self.is_master:
|
266 |
+
print('\nFinetuning from:')
|
267 |
+
for name, path in finetune_ckpt.items():
|
268 |
+
print(f' - {name}: {path}')
|
269 |
+
|
270 |
+
model_ckpts = {}
|
271 |
+
for name, model in self.models.items():
|
272 |
+
model_state_dict = model.state_dict()
|
273 |
+
if name in finetune_ckpt:
|
274 |
+
model_ckpt = torch.load(read_file_dist(finetune_ckpt[name]), map_location=self.device, weights_only=True)
|
275 |
+
for k, v in model_ckpt.items():
|
276 |
+
if model_ckpt[k].shape != model_state_dict[k].shape:
|
277 |
+
if self.is_master:
|
278 |
+
print(f'Warning: {k} shape mismatch, {model_ckpt[k].shape} vs {model_state_dict[k].shape}, skipped.')
|
279 |
+
model_ckpt[k] = model_state_dict[k]
|
280 |
+
model_ckpts[name] = model_ckpt
|
281 |
+
model.load_state_dict(model_ckpt)
|
282 |
+
if self.fp16_mode == 'inflat_all':
|
283 |
+
model.convert_to_fp16()
|
284 |
+
else:
|
285 |
+
if self.is_master:
|
286 |
+
print(f'Warning: {name} not found in finetune_ckpt, skipped.')
|
287 |
+
model_ckpts[name] = model_state_dict
|
288 |
+
self._state_dicts_to_master_params(self.master_params, model_ckpts)
|
289 |
+
del model_ckpts
|
290 |
+
|
291 |
+
if self.world_size > 1:
|
292 |
+
dist.barrier()
|
293 |
+
if self.is_master:
|
294 |
+
print('Done.')
|
295 |
+
|
296 |
+
if self.world_size > 1:
|
297 |
+
self.check_ddp()
|
298 |
+
|
299 |
+
def update_ema(self):
|
300 |
+
"""
|
301 |
+
Update exponential moving average.
|
302 |
+
Should only be called by the rank 0 process.
|
303 |
+
"""
|
304 |
+
assert self.is_master, 'update_ema() should be called only by the rank 0 process.'
|
305 |
+
for i, ema_rate in enumerate(self.ema_rate):
|
306 |
+
for master_param, ema_param in zip(self.master_params, self.ema_params[i]):
|
307 |
+
ema_param.detach().mul_(ema_rate).add_(master_param, alpha=1.0 - ema_rate)
|
308 |
+
|
309 |
+
def check_ddp(self):
|
310 |
+
"""
|
311 |
+
Check if DDP is working properly.
|
312 |
+
Should be called by all process.
|
313 |
+
"""
|
314 |
+
if self.is_master:
|
315 |
+
print('\nPerforming DDP check...')
|
316 |
+
|
317 |
+
if self.is_master:
|
318 |
+
print('Checking if parameters are consistent across processes...')
|
319 |
+
dist.barrier()
|
320 |
+
try:
|
321 |
+
for p in self.master_params:
|
322 |
+
# split to avoid OOM
|
323 |
+
for i in range(0, p.numel(), 10000000):
|
324 |
+
sub_size = min(10000000, p.numel() - i)
|
325 |
+
sub_p = p.detach().view(-1)[i:i+sub_size]
|
326 |
+
# gather from all processes
|
327 |
+
sub_p_gather = [torch.empty_like(sub_p) for _ in range(self.world_size)]
|
328 |
+
dist.all_gather(sub_p_gather, sub_p)
|
329 |
+
# check if equal
|
330 |
+
assert all([torch.equal(sub_p, sub_p_gather[i]) for i in range(self.world_size)]), 'parameters are not consistent across processes'
|
331 |
+
except AssertionError as e:
|
332 |
+
if self.is_master:
|
333 |
+
print(f'\n\033[91mError: {e}\033[0m')
|
334 |
+
print('DDP check failed.')
|
335 |
+
raise e
|
336 |
+
|
337 |
+
dist.barrier()
|
338 |
+
if self.is_master:
|
339 |
+
print('Done.')
|
340 |
+
|
341 |
+
def run_step(self, data_list):
|
342 |
+
"""
|
343 |
+
Run a training step.
|
344 |
+
"""
|
345 |
+
step_log = {'loss': {}, 'status': {}}
|
346 |
+
amp_context = partial(torch.autocast, device_type='cuda') if self.fp16_mode == 'amp' else nullcontext
|
347 |
+
elastic_controller_context = self.elastic_controller.record if self.elastic_controller_config is not None else nullcontext
|
348 |
+
|
349 |
+
# Train
|
350 |
+
losses = []
|
351 |
+
statuses = []
|
352 |
+
elastic_controller_logs = []
|
353 |
+
zero_grad(self.model_params)
|
354 |
+
for i, mb_data in enumerate(data_list):
|
355 |
+
## sync at the end of each batch split
|
356 |
+
sync_contexts = [self.training_models[name].no_sync for name in self.training_models] if i != len(data_list) - 1 and self.world_size > 1 else [nullcontext]
|
357 |
+
with nested_contexts(*sync_contexts), elastic_controller_context():
|
358 |
+
with amp_context():
|
359 |
+
loss, status = self.training_losses(**mb_data)
|
360 |
+
l = loss['loss'] / len(data_list)
|
361 |
+
## backward
|
362 |
+
if self.fp16_mode == 'amp':
|
363 |
+
self.scaler.scale(l).backward()
|
364 |
+
elif self.fp16_mode == 'inflat_all':
|
365 |
+
scaled_l = l * (2 ** self.log_scale)
|
366 |
+
scaled_l.backward()
|
367 |
+
else:
|
368 |
+
l.backward()
|
369 |
+
## log
|
370 |
+
losses.append(dict_foreach(loss, lambda x: x.item() if isinstance(x, torch.Tensor) else x))
|
371 |
+
statuses.append(dict_foreach(status, lambda x: x.item() if isinstance(x, torch.Tensor) else x))
|
372 |
+
if self.elastic_controller_config is not None:
|
373 |
+
elastic_controller_logs.append(self.elastic_controller.log())
|
374 |
+
## gradient clip
|
375 |
+
if self.grad_clip is not None:
|
376 |
+
if self.fp16_mode == 'amp':
|
377 |
+
self.scaler.unscale_(self.optimizer)
|
378 |
+
elif self.fp16_mode == 'inflat_all':
|
379 |
+
model_grads_to_master_grads(self.model_params, self.master_params)
|
380 |
+
self.master_params[0].grad.mul_(1.0 / (2 ** self.log_scale))
|
381 |
+
if isinstance(self.grad_clip, float):
|
382 |
+
grad_norm = torch.nn.utils.clip_grad_norm_(self.master_params, self.grad_clip)
|
383 |
+
else:
|
384 |
+
grad_norm = self.grad_clip(self.master_params)
|
385 |
+
if torch.isfinite(grad_norm):
|
386 |
+
statuses[-1]['grad_norm'] = grad_norm.item()
|
387 |
+
## step
|
388 |
+
if self.fp16_mode == 'amp':
|
389 |
+
prev_scale = self.scaler.get_scale()
|
390 |
+
self.scaler.step(self.optimizer)
|
391 |
+
self.scaler.update()
|
392 |
+
elif self.fp16_mode == 'inflat_all':
|
393 |
+
prev_scale = 2 ** self.log_scale
|
394 |
+
if not any(not p.grad.isfinite().all() for p in self.model_params):
|
395 |
+
if self.grad_clip is None:
|
396 |
+
model_grads_to_master_grads(self.model_params, self.master_params)
|
397 |
+
self.master_params[0].grad.mul_(1.0 / (2 ** self.log_scale))
|
398 |
+
self.optimizer.step()
|
399 |
+
master_params_to_model_params(self.model_params, self.master_params)
|
400 |
+
self.log_scale += self.fp16_scale_growth
|
401 |
+
else:
|
402 |
+
self.log_scale -= 1
|
403 |
+
else:
|
404 |
+
prev_scale = 1.0
|
405 |
+
if not any(not p.grad.isfinite().all() for p in self.model_params):
|
406 |
+
self.optimizer.step()
|
407 |
+
else:
|
408 |
+
print('\n\033[93mWarning: NaN detected in gradients. Skipping update.\033[0m')
|
409 |
+
## adjust learning rate
|
410 |
+
if self.lr_scheduler_config is not None:
|
411 |
+
statuses[-1]['lr'] = self.lr_scheduler.get_last_lr()[0]
|
412 |
+
self.lr_scheduler.step()
|
413 |
+
|
414 |
+
# Logs
|
415 |
+
step_log['loss'] = dict_reduce(losses, lambda x: np.mean(x))
|
416 |
+
step_log['status'] = dict_reduce(statuses, lambda x: np.mean(x), special_func={'min': lambda x: np.min(x), 'max': lambda x: np.max(x)})
|
417 |
+
if self.elastic_controller_config is not None:
|
418 |
+
step_log['elastic'] = dict_reduce(elastic_controller_logs, lambda x: np.mean(x))
|
419 |
+
if self.grad_clip is not None:
|
420 |
+
step_log['grad_clip'] = self.grad_clip if isinstance(self.grad_clip, float) else self.grad_clip.log()
|
421 |
+
|
422 |
+
# Check grad and norm of each param
|
423 |
+
if self.log_param_stats:
|
424 |
+
param_norms = {}
|
425 |
+
param_grads = {}
|
426 |
+
for name, param in self.backbone.named_parameters():
|
427 |
+
if param.requires_grad:
|
428 |
+
param_norms[name] = param.norm().item()
|
429 |
+
if param.grad is not None and torch.isfinite(param.grad).all():
|
430 |
+
param_grads[name] = param.grad.norm().item() / prev_scale
|
431 |
+
step_log['param_norms'] = param_norms
|
432 |
+
step_log['param_grads'] = param_grads
|
433 |
+
|
434 |
+
# Update exponential moving average
|
435 |
+
if self.is_master:
|
436 |
+
self.update_ema()
|
437 |
+
|
438 |
+
return step_log
|
trellis/trainers/flow_matching/flow_matching.py
ADDED
@@ -0,0 +1,353 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import *
|
2 |
+
import copy
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torch.utils.data import DataLoader
|
6 |
+
import numpy as np
|
7 |
+
from easydict import EasyDict as edict
|
8 |
+
|
9 |
+
from ..basic import BasicTrainer
|
10 |
+
from ...pipelines import samplers
|
11 |
+
from ...utils.general_utils import dict_reduce
|
12 |
+
from .mixins.classifier_free_guidance import ClassifierFreeGuidanceMixin
|
13 |
+
from .mixins.text_conditioned import TextConditionedMixin
|
14 |
+
from .mixins.image_conditioned import ImageConditionedMixin
|
15 |
+
|
16 |
+
|
17 |
+
class FlowMatchingTrainer(BasicTrainer):
|
18 |
+
"""
|
19 |
+
Trainer for diffusion model with flow matching objective.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
models (dict[str, nn.Module]): Models to train.
|
23 |
+
dataset (torch.utils.data.Dataset): Dataset.
|
24 |
+
output_dir (str): Output directory.
|
25 |
+
load_dir (str): Load directory.
|
26 |
+
step (int): Step to load.
|
27 |
+
batch_size (int): Batch size.
|
28 |
+
batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
|
29 |
+
batch_split (int): Split batch with gradient accumulation.
|
30 |
+
max_steps (int): Max steps.
|
31 |
+
optimizer (dict): Optimizer config.
|
32 |
+
lr_scheduler (dict): Learning rate scheduler config.
|
33 |
+
elastic (dict): Elastic memory management config.
|
34 |
+
grad_clip (float or dict): Gradient clip config.
|
35 |
+
ema_rate (float or list): Exponential moving average rates.
|
36 |
+
fp16_mode (str): FP16 mode.
|
37 |
+
- None: No FP16.
|
38 |
+
- 'inflat_all': Hold a inflated fp32 master param for all params.
|
39 |
+
- 'amp': Automatic mixed precision.
|
40 |
+
fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
|
41 |
+
finetune_ckpt (dict): Finetune checkpoint.
|
42 |
+
log_param_stats (bool): Log parameter stats.
|
43 |
+
i_print (int): Print interval.
|
44 |
+
i_log (int): Log interval.
|
45 |
+
i_sample (int): Sample interval.
|
46 |
+
i_save (int): Save interval.
|
47 |
+
i_ddpcheck (int): DDP check interval.
|
48 |
+
|
49 |
+
t_schedule (dict): Time schedule for flow matching.
|
50 |
+
sigma_min (float): Minimum noise level.
|
51 |
+
"""
|
52 |
+
def __init__(
|
53 |
+
self,
|
54 |
+
*args,
|
55 |
+
t_schedule: dict = {
|
56 |
+
'name': 'logitNormal',
|
57 |
+
'args': {
|
58 |
+
'mean': 0.0,
|
59 |
+
'std': 1.0,
|
60 |
+
}
|
61 |
+
},
|
62 |
+
sigma_min: float = 1e-5,
|
63 |
+
**kwargs
|
64 |
+
):
|
65 |
+
super().__init__(*args, **kwargs)
|
66 |
+
self.t_schedule = t_schedule
|
67 |
+
self.sigma_min = sigma_min
|
68 |
+
|
69 |
+
def diffuse(self, x_0: torch.Tensor, t: torch.Tensor, noise: Optional[torch.Tensor] = None) -> torch.Tensor:
|
70 |
+
"""
|
71 |
+
Diffuse the data for a given number of diffusion steps.
|
72 |
+
In other words, sample from q(x_t | x_0).
|
73 |
+
|
74 |
+
Args:
|
75 |
+
x_0: The [N x C x ...] tensor of noiseless inputs.
|
76 |
+
t: The [N] tensor of diffusion steps [0-1].
|
77 |
+
noise: If specified, use this noise instead of generating new noise.
|
78 |
+
|
79 |
+
Returns:
|
80 |
+
x_t, the noisy version of x_0 under timestep t.
|
81 |
+
"""
|
82 |
+
if noise is None:
|
83 |
+
noise = torch.randn_like(x_0)
|
84 |
+
assert noise.shape == x_0.shape, "noise must have same shape as x_0"
|
85 |
+
|
86 |
+
t = t.view(-1, *[1 for _ in range(len(x_0.shape) - 1)])
|
87 |
+
x_t = (1 - t) * x_0 + (self.sigma_min + (1 - self.sigma_min) * t) * noise
|
88 |
+
|
89 |
+
return x_t
|
90 |
+
|
91 |
+
def reverse_diffuse(self, x_t: torch.Tensor, t: torch.Tensor, noise: torch.Tensor) -> torch.Tensor:
|
92 |
+
"""
|
93 |
+
Get original image from noisy version under timestep t.
|
94 |
+
"""
|
95 |
+
assert noise.shape == x_t.shape, "noise must have same shape as x_t"
|
96 |
+
t = t.view(-1, *[1 for _ in range(len(x_t.shape) - 1)])
|
97 |
+
x_0 = (x_t - (self.sigma_min + (1 - self.sigma_min) * t) * noise) / (1 - t)
|
98 |
+
return x_0
|
99 |
+
|
100 |
+
def get_v(self, x_0: torch.Tensor, noise: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
|
101 |
+
"""
|
102 |
+
Compute the velocity of the diffusion process at time t.
|
103 |
+
"""
|
104 |
+
return (1 - self.sigma_min) * noise - x_0
|
105 |
+
|
106 |
+
def get_cond(self, cond, **kwargs):
|
107 |
+
"""
|
108 |
+
Get the conditioning data.
|
109 |
+
"""
|
110 |
+
return cond
|
111 |
+
|
112 |
+
def get_inference_cond(self, cond, **kwargs):
|
113 |
+
"""
|
114 |
+
Get the conditioning data for inference.
|
115 |
+
"""
|
116 |
+
return {'cond': cond, **kwargs}
|
117 |
+
|
118 |
+
def get_sampler(self, **kwargs) -> samplers.FlowEulerSampler:
|
119 |
+
"""
|
120 |
+
Get the sampler for the diffusion process.
|
121 |
+
"""
|
122 |
+
return samplers.FlowEulerSampler(self.sigma_min)
|
123 |
+
|
124 |
+
def vis_cond(self, **kwargs):
|
125 |
+
"""
|
126 |
+
Visualize the conditioning data.
|
127 |
+
"""
|
128 |
+
return {}
|
129 |
+
|
130 |
+
def sample_t(self, batch_size: int) -> torch.Tensor:
|
131 |
+
"""
|
132 |
+
Sample timesteps.
|
133 |
+
"""
|
134 |
+
if self.t_schedule['name'] == 'uniform':
|
135 |
+
t = torch.rand(batch_size)
|
136 |
+
elif self.t_schedule['name'] == 'logitNormal':
|
137 |
+
mean = self.t_schedule['args']['mean']
|
138 |
+
std = self.t_schedule['args']['std']
|
139 |
+
t = torch.sigmoid(torch.randn(batch_size) * std + mean)
|
140 |
+
else:
|
141 |
+
raise ValueError(f"Unknown t_schedule: {self.t_schedule['name']}")
|
142 |
+
return t
|
143 |
+
|
144 |
+
def training_losses(
|
145 |
+
self,
|
146 |
+
x_0: torch.Tensor,
|
147 |
+
cond=None,
|
148 |
+
**kwargs
|
149 |
+
) -> Tuple[Dict, Dict]:
|
150 |
+
"""
|
151 |
+
Compute training losses for a single timestep.
|
152 |
+
|
153 |
+
Args:
|
154 |
+
x_0: The [N x C x ...] tensor of noiseless inputs.
|
155 |
+
cond: The [N x ...] tensor of additional conditions.
|
156 |
+
kwargs: Additional arguments to pass to the backbone.
|
157 |
+
|
158 |
+
Returns:
|
159 |
+
a dict with the key "loss" containing a tensor of shape [N].
|
160 |
+
may also contain other keys for different terms.
|
161 |
+
"""
|
162 |
+
noise = torch.randn_like(x_0)
|
163 |
+
t = self.sample_t(x_0.shape[0]).to(x_0.device).float()
|
164 |
+
x_t = self.diffuse(x_0, t, noise=noise)
|
165 |
+
cond = self.get_cond(cond, **kwargs)
|
166 |
+
|
167 |
+
pred = self.training_models['denoiser'](x_t, t * 1000, cond, **kwargs)
|
168 |
+
assert pred.shape == noise.shape == x_0.shape
|
169 |
+
target = self.get_v(x_0, noise, t)
|
170 |
+
terms = edict()
|
171 |
+
terms["mse"] = F.mse_loss(pred, target)
|
172 |
+
terms["loss"] = terms["mse"]
|
173 |
+
|
174 |
+
# log loss with time bins
|
175 |
+
mse_per_instance = np.array([
|
176 |
+
F.mse_loss(pred[i], target[i]).item()
|
177 |
+
for i in range(x_0.shape[0])
|
178 |
+
])
|
179 |
+
time_bin = np.digitize(t.cpu().numpy(), np.linspace(0, 1, 11)) - 1
|
180 |
+
for i in range(10):
|
181 |
+
if (time_bin == i).sum() != 0:
|
182 |
+
terms[f"bin_{i}"] = {"mse": mse_per_instance[time_bin == i].mean()}
|
183 |
+
|
184 |
+
return terms, {}
|
185 |
+
|
186 |
+
@torch.no_grad()
|
187 |
+
def run_snapshot(
|
188 |
+
self,
|
189 |
+
num_samples: int,
|
190 |
+
batch_size: int,
|
191 |
+
verbose: bool = False,
|
192 |
+
) -> Dict:
|
193 |
+
dataloader = DataLoader(
|
194 |
+
copy.deepcopy(self.dataset),
|
195 |
+
batch_size=batch_size,
|
196 |
+
shuffle=True,
|
197 |
+
num_workers=0,
|
198 |
+
collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None,
|
199 |
+
)
|
200 |
+
|
201 |
+
# inference
|
202 |
+
sampler = self.get_sampler()
|
203 |
+
sample_gt = []
|
204 |
+
sample = []
|
205 |
+
cond_vis = []
|
206 |
+
for i in range(0, num_samples, batch_size):
|
207 |
+
batch = min(batch_size, num_samples - i)
|
208 |
+
data = next(iter(dataloader))
|
209 |
+
data = {k: v[:batch].cuda() if isinstance(v, torch.Tensor) else v[:batch] for k, v in data.items()}
|
210 |
+
noise = torch.randn_like(data['x_0'])
|
211 |
+
sample_gt.append(data['x_0'])
|
212 |
+
cond_vis.append(self.vis_cond(**data))
|
213 |
+
del data['x_0']
|
214 |
+
args = self.get_inference_cond(**data)
|
215 |
+
res = sampler.sample(
|
216 |
+
self.models['denoiser'],
|
217 |
+
noise=noise,
|
218 |
+
**args,
|
219 |
+
steps=50, cfg_strength=3.0, verbose=verbose,
|
220 |
+
)
|
221 |
+
sample.append(res.samples)
|
222 |
+
|
223 |
+
sample_gt = torch.cat(sample_gt, dim=0)
|
224 |
+
sample = torch.cat(sample, dim=0)
|
225 |
+
sample_dict = {
|
226 |
+
'sample_gt': {'value': sample_gt, 'type': 'sample'},
|
227 |
+
'sample': {'value': sample, 'type': 'sample'},
|
228 |
+
}
|
229 |
+
sample_dict.update(dict_reduce(cond_vis, None, {
|
230 |
+
'value': lambda x: torch.cat(x, dim=0),
|
231 |
+
'type': lambda x: x[0],
|
232 |
+
}))
|
233 |
+
|
234 |
+
return sample_dict
|
235 |
+
|
236 |
+
|
237 |
+
class FlowMatchingCFGTrainer(ClassifierFreeGuidanceMixin, FlowMatchingTrainer):
|
238 |
+
"""
|
239 |
+
Trainer for diffusion model with flow matching objective and classifier-free guidance.
|
240 |
+
|
241 |
+
Args:
|
242 |
+
models (dict[str, nn.Module]): Models to train.
|
243 |
+
dataset (torch.utils.data.Dataset): Dataset.
|
244 |
+
output_dir (str): Output directory.
|
245 |
+
load_dir (str): Load directory.
|
246 |
+
step (int): Step to load.
|
247 |
+
batch_size (int): Batch size.
|
248 |
+
batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
|
249 |
+
batch_split (int): Split batch with gradient accumulation.
|
250 |
+
max_steps (int): Max steps.
|
251 |
+
optimizer (dict): Optimizer config.
|
252 |
+
lr_scheduler (dict): Learning rate scheduler config.
|
253 |
+
elastic (dict): Elastic memory management config.
|
254 |
+
grad_clip (float or dict): Gradient clip config.
|
255 |
+
ema_rate (float or list): Exponential moving average rates.
|
256 |
+
fp16_mode (str): FP16 mode.
|
257 |
+
- None: No FP16.
|
258 |
+
- 'inflat_all': Hold a inflated fp32 master param for all params.
|
259 |
+
- 'amp': Automatic mixed precision.
|
260 |
+
fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
|
261 |
+
finetune_ckpt (dict): Finetune checkpoint.
|
262 |
+
log_param_stats (bool): Log parameter stats.
|
263 |
+
i_print (int): Print interval.
|
264 |
+
i_log (int): Log interval.
|
265 |
+
i_sample (int): Sample interval.
|
266 |
+
i_save (int): Save interval.
|
267 |
+
i_ddpcheck (int): DDP check interval.
|
268 |
+
|
269 |
+
t_schedule (dict): Time schedule for flow matching.
|
270 |
+
sigma_min (float): Minimum noise level.
|
271 |
+
p_uncond (float): Probability of dropping conditions.
|
272 |
+
"""
|
273 |
+
pass
|
274 |
+
|
275 |
+
|
276 |
+
class TextConditionedFlowMatchingCFGTrainer(TextConditionedMixin, FlowMatchingCFGTrainer):
|
277 |
+
"""
|
278 |
+
Trainer for text-conditioned diffusion model with flow matching objective and classifier-free guidance.
|
279 |
+
|
280 |
+
Args:
|
281 |
+
models (dict[str, nn.Module]): Models to train.
|
282 |
+
dataset (torch.utils.data.Dataset): Dataset.
|
283 |
+
output_dir (str): Output directory.
|
284 |
+
load_dir (str): Load directory.
|
285 |
+
step (int): Step to load.
|
286 |
+
batch_size (int): Batch size.
|
287 |
+
batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
|
288 |
+
batch_split (int): Split batch with gradient accumulation.
|
289 |
+
max_steps (int): Max steps.
|
290 |
+
optimizer (dict): Optimizer config.
|
291 |
+
lr_scheduler (dict): Learning rate scheduler config.
|
292 |
+
elastic (dict): Elastic memory management config.
|
293 |
+
grad_clip (float or dict): Gradient clip config.
|
294 |
+
ema_rate (float or list): Exponential moving average rates.
|
295 |
+
fp16_mode (str): FP16 mode.
|
296 |
+
- None: No FP16.
|
297 |
+
- 'inflat_all': Hold a inflated fp32 master param for all params.
|
298 |
+
- 'amp': Automatic mixed precision.
|
299 |
+
fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
|
300 |
+
finetune_ckpt (dict): Finetune checkpoint.
|
301 |
+
log_param_stats (bool): Log parameter stats.
|
302 |
+
i_print (int): Print interval.
|
303 |
+
i_log (int): Log interval.
|
304 |
+
i_sample (int): Sample interval.
|
305 |
+
i_save (int): Save interval.
|
306 |
+
i_ddpcheck (int): DDP check interval.
|
307 |
+
|
308 |
+
t_schedule (dict): Time schedule for flow matching.
|
309 |
+
sigma_min (float): Minimum noise level.
|
310 |
+
p_uncond (float): Probability of dropping conditions.
|
311 |
+
text_cond_model(str): Text conditioning model.
|
312 |
+
"""
|
313 |
+
pass
|
314 |
+
|
315 |
+
|
316 |
+
class ImageConditionedFlowMatchingCFGTrainer(ImageConditionedMixin, FlowMatchingCFGTrainer):
|
317 |
+
"""
|
318 |
+
Trainer for image-conditioned diffusion model with flow matching objective and classifier-free guidance.
|
319 |
+
|
320 |
+
Args:
|
321 |
+
models (dict[str, nn.Module]): Models to train.
|
322 |
+
dataset (torch.utils.data.Dataset): Dataset.
|
323 |
+
output_dir (str): Output directory.
|
324 |
+
load_dir (str): Load directory.
|
325 |
+
step (int): Step to load.
|
326 |
+
batch_size (int): Batch size.
|
327 |
+
batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
|
328 |
+
batch_split (int): Split batch with gradient accumulation.
|
329 |
+
max_steps (int): Max steps.
|
330 |
+
optimizer (dict): Optimizer config.
|
331 |
+
lr_scheduler (dict): Learning rate scheduler config.
|
332 |
+
elastic (dict): Elastic memory management config.
|
333 |
+
grad_clip (float or dict): Gradient clip config.
|
334 |
+
ema_rate (float or list): Exponential moving average rates.
|
335 |
+
fp16_mode (str): FP16 mode.
|
336 |
+
- None: No FP16.
|
337 |
+
- 'inflat_all': Hold a inflated fp32 master param for all params.
|
338 |
+
- 'amp': Automatic mixed precision.
|
339 |
+
fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
|
340 |
+
finetune_ckpt (dict): Finetune checkpoint.
|
341 |
+
log_param_stats (bool): Log parameter stats.
|
342 |
+
i_print (int): Print interval.
|
343 |
+
i_log (int): Log interval.
|
344 |
+
i_sample (int): Sample interval.
|
345 |
+
i_save (int): Save interval.
|
346 |
+
i_ddpcheck (int): DDP check interval.
|
347 |
+
|
348 |
+
t_schedule (dict): Time schedule for flow matching.
|
349 |
+
sigma_min (float): Minimum noise level.
|
350 |
+
p_uncond (float): Probability of dropping conditions.
|
351 |
+
image_cond_model (str): Image conditioning model.
|
352 |
+
"""
|
353 |
+
pass
|
trellis/trainers/flow_matching/mixins/classifier_free_guidance.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from ....utils.general_utils import dict_foreach
|
4 |
+
from ....pipelines import samplers
|
5 |
+
|
6 |
+
|
7 |
+
class ClassifierFreeGuidanceMixin:
|
8 |
+
def __init__(self, *args, p_uncond: float = 0.1, **kwargs):
|
9 |
+
super().__init__(*args, **kwargs)
|
10 |
+
self.p_uncond = p_uncond
|
11 |
+
|
12 |
+
def get_cond(self, cond, neg_cond=None, **kwargs):
|
13 |
+
"""
|
14 |
+
Get the conditioning data.
|
15 |
+
"""
|
16 |
+
assert neg_cond is not None, "neg_cond must be provided for classifier-free guidance"
|
17 |
+
|
18 |
+
if self.p_uncond > 0:
|
19 |
+
# randomly drop the class label
|
20 |
+
def get_batch_size(cond):
|
21 |
+
if isinstance(cond, torch.Tensor):
|
22 |
+
return cond.shape[0]
|
23 |
+
elif isinstance(cond, list):
|
24 |
+
return len(cond)
|
25 |
+
else:
|
26 |
+
raise ValueError(f"Unsupported type of cond: {type(cond)}")
|
27 |
+
|
28 |
+
ref_cond = cond if not isinstance(cond, dict) else cond[list(cond.keys())[0]]
|
29 |
+
B = get_batch_size(ref_cond)
|
30 |
+
|
31 |
+
def select(cond, neg_cond, mask):
|
32 |
+
if isinstance(cond, torch.Tensor):
|
33 |
+
mask = torch.tensor(mask, device=cond.device).reshape(-1, *[1] * (cond.ndim - 1))
|
34 |
+
return torch.where(mask, neg_cond, cond)
|
35 |
+
elif isinstance(cond, list):
|
36 |
+
return [nc if m else c for c, nc, m in zip(cond, neg_cond, mask)]
|
37 |
+
else:
|
38 |
+
raise ValueError(f"Unsupported type of cond: {type(cond)}")
|
39 |
+
|
40 |
+
mask = list(np.random.rand(B) < self.p_uncond)
|
41 |
+
if not isinstance(cond, dict):
|
42 |
+
cond = select(cond, neg_cond, mask)
|
43 |
+
else:
|
44 |
+
cond = dict_foreach([cond, neg_cond], lambda x: select(x[0], x[1], mask))
|
45 |
+
|
46 |
+
return cond
|
47 |
+
|
48 |
+
def get_inference_cond(self, cond, neg_cond=None, **kwargs):
|
49 |
+
"""
|
50 |
+
Get the conditioning data for inference.
|
51 |
+
"""
|
52 |
+
assert neg_cond is not None, "neg_cond must be provided for classifier-free guidance"
|
53 |
+
return {'cond': cond, 'neg_cond': neg_cond, **kwargs}
|
54 |
+
|
55 |
+
def get_sampler(self, **kwargs) -> samplers.FlowEulerCfgSampler:
|
56 |
+
"""
|
57 |
+
Get the sampler for the diffusion process.
|
58 |
+
"""
|
59 |
+
return samplers.FlowEulerCfgSampler(self.sigma_min)
|
trellis/trainers/flow_matching/mixins/image_conditioned.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import *
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torchvision import transforms
|
5 |
+
import numpy as np
|
6 |
+
from PIL import Image
|
7 |
+
|
8 |
+
from ....utils import dist_utils
|
9 |
+
|
10 |
+
|
11 |
+
class ImageConditionedMixin:
|
12 |
+
"""
|
13 |
+
Mixin for image-conditioned models.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
image_cond_model: The image conditioning model.
|
17 |
+
"""
|
18 |
+
def __init__(self, *args, image_cond_model: str = 'dinov2_vitl14_reg', **kwargs):
|
19 |
+
super().__init__(*args, **kwargs)
|
20 |
+
self.image_cond_model_name = image_cond_model
|
21 |
+
self.image_cond_model = None # the model is init lazily
|
22 |
+
|
23 |
+
@staticmethod
|
24 |
+
def prepare_for_training(image_cond_model: str, **kwargs):
|
25 |
+
"""
|
26 |
+
Prepare for training.
|
27 |
+
"""
|
28 |
+
if hasattr(super(ImageConditionedMixin, ImageConditionedMixin), 'prepare_for_training'):
|
29 |
+
super(ImageConditionedMixin, ImageConditionedMixin).prepare_for_training(**kwargs)
|
30 |
+
# download the model
|
31 |
+
torch.hub.load('facebookresearch/dinov2', image_cond_model, pretrained=True)
|
32 |
+
|
33 |
+
def _init_image_cond_model(self):
|
34 |
+
"""
|
35 |
+
Initialize the image conditioning model.
|
36 |
+
"""
|
37 |
+
with dist_utils.local_master_first():
|
38 |
+
dinov2_model = torch.hub.load('facebookresearch/dinov2', self.image_cond_model_name, pretrained=True)
|
39 |
+
dinov2_model.eval().cuda()
|
40 |
+
transform = transforms.Compose([
|
41 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
42 |
+
])
|
43 |
+
self.image_cond_model = {
|
44 |
+
'model': dinov2_model,
|
45 |
+
'transform': transform,
|
46 |
+
}
|
47 |
+
|
48 |
+
@torch.no_grad()
|
49 |
+
def encode_image(self, image: Union[torch.Tensor, List[Image.Image]]) -> torch.Tensor:
|
50 |
+
"""
|
51 |
+
Encode the image.
|
52 |
+
"""
|
53 |
+
if isinstance(image, torch.Tensor):
|
54 |
+
assert image.ndim == 4, "Image tensor should be batched (B, C, H, W)"
|
55 |
+
elif isinstance(image, list):
|
56 |
+
assert all(isinstance(i, Image.Image) for i in image), "Image list should be list of PIL images"
|
57 |
+
image = [i.resize((518, 518), Image.LANCZOS) for i in image]
|
58 |
+
image = [np.array(i.convert('RGB')).astype(np.float32) / 255 for i in image]
|
59 |
+
image = [torch.from_numpy(i).permute(2, 0, 1).float() for i in image]
|
60 |
+
image = torch.stack(image).cuda()
|
61 |
+
else:
|
62 |
+
raise ValueError(f"Unsupported type of image: {type(image)}")
|
63 |
+
|
64 |
+
if self.image_cond_model is None:
|
65 |
+
self._init_image_cond_model()
|
66 |
+
image = self.image_cond_model['transform'](image).cuda()
|
67 |
+
features = self.image_cond_model['model'](image, is_training=True)['x_prenorm']
|
68 |
+
patchtokens = F.layer_norm(features, features.shape[-1:])
|
69 |
+
return patchtokens
|
70 |
+
|
71 |
+
def get_cond(self, cond, **kwargs):
|
72 |
+
"""
|
73 |
+
Get the conditioning data.
|
74 |
+
"""
|
75 |
+
cond = self.encode_image(cond)
|
76 |
+
kwargs['neg_cond'] = torch.zeros_like(cond)
|
77 |
+
cond = super().get_cond(cond, **kwargs)
|
78 |
+
return cond
|
79 |
+
|
80 |
+
def get_inference_cond(self, cond, **kwargs):
|
81 |
+
"""
|
82 |
+
Get the conditioning data for inference.
|
83 |
+
"""
|
84 |
+
cond = self.encode_image(cond)
|
85 |
+
kwargs['neg_cond'] = torch.zeros_like(cond)
|
86 |
+
cond = super().get_inference_cond(cond, **kwargs)
|
87 |
+
return cond
|
88 |
+
|
89 |
+
def vis_cond(self, cond, **kwargs):
|
90 |
+
"""
|
91 |
+
Visualize the conditioning data.
|
92 |
+
"""
|
93 |
+
return {'image': {'value': cond, 'type': 'image'}}
|
trellis/trainers/flow_matching/mixins/text_conditioned.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import *
|
2 |
+
import os
|
3 |
+
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
|
4 |
+
import torch
|
5 |
+
from transformers import AutoTokenizer, CLIPTextModel
|
6 |
+
|
7 |
+
from ....utils import dist_utils
|
8 |
+
|
9 |
+
|
10 |
+
class TextConditionedMixin:
|
11 |
+
"""
|
12 |
+
Mixin for text-conditioned models.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
text_cond_model: The text conditioning model.
|
16 |
+
"""
|
17 |
+
def __init__(self, *args, text_cond_model: str = 'openai/clip-vit-large-patch14', **kwargs):
|
18 |
+
super().__init__(*args, **kwargs)
|
19 |
+
self.text_cond_model_name = text_cond_model
|
20 |
+
self.text_cond_model = None # the model is init lazily
|
21 |
+
|
22 |
+
def _init_text_cond_model(self):
|
23 |
+
"""
|
24 |
+
Initialize the text conditioning model.
|
25 |
+
"""
|
26 |
+
# load model
|
27 |
+
with dist_utils.local_master_first():
|
28 |
+
model = CLIPTextModel.from_pretrained(self.text_cond_model_name)
|
29 |
+
tokenizer = AutoTokenizer.from_pretrained(self.text_cond_model_name)
|
30 |
+
model.eval()
|
31 |
+
model = model.cuda()
|
32 |
+
self.text_cond_model = {
|
33 |
+
'model': model,
|
34 |
+
'tokenizer': tokenizer,
|
35 |
+
}
|
36 |
+
self.text_cond_model['null_cond'] = self.encode_text([''])
|
37 |
+
|
38 |
+
@torch.no_grad()
|
39 |
+
def encode_text(self, text: List[str]) -> torch.Tensor:
|
40 |
+
"""
|
41 |
+
Encode the text.
|
42 |
+
"""
|
43 |
+
assert isinstance(text, list) and isinstance(text[0], str), "TextConditionedMixin only supports list of strings as cond"
|
44 |
+
if self.text_cond_model is None:
|
45 |
+
self._init_text_cond_model()
|
46 |
+
encoding = self.text_cond_model['tokenizer'](text, max_length=77, padding='max_length', truncation=True, return_tensors='pt')
|
47 |
+
tokens = encoding['input_ids'].cuda()
|
48 |
+
embeddings = self.text_cond_model['model'](input_ids=tokens).last_hidden_state
|
49 |
+
|
50 |
+
return embeddings
|
51 |
+
|
52 |
+
def get_cond(self, cond, **kwargs):
|
53 |
+
"""
|
54 |
+
Get the conditioning data.
|
55 |
+
"""
|
56 |
+
cond = self.encode_text(cond)
|
57 |
+
kwargs['neg_cond'] = self.text_cond_model['null_cond'].repeat(cond.shape[0], 1, 1)
|
58 |
+
cond = super().get_cond(cond, **kwargs)
|
59 |
+
return cond
|
60 |
+
|
61 |
+
def get_inference_cond(self, cond, **kwargs):
|
62 |
+
"""
|
63 |
+
Get the conditioning data for inference.
|
64 |
+
"""
|
65 |
+
cond = self.encode_text(cond)
|
66 |
+
kwargs['neg_cond'] = self.text_cond_model['null_cond'].repeat(cond.shape[0], 1, 1)
|
67 |
+
cond = super().get_inference_cond(cond, **kwargs)
|
68 |
+
return cond
|
trellis/trainers/flow_matching/sparse_flow_matching.py
ADDED
@@ -0,0 +1,286 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import *
|
2 |
+
import os
|
3 |
+
import copy
|
4 |
+
import functools
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from torch.utils.data import DataLoader
|
8 |
+
import numpy as np
|
9 |
+
from easydict import EasyDict as edict
|
10 |
+
|
11 |
+
from ...modules import sparse as sp
|
12 |
+
from ...utils.general_utils import dict_reduce
|
13 |
+
from ...utils.data_utils import cycle, BalancedResumableSampler
|
14 |
+
from .flow_matching import FlowMatchingTrainer
|
15 |
+
from .mixins.classifier_free_guidance import ClassifierFreeGuidanceMixin
|
16 |
+
from .mixins.text_conditioned import TextConditionedMixin
|
17 |
+
from .mixins.image_conditioned import ImageConditionedMixin
|
18 |
+
|
19 |
+
|
20 |
+
class SparseFlowMatchingTrainer(FlowMatchingTrainer):
|
21 |
+
"""
|
22 |
+
Trainer for sparse diffusion model with flow matching objective.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
models (dict[str, nn.Module]): Models to train.
|
26 |
+
dataset (torch.utils.data.Dataset): Dataset.
|
27 |
+
output_dir (str): Output directory.
|
28 |
+
load_dir (str): Load directory.
|
29 |
+
step (int): Step to load.
|
30 |
+
batch_size (int): Batch size.
|
31 |
+
batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
|
32 |
+
batch_split (int): Split batch with gradient accumulation.
|
33 |
+
max_steps (int): Max steps.
|
34 |
+
optimizer (dict): Optimizer config.
|
35 |
+
lr_scheduler (dict): Learning rate scheduler config.
|
36 |
+
elastic (dict): Elastic memory management config.
|
37 |
+
grad_clip (float or dict): Gradient clip config.
|
38 |
+
ema_rate (float or list): Exponential moving average rates.
|
39 |
+
fp16_mode (str): FP16 mode.
|
40 |
+
- None: No FP16.
|
41 |
+
- 'inflat_all': Hold a inflated fp32 master param for all params.
|
42 |
+
- 'amp': Automatic mixed precision.
|
43 |
+
fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
|
44 |
+
finetune_ckpt (dict): Finetune checkpoint.
|
45 |
+
log_param_stats (bool): Log parameter stats.
|
46 |
+
i_print (int): Print interval.
|
47 |
+
i_log (int): Log interval.
|
48 |
+
i_sample (int): Sample interval.
|
49 |
+
i_save (int): Save interval.
|
50 |
+
i_ddpcheck (int): DDP check interval.
|
51 |
+
|
52 |
+
t_schedule (dict): Time schedule for flow matching.
|
53 |
+
sigma_min (float): Minimum noise level.
|
54 |
+
"""
|
55 |
+
|
56 |
+
def prepare_dataloader(self, **kwargs):
|
57 |
+
"""
|
58 |
+
Prepare dataloader.
|
59 |
+
"""
|
60 |
+
self.data_sampler = BalancedResumableSampler(
|
61 |
+
self.dataset,
|
62 |
+
shuffle=True,
|
63 |
+
batch_size=self.batch_size_per_gpu,
|
64 |
+
)
|
65 |
+
self.dataloader = DataLoader(
|
66 |
+
self.dataset,
|
67 |
+
batch_size=self.batch_size_per_gpu,
|
68 |
+
num_workers=int(np.ceil(os.cpu_count() / torch.cuda.device_count())),
|
69 |
+
pin_memory=True,
|
70 |
+
drop_last=True,
|
71 |
+
persistent_workers=True,
|
72 |
+
collate_fn=functools.partial(self.dataset.collate_fn, split_size=self.batch_split),
|
73 |
+
sampler=self.data_sampler,
|
74 |
+
)
|
75 |
+
self.data_iterator = cycle(self.dataloader)
|
76 |
+
|
77 |
+
def training_losses(
|
78 |
+
self,
|
79 |
+
x_0: sp.SparseTensor,
|
80 |
+
cond=None,
|
81 |
+
**kwargs
|
82 |
+
) -> Tuple[Dict, Dict]:
|
83 |
+
"""
|
84 |
+
Compute training losses for a single timestep.
|
85 |
+
|
86 |
+
Args:
|
87 |
+
x_0: The [N x ... x C] sparse tensor of the inputs.
|
88 |
+
cond: The [N x ...] tensor of additional conditions.
|
89 |
+
kwargs: Additional arguments to pass to the backbone.
|
90 |
+
|
91 |
+
Returns:
|
92 |
+
a dict with the key "loss" containing a tensor of shape [N].
|
93 |
+
may also contain other keys for different terms.
|
94 |
+
"""
|
95 |
+
noise = x_0.replace(torch.randn_like(x_0.feats))
|
96 |
+
t = self.sample_t(x_0.shape[0]).to(x_0.device).float()
|
97 |
+
x_t = self.diffuse(x_0, t, noise=noise)
|
98 |
+
cond = self.get_cond(cond, **kwargs)
|
99 |
+
|
100 |
+
pred = self.training_models['denoiser'](x_t, t * 1000, cond, **kwargs)
|
101 |
+
assert pred.shape == noise.shape == x_0.shape
|
102 |
+
target = self.get_v(x_0, noise, t)
|
103 |
+
terms = edict()
|
104 |
+
terms["mse"] = F.mse_loss(pred.feats, target.feats)
|
105 |
+
terms["loss"] = terms["mse"]
|
106 |
+
|
107 |
+
# log loss with time bins
|
108 |
+
mse_per_instance = np.array([
|
109 |
+
F.mse_loss(pred.feats[x_0.layout[i]], target.feats[x_0.layout[i]]).item()
|
110 |
+
for i in range(x_0.shape[0])
|
111 |
+
])
|
112 |
+
time_bin = np.digitize(t.cpu().numpy(), np.linspace(0, 1, 11)) - 1
|
113 |
+
for i in range(10):
|
114 |
+
if (time_bin == i).sum() != 0:
|
115 |
+
terms[f"bin_{i}"] = {"mse": mse_per_instance[time_bin == i].mean()}
|
116 |
+
|
117 |
+
return terms, {}
|
118 |
+
|
119 |
+
@torch.no_grad()
|
120 |
+
def run_snapshot(
|
121 |
+
self,
|
122 |
+
num_samples: int,
|
123 |
+
batch_size: int,
|
124 |
+
verbose: bool = False,
|
125 |
+
) -> Dict:
|
126 |
+
dataloader = DataLoader(
|
127 |
+
copy.deepcopy(self.dataset),
|
128 |
+
batch_size=batch_size,
|
129 |
+
shuffle=True,
|
130 |
+
num_workers=0,
|
131 |
+
collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None,
|
132 |
+
)
|
133 |
+
|
134 |
+
# inference
|
135 |
+
sampler = self.get_sampler()
|
136 |
+
sample_gt = []
|
137 |
+
sample = []
|
138 |
+
cond_vis = []
|
139 |
+
for i in range(0, num_samples, batch_size):
|
140 |
+
batch = min(batch_size, num_samples - i)
|
141 |
+
data = next(iter(dataloader))
|
142 |
+
data = {k: v[:batch].cuda() if not isinstance(v, list) else v[:batch] for k, v in data.items()}
|
143 |
+
noise = data['x_0'].replace(torch.randn_like(data['x_0'].feats))
|
144 |
+
sample_gt.append(data['x_0'])
|
145 |
+
cond_vis.append(self.vis_cond(**data))
|
146 |
+
del data['x_0']
|
147 |
+
args = self.get_inference_cond(**data)
|
148 |
+
res = sampler.sample(
|
149 |
+
self.models['denoiser'],
|
150 |
+
noise=noise,
|
151 |
+
**args,
|
152 |
+
steps=50, cfg_strength=3.0, verbose=verbose,
|
153 |
+
)
|
154 |
+
sample.append(res.samples)
|
155 |
+
|
156 |
+
sample_gt = sp.sparse_cat(sample_gt)
|
157 |
+
sample = sp.sparse_cat(sample)
|
158 |
+
sample_dict = {
|
159 |
+
'sample_gt': {'value': sample_gt, 'type': 'sample'},
|
160 |
+
'sample': {'value': sample, 'type': 'sample'},
|
161 |
+
}
|
162 |
+
sample_dict.update(dict_reduce(cond_vis, None, {
|
163 |
+
'value': lambda x: torch.cat(x, dim=0),
|
164 |
+
'type': lambda x: x[0],
|
165 |
+
}))
|
166 |
+
|
167 |
+
return sample_dict
|
168 |
+
|
169 |
+
|
170 |
+
class SparseFlowMatchingCFGTrainer(ClassifierFreeGuidanceMixin, SparseFlowMatchingTrainer):
|
171 |
+
"""
|
172 |
+
Trainer for sparse diffusion model with flow matching objective and classifier-free guidance.
|
173 |
+
|
174 |
+
Args:
|
175 |
+
models (dict[str, nn.Module]): Models to train.
|
176 |
+
dataset (torch.utils.data.Dataset): Dataset.
|
177 |
+
output_dir (str): Output directory.
|
178 |
+
load_dir (str): Load directory.
|
179 |
+
step (int): Step to load.
|
180 |
+
batch_size (int): Batch size.
|
181 |
+
batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
|
182 |
+
batch_split (int): Split batch with gradient accumulation.
|
183 |
+
max_steps (int): Max steps.
|
184 |
+
optimizer (dict): Optimizer config.
|
185 |
+
lr_scheduler (dict): Learning rate scheduler config.
|
186 |
+
elastic (dict): Elastic memory management config.
|
187 |
+
grad_clip (float or dict): Gradient clip config.
|
188 |
+
ema_rate (float or list): Exponential moving average rates.
|
189 |
+
fp16_mode (str): FP16 mode.
|
190 |
+
- None: No FP16.
|
191 |
+
- 'inflat_all': Hold a inflated fp32 master param for all params.
|
192 |
+
- 'amp': Automatic mixed precision.
|
193 |
+
fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
|
194 |
+
finetune_ckpt (dict): Finetune checkpoint.
|
195 |
+
log_param_stats (bool): Log parameter stats.
|
196 |
+
i_print (int): Print interval.
|
197 |
+
i_log (int): Log interval.
|
198 |
+
i_sample (int): Sample interval.
|
199 |
+
i_save (int): Save interval.
|
200 |
+
i_ddpcheck (int): DDP check interval.
|
201 |
+
|
202 |
+
t_schedule (dict): Time schedule for flow matching.
|
203 |
+
sigma_min (float): Minimum noise level.
|
204 |
+
p_uncond (float): Probability of dropping conditions.
|
205 |
+
"""
|
206 |
+
pass
|
207 |
+
|
208 |
+
|
209 |
+
class TextConditionedSparseFlowMatchingCFGTrainer(TextConditionedMixin, SparseFlowMatchingCFGTrainer):
|
210 |
+
"""
|
211 |
+
Trainer for sparse text-conditioned diffusion model with flow matching objective and classifier-free guidance.
|
212 |
+
|
213 |
+
Args:
|
214 |
+
models (dict[str, nn.Module]): Models to train.
|
215 |
+
dataset (torch.utils.data.Dataset): Dataset.
|
216 |
+
output_dir (str): Output directory.
|
217 |
+
load_dir (str): Load directory.
|
218 |
+
step (int): Step to load.
|
219 |
+
batch_size (int): Batch size.
|
220 |
+
batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
|
221 |
+
batch_split (int): Split batch with gradient accumulation.
|
222 |
+
max_steps (int): Max steps.
|
223 |
+
optimizer (dict): Optimizer config.
|
224 |
+
lr_scheduler (dict): Learning rate scheduler config.
|
225 |
+
elastic (dict): Elastic memory management config.
|
226 |
+
grad_clip (float or dict): Gradient clip config.
|
227 |
+
ema_rate (float or list): Exponential moving average rates.
|
228 |
+
fp16_mode (str): FP16 mode.
|
229 |
+
- None: No FP16.
|
230 |
+
- 'inflat_all': Hold a inflated fp32 master param for all params.
|
231 |
+
- 'amp': Automatic mixed precision.
|
232 |
+
fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
|
233 |
+
finetune_ckpt (dict): Finetune checkpoint.
|
234 |
+
log_param_stats (bool): Log parameter stats.
|
235 |
+
i_print (int): Print interval.
|
236 |
+
i_log (int): Log interval.
|
237 |
+
i_sample (int): Sample interval.
|
238 |
+
i_save (int): Save interval.
|
239 |
+
i_ddpcheck (int): DDP check interval.
|
240 |
+
|
241 |
+
t_schedule (dict): Time schedule for flow matching.
|
242 |
+
sigma_min (float): Minimum noise level.
|
243 |
+
p_uncond (float): Probability of dropping conditions.
|
244 |
+
text_cond_model(str): Text conditioning model.
|
245 |
+
"""
|
246 |
+
pass
|
247 |
+
|
248 |
+
|
249 |
+
class ImageConditionedSparseFlowMatchingCFGTrainer(ImageConditionedMixin, SparseFlowMatchingCFGTrainer):
|
250 |
+
"""
|
251 |
+
Trainer for sparse image-conditioned diffusion model with flow matching objective and classifier-free guidance.
|
252 |
+
|
253 |
+
Args:
|
254 |
+
models (dict[str, nn.Module]): Models to train.
|
255 |
+
dataset (torch.utils.data.Dataset): Dataset.
|
256 |
+
output_dir (str): Output directory.
|
257 |
+
load_dir (str): Load directory.
|
258 |
+
step (int): Step to load.
|
259 |
+
batch_size (int): Batch size.
|
260 |
+
batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
|
261 |
+
batch_split (int): Split batch with gradient accumulation.
|
262 |
+
max_steps (int): Max steps.
|
263 |
+
optimizer (dict): Optimizer config.
|
264 |
+
lr_scheduler (dict): Learning rate scheduler config.
|
265 |
+
elastic (dict): Elastic memory management config.
|
266 |
+
grad_clip (float or dict): Gradient clip config.
|
267 |
+
ema_rate (float or list): Exponential moving average rates.
|
268 |
+
fp16_mode (str): FP16 mode.
|
269 |
+
- None: No FP16.
|
270 |
+
- 'inflat_all': Hold a inflated fp32 master param for all params.
|
271 |
+
- 'amp': Automatic mixed precision.
|
272 |
+
fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
|
273 |
+
finetune_ckpt (dict): Finetune checkpoint.
|
274 |
+
log_param_stats (bool): Log parameter stats.
|
275 |
+
i_print (int): Print interval.
|
276 |
+
i_log (int): Log interval.
|
277 |
+
i_sample (int): Sample interval.
|
278 |
+
i_save (int): Save interval.
|
279 |
+
i_ddpcheck (int): DDP check interval.
|
280 |
+
|
281 |
+
t_schedule (dict): Time schedule for flow matching.
|
282 |
+
sigma_min (float): Minimum noise level.
|
283 |
+
p_uncond (float): Probability of dropping conditions.
|
284 |
+
image_cond_model (str): Image conditioning model.
|
285 |
+
"""
|
286 |
+
pass
|
trellis/trainers/utils.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
|
3 |
+
|
4 |
+
# FP16 utils
|
5 |
+
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
6 |
+
|
7 |
+
def make_master_params(model_params):
|
8 |
+
"""
|
9 |
+
Copy model parameters into a inflated tensor of full-precision parameters.
|
10 |
+
"""
|
11 |
+
master_params = _flatten_dense_tensors(
|
12 |
+
[param.detach().float() for param in model_params]
|
13 |
+
)
|
14 |
+
master_params = nn.Parameter(master_params)
|
15 |
+
master_params.requires_grad = True
|
16 |
+
return [master_params]
|
17 |
+
|
18 |
+
|
19 |
+
def unflatten_master_params(model_params, master_params):
|
20 |
+
"""
|
21 |
+
Unflatten the master parameters to look like model_params.
|
22 |
+
"""
|
23 |
+
return _unflatten_dense_tensors(master_params[0].detach(), model_params)
|
24 |
+
|
25 |
+
|
26 |
+
def model_params_to_master_params(model_params, master_params):
|
27 |
+
"""
|
28 |
+
Copy the model parameter data into the master parameters.
|
29 |
+
"""
|
30 |
+
master_params[0].detach().copy_(
|
31 |
+
_flatten_dense_tensors([param.detach().float() for param in model_params])
|
32 |
+
)
|
33 |
+
|
34 |
+
|
35 |
+
def master_params_to_model_params(model_params, master_params):
|
36 |
+
"""
|
37 |
+
Copy the master parameter data back into the model parameters.
|
38 |
+
"""
|
39 |
+
for param, master_param in zip(
|
40 |
+
model_params, _unflatten_dense_tensors(master_params[0].detach(), model_params)
|
41 |
+
):
|
42 |
+
param.detach().copy_(master_param)
|
43 |
+
|
44 |
+
|
45 |
+
def model_grads_to_master_grads(model_params, master_params):
|
46 |
+
"""
|
47 |
+
Copy the gradients from the model parameters into the master parameters
|
48 |
+
from make_master_params().
|
49 |
+
"""
|
50 |
+
master_params[0].grad = _flatten_dense_tensors(
|
51 |
+
[param.grad.data.detach().float() for param in model_params]
|
52 |
+
)
|
53 |
+
|
54 |
+
|
55 |
+
def zero_grad(model_params):
|
56 |
+
for param in model_params:
|
57 |
+
if param.grad is not None:
|
58 |
+
if param.grad.grad_fn is not None:
|
59 |
+
param.grad.detach_()
|
60 |
+
else:
|
61 |
+
param.grad.requires_grad_(False)
|
62 |
+
param.grad.zero_()
|
63 |
+
|
64 |
+
|
65 |
+
# LR Schedulers
|
66 |
+
from torch.optim.lr_scheduler import LambdaLR
|
67 |
+
|
68 |
+
class LinearWarmupLRScheduler(LambdaLR):
|
69 |
+
def __init__(self, optimizer, warmup_steps, last_epoch=-1):
|
70 |
+
self.warmup_steps = warmup_steps
|
71 |
+
super(LinearWarmupLRScheduler, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch)
|
72 |
+
|
73 |
+
def lr_lambda(self, current_step):
|
74 |
+
if current_step < self.warmup_steps:
|
75 |
+
return float(current_step + 1) / self.warmup_steps
|
76 |
+
return 1.0
|
77 |
+
|
trellis/trainers/vae/sparse_structure_vae.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import *
|
2 |
+
import copy
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torch.utils.data import DataLoader
|
6 |
+
from easydict import EasyDict as edict
|
7 |
+
|
8 |
+
from ..basic import BasicTrainer
|
9 |
+
|
10 |
+
|
11 |
+
class SparseStructureVaeTrainer(BasicTrainer):
|
12 |
+
"""
|
13 |
+
Trainer for Sparse Structure VAE.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
models (dict[str, nn.Module]): Models to train.
|
17 |
+
dataset (torch.utils.data.Dataset): Dataset.
|
18 |
+
output_dir (str): Output directory.
|
19 |
+
load_dir (str): Load directory.
|
20 |
+
step (int): Step to load.
|
21 |
+
batch_size (int): Batch size.
|
22 |
+
batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
|
23 |
+
batch_split (int): Split batch with gradient accumulation.
|
24 |
+
max_steps (int): Max steps.
|
25 |
+
optimizer (dict): Optimizer config.
|
26 |
+
lr_scheduler (dict): Learning rate scheduler config.
|
27 |
+
elastic (dict): Elastic memory management config.
|
28 |
+
grad_clip (float or dict): Gradient clip config.
|
29 |
+
ema_rate (float or list): Exponential moving average rates.
|
30 |
+
fp16_mode (str): FP16 mode.
|
31 |
+
- None: No FP16.
|
32 |
+
- 'inflat_all': Hold a inflated fp32 master param for all params.
|
33 |
+
- 'amp': Automatic mixed precision.
|
34 |
+
fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
|
35 |
+
finetune_ckpt (dict): Finetune checkpoint.
|
36 |
+
log_param_stats (bool): Log parameter stats.
|
37 |
+
i_print (int): Print interval.
|
38 |
+
i_log (int): Log interval.
|
39 |
+
i_sample (int): Sample interval.
|
40 |
+
i_save (int): Save interval.
|
41 |
+
i_ddpcheck (int): DDP check interval.
|
42 |
+
|
43 |
+
loss_type (str): Loss type. 'bce' for binary cross entropy, 'l1' for L1 loss, 'dice' for Dice loss.
|
44 |
+
lambda_kl (float): KL divergence loss weight.
|
45 |
+
"""
|
46 |
+
|
47 |
+
def __init__(
|
48 |
+
self,
|
49 |
+
*args,
|
50 |
+
loss_type='bce',
|
51 |
+
lambda_kl=1e-6,
|
52 |
+
**kwargs
|
53 |
+
):
|
54 |
+
super().__init__(*args, **kwargs)
|
55 |
+
self.loss_type = loss_type
|
56 |
+
self.lambda_kl = lambda_kl
|
57 |
+
|
58 |
+
def training_losses(
|
59 |
+
self,
|
60 |
+
ss: torch.Tensor,
|
61 |
+
**kwargs
|
62 |
+
) -> Tuple[Dict, Dict]:
|
63 |
+
"""
|
64 |
+
Compute training losses.
|
65 |
+
|
66 |
+
Args:
|
67 |
+
ss: The [N x 1 x H x W x D] tensor of binary sparse structure.
|
68 |
+
|
69 |
+
Returns:
|
70 |
+
a dict with the key "loss" containing a scalar tensor.
|
71 |
+
may also contain other keys for different terms.
|
72 |
+
"""
|
73 |
+
z, mean, logvar = self.training_models['encoder'](ss.float(), sample_posterior=True, return_raw=True)
|
74 |
+
logits = self.training_models['decoder'](z)
|
75 |
+
|
76 |
+
terms = edict(loss = 0.0)
|
77 |
+
if self.loss_type == 'bce':
|
78 |
+
terms["bce"] = F.binary_cross_entropy_with_logits(logits, ss.float(), reduction='mean')
|
79 |
+
terms["loss"] = terms["loss"] + terms["bce"]
|
80 |
+
elif self.loss_type == 'l1':
|
81 |
+
terms["l1"] = F.l1_loss(F.sigmoid(logits), ss.float(), reduction='mean')
|
82 |
+
terms["loss"] = terms["loss"] + terms["l1"]
|
83 |
+
elif self.loss_type == 'dice':
|
84 |
+
logits = F.sigmoid(logits)
|
85 |
+
terms["dice"] = 1 - (2 * (logits * ss.float()).sum() + 1) / (logits.sum() + ss.float().sum() + 1)
|
86 |
+
terms["loss"] = terms["loss"] + terms["dice"]
|
87 |
+
else:
|
88 |
+
raise ValueError(f'Invalid loss type {self.loss_type}')
|
89 |
+
terms["kl"] = 0.5 * torch.mean(mean.pow(2) + logvar.exp() - logvar - 1)
|
90 |
+
terms["loss"] = terms["loss"] + self.lambda_kl * terms["kl"]
|
91 |
+
|
92 |
+
return terms, {}
|
93 |
+
|
94 |
+
@torch.no_grad()
|
95 |
+
def snapshot(self, suffix=None, num_samples=64, batch_size=1, verbose=False):
|
96 |
+
super().snapshot(suffix=suffix, num_samples=num_samples, batch_size=batch_size, verbose=verbose)
|
97 |
+
|
98 |
+
@torch.no_grad()
|
99 |
+
def run_snapshot(
|
100 |
+
self,
|
101 |
+
num_samples: int,
|
102 |
+
batch_size: int,
|
103 |
+
verbose: bool = False,
|
104 |
+
) -> Dict:
|
105 |
+
dataloader = DataLoader(
|
106 |
+
copy.deepcopy(self.dataset),
|
107 |
+
batch_size=batch_size,
|
108 |
+
shuffle=True,
|
109 |
+
num_workers=0,
|
110 |
+
collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None,
|
111 |
+
)
|
112 |
+
|
113 |
+
# inference
|
114 |
+
gts = []
|
115 |
+
recons = []
|
116 |
+
for i in range(0, num_samples, batch_size):
|
117 |
+
batch = min(batch_size, num_samples - i)
|
118 |
+
data = next(iter(dataloader))
|
119 |
+
args = {k: v[:batch].cuda() if isinstance(v, torch.Tensor) else v[:batch] for k, v in data.items()}
|
120 |
+
z = self.models['encoder'](args['ss'].float(), sample_posterior=False)
|
121 |
+
logits = self.models['decoder'](z)
|
122 |
+
recon = (logits > 0).long()
|
123 |
+
gts.append(args['ss'])
|
124 |
+
recons.append(recon)
|
125 |
+
|
126 |
+
sample_dict = {
|
127 |
+
'gt': {'value': torch.cat(gts, dim=0), 'type': 'sample'},
|
128 |
+
'recon': {'value': torch.cat(recons, dim=0), 'type': 'sample'},
|
129 |
+
}
|
130 |
+
return sample_dict
|
trellis/trainers/vae/structured_latent_vae_gaussian.py
ADDED
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import *
|
2 |
+
import copy
|
3 |
+
import torch
|
4 |
+
from torch.utils.data import DataLoader
|
5 |
+
import numpy as np
|
6 |
+
from easydict import EasyDict as edict
|
7 |
+
import utils3d.torch
|
8 |
+
|
9 |
+
from ..basic import BasicTrainer
|
10 |
+
from ...representations import Gaussian
|
11 |
+
from ...renderers import GaussianRenderer
|
12 |
+
from ...modules.sparse import SparseTensor
|
13 |
+
from ...utils.loss_utils import l1_loss, l2_loss, ssim, lpips
|
14 |
+
|
15 |
+
|
16 |
+
class SLatVaeGaussianTrainer(BasicTrainer):
|
17 |
+
"""
|
18 |
+
Trainer for structured latent VAE.
|
19 |
+
|
20 |
+
Args:
|
21 |
+
models (dict[str, nn.Module]): Models to train.
|
22 |
+
dataset (torch.utils.data.Dataset): Dataset.
|
23 |
+
output_dir (str): Output directory.
|
24 |
+
load_dir (str): Load directory.
|
25 |
+
step (int): Step to load.
|
26 |
+
batch_size (int): Batch size.
|
27 |
+
batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
|
28 |
+
batch_split (int): Split batch with gradient accumulation.
|
29 |
+
max_steps (int): Max steps.
|
30 |
+
optimizer (dict): Optimizer config.
|
31 |
+
lr_scheduler (dict): Learning rate scheduler config.
|
32 |
+
elastic (dict): Elastic memory management config.
|
33 |
+
grad_clip (float or dict): Gradient clip config.
|
34 |
+
ema_rate (float or list): Exponential moving average rates.
|
35 |
+
fp16_mode (str): FP16 mode.
|
36 |
+
- None: No FP16.
|
37 |
+
- 'inflat_all': Hold a inflated fp32 master param for all params.
|
38 |
+
- 'amp': Automatic mixed precision.
|
39 |
+
fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
|
40 |
+
finetune_ckpt (dict): Finetune checkpoint.
|
41 |
+
log_param_stats (bool): Log parameter stats.
|
42 |
+
i_print (int): Print interval.
|
43 |
+
i_log (int): Log interval.
|
44 |
+
i_sample (int): Sample interval.
|
45 |
+
i_save (int): Save interval.
|
46 |
+
i_ddpcheck (int): DDP check interval.
|
47 |
+
|
48 |
+
loss_type (str): Loss type. Can be 'l1', 'l2'
|
49 |
+
lambda_ssim (float): SSIM loss weight.
|
50 |
+
lambda_lpips (float): LPIPS loss weight.
|
51 |
+
lambda_kl (float): KL loss weight.
|
52 |
+
regularizations (dict): Regularization config.
|
53 |
+
"""
|
54 |
+
|
55 |
+
def __init__(
|
56 |
+
self,
|
57 |
+
*args,
|
58 |
+
loss_type: str = 'l1',
|
59 |
+
lambda_ssim: float = 0.2,
|
60 |
+
lambda_lpips: float = 0.2,
|
61 |
+
lambda_kl: float = 1e-6,
|
62 |
+
regularizations: Dict = {},
|
63 |
+
**kwargs
|
64 |
+
):
|
65 |
+
super().__init__(*args, **kwargs)
|
66 |
+
self.loss_type = loss_type
|
67 |
+
self.lambda_ssim = lambda_ssim
|
68 |
+
self.lambda_lpips = lambda_lpips
|
69 |
+
self.lambda_kl = lambda_kl
|
70 |
+
self.regularizations = regularizations
|
71 |
+
|
72 |
+
self._init_renderer()
|
73 |
+
|
74 |
+
def _init_renderer(self):
|
75 |
+
rendering_options = {"near" : 0.8,
|
76 |
+
"far" : 1.6,
|
77 |
+
"bg_color" : 'random'}
|
78 |
+
self.renderer = GaussianRenderer(rendering_options)
|
79 |
+
self.renderer.pipe.kernel_size = self.models['decoder'].rep_config['2d_filter_kernel_size']
|
80 |
+
|
81 |
+
def _render_batch(self, reps: List[Gaussian], extrinsics: torch.Tensor, intrinsics: torch.Tensor) -> torch.Tensor:
|
82 |
+
"""
|
83 |
+
Render a batch of representations.
|
84 |
+
|
85 |
+
Args:
|
86 |
+
reps: The dictionary of lists of representations.
|
87 |
+
extrinsics: The [N x 4 x 4] tensor of extrinsics.
|
88 |
+
intrinsics: The [N x 3 x 3] tensor of intrinsics.
|
89 |
+
"""
|
90 |
+
ret = None
|
91 |
+
for i, representation in enumerate(reps):
|
92 |
+
render_pack = self.renderer.render(representation, extrinsics[i], intrinsics[i])
|
93 |
+
if ret is None:
|
94 |
+
ret = {k: [] for k in list(render_pack.keys()) + ['bg_color']}
|
95 |
+
for k, v in render_pack.items():
|
96 |
+
ret[k].append(v)
|
97 |
+
ret['bg_color'].append(self.renderer.bg_color)
|
98 |
+
for k, v in ret.items():
|
99 |
+
ret[k] = torch.stack(v, dim=0)
|
100 |
+
return ret
|
101 |
+
|
102 |
+
@torch.no_grad()
|
103 |
+
def _get_status(self, z: SparseTensor, reps: List[Gaussian]) -> Dict:
|
104 |
+
xyz = torch.cat([g.get_xyz for g in reps], dim=0)
|
105 |
+
xyz_base = (z.coords[:, 1:].float() + 0.5) / self.models['decoder'].resolution - 0.5
|
106 |
+
offset = xyz - xyz_base.unsqueeze(1).expand(-1, self.models['decoder'].rep_config['num_gaussians'], -1).reshape(-1, 3)
|
107 |
+
status = {
|
108 |
+
'xyz': xyz,
|
109 |
+
'offset': offset,
|
110 |
+
'scale': torch.cat([g.get_scaling for g in reps], dim=0),
|
111 |
+
'opacity': torch.cat([g.get_opacity for g in reps], dim=0),
|
112 |
+
}
|
113 |
+
|
114 |
+
for k in list(status.keys()):
|
115 |
+
status[k] = {
|
116 |
+
'mean': status[k].mean().item(),
|
117 |
+
'max': status[k].max().item(),
|
118 |
+
'min': status[k].min().item(),
|
119 |
+
}
|
120 |
+
|
121 |
+
return status
|
122 |
+
|
123 |
+
def _get_regularization_loss(self, reps: List[Gaussian]) -> Tuple[torch.Tensor, Dict]:
|
124 |
+
loss = 0.0
|
125 |
+
terms = {}
|
126 |
+
if 'lambda_vol' in self.regularizations:
|
127 |
+
scales = torch.cat([g.get_scaling for g in reps], dim=0) # [N x 3]
|
128 |
+
volume = torch.prod(scales, dim=1) # [N]
|
129 |
+
terms[f'reg_vol'] = volume.mean()
|
130 |
+
loss = loss + self.regularizations['lambda_vol'] * terms[f'reg_vol']
|
131 |
+
if 'lambda_opacity' in self.regularizations:
|
132 |
+
opacity = torch.cat([g.get_opacity for g in reps], dim=0)
|
133 |
+
terms[f'reg_opacity'] = (opacity - 1).pow(2).mean()
|
134 |
+
loss = loss + self.regularizations['lambda_opacity'] * terms[f'reg_opacity']
|
135 |
+
return loss, terms
|
136 |
+
|
137 |
+
def training_losses(
|
138 |
+
self,
|
139 |
+
feats: SparseTensor,
|
140 |
+
image: torch.Tensor,
|
141 |
+
alpha: torch.Tensor,
|
142 |
+
extrinsics: torch.Tensor,
|
143 |
+
intrinsics: torch.Tensor,
|
144 |
+
return_aux: bool = False,
|
145 |
+
**kwargs
|
146 |
+
) -> Tuple[Dict, Dict]:
|
147 |
+
"""
|
148 |
+
Compute training losses.
|
149 |
+
|
150 |
+
Args:
|
151 |
+
feats: The [N x * x C] sparse tensor of features.
|
152 |
+
image: The [N x 3 x H x W] tensor of images.
|
153 |
+
alpha: The [N x H x W] tensor of alpha channels.
|
154 |
+
extrinsics: The [N x 4 x 4] tensor of extrinsics.
|
155 |
+
intrinsics: The [N x 3 x 3] tensor of intrinsics.
|
156 |
+
return_aux: Whether to return auxiliary information.
|
157 |
+
|
158 |
+
Returns:
|
159 |
+
a dict with the key "loss" containing a scalar tensor.
|
160 |
+
may also contain other keys for different terms.
|
161 |
+
"""
|
162 |
+
z, mean, logvar = self.training_models['encoder'](feats, sample_posterior=True, return_raw=True)
|
163 |
+
reps = self.training_models['decoder'](z)
|
164 |
+
self.renderer.rendering_options.resolution = image.shape[-1]
|
165 |
+
render_results = self._render_batch(reps, extrinsics, intrinsics)
|
166 |
+
|
167 |
+
terms = edict(loss = 0.0, rec = 0.0)
|
168 |
+
|
169 |
+
rec_image = render_results['color']
|
170 |
+
gt_image = image * alpha[:, None] + (1 - alpha[:, None]) * render_results['bg_color'][..., None, None]
|
171 |
+
|
172 |
+
if self.loss_type == 'l1':
|
173 |
+
terms["l1"] = l1_loss(rec_image, gt_image)
|
174 |
+
terms["rec"] = terms["rec"] + terms["l1"]
|
175 |
+
elif self.loss_type == 'l2':
|
176 |
+
terms["l2"] = l2_loss(rec_image, gt_image)
|
177 |
+
terms["rec"] = terms["rec"] + terms["l2"]
|
178 |
+
else:
|
179 |
+
raise ValueError(f"Invalid loss type: {self.loss_type}")
|
180 |
+
if self.lambda_ssim > 0:
|
181 |
+
terms["ssim"] = 1 - ssim(rec_image, gt_image)
|
182 |
+
terms["rec"] = terms["rec"] + self.lambda_ssim * terms["ssim"]
|
183 |
+
if self.lambda_lpips > 0:
|
184 |
+
terms["lpips"] = lpips(rec_image, gt_image)
|
185 |
+
terms["rec"] = terms["rec"] + self.lambda_lpips * terms["lpips"]
|
186 |
+
terms["loss"] = terms["loss"] + terms["rec"]
|
187 |
+
|
188 |
+
terms["kl"] = 0.5 * torch.mean(mean.pow(2) + logvar.exp() - logvar - 1)
|
189 |
+
terms["loss"] = terms["loss"] + self.lambda_kl * terms["kl"]
|
190 |
+
|
191 |
+
reg_loss, reg_terms = self._get_regularization_loss(reps)
|
192 |
+
terms.update(reg_terms)
|
193 |
+
terms["loss"] = terms["loss"] + reg_loss
|
194 |
+
|
195 |
+
status = self._get_status(z, reps)
|
196 |
+
|
197 |
+
if return_aux:
|
198 |
+
return terms, status, {'rec_image': rec_image, 'gt_image': gt_image}
|
199 |
+
return terms, status
|
200 |
+
|
201 |
+
@torch.no_grad()
|
202 |
+
def run_snapshot(
|
203 |
+
self,
|
204 |
+
num_samples: int,
|
205 |
+
batch_size: int,
|
206 |
+
verbose: bool = False,
|
207 |
+
) -> Dict:
|
208 |
+
dataloader = DataLoader(
|
209 |
+
copy.deepcopy(self.dataset),
|
210 |
+
batch_size=batch_size,
|
211 |
+
shuffle=True,
|
212 |
+
num_workers=0,
|
213 |
+
collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None,
|
214 |
+
)
|
215 |
+
|
216 |
+
# inference
|
217 |
+
ret_dict = {}
|
218 |
+
gt_images = []
|
219 |
+
exts = []
|
220 |
+
ints = []
|
221 |
+
reps = []
|
222 |
+
for i in range(0, num_samples, batch_size):
|
223 |
+
batch = min(batch_size, num_samples - i)
|
224 |
+
data = next(iter(dataloader))
|
225 |
+
args = {k: v[:batch].cuda() for k, v in data.items()}
|
226 |
+
gt_images.append(args['image'] * args['alpha'][:, None])
|
227 |
+
exts.append(args['extrinsics'])
|
228 |
+
ints.append(args['intrinsics'])
|
229 |
+
z = self.models['encoder'](args['feats'], sample_posterior=True, return_raw=False)
|
230 |
+
reps.extend(self.models['decoder'](z))
|
231 |
+
gt_images = torch.cat(gt_images, dim=0)
|
232 |
+
ret_dict.update({f'gt_image': {'value': gt_images, 'type': 'image'}})
|
233 |
+
|
234 |
+
# render single view
|
235 |
+
exts = torch.cat(exts, dim=0)
|
236 |
+
ints = torch.cat(ints, dim=0)
|
237 |
+
self.renderer.rendering_options.bg_color = (0, 0, 0)
|
238 |
+
self.renderer.rendering_options.resolution = gt_images.shape[-1]
|
239 |
+
render_results = self._render_batch(reps, exts, ints)
|
240 |
+
ret_dict.update({f'rec_image': {'value': render_results['color'], 'type': 'image'}})
|
241 |
+
|
242 |
+
# render multiview
|
243 |
+
self.renderer.rendering_options.resolution = 512
|
244 |
+
## Build camera
|
245 |
+
yaws = [0, np.pi / 2, np.pi, 3 * np.pi / 2]
|
246 |
+
yaws_offset = np.random.uniform(-np.pi / 4, np.pi / 4)
|
247 |
+
yaws = [y + yaws_offset for y in yaws]
|
248 |
+
pitch = [np.random.uniform(-np.pi / 4, np.pi / 4) for _ in range(4)]
|
249 |
+
|
250 |
+
## render each view
|
251 |
+
miltiview_images = []
|
252 |
+
for yaw, pitch in zip(yaws, pitch):
|
253 |
+
orig = torch.tensor([
|
254 |
+
np.sin(yaw) * np.cos(pitch),
|
255 |
+
np.cos(yaw) * np.cos(pitch),
|
256 |
+
np.sin(pitch),
|
257 |
+
]).float().cuda() * 2
|
258 |
+
fov = torch.deg2rad(torch.tensor(30)).cuda()
|
259 |
+
extrinsics = utils3d.torch.extrinsics_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda())
|
260 |
+
intrinsics = utils3d.torch.intrinsics_from_fov_xy(fov, fov)
|
261 |
+
extrinsics = extrinsics.unsqueeze(0).expand(num_samples, -1, -1)
|
262 |
+
intrinsics = intrinsics.unsqueeze(0).expand(num_samples, -1, -1)
|
263 |
+
render_results = self._render_batch(reps, extrinsics, intrinsics)
|
264 |
+
miltiview_images.append(render_results['color'])
|
265 |
+
|
266 |
+
## Concatenate views
|
267 |
+
miltiview_images = torch.cat([
|
268 |
+
torch.cat(miltiview_images[:2], dim=-2),
|
269 |
+
torch.cat(miltiview_images[2:], dim=-2),
|
270 |
+
], dim=-1)
|
271 |
+
ret_dict.update({f'miltiview_image': {'value': miltiview_images, 'type': 'image'}})
|
272 |
+
|
273 |
+
self.renderer.rendering_options.bg_color = 'random'
|
274 |
+
|
275 |
+
return ret_dict
|
trellis/trainers/vae/structured_latent_vae_mesh_dec.py
ADDED
@@ -0,0 +1,382 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import *
|
2 |
+
import copy
|
3 |
+
import torch
|
4 |
+
from torch.utils.data import DataLoader
|
5 |
+
import numpy as np
|
6 |
+
from easydict import EasyDict as edict
|
7 |
+
import utils3d.torch
|
8 |
+
|
9 |
+
from ..basic import BasicTrainer
|
10 |
+
from ...representations import MeshExtractResult
|
11 |
+
from ...renderers import MeshRenderer
|
12 |
+
from ...modules.sparse import SparseTensor
|
13 |
+
from ...utils.loss_utils import l1_loss, smooth_l1_loss, ssim, lpips
|
14 |
+
from ...utils.data_utils import recursive_to_device
|
15 |
+
|
16 |
+
|
17 |
+
class SLatVaeMeshDecoderTrainer(BasicTrainer):
|
18 |
+
"""
|
19 |
+
Trainer for structured latent VAE Mesh Decoder.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
models (dict[str, nn.Module]): Models to train.
|
23 |
+
dataset (torch.utils.data.Dataset): Dataset.
|
24 |
+
output_dir (str): Output directory.
|
25 |
+
load_dir (str): Load directory.
|
26 |
+
step (int): Step to load.
|
27 |
+
batch_size (int): Batch size.
|
28 |
+
batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
|
29 |
+
batch_split (int): Split batch with gradient accumulation.
|
30 |
+
max_steps (int): Max steps.
|
31 |
+
optimizer (dict): Optimizer config.
|
32 |
+
lr_scheduler (dict): Learning rate scheduler config.
|
33 |
+
elastic (dict): Elastic memory management config.
|
34 |
+
grad_clip (float or dict): Gradient clip config.
|
35 |
+
ema_rate (float or list): Exponential moving average rates.
|
36 |
+
fp16_mode (str): FP16 mode.
|
37 |
+
- None: No FP16.
|
38 |
+
- 'inflat_all': Hold a inflated fp32 master param for all params.
|
39 |
+
- 'amp': Automatic mixed precision.
|
40 |
+
fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
|
41 |
+
finetune_ckpt (dict): Finetune checkpoint.
|
42 |
+
log_param_stats (bool): Log parameter stats.
|
43 |
+
i_print (int): Print interval.
|
44 |
+
i_log (int): Log interval.
|
45 |
+
i_sample (int): Sample interval.
|
46 |
+
i_save (int): Save interval.
|
47 |
+
i_ddpcheck (int): DDP check interval.
|
48 |
+
|
49 |
+
loss_type (str): Loss type. Can be 'l1', 'l2'
|
50 |
+
lambda_ssim (float): SSIM loss weight.
|
51 |
+
lambda_lpips (float): LPIPS loss weight.
|
52 |
+
"""
|
53 |
+
|
54 |
+
def __init__(
|
55 |
+
self,
|
56 |
+
*args,
|
57 |
+
depth_loss_type: str = 'l1',
|
58 |
+
lambda_depth: int = 1,
|
59 |
+
lambda_ssim: float = 0.2,
|
60 |
+
lambda_lpips: float = 0.2,
|
61 |
+
lambda_tsdf: float = 0.01,
|
62 |
+
lambda_color: float = 0.1,
|
63 |
+
**kwargs
|
64 |
+
):
|
65 |
+
super().__init__(*args, **kwargs)
|
66 |
+
self.depth_loss_type = depth_loss_type
|
67 |
+
self.lambda_depth = lambda_depth
|
68 |
+
self.lambda_ssim = lambda_ssim
|
69 |
+
self.lambda_lpips = lambda_lpips
|
70 |
+
self.lambda_tsdf = lambda_tsdf
|
71 |
+
self.lambda_color = lambda_color
|
72 |
+
self.use_color = self.lambda_color > 0
|
73 |
+
|
74 |
+
self._init_renderer()
|
75 |
+
|
76 |
+
def _init_renderer(self):
|
77 |
+
rendering_options = {"near" : 1,
|
78 |
+
"far" : 3}
|
79 |
+
self.renderer = MeshRenderer(rendering_options, device=self.device)
|
80 |
+
|
81 |
+
def _render_batch(self, reps: List[MeshExtractResult], extrinsics: torch.Tensor, intrinsics: torch.Tensor,
|
82 |
+
return_types=['mask', 'normal', 'depth']) -> Dict[str, torch.Tensor]:
|
83 |
+
"""
|
84 |
+
Render a batch of representations.
|
85 |
+
|
86 |
+
Args:
|
87 |
+
reps: The dictionary of lists of representations.
|
88 |
+
extrinsics: The [N x 4 x 4] tensor of extrinsics.
|
89 |
+
intrinsics: The [N x 3 x 3] tensor of intrinsics.
|
90 |
+
return_types: vary in ['mask', 'normal', 'depth', 'normal_map', 'color']
|
91 |
+
|
92 |
+
Returns:
|
93 |
+
a dict with
|
94 |
+
reg_loss : [N] tensor of regularization losses
|
95 |
+
mask : [N x 1 x H x W] tensor of rendered masks
|
96 |
+
normal : [N x 3 x H x W] tensor of rendered normals
|
97 |
+
depth : [N x 1 x H x W] tensor of rendered depths
|
98 |
+
"""
|
99 |
+
ret = {k : [] for k in return_types}
|
100 |
+
for i, rep in enumerate(reps):
|
101 |
+
out_dict = self.renderer.render(rep, extrinsics[i], intrinsics[i], return_types=return_types)
|
102 |
+
for k in out_dict:
|
103 |
+
ret[k].append(out_dict[k][None] if k in ['mask', 'depth'] else out_dict[k])
|
104 |
+
for k in ret:
|
105 |
+
ret[k] = torch.stack(ret[k])
|
106 |
+
return ret
|
107 |
+
|
108 |
+
@staticmethod
|
109 |
+
def _tsdf_reg_loss(rep: MeshExtractResult, depth_map: torch.Tensor, extrinsics: torch.Tensor, intrinsics: torch.Tensor) -> torch.Tensor:
|
110 |
+
# Calculate tsdf
|
111 |
+
with torch.no_grad():
|
112 |
+
# Project points to camera and calculate pseudo-sdf as difference between gt depth and projected depth
|
113 |
+
projected_pts, pts_depth = utils3d.torch.project_cv(extrinsics=extrinsics, intrinsics=intrinsics, points=rep.tsdf_v)
|
114 |
+
projected_pts = (projected_pts - 0.5) * 2.0
|
115 |
+
depth_map_res = depth_map.shape[1]
|
116 |
+
gt_depth = torch.nn.functional.grid_sample(depth_map.reshape(1, 1, depth_map_res, depth_map_res),
|
117 |
+
projected_pts.reshape(1, 1, -1, 2), mode='bilinear', padding_mode='border', align_corners=True)
|
118 |
+
pseudo_sdf = gt_depth.flatten() - pts_depth.flatten()
|
119 |
+
# Truncate pseudo-sdf
|
120 |
+
delta = 1 / rep.res * 3.0
|
121 |
+
trunc_mask = pseudo_sdf > -delta
|
122 |
+
|
123 |
+
# Loss
|
124 |
+
gt_tsdf = pseudo_sdf[trunc_mask]
|
125 |
+
tsdf = rep.tsdf_s.flatten()[trunc_mask]
|
126 |
+
gt_tsdf = torch.clamp(gt_tsdf, -delta, delta)
|
127 |
+
return torch.mean((tsdf - gt_tsdf) ** 2)
|
128 |
+
|
129 |
+
def _calc_tsdf_loss(self, reps : list[MeshExtractResult], depth_maps, extrinsics, intrinsics) -> torch.Tensor:
|
130 |
+
tsdf_loss = 0.0
|
131 |
+
for i, rep in enumerate(reps):
|
132 |
+
tsdf_loss += self._tsdf_reg_loss(rep, depth_maps[i], extrinsics[i], intrinsics[i])
|
133 |
+
return tsdf_loss / len(reps)
|
134 |
+
|
135 |
+
@torch.no_grad()
|
136 |
+
def _flip_normal(self, normal: torch.Tensor, extrinsics: torch.Tensor, intrinsics: torch.Tensor) -> torch.Tensor:
|
137 |
+
"""
|
138 |
+
Flip normal to align with camera.
|
139 |
+
"""
|
140 |
+
normal = normal * 2.0 - 1.0
|
141 |
+
R = torch.zeros_like(extrinsics)
|
142 |
+
R[:, :3, :3] = extrinsics[:, :3, :3]
|
143 |
+
R[:, 3, 3] = 1.0
|
144 |
+
view_dir = utils3d.torch.unproject_cv(
|
145 |
+
utils3d.torch.image_uv(*normal.shape[-2:], device=self.device).reshape(1, -1, 2),
|
146 |
+
torch.ones(*normal.shape[-2:], device=self.device).reshape(1, -1),
|
147 |
+
R, intrinsics
|
148 |
+
).reshape(-1, *normal.shape[-2:], 3).permute(0, 3, 1, 2)
|
149 |
+
unflip = (normal * view_dir).sum(1, keepdim=True) < 0
|
150 |
+
normal *= unflip * 2.0 - 1.0
|
151 |
+
return (normal + 1.0) / 2.0
|
152 |
+
|
153 |
+
def _perceptual_loss(self, gt: torch.Tensor, pred: torch.Tensor, name: str) -> Dict[str, torch.Tensor]:
|
154 |
+
"""
|
155 |
+
Combination of L1, SSIM, and LPIPS loss.
|
156 |
+
"""
|
157 |
+
if gt.shape[1] != 3:
|
158 |
+
assert gt.shape[-1] == 3
|
159 |
+
gt = gt.permute(0, 3, 1, 2)
|
160 |
+
if pred.shape[1] != 3:
|
161 |
+
assert pred.shape[-1] == 3
|
162 |
+
pred = pred.permute(0, 3, 1, 2)
|
163 |
+
terms = {
|
164 |
+
f"{name}_loss" : l1_loss(gt, pred),
|
165 |
+
f"{name}_loss_ssim" : 1 - ssim(gt, pred),
|
166 |
+
f"{name}_loss_lpips" : lpips(gt, pred)
|
167 |
+
}
|
168 |
+
terms[f"{name}_loss_perceptual"] = terms[f"{name}_loss"] + terms[f"{name}_loss_ssim"] * self.lambda_ssim + terms[f"{name}_loss_lpips"] * self.lambda_lpips
|
169 |
+
return terms
|
170 |
+
|
171 |
+
def geometry_losses(
|
172 |
+
self,
|
173 |
+
reps: List[MeshExtractResult],
|
174 |
+
mesh: List[Dict],
|
175 |
+
normal_map: torch.Tensor,
|
176 |
+
extrinsics: torch.Tensor,
|
177 |
+
intrinsics: torch.Tensor,
|
178 |
+
):
|
179 |
+
with torch.no_grad():
|
180 |
+
gt_meshes = []
|
181 |
+
for i in range(len(reps)):
|
182 |
+
gt_mesh = MeshExtractResult(mesh[i]['vertices'].to(self.device), mesh[i]['faces'].to(self.device))
|
183 |
+
gt_meshes.append(gt_mesh)
|
184 |
+
target = self._render_batch(gt_meshes, extrinsics, intrinsics, return_types=['mask', 'depth', 'normal'])
|
185 |
+
target['normal'] = self._flip_normal(target['normal'], extrinsics, intrinsics)
|
186 |
+
|
187 |
+
terms = edict(geo_loss = 0.0)
|
188 |
+
if self.lambda_tsdf > 0:
|
189 |
+
tsdf_loss = self._calc_tsdf_loss(reps, target['depth'], extrinsics, intrinsics)
|
190 |
+
terms['tsdf_loss'] = tsdf_loss
|
191 |
+
terms['geo_loss'] += tsdf_loss * self.lambda_tsdf
|
192 |
+
|
193 |
+
return_types = ['mask', 'depth', 'normal', 'normal_map'] if self.use_color else ['mask', 'depth', 'normal']
|
194 |
+
buffer = self._render_batch(reps, extrinsics, intrinsics, return_types=return_types)
|
195 |
+
|
196 |
+
success_mask = torch.tensor([rep.success for rep in reps], device=self.device)
|
197 |
+
if success_mask.sum() != 0:
|
198 |
+
for k, v in buffer.items():
|
199 |
+
buffer[k] = v[success_mask]
|
200 |
+
for k, v in target.items():
|
201 |
+
target[k] = v[success_mask]
|
202 |
+
|
203 |
+
terms['mask_loss'] = l1_loss(buffer['mask'], target['mask'])
|
204 |
+
if self.depth_loss_type == 'l1':
|
205 |
+
terms['depth_loss'] = l1_loss(buffer['depth'] * target['mask'], target['depth'] * target['mask'])
|
206 |
+
elif self.depth_loss_type == 'smooth_l1':
|
207 |
+
terms['depth_loss'] = smooth_l1_loss(buffer['depth'] * target['mask'], target['depth'] * target['mask'], beta=1.0 / (2 * reps[0].res))
|
208 |
+
else:
|
209 |
+
raise ValueError(f"Unsupported depth loss type: {self.depth_loss_type}")
|
210 |
+
terms.update(self._perceptual_loss(buffer['normal'] * target['mask'], target['normal'] * target['mask'], 'normal'))
|
211 |
+
terms['geo_loss'] = terms['geo_loss'] + terms['mask_loss'] + terms['depth_loss'] * self.lambda_depth + terms['normal_loss_perceptual']
|
212 |
+
if self.use_color and normal_map is not None:
|
213 |
+
terms.update(self._perceptual_loss(normal_map[success_mask], buffer['normal_map'], 'normal_map'))
|
214 |
+
terms['geo_loss'] = terms['geo_loss'] + terms['normal_map_loss_perceptual'] * self.lambda_color
|
215 |
+
|
216 |
+
return terms
|
217 |
+
|
218 |
+
def color_losses(self, reps, image, alpha, extrinsics, intrinsics):
|
219 |
+
terms = edict(color_loss = torch.tensor(0.0, device=self.device))
|
220 |
+
buffer = self._render_batch(reps, extrinsics, intrinsics, return_types=['color'])
|
221 |
+
success_mask = torch.tensor([rep.success for rep in reps], device=self.device)
|
222 |
+
if success_mask.sum() != 0:
|
223 |
+
terms.update(self._perceptual_loss(image * alpha[:, None][success_mask], buffer['color'][success_mask], 'color'))
|
224 |
+
terms['color_loss'] = terms['color_loss'] + terms['color_loss_perceptual'] * self.lambda_color
|
225 |
+
return terms
|
226 |
+
|
227 |
+
def training_losses(
|
228 |
+
self,
|
229 |
+
latents: SparseTensor,
|
230 |
+
image: torch.Tensor,
|
231 |
+
alpha: torch.Tensor,
|
232 |
+
mesh: List[Dict],
|
233 |
+
extrinsics: torch.Tensor,
|
234 |
+
intrinsics: torch.Tensor,
|
235 |
+
normal_map: torch.Tensor = None,
|
236 |
+
) -> Tuple[Dict, Dict]:
|
237 |
+
"""
|
238 |
+
Compute training losses.
|
239 |
+
|
240 |
+
Args:
|
241 |
+
latents: The [N x * x C] sparse latents
|
242 |
+
image: The [N x 3 x H x W] tensor of images.
|
243 |
+
alpha: The [N x H x W] tensor of alpha channels.
|
244 |
+
mesh: The list of dictionaries of meshes.
|
245 |
+
extrinsics: The [N x 4 x 4] tensor of extrinsics.
|
246 |
+
intrinsics: The [N x 3 x 3] tensor of intrinsics.
|
247 |
+
|
248 |
+
Returns:
|
249 |
+
a dict with the key "loss" containing a scalar tensor.
|
250 |
+
may also contain other keys for different terms.
|
251 |
+
"""
|
252 |
+
reps = self.training_models['decoder'](latents)
|
253 |
+
self.renderer.rendering_options.resolution = image.shape[-1]
|
254 |
+
|
255 |
+
terms = edict(loss = 0.0, rec = 0.0)
|
256 |
+
|
257 |
+
terms['reg_loss'] = sum([rep.reg_loss for rep in reps]) / len(reps)
|
258 |
+
terms['loss'] = terms['loss'] + terms['reg_loss']
|
259 |
+
|
260 |
+
geo_terms = self.geometry_losses(reps, mesh, normal_map, extrinsics, intrinsics)
|
261 |
+
terms.update(geo_terms)
|
262 |
+
terms['loss'] = terms['loss'] + terms['geo_loss']
|
263 |
+
|
264 |
+
if self.use_color:
|
265 |
+
color_terms = self.color_losses(reps, image, alpha, extrinsics, intrinsics)
|
266 |
+
terms.update(color_terms)
|
267 |
+
terms['loss'] = terms['loss'] + terms['color_loss']
|
268 |
+
|
269 |
+
return terms, {}
|
270 |
+
|
271 |
+
@torch.no_grad()
|
272 |
+
def run_snapshot(
|
273 |
+
self,
|
274 |
+
num_samples: int,
|
275 |
+
batch_size: int,
|
276 |
+
verbose: bool = False,
|
277 |
+
) -> Dict:
|
278 |
+
dataloader = DataLoader(
|
279 |
+
copy.deepcopy(self.dataset),
|
280 |
+
batch_size=batch_size,
|
281 |
+
shuffle=True,
|
282 |
+
num_workers=0,
|
283 |
+
collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None,
|
284 |
+
)
|
285 |
+
|
286 |
+
# inference
|
287 |
+
ret_dict = {}
|
288 |
+
gt_images = []
|
289 |
+
gt_normal_maps = []
|
290 |
+
gt_meshes = []
|
291 |
+
exts = []
|
292 |
+
ints = []
|
293 |
+
reps = []
|
294 |
+
for i in range(0, num_samples, batch_size):
|
295 |
+
batch = min(batch_size, num_samples - i)
|
296 |
+
data = next(iter(dataloader))
|
297 |
+
args = recursive_to_device(data, 'cuda')
|
298 |
+
gt_images.append(args['image'] * args['alpha'][:, None])
|
299 |
+
if self.use_color and 'normal_map' in data:
|
300 |
+
gt_normal_maps.append(args['normal_map'])
|
301 |
+
gt_meshes.extend(args['mesh'])
|
302 |
+
exts.append(args['extrinsics'])
|
303 |
+
ints.append(args['intrinsics'])
|
304 |
+
reps.extend(self.models['decoder'](args['latents']))
|
305 |
+
gt_images = torch.cat(gt_images, dim=0)
|
306 |
+
ret_dict.update({f'gt_image': {'value': gt_images, 'type': 'image'}})
|
307 |
+
if self.use_color and gt_normal_maps:
|
308 |
+
gt_normal_maps = torch.cat(gt_normal_maps, dim=0)
|
309 |
+
ret_dict.update({f'gt_normal_map': {'value': gt_normal_maps, 'type': 'image'}})
|
310 |
+
|
311 |
+
# render single view
|
312 |
+
exts = torch.cat(exts, dim=0)
|
313 |
+
ints = torch.cat(ints, dim=0)
|
314 |
+
self.renderer.rendering_options.bg_color = (0, 0, 0)
|
315 |
+
self.renderer.rendering_options.resolution = gt_images.shape[-1]
|
316 |
+
gt_render_results = self._render_batch([
|
317 |
+
MeshExtractResult(vertices=mesh['vertices'].to(self.device), faces=mesh['faces'].to(self.device))
|
318 |
+
for mesh in gt_meshes
|
319 |
+
], exts, ints, return_types=['normal'])
|
320 |
+
ret_dict.update({f'gt_normal': {'value': self._flip_normal(gt_render_results['normal'], exts, ints), 'type': 'image'}})
|
321 |
+
return_types = ['normal']
|
322 |
+
if self.use_color:
|
323 |
+
return_types.append('color')
|
324 |
+
if 'normal_map' in data:
|
325 |
+
return_types.append('normal_map')
|
326 |
+
render_results = self._render_batch(reps, exts, ints, return_types=return_types)
|
327 |
+
ret_dict.update({f'rec_normal': {'value': render_results['normal'], 'type': 'image'}})
|
328 |
+
if 'color' in return_types:
|
329 |
+
ret_dict.update({f'rec_image': {'value': render_results['color'], 'type': 'image'}})
|
330 |
+
if 'normal_map' in return_types:
|
331 |
+
ret_dict.update({f'rec_normal_map': {'value': render_results['normal_map'], 'type': 'image'}})
|
332 |
+
|
333 |
+
# render multiview
|
334 |
+
self.renderer.rendering_options.resolution = 512
|
335 |
+
## Build camera
|
336 |
+
yaws = [0, np.pi / 2, np.pi, 3 * np.pi / 2]
|
337 |
+
yaws_offset = np.random.uniform(-np.pi / 4, np.pi / 4)
|
338 |
+
yaws = [y + yaws_offset for y in yaws]
|
339 |
+
pitch = [np.random.uniform(-np.pi / 4, np.pi / 4) for _ in range(4)]
|
340 |
+
|
341 |
+
## render each view
|
342 |
+
multiview_normals = []
|
343 |
+
multiview_normal_maps = []
|
344 |
+
miltiview_images = []
|
345 |
+
for yaw, pitch in zip(yaws, pitch):
|
346 |
+
orig = torch.tensor([
|
347 |
+
np.sin(yaw) * np.cos(pitch),
|
348 |
+
np.cos(yaw) * np.cos(pitch),
|
349 |
+
np.sin(pitch),
|
350 |
+
]).float().cuda() * 2
|
351 |
+
fov = torch.deg2rad(torch.tensor(30)).cuda()
|
352 |
+
extrinsics = utils3d.torch.extrinsics_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda())
|
353 |
+
intrinsics = utils3d.torch.intrinsics_from_fov_xy(fov, fov)
|
354 |
+
extrinsics = extrinsics.unsqueeze(0).expand(num_samples, -1, -1)
|
355 |
+
intrinsics = intrinsics.unsqueeze(0).expand(num_samples, -1, -1)
|
356 |
+
render_results = self._render_batch(reps, extrinsics, intrinsics, return_types=return_types)
|
357 |
+
multiview_normals.append(render_results['normal'])
|
358 |
+
if 'color' in return_types:
|
359 |
+
miltiview_images.append(render_results['color'])
|
360 |
+
if 'normal_map' in return_types:
|
361 |
+
multiview_normal_maps.append(render_results['normal_map'])
|
362 |
+
|
363 |
+
## Concatenate views
|
364 |
+
multiview_normals = torch.cat([
|
365 |
+
torch.cat(multiview_normals[:2], dim=-2),
|
366 |
+
torch.cat(multiview_normals[2:], dim=-2),
|
367 |
+
], dim=-1)
|
368 |
+
ret_dict.update({f'multiview_normal': {'value': multiview_normals, 'type': 'image'}})
|
369 |
+
if 'color' in return_types:
|
370 |
+
miltiview_images = torch.cat([
|
371 |
+
torch.cat(miltiview_images[:2], dim=-2),
|
372 |
+
torch.cat(miltiview_images[2:], dim=-2),
|
373 |
+
], dim=-1)
|
374 |
+
ret_dict.update({f'multiview_image': {'value': miltiview_images, 'type': 'image'}})
|
375 |
+
if 'normal_map' in return_types:
|
376 |
+
multiview_normal_maps = torch.cat([
|
377 |
+
torch.cat(multiview_normal_maps[:2], dim=-2),
|
378 |
+
torch.cat(multiview_normal_maps[2:], dim=-2),
|
379 |
+
], dim=-1)
|
380 |
+
ret_dict.update({f'multiview_normal_map': {'value': multiview_normal_maps, 'type': 'image'}})
|
381 |
+
|
382 |
+
return ret_dict
|