Spaces:
Runtime error
Runtime error
File size: 4,881 Bytes
f670afc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, check out LICENSE.md
import io
import json
# import cv2
import boto3
from botocore.config import Config
import numpy as np
import torch.utils.data as data
from PIL import Image
import imageio
from botocore.exceptions import ClientError
from imaginaire.datasets.cache import Cache
from imaginaire.utils.data import IMG_EXTENSIONS, HDR_IMG_EXTENSIONS
Image.MAX_IMAGE_PIXELS = None
class ObjectStoreDataset(data.Dataset):
r"""This deals with opening, and reading from an AWS S3 bucket.
Args:
root (str): Path to the AWS S3 bucket.
aws_credentials_file (str): Path to file containing AWS credentials.
data_type (str): Which data type should this dataset load?
"""
def __init__(self, root, aws_credentials_file, data_type='', cache=None):
# Cache.
self.cache = False
if cache is not None:
# raise NotImplementedError
self.cache = Cache(cache.root, cache.size_GB)
# Get bucket info, and keys to info about dataset.
with open(aws_credentials_file) as fin:
self.credentials = json.load(fin)
parts = root.split('/')
self.bucket = parts[0]
self.all_filenames_key = '/'.join(parts[1:]) + '/all_filenames.json'
self.metadata_key = '/'.join(parts[1:]) + '/metadata.json'
# Get list of filenames.
filename_info = self._get_object(self.all_filenames_key)
self.sequence_list = json.loads(filename_info.decode('utf-8'))
# Get length.
length = 0
for _, value in self.sequence_list.items():
length += len(value)
self.length = length
# Read metadata.
metadata_info = self._get_object(self.metadata_key)
self.extensions = json.loads(metadata_info.decode('utf-8'))
self.data_type = data_type
print('AWS S3 bucket at %s opened.' % (root + '/' + self.data_type))
def _get_object(self, key):
r"""Download object from bucket.
Args:
key (str): Key inside bucket.
"""
# Look up value in cache.
object_content = self.cache.read(key) if self.cache else False
if not object_content:
# Either no cache used or key not found in cache.
config = Config(connect_timeout=30,
signature_version="s3",
retries={"max_attempts": 999999})
s3 = boto3.client('s3', **self.credentials, config=config)
try:
s3_response_object = s3.get_object(Bucket=self.bucket, Key=key)
object_content = s3_response_object['Body'].read()
except Exception as e:
print('%s not found' % (key))
print(e)
# Save content to cache.
if self.cache:
self.cache.write(key, object_content)
return object_content
def getitem_by_path(self, path, data_type):
r"""Load data item stored for key = path.
Args:
path (str): Path into AWS S3 bucket, without data_type prefix.
data_type (str): Key into self.extensions e.g. data/data_segmaps/...
Returns:
img (PIL.Image) or buf (str): Contents of LMDB value for this key.
"""
# Figure out decoding params.
ext = self.extensions[data_type]
is_image = False
is_hdr = False
parts = path.split('/')
key = parts[0] + '/' + data_type + '/' + '/'.join(parts[1:]) + '.' + ext
if ext in IMG_EXTENSIONS:
is_image = True
if 'tif' in ext:
_, mode = np.uint16, -1
elif 'JPEG' in ext or 'JPG' in ext \
or 'jpeg' in ext or 'jpg' in ext:
_, mode = np.uint8, 3
else:
_, mode = np.uint8, -1
elif ext in HDR_IMG_EXTENSIONS:
is_hdr = True
else:
is_image = False
# Get value from key.
buf = self._get_object(key)
# Decode and return.
if is_image:
# This is totally a hack.
# We should have a better way to handle grayscale images.
img = Image.open(io.BytesIO(buf))
if mode == 3:
img = img.convert('RGB')
return img
elif is_hdr:
try:
imageio.plugins.freeimage.download()
img = imageio.imread(buf)
except Exception:
print(path)
return img # Return a numpy array
else:
return buf
def __len__(self):
r"""Return number of keys in LMDB dataset."""
return self.length
|