zerchen commited on
Commit
b32d972
·
1 Parent(s): 5d50546

update theme

Browse files
Files changed (2) hide show
  1. app.py +18 -8
  2. hort/utils/renderer.py +8 -0
app.py CHANGED
@@ -12,6 +12,7 @@ from ultralytics import YOLO
12
  from pathlib import Path
13
  import argparse
14
  import json
 
15
  from torchvision import transforms
16
  from typing import Dict, Optional
17
  from PIL import Image, ImageDraw
@@ -175,13 +176,17 @@ def run_model(image, conf, IoU_threshold=0.5):
175
 
176
  reconstructions = {'verts': verts, 'palm': palm, 'objtrans': objtrans, 'objpcs': pointclouds_up, 'cam_t': cam_t, 'right': is_right, 'img_size': 224, 'focal': scaled_focal_length}
177
 
178
- return crop_img_cv2[..., ::-1].astype(np.float32) / 255.0, len(detections), reconstructions
179
- else:
180
- return crop_img_cv2[..., ::-1].astype(np.float32) / 255.0, len(detections), None
 
 
 
181
 
 
182
 
183
  def render_reconstruction(image, conf, IoU_threshold=0.3):
184
- input_img, num_dets, reconstructions = run_model(image, conf, IoU_threshold=0.5)
185
  # Render front view
186
  misc_args = dict(mesh_base_color=LIGHT_PURPLE, point_base_color=STEEL_BLUE, scene_bg_color=(1, 1, 1), focal_length=reconstructions['focal'])
187
  cam_view = renderer.render_rgba(reconstructions['verts'], reconstructions['objpcs'] + reconstructions['palm'] + reconstructions['objtrans'], cam_t=reconstructions['cam_t'], render_res=(224, 224), is_right=True, **misc_args)
@@ -190,7 +195,7 @@ def render_reconstruction(image, conf, IoU_threshold=0.3):
190
  input_img = np.concatenate([input_img, np.ones_like(input_img[:,:,:1])], axis=2) # Add alpha channel
191
  input_img_overlay = input_img[:,:,:3] * (1-cam_view[:,:,3:]) + cam_view[:,:,:3] * cam_view[:,:,3:]
192
 
193
- return input_img_overlay, f'{num_dets} hands detected'
194
 
195
 
196
  header = ('''
@@ -215,8 +220,12 @@ header = ('''
215
  <a href='https://github.com/zerchen/hort'><img src='https://img.shields.io/badge/GitHub-Code-black?style=flat&logo=github&logoColor=white'></a>
216
  ''')
217
 
218
-
219
- with gr.Blocks(title="HORT: Monocular Hand-held Objects Reconstruction with Transformers", css=".gradio-container") as demo:
 
 
 
 
220
 
221
  gr.Markdown(header)
222
 
@@ -229,9 +238,10 @@ with gr.Blocks(title="HORT: Monocular Hand-held Objects Reconstruction with Tran
229
 
230
  with gr.Column():
231
  reconstruction = gr.Image(label="Reconstructions", type="numpy")
 
232
  hands_detected = gr.Textbox(label="Hands Detected")
233
 
234
- submit.click(fn=render_reconstruction, inputs=[input_image, threshold], outputs=[reconstruction, hands_detected])
235
 
236
  with gr.Row():
237
  example_images = gr.Examples([
 
12
  from pathlib import Path
13
  import argparse
14
  import json
15
+ import trimesh
16
  from torchvision import transforms
17
  from typing import Dict, Optional
18
  from PIL import Image, ImageDraw
 
176
 
177
  reconstructions = {'verts': verts, 'palm': palm, 'objtrans': objtrans, 'objpcs': pointclouds_up, 'cam_t': cam_t, 'right': is_right, 'img_size': 224, 'focal': scaled_focal_length}
178
 
179
+ camera_translation = cam_t.copy()
180
+ hand_mesh = renderer.mesh(verts, camera_translation, LIGHT_PURPLE, is_right=is_right)
181
+ obj_pcd = trimesh.PointCloud(reconstructions['objpcs'] + reconstructions['palm'] + reconstructions['objtrans'] + camera_translation, colors=[70, 130, 180, 255])
182
+ scene = trimesh.Scene([hand_mesh, obj_pcd])
183
+ scene_path = "/tmp/test.glb"
184
+ scene.export(scene_path)
185
 
186
+ return crop_img_cv2[..., ::-1].astype(np.float32) / 255.0, len(detections), reconstructions, scene_path
187
 
188
  def render_reconstruction(image, conf, IoU_threshold=0.3):
189
+ input_img, num_dets, reconstructions, scene_path = run_model(image, conf, IoU_threshold=0.5)
190
  # Render front view
191
  misc_args = dict(mesh_base_color=LIGHT_PURPLE, point_base_color=STEEL_BLUE, scene_bg_color=(1, 1, 1), focal_length=reconstructions['focal'])
192
  cam_view = renderer.render_rgba(reconstructions['verts'], reconstructions['objpcs'] + reconstructions['palm'] + reconstructions['objtrans'], cam_t=reconstructions['cam_t'], render_res=(224, 224), is_right=True, **misc_args)
 
195
  input_img = np.concatenate([input_img, np.ones_like(input_img[:,:,:1])], axis=2) # Add alpha channel
196
  input_img_overlay = input_img[:,:,:3] * (1-cam_view[:,:,3:]) + cam_view[:,:,:3] * cam_view[:,:,3:]
197
 
198
+ return input_img_overlay, f'{num_dets} hands detected', scene_path
199
 
200
 
201
  header = ('''
 
220
  <a href='https://github.com/zerchen/hort'><img src='https://img.shields.io/badge/GitHub-Code-black?style=flat&logo=github&logoColor=white'></a>
221
  ''')
222
 
223
+ theme = gr.themes.Ocean()
224
+ theme.set(
225
+ checkbox_label_background_fill_selected="*button_primary_background_fill",
226
+ checkbox_label_text_color_selected="*button_primary_text_color",
227
+ )
228
+ with gr.Blocks(theme=theme, title="HORT: Monocular Hand-held Objects Reconstruction with Transformers", css=".gradio-container") as demo:
229
 
230
  gr.Markdown(header)
231
 
 
238
 
239
  with gr.Column():
240
  reconstruction = gr.Image(label="Reconstructions", type="numpy")
241
+ output_meshes = gr.Model3D(height=300, zoom_speed=0.5, pan_speed=0.5)
242
  hands_detected = gr.Textbox(label="Hands Detected")
243
 
244
+ submit.click(fn=render_reconstruction, inputs=[input_image, threshold], outputs=[reconstruction, hands_detected, output_meshes])
245
 
246
  with gr.Row():
247
  example_images = gr.Examples([
hort/utils/renderer.py CHANGED
@@ -280,6 +280,14 @@ class Renderer:
280
  mesh.apply_transform(rot)
281
  return mesh
282
 
 
 
 
 
 
 
 
 
283
  def render_rgba(
284
  self,
285
  vertices: np.array,
 
280
  mesh.apply_transform(rot)
281
  return mesh
282
 
283
+ def mesh(self, vertices, camera_translation, mesh_base_color=(1.0, 1.0, 0.9), rot_axis=[1,0,0], rot_angle=0, is_right=1):
284
+ vertex_colors = np.array([(*mesh_base_color, 1.0)] * vertices.shape[0])
285
+ if is_right:
286
+ mesh = trimesh.Trimesh(vertices.copy() + camera_translation, self.faces.copy(), vertex_colors=vertex_colors)
287
+ else:
288
+ mesh = trimesh.Trimesh(vertices.copy() + camera_translation, self.faces_left.copy(), vertex_colors=vertex_colors)
289
+ return mesh
290
+
291
  def render_rgba(
292
  self,
293
  vertices: np.array,