jasonddd123's picture
Upload folder using huggingface_hub
76133cf
import os
import json
import math
import numpy as np
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader, IterableDataset
import torchvision.transforms.functional as TF
import pytorch_lightning as pl
import datasets
from models.ray_utils import get_ray_directions
from utils.misc import get_rank
class BlenderDatasetBase():
def setup(self, config, split):
self.config = config
self.split = split
self.rank = get_rank()
self.has_mask = True
self.apply_mask = True
with open(os.path.join(self.config.root_dir, f"transforms_{self.split}.json"), 'r') as f:
meta = json.load(f)
if 'w' in meta and 'h' in meta:
W, H = int(meta['w']), int(meta['h'])
else:
W, H = 800, 800
if 'img_wh' in self.config:
w, h = self.config.img_wh
assert round(W / w * h) == H
elif 'img_downscale' in self.config:
w, h = W // self.config.img_downscale, H // self.config.img_downscale
else:
raise KeyError("Either img_wh or img_downscale should be specified.")
self.w, self.h = w, h
self.img_wh = (self.w, self.h)
self.near, self.far = self.config.near_plane, self.config.far_plane
self.focal = 0.5 * w / math.tan(0.5 * meta['camera_angle_x']) # scaled focal length
# ray directions for all pixels, same for all images (same H, W, focal)
self.directions = \
get_ray_directions(self.w, self.h, self.focal, self.focal, self.w//2, self.h//2).to(self.rank) # (h, w, 3)
self.all_c2w, self.all_images, self.all_fg_masks = [], [], []
for i, frame in enumerate(meta['frames']):
c2w = torch.from_numpy(np.array(frame['transform_matrix'])[:3, :4])
self.all_c2w.append(c2w)
img_path = os.path.join(self.config.root_dir, f"{frame['file_path']}.png")
img = Image.open(img_path)
img = img.resize(self.img_wh, Image.BICUBIC)
img = TF.to_tensor(img).permute(1, 2, 0) # (4, h, w) => (h, w, 4)
self.all_fg_masks.append(img[..., -1]) # (h, w)
self.all_images.append(img[...,:3])
self.all_c2w, self.all_images, self.all_fg_masks = \
torch.stack(self.all_c2w, dim=0).float().to(self.rank), \
torch.stack(self.all_images, dim=0).float().to(self.rank), \
torch.stack(self.all_fg_masks, dim=0).float().to(self.rank)
class BlenderDataset(Dataset, BlenderDatasetBase):
def __init__(self, config, split):
self.setup(config, split)
def __len__(self):
return len(self.all_images)
def __getitem__(self, index):
return {
'index': index
}
class BlenderIterableDataset(IterableDataset, BlenderDatasetBase):
def __init__(self, config, split):
self.setup(config, split)
def __iter__(self):
while True:
yield {}
@datasets.register('blender')
class BlenderDataModule(pl.LightningDataModule):
def __init__(self, config):
super().__init__()
self.config = config
def setup(self, stage=None):
if stage in [None, 'fit']:
self.train_dataset = BlenderIterableDataset(self.config, self.config.train_split)
if stage in [None, 'fit', 'validate']:
self.val_dataset = BlenderDataset(self.config, self.config.val_split)
if stage in [None, 'test']:
self.test_dataset = BlenderDataset(self.config, self.config.test_split)
if stage in [None, 'predict']:
self.predict_dataset = BlenderDataset(self.config, self.config.train_split)
def prepare_data(self):
pass
def general_loader(self, dataset, batch_size):
sampler = None
return DataLoader(
dataset,
num_workers=os.cpu_count(),
batch_size=batch_size,
pin_memory=True,
sampler=sampler
)
def train_dataloader(self):
return self.general_loader(self.train_dataset, batch_size=1)
def val_dataloader(self):
return self.general_loader(self.val_dataset, batch_size=1)
def test_dataloader(self):
return self.general_loader(self.test_dataset, batch_size=1)
def predict_dataloader(self):
return self.general_loader(self.predict_dataset, batch_size=1)