xinjie.wang commited on
Commit
4a6c4ad
·
1 Parent(s): c417f1a
Files changed (2) hide show
  1. app.py +18 -29
  2. common.py +23 -8
app.py CHANGED
@@ -10,7 +10,9 @@ from common import (
10
  extract_3d_representations_v2,
11
  extract_urdf,
12
  get_seed,
 
13
  image_to_3d,
 
14
  preprocess_image_fn,
15
  preprocess_sam_image_fn,
16
  select_point,
@@ -29,6 +31,8 @@ with gr.Blocks(
29
  The service is temporarily deployed on `dev015-10.34.8.82: CUDA 4`.
30
  """
31
  )
 
 
32
  with gr.Row():
33
  with gr.Column(scale=2):
34
  with gr.Tabs() as input_tabs:
@@ -41,23 +45,13 @@ with gr.Blocks(
41
  type="pil",
42
  visible=False,
43
  )
44
-
45
- image_css = """
46
- <style>
47
- #img-fit .image-frame {
48
- object-fit: contain !important;
49
- height: 100% !important;
50
- }
51
- </style>
52
- """
53
- gr.HTML(image_css)
54
  image_prompt = gr.Image(
55
  label="Input Image",
56
  format="png",
57
  image_mode="RGBA",
58
  type="pil",
59
- height=300,
60
- elem_id="img-fit",
61
  )
62
  gr.Markdown(
63
  """
@@ -70,7 +64,10 @@ with gr.Blocks(
70
  with gr.Row():
71
  with gr.Column(scale=1):
72
  image_prompt_sam = gr.Image(
73
- label="Input Image", type="numpy", height=500
 
 
 
74
  )
75
  image_seg_sam = gr.Image(
76
  label="SAM Seg Image",
@@ -80,7 +77,9 @@ with gr.Blocks(
80
  visible=False,
81
  )
82
  with gr.Column(scale=1):
83
- image_mask_sam = gr.AnnotatedImage()
 
 
84
 
85
  fg_bg_radio = gr.Radio(
86
  ["foreground_point", "background_point"],
@@ -238,26 +237,17 @@ with gr.Blocks(
238
  label="Gaussian Representation", height=300, interactive=False
239
  )
240
  aligned_gs = gr.Textbox(visible=False)
241
-
242
- lighting_css = """
243
- <style>
244
- #lighter_mesh canvas {
245
- filter: brightness(2.8) !important;
246
- }
247
- </style>
248
- """
249
- gr.HTML(lighting_css)
250
  with gr.Row():
251
  model_output_mesh = gr.Model3D(
252
  label="Mesh Representation",
253
  height=300,
254
  interactive=False,
255
- clear_color=[1, 1, 1, 1],
256
- elem_id="lighter_mesh"
257
  )
258
- gr.Markdown(
259
- """ The rendering of `Gaussian Representation` takes additional 10s. """ # noqa
260
- )
261
 
262
  is_samimage = gr.State(False)
263
  output_buf = gr.State()
@@ -456,4 +446,3 @@ with gr.Blocks(
456
 
457
  if __name__ == "__main__":
458
  demo.launch()
459
-
 
10
  extract_3d_representations_v2,
11
  extract_urdf,
12
  get_seed,
13
+ image_css,
14
  image_to_3d,
15
+ lighting_css,
16
  preprocess_image_fn,
17
  preprocess_sam_image_fn,
18
  select_point,
 
31
  The service is temporarily deployed on `dev015-10.34.8.82: CUDA 4`.
32
  """
33
  )
34
+ gr.HTML(image_css)
35
+ gr.HTML(lighting_css)
36
  with gr.Row():
37
  with gr.Column(scale=2):
38
  with gr.Tabs() as input_tabs:
 
45
  type="pil",
46
  visible=False,
47
  )
 
 
 
 
 
 
 
 
 
 
48
  image_prompt = gr.Image(
49
  label="Input Image",
50
  format="png",
51
  image_mode="RGBA",
52
  type="pil",
53
+ height=400,
54
+ elem_classes=["image_fit"],
55
  )
56
  gr.Markdown(
57
  """
 
64
  with gr.Row():
65
  with gr.Column(scale=1):
66
  image_prompt_sam = gr.Image(
67
+ label="Input Image",
68
+ type="numpy",
69
+ height=400,
70
+ elem_classes=["image_fit"],
71
  )
72
  image_seg_sam = gr.Image(
73
  label="SAM Seg Image",
 
77
  visible=False,
78
  )
79
  with gr.Column(scale=1):
80
+ image_mask_sam = gr.AnnotatedImage(
81
+ elem_classes=["image_fit"]
82
+ )
83
 
84
  fg_bg_radio = gr.Radio(
85
  ["foreground_point", "background_point"],
 
237
  label="Gaussian Representation", height=300, interactive=False
238
  )
239
  aligned_gs = gr.Textbox(visible=False)
240
+ gr.Markdown(
241
+ """ The rendering of `Gaussian Representation` takes additional 10s. """ # noqa
242
+ )
 
 
 
 
 
 
243
  with gr.Row():
244
  model_output_mesh = gr.Model3D(
245
  label="Mesh Representation",
246
  height=300,
247
  interactive=False,
248
+ clear_color=[0.9, 0.9, 0.9, 1],
249
+ elem_id="lighter_mesh",
250
  )
 
 
 
251
 
252
  is_samimage = gr.State(False)
253
  output_buf = gr.State()
 
446
 
447
  if __name__ == "__main__":
448
  demo.launch()
 
common.py CHANGED
@@ -127,6 +127,24 @@ elif os.getenv("GRADIO_APP") == "texture_edit":
127
  os.makedirs(TMP_DIR, exist_ok=True)
128
 
129
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  def start_session(req: gr.Request) -> None:
131
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
132
  os.makedirs(user_dir, exist_ok=True)
@@ -193,13 +211,13 @@ def render_video(
193
 
194
  @spaces.GPU
195
  def preprocess_image_fn(
196
- image: str | np.ndarray | Image.Image
197
  ) -> tuple[Image.Image, Image.Image]:
198
  if isinstance(image, str):
199
  image = Image.open(image)
200
  elif isinstance(image, np.ndarray):
201
  image = Image.fromarray(image)
202
-
203
  image_cache = image.copy().resize((512, 512))
204
 
205
  image = RBG_REMOVER(image)
@@ -208,9 +226,8 @@ def preprocess_image_fn(
208
  return image, image_cache
209
 
210
 
211
- # @spaces.GPU
212
  def preprocess_sam_image_fn(
213
- image: Image.Image
214
  ) -> tuple[Image.Image, Image.Image]:
215
  if isinstance(image, np.ndarray):
216
  image = Image.fromarray(image)
@@ -304,7 +321,6 @@ def get_seed(randomize_seed: bool, seed: int, max_seed: int = MAX_SEED) -> int:
304
  return np.random.randint(0, max_seed) if randomize_seed else seed
305
 
306
 
307
- # @spaces.GPU
308
  def select_point(
309
  image: np.ndarray,
310
  sel_pix: list,
@@ -333,7 +349,7 @@ def select_point(
333
  thickness=10,
334
  )
335
 
336
- # torch.cuda.empty_cache()
337
 
338
  return (image, masks), seg_image
339
 
@@ -387,8 +403,7 @@ def image_to_3d(
387
  mesh_model = outputs["mesh"][0]
388
  color_images = render_video(gs_model)["color"]
389
  normal_images = render_video(mesh_model)["normal"]
390
-
391
-
392
  video_path = os.path.join(output_root, "gs_mesh.mp4")
393
  merge_images_video(color_images, normal_images, video_path)
394
  state = pack_state(gs_model, mesh_model)
 
127
  os.makedirs(TMP_DIR, exist_ok=True)
128
 
129
 
130
+ lighting_css = """
131
+ <style>
132
+ #lighter_mesh canvas {
133
+ filter: brightness(2.5) !important;
134
+ }
135
+ </style>
136
+ """
137
+
138
+ image_css = """
139
+ <style>
140
+ .image_fit .image-frame {
141
+ object-fit: contain !important;
142
+ height: 100% !important;
143
+ }
144
+ </style>
145
+ """
146
+
147
+
148
  def start_session(req: gr.Request) -> None:
149
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
150
  os.makedirs(user_dir, exist_ok=True)
 
211
 
212
  @spaces.GPU
213
  def preprocess_image_fn(
214
+ image: str | np.ndarray | Image.Image,
215
  ) -> tuple[Image.Image, Image.Image]:
216
  if isinstance(image, str):
217
  image = Image.open(image)
218
  elif isinstance(image, np.ndarray):
219
  image = Image.fromarray(image)
220
+
221
  image_cache = image.copy().resize((512, 512))
222
 
223
  image = RBG_REMOVER(image)
 
226
  return image, image_cache
227
 
228
 
 
229
  def preprocess_sam_image_fn(
230
+ image: Image.Image,
231
  ) -> tuple[Image.Image, Image.Image]:
232
  if isinstance(image, np.ndarray):
233
  image = Image.fromarray(image)
 
321
  return np.random.randint(0, max_seed) if randomize_seed else seed
322
 
323
 
 
324
  def select_point(
325
  image: np.ndarray,
326
  sel_pix: list,
 
349
  thickness=10,
350
  )
351
 
352
+ torch.cuda.empty_cache()
353
 
354
  return (image, masks), seg_image
355
 
 
403
  mesh_model = outputs["mesh"][0]
404
  color_images = render_video(gs_model)["color"]
405
  normal_images = render_video(mesh_model)["normal"]
406
+
 
407
  video_path = os.path.join(output_root, "gs_mesh.mp4")
408
  merge_images_video(color_images, normal_images, video_path)
409
  state = pack_state(gs_model, mesh_model)