FQiao commited on
Commit
a0c86ca
·
verified ·
1 Parent(s): 24be5a8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +307 -126
app.py CHANGED
@@ -1,24 +1,24 @@
1
- import os
2
- from os.path import basename, splitext, join
3
  import tempfile
4
- import gradio as gr
 
 
 
5
  import numpy as np
 
6
  from PIL import Image
 
7
  import torch
8
- import cv2
9
  from torchvision.transforms.functional import to_tensor, to_pil_image
10
- from torch import Tensor
11
- from genstereo import GenStereo, AdaptiveFusionLayer
12
- import ssl
13
  from huggingface_hub import hf_hub_download
14
- import spaces
15
 
16
- from extern.DAM2.depth_anything_v2.dpt import DepthAnythingV2
17
- ssl._create_default_https_context = ssl._create_unverified_context
18
 
19
- IMAGE_SIZE = 512
20
- DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
21
- CHECKPOINT_NAME = 'genstereo'
22
 
23
  def download_models():
24
  models = [
@@ -37,17 +37,10 @@ def download_models():
37
  'token': None
38
  },
39
  {
40
- 'repo': 'FQiao/GenStereo',
41
- 'sub': None,
42
- 'dst': 'checkpoints/genstereo',
43
- 'files': ['config.json', 'denoising_unet.pth', 'fusion_layer.pth', 'pose_guider.pth', 'reference_unet.pth'],
44
- 'token': None
45
- },
46
- {
47
- 'repo': 'depth-anything/Depth-Anything-V2-Large',
48
- 'sub': None,
49
  'dst': 'checkpoints',
50
- 'files': [f'depth_anything_v2_vitl.pth'],
51
  'token': None
52
  }
53
  ]
@@ -65,43 +58,53 @@ def download_models():
65
  # Setup.
66
  download_models()
67
 
68
- # DepthAnythingV2
69
- def get_dam2_model():
70
- model_configs = {
71
- 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
72
- 'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
73
- 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
74
- }
75
-
76
- encoder = 'vitl'
77
- encoder_size_map = {'vits': 'Small', 'vitb': 'Base', 'vitl': 'Large'}
78
-
79
- if encoder not in encoder_size_map:
80
- raise ValueError(f"Unsupported encoder: {encoder}. Supported: {list(encoder_size_map.keys())}")
81
-
82
- dam2 = DepthAnythingV2(**model_configs[encoder])
83
- dam2_checkpoint = f'checkpoints/depth_anything_v2_{encoder}.pth'
84
- dam2.load_state_dict(torch.load(dam2_checkpoint, map_location='cpu'))
85
- dam2 = dam2.to(DEVICE).eval()
86
- return dam2
87
-
88
- # GenStereo
89
- def get_genstereo_model():
90
- genwarp_cfg = dict(
91
- pretrained_model_path='checkpoints',
92
- checkpoint_name=CHECKPOINT_NAME,
93
- half_precision_weights=True
94
- )
95
- genstereo = GenStereo(cfg=genwarp_cfg, device=DEVICE)
96
- return genstereo
97
-
98
- # Adaptive Fusion
99
- def get_fusion_model():
100
- fusion_model = AdaptiveFusionLayer()
101
- fusion_checkpoint = join('checkpoints', CHECKPOINT_NAME, 'fusion_layer.pth')
102
- fusion_model.load_state_dict(torch.load(fusion_checkpoint, map_location='cpu'))
103
- fusion_model = fusion_model.to(DEVICE).eval()
104
- return fusion_model
 
 
 
 
 
 
 
 
 
 
105
 
106
  # Crop the image to the shorter side.
107
  def crop(img: Image) -> Image:
@@ -112,64 +115,211 @@ def crop(img: Image) -> Image:
112
  else:
113
  left, right = np.ceil((W - H) / 2.), np.floor((W - H) / 2.) + H
114
  top, bottom = 0, H
115
- return img.crop((left, top, right, bottom))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
- # Gradio app
118
  with tempfile.TemporaryDirectory() as tmpdir:
119
  with gr.Blocks(
120
- title='StereoGen Demo',
121
  css='img {display: inline;}'
122
  ) as demo:
123
  # Internal states.
124
- src_image = gr.State()
125
- src_depth = gr.State()
126
-
127
- def normalize_disp(disp):
128
- return (disp - disp.min()) / (disp.max() - disp.min())
129
 
130
  # Callbacks
131
- @spaces.GPU()
132
  def cb_mde(image_file: str):
133
- if not image_file:
134
- # Return None if no image is provided (e.g., when file is cleared).
135
- return None, None, None, None
 
 
 
 
136
 
137
- image = crop(Image.open(image_file).convert('RGB')) # Load image using PIL
138
- image = image.resize((IMAGE_SIZE, IMAGE_SIZE))
139
-
140
- image_bgr = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
141
-
142
- dam2 = get_dam2_model()
143
- depth_dam2 = dam2.infer_image(image_bgr)
144
- depth = torch.tensor(depth_dam2).unsqueeze(0).unsqueeze(0).float()
145
-
146
- depth_image = cv2.applyColorMap((normalize_disp(depth_dam2) * 255).astype(np.uint8), cv2.COLORMAP_JET)
147
-
148
- return image, depth_image, image, depth
149
 
150
  @spaces.GPU()
151
- def cb_generate(image, depth: Tensor, scale_factor):
152
- norm_disp = normalize_disp(depth.cuda())
153
- disp = norm_disp * scale_factor / 100 * IMAGE_SIZE
 
154
 
155
- genstereo = get_genstereo_model()
156
- fusion_model = get_fusion_model()
 
 
 
 
 
 
 
 
 
 
 
 
157
 
158
- renders = genstereo(
159
- src_image=image,
160
- src_disparity=disp,
161
- ratio=None,
 
 
 
162
  )
163
- warped = (renders['warped'] + 1) / 2
164
-
165
- synthesized = renders['synthesized']
166
- mask = renders['mask']
167
- fusion_image = fusion_model(synthesized.float(), warped.float(), mask.float())
168
 
169
- warped_pil = to_pil_image(warped[0])
170
- fusion_pil = to_pil_image(fusion_image[0])
171
 
172
- return warped_pil, fusion_pil
 
 
 
 
 
 
 
 
 
 
 
 
 
173
 
174
  # Blocks.
175
  gr.Markdown(
@@ -180,55 +330,86 @@ with tempfile.TemporaryDirectory() as tmpdir:
180
  [![Github](https://img.shields.io/badge/Github-Repo-orange?logo=github)](https://github.com/sony/genwarp/)  
181
  [![Models](https://img.shields.io/badge/Models-checkpoints-blue?logo=huggingface)](https://huggingface.co/Sony/genwarp)  
182
  [![arXiv](https://img.shields.io/badge/arXiv-2405.17251-red?logo=arxiv)](https://arxiv.org/abs/2405.17251)
183
-
184
  ## Introduction
185
  This is an official demo for the paper "[GenWarp: Single Image to Novel Views with Semantic-Preserving Generative Warping](https://genwarp-nvs.github.io/)". Genwarp can generate novel view images from a single input conditioned on camera poses. In this demo, we offer a basic use of inference of the model. For detailed information, please refer to the [paper](https://arxiv.org/abs/2405.17251).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  """
187
  )
188
- file = gr.File(label='Left', file_types=['image'])
189
- examples = gr.Examples(
190
- examples=['./assets/COCO_val2017_000000070229.jpg',
191
- './assets/COCO_val2017_000000092839.jpg',
192
- './assets/KITTI2015_000003_10.png',
193
- './assets/KITTI2015_000147_10.png'],
194
- inputs=file
195
- )
196
  with gr.Row():
197
  image_widget = gr.Image(
198
- label='Depth', type='filepath',
199
  interactive=False
200
  )
201
  depth_widget = gr.Image(label='Estimated Depth', type='pil')
202
-
203
- # Add scale factor slider
204
- scale_slider = gr.Slider(
205
- label='Scale Factor',
206
- minimum=1.0,
207
- maximum=30.0,
208
- value=15.0,
209
- step=0.1,
210
- )
211
-
212
- button = gr.Button('Generate a right image', size='lg', variant='primary')
 
213
  with gr.Row():
214
  warped_widget = gr.Image(
215
  label='Warped Image', type='pil', interactive=False
216
  )
217
  gen_widget = gr.Image(
218
- label='Generated Right', type='pil', interactive=False
219
  )
 
 
 
 
 
 
 
 
 
 
 
220
 
221
  # Events
222
- file.change(
223
  fn=cb_mde,
224
  inputs=file,
225
- outputs=[image_widget, depth_widget, src_image, src_depth]
 
 
 
 
226
  )
227
  button.click(
228
  fn=cb_generate,
229
- inputs=[src_image, src_depth, scale_slider],
230
  outputs=[warped_widget, gen_widget]
231
  )
 
 
 
 
 
 
232
 
233
  if __name__ == '__main__':
234
  demo.launch()
 
1
+ import sys
2
+ from subprocess import check_call
3
  import tempfile
4
+
5
+ from os.path import basename, splitext, join
6
+ from io import BytesIO
7
+
8
  import numpy as np
9
+ from scipy.spatial import KDTree
10
  from PIL import Image
11
+
12
  import torch
13
+ import torch.nn.functional as F
14
  from torchvision.transforms.functional import to_tensor, to_pil_image
15
+ from einops import rearrange
16
+ import gradio as gr
 
17
  from huggingface_hub import hf_hub_download
 
18
 
19
+ from extern.ZoeDepth.zoedepth.utils.misc import colorize
 
20
 
21
+ from gradio_model3dgscamera import Model3DGSCamera
 
 
22
 
23
  def download_models():
24
  models = [
 
37
  'token': None
38
  },
39
  {
40
+ 'repo': 'Sony/genwarp',
41
+ 'sub': 'multi1',
 
 
 
 
 
 
 
42
  'dst': 'checkpoints',
43
+ 'files': ['config.json', 'denoising_unet.pth', 'pose_guider.pth', 'reference_unet.pth'],
44
  'token': None
45
  }
46
  ]
 
58
  # Setup.
59
  download_models()
60
 
61
+ mde = torch.hub.load(
62
+ './extern/ZoeDepth',
63
+ 'ZoeD_N',
64
+ source='local',
65
+ pretrained=True,
66
+ trust_repo=True
67
+ )
68
+
69
+ import spaces
70
+
71
+ check_call([
72
+ sys.executable, '-m', 'pip', 'install',
73
+ 'extern/splatting-0.0.1-py3-none-any.whl'
74
+ ])
75
+
76
+ from genwarp import GenWarp
77
+ from genwarp.ops import (
78
+ camera_lookat, get_projection_matrix, get_viewport_matrix
79
+ )
80
+
81
+ # GenWarp
82
+ genwarp_cfg = dict(
83
+ pretrained_model_path='checkpoints',
84
+ checkpoint_name='multi1',
85
+ half_precision_weights=True
86
+ )
87
+ genwarp_nvs = GenWarp(cfg=genwarp_cfg, device='cpu')
88
+
89
+ # Fixed parameters.
90
+ IMAGE_SIZE = 512
91
+ NEAR, FAR = 0.01, 100
92
+ FOVY = np.deg2rad(55)
93
+ PROJ_MTX = get_projection_matrix(
94
+ fovy=torch.ones(1) * FOVY,
95
+ aspect_wh=1.,
96
+ near=NEAR,
97
+ far=FAR
98
+ )
99
+ VIEW_MTX = camera_lookat(
100
+ torch.tensor([[0., 0., 0.]]),
101
+ torch.tensor([[0., 0., 1.]]),
102
+ torch.tensor([[0., -1., 0.]])
103
+ )
104
+ VIEWPORT_MTX = get_viewport_matrix(
105
+ IMAGE_SIZE, IMAGE_SIZE,
106
+ batch_size=1
107
+ )
108
 
109
  # Crop the image to the shorter side.
110
  def crop(img: Image) -> Image:
 
115
  else:
116
  left, right = np.ceil((W - H) / 2.), np.floor((W - H) / 2.) + H
117
  top, bottom = 0, H
118
+ img = img.crop((left, top, right, bottom))
119
+ img = img.resize((IMAGE_SIZE, IMAGE_SIZE))
120
+ return img
121
+
122
+ def save_as_splat(
123
+ filepath: str,
124
+ xyz: np.ndarray,
125
+ rgb: np.ndarray
126
+ ):
127
+ # To gaussian splat
128
+ inv_sigmoid = lambda x: np.log(x / (1 - x))
129
+ dist2 = np.clip(calc_dist2(xyz), a_min=0.0000001, a_max=None)
130
+ scales = np.repeat(np.log(np.sqrt(dist2))[..., np.newaxis], 3, axis=1)
131
+ rots = np.zeros((xyz.shape[0], 4))
132
+ rots[:, 0] = 1
133
+ opacities = inv_sigmoid(0.1 * np.ones((xyz.shape[0], 1)))
134
+
135
+ sorted_indices = np.argsort((
136
+ -np.exp(np.sum(scales, axis=-1, keepdims=True))
137
+ / (1 + np.exp(-opacities))
138
+ ).squeeze())
139
+
140
+ buffer = BytesIO()
141
+ for idx in sorted_indices:
142
+ position = xyz[idx]
143
+ scale = np.exp(scales[idx]).astype(np.float32)
144
+ rot = rots[idx].astype(np.float32)
145
+ color = np.concatenate(
146
+ (rgb[idx], 1 / (1 + np.exp(-opacities[idx]))),
147
+ axis=-1
148
+ )
149
+ buffer.write(position.tobytes())
150
+ buffer.write(scale.tobytes())
151
+ buffer.write((color * 255).clip(0, 255).astype(np.uint8).tobytes())
152
+ buffer.write(
153
+ ((rot / np.linalg.norm(rot)) * 128 + 128)
154
+ .clip(0, 255)
155
+ .astype(np.uint8)
156
+ .tobytes()
157
+ )
158
+
159
+ with open(filepath, "wb") as f:
160
+ f.write(buffer.getvalue())
161
+
162
+ def calc_dist2(points: np.ndarray):
163
+ dists, _ = KDTree(points).query(points, k=4)
164
+ mean_dists = (dists[:, 1:] ** 2).mean(1)
165
+ return mean_dists
166
+
167
+ def unproject(depth):
168
+ H, W = depth.shape[2:4]
169
+ mean_depth = depth.mean(dim=(2, 3)).squeeze().item()
170
+
171
+ # Matrices.
172
+ viewport_mtx = VIEWPORT_MTX.to(depth)
173
+ proj_mtx = PROJ_MTX.to(depth)
174
+ view_mtx = VIEW_MTX.to(depth)
175
+ scr_mtx = (viewport_mtx @ proj_mtx).to(depth)
176
+
177
+ grid = torch.stack(torch.meshgrid(
178
+ torch.arange(W), torch.arange(H), indexing='xy'), dim=-1
179
+ ).to(depth)[None] # BHW2
180
+
181
+ screen = F.pad(grid, (0, 1), 'constant', 0)
182
+ screen = F.pad(screen, (0, 1), 'constant', 1)
183
+ screen_flat = rearrange(screen, 'b h w c -> b (h w) c')
184
+
185
+ eye = screen_flat @ torch.linalg.inv_ex(
186
+ scr_mtx.float()
187
+ )[0].mT.to(depth)
188
+ eye = eye * rearrange(depth, 'b c h w -> b (h w) c')
189
+ eye[..., 3] = 1
190
+
191
+ points = eye @ torch.linalg.inv_ex(view_mtx.float())[0].mT.to(depth)
192
+ points = points[0, :, :3]
193
+
194
+ # Translate to the origin.
195
+ points[..., 2] -= mean_depth
196
+ camera_pos = (0, 0, -mean_depth)
197
+
198
+ return points, camera_pos
199
+
200
+ def view_from_rt(position, rotation):
201
+ t = np.array(position)
202
+ euler = np.array(rotation)
203
+
204
+ cx = np.cos(euler[0])
205
+ sx = np.sin(euler[0])
206
+ cy = np.cos(euler[1])
207
+ sy = np.sin(euler[1])
208
+ cz = np.cos(euler[2])
209
+ sz = np.sin(euler[2])
210
+ R = np.array([
211
+ cy * cz + sy * sx * sz,
212
+ -cy * sz + sy * sx * cz,
213
+ sy * cx,
214
+ cx * sz,
215
+ cx * cz,
216
+ -sx,
217
+ -sy * cz + cy * sx * sz,
218
+ sy * sz + cy * sx * cz,
219
+ cy * cx
220
+ ])
221
+ view_mtx = np.array([
222
+ [R[0], R[1], R[2], 0],
223
+ [R[3], R[4], R[5], 0],
224
+ [R[6], R[7], R[8], 0],
225
+ [
226
+ -t[0] * R[0] - t[1] * R[3] - t[2] * R[6],
227
+ -t[0] * R[1] - t[1] * R[4] - t[2] * R[7],
228
+ -t[0] * R[2] - t[1] * R[5] - t[2] * R[8],
229
+ 1
230
+ ]
231
+ ]).T
232
+
233
+ B = np.array([
234
+ [1, 0, 0, 0],
235
+ [0, -1, 0, 0],
236
+ [0, 0, -1, 0],
237
+ [0, 0, 0, 1]
238
+ ])
239
+ return B @ view_mtx
240
+
241
 
 
242
  with tempfile.TemporaryDirectory() as tmpdir:
243
  with gr.Blocks(
244
+ title='GenWarp Demo',
245
  css='img {display: inline;}'
246
  ) as demo:
247
  # Internal states.
248
+ image = gr.State()
249
+ depth = gr.State()
 
 
 
250
 
251
  # Callbacks
252
+ @spaces.GPU()
253
  def cb_mde(image_file: str):
254
+ # Load an image.
255
+ image_pil = crop(Image.open(image_file).convert('RGB'))
256
+ image = to_tensor(image_pil)[None].detach()
257
+ # Get depth.
258
+ depth = mde.cuda().infer(image.cuda()).cpu().detach()
259
+ depth_pil = to_pil_image(colorize(depth[0]))
260
+ return image_pil, depth_pil, image, depth
261
 
262
+ @spaces.GPU()
263
+ def cb_3d(image_file, image, depth):
264
+ # Unproject.
265
+ xyz, camera_pos = unproject(depth.cuda())
266
+ xyz = xyz.cpu().detach().numpy()
267
+ # Save as a splat.
268
+ ## Output filename.
269
+ splat_file = join(
270
+ tmpdir, f'./{splitext(basename(image_file))[0]}.splat')
271
+ rgb = rearrange(image, 'b c h w -> b (h w) c')[0].numpy()
272
+ save_as_splat(splat_file, xyz, rgb)
273
+ return splat_file, camera_pos, (0, 0, 0)
274
 
275
  @spaces.GPU()
276
+ def cb_generate(viewer, image, depth):
277
+ if depth is None:
278
+ gr.Error('Image and Depth are not set. Try again.')
279
+ return None, None
280
 
281
+ mean_depth = depth.mean(dim=(2, 3)).squeeze().item()
282
+ src_view_mtx = camera_lookat(
283
+ torch.tensor([[0., 0., -mean_depth]]),
284
+ torch.tensor([[0., 0., 0.]]),
285
+ torch.tensor([[0., -1., 0.]])
286
+ ).to(depth)
287
+ tar_camera_pos, tar_camera_rot = viewer[1:3]
288
+ tar_view_mtx = torch.from_numpy(view_from_rt(
289
+ tar_camera_pos, tar_camera_rot
290
+ ))
291
+ rel_view_mtx = (
292
+ tar_view_mtx @ torch.linalg.inv(src_view_mtx.double())
293
+ ).half().cuda()
294
+ proj_mtx = PROJ_MTX.half().cuda()
295
 
296
+ # GenWarp.
297
+ renders = genwarp_nvs.to('cuda')(
298
+ src_image=image.half().cuda(),
299
+ src_depth=depth.half().cuda(),
300
+ rel_view_mtx=rel_view_mtx,
301
+ src_proj_mtx=proj_mtx,
302
+ tar_proj_mtx=proj_mtx
303
  )
304
+ warped_pil = to_pil_image(renders['warped'].cpu()[0])
305
+ synthesized_pil = to_pil_image(renders['synthesized'].cpu()[0])
 
 
 
306
 
307
+ return warped_pil, synthesized_pil
 
308
 
309
+ def process_example(image_file):
310
+ gr.Error('')
311
+ image_pil, depth_pil, image, depth = cb_mde(image_file)
312
+ viewer = cb_3d(image_file, image, depth)
313
+ # Fixed angle for examples.
314
+ viewer = (viewer[0], (-2.020, -0.727, -5.236), (-0.132, 0.378, 0.0))
315
+ warped_pil, synthsized_pil = cb_generate(
316
+ viewer, image, depth
317
+ )
318
+ return (
319
+ image_pil, depth_pil, viewer,
320
+ warped_pil, synthsized_pil,
321
+ None, None # Clear internal states.
322
+ )
323
 
324
  # Blocks.
325
  gr.Markdown(
 
330
  [![Github](https://img.shields.io/badge/Github-Repo-orange?logo=github)](https://github.com/sony/genwarp/)  
331
  [![Models](https://img.shields.io/badge/Models-checkpoints-blue?logo=huggingface)](https://huggingface.co/Sony/genwarp)  
332
  [![arXiv](https://img.shields.io/badge/arXiv-2405.17251-red?logo=arxiv)](https://arxiv.org/abs/2405.17251)
333
+
334
  ## Introduction
335
  This is an official demo for the paper "[GenWarp: Single Image to Novel Views with Semantic-Preserving Generative Warping](https://genwarp-nvs.github.io/)". Genwarp can generate novel view images from a single input conditioned on camera poses. In this demo, we offer a basic use of inference of the model. For detailed information, please refer to the [paper](https://arxiv.org/abs/2405.17251).
336
+
337
+ ## How to Use
338
+
339
+ ### Try examples
340
+ - Examples are in the bottom section of the page
341
+
342
+ ### Upload your own images
343
+ 1. Upload a reference image to "Reference Input"
344
+ 2. Move the camera to your desired view in "Unprojected 3DGS" 3D viewer
345
+ 3. Hit "Generate a novel view" button and check the result
346
+
347
+ ## Tips
348
+ - This model is mainly trained for indoor/outdoor scenery. It might not work well for object-centric inputs. For details on training the model, please check our [paper](https://arxiv.org/abs/2405.17251).
349
+ - Extremely large camera movement from the input view might cause low performance results due to the unexpected deviation from the training distribution, which is not the scope of this model. Instead, you can feed the generation result for the small camera movement repeatedly and progressively move towards a desired view.
350
+ - 3D viewer might take some time to update especially when trying different images back to back. Wait until it fully updates to the new image.
351
+
352
  """
353
  )
354
+ file = gr.File(label='Reference Input', file_types=['image'])
 
 
 
 
 
 
 
355
  with gr.Row():
356
  image_widget = gr.Image(
357
+ label='Reference View', type='filepath',
358
  interactive=False
359
  )
360
  depth_widget = gr.Image(label='Estimated Depth', type='pil')
361
+ viewer = Model3DGSCamera(
362
+ label = 'Unprojected 3DGS',
363
+ width=IMAGE_SIZE,
364
+ height=IMAGE_SIZE,
365
+ camera_width=IMAGE_SIZE,
366
+ camera_height=IMAGE_SIZE,
367
+ camera_fx=IMAGE_SIZE / (np.tan(FOVY / 2.)) / 2.,
368
+ camera_fy=IMAGE_SIZE / (np.tan(FOVY / 2.)) / 2.,
369
+ camera_near=NEAR,
370
+ camera_far=FAR
371
+ )
372
+ button = gr.Button('Generate a novel view', size='lg', variant='primary')
373
  with gr.Row():
374
  warped_widget = gr.Image(
375
  label='Warped Image', type='pil', interactive=False
376
  )
377
  gen_widget = gr.Image(
378
+ label='Generated View', type='pil', interactive=False
379
  )
380
+ examples = gr.Examples(
381
+ examples=[
382
+ './assets/pexels-heyho-5998120_19mm.jpg',
383
+ './assets/pexels-itsterrymag-12639296_24mm.jpg'
384
+ ],
385
+ fn=process_example,
386
+ inputs=file,
387
+ outputs=[image_widget, depth_widget, viewer,
388
+ warped_widget, gen_widget,
389
+ image, depth]
390
+ )
391
 
392
  # Events
393
+ file.upload(
394
  fn=cb_mde,
395
  inputs=file,
396
+ outputs=[image_widget, depth_widget, image, depth]
397
+ ).success(
398
+ fn=cb_3d,
399
+ inputs=[image_widget, image, depth],
400
+ outputs=viewer
401
  )
402
  button.click(
403
  fn=cb_generate,
404
+ inputs=[viewer, image, depth],
405
  outputs=[warped_widget, gen_widget]
406
  )
407
+ # To re-calculate the uncached depth for examples in background.
408
+ examples.load_input_event.success(
409
+ fn=lambda x: cb_mde(x)[2:4],
410
+ inputs=file,
411
+ outputs=[image, depth]
412
+ )
413
 
414
  if __name__ == '__main__':
415
  demo.launch()