FQiao commited on
Commit
02b9ec5
·
verified ·
1 Parent(s): a0c86ca

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +140 -314
app.py CHANGED
@@ -1,24 +1,24 @@
1
- import sys
2
- from subprocess import check_call
3
- import tempfile
4
-
5
  from os.path import basename, splitext, join
6
- from io import BytesIO
7
-
8
  import numpy as np
9
- from scipy.spatial import KDTree
10
  from PIL import Image
11
-
12
  import torch
13
- import torch.nn.functional as F
14
  from torchvision.transforms.functional import to_tensor, to_pil_image
15
- from einops import rearrange
16
- import gradio as gr
 
17
  from huggingface_hub import hf_hub_download
 
18
 
19
- from extern.ZoeDepth.zoedepth.utils.misc import colorize
 
20
 
21
- from gradio_model3dgscamera import Model3DGSCamera
 
 
22
 
23
  def download_models():
24
  models = [
@@ -37,10 +37,17 @@ def download_models():
37
  'token': None
38
  },
39
  {
40
- 'repo': 'Sony/genwarp',
41
- 'sub': 'multi1',
 
 
 
 
 
 
 
42
  'dst': 'checkpoints',
43
- 'files': ['config.json', 'denoising_unet.pth', 'pose_guider.pth', 'reference_unet.pth'],
44
  'token': None
45
  }
46
  ]
@@ -58,53 +65,43 @@ def download_models():
58
  # Setup.
59
  download_models()
60
 
61
- mde = torch.hub.load(
62
- './extern/ZoeDepth',
63
- 'ZoeD_N',
64
- source='local',
65
- pretrained=True,
66
- trust_repo=True
67
- )
68
-
69
- import spaces
70
-
71
- check_call([
72
- sys.executable, '-m', 'pip', 'install',
73
- 'extern/splatting-0.0.1-py3-none-any.whl'
74
- ])
75
-
76
- from genwarp import GenWarp
77
- from genwarp.ops import (
78
- camera_lookat, get_projection_matrix, get_viewport_matrix
79
- )
80
-
81
- # GenWarp
82
- genwarp_cfg = dict(
83
- pretrained_model_path='checkpoints',
84
- checkpoint_name='multi1',
85
- half_precision_weights=True
86
- )
87
- genwarp_nvs = GenWarp(cfg=genwarp_cfg, device='cpu')
88
-
89
- # Fixed parameters.
90
- IMAGE_SIZE = 512
91
- NEAR, FAR = 0.01, 100
92
- FOVY = np.deg2rad(55)
93
- PROJ_MTX = get_projection_matrix(
94
- fovy=torch.ones(1) * FOVY,
95
- aspect_wh=1.,
96
- near=NEAR,
97
- far=FAR
98
- )
99
- VIEW_MTX = camera_lookat(
100
- torch.tensor([[0., 0., 0.]]),
101
- torch.tensor([[0., 0., 1.]]),
102
- torch.tensor([[0., -1., 0.]])
103
- )
104
- VIEWPORT_MTX = get_viewport_matrix(
105
- IMAGE_SIZE, IMAGE_SIZE,
106
- batch_size=1
107
- )
108
 
109
  # Crop the image to the shorter side.
110
  def crop(img: Image) -> Image:
@@ -115,301 +112,130 @@ def crop(img: Image) -> Image:
115
  else:
116
  left, right = np.ceil((W - H) / 2.), np.floor((W - H) / 2.) + H
117
  top, bottom = 0, H
118
- img = img.crop((left, top, right, bottom))
119
- img = img.resize((IMAGE_SIZE, IMAGE_SIZE))
120
- return img
121
-
122
- def save_as_splat(
123
- filepath: str,
124
- xyz: np.ndarray,
125
- rgb: np.ndarray
126
- ):
127
- # To gaussian splat
128
- inv_sigmoid = lambda x: np.log(x / (1 - x))
129
- dist2 = np.clip(calc_dist2(xyz), a_min=0.0000001, a_max=None)
130
- scales = np.repeat(np.log(np.sqrt(dist2))[..., np.newaxis], 3, axis=1)
131
- rots = np.zeros((xyz.shape[0], 4))
132
- rots[:, 0] = 1
133
- opacities = inv_sigmoid(0.1 * np.ones((xyz.shape[0], 1)))
134
-
135
- sorted_indices = np.argsort((
136
- -np.exp(np.sum(scales, axis=-1, keepdims=True))
137
- / (1 + np.exp(-opacities))
138
- ).squeeze())
139
-
140
- buffer = BytesIO()
141
- for idx in sorted_indices:
142
- position = xyz[idx]
143
- scale = np.exp(scales[idx]).astype(np.float32)
144
- rot = rots[idx].astype(np.float32)
145
- color = np.concatenate(
146
- (rgb[idx], 1 / (1 + np.exp(-opacities[idx]))),
147
- axis=-1
148
- )
149
- buffer.write(position.tobytes())
150
- buffer.write(scale.tobytes())
151
- buffer.write((color * 255).clip(0, 255).astype(np.uint8).tobytes())
152
- buffer.write(
153
- ((rot / np.linalg.norm(rot)) * 128 + 128)
154
- .clip(0, 255)
155
- .astype(np.uint8)
156
- .tobytes()
157
- )
158
-
159
- with open(filepath, "wb") as f:
160
- f.write(buffer.getvalue())
161
-
162
- def calc_dist2(points: np.ndarray):
163
- dists, _ = KDTree(points).query(points, k=4)
164
- mean_dists = (dists[:, 1:] ** 2).mean(1)
165
- return mean_dists
166
-
167
- def unproject(depth):
168
- H, W = depth.shape[2:4]
169
- mean_depth = depth.mean(dim=(2, 3)).squeeze().item()
170
-
171
- # Matrices.
172
- viewport_mtx = VIEWPORT_MTX.to(depth)
173
- proj_mtx = PROJ_MTX.to(depth)
174
- view_mtx = VIEW_MTX.to(depth)
175
- scr_mtx = (viewport_mtx @ proj_mtx).to(depth)
176
-
177
- grid = torch.stack(torch.meshgrid(
178
- torch.arange(W), torch.arange(H), indexing='xy'), dim=-1
179
- ).to(depth)[None] # BHW2
180
-
181
- screen = F.pad(grid, (0, 1), 'constant', 0)
182
- screen = F.pad(screen, (0, 1), 'constant', 1)
183
- screen_flat = rearrange(screen, 'b h w c -> b (h w) c')
184
-
185
- eye = screen_flat @ torch.linalg.inv_ex(
186
- scr_mtx.float()
187
- )[0].mT.to(depth)
188
- eye = eye * rearrange(depth, 'b c h w -> b (h w) c')
189
- eye[..., 3] = 1
190
-
191
- points = eye @ torch.linalg.inv_ex(view_mtx.float())[0].mT.to(depth)
192
- points = points[0, :, :3]
193
-
194
- # Translate to the origin.
195
- points[..., 2] -= mean_depth
196
- camera_pos = (0, 0, -mean_depth)
197
-
198
- return points, camera_pos
199
-
200
- def view_from_rt(position, rotation):
201
- t = np.array(position)
202
- euler = np.array(rotation)
203
-
204
- cx = np.cos(euler[0])
205
- sx = np.sin(euler[0])
206
- cy = np.cos(euler[1])
207
- sy = np.sin(euler[1])
208
- cz = np.cos(euler[2])
209
- sz = np.sin(euler[2])
210
- R = np.array([
211
- cy * cz + sy * sx * sz,
212
- -cy * sz + sy * sx * cz,
213
- sy * cx,
214
- cx * sz,
215
- cx * cz,
216
- -sx,
217
- -sy * cz + cy * sx * sz,
218
- sy * sz + cy * sx * cz,
219
- cy * cx
220
- ])
221
- view_mtx = np.array([
222
- [R[0], R[1], R[2], 0],
223
- [R[3], R[4], R[5], 0],
224
- [R[6], R[7], R[8], 0],
225
- [
226
- -t[0] * R[0] - t[1] * R[3] - t[2] * R[6],
227
- -t[0] * R[1] - t[1] * R[4] - t[2] * R[7],
228
- -t[0] * R[2] - t[1] * R[5] - t[2] * R[8],
229
- 1
230
- ]
231
- ]).T
232
-
233
- B = np.array([
234
- [1, 0, 0, 0],
235
- [0, -1, 0, 0],
236
- [0, 0, -1, 0],
237
- [0, 0, 0, 1]
238
- ])
239
- return B @ view_mtx
240
 
 
 
241
 
 
242
  with tempfile.TemporaryDirectory() as tmpdir:
243
  with gr.Blocks(
244
- title='GenWarp Demo',
245
  css='img {display: inline;}'
246
  ) as demo:
247
  # Internal states.
248
- image = gr.State()
249
- depth = gr.State()
250
 
251
  # Callbacks
252
- @spaces.GPU()
253
  def cb_mde(image_file: str):
254
- # Load an image.
255
- image_pil = crop(Image.open(image_file).convert('RGB'))
256
- image = to_tensor(image_pil)[None].detach()
257
- # Get depth.
258
- depth = mde.cuda().infer(image.cuda()).cpu().detach()
259
- depth_pil = to_pil_image(colorize(depth[0]))
260
- return image_pil, depth_pil, image, depth
261
 
262
- @spaces.GPU()
263
- def cb_3d(image_file, image, depth):
264
- # Unproject.
265
- xyz, camera_pos = unproject(depth.cuda())
266
- xyz = xyz.cpu().detach().numpy()
267
- # Save as a splat.
268
- ## Output filename.
269
- splat_file = join(
270
- tmpdir, f'./{splitext(basename(image_file))[0]}.splat')
271
- rgb = rearrange(image, 'b c h w -> b (h w) c')[0].numpy()
272
- save_as_splat(splat_file, xyz, rgb)
273
- return splat_file, camera_pos, (0, 0, 0)
274
 
275
  @spaces.GPU()
276
- def cb_generate(viewer, image, depth):
277
- if depth is None:
278
- gr.Error('Image and Depth are not set. Try again.')
279
- return None, None
280
-
281
- mean_depth = depth.mean(dim=(2, 3)).squeeze().item()
282
- src_view_mtx = camera_lookat(
283
- torch.tensor([[0., 0., -mean_depth]]),
284
- torch.tensor([[0., 0., 0.]]),
285
- torch.tensor([[0., -1., 0.]])
286
- ).to(depth)
287
- tar_camera_pos, tar_camera_rot = viewer[1:3]
288
- tar_view_mtx = torch.from_numpy(view_from_rt(
289
- tar_camera_pos, tar_camera_rot
290
- ))
291
- rel_view_mtx = (
292
- tar_view_mtx @ torch.linalg.inv(src_view_mtx.double())
293
- ).half().cuda()
294
- proj_mtx = PROJ_MTX.half().cuda()
295
-
296
- # GenWarp.
297
- renders = genwarp_nvs.to('cuda')(
298
- src_image=image.half().cuda(),
299
- src_depth=depth.half().cuda(),
300
- rel_view_mtx=rel_view_mtx,
301
- src_proj_mtx=proj_mtx,
302
- tar_proj_mtx=proj_mtx
303
- )
304
- warped_pil = to_pil_image(renders['warped'].cpu()[0])
305
- synthesized_pil = to_pil_image(renders['synthesized'].cpu()[0])
306
-
307
- return warped_pil, synthesized_pil
308
-
309
- def process_example(image_file):
310
- gr.Error('')
311
- image_pil, depth_pil, image, depth = cb_mde(image_file)
312
- viewer = cb_3d(image_file, image, depth)
313
- # Fixed angle for examples.
314
- viewer = (viewer[0], (-2.020, -0.727, -5.236), (-0.132, 0.378, 0.0))
315
- warped_pil, synthsized_pil = cb_generate(
316
- viewer, image, depth
317
- )
318
- return (
319
- image_pil, depth_pil, viewer,
320
- warped_pil, synthsized_pil,
321
- None, None # Clear internal states.
322
  )
 
 
 
 
 
 
 
 
 
 
323
 
324
  # Blocks.
325
  gr.Markdown(
326
  """
327
- # GenWarp: Single Image to Novel Views with Semantic-Preserving Generative Warping
328
- [![Project Site](https://img.shields.io/badge/Project-Web-green)](https://genwarp-nvs.github.io/)  
329
- [![Spaces](https://img.shields.io/badge/Spaces-Demo-yellow?logo=huggingface)](https://huggingface.co/spaces/Sony/GenWarp)  
330
- [![Github](https://img.shields.io/badge/Github-Repo-orange?logo=github)](https://github.com/sony/genwarp/)  
331
- [![Models](https://img.shields.io/badge/Models-checkpoints-blue?logo=huggingface)](https://huggingface.co/Sony/genwarp)  
332
  [![arXiv](https://img.shields.io/badge/arXiv-2405.17251-red?logo=arxiv)](https://arxiv.org/abs/2405.17251)
333
-
334
- ## Introduction
335
- This is an official demo for the paper "[GenWarp: Single Image to Novel Views with Semantic-Preserving Generative Warping](https://genwarp-nvs.github.io/)". Genwarp can generate novel view images from a single input conditioned on camera poses. In this demo, we offer a basic use of inference of the model. For detailed information, please refer to the [paper](https://arxiv.org/abs/2405.17251).
336
-
337
  ## How to Use
338
 
339
- ### Try examples
340
- - Examples are in the bottom section of the page
341
-
342
- ### Upload your own images
343
- 1. Upload a reference image to "Reference Input"
344
- 2. Move the camera to your desired view in "Unprojected 3DGS" 3D viewer
345
- 3. Hit "Generate a novel view" button and check the result
346
-
347
- ## Tips
348
- - This model is mainly trained for indoor/outdoor scenery. It might not work well for object-centric inputs. For details on training the model, please check our [paper](https://arxiv.org/abs/2405.17251).
349
- - Extremely large camera movement from the input view might cause low performance results due to the unexpected deviation from the training distribution, which is not the scope of this model. Instead, you can feed the generation result for the small camera movement repeatedly and progressively move towards a desired view.
350
- - 3D viewer might take some time to update especially when trying different images back to back. Wait until it fully updates to the new image.
351
 
352
  """
353
  )
354
- file = gr.File(label='Reference Input', file_types=['image'])
 
 
 
 
 
 
 
355
  with gr.Row():
356
  image_widget = gr.Image(
357
- label='Reference View', type='filepath',
358
  interactive=False
359
  )
360
  depth_widget = gr.Image(label='Estimated Depth', type='pil')
361
- viewer = Model3DGSCamera(
362
- label = 'Unprojected 3DGS',
363
- width=IMAGE_SIZE,
364
- height=IMAGE_SIZE,
365
- camera_width=IMAGE_SIZE,
366
- camera_height=IMAGE_SIZE,
367
- camera_fx=IMAGE_SIZE / (np.tan(FOVY / 2.)) / 2.,
368
- camera_fy=IMAGE_SIZE / (np.tan(FOVY / 2.)) / 2.,
369
- camera_near=NEAR,
370
- camera_far=FAR
371
- )
372
- button = gr.Button('Generate a novel view', size='lg', variant='primary')
373
  with gr.Row():
374
  warped_widget = gr.Image(
375
  label='Warped Image', type='pil', interactive=False
376
  )
377
  gen_widget = gr.Image(
378
- label='Generated View', type='pil', interactive=False
379
  )
380
- examples = gr.Examples(
381
- examples=[
382
- './assets/pexels-heyho-5998120_19mm.jpg',
383
- './assets/pexels-itsterrymag-12639296_24mm.jpg'
384
- ],
385
- fn=process_example,
386
- inputs=file,
387
- outputs=[image_widget, depth_widget, viewer,
388
- warped_widget, gen_widget,
389
- image, depth]
390
- )
391
 
392
  # Events
393
- file.upload(
394
  fn=cb_mde,
395
  inputs=file,
396
- outputs=[image_widget, depth_widget, image, depth]
397
- ).success(
398
- fn=cb_3d,
399
- inputs=[image_widget, image, depth],
400
- outputs=viewer
401
  )
402
  button.click(
403
  fn=cb_generate,
404
- inputs=[viewer, image, depth],
405
  outputs=[warped_widget, gen_widget]
406
  )
407
- # To re-calculate the uncached depth for examples in background.
408
- examples.load_input_event.success(
409
- fn=lambda x: cb_mde(x)[2:4],
410
- inputs=file,
411
- outputs=[image, depth]
412
- )
413
 
414
  if __name__ == '__main__':
415
  demo.launch()
 
1
+ import os
 
 
 
2
  from os.path import basename, splitext, join
3
+ import tempfile
4
+ import gradio as gr
5
  import numpy as np
 
6
  from PIL import Image
 
7
  import torch
8
+ import cv2
9
  from torchvision.transforms.functional import to_tensor, to_pil_image
10
+ from torch import Tensor
11
+ from genstereo import GenStereo, AdaptiveFusionLayer
12
+ import ssl
13
  from huggingface_hub import hf_hub_download
14
+ import spaces
15
 
16
+ from extern.DAM2.depth_anything_v2.dpt import DepthAnythingV2
17
+ ssl._create_default_https_context = ssl._create_unverified_context
18
 
19
+ IMAGE_SIZE = 512
20
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
21
+ CHECKPOINT_NAME = 'genstereo'
22
 
23
  def download_models():
24
  models = [
 
37
  'token': None
38
  },
39
  {
40
+ 'repo': 'FQiao/GenStereo',
41
+ 'sub': None,
42
+ 'dst': 'checkpoints/genstereo',
43
+ 'files': ['config.json', 'denoising_unet.pth', 'fusion_layer.pth', 'pose_guider.pth', 'reference_unet.pth'],
44
+ 'token': None
45
+ },
46
+ {
47
+ 'repo': 'depth-anything/Depth-Anything-V2-Large',
48
+ 'sub': None,
49
  'dst': 'checkpoints',
50
+ 'files': [f'depth_anything_v2_vitl.pth'],
51
  'token': None
52
  }
53
  ]
 
65
  # Setup.
66
  download_models()
67
 
68
+ # DepthAnythingV2
69
+ def get_dam2_model():
70
+ model_configs = {
71
+ 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
72
+ 'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
73
+ 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
74
+ }
75
+
76
+ encoder = 'vitl'
77
+ encoder_size_map = {'vits': 'Small', 'vitb': 'Base', 'vitl': 'Large'}
78
+
79
+ if encoder not in encoder_size_map:
80
+ raise ValueError(f"Unsupported encoder: {encoder}. Supported: {list(encoder_size_map.keys())}")
81
+
82
+ dam2 = DepthAnythingV2(**model_configs[encoder])
83
+ dam2_checkpoint = f'checkpoints/depth_anything_v2_{encoder}.pth'
84
+ dam2.load_state_dict(torch.load(dam2_checkpoint, map_location='cpu'))
85
+ dam2 = dam2.to(DEVICE).eval()
86
+ return dam2
87
+
88
+ # GenStereo
89
+ def get_genstereo_model():
90
+ genwarp_cfg = dict(
91
+ pretrained_model_path='checkpoints',
92
+ checkpoint_name=CHECKPOINT_NAME,
93
+ half_precision_weights=True
94
+ )
95
+ genstereo = GenStereo(cfg=genwarp_cfg, device=DEVICE)
96
+ return genstereo
97
+
98
+ # Adaptive Fusion
99
+ def get_fusion_model():
100
+ fusion_model = AdaptiveFusionLayer()
101
+ fusion_checkpoint = join('checkpoints', CHECKPOINT_NAME, 'fusion_layer.pth')
102
+ fusion_model.load_state_dict(torch.load(fusion_checkpoint, map_location='cpu'))
103
+ fusion_model = fusion_model.to(DEVICE).eval()
104
+ return fusion_model
 
 
 
 
 
 
 
 
 
 
105
 
106
  # Crop the image to the shorter side.
107
  def crop(img: Image) -> Image:
 
112
  else:
113
  left, right = np.ceil((W - H) / 2.), np.floor((W - H) / 2.) + H
114
  top, bottom = 0, H
115
+ return img.crop((left, top, right, bottom))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
+ def normalize_disp(disp):
118
+ return (disp - disp.min()) / (disp.max() - disp.min())
119
 
120
+ # Gradio app
121
  with tempfile.TemporaryDirectory() as tmpdir:
122
  with gr.Blocks(
123
+ title='StereoGen Demo',
124
  css='img {display: inline;}'
125
  ) as demo:
126
  # Internal states.
127
+ src_image = gr.State()
128
+ src_depth = gr.State()
129
 
130
  # Callbacks
131
+ @spaces.GPU()
132
  def cb_mde(image_file: str):
133
+ if not image_file:
134
+ # Return None if no image is provided (e.g., when file is cleared).
135
+ return None, None, None, None
 
 
 
 
136
 
137
+ image = crop(Image.open(image_file).convert('RGB')) # Load image using PIL
138
+ image = image.resize((IMAGE_SIZE, IMAGE_SIZE))
139
+
140
+ image_bgr = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
141
+
142
+ dam2 = get_dam2_model()
143
+ depth_dam2 = dam2.infer_image(image_bgr)
144
+ depth = torch.tensor(depth_dam2).unsqueeze(0).unsqueeze(0).float()
145
+
146
+ depth_image = cv2.applyColorMap((normalize_disp(depth_dam2) * 255).astype(np.uint8), cv2.COLORMAP_JET)
147
+
148
+ return image, depth_image, image, depth
149
 
150
  @spaces.GPU()
151
+ def cb_generate(image, depth: Tensor, scale_factor):
152
+ norm_disp = normalize_disp(depth.cuda())
153
+ disp = norm_disp * scale_factor / 100 * IMAGE_SIZE
154
+
155
+ genstereo = get_genstereo_model()
156
+ fusion_model = get_fusion_model()
157
+
158
+ renders = genstereo(
159
+ src_image=image,
160
+ src_disparity=disp,
161
+ ratio=None,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  )
163
+ warped = (renders['warped'] + 1) / 2
164
+
165
+ synthesized = renders['synthesized']
166
+ mask = renders['mask']
167
+ fusion_image = fusion_model(synthesized.float(), warped.float(), mask.float())
168
+
169
+ warped_pil = to_pil_image(warped[0])
170
+ fusion_pil = to_pil_image(fusion_image[0])
171
+
172
+ return warped_pil, fusion_pil
173
 
174
  # Blocks.
175
  gr.Markdown(
176
  """
177
+ # StereoGen: Towards Open-World Generation of Stereo Images and Unsupervised Matching
178
+ [![Project Site](https://img.shields.io/badge/Project-Web-green)](https://qjizhi.github.io/genstereo)  
179
+ [![Spaces](https://img.shields.io/badge/Spaces-Demo-yellow?logo=huggingface)](https://huggingface.co/spaces/FQiao/GenStereo)  
180
+ [![Github](https://img.shields.io/badge/Github-Repo-orange?logo=github)](https://github.com/Qjizhi/GenStereo)  
181
+ [![Models](https://img.shields.io/badge/Models-checkpoints-blue?logo=huggingface)](https://huggingface.co/FQiao/GenStereo/tree/main)  
182
  [![arXiv](https://img.shields.io/badge/arXiv-2405.17251-red?logo=arxiv)](https://arxiv.org/abs/2405.17251)
183
+
184
+ ## Introduction
185
+ This is an official demo for the paper "[Towards Open-World Generation of Stereo Images and Unsupervised Matching](https://qjizhi.github.io/genstereo)". Given an arbitrary reference image, GenStereo can generate the corresponding right-view image.
186
+
187
  ## How to Use
188
 
189
+ 1. Upload a reference image to "Left Image"
190
+ - You can also select an image from "Examples"
191
+ 3. Hit "Generate a right image" button and check the result
 
 
 
 
 
 
 
 
 
192
 
193
  """
194
  )
195
+ file = gr.File(label='Left', file_types=['image'])
196
+ examples = gr.Examples(
197
+ examples=['./assets/COCO_val2017_000000070229.jpg',
198
+ './assets/COCO_val2017_000000092839.jpg',
199
+ './assets/KITTI2015_000003_10.png',
200
+ './assets/KITTI2015_000147_10.png'],
201
+ inputs=file
202
+ )
203
  with gr.Row():
204
  image_widget = gr.Image(
205
+ label='Depth', type='filepath',
206
  interactive=False
207
  )
208
  depth_widget = gr.Image(label='Estimated Depth', type='pil')
209
+
210
+ # Add scale factor slider
211
+ scale_slider = gr.Slider(
212
+ label='Scale Factor',
213
+ minimum=1.0,
214
+ maximum=30.0,
215
+ value=15.0,
216
+ step=0.1,
217
+ )
218
+
219
+ button = gr.Button('Generate a right image', size='lg', variant='primary')
 
220
  with gr.Row():
221
  warped_widget = gr.Image(
222
  label='Warped Image', type='pil', interactive=False
223
  )
224
  gen_widget = gr.Image(
225
+ label='Generated Right', type='pil', interactive=False
226
  )
 
 
 
 
 
 
 
 
 
 
 
227
 
228
  # Events
229
+ file.change(
230
  fn=cb_mde,
231
  inputs=file,
232
+ outputs=[image_widget, depth_widget, src_image, src_depth]
 
 
 
 
233
  )
234
  button.click(
235
  fn=cb_generate,
236
+ inputs=[src_image, src_depth, scale_slider],
237
  outputs=[warped_widget, gen_widget]
238
  )
 
 
 
 
 
 
239
 
240
  if __name__ == '__main__':
241
  demo.launch()