File size: 3,809 Bytes
e6d4b46
 
 
 
 
 
 
 
 
 
 
 
 
 
0ab266a
e6d4b46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import glob
import os
import tarfile
from glob import glob
from io import BytesIO
from os.path import join

from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from pathlib import Path

import tempfile
import shutil

from DenseAV.denseav.shared import batch

class Tarballer(Dataset):

    def __init__(self, source, target, n):
        source_path = Path(source)
        self.frames = [f.relative_to(source_path) for f in source_path.rglob('*') if f.is_file()]
        assert (len(self.frames) > 0)
        self.source = source
        self.target_dir = target
        self.batched = list(batch(self.frames, n))
        os.makedirs(self.target_dir, exist_ok=True)

    def __len__(self):
        return len(self.batched)

    def __getitem__(self, item):
        with tarfile.open(join(self.target_dir, f"{item}.tar"), "w") as tar:
            for relpath in self.batched[item]:
                abs_path = os.path.join(self.source, str(relpath))  # Convert to string here
                with open(abs_path, "rb") as file:
                    file_content = file.read()
                info = tarfile.TarInfo(name=str(relpath))  # Convert to string here
                info.size = len(file_content)
                tar.addfile(info, fileobj=BytesIO(file_content))

        return 0


class UnTarballer:

    def __init__(self, archive_dir, target_dir, remove_source=False):
        self.tarballs = sorted(glob(join(archive_dir, "*.tar")))
        self.target_dir = target_dir
        self.remove_source = remove_source  # New flag to determine if source tarball should be removed
        os.makedirs(self.target_dir, exist_ok=True)

    def __len__(self):
        return len(self.tarballs)

    def __getitem__(self, item):
        with tarfile.open(self.tarballs[item], "r") as tar:
            # Create a unique temporary directory inside the target directory
            with tempfile.TemporaryDirectory(dir=self.target_dir) as tmpdirname:
                tar.extractall(tmpdirname)  # Extract to the temporary directory

                # Move contents from temporary directory to final target directory
                for src_dir, dirs, files in os.walk(tmpdirname):
                    dst_dir = src_dir.replace(tmpdirname, self.target_dir, 1)
                    os.makedirs(dst_dir, exist_ok=True)
                    for file_ in files:
                        src_file = os.path.join(src_dir, file_)
                        dst_file = os.path.join(dst_dir, file_)
                        shutil.move(src_file, dst_file)

        # Remove the source tarball if the flag is set to True
        if self.remove_source:
            os.remove(self.tarballs[item])

        return 0

def untar_all(archive_dir, target_dir, remove_source):
    loader = DataLoader(UnTarballer(archive_dir, target_dir, remove_source), num_workers=24)
    for _ in tqdm(loader):
        pass


if __name__ == "__main__":
    # loader = DataLoader(Tarballer(
    #     join("/pytorch-data", "audioset-raw", "audio"),
    #     join("/pytorch-data", "audioset-raw", "audio_archives")
    # ), num_workers=24)

    # loader = DataLoader(Tarballer(
    #     join("/pytorch-data", "audioset-raw", "frames"),
    #     join("/pytorch-data", "audioset-raw", "frame_archives"),
    #     5000
    # ), num_workers=24)

    # loader = DataLoader(Tarballer(
    #     join("/pytorch-data", "ADE20KLabels"),
    #     join("/pytorch-data", "ADE20KLabelsAr"),
    #     100
    # ), num_workers=24)
    #
    # for _ in tqdm(loader):
    #     pass
    #
    # #
    #
    untar_all(
        join("/pytorch-data", "audioset-raw", "frame_archives"),
        join("/pytorch-data", "audioset-raw", "frames_4"))