File size: 10,057 Bytes
7934b29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import pathlib
import shutil
import subprocess
from typing import Tuple

from nemo import __version__ as NEMO_VERSION
from nemo import constants
from nemo.utils import logging


def resolve_cache_dir() -> pathlib.Path:
    """
    Utility method to resolve a cache directory for NeMo that can be overriden by an environment variable.

    Example:
        NEMO_CACHE_DIR="~/nemo_cache_dir/" python nemo_example_script.py

    Returns:
        A Path object, resolved to the absolute path of the cache directory. If no override is provided,
        uses an inbuilt default which adapts to nemo versions strings.
    """
    override_dir = os.environ.get(constants.NEMO_ENV_CACHE_DIR, "")
    if override_dir == "":
        path = pathlib.Path.joinpath(pathlib.Path.home(), f'.cache/torch/NeMo/NeMo_{NEMO_VERSION}')
    else:
        path = pathlib.Path(override_dir).resolve()
    return path


def is_datastore_path(path) -> bool:
    """Check if a path is from a data object store.
    Currently, only AIStore is supported.
    """
    return path.startswith('ais://')


def is_datastore_cache_shared() -> bool:
    """Check if store cache is shared.
    """
    # Assume cache is shared by default, e.g., as in resolve_cache_dir (~/.cache)
    cache_shared = int(os.environ.get(constants.NEMO_ENV_DATA_STORE_CACHE_SHARED, 1))

    if cache_shared == 0:
        return False
    elif cache_shared == 1:
        return True
    else:
        raise ValueError(f'Unexpected value of env {constants.NEMO_ENV_DATA_STORE_CACHE_SHARED}')


def ais_cache_base() -> str:
    """Return path to local cache for AIS.
    """
    override_dir = os.environ.get(constants.NEMO_ENV_DATA_STORE_CACHE_DIR, "")
    if override_dir == "":
        cache_dir = resolve_cache_dir().as_posix()
    else:
        cache_dir = pathlib.Path(override_dir).resolve().as_posix()

    if cache_dir.endswith(NEMO_VERSION):
        # Prevent re-caching dataset after upgrading NeMo
        cache_dir = os.path.dirname(cache_dir)
    return os.path.join(cache_dir, 'ais')


def ais_endpoint() -> str:
    """Get configured AIS endpoint.
    """
    return os.getenv('AIS_ENDPOINT')


def bucket_and_object_from_uri(uri: str) -> Tuple[str, str]:
    """Parse a path to determine bucket and object path.

    Args:
        uri: Full path to an object on an object store

    Returns:
        Tuple of strings (bucket_name, object_path)
    """
    if not is_datastore_path(uri):
        raise ValueError(f'Provided URI is not a valid store path: {uri}')
    uri_parts = pathlib.PurePath(uri).parts
    bucket = uri_parts[1]
    object_path = pathlib.PurePath(*uri_parts[2:])

    return str(bucket), str(object_path)


def ais_endpoint_to_dir(endpoint: str) -> str:
    """Convert AIS endpoint to a valid dir name.
    Used to build cache location.

    Args:
        endpoint: AIStore endpoint in format https://host:port
    
    Returns:
        Directory formed as `host/port`.
    """
    if not endpoint.startswith('http://'):
        raise ValueError(f'Unexpected format for ais endpoint: {endpoint}')

    endpoint = endpoint.replace('http://', '')
    host, port = endpoint.split(':')
    return os.path.join(host, port)


def ais_binary() -> str:
    """Return location of `ais` binary.
    """
    path = shutil.which('ais')

    if path is not None:
        logging.debug('Found AIS binary at %s', path)
        return path

    logging.warning('AIS binary not found with `which ais`.')

    # Double-check if it exists at the default path
    default_path = '/usr/local/bin/ais'
    if os.path.isfile(default_path):
        logging.info('ais available at the default path: %s', default_path)
        return default_path
    else:
        raise RuntimeError(f'AIS binary not found.')


def datastore_path_to_local_path(store_path: str) -> str:
    """Convert a data store path to a path in a local cache.

    Args:
        store_path: a path to an object on an object store

    Returns:
        Path to the same object in local cache.
    """
    if store_path.startswith('ais://'):
        endpoint = ais_endpoint()
        if endpoint is None:
            raise RuntimeError(f'AIS endpoint not set, cannot resolve {store_path}')

        local_ais_cache = os.path.join(ais_cache_base(), ais_endpoint_to_dir(endpoint))
        store_bucket, store_object = bucket_and_object_from_uri(store_path)
        local_path = os.path.join(local_ais_cache, store_bucket, store_object)
    else:
        raise ValueError(f'Unexpected store path format: {store_path}')

    return local_path


def get_datastore_object(path: str, force: bool = False, num_retries: int = 5) -> str:
    """Download an object from a store path and return the local path.
    If the input `path` is a local path, then nothing will be done, and
    the original path will be returned.

    Args:
        path: path to an object
        force: force download, even if a local file exists
        num_retries: number of retries if the get command fails
    
    Returns:
        Local path of the object.
    """
    if path.startswith('ais://'):
        endpoint = ais_endpoint()
        if endpoint is None:
            raise RuntimeError(f'AIS endpoint not set, cannot resolve {path}')

        local_path = datastore_path_to_local_path(store_path=path)

        if not os.path.isfile(local_path) or force:
            # Either we don't have the file in cache or we force download it
            # Enhancement: if local file is present, check some tag and compare against remote
            local_dir = os.path.dirname(local_path)
            if not os.path.isdir(local_dir):
                os.makedirs(local_dir, exist_ok=True)

            cmd = [ais_binary(), 'get', path, local_path]

            # for now info, later debug
            logging.debug('Downloading from AIS')
            logging.debug('\tendpoint    %s', endpoint)
            logging.debug('\tpath:       %s', path)
            logging.debug('\tlocal path: %s', local_path)
            logging.debug('\tcmd:        %s', subprocess.list2cmdline(cmd))

            done = False
            for n in range(num_retries):
                if not done:
                    try:
                        # Use stdout=subprocess.DEVNULL to prevent showing AIS command on each line
                        subprocess.check_call(cmd, stdout=subprocess.DEVNULL)
                        done = True
                    except subprocess.CalledProcessError as err:
                        logging.warning('Attempt %d of %d failed with: %s', n + 1, num_retries, str(err))

            if not done:
                raise RuntimeError('Download failed: %s', subprocess.list2cmdline(cmd))

        return local_path

    else:
        # Assume the file is local
        return path


class DataStoreObject:
    """A simple class for handling objects in a data store.
    Currently, this class supports objects on AIStore.

    Args:
        store_path: path to a store object
        local_path: path to a local object, may be used to upload local object to store
        get: get the object from a store
    """

    def __init__(self, store_path: str, local_path: str = None, get: bool = False):
        if local_path is not None:
            raise NotImplementedError('Specifying a local path is currently not supported.')

        self._store_path = store_path
        self._local_path = local_path

        if get:
            self.get()

    @property
    def store_path(self) -> str:
        """Return store path of the object.
        """
        return self._store_path

    @property
    def local_path(self) -> str:
        """Return local path of the object.
        """
        return self._local_path

    def get(self, force: bool = False) -> str:
        """Get an object from the store to local cache and return the local path.

        Args:
            force: force download, even if a local file exists

        Returns:
            Path to a local object.
        """
        if not self.local_path:
            # Assume the object needs to be downloaded
            self._local_path = get_datastore_object(self.store_path, force=force)
        return self.local_path

    def put(self, force: bool = False) -> str:
        """Push to remote and return the store path

        Args:
            force: force download, even if a local file exists

        Returns:
            Path to a (remote) object object on the object store.
        """
        raise NotImplementedError()

    def __str__(self):
        """Return a human-readable description of the object.
        """
        description = f'{type(self)}: store_path={self.store_path}, local_path={self.local_path}'
        return description


def datastore_path_to_webdataset_url(store_path: str):
    """Convert store_path to a WebDataset URL.

    Args:
        store_path: path to buckets on store

    Returns:
        URL which can be directly used with WebDataset.
    """
    if store_path.startswith('ais://'):
        url = f'pipe:ais get {store_path} - || true'
    else:
        raise ValueError(f'Unknown store path format: {store_path}')

    return url


def datastore_object_get(store_object: DataStoreObject) -> bool:
    """A convenience wrapper for multiprocessing.imap.

    Args:
        store_object: An instance of DataStoreObject

    Returns:
        True if get() returned a path.
    """
    return store_object.get() is not None