Spaces:
dylanebert
/
Running on Zero

File size: 2,168 Bytes
5f9d349
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import torch

import sys
sys.path.append('./submodules/gaussian-splatting/')

from random import randint
from scene import Scene, GaussianModel
from gaussian_renderer import render
from source.data_utils import scene_cameras_train_test_split

class Warper3DGS(torch.nn.Module):
    def __init__(self, sh_degree,  opt, pipe, dataset, viewpoint_stack, verbose,
                 do_train_test_split=True):
        super(Warper3DGS, self).__init__()
        """
        Init Warper using all the objects necessary for rendering gaussian splats.
        Here we merely link class objects to the objects instantiated outsided the class.
        """
        print("ready!!!7")
        self.gaussians = GaussianModel(sh_degree)
        print("ready!!!8")
        self.gaussians.tmp_radii = torch.zeros((self.gaussians.get_xyz.shape[0]), device="cuda")
        self.render = render
        self.gs_config_opt = opt
        bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
        self.bg = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
        self.pipe = pipe
        print("ready!!!")
        self.scene = Scene(dataset, self.gaussians, shuffle=False)
        print("ready2")
        if do_train_test_split:
            scene_cameras_train_test_split(self.scene, verbose=verbose)

        self.gaussians.training_setup(opt)
        self.viewpoint_stack = viewpoint_stack
        if not self.viewpoint_stack:
            self.viewpoint_stack = self.scene.getTrainCameras().copy()

    def forward(self, viewpoint_cam=None):
        """
        For a provided camera viewpoint_cam we render gaussians from this viewpoint.
        If no camera provided then we use the self.viewpoint_stack (list of cameras).
        If the latter is empty we reinitialize it using the self.scene object.
        """
        if not viewpoint_cam:
            if not self.viewpoint_stack:
                self.viewpoint_stack = self.scene.getTrainCameras().copy()
            viewpoint_cam = self.viewpoint_stack[randint(0, len(self.viewpoint_stack) - 1)]

        render_pkg = self.render(viewpoint_cam, self.gaussians, self.pipe, self.bg)
        return render_pkg