Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,242 Bytes
3324de2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
import torch
from accelerate.test_utils.testing import get_backend
from PIL import Image
import os
import sys
from config import LOGS_DIR, DEPTH_FM_CHECKPOINT, DEPTH_FM_DIR
sys.path.append(DEPTH_FM_DIR + '/depthfm')
from dfm import DepthFM
from unet import UNetModel
import einops
import numpy as np
from torchvision import transforms
class DepthEstimator:
def __init__(self, image_dir = LOGS_DIR):
self.device,_,_ = get_backend()
self.image_dir = image_dir
self.model = None
def _load_model(self):
if self.model is None:
self.model = DepthFM(DEPTH_FM_CHECKPOINT).to(self.device).eval()
else:
self.model = self.model.to(self.device).eval()
def _unload_model(self):
if self.model is not None:
self.model = self.model.to("cpu")
torch.cuda.empty_cache()
def estimate_depth(self, image_path : str) -> list:
print("Estimating depth...")
predictions_list = []
self._load_model()
for img in os.listdir(image_path):
if img.endswith(".jpg") or img.endswith(".jpeg") or img.endswith(".png"):
image = Image.open(os.path.join(image_path, img))
x = np.array(image)
x = einops.rearrange(x, 'h w c -> c h w')
x = x / 127.5 - 1
x = torch.tensor(x, dtype=torch.float32)[None]
with torch.no_grad():
depth = self.model.predict_depth(x.to(self.device), num_steps=2, ensemble_size=4) # returns a tensor
depth.cpu()
to_pil = transforms.ToPILImage()
PIL_image = to_pil(depth.squeeze())
predictions_list.append({"depth": PIL_image})
del x, depth
torch.cuda.empty_cache()
self._unload_model()
print("Depth estimation complete.")
return predictions_list
def visualize(self, predictions_list : list) -> None:
for (i, prediction) in enumerate(predictions_list):
prediction["depth"].save(f"depth_{i}.png")
# Estimator = DepthEstimator()
# predictions = Estimator.estimate_depth(Estimator.image_dir)
# Estimator.visualize(predictions)
|