SoundingStreet / DepthEstimator.py
FQiao's picture
Upload 70 files
3324de2 verified
import torch
from accelerate.test_utils.testing import get_backend
from PIL import Image
import os
import sys
from config import LOGS_DIR, DEPTH_FM_CHECKPOINT, DEPTH_FM_DIR
sys.path.append(DEPTH_FM_DIR + '/depthfm')
from dfm import DepthFM
from unet import UNetModel
import einops
import numpy as np
from torchvision import transforms
class DepthEstimator:
def __init__(self, image_dir = LOGS_DIR):
self.device,_,_ = get_backend()
self.image_dir = image_dir
self.model = None
def _load_model(self):
if self.model is None:
self.model = DepthFM(DEPTH_FM_CHECKPOINT).to(self.device).eval()
else:
self.model = self.model.to(self.device).eval()
def _unload_model(self):
if self.model is not None:
self.model = self.model.to("cpu")
torch.cuda.empty_cache()
def estimate_depth(self, image_path : str) -> list:
print("Estimating depth...")
predictions_list = []
self._load_model()
for img in os.listdir(image_path):
if img.endswith(".jpg") or img.endswith(".jpeg") or img.endswith(".png"):
image = Image.open(os.path.join(image_path, img))
x = np.array(image)
x = einops.rearrange(x, 'h w c -> c h w')
x = x / 127.5 - 1
x = torch.tensor(x, dtype=torch.float32)[None]
with torch.no_grad():
depth = self.model.predict_depth(x.to(self.device), num_steps=2, ensemble_size=4) # returns a tensor
depth.cpu()
to_pil = transforms.ToPILImage()
PIL_image = to_pil(depth.squeeze())
predictions_list.append({"depth": PIL_image})
del x, depth
torch.cuda.empty_cache()
self._unload_model()
print("Depth estimation complete.")
return predictions_list
def visualize(self, predictions_list : list) -> None:
for (i, prediction) in enumerate(predictions_list):
prediction["depth"].save(f"depth_{i}.png")
# Estimator = DepthEstimator()
# predictions = Estimator.estimate_depth(Estimator.image_dir)
# Estimator.visualize(predictions)