Spaces:
Running
Running
import glob | |
import os | |
import tarfile | |
from glob import glob | |
from io import BytesIO | |
from os.path import join | |
from torch.utils.data import Dataset, DataLoader | |
from tqdm import tqdm | |
from pathlib import Path | |
import tempfile | |
import shutil | |
from DenseAV.denseav.shared import batch | |
class Tarballer(Dataset): | |
def __init__(self, source, target, n): | |
source_path = Path(source) | |
self.frames = [f.relative_to(source_path) for f in source_path.rglob('*') if f.is_file()] | |
assert (len(self.frames) > 0) | |
self.source = source | |
self.target_dir = target | |
self.batched = list(batch(self.frames, n)) | |
os.makedirs(self.target_dir, exist_ok=True) | |
def __len__(self): | |
return len(self.batched) | |
def __getitem__(self, item): | |
with tarfile.open(join(self.target_dir, f"{item}.tar"), "w") as tar: | |
for relpath in self.batched[item]: | |
abs_path = os.path.join(self.source, str(relpath)) # Convert to string here | |
with open(abs_path, "rb") as file: | |
file_content = file.read() | |
info = tarfile.TarInfo(name=str(relpath)) # Convert to string here | |
info.size = len(file_content) | |
tar.addfile(info, fileobj=BytesIO(file_content)) | |
return 0 | |
class UnTarballer: | |
def __init__(self, archive_dir, target_dir, remove_source=False): | |
self.tarballs = sorted(glob(join(archive_dir, "*.tar"))) | |
self.target_dir = target_dir | |
self.remove_source = remove_source # New flag to determine if source tarball should be removed | |
os.makedirs(self.target_dir, exist_ok=True) | |
def __len__(self): | |
return len(self.tarballs) | |
def __getitem__(self, item): | |
with tarfile.open(self.tarballs[item], "r") as tar: | |
# Create a unique temporary directory inside the target directory | |
with tempfile.TemporaryDirectory(dir=self.target_dir) as tmpdirname: | |
tar.extractall(tmpdirname) # Extract to the temporary directory | |
# Move contents from temporary directory to final target directory | |
for src_dir, dirs, files in os.walk(tmpdirname): | |
dst_dir = src_dir.replace(tmpdirname, self.target_dir, 1) | |
os.makedirs(dst_dir, exist_ok=True) | |
for file_ in files: | |
src_file = os.path.join(src_dir, file_) | |
dst_file = os.path.join(dst_dir, file_) | |
shutil.move(src_file, dst_file) | |
# Remove the source tarball if the flag is set to True | |
if self.remove_source: | |
os.remove(self.tarballs[item]) | |
return 0 | |
def untar_all(archive_dir, target_dir, remove_source): | |
loader = DataLoader(UnTarballer(archive_dir, target_dir, remove_source), num_workers=24) | |
for _ in tqdm(loader): | |
pass | |
if __name__ == "__main__": | |
# loader = DataLoader(Tarballer( | |
# join("/pytorch-data", "audioset-raw", "audio"), | |
# join("/pytorch-data", "audioset-raw", "audio_archives") | |
# ), num_workers=24) | |
# loader = DataLoader(Tarballer( | |
# join("/pytorch-data", "audioset-raw", "frames"), | |
# join("/pytorch-data", "audioset-raw", "frame_archives"), | |
# 5000 | |
# ), num_workers=24) | |
# loader = DataLoader(Tarballer( | |
# join("/pytorch-data", "ADE20KLabels"), | |
# join("/pytorch-data", "ADE20KLabelsAr"), | |
# 100 | |
# ), num_workers=24) | |
# | |
# for _ in tqdm(loader): | |
# pass | |
# | |
# # | |
# | |
untar_all( | |
join("/pytorch-data", "audioset-raw", "frame_archives"), | |
join("/pytorch-data", "audioset-raw", "frames_4")) | |