Update app.py
Browse files
app.py
CHANGED
@@ -1,24 +1,24 @@
|
|
1 |
-
import
|
2 |
-
from subprocess import check_call
|
3 |
-
import tempfile
|
4 |
-
|
5 |
from os.path import basename, splitext, join
|
6 |
-
|
7 |
-
|
8 |
import numpy as np
|
9 |
-
from scipy.spatial import KDTree
|
10 |
from PIL import Image
|
11 |
-
|
12 |
import torch
|
13 |
-
import
|
14 |
from torchvision.transforms.functional import to_tensor, to_pil_image
|
15 |
-
from
|
16 |
-
import
|
|
|
17 |
from huggingface_hub import hf_hub_download
|
|
|
18 |
|
19 |
-
from extern.
|
|
|
20 |
|
21 |
-
|
|
|
|
|
22 |
|
23 |
def download_models():
|
24 |
models = [
|
@@ -37,10 +37,17 @@ def download_models():
|
|
37 |
'token': None
|
38 |
},
|
39 |
{
|
40 |
-
'repo': '
|
41 |
-
'sub':
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
'dst': 'checkpoints',
|
43 |
-
'files': ['
|
44 |
'token': None
|
45 |
}
|
46 |
]
|
@@ -58,53 +65,43 @@ def download_models():
|
|
58 |
# Setup.
|
59 |
download_models()
|
60 |
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
#
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
)
|
99 |
-
VIEW_MTX = camera_lookat(
|
100 |
-
torch.tensor([[0., 0., 0.]]),
|
101 |
-
torch.tensor([[0., 0., 1.]]),
|
102 |
-
torch.tensor([[0., -1., 0.]])
|
103 |
-
)
|
104 |
-
VIEWPORT_MTX = get_viewport_matrix(
|
105 |
-
IMAGE_SIZE, IMAGE_SIZE,
|
106 |
-
batch_size=1
|
107 |
-
)
|
108 |
|
109 |
# Crop the image to the shorter side.
|
110 |
def crop(img: Image) -> Image:
|
@@ -115,301 +112,130 @@ def crop(img: Image) -> Image:
|
|
115 |
else:
|
116 |
left, right = np.ceil((W - H) / 2.), np.floor((W - H) / 2.) + H
|
117 |
top, bottom = 0, H
|
118 |
-
|
119 |
-
img = img.resize((IMAGE_SIZE, IMAGE_SIZE))
|
120 |
-
return img
|
121 |
-
|
122 |
-
def save_as_splat(
|
123 |
-
filepath: str,
|
124 |
-
xyz: np.ndarray,
|
125 |
-
rgb: np.ndarray
|
126 |
-
):
|
127 |
-
# To gaussian splat
|
128 |
-
inv_sigmoid = lambda x: np.log(x / (1 - x))
|
129 |
-
dist2 = np.clip(calc_dist2(xyz), a_min=0.0000001, a_max=None)
|
130 |
-
scales = np.repeat(np.log(np.sqrt(dist2))[..., np.newaxis], 3, axis=1)
|
131 |
-
rots = np.zeros((xyz.shape[0], 4))
|
132 |
-
rots[:, 0] = 1
|
133 |
-
opacities = inv_sigmoid(0.1 * np.ones((xyz.shape[0], 1)))
|
134 |
-
|
135 |
-
sorted_indices = np.argsort((
|
136 |
-
-np.exp(np.sum(scales, axis=-1, keepdims=True))
|
137 |
-
/ (1 + np.exp(-opacities))
|
138 |
-
).squeeze())
|
139 |
-
|
140 |
-
buffer = BytesIO()
|
141 |
-
for idx in sorted_indices:
|
142 |
-
position = xyz[idx]
|
143 |
-
scale = np.exp(scales[idx]).astype(np.float32)
|
144 |
-
rot = rots[idx].astype(np.float32)
|
145 |
-
color = np.concatenate(
|
146 |
-
(rgb[idx], 1 / (1 + np.exp(-opacities[idx]))),
|
147 |
-
axis=-1
|
148 |
-
)
|
149 |
-
buffer.write(position.tobytes())
|
150 |
-
buffer.write(scale.tobytes())
|
151 |
-
buffer.write((color * 255).clip(0, 255).astype(np.uint8).tobytes())
|
152 |
-
buffer.write(
|
153 |
-
((rot / np.linalg.norm(rot)) * 128 + 128)
|
154 |
-
.clip(0, 255)
|
155 |
-
.astype(np.uint8)
|
156 |
-
.tobytes()
|
157 |
-
)
|
158 |
-
|
159 |
-
with open(filepath, "wb") as f:
|
160 |
-
f.write(buffer.getvalue())
|
161 |
-
|
162 |
-
def calc_dist2(points: np.ndarray):
|
163 |
-
dists, _ = KDTree(points).query(points, k=4)
|
164 |
-
mean_dists = (dists[:, 1:] ** 2).mean(1)
|
165 |
-
return mean_dists
|
166 |
-
|
167 |
-
def unproject(depth):
|
168 |
-
H, W = depth.shape[2:4]
|
169 |
-
mean_depth = depth.mean(dim=(2, 3)).squeeze().item()
|
170 |
-
|
171 |
-
# Matrices.
|
172 |
-
viewport_mtx = VIEWPORT_MTX.to(depth)
|
173 |
-
proj_mtx = PROJ_MTX.to(depth)
|
174 |
-
view_mtx = VIEW_MTX.to(depth)
|
175 |
-
scr_mtx = (viewport_mtx @ proj_mtx).to(depth)
|
176 |
-
|
177 |
-
grid = torch.stack(torch.meshgrid(
|
178 |
-
torch.arange(W), torch.arange(H), indexing='xy'), dim=-1
|
179 |
-
).to(depth)[None] # BHW2
|
180 |
-
|
181 |
-
screen = F.pad(grid, (0, 1), 'constant', 0)
|
182 |
-
screen = F.pad(screen, (0, 1), 'constant', 1)
|
183 |
-
screen_flat = rearrange(screen, 'b h w c -> b (h w) c')
|
184 |
-
|
185 |
-
eye = screen_flat @ torch.linalg.inv_ex(
|
186 |
-
scr_mtx.float()
|
187 |
-
)[0].mT.to(depth)
|
188 |
-
eye = eye * rearrange(depth, 'b c h w -> b (h w) c')
|
189 |
-
eye[..., 3] = 1
|
190 |
-
|
191 |
-
points = eye @ torch.linalg.inv_ex(view_mtx.float())[0].mT.to(depth)
|
192 |
-
points = points[0, :, :3]
|
193 |
-
|
194 |
-
# Translate to the origin.
|
195 |
-
points[..., 2] -= mean_depth
|
196 |
-
camera_pos = (0, 0, -mean_depth)
|
197 |
-
|
198 |
-
return points, camera_pos
|
199 |
-
|
200 |
-
def view_from_rt(position, rotation):
|
201 |
-
t = np.array(position)
|
202 |
-
euler = np.array(rotation)
|
203 |
-
|
204 |
-
cx = np.cos(euler[0])
|
205 |
-
sx = np.sin(euler[0])
|
206 |
-
cy = np.cos(euler[1])
|
207 |
-
sy = np.sin(euler[1])
|
208 |
-
cz = np.cos(euler[2])
|
209 |
-
sz = np.sin(euler[2])
|
210 |
-
R = np.array([
|
211 |
-
cy * cz + sy * sx * sz,
|
212 |
-
-cy * sz + sy * sx * cz,
|
213 |
-
sy * cx,
|
214 |
-
cx * sz,
|
215 |
-
cx * cz,
|
216 |
-
-sx,
|
217 |
-
-sy * cz + cy * sx * sz,
|
218 |
-
sy * sz + cy * sx * cz,
|
219 |
-
cy * cx
|
220 |
-
])
|
221 |
-
view_mtx = np.array([
|
222 |
-
[R[0], R[1], R[2], 0],
|
223 |
-
[R[3], R[4], R[5], 0],
|
224 |
-
[R[6], R[7], R[8], 0],
|
225 |
-
[
|
226 |
-
-t[0] * R[0] - t[1] * R[3] - t[2] * R[6],
|
227 |
-
-t[0] * R[1] - t[1] * R[4] - t[2] * R[7],
|
228 |
-
-t[0] * R[2] - t[1] * R[5] - t[2] * R[8],
|
229 |
-
1
|
230 |
-
]
|
231 |
-
]).T
|
232 |
-
|
233 |
-
B = np.array([
|
234 |
-
[1, 0, 0, 0],
|
235 |
-
[0, -1, 0, 0],
|
236 |
-
[0, 0, -1, 0],
|
237 |
-
[0, 0, 0, 1]
|
238 |
-
])
|
239 |
-
return B @ view_mtx
|
240 |
|
|
|
|
|
241 |
|
|
|
242 |
with tempfile.TemporaryDirectory() as tmpdir:
|
243 |
with gr.Blocks(
|
244 |
-
title='
|
245 |
css='img {display: inline;}'
|
246 |
) as demo:
|
247 |
# Internal states.
|
248 |
-
|
249 |
-
|
250 |
|
251 |
# Callbacks
|
252 |
-
@spaces.GPU()
|
253 |
def cb_mde(image_file: str):
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
# Get depth.
|
258 |
-
depth = mde.cuda().infer(image.cuda()).cpu().detach()
|
259 |
-
depth_pil = to_pil_image(colorize(depth[0]))
|
260 |
-
return image_pil, depth_pil, image, depth
|
261 |
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
return
|
274 |
|
275 |
@spaces.GPU()
|
276 |
-
def cb_generate(
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
tar_camera_pos, tar_camera_rot = viewer[1:3]
|
288 |
-
tar_view_mtx = torch.from_numpy(view_from_rt(
|
289 |
-
tar_camera_pos, tar_camera_rot
|
290 |
-
))
|
291 |
-
rel_view_mtx = (
|
292 |
-
tar_view_mtx @ torch.linalg.inv(src_view_mtx.double())
|
293 |
-
).half().cuda()
|
294 |
-
proj_mtx = PROJ_MTX.half().cuda()
|
295 |
-
|
296 |
-
# GenWarp.
|
297 |
-
renders = genwarp_nvs.to('cuda')(
|
298 |
-
src_image=image.half().cuda(),
|
299 |
-
src_depth=depth.half().cuda(),
|
300 |
-
rel_view_mtx=rel_view_mtx,
|
301 |
-
src_proj_mtx=proj_mtx,
|
302 |
-
tar_proj_mtx=proj_mtx
|
303 |
-
)
|
304 |
-
warped_pil = to_pil_image(renders['warped'].cpu()[0])
|
305 |
-
synthesized_pil = to_pil_image(renders['synthesized'].cpu()[0])
|
306 |
-
|
307 |
-
return warped_pil, synthesized_pil
|
308 |
-
|
309 |
-
def process_example(image_file):
|
310 |
-
gr.Error('')
|
311 |
-
image_pil, depth_pil, image, depth = cb_mde(image_file)
|
312 |
-
viewer = cb_3d(image_file, image, depth)
|
313 |
-
# Fixed angle for examples.
|
314 |
-
viewer = (viewer[0], (-2.020, -0.727, -5.236), (-0.132, 0.378, 0.0))
|
315 |
-
warped_pil, synthsized_pil = cb_generate(
|
316 |
-
viewer, image, depth
|
317 |
-
)
|
318 |
-
return (
|
319 |
-
image_pil, depth_pil, viewer,
|
320 |
-
warped_pil, synthsized_pil,
|
321 |
-
None, None # Clear internal states.
|
322 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
323 |
|
324 |
# Blocks.
|
325 |
gr.Markdown(
|
326 |
"""
|
327 |
-
#
|
328 |
-
[](https://
|
329 |
-
[](https://huggingface.co/spaces/
|
330 |
-
[](https://github.com/
|
331 |
-
[](https://huggingface.co/
|
332 |
[](https://arxiv.org/abs/2405.17251)
|
333 |
-
|
334 |
-
## Introduction
|
335 |
-
This is an official demo for the paper "[
|
336 |
-
|
337 |
## How to Use
|
338 |
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
### Upload your own images
|
343 |
-
1. Upload a reference image to "Reference Input"
|
344 |
-
2. Move the camera to your desired view in "Unprojected 3DGS" 3D viewer
|
345 |
-
3. Hit "Generate a novel view" button and check the result
|
346 |
-
|
347 |
-
## Tips
|
348 |
-
- This model is mainly trained for indoor/outdoor scenery. It might not work well for object-centric inputs. For details on training the model, please check our [paper](https://arxiv.org/abs/2405.17251).
|
349 |
-
- Extremely large camera movement from the input view might cause low performance results due to the unexpected deviation from the training distribution, which is not the scope of this model. Instead, you can feed the generation result for the small camera movement repeatedly and progressively move towards a desired view.
|
350 |
-
- 3D viewer might take some time to update especially when trying different images back to back. Wait until it fully updates to the new image.
|
351 |
|
352 |
"""
|
353 |
)
|
354 |
-
file = gr.File(label='
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
355 |
with gr.Row():
|
356 |
image_widget = gr.Image(
|
357 |
-
label='
|
358 |
interactive=False
|
359 |
)
|
360 |
depth_widget = gr.Image(label='Estimated Depth', type='pil')
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
button = gr.Button('Generate a novel view', size='lg', variant='primary')
|
373 |
with gr.Row():
|
374 |
warped_widget = gr.Image(
|
375 |
label='Warped Image', type='pil', interactive=False
|
376 |
)
|
377 |
gen_widget = gr.Image(
|
378 |
-
label='Generated
|
379 |
)
|
380 |
-
examples = gr.Examples(
|
381 |
-
examples=[
|
382 |
-
'./assets/pexels-heyho-5998120_19mm.jpg',
|
383 |
-
'./assets/pexels-itsterrymag-12639296_24mm.jpg'
|
384 |
-
],
|
385 |
-
fn=process_example,
|
386 |
-
inputs=file,
|
387 |
-
outputs=[image_widget, depth_widget, viewer,
|
388 |
-
warped_widget, gen_widget,
|
389 |
-
image, depth]
|
390 |
-
)
|
391 |
|
392 |
# Events
|
393 |
-
file.
|
394 |
fn=cb_mde,
|
395 |
inputs=file,
|
396 |
-
outputs=[image_widget, depth_widget,
|
397 |
-
).success(
|
398 |
-
fn=cb_3d,
|
399 |
-
inputs=[image_widget, image, depth],
|
400 |
-
outputs=viewer
|
401 |
)
|
402 |
button.click(
|
403 |
fn=cb_generate,
|
404 |
-
inputs=[
|
405 |
outputs=[warped_widget, gen_widget]
|
406 |
)
|
407 |
-
# To re-calculate the uncached depth for examples in background.
|
408 |
-
examples.load_input_event.success(
|
409 |
-
fn=lambda x: cb_mde(x)[2:4],
|
410 |
-
inputs=file,
|
411 |
-
outputs=[image, depth]
|
412 |
-
)
|
413 |
|
414 |
if __name__ == '__main__':
|
415 |
demo.launch()
|
|
|
1 |
+
import os
|
|
|
|
|
|
|
2 |
from os.path import basename, splitext, join
|
3 |
+
import tempfile
|
4 |
+
import gradio as gr
|
5 |
import numpy as np
|
|
|
6 |
from PIL import Image
|
|
|
7 |
import torch
|
8 |
+
import cv2
|
9 |
from torchvision.transforms.functional import to_tensor, to_pil_image
|
10 |
+
from torch import Tensor
|
11 |
+
from genstereo import GenStereo, AdaptiveFusionLayer
|
12 |
+
import ssl
|
13 |
from huggingface_hub import hf_hub_download
|
14 |
+
import spaces
|
15 |
|
16 |
+
from extern.DAM2.depth_anything_v2.dpt import DepthAnythingV2
|
17 |
+
ssl._create_default_https_context = ssl._create_unverified_context
|
18 |
|
19 |
+
IMAGE_SIZE = 512
|
20 |
+
DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
|
21 |
+
CHECKPOINT_NAME = 'genstereo'
|
22 |
|
23 |
def download_models():
|
24 |
models = [
|
|
|
37 |
'token': None
|
38 |
},
|
39 |
{
|
40 |
+
'repo': 'FQiao/GenStereo',
|
41 |
+
'sub': None,
|
42 |
+
'dst': 'checkpoints/genstereo',
|
43 |
+
'files': ['config.json', 'denoising_unet.pth', 'fusion_layer.pth', 'pose_guider.pth', 'reference_unet.pth'],
|
44 |
+
'token': None
|
45 |
+
},
|
46 |
+
{
|
47 |
+
'repo': 'depth-anything/Depth-Anything-V2-Large',
|
48 |
+
'sub': None,
|
49 |
'dst': 'checkpoints',
|
50 |
+
'files': [f'depth_anything_v2_vitl.pth'],
|
51 |
'token': None
|
52 |
}
|
53 |
]
|
|
|
65 |
# Setup.
|
66 |
download_models()
|
67 |
|
68 |
+
# DepthAnythingV2
|
69 |
+
def get_dam2_model():
|
70 |
+
model_configs = {
|
71 |
+
'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
|
72 |
+
'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
|
73 |
+
'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
|
74 |
+
}
|
75 |
+
|
76 |
+
encoder = 'vitl'
|
77 |
+
encoder_size_map = {'vits': 'Small', 'vitb': 'Base', 'vitl': 'Large'}
|
78 |
+
|
79 |
+
if encoder not in encoder_size_map:
|
80 |
+
raise ValueError(f"Unsupported encoder: {encoder}. Supported: {list(encoder_size_map.keys())}")
|
81 |
+
|
82 |
+
dam2 = DepthAnythingV2(**model_configs[encoder])
|
83 |
+
dam2_checkpoint = f'checkpoints/depth_anything_v2_{encoder}.pth'
|
84 |
+
dam2.load_state_dict(torch.load(dam2_checkpoint, map_location='cpu'))
|
85 |
+
dam2 = dam2.to(DEVICE).eval()
|
86 |
+
return dam2
|
87 |
+
|
88 |
+
# GenStereo
|
89 |
+
def get_genstereo_model():
|
90 |
+
genwarp_cfg = dict(
|
91 |
+
pretrained_model_path='checkpoints',
|
92 |
+
checkpoint_name=CHECKPOINT_NAME,
|
93 |
+
half_precision_weights=True
|
94 |
+
)
|
95 |
+
genstereo = GenStereo(cfg=genwarp_cfg, device=DEVICE)
|
96 |
+
return genstereo
|
97 |
+
|
98 |
+
# Adaptive Fusion
|
99 |
+
def get_fusion_model():
|
100 |
+
fusion_model = AdaptiveFusionLayer()
|
101 |
+
fusion_checkpoint = join('checkpoints', CHECKPOINT_NAME, 'fusion_layer.pth')
|
102 |
+
fusion_model.load_state_dict(torch.load(fusion_checkpoint, map_location='cpu'))
|
103 |
+
fusion_model = fusion_model.to(DEVICE).eval()
|
104 |
+
return fusion_model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
|
106 |
# Crop the image to the shorter side.
|
107 |
def crop(img: Image) -> Image:
|
|
|
112 |
else:
|
113 |
left, right = np.ceil((W - H) / 2.), np.floor((W - H) / 2.) + H
|
114 |
top, bottom = 0, H
|
115 |
+
return img.crop((left, top, right, bottom))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
|
117 |
+
def normalize_disp(disp):
|
118 |
+
return (disp - disp.min()) / (disp.max() - disp.min())
|
119 |
|
120 |
+
# Gradio app
|
121 |
with tempfile.TemporaryDirectory() as tmpdir:
|
122 |
with gr.Blocks(
|
123 |
+
title='StereoGen Demo',
|
124 |
css='img {display: inline;}'
|
125 |
) as demo:
|
126 |
# Internal states.
|
127 |
+
src_image = gr.State()
|
128 |
+
src_depth = gr.State()
|
129 |
|
130 |
# Callbacks
|
131 |
+
@spaces.GPU()
|
132 |
def cb_mde(image_file: str):
|
133 |
+
if not image_file:
|
134 |
+
# Return None if no image is provided (e.g., when file is cleared).
|
135 |
+
return None, None, None, None
|
|
|
|
|
|
|
|
|
136 |
|
137 |
+
image = crop(Image.open(image_file).convert('RGB')) # Load image using PIL
|
138 |
+
image = image.resize((IMAGE_SIZE, IMAGE_SIZE))
|
139 |
+
|
140 |
+
image_bgr = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
|
141 |
+
|
142 |
+
dam2 = get_dam2_model()
|
143 |
+
depth_dam2 = dam2.infer_image(image_bgr)
|
144 |
+
depth = torch.tensor(depth_dam2).unsqueeze(0).unsqueeze(0).float()
|
145 |
+
|
146 |
+
depth_image = cv2.applyColorMap((normalize_disp(depth_dam2) * 255).astype(np.uint8), cv2.COLORMAP_JET)
|
147 |
+
|
148 |
+
return image, depth_image, image, depth
|
149 |
|
150 |
@spaces.GPU()
|
151 |
+
def cb_generate(image, depth: Tensor, scale_factor):
|
152 |
+
norm_disp = normalize_disp(depth.cuda())
|
153 |
+
disp = norm_disp * scale_factor / 100 * IMAGE_SIZE
|
154 |
+
|
155 |
+
genstereo = get_genstereo_model()
|
156 |
+
fusion_model = get_fusion_model()
|
157 |
+
|
158 |
+
renders = genstereo(
|
159 |
+
src_image=image,
|
160 |
+
src_disparity=disp,
|
161 |
+
ratio=None,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
162 |
)
|
163 |
+
warped = (renders['warped'] + 1) / 2
|
164 |
+
|
165 |
+
synthesized = renders['synthesized']
|
166 |
+
mask = renders['mask']
|
167 |
+
fusion_image = fusion_model(synthesized.float(), warped.float(), mask.float())
|
168 |
+
|
169 |
+
warped_pil = to_pil_image(warped[0])
|
170 |
+
fusion_pil = to_pil_image(fusion_image[0])
|
171 |
+
|
172 |
+
return warped_pil, fusion_pil
|
173 |
|
174 |
# Blocks.
|
175 |
gr.Markdown(
|
176 |
"""
|
177 |
+
# StereoGen: Towards Open-World Generation of Stereo Images and Unsupervised Matching
|
178 |
+
[](https://qjizhi.github.io/genstereo)
|
179 |
+
[](https://huggingface.co/spaces/FQiao/GenStereo)
|
180 |
+
[](https://github.com/Qjizhi/GenStereo)
|
181 |
+
[](https://huggingface.co/FQiao/GenStereo/tree/main)
|
182 |
[](https://arxiv.org/abs/2405.17251)
|
183 |
+
|
184 |
+
## Introduction
|
185 |
+
This is an official demo for the paper "[Towards Open-World Generation of Stereo Images and Unsupervised Matching](https://qjizhi.github.io/genstereo)". Given an arbitrary reference image, GenStereo can generate the corresponding right-view image.
|
186 |
+
|
187 |
## How to Use
|
188 |
|
189 |
+
1. Upload a reference image to "Left Image"
|
190 |
+
- You can also select an image from "Examples"
|
191 |
+
3. Hit "Generate a right image" button and check the result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
192 |
|
193 |
"""
|
194 |
)
|
195 |
+
file = gr.File(label='Left', file_types=['image'])
|
196 |
+
examples = gr.Examples(
|
197 |
+
examples=['./assets/COCO_val2017_000000070229.jpg',
|
198 |
+
'./assets/COCO_val2017_000000092839.jpg',
|
199 |
+
'./assets/KITTI2015_000003_10.png',
|
200 |
+
'./assets/KITTI2015_000147_10.png'],
|
201 |
+
inputs=file
|
202 |
+
)
|
203 |
with gr.Row():
|
204 |
image_widget = gr.Image(
|
205 |
+
label='Depth', type='filepath',
|
206 |
interactive=False
|
207 |
)
|
208 |
depth_widget = gr.Image(label='Estimated Depth', type='pil')
|
209 |
+
|
210 |
+
# Add scale factor slider
|
211 |
+
scale_slider = gr.Slider(
|
212 |
+
label='Scale Factor',
|
213 |
+
minimum=1.0,
|
214 |
+
maximum=30.0,
|
215 |
+
value=15.0,
|
216 |
+
step=0.1,
|
217 |
+
)
|
218 |
+
|
219 |
+
button = gr.Button('Generate a right image', size='lg', variant='primary')
|
|
|
220 |
with gr.Row():
|
221 |
warped_widget = gr.Image(
|
222 |
label='Warped Image', type='pil', interactive=False
|
223 |
)
|
224 |
gen_widget = gr.Image(
|
225 |
+
label='Generated Right', type='pil', interactive=False
|
226 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
227 |
|
228 |
# Events
|
229 |
+
file.change(
|
230 |
fn=cb_mde,
|
231 |
inputs=file,
|
232 |
+
outputs=[image_widget, depth_widget, src_image, src_depth]
|
|
|
|
|
|
|
|
|
233 |
)
|
234 |
button.click(
|
235 |
fn=cb_generate,
|
236 |
+
inputs=[src_image, src_depth, scale_slider],
|
237 |
outputs=[warped_widget, gen_widget]
|
238 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
239 |
|
240 |
if __name__ == '__main__':
|
241 |
demo.launch()
|