lorocksUMD's picture
Upload 32 files
0ab266a verified
import glob
import os
import tarfile
from glob import glob
from io import BytesIO
from os.path import join
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from pathlib import Path
import tempfile
import shutil
from DenseAV.denseav.shared import batch
class Tarballer(Dataset):
def __init__(self, source, target, n):
source_path = Path(source)
self.frames = [f.relative_to(source_path) for f in source_path.rglob('*') if f.is_file()]
assert (len(self.frames) > 0)
self.source = source
self.target_dir = target
self.batched = list(batch(self.frames, n))
os.makedirs(self.target_dir, exist_ok=True)
def __len__(self):
return len(self.batched)
def __getitem__(self, item):
with tarfile.open(join(self.target_dir, f"{item}.tar"), "w") as tar:
for relpath in self.batched[item]:
abs_path = os.path.join(self.source, str(relpath)) # Convert to string here
with open(abs_path, "rb") as file:
file_content = file.read()
info = tarfile.TarInfo(name=str(relpath)) # Convert to string here
info.size = len(file_content)
tar.addfile(info, fileobj=BytesIO(file_content))
return 0
class UnTarballer:
def __init__(self, archive_dir, target_dir, remove_source=False):
self.tarballs = sorted(glob(join(archive_dir, "*.tar")))
self.target_dir = target_dir
self.remove_source = remove_source # New flag to determine if source tarball should be removed
os.makedirs(self.target_dir, exist_ok=True)
def __len__(self):
return len(self.tarballs)
def __getitem__(self, item):
with tarfile.open(self.tarballs[item], "r") as tar:
# Create a unique temporary directory inside the target directory
with tempfile.TemporaryDirectory(dir=self.target_dir) as tmpdirname:
tar.extractall(tmpdirname) # Extract to the temporary directory
# Move contents from temporary directory to final target directory
for src_dir, dirs, files in os.walk(tmpdirname):
dst_dir = src_dir.replace(tmpdirname, self.target_dir, 1)
os.makedirs(dst_dir, exist_ok=True)
for file_ in files:
src_file = os.path.join(src_dir, file_)
dst_file = os.path.join(dst_dir, file_)
shutil.move(src_file, dst_file)
# Remove the source tarball if the flag is set to True
if self.remove_source:
os.remove(self.tarballs[item])
return 0
def untar_all(archive_dir, target_dir, remove_source):
loader = DataLoader(UnTarballer(archive_dir, target_dir, remove_source), num_workers=24)
for _ in tqdm(loader):
pass
if __name__ == "__main__":
# loader = DataLoader(Tarballer(
# join("/pytorch-data", "audioset-raw", "audio"),
# join("/pytorch-data", "audioset-raw", "audio_archives")
# ), num_workers=24)
# loader = DataLoader(Tarballer(
# join("/pytorch-data", "audioset-raw", "frames"),
# join("/pytorch-data", "audioset-raw", "frame_archives"),
# 5000
# ), num_workers=24)
# loader = DataLoader(Tarballer(
# join("/pytorch-data", "ADE20KLabels"),
# join("/pytorch-data", "ADE20KLabelsAr"),
# 100
# ), num_workers=24)
#
# for _ in tqdm(loader):
# pass
#
# #
#
untar_all(
join("/pytorch-data", "audioset-raw", "frame_archives"),
join("/pytorch-data", "audioset-raw", "frames_4"))