File size: 3,900 Bytes
7934b29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from unittest import mock

import pytest

from nemo import __version__ as NEMO_VERSION
from nemo.utils.data_utils import (
    ais_binary,
    ais_endpoint_to_dir,
    bucket_and_object_from_uri,
    datastore_path_to_webdataset_url,
    is_datastore_path,
    resolve_cache_dir,
)


class TestDataUtils:
    @pytest.mark.unit
    def test_resolve_cache_dir(self):
        """Test cache dir path.
        """
        TEST_NEMO_ENV_CACHE_DIR = 'TEST_NEMO_ENV_CACHE_DIR'
        with mock.patch('nemo.constants.NEMO_ENV_CACHE_DIR', TEST_NEMO_ENV_CACHE_DIR):

            envar_to_resolved_path = {
                '/path/to/cache': '/path/to/cache',
                'relative/path': os.path.join(os.getcwd(), 'relative/path'),
                '': os.path.expanduser(f'~/.cache/torch/NeMo/NeMo_{NEMO_VERSION}'),
            }

            for envar, expected_path in envar_to_resolved_path.items():
                # Set envar
                os.environ[TEST_NEMO_ENV_CACHE_DIR] = envar
                # Check path
                uut_path = resolve_cache_dir().as_posix()
                assert uut_path == expected_path, f'Expected: {expected_path}, got {uut_path}'

    @pytest.mark.unit
    def test_is_datastore_path(self):
        """Test checking for datastore path.
        """
        # Positive examples
        assert is_datastore_path('ais://positive/example')
        # Negative examples
        assert not is_datastore_path('ais/negative/example')
        assert not is_datastore_path('/negative/example')
        assert not is_datastore_path('negative/example')

    @pytest.mark.unit
    def test_bucket_and_object_from_uri(self):
        """Test getting bucket and object from URI.
        """
        # Positive examples
        assert bucket_and_object_from_uri('ais://bucket/object') == ('bucket', 'object')
        assert bucket_and_object_from_uri('ais://bucket_2/object/is/here') == ('bucket_2', 'object/is/here')

        # Negative examples: invalid URI
        with pytest.raises(ValueError):
            bucket_and_object_from_uri('/local/file')

        with pytest.raises(ValueError):
            bucket_and_object_from_uri('local/file')

    @pytest.mark.unit
    def test_ais_endpoint_to_dir(self):
        """Test converting an AIS endpoint to dir.
        """
        assert ais_endpoint_to_dir('http://local:123') == os.path.join('local', '123')
        assert ais_endpoint_to_dir('http://1.2.3.4:567') == os.path.join('1.2.3.4', '567')

        with pytest.raises(ValueError):
            ais_endpoint_to_dir('local:123')

    @pytest.mark.unit
    def test_ais_binary(self):
        """Test cache dir path.
        """
        with mock.patch('shutil.which', lambda x: '/test/path/ais'):
            assert ais_binary() == '/test/path/ais'

        # Negative example: AIS binary cannot be found
        with mock.patch('shutil.which', lambda x: None), mock.patch('os.path.isfile', lambda x: None):
            with pytest.raises(RuntimeError):
                ais_binary()

    @pytest.mark.unit
    def test_datastore_path_to_webdataset_url(self):
        """Test conversion of data store path to an URL for WebDataset.
        """
        assert datastore_path_to_webdataset_url('ais://test/path') == 'pipe:ais get ais://test/path - || true'