|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import base64 |
|
import gzip |
|
import hashlib |
|
import io |
|
import json |
|
import math |
|
import os |
|
import os.path as osp |
|
import random |
|
import re |
|
import sqlite3 |
|
import sys |
|
import tempfile |
|
import uuid |
|
import warnings |
|
from functools import lru_cache, partial |
|
from typing import Any, BinaryIO, Dict, Optional, TypeVar, Union |
|
from urllib.parse import quote, urlparse |
|
|
|
import numpy as np |
|
import torch |
|
import torch.distributed as dist |
|
from torch.utils.data.distributed import DistributedSampler |
|
|
|
from .wids_dl import download_and_open |
|
from .wids_lru import LRUCache |
|
from .wids_mmtar import MMIndexedTar |
|
from .wids_specs import load_dsdesc_and_resolve, urldir |
|
from .wids_tar import TarFileReader, find_index_file |
|
|
|
try: |
|
from torch.utils.data import Dataset, Sampler |
|
except ImportError: |
|
|
|
class Dataset: |
|
pass |
|
|
|
class Sampler: |
|
pass |
|
|
|
|
|
T = TypeVar("T") |
|
|
|
T_co = TypeVar("T_co", covariant=True) |
|
|
|
|
|
def compute_file_md5sum(fname: Union[str, BinaryIO], chunksize: int = 1000000) -> str: |
|
"""Compute the md5sum of a file in chunks. |
|
|
|
Parameters |
|
---------- |
|
fname : Union[str, BinaryIO] |
|
Filename or file object |
|
chunksize : int, optional |
|
Chunk size in bytes, by default 1000000 |
|
|
|
Returns |
|
------- |
|
str |
|
MD5 sum of the file |
|
|
|
Examples |
|
-------- |
|
>>> compute_file_md5sum("test.txt") |
|
'd41d8cd98f00b204e9800998ecf8427e' |
|
""" |
|
md5 = hashlib.md5() |
|
if isinstance(fname, str): |
|
with open(fname, "rb") as f: |
|
for chunk in iter(lambda: f.read(chunksize), b""): |
|
md5.update(chunk) |
|
else: |
|
fname.seek(0) |
|
for chunk in iter(lambda: fname.read(chunksize), b""): |
|
md5.update(chunk) |
|
return md5.hexdigest() |
|
|
|
|
|
def compute_file_md5sum(fname: Union[str, BinaryIO], chunksize: int = 1000000) -> str: |
|
"""Compute the md5sum of a file in chunks.""" |
|
md5 = hashlib.md5() |
|
if isinstance(fname, str): |
|
with open(fname, "rb") as f: |
|
for chunk in iter(lambda: f.read(chunksize), b""): |
|
md5.update(chunk) |
|
else: |
|
fname.seek(0) |
|
for chunk in iter(lambda: fname.read(chunksize), b""): |
|
md5.update(chunk) |
|
return md5.hexdigest() |
|
|
|
|
|
def compute_num_samples(fname): |
|
ds = IndexedTarSamples(fname) |
|
return len(ds) |
|
|
|
|
|
def splitname(fname): |
|
"""Returns the basename and extension of a filename""" |
|
assert "." in fname, "Filename must have an extension" |
|
|
|
basename, extension = os.path.splitext(fname) |
|
return basename, extension |
|
|
|
|
|
|
|
|
|
def group_by_key(names): |
|
"""Group the file names by key. |
|
|
|
Args: |
|
names: A list of file names. |
|
|
|
Returns: |
|
A list of lists of indices, where each sublist contains indices of files |
|
with the same key. |
|
""" |
|
groups = [] |
|
kmaps = {} |
|
for i, fname in enumerate(names): |
|
|
|
if "." not in fname: |
|
print(f"Warning: Ignoring file {fname} (no '.')") |
|
continue |
|
if fname == ".": |
|
print(f"Warning: Ignoring the '.' file.") |
|
continue |
|
key, ext = splitname(fname) |
|
if key not in kmaps: |
|
kmaps[key] = [] |
|
kmaps[key].append(i) |
|
for k, v in kmaps.items(): |
|
groups.append(v) |
|
return groups |
|
|
|
|
|
def default_decoder(sample: Dict[str, Any], format: Optional[Union[bool, str]] = True): |
|
"""A default decoder for webdataset. |
|
|
|
This handles common file extensions: .txt, .cls, .cls2, |
|
.jpg, .png, .json, .npy, .mp, .pt, .pth, .pickle, .pkl. |
|
These are the most common extensions used in webdataset. |
|
For other extensions, users can provide their own decoder. |
|
|
|
Args: |
|
sample: sample, modified in place |
|
""" |
|
sample = dict(sample) |
|
for key, stream in sample.items(): |
|
extensions = key.split(".") |
|
if len(extensions) < 1: |
|
continue |
|
extension = extensions[-1] |
|
if extension in ["gz"]: |
|
decompressed = gzip.decompress(stream.read()) |
|
stream = io.BytesIO(decompressed) |
|
if len(extensions) < 2: |
|
sample[key] = stream |
|
continue |
|
extension = extensions[-2] |
|
if key.startswith("__"): |
|
continue |
|
elif extension in ["txt", "text"]: |
|
value = stream.read() |
|
sample[key] = value.decode("utf-8") |
|
elif extension in ["cls", "cls2"]: |
|
value = stream.read() |
|
sample[key] = int(value.decode("utf-8")) |
|
elif extension in ["jpg", "png", "ppm", "pgm", "pbm", "pnm"]: |
|
if format == "PIL": |
|
import PIL.Image |
|
|
|
sample[key] = PIL.Image.open(stream) |
|
elif format == "numpy": |
|
import numpy as np |
|
|
|
sample[key] = np.asarray(PIL.Image.open(stream)) |
|
else: |
|
raise ValueError(f"Unknown format: {format}") |
|
elif extension == "json": |
|
import json |
|
|
|
value = stream.read() |
|
sample[key] = json.loads(value) |
|
elif extension == "npy": |
|
import numpy as np |
|
|
|
sample[key] = np.load(stream) |
|
elif extension == "mp": |
|
import msgpack |
|
|
|
value = stream.read() |
|
sample[key] = msgpack.unpackb(value, raw=False) |
|
elif extension in ["pt", "pth"]: |
|
import torch |
|
|
|
sample[key] = torch.load(stream) |
|
elif extension in ["pickle", "pkl"]: |
|
import pickle |
|
|
|
sample[key] = pickle.load(stream) |
|
elif extension == "mp4": |
|
|
|
|
|
|
|
|
|
|
|
|
|
sample[key] = io.BytesIO(stream.read()) |
|
return sample |
|
|
|
|
|
def update_dict_with_extend(original_dict, update_dict): |
|
for key, value in update_dict.items(): |
|
if key in original_dict and isinstance(original_dict[key], list) and isinstance(value, list): |
|
original_dict[key].extend(value) |
|
else: |
|
original_dict[key] = value |
|
|
|
|
|
open_itfs = {} |
|
|
|
|
|
class IndexedTarSamples: |
|
"""A class that accesses samples in a tar file. The tar file must follow |
|
WebDataset conventions. The tar file is indexed when the IndexedTarSamples |
|
object is created. The samples are accessed by index using the __getitem__ |
|
method. The __getitem__ method returns a dictionary containing the files |
|
for the sample. The key for each file is the extension of the file name. |
|
The key "__key__" is reserved for the key of the sample (the basename of |
|
each file without the extension). For example, if the tar file contains |
|
the files "sample1.jpg" and "sample1.txt", then the sample with key |
|
"sample1" will be returned as the dictionary {"jpg": ..., "txt": ...}. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
*, |
|
path=None, |
|
stream=None, |
|
md5sum=None, |
|
expected_size=None, |
|
use_mmap=True, |
|
index_file=find_index_file, |
|
): |
|
assert path is not None or stream is not None |
|
|
|
|
|
self.path = path |
|
stream = self.stream = stream or open(path, "rb") |
|
|
|
|
|
if md5sum is not None: |
|
stream.seek(0) |
|
got = compute_file_md5sum(stream) |
|
assert got == md5sum, f"MD5 sum mismatch: expected {md5sum}, got {got}" |
|
stream.seek(0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if use_mmap: |
|
self.reader = MMIndexedTar(stream) |
|
else: |
|
self.reader = TarFileReader(stream, index_file=index_file) |
|
|
|
|
|
all_files = self.reader.names() |
|
|
|
|
|
self.samples = group_by_key(all_files) |
|
|
|
|
|
|
|
|
|
if expected_size is not None: |
|
assert len(self) == expected_size, f"Expected {expected_size} samples, got {len(self)}" |
|
|
|
self.uuid = str(uuid.uuid4()) |
|
|
|
def close(self): |
|
self.reader.close() |
|
if not self.stream.closed: |
|
self.stream.close() |
|
|
|
def __len__(self): |
|
return len(self.samples) |
|
|
|
def __getitem__(self, idx): |
|
|
|
try: |
|
indexes = self.samples[idx] |
|
except IndexError as e: |
|
print(f"[wids-debug] curr idx: {idx}, total sample length: {len(self.samples)} {e}") |
|
raise e |
|
sample = {} |
|
key = None |
|
for i in indexes: |
|
|
|
fname, data = self.reader.get_file(i) |
|
|
|
k, ext = splitname(fname) |
|
|
|
key = key or k |
|
assert key == k |
|
sample[ext] = data |
|
|
|
sample["__key__"] = key |
|
return sample |
|
|
|
def __str__(self): |
|
return f"<IndexedTarSamples-{id(self)} {self.path}>" |
|
|
|
def __repr__(self): |
|
return str(self) |
|
|
|
|
|
def hash_localname(dldir="/tmp/_wids_cache"): |
|
os.makedirs(dldir, exist_ok=True) |
|
|
|
connection = sqlite3.connect(os.path.join(dldir, "cache.db")) |
|
cursor = connection.cursor() |
|
cursor.execute("CREATE TABLE IF NOT EXISTS cache (url TEXT PRIMARY KEY, path TEXT, checksum TEXT)") |
|
connection.commit() |
|
|
|
def f(shard): |
|
"""Given a URL, return a local name for the shard.""" |
|
if shard.startswith("pipe:"): |
|
|
|
hex32 = base64.urlsafe_b64encode(hashlib.sha256(shard.encode()).digest())[:32].decode() |
|
return os.path.join(dldir, "pipe__" + hex32) |
|
else: |
|
|
|
dirname = urldir(shard) |
|
hex16 = base64.urlsafe_b64encode(hashlib.sha256(dirname.encode()).digest())[:16].decode() |
|
|
|
cachename = "data__" + hex16 + "__" + os.path.basename(urlparse(shard).path) |
|
checksum = None |
|
cursor.execute( |
|
"INSERT OR REPLACE INTO cache VALUES (?, ?, ?)", |
|
(shard, cachename, checksum), |
|
) |
|
connection.commit() |
|
return os.path.join(dldir, cachename) |
|
|
|
return f |
|
|
|
|
|
def cache_localname(cachedir): |
|
os.makedirs(cachedir, exist_ok=True) |
|
|
|
def f(shard): |
|
"""Given a URL, return a local name for the shard.""" |
|
path = urlparse(shard).path |
|
fname = os.path.basename(path) |
|
return os.path.join(cachedir, fname) |
|
|
|
return f |
|
|
|
|
|
def default_localname(dldir="/tmp/_wids_cache"): |
|
os.makedirs(dldir, exist_ok=True) |
|
|
|
def f(shard): |
|
"""Given a URL, return a local name for the shard.""" |
|
cachename = quote(shard, safe="+-") |
|
return os.path.join(dldir, cachename) |
|
|
|
return f |
|
|
|
|
|
class LRUShards: |
|
"""A class that manages a cache of shards. The cache is a LRU cache that |
|
stores the local names of the shards as keys and the downloaded paths as |
|
values. The shards are downloaded to a directory specified by dldir. |
|
The local name of a shard is computed by the localname function, which |
|
takes the shard URL as an argument. If keep is True, the downloaded files |
|
are not deleted when they are no longer needed. |
|
""" |
|
|
|
def __init__(self, lru_size, keep=False, localname=default_localname()): |
|
self.localname = localname |
|
|
|
self.lru = LRUCache(lru_size, release_handler=self.release_handler) |
|
|
|
self.reset_stats() |
|
|
|
def reset_stats(self): |
|
self.accesses = 0 |
|
self.misses = 0 |
|
|
|
def __len__(self): |
|
return len(self.lru) |
|
|
|
def release_handler(self, key, value): |
|
value.close() |
|
|
|
def clear(self): |
|
self.lru.clear() |
|
|
|
def get_shard(self, url): |
|
assert isinstance(url, str) |
|
self.accesses += 1 |
|
if url not in self.lru: |
|
local = self.localname(url) |
|
with download_and_open(url, local) as stream: |
|
itf = IndexedTarSamples(path=local, stream=stream) |
|
self.lru[url] = itf |
|
self.misses += 1 |
|
self.last_missed = True |
|
else: |
|
self.last_missed = False |
|
return self.lru[url] |
|
|
|
|
|
def interpret_transformations(transformations): |
|
"""Interpret the transformations argument. |
|
|
|
This takes care of transformations specified as string shortcuts |
|
and returns a list of callables. |
|
""" |
|
if not isinstance(transformations, list): |
|
transformations = [transformations] |
|
|
|
result = [] |
|
|
|
for transformation in transformations: |
|
if transformation == "PIL": |
|
transformation = partial(default_decoder, format="PIL") |
|
elif transformation == "numpy": |
|
transformation = partial(default_decoder, format="numpy") |
|
else: |
|
assert callable(transformation) |
|
result.append(transformation) |
|
|
|
return result |
|
|
|
|
|
def hash_dataset_name(input_string): |
|
"""Compute a hash of the input string and return the first 16 characters of the hash.""" |
|
|
|
hash_object = hashlib.sha256(input_string.encode()) |
|
hash_digest = hash_object.digest() |
|
|
|
|
|
base64_encoded_hash = base64.urlsafe_b64encode(hash_digest) |
|
|
|
|
|
return base64_encoded_hash[:16].decode("ascii") |
|
|
|
|
|
@lru_cache(maxsize=16) |
|
def lru_json_load(fpath): |
|
with open(fpath) as fp: |
|
return json.load(fp) |
|
|
|
|
|
class ShardListDataset(Dataset[T]): |
|
"""An indexable dataset based on a list of shards. |
|
|
|
The dataset is either given as a list of shards with optional options and name, |
|
or as a URL pointing to a JSON descriptor file. |
|
|
|
Datasets can reference other datasets via `source_url`. |
|
|
|
Shard references within a dataset are resolve relative to an explicitly |
|
given `base` property, or relative to the URL from which the dataset |
|
descriptor was loaded. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
shards, |
|
*, |
|
cache_size=int(1e12), |
|
cache_dir=None, |
|
lru_size=10, |
|
dataset_name=None, |
|
localname=None, |
|
transformations="PIL", |
|
keep=False, |
|
base=None, |
|
options=None, |
|
): |
|
"""Create a ShardListDataset. |
|
|
|
Args: |
|
shards: a list of (filename, length) pairs or a URL pointing to a JSON descriptor file |
|
cache_size: the number of shards to keep in the cache |
|
lru_size: the number of shards to keep in the LRU cache |
|
localname: a function that maps URLs to local filenames |
|
|
|
Note that there are two caches: an on-disk directory, and an in-memory LRU cache. |
|
""" |
|
if options is None: |
|
options = {} |
|
super().__init__() |
|
|
|
|
|
|
|
if isinstance(shards, (str, io.IOBase)): |
|
if base is None and isinstance(shards, str): |
|
shards = osp.expanduser(shards) |
|
base = urldir(shards) |
|
self.base = base |
|
self.spec = load_dsdesc_and_resolve(shards, options=options, base=base) |
|
self.shards = self.spec.get("shardlist", []) |
|
self.dataset_name = self.spec.get("name") or hash_dataset_name(str(shards)) |
|
else: |
|
raise NotImplementedError("Only support taking path/url to JSON descriptor file.") |
|
self.base = None |
|
self.spec = options |
|
self.shards = shards |
|
self.dataset_name = dataset_name or hash_dataset_name(str(shards)) |
|
|
|
self.lengths = [shard["nsamples"] for shard in self.shards] |
|
self.cum_lengths = np.cumsum(self.lengths) |
|
self.total_length = self.cum_lengths[-1] |
|
|
|
if cache_dir is not None: |
|
|
|
|
|
self.cache_dir = cache_dir |
|
self.localname = cache_localname(cache_dir) |
|
elif localname is not None: |
|
|
|
self.cache_dir = None |
|
self.localname = localname |
|
else: |
|
import getpass |
|
|
|
|
|
self.cache_dir = os.environ.get("WIDS_CACHE", f"~/.cache/_wids_cache") |
|
self.cache_dir = osp.expanduser(self.cache_dir) |
|
self.localname = default_localname(self.cache_dir) |
|
|
|
self.data_info = ( |
|
f"[WebShardedList] {str(shards)}, base: {self.base,}, name: {self.spec.get('name')}, " |
|
f"nfiles: {str(len(self.shards))}" |
|
) |
|
if True or int(os.environ.get("WIDS_VERBOSE", 0)): |
|
nbytes = sum(shard.get("filesize", 0) for shard in self.shards) |
|
nsamples = sum(shard["nsamples"] for shard in self.shards) |
|
self.data_info += f"nbytes: {str(nbytes)}, samples: {str(nsamples),}, cache: {self.cache_dir} " |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.transformations = interpret_transformations(transformations) |
|
|
|
if lru_size > 200: |
|
warnings.warn("LRU size is very large; consider reducing it to avoid running out of file descriptors") |
|
self.cache = LRUShards(lru_size, localname=self.localname, keep=keep) |
|
|
|
def add_transform(self, transform): |
|
"""Add a transformation to the dataset.""" |
|
self.transformations.append(transform) |
|
return self |
|
|
|
def __len__(self): |
|
"""Return the total number of samples in the dataset.""" |
|
return self.total_length |
|
|
|
def get_stats(self): |
|
"""Return the number of cache accesses and misses.""" |
|
return self.cache.accesses, self.cache.misses |
|
|
|
def check_cache_misses(self): |
|
"""Check if the cache miss rate is too high.""" |
|
accesses, misses = self.get_stats() |
|
if accesses > 100 and misses / accesses > 0.3: |
|
|
|
self.check_cache_misses = lambda: None |
|
print(f"Warning: ShardListDataset has a cache miss rate of {misses * 100.0 / accesses:.1%}%") |
|
|
|
def get_shard(self, index): |
|
"""Get the shard and index within the shard corresponding to the given index.""" |
|
|
|
shard_idx = np.searchsorted(self.cum_lengths, index, side="right") |
|
|
|
|
|
|
|
if shard_idx == 0: |
|
inner_idx = index |
|
else: |
|
inner_idx = index - self.cum_lengths[shard_idx - 1] |
|
|
|
|
|
desc = self.shards[shard_idx] |
|
url = desc["url"] |
|
if url.startswith(("https://", "http://", "gs://", "/", "~")): |
|
|
|
url = url |
|
else: |
|
|
|
if self.base is None and "base_path" not in self.spec: |
|
raise FileNotFoundError("passing a relative path in shardlist but no base found.") |
|
base_path = self.spec["base_path"] if "base_path" in self.spec else self.base |
|
url = osp.abspath(osp.join(osp.expanduser(base_path), url)) |
|
|
|
desc["url"] = url |
|
try: |
|
shard = self.cache.get_shard(url) |
|
except UnicodeDecodeError as e: |
|
print("UnicodeDecodeError:", desc) |
|
raise e |
|
return shard, inner_idx, desc |
|
|
|
def __getitem__(self, index): |
|
"""Return the sample corresponding to the given index.""" |
|
shard, inner_idx, desc = self.get_shard(index) |
|
sample = shard[inner_idx] |
|
|
|
|
|
self.check_cache_misses() |
|
|
|
sample["__dataset__"] = desc.get("dataset") |
|
sample["__index__"] = index |
|
sample["__shard__"] = desc["url"] |
|
sample["__shardindex__"] = inner_idx |
|
|
|
|
|
for transform in self.transformations: |
|
sample = transform(sample) |
|
|
|
return sample |
|
|
|
def close(self): |
|
"""Close the dataset.""" |
|
self.cache.clear() |
|
|
|
|
|
class ShardListDatasetMulti(ShardListDataset): |
|
"""An indexable dataset based on a list of shards. |
|
|
|
The dataset is either given as a list of shards with optional options and name, |
|
or as a URL pointing to a JSON descriptor file. |
|
|
|
Datasets can reference other datasets via `source_url`. |
|
|
|
Shard references within a dataset are resolve relative to an explicitly |
|
given `base` property, or relative to the URL from which the dataset |
|
descriptor was loaded. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
shards, |
|
*, |
|
cache_size=int(1e12), |
|
cache_dir=None, |
|
lru_size=10, |
|
dataset_name=None, |
|
localname=None, |
|
transformations="PIL", |
|
keep=False, |
|
base=None, |
|
options=None, |
|
sort_data_inseq=False, |
|
num_replicas=None, |
|
): |
|
"""Create a ShardListDataset. |
|
|
|
Args: |
|
shards: a list of (filename, length) pairs or a URL pointing to a JSON descriptor file |
|
cache_size: the number of shards to keep in the cache |
|
lru_size: the number of shards to keep in the LRU cache |
|
localname: a function that maps URLs to local filenames |
|
|
|
Note that there are two caches: an on-disk directory, and an in-memory LRU cache. |
|
""" |
|
if options is None: |
|
options = {} |
|
|
|
|
|
|
|
shards_lists = shards if isinstance(shards, list) else [shards] |
|
bases = base if isinstance(base, list) else [base] * len(shards_lists) |
|
self.spec = {} |
|
self.shards = [] |
|
self.num_per_dir = {} |
|
for base, shards in zip(bases, shards_lists): |
|
if isinstance(shards, (str, io.IOBase)): |
|
if base is None and isinstance(shards, str): |
|
shards = osp.expanduser(shards) |
|
base = urldir(shards) |
|
self.base = base |
|
_spec = load_dsdesc_and_resolve(shards, options=options, base=base) |
|
update_dict_with_extend(self.spec, _spec) |
|
self.num_per_dir[os.path.basename(os.path.dirname(shards))] = sum( |
|
[shard["nsamples"] for shard in _spec.get("shardlist", [])] |
|
) |
|
else: |
|
raise NotImplementedError("Only support taking path/url to JSON descriptor file.") |
|
self.base = None |
|
self.spec = options |
|
self.shards = shards |
|
self.dataset_name = dataset_name or hash_dataset_name(str(shards)) |
|
|
|
if sort_data_inseq and len(self.spec.get("shardlist", [])) > 0: |
|
num_replicas = num_replicas or dist.get_world_size() |
|
self.spec["shardlist"] = split_and_recombine(self.spec["shardlist"], num_replicas) |
|
|
|
self.shards.extend(self.spec.get("shardlist", [])) |
|
self.dataset_name = self.spec.get("name") or hash_dataset_name(str(shards)) |
|
|
|
self.lengths = [shard["nsamples"] for shard in self.shards] |
|
self.cum_lengths = np.cumsum(self.lengths) |
|
self.total_length = self.cum_lengths[-1] |
|
|
|
if cache_dir is not None: |
|
|
|
|
|
self.cache_dir = cache_dir |
|
self.localname = cache_localname(cache_dir) |
|
elif localname is not None: |
|
|
|
self.cache_dir = None |
|
self.localname = localname |
|
else: |
|
import getpass |
|
|
|
|
|
self.cache_dir = os.environ.get("WIDS_CACHE", f"~/.cache/_wids_cache") |
|
self.cache_dir = osp.expanduser(self.cache_dir) |
|
self.localname = default_localname(self.cache_dir) |
|
|
|
self.data_info = ( |
|
f"[WebShardedList] {str(shards)}, base: {self.base,}, name: {self.spec.get('name')}, " |
|
f"nfiles: {str(len(self.shards))}" |
|
) |
|
if True or int(os.environ.get("WIDS_VERBOSE", 0)): |
|
nbytes = sum(shard.get("filesize", 0) for shard in self.shards) |
|
nsamples = sum(shard["nsamples"] for shard in self.shards) |
|
self.data_info += f"nbytes: {str(nbytes)}, samples: {str(nsamples),}, cache: {self.cache_dir} " |
|
self.transformations = interpret_transformations(transformations) |
|
|
|
if lru_size > 200: |
|
warnings.warn("LRU size is very large; consider reducing it to avoid running out of file descriptors") |
|
self.cache = LRUShards(lru_size, localname=self.localname, keep=keep) |
|
|
|
|
|
def split_and_recombine(lst, n): |
|
from collections import OrderedDict |
|
|
|
def extract_prefix(i): |
|
return i["url"].split("/")[-2] |
|
|
|
unique_parts = list(OrderedDict((extract_prefix(item), None) for item in lst).keys()) |
|
split_dict = {part: [] for part in unique_parts} |
|
|
|
for part in unique_parts: |
|
part_list = [item for item in lst if extract_prefix(item) == part] |
|
chunk_size = max(1, len(part_list) // n) |
|
chunks = [part_list[i * chunk_size : (i + 1) * chunk_size] for i in range(n)] |
|
|
|
|
|
if len(part_list) % n != 0: |
|
chunks[-1].extend(part_list[n * chunk_size :]) |
|
|
|
split_dict[part] = chunks |
|
|
|
recombined_list = [] |
|
for i in range(n): |
|
for part in unique_parts: |
|
recombined_list.extend(split_dict[part][i]) |
|
|
|
return recombined_list |
|
|
|
|
|
def lengths_to_ranges(lengths): |
|
"""Convert a list of lengths to a list of ranges.""" |
|
ranges = [] |
|
start = 0 |
|
for length in lengths: |
|
ranges.append((start, start + length)) |
|
start += length |
|
return ranges |
|
|
|
|
|
def intersect_range(a, b): |
|
"""Return the intersection of the two half-open integer intervals.""" |
|
result = max(a[0], b[0]), min(a[1], b[1]) |
|
if result[0] >= result[1]: |
|
return None |
|
return result |
|
|
|
|
|
def intersect_ranges(rangelist, r): |
|
"""Return the intersection of the half-open integer interval r with the list of half-open integer intervals.""" |
|
result = [] |
|
for a in rangelist: |
|
x = intersect_range(a, r) |
|
if x is not None: |
|
result.append(x) |
|
return result |
|
|
|
|
|
def iterate_ranges(ranges, rng, indexshuffle=True, shardshuffle=True): |
|
"""Iterate over the ranges in a random order.""" |
|
shard_indexes = list(range(len(ranges))) |
|
if shardshuffle: |
|
rng.shuffle(shard_indexes) |
|
for i in shard_indexes: |
|
lo, hi = ranges[i] |
|
sample_indexes = list(range(lo, hi)) |
|
if indexshuffle: |
|
rng.shuffle(sample_indexes) |
|
yield from sample_indexes |
|
|
|
|
|
class ShardListSampler(Sampler): |
|
"""A sampler that samples consistent with a ShardListDataset. |
|
|
|
This sampler is used to sample from a ShardListDataset in a way that |
|
preserves locality. |
|
|
|
This returns a permutation of the indexes by shard, then a permutation of |
|
indexes within each shard. This ensures that the data is accessed in a |
|
way that preserves locality. |
|
|
|
Note that how this ends up splitting data between multiple workers ends up |
|
on the details of the DataLoader. Generally, it will likely load samples from the |
|
same shard in each worker. |
|
|
|
Other more sophisticated shard-aware samplers are possible and will likely |
|
be added. |
|
""" |
|
|
|
def __init__(self, dataset, *, lengths=None, seed=0, shufflefirst=False): |
|
if lengths is None: |
|
lengths = list(dataset.lengths) |
|
self.ranges = lengths_to_ranges(lengths) |
|
self.seed = seed |
|
self.shufflefirst = shufflefirst |
|
self.epoch = 0 |
|
|
|
def __iter__(self): |
|
self.rng = random.Random(self.seed + 1289738273 * self.epoch) |
|
shardshuffle = self.shufflefirst or self.epoch > 0 |
|
yield from iterate_ranges(self.ranges, self.rng, shardshuffle=shardshuffle) |
|
self.epoch += 1 |
|
|
|
|
|
ShardedSampler = ShardListSampler |
|
|
|
|
|
class ChunkedSampler(Sampler): |
|
"""A sampler that samples in chunks and then shuffles the samples within each chunk. |
|
|
|
This preserves locality of reference while still shuffling the data. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
dataset, |
|
*, |
|
num_samples=None, |
|
chunksize=2000, |
|
seed=0, |
|
shuffle=False, |
|
shufflefirst=False, |
|
): |
|
if isinstance(num_samples, int): |
|
lo, hi = 0, num_samples |
|
elif num_samples is None: |
|
lo, hi = 0, len(dataset) |
|
else: |
|
lo, hi = num_samples |
|
self.ranges = [(i, min(i + chunksize, hi)) for i in range(lo, hi, chunksize)] |
|
self.seed = seed |
|
self.shuffle = shuffle |
|
self.shufflefirst = shufflefirst |
|
self.epoch = 0 |
|
|
|
def set_epoch(self, epoch): |
|
self.epoch = epoch |
|
|
|
def __iter__(self): |
|
self.rng = random.Random(self.seed + 1289738273 * self.epoch) |
|
shardshuffle = self.shufflefirst or self.epoch > 0 |
|
yield from iterate_ranges( |
|
self.ranges, |
|
self.rng, |
|
indexshuffle=self.shuffle, |
|
shardshuffle=(self.shuffle and shardshuffle), |
|
) |
|
self.epoch += 1 |
|
|
|
def __len__(self): |
|
return len(self.ranges) |
|
|
|
|
|
def DistributedChunkedSampler( |
|
dataset: Dataset, |
|
*, |
|
num_replicas: Optional[int] = None, |
|
num_samples: Optional[int] = None, |
|
rank: Optional[int] = None, |
|
shuffle: bool = True, |
|
shufflefirst: bool = False, |
|
seed: int = 0, |
|
drop_last: bool = None, |
|
chunksize: int = 1000000, |
|
) -> ChunkedSampler: |
|
"""Return a ChunkedSampler for the current worker in distributed training. |
|
|
|
Reverts to a simple ChunkedSampler if not running in distributed mode. |
|
|
|
Since the split among workers takes place before the chunk shuffle, |
|
workers end up with a fixed set of shards they need to download. The |
|
more workers, the fewer shards are used by each worker. |
|
""" |
|
if drop_last is not None: |
|
warnings.warn("DistributedChunkedSampler does not support drop_last, thus it will be ignored") |
|
if not dist.is_initialized(): |
|
warnings.warn("DistributedChunkedSampler is called without distributed initialized; assuming single process") |
|
num_replicas = 1 |
|
rank = 0 |
|
else: |
|
num_replicas = num_replicas or dist.get_world_size() |
|
rank = rank or dist.get_rank() |
|
assert rank >= 0 and rank < num_replicas |
|
|
|
num_samples = num_samples or len(dataset) |
|
worker_chunk = (num_samples + num_replicas - 1) // num_replicas |
|
worker_start = rank * worker_chunk |
|
worker_end = min(worker_start + worker_chunk, num_samples) |
|
return ChunkedSampler( |
|
dataset, |
|
num_samples=(worker_start, worker_end), |
|
chunksize=chunksize, |
|
seed=seed, |
|
shuffle=shuffle, |
|
shufflefirst=shufflefirst, |
|
) |
|
|
|
|
|
class DistributedRangedSampler(Sampler): |
|
"""A sampler that samples in chunks and then shuffles the samples within each chunk. |
|
|
|
This preserves locality of reference while still shuffling the data. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
dataset: Dataset, |
|
num_replicas: Optional[int] = None, |
|
num_samples: Optional[int] = None, |
|
rank: Optional[int] = None, |
|
drop_last: bool = None, |
|
): |
|
if drop_last is not None: |
|
warnings.warn("DistributedChunkedSampler does not support drop_last, thus it will be ignored") |
|
if not dist.is_initialized(): |
|
warnings.warn( |
|
"DistributedChunkedSampler is called without distributed initialized; assuming single process" |
|
) |
|
num_replicas = 1 |
|
rank = 0 |
|
else: |
|
num_replicas = num_replicas or dist.get_world_size() |
|
rank = rank or dist.get_rank() |
|
assert rank >= 0 and rank < num_replicas |
|
num_samples = num_samples or len(dataset) |
|
self.worker_chunk = num_samples // num_replicas |
|
self.worker_start = rank * self.worker_chunk |
|
self.worker_end = min((rank + 1) * self.worker_chunk, num_samples) |
|
self.ranges = range(self.worker_start, self.worker_end) |
|
self.epoch = 0 |
|
self.step_start = 0 |
|
|
|
def set_epoch(self, epoch): |
|
self.epoch = epoch |
|
|
|
def __len__(self): |
|
return len(self.ranges) |
|
|
|
def set_start(self, start): |
|
self.step_start = start |
|
|
|
def __iter__(self): |
|
yield from self.ranges[self.step_start :] |
|
self.epoch += 1 |
|
|
|
|
|
class DistributedLocalSampler(DistributedSampler): |
|
def __iter__(self): |
|
if self.shuffle: |
|
|
|
g = torch.Generator() |
|
g.manual_seed(self.seed + self.epoch) |
|
indices = torch.randperm(len(self.dataset), generator=g).tolist() |
|
else: |
|
indices = list(range(len(self.dataset))) |
|
|
|
if not self.drop_last: |
|
|
|
padding_size = self.total_size - len(indices) |
|
if padding_size <= len(indices): |
|
indices += indices[:padding_size] |
|
else: |
|
indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size] |
|
else: |
|
|
|
indices = indices[: self.total_size] |
|
assert len(indices) == self.total_size |
|
|
|
|
|
|
|
chunk_size = self.total_size // self.num_replicas |
|
begin_idx = chunk_size * self.rank |
|
stop_idx = chunk_size * (self.rank + 1) |
|
indices = indices[begin_idx:stop_idx] |
|
|
|
|
|
assert len(indices) == self.num_samples |
|
return iter(indices) |
|
|