OnlyFlow / onlyflow /data /dataset_itr.py
arlaz's picture
initial commit
9bb001a
raw
history blame contribute delete
2.53 kB
import functools
import os
from io import BytesIO
import torch
import torchvision
import torchvision.transforms.v2 as transforms
import webdataset as wds
def _video_shortener(video_tensor, length):
start = torch.randint(0, video_tensor.shape[0] - length, (1,))
return video_tensor[start:start + length]
def select_video_extract(length=16):
return functools.partial(_video_shortener, length=length)
def my_collate_fn(batch):
output = {}
for key in batch[0].keys():
if key == 'video':
output[key] = torch.stack([sample[key] for sample in batch])
else:
output[key] = [sample[key] for sample in batch]
return output
def map_mp4(sample):
return torchvision.io.read_video(BytesIO(sample), output_format="TCHW", pts_unit='sec')[0]
def map_txt(sample):
return sample.decode("utf-8")
class WebVidDataset(wds.DataPipeline):
def __init__(self, batch_size, tar_index, root_path, video_length=16, video_size=256, video_length_offset=0,
horizontal_flip=True, seed=None):
self.dataset_full_path = os.path.join(root_path, f'webvid-uw-{{{tar_index}}}.tar')
if isinstance(video_size, int):
video_size = (video_size, video_size)
for size in video_size:
if size % 8 != 0:
raise ValueError("video_size must be divisible by 8")
self.pipeline = [
wds.SimpleShardList('file:' + str(self.dataset_full_path), seed=seed),
wds.shuffle(50),
wds.split_by_node,
wds.tarfile_to_samples(),
wds.shuffle(100),
wds.split_by_worker,
wds.map_dict(
mp4=map_mp4,
txt=map_txt,
),
wds.map_dict(
mp4=transforms.Compose(
[
select_video_extract(length=video_length + video_length_offset),
transforms.Resize(size=video_size),
transforms.RandomCrop(size=video_size),
transforms.RandomHorizontalFlip() if horizontal_flip else transforms.Identity,
]
)
),
wds.rename_keys(video="mp4", text='txt', keep_unselected=True),
wds.batched(batch_size, collation_fn=my_collate_fn, partial=True)
]
super().__init__(self.pipeline)
self.batch_size = batch_size
self.video_length = video_length
self.video_size = video_size