File size: 5,181 Bytes
7934b29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from pathlib import Path
from time import sleep

import wget
from pytorch_lightning.plugins.environments import LightningEnvironment
from pytorch_lightning.strategies import DDPStrategy, StrategyRegistry

from nemo.utils import logging


def maybe_download_from_cloud(url, filename, subfolder=None, cache_dir=None, refresh_cache=False) -> str:
    """
    Helper function to download pre-trained weights from the cloud
    Args:
        url: (str) URL of storage
        filename: (str) what to download. The request will be issued to url/filename
        subfolder: (str) subfolder within cache_dir. The file will be stored in cache_dir/subfolder. Subfolder can
            be empty
        cache_dir: (str) a cache directory where to download. If not present, this function will attempt to create it.
            If None (default), then it will be $HOME/.cache/torch/NeMo
        refresh_cache: (bool) if True and cached file is present, it will delete it and re-fetch

    Returns:
        If successful - absolute local path to the downloaded file
        else - empty string
    """
    # try:
    if cache_dir is None:
        cache_location = Path.joinpath(Path.home(), ".cache/torch/NeMo")
    else:
        cache_location = cache_dir
    if subfolder is not None:
        destination = Path.joinpath(cache_location, subfolder)
    else:
        destination = cache_location

    if not os.path.exists(destination):
        os.makedirs(destination, exist_ok=True)

    destination_file = Path.joinpath(destination, filename)

    if os.path.exists(destination_file):
        logging.info(f"Found existing object {destination_file}.")
        if refresh_cache:
            logging.info("Asked to refresh the cache.")
            logging.info(f"Deleting file: {destination_file}")
            os.remove(destination_file)
        else:
            logging.info(f"Re-using file from: {destination_file}")
            return str(destination_file)
    # download file
    wget_uri = url + filename
    logging.info(f"Downloading from: {wget_uri} to {str(destination_file)}")
    # NGC links do not work everytime so we try and wait
    i = 0
    max_attempts = 3
    while i < max_attempts:
        i += 1
        try:
            wget.download(wget_uri, str(destination_file))
            if os.path.exists(destination_file):
                return destination_file
            else:
                return ""
        except:
            logging.info(f"Download from cloud failed. Attempt {i} of {max_attempts}")
            sleep(0.05)
            continue
    raise ValueError("Not able to download url right now, please try again.")


class SageMakerDDPStrategy(DDPStrategy):
    @property
    def cluster_environment(self):
        env = LightningEnvironment()
        env.world_size = lambda: int(os.environ["WORLD_SIZE"])
        env.global_rank = lambda: int(os.environ["RANK"])
        return env

    @cluster_environment.setter
    def cluster_environment(self, env):
        # prevents Lightning from overriding the Environment required for SageMaker
        pass


def initialize_sagemaker() -> None:
    """
    Helper function to initiate sagemaker with NeMo.
    This function installs libraries that NeMo requires for the ASR toolkit + initializes sagemaker ddp.
    """

    StrategyRegistry.register(
        name='smddp', strategy=SageMakerDDPStrategy, process_group_backend="smddp", find_unused_parameters=False,
    )

    def _install_system_libraries() -> None:
        os.system('chmod 777 /tmp && apt-get update && apt-get install -y libsndfile1 ffmpeg')

    def _patch_torch_metrics() -> None:
        """
        Patches torchmetrics to not rely on internal state.
        This is because sagemaker DDP overrides the `__init__` function of the modules to do automatic-partitioning.
        """
        from torchmetrics import Metric

        def __new_hash__(self):
            hash_vals = [self.__class__.__name__, id(self)]
            return hash(tuple(hash_vals))

        Metric.__hash__ = __new_hash__

    _patch_torch_metrics()

    if os.environ.get("RANK") and os.environ.get("WORLD_SIZE"):
        import smdistributed.dataparallel.torch.distributed as dist

        # has to be imported, as it overrides torch modules and such when DDP is enabled.
        import smdistributed.dataparallel.torch.torch_smddp

        dist.init_process_group()

        if dist.get_local_rank():
            _install_system_libraries()
        return dist.barrier()  # wait for main process
    _install_system_libraries()
    return