Spaces:
Running
Running
File size: 3,809 Bytes
e6d4b46 0ab266a e6d4b46 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
import glob
import os
import tarfile
from glob import glob
from io import BytesIO
from os.path import join
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from pathlib import Path
import tempfile
import shutil
from DenseAV.denseav.shared import batch
class Tarballer(Dataset):
def __init__(self, source, target, n):
source_path = Path(source)
self.frames = [f.relative_to(source_path) for f in source_path.rglob('*') if f.is_file()]
assert (len(self.frames) > 0)
self.source = source
self.target_dir = target
self.batched = list(batch(self.frames, n))
os.makedirs(self.target_dir, exist_ok=True)
def __len__(self):
return len(self.batched)
def __getitem__(self, item):
with tarfile.open(join(self.target_dir, f"{item}.tar"), "w") as tar:
for relpath in self.batched[item]:
abs_path = os.path.join(self.source, str(relpath)) # Convert to string here
with open(abs_path, "rb") as file:
file_content = file.read()
info = tarfile.TarInfo(name=str(relpath)) # Convert to string here
info.size = len(file_content)
tar.addfile(info, fileobj=BytesIO(file_content))
return 0
class UnTarballer:
def __init__(self, archive_dir, target_dir, remove_source=False):
self.tarballs = sorted(glob(join(archive_dir, "*.tar")))
self.target_dir = target_dir
self.remove_source = remove_source # New flag to determine if source tarball should be removed
os.makedirs(self.target_dir, exist_ok=True)
def __len__(self):
return len(self.tarballs)
def __getitem__(self, item):
with tarfile.open(self.tarballs[item], "r") as tar:
# Create a unique temporary directory inside the target directory
with tempfile.TemporaryDirectory(dir=self.target_dir) as tmpdirname:
tar.extractall(tmpdirname) # Extract to the temporary directory
# Move contents from temporary directory to final target directory
for src_dir, dirs, files in os.walk(tmpdirname):
dst_dir = src_dir.replace(tmpdirname, self.target_dir, 1)
os.makedirs(dst_dir, exist_ok=True)
for file_ in files:
src_file = os.path.join(src_dir, file_)
dst_file = os.path.join(dst_dir, file_)
shutil.move(src_file, dst_file)
# Remove the source tarball if the flag is set to True
if self.remove_source:
os.remove(self.tarballs[item])
return 0
def untar_all(archive_dir, target_dir, remove_source):
loader = DataLoader(UnTarballer(archive_dir, target_dir, remove_source), num_workers=24)
for _ in tqdm(loader):
pass
if __name__ == "__main__":
# loader = DataLoader(Tarballer(
# join("/pytorch-data", "audioset-raw", "audio"),
# join("/pytorch-data", "audioset-raw", "audio_archives")
# ), num_workers=24)
# loader = DataLoader(Tarballer(
# join("/pytorch-data", "audioset-raw", "frames"),
# join("/pytorch-data", "audioset-raw", "frame_archives"),
# 5000
# ), num_workers=24)
# loader = DataLoader(Tarballer(
# join("/pytorch-data", "ADE20KLabels"),
# join("/pytorch-data", "ADE20KLabelsAr"),
# 100
# ), num_workers=24)
#
# for _ in tqdm(loader):
# pass
#
# #
#
untar_all(
join("/pytorch-data", "audioset-raw", "frame_archives"),
join("/pytorch-data", "audioset-raw", "frames_4"))
|