File size: 20,248 Bytes
98198a3
6858162
 
252c586
6858162
d89af41
 
 
 
 
 
 
 
 
12822cb
d89af41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f332442
d89af41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a3c6fb1
d89af41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72e9aca
 
a3c6fb1
b100032
d89af41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e4f3fd3
d89af41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72e9aca
 
a3c6fb1
d89af41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72e9aca
 
cdfed4c
 
 
d89af41
cdfed4c
d89af41
cdfed4c
d89af41
cdfed4c
d89af41
 
 
 
 
 
 
 
 
ef9846f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d89af41
8aa62d9
 
 
 
 
 
 
 
 
 
 
 
 
d89af41
 
67776dd
d89af41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12d705c
d89af41
 
1c3d749
 
 
d89af41
 
 
 
 
ef9846f
 
d89af41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85daf75
ebdc008
 
1c3d749
ebdc008
 
 
85daf75
 
 
 
 
d89af41
 
 
b100032
d89af41
 
 
 
 
 
 
 
85daf75
d89af41
 
ebdc008
d89af41
 
ebdc008
d89af41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8aa62d9
 
d89af41
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
import os, subprocess, shlex, sys, gc
from gradio_client import Client

# client = Client("endless-ai/SDXL", hf_token=os.getenv("HF_TOKEN"))

import time
import torch
import numpy as np
import shutil
import argparse
import gradio as gr
import uuid
import spaces

subprocess.run(shlex.split("pip install wheel/diff_gaussian_rasterization-0.0.0-cp310-cp310-linux_x86_64.whl"))
subprocess.run(shlex.split("pip install wheel/simple_knn-0.0.0-cp310-cp310-linux_x86_64.whl"))
subprocess.run(shlex.split("pip install wheel/curope-0.0.0-cp310-cp310-linux_x86_64.whl"))

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
os.sys.path.append(os.path.abspath(os.path.join(BASE_DIR, "submodules", "mast3r")))
os.sys.path.append(os.path.abspath(os.path.join(BASE_DIR, "submodules", "mast3r", "dust3r")))
# os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
from dust3r.inference import inference
from dust3r.model import AsymmetricCroCo3DStereo
from dust3r.utils.device import to_numpy
from dust3r.image_pairs import make_pairs
from dust3r.cloud_opt import global_aligner, GlobalAlignerMode
from utils.dust3r_utils import compute_global_alignment, load_images, storePly, save_colmap_cameras, save_colmap_images

from argparse import ArgumentParser
from arguments import ModelParams, PipelineParams, OptimizationParams
from train_feat2gs import training
from run_video import render_sets
GRADIO_CACHE_FOLDER = './gradio_cache_folder'

from utils.feat_utils import FeatureExtractor
from dust3r.demo import _convert_scene_output_to_glb
#############################################################################################################################################


def get_dust3r_args_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument("--image_size", type=int, default=512, choices=[512, 224], help="image size")
    parser.add_argument("--model_path", type=str, default="naver/DUSt3R_ViTLarge_BaseDecoder_512_dpt", help="path to the model weights")
    parser.add_argument("--device", type=str, default='cuda', help="pytorch device")
    parser.add_argument("--batch_size", type=int, default=1)
    parser.add_argument("--schedule", type=str, default='linear')
    parser.add_argument("--lr", type=float, default=0.01)
    parser.add_argument("--niter", type=int, default=300)
    parser.add_argument("--focal_avg", type=bool, default=True)
    parser.add_argument("--n_views", type=int, default=3)
    parser.add_argument("--base_path", type=str, default=GRADIO_CACHE_FOLDER) 
    parser.add_argument("--feat_dim", type=int, default=256, help="PCA dimension. If None, PCA is not applied, and the original feature dimension is retained.")
    parser.add_argument("--feat_type", type=str, nargs='*', default=["dust3r",], help="Feature type(s). Multiple types can be specified for combination.")
    parser.add_argument("--vis_feat", action="store_true", default=True, help="Visualize features")
    parser.add_argument("--vis_key", type=str, default=None, help="Feature type to visualize (only for mast3r), e.g., 'decfeat' or 'desc'")
    parser.add_argument("--method", type=str, default='dust3r', help="Method of Initialization, e.g., 'dust3r' or 'mast3r'")

    return parser


@spaces.GPU(duration=300)
def run_dust3r(inputfiles, input_path=None):

    if input_path is not None:
        imgs_path = './assets/example/' + input_path
        imgs_names = sorted(os.listdir(imgs_path))

        inputfiles = []
        for imgs_name in imgs_names:
            file_path = os.path.join(imgs_path, imgs_name)
            print(file_path)
            inputfiles.append(file_path)
        print(inputfiles)

    # ------ Step(1) DUSt3R initialization & Feature extraction ------
    # os.system(f"rm -rf {GRADIO_CACHE_FOLDER}")
    parser = get_dust3r_args_parser()
    opt = parser.parse_args()

    method = opt.method

    tmp_user_folder = str(uuid.uuid4()).replace("-", "")
    opt.img_base_path = os.path.join(opt.base_path, tmp_user_folder)
    img_folder_path = os.path.join(opt.img_base_path, "images")    
 
    model = AsymmetricCroCo3DStereo.from_pretrained(opt.model_path).to(opt.device)
    os.makedirs(img_folder_path, exist_ok=True)

    opt.n_views = len(inputfiles)  
    if opt.n_views == 1:
        raise gr.Error("The number of input images should be greater than 1.")
    print("Multiple images: ", inputfiles)
    # for image_file in inputfiles:
    #     image_path = image_file.name if hasattr(image_file, 'name') else image_file
    #     shutil.copy(image_path, img_folder_path)
    for image_path in inputfiles:
        if input_path is not None:
            shutil.copy(image_path, img_folder_path)
        else:
            shutil.move(image_path, img_folder_path)
    train_img_list = sorted(os.listdir(img_folder_path))
    assert len(train_img_list)==opt.n_views, f"Number of images in the folder is not equal to {opt.n_views}"
    images, ori_size = load_images(img_folder_path, size=512) 
    # images, ori_size, imgs_resolution = load_images(img_folder_path, size=512) 
    # resolutions_are_equal = len(set(imgs_resolution)) == 1
    # if resolutions_are_equal == False:
    #     raise gr.Error("The resolution of the input image should be the same.")
    print("ori_size", ori_size)
    start_time = time.time()
    ######################################################
    pairs = make_pairs(images, scene_graph='complete', prefilter=None, symmetrize=True)
    output = inference(pairs, model, opt.device, batch_size=opt.batch_size)

    scene = global_aligner(output, device=opt.device, mode=GlobalAlignerMode.PointCloudOptimizer)
    loss = compute_global_alignment(scene=scene, init="mst", niter=opt.niter, schedule=opt.schedule, lr=opt.lr, focal_avg=opt.focal_avg)
    scene = scene.clean_pointcloud()   

    imgs = to_numpy(scene.imgs)
    focals = scene.get_focals()
    poses = to_numpy(scene.get_im_poses())
    pts3d = to_numpy(scene.get_pts3d())
    scene.min_conf_thr = float(scene.conf_trf(torch.tensor(1.0)))
    confidence_masks = to_numpy(scene.get_masks())
    intrinsics = to_numpy(scene.get_intrinsics())

    ######################################################
    end_time = time.time()
    print(f"Time taken for {opt.n_views} views: {end_time-start_time} seconds")
    
    output_colmap_path=img_folder_path.replace("images", f"sparse/0/{method}")
    
    # Feature extraction for per point(per pixel)
    extractor = FeatureExtractor(images, opt, method)
    feats = extractor(scene=scene)
    feat_type_str = '-'.join(extractor.feat_type)
    output_colmap_path = os.path.join(output_colmap_path, feat_type_str)
    os.makedirs(output_colmap_path, exist_ok=True)

    outfile = _convert_scene_output_to_glb(output_colmap_path, imgs, pts3d, confidence_masks, focals, poses, as_pointcloud=True, cam_size=0.03)
    feat_image_path = os.path.join(opt.img_base_path, "feat_dim0-9_dust3r.png")

    save_colmap_cameras(ori_size, intrinsics, os.path.join(output_colmap_path, 'cameras.txt'))
    save_colmap_images(poses, os.path.join(output_colmap_path, 'images.txt'), train_img_list)
    pts_4_3dgs = np.concatenate([p[m] for p, m in zip(pts3d, confidence_masks)])
    color_4_3dgs = np.concatenate([p[m] for p, m in zip(imgs, confidence_masks)])
    color_4_3dgs = (color_4_3dgs * 255.0).astype(np.uint8)
    feat_4_3dgs = np.concatenate([p[m] for p, m in zip(feats, confidence_masks)])
    storePly(os.path.join(output_colmap_path, f"points3D.ply"), pts_4_3dgs, color_4_3dgs, feat_4_3dgs)    

    del scene
    torch.cuda.empty_cache()
    gc.collect()

    return outfile, feat_image_path, opt, None, None

run_dust3r.zerogpu = True

@spaces.GPU(duration=300)
def run_feat2gs(opt, niter=2000):

    if opt is None:
        raise gr.Error("Please run Step 1 first!")
    
    try:
        if not os.path.exists(opt.img_base_path):
            raise ValueError(f"Input path does not exist: {opt.img_base_path}")
            
        if not os.path.exists(os.path.join(opt.img_base_path, "images")):
            raise ValueError("Input images not found. Please run Step 1 first")
        
        if not os.path.exists(os.path.join(opt.img_base_path, f"sparse/0/{opt.method}")):
            raise ValueError("DUSt3R output not found. Please run Step 1 first")
        
        # ------ Step(2) Readout 3DGS from features & Jointly optimize pose ------
        parser = ArgumentParser(description="Training script parameters")
        lp = ModelParams(parser)
        op = OptimizationParams(parser)
        pp = PipelineParams(parser)
        parser.add_argument('--debug_from', type=int, default=-1)
        parser.add_argument("--test_iterations", nargs="+", type=int, default=[])
        parser.add_argument("--save_iterations", nargs="+", type=int, default=[])
        parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[])
        parser.add_argument("--start_checkpoint", type=str, default = None)
        parser.add_argument("--scene", type=str, default="demo")
        parser.add_argument("--n_views", type=int, default=3)
        parser.add_argument("--get_video", action="store_true")
        parser.add_argument("--optim_pose", type=bool, default=True)
        parser.add_argument("--feat_type", type=str, nargs='*', default=["dust3r",], help="Feature type(s). Multiple types can be specified for combination.")
        parser.add_argument("--method", type=str, default='dust3r', help="Method of Initialization, e.g., 'dust3r' or 'mast3r'")
        parser.add_argument("--feat_dim", type=int, default=256, help="Feture dimension after PCA . If None, PCA is not applied.")
        parser.add_argument("--model", type=str, default='Gft', help="Model of Feat2gs, 'G'='geometry'/'T'='texture'/'A'='all'")
        parser.add_argument("--dataset", default="demo", type=str)
        parser.add_argument("--resize", action="store_true", default=False, 
                        help="If True, resize rendering to square")
        
        args = parser.parse_args(sys.argv[1:])
        args.iterations = niter
        args.save_iterations.append(args.iterations)
        args.model_path = opt.img_base_path + '/output/'    
        args.source_path = opt.img_base_path
        # args.model_path = GRADIO_CACHE_FOLDER + '/output/'    
        # args.source_path = GRADIO_CACHE_FOLDER
        args.iteration = niter
        os.makedirs(args.model_path, exist_ok=True)
        training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from, args)

        output_ply_path = opt.img_base_path + f'/output/point_cloud/iteration_{args.iteration}/point_cloud.ply'
        # output_ply_path = GRADIO_CACHE_FOLDER+ f'/output/point_cloud/iteration_{args.iteration}/point_cloud.ply'

        torch.cuda.empty_cache()
        gc.collect()

        return output_ply_path, args, None

    except Exception as e:
        raise gr.Error(f"Step 2 failed: {str(e)}")

run_feat2gs.zerogpu = True

@spaces.GPU(duration=300)
def run_render(opt, args, cam_traj='ellipse'):
    if opt is None or args is None:
        raise gr.Error("Please run Steps 1 and 2 first!")
    
    try:
        iteration_path = os.path.join(opt.img_base_path, f"output/point_cloud/iteration_{args.iteration}/point_cloud.ply")
        if not os.path.exists(iteration_path):
            raise ValueError("Training results not found. Please run Step 2 first")
        
        # ------ Step(3) Render video with camera trajectory ------
        parser = ArgumentParser(description="Testing script parameters")
        model = ModelParams(parser, sentinel=True)
        pipeline = PipelineParams(parser)
        args.eval = True
        args.get_video = True
        args.n_views = opt.n_views
        args.cam_traj = cam_traj
        render_sets(
            model.extract(args),
            args.iteration,
            pipeline.extract(args),
            args,
        )

        output_video_path = opt.img_base_path + f'/output/videos/demo_{opt.n_views}_view_{args.cam_traj}.mp4'

        torch.cuda.empty_cache()
        gc.collect()

        return output_video_path
    
    except Exception as e:
        raise gr.Error(f"Step 3 failed: {str(e)}")

run_render.zerogpu = True

# @spaces.GPU(duration=1000)
# def process_example(inputfiles, input_path):
#     dust3r_model, feat_image, dust3r_state, _, _ = run_dust3r(inputfiles, input_path=input_path)
    
#     output_model, feat2gs_state, _ = run_feat2gs(dust3r_state, niter=2000)
    
#     output_video = run_render(dust3r_state, feat2gs_state, cam_traj='interpolated')
    
#     return dust3r_model, feat_image, output_model, output_video

def reset_dust3r_state():
    return None, None, None, None, None

def reset_feat2gs_state():
    return None, None, None

_TITLE = '''Feat2GS Demo'''
_DESCRIPTION = '''
<div style="display: flex; justify-content: center; align-items: center;">
    <div style="width: 100%; text-align: center; font-size: 30px;">
        <strong><span style="font-family: 'Comic Sans MS';"><span style="color: #E0933F">Feat</span><span style="color: #B24C33">2</span><span style="color: #E0933F">GS</span></span>: Probing Visual Foundation Models with Gaussian Splatting</strong>
    </div>
</div> 
<p></p>
<div align="center">
    <a style="display:inline-block" href="https://fanegg.github.io/Feat2GS/"><img src='https://img.shields.io/badge/Project-Website-green.svg'></a>&nbsp;
    <a style="display:inline-block" href="https://arxiv.org/abs/2412.09606"><img src="https://img.shields.io/badge/Arxiv-2412.09606-b31b1b.svg?logo=arXiv" alt='arxiv'></a>&nbsp;
    <a style="display:inline-block" href="https://youtu.be/4fT5lzcAJqo?si=_fCSIuXNBSmov2VA"><img src='https://img.shields.io/badge/Video-E33122?logo=Youtube'></a>&nbsp;
    <a style="display:inline-block" href="https://github.com/fanegg/Feat2GS"><img src="https://img.shields.io/badge/Code-black?logo=Github" alt='Code'></a>&nbsp;
    <a title="X" href="https://twitter.com/faneggchen" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
        <img src="https://img.shields.io/badge/@Yue%20Chen-black?logo=X" alt="X">
    </a>&nbsp;
    <a title="Bluesky" href="https://bsky.app/profile/fanegg.bsky.social" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
        <img src="https://img.shields.io/badge/@Yue%20Chen-white?logo=Bluesky" alt="Bluesky">
    </a>
</div>
<p></p>
'''
_CITE_ = r"""
## 📝 **Citation**
If you find our work useful for your research or applications, please consider citing the following paper:
```bibtex
@article{chen2024feat2gs,
  title={Feat2GS: Probing Visual Foundation Models with Gaussian Splatting},
  author={Chen, Yue and Chen, Xingyu and Chen, Anpei and Pons-Moll, Gerard and Xiu, Yuliang},
  journal={arXiv preprint arXiv:2412.09606},
  year={2024}
}
```
"""


# demo = gr.Blocks(title=_TITLE).queue()
demo = gr.Blocks(css=""".gradio-container {margin: 0 !important; min-width: 100%};""", title="Feat2GS Demo", theme=gr.themes.Monochrome()).queue()
with demo:
    dust3r_state = gr.State(None)
    feat2gs_state = gr.State(None)
    render_state = gr.State(None)

    with gr.Row():
        with gr.Column(scale=1):
            with gr.Accordion("🚀 Quickstart", open=False):
                gr.Markdown("""
                1. **Input Images**
                * Upload 2 or more images of the same scene from different views
                * For best results, ensure images have good overlap

                2. **Step 1: DUSt3R Initialization & Feature Extraction**
                * Click "RUN Step 1" to process your images
                * This step estimates initial DUSt3R point cloud and camera poses, and extracts DUSt3R features for each pixel

                3. **Step 2: Readout 3DGS from Features**
                * Set the number of training iterations, larger number leads to better quality but longer time (default: 2000, max: 8000) 
                * Click "RUN Step 2" to optimize the 3D model

                4. **Step 3: Video Rendering**
                * Choose a camera trajectory
                * Click "RUN Step 3" to generate a video of your 3D model
                """)
            
            with gr.Accordion("💡 Tips", open=False):
                gr.Markdown("""
                * Processing time depends on image resolution (**recommended <1K**) and quantity (**recommended 2-6 views**)
                * For optimal performance, test on high-end GPUs (A100/4090)
                * Use the mouse to interact with 3D models:
                  - Left-click: Rotate
                  - Scroll: Zoom
                  - Right-click: Move
                """)

    with gr.Row():
        with gr.Column(scale=1):
            # gr.Markdown('# ' + _TITLE)
            gr.HTML(_DESCRIPTION)

    with gr.Row(variant='panel'):
        with gr.Tab("Input"):
            inputfiles = gr.File(file_count="multiple", label="images")
            input_path = gr.Textbox(visible=False, label="example_path")
            # button_gen = gr.Button("RUN")
            
    with gr.Row(variant='panel'):
        with gr.Tab("Step 1: DUSt3R initialization & Feature extraction"):
            dust3r_run = gr.Button("RUN Step 1")
            with gr.Column(scale=2):
                with gr.Group():
                    dust3r_model = gr.Model3D(
                        label="DUSt3R Output",
                        interactive=False,
                        # camera_position=[0.5, 0.5, 1],
                    )
                    gr.Markdown(
                        """
                        <div class="model-description">
                           &nbsp;&nbsp;Left-click to rotate, Scroll to zoom, and Right-click to move.
                        </div>
                        """
                    )    
                    feat_image = gr.Image(
                        label="Feature Visualization",
                        type="filepath"
                    )


    with gr.Row(variant='panel'):
        with gr.Tab("Step 2: Readout 3DGS from features & Jointly optimize pose"):
            niter = gr.Number(value=2000, precision=0, minimum=2000, maximum=8000, label="Training iterations")
            feat2gs_run = gr.Button("RUN Step 2")
            with gr.Column(scale=1):
                with gr.Group():
                    output_model = gr.Model3D(
                        label="3D Gaussian Splats Output, need more time to visualize",
                        interactive=False,
                        # camera_position=[0.5, 0.5, 1],
                    )
                    gr.Markdown(
                        """
                        <div class="model-description">
                           &nbsp;&nbsp;Failed to visualize or got a GPU runtime error? <a href="https://github.com/fanegg/Feat2GS" target="_blank">Run the demo on your local computer!</a>
                        </div>
                        """
                    )  

    with gr.Row(variant='panel'):
        with gr.Tab("Step 3: Render video with camera trajectory"):
            cam_traj = gr.Dropdown(["arc", "spiral", "lemniscate", "wander", "ellipse", "interpolated"], value='ellipse', label="Camera trajectory")
            render_run = gr.Button("RUN Step 3")
            with gr.Column(scale=1):
                output_video = gr.Video(label="video", height=800)
   
    dust3r_run.click(
        fn=reset_dust3r_state,
        inputs=None,
        outputs=[dust3r_model, feat_image, dust3r_state, feat2gs_state, render_state],
        queue=False
    ).then(
        fn=run_dust3r,
        inputs=[inputfiles],
        outputs=[dust3r_model, feat_image, dust3r_state, feat2gs_state, render_state]
    )
    feat2gs_run.click(
        fn=reset_feat2gs_state,
        inputs=None,
        outputs=[output_model, feat2gs_state, render_state],
        queue=False
    ).then(
        fn=run_feat2gs,
        inputs=[dust3r_state, niter],
        outputs=[output_model, feat2gs_state, render_state]
    )
    render_run.click(run_render, inputs=[dust3r_state, feat2gs_state, cam_traj], outputs=[output_video])

    gr.Markdown(_CITE_)

demo.launch(server_name="0.0.0.0", share=False)