zhuhai111 commited on
Commit
ddbac99
·
verified ·
1 Parent(s): 75b2ef1

Upload mesh_renderer.py

Browse files
Files changed (1) hide show
  1. trellis/renderers/mesh_renderer.py +187 -140
trellis/renderers/mesh_renderer.py CHANGED
@@ -1,140 +1,187 @@
1
- # Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- #
3
- # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
4
- # and proprietary rights in and to this software, related documentation
5
- # and any modifications thereto. Any use, reproduction, disclosure or
6
- # distribution of this software and related documentation without an express
7
- # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
8
- import torch
9
- import nvdiffrast.torch as dr
10
- from easydict import EasyDict as edict
11
- from ..representations.mesh import MeshExtractResult
12
- import torch.nn.functional as F
13
-
14
-
15
- def intrinsics_to_projection(
16
- intrinsics: torch.Tensor,
17
- near: float,
18
- far: float,
19
- ) -> torch.Tensor:
20
- """
21
- OpenCV intrinsics to OpenGL perspective matrix
22
-
23
- Args:
24
- intrinsics (torch.Tensor): [3, 3] OpenCV intrinsics matrix
25
- near (float): near plane to clip
26
- far (float): far plane to clip
27
- Returns:
28
- (torch.Tensor): [4, 4] OpenGL perspective matrix
29
- """
30
- fx, fy = intrinsics[0, 0], intrinsics[1, 1]
31
- cx, cy = intrinsics[0, 2], intrinsics[1, 2]
32
- ret = torch.zeros((4, 4), dtype=intrinsics.dtype, device=intrinsics.device)
33
- ret[0, 0] = 2 * fx
34
- ret[1, 1] = 2 * fy
35
- ret[0, 2] = 2 * cx - 1
36
- ret[1, 2] = - 2 * cy + 1
37
- ret[2, 2] = far / (far - near)
38
- ret[2, 3] = near * far / (near - far)
39
- ret[3, 2] = 1.
40
- return ret
41
-
42
-
43
- class MeshRenderer:
44
- """
45
- Renderer for the Mesh representation.
46
-
47
- Args:
48
- rendering_options (dict): Rendering options.
49
- glctx (nvdiffrast.torch.RasterizeGLContext): RasterizeGLContext object for CUDA/OpenGL interop.
50
- """
51
- def __init__(self, rendering_options={}, device='cuda'):
52
- self.rendering_options = edict({
53
- "resolution": None,
54
- "near": None,
55
- "far": None,
56
- "ssaa": 1
57
- })
58
- self.rendering_options.update(rendering_options)
59
- self.glctx = dr.RasterizeCudaContext(device=device)
60
- self.device=device
61
-
62
- def render(
63
- self,
64
- mesh : MeshExtractResult,
65
- extrinsics: torch.Tensor,
66
- intrinsics: torch.Tensor,
67
- return_types = ["mask", "normal", "depth", "color"]
68
- ) -> edict:
69
- """
70
- Render the mesh.
71
-
72
- Args:
73
- mesh : meshmodel
74
- extrinsics (torch.Tensor): (4, 4) camera extrinsics
75
- intrinsics (torch.Tensor): (3, 3) camera intrinsics
76
- return_types (list): list of return types, can be "mask", "depth", "normal_map", "normal", "color"
77
-
78
- Returns:
79
- edict based on return_types containing:
80
- color (torch.Tensor): [3, H, W] rendered color image
81
- depth (torch.Tensor): [H, W] rendered depth image
82
- normal (torch.Tensor): [3, H, W] rendered normal image
83
- normal_map (torch.Tensor): [3, H, W] rendered normal map image
84
- mask (torch.Tensor): [H, W] rendered mask image
85
- """
86
- resolution = self.rendering_options["resolution"]
87
- near = self.rendering_options["near"]
88
- far = self.rendering_options["far"]
89
- ssaa = self.rendering_options["ssaa"]
90
-
91
- if mesh.vertices.shape[0] == 0 or mesh.faces.shape[0] == 0:
92
- default_img = torch.zeros((1, resolution, resolution, 3), dtype=torch.float32, device=self.device)
93
- ret_dict = {k : default_img if k in ['normal', 'normal_map', 'color'] else default_img[..., :1] for k in return_types}
94
- return ret_dict
95
-
96
- perspective = intrinsics_to_projection(intrinsics, near, far)
97
-
98
- RT = extrinsics.unsqueeze(0)
99
- full_proj = (perspective @ extrinsics).unsqueeze(0)
100
-
101
- vertices = mesh.vertices.unsqueeze(0)
102
-
103
- vertices_homo = torch.cat([vertices, torch.ones_like(vertices[..., :1])], dim=-1)
104
- vertices_camera = torch.bmm(vertices_homo, RT.transpose(-1, -2))
105
- vertices_clip = torch.bmm(vertices_homo, full_proj.transpose(-1, -2))
106
- faces_int = mesh.faces.int()
107
- rast, _ = dr.rasterize(
108
- self.glctx, vertices_clip, faces_int, (resolution * ssaa, resolution * ssaa))
109
-
110
- out_dict = edict()
111
- for type in return_types:
112
- img = None
113
- if type == "mask" :
114
- img = dr.antialias((rast[..., -1:] > 0).float(), rast, vertices_clip, faces_int)
115
- elif type == "depth":
116
- img = dr.interpolate(vertices_camera[..., 2:3].contiguous(), rast, faces_int)[0]
117
- img = dr.antialias(img, rast, vertices_clip, faces_int)
118
- elif type == "normal" :
119
- img = dr.interpolate(
120
- mesh.face_normal.reshape(1, -1, 3), rast,
121
- torch.arange(mesh.faces.shape[0] * 3, device=self.device, dtype=torch.int).reshape(-1, 3)
122
- )[0]
123
- img = dr.antialias(img, rast, vertices_clip, faces_int)
124
- # normalize norm pictures
125
- img = (img + 1) / 2
126
- elif type == "normal_map" :
127
- img = dr.interpolate(mesh.vertex_attrs[:, 3:].contiguous(), rast, faces_int)[0]
128
- img = dr.antialias(img, rast, vertices_clip, faces_int)
129
- elif type == "color" :
130
- img = dr.interpolate(mesh.vertex_attrs[:, :3].contiguous(), rast, faces_int)[0]
131
- img = dr.antialias(img, rast, vertices_clip, faces_int)
132
-
133
- if ssaa > 1:
134
- img = F.interpolate(img.permute(0, 3, 1, 2), (resolution, resolution), mode='bilinear', align_corners=False, antialias=True)
135
- img = img.squeeze()
136
- else:
137
- img = img.permute(0, 3, 1, 2).squeeze()
138
- out_dict[type] = img
139
-
140
- return out_dict
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
8
+ import torch
9
+ import os
10
+ from easydict import EasyDict as edict
11
+ from ..representations.mesh import MeshExtractResult
12
+ import torch.nn.functional as F
13
+
14
+ # CPU environment check
15
+ CPU_ONLY = os.environ.get('CPU_ONLY', '0') == '1' or not torch.cuda.is_available()
16
+
17
+ # Conditional import for nvdiffrast
18
+ if not CPU_ONLY:
19
+ try:
20
+ import nvdiffrast.torch as dr
21
+ HAS_NVDIFFRAST = True
22
+ except ImportError:
23
+ HAS_NVDIFFRAST = False
24
+ else:
25
+ HAS_NVDIFFRAST = False
26
+
27
+ def intrinsics_to_projection(
28
+ intrinsics: torch.Tensor,
29
+ near: float,
30
+ far: float,
31
+ ) -> torch.Tensor:
32
+ """
33
+ OpenCV intrinsics to OpenGL perspective matrix
34
+
35
+ Args:
36
+ intrinsics (torch.Tensor): [3, 3] OpenCV intrinsics matrix
37
+ near (float): near plane to clip
38
+ far (float): far plane to clip
39
+ Returns:
40
+ (torch.Tensor): [4, 4] OpenGL perspective matrix
41
+ """
42
+ fx, fy = intrinsics[0, 0], intrinsics[1, 1]
43
+ cx, cy = intrinsics[0, 2], intrinsics[1, 2]
44
+ ret = torch.zeros((4, 4), dtype=intrinsics.dtype, device=intrinsics.device)
45
+ ret[0, 0] = 2 * fx
46
+ ret[1, 1] = 2 * fy
47
+ ret[0, 2] = 2 * cx - 1
48
+ ret[1, 2] = - 2 * cy + 1
49
+ ret[2, 2] = far / (far - near)
50
+ ret[2, 3] = near * far / (near - far)
51
+ ret[3, 2] = 1.
52
+ return ret
53
+
54
+
55
+ class MeshRenderer:
56
+ """
57
+ Renderer for the Mesh representation.
58
+
59
+ Args:
60
+ rendering_options (dict): Rendering options.
61
+ glctx (nvdiffrast.torch.RasterizeGLContext): RasterizeGLContext object for CUDA/OpenGL interop.
62
+ """
63
+ def __init__(self, rendering_options={}, device='cuda'):
64
+ self.rendering_options = edict({
65
+ "resolution": None,
66
+ "near": None,
67
+ "far": None,
68
+ "ssaa": 1
69
+ })
70
+ self.rendering_options.update(rendering_options)
71
+ self.device = device
72
+
73
+ # Set up renderer based on environment
74
+ if HAS_NVDIFFRAST and device != 'cpu':
75
+ self.glctx = dr.RasterizeCudaContext(device=device)
76
+ self.use_cpu_fallback = False
77
+ else:
78
+ # CPU fallback mode
79
+ self.use_cpu_fallback = True
80
+ print("[WARNING] Using CPU fallback renderer. Rendering will be simplified.")
81
+
82
+ def render(
83
+ self,
84
+ mesh : MeshExtractResult,
85
+ extrinsics: torch.Tensor,
86
+ intrinsics: torch.Tensor,
87
+ return_types = ["mask", "normal", "depth", "color"]
88
+ ) -> edict:
89
+ """
90
+ Render the mesh.
91
+
92
+ Args:
93
+ mesh : meshmodel
94
+ extrinsics (torch.Tensor): (4, 4) camera extrinsics
95
+ intrinsics (torch.Tensor): (3, 3) camera intrinsics
96
+ return_types (list): list of return types, can be "mask", "depth", "normal_map", "normal", "color"
97
+
98
+ Returns:
99
+ edict based on return_types containing:
100
+ color (torch.Tensor): [3, H, W] rendered color image
101
+ depth (torch.Tensor): [H, W] rendered depth image
102
+ normal (torch.Tensor): [3, H, W] rendered normal image
103
+ normal_map (torch.Tensor): [3, H, W] rendered normal map image
104
+ mask (torch.Tensor): [H, W] rendered mask image
105
+ """
106
+ resolution = self.rendering_options["resolution"]
107
+ near = self.rendering_options["near"]
108
+ far = self.rendering_options["far"]
109
+ ssaa = self.rendering_options["ssaa"]
110
+
111
+ if mesh.vertices.shape[0] == 0 or mesh.faces.shape[0] == 0:
112
+ default_img = torch.zeros((1, resolution, resolution, 3), dtype=torch.float32, device=self.device)
113
+ ret_dict = {k : default_img if k in ['normal', 'normal_map', 'color'] else default_img[..., :1] for k in return_types}
114
+ return ret_dict
115
+
116
+ # CPU fallback rendering - simplified version
117
+ if self.use_cpu_fallback:
118
+ out_dict = edict()
119
+
120
+ # Create simplified outputs for CPU mode
121
+ for type in return_types:
122
+ if type in ["normal", "normal_map", "color"]:
123
+ # Create a basic color output
124
+ base_color = torch.zeros((3, resolution, resolution), dtype=torch.float32, device=self.device)
125
+ if type == "normal":
126
+ # Simple light blue for normal map
127
+ base_color[0] = 0.5 # R
128
+ base_color[1] = 0.5 # G
129
+ base_color[2] = 1.0 # B
130
+ elif type == "color":
131
+ # Simple gray for color
132
+ base_color[0] = 0.7 # R
133
+ base_color[1] = 0.7 # G
134
+ base_color[2] = 0.7 # B
135
+ out_dict[type] = base_color
136
+ else:
137
+ # For mask and depth, create a simple placeholder
138
+ out_dict[type] = torch.ones((1, resolution, resolution), dtype=torch.float32, device=self.device)
139
+
140
+ return out_dict
141
+
142
+ # GPU rendering with nvdiffrast
143
+ perspective = intrinsics_to_projection(intrinsics, near, far)
144
+
145
+ RT = extrinsics.unsqueeze(0)
146
+ full_proj = (perspective @ extrinsics).unsqueeze(0)
147
+
148
+ vertices = mesh.vertices.unsqueeze(0)
149
+
150
+ vertices_homo = torch.cat([vertices, torch.ones_like(vertices[..., :1])], dim=-1)
151
+ vertices_camera = torch.bmm(vertices_homo, RT.transpose(-1, -2))
152
+ vertices_clip = torch.bmm(vertices_homo, full_proj.transpose(-1, -2))
153
+ faces_int = mesh.faces.int()
154
+ rast, _ = dr.rasterize(
155
+ self.glctx, vertices_clip, faces_int, (resolution * ssaa, resolution * ssaa))
156
+
157
+ out_dict = edict()
158
+ for type in return_types:
159
+ img = None
160
+ if type == "mask" :
161
+ img = dr.antialias((rast[..., -1:] > 0).float(), rast, vertices_clip, faces_int)
162
+ elif type == "depth":
163
+ img = dr.interpolate(vertices_camera[..., 2:3].contiguous(), rast, faces_int)[0]
164
+ img = dr.antialias(img, rast, vertices_clip, faces_int)
165
+ elif type == "normal" :
166
+ img = dr.interpolate(
167
+ mesh.face_normal.reshape(1, -1, 3), rast,
168
+ torch.arange(mesh.faces.shape[0] * 3, device=self.device, dtype=torch.int).reshape(-1, 3)
169
+ )[0]
170
+ img = dr.antialias(img, rast, vertices_clip, faces_int)
171
+ # normalize norm pictures
172
+ img = (img + 1) / 2
173
+ elif type == "normal_map" :
174
+ img = dr.interpolate(mesh.vertex_attrs[:, 3:].contiguous(), rast, faces_int)[0]
175
+ img = dr.antialias(img, rast, vertices_clip, faces_int)
176
+ elif type == "color" :
177
+ img = dr.interpolate(mesh.vertex_attrs[:, :3].contiguous(), rast, faces_int)[0]
178
+ img = dr.antialias(img, rast, vertices_clip, faces_int)
179
+
180
+ if ssaa > 1:
181
+ img = F.interpolate(img.permute(0, 3, 1, 2), (resolution, resolution), mode='bilinear', align_corners=False, antialias=True)
182
+ img = img.squeeze()
183
+ else:
184
+ img = img.permute(0, 3, 1, 2).squeeze()
185
+ out_dict[type] = img
186
+
187
+ return out_dict