File size: 2,242 Bytes
3324de2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import torch
from accelerate.test_utils.testing import get_backend
from PIL import Image
import os
import sys 
from config import LOGS_DIR, DEPTH_FM_CHECKPOINT, DEPTH_FM_DIR
sys.path.append(DEPTH_FM_DIR + '/depthfm')
from dfm import DepthFM
from unet import UNetModel
import einops
import numpy as np
from torchvision import transforms


class DepthEstimator:
    def __init__(self, image_dir = LOGS_DIR):
        self.device,_,_ = get_backend()
        self.image_dir = image_dir
        self.model = None

    def _load_model(self):
        if self.model is None:
            self.model = DepthFM(DEPTH_FM_CHECKPOINT).to(self.device).eval()
        else:
            self.model = self.model.to(self.device).eval()

    def _unload_model(self):
        if self.model is not None:
            self.model = self.model.to("cpu")
            torch.cuda.empty_cache()


    def estimate_depth(self, image_path : str) -> list:
        print("Estimating depth...")
        predictions_list = []
        self._load_model()
        for img in os.listdir(image_path):
            if img.endswith(".jpg") or img.endswith(".jpeg") or img.endswith(".png"):
                image = Image.open(os.path.join(image_path, img))
                x = np.array(image)
                x = einops.rearrange(x, 'h w c -> c h w')
                x = x / 127.5 - 1
                x = torch.tensor(x, dtype=torch.float32)[None]
                with torch.no_grad():
                    depth = self.model.predict_depth(x.to(self.device), num_steps=2, ensemble_size=4) # returns a tensor
                depth.cpu()
                to_pil = transforms.ToPILImage()
                PIL_image = to_pil(depth.squeeze())
                predictions_list.append({"depth": PIL_image})
                del x, depth
                torch.cuda.empty_cache()
        self._unload_model()
        print("Depth estimation complete.")
        return predictions_list
    
    def visualize(self, predictions_list : list) -> None:
        for (i, prediction) in enumerate(predictions_list):
            prediction["depth"].save(f"depth_{i}.png")


# Estimator = DepthEstimator()
# predictions = Estimator.estimate_depth(Estimator.image_dir)
# Estimator.visualize(predictions)