Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
from accelerate.test_utils.testing import get_backend | |
from PIL import Image | |
import os | |
import sys | |
from config import LOGS_DIR, DEPTH_FM_CHECKPOINT, DEPTH_FM_DIR | |
sys.path.append(DEPTH_FM_DIR + '/depthfm') | |
from dfm import DepthFM | |
from unet import UNetModel | |
import einops | |
import numpy as np | |
from torchvision import transforms | |
class DepthEstimator: | |
def __init__(self, image_dir = LOGS_DIR): | |
self.device,_,_ = get_backend() | |
self.image_dir = image_dir | |
self.model = None | |
def _load_model(self): | |
if self.model is None: | |
self.model = DepthFM(DEPTH_FM_CHECKPOINT).to(self.device).eval() | |
else: | |
self.model = self.model.to(self.device).eval() | |
def _unload_model(self): | |
if self.model is not None: | |
self.model = self.model.to("cpu") | |
torch.cuda.empty_cache() | |
def estimate_depth(self, image_path : str) -> list: | |
print("Estimating depth...") | |
predictions_list = [] | |
self._load_model() | |
for img in os.listdir(image_path): | |
if img.endswith(".jpg") or img.endswith(".jpeg") or img.endswith(".png"): | |
image = Image.open(os.path.join(image_path, img)) | |
x = np.array(image) | |
x = einops.rearrange(x, 'h w c -> c h w') | |
x = x / 127.5 - 1 | |
x = torch.tensor(x, dtype=torch.float32)[None] | |
with torch.no_grad(): | |
depth = self.model.predict_depth(x.to(self.device), num_steps=2, ensemble_size=4) # returns a tensor | |
depth.cpu() | |
to_pil = transforms.ToPILImage() | |
PIL_image = to_pil(depth.squeeze()) | |
predictions_list.append({"depth": PIL_image}) | |
del x, depth | |
torch.cuda.empty_cache() | |
self._unload_model() | |
print("Depth estimation complete.") | |
return predictions_list | |
def visualize(self, predictions_list : list) -> None: | |
for (i, prediction) in enumerate(predictions_list): | |
prediction["depth"].save(f"depth_{i}.png") | |
# Estimator = DepthEstimator() | |
# predictions = Estimator.estimate_depth(Estimator.image_dir) | |
# Estimator.visualize(predictions) | |