File size: 16,740 Bytes
7934b29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import enum
import logging as _logging
import sys
import threading
import warnings
from contextlib import contextmanager
from logging.handlers import MemoryHandler

from nemo.constants import NEMO_ENV_VARNAME_REDIRECT_LOGS_TO_STDERR, NEMO_ENV_VARNAME_TESTING
from nemo.utils.env_var_parsing import get_envbool
from nemo.utils.formatters.base import BaseNeMoFormatter, DebugNeMoFormatter
from nemo.utils.get_rank import is_global_rank_zero
from nemo.utils.metaclasses import Singleton

__all__ = ["Logger", "LogMode"]


class LogMode(enum.IntEnum):
    EACH = 0  # Log the message each time
    ONCE = 1  # Log the message only once. The same message will not be logged again.


class Logger(metaclass=Singleton):

    # Level 0
    NOTSET = _logging.NOTSET

    # Level 10
    DEBUG = _logging.DEBUG

    # Level 20
    INFO = _logging.INFO

    # Level 30
    WARNING = _logging.WARNING

    # Level 40
    ERROR = _logging.ERROR

    # Level 50
    CRITICAL = _logging.CRITICAL

    _level_names = {
        0: "NOTSET",
        10: "DEBUG",
        20: "INFO",
        30: "WARNING",
        40: "ERROR",
        50: "CRITICAL",
    }

    def __init__(self, capture_warnings=True):

        self._logger = None
        # Multi-GPU runs run in separate processes, thread locks shouldn't be needed
        self._logger_lock = threading.Lock()
        self._handlers = dict()
        self.old_warnings_showwarning = None
        self._define_logger(capture_warnings)
        self.once_logged = set()
        self.rank = 0 if is_global_rank_zero() else "UNK"

    def _define_logger(self, capture_warnings=True):
        """ Creates the logger if not already created. Called in init"""

        # Use double-checked locking to avoid taking lock unnecessarily.
        if self._logger is not None:
            return self._logger

        with self._logger_lock:
            try:
                self._logger = _logging.getLogger("nemo_logger")
                # By default, silence all loggers except the logger for rank 0
                self.remove_stream_handlers()
                # If NEMO_TESTING is set, add a streamhandler to all ranks
                if get_envbool(NEMO_ENV_VARNAME_TESTING, False):
                    old_factory = _logging.getLogRecordFactory()

                    def record_factory(*args, **kwargs):
                        record = old_factory(*args, **kwargs)
                        record.rank = self.rank
                        return record

                    _logging.setLogRecordFactory(record_factory)
                    self.add_stream_handlers(formatter=DebugNeMoFormatter)
                elif is_global_rank_zero():
                    self.add_stream_handlers()

                # Add memoryhandlers, essentially buffers. They are used to save messages that we will flush to file
                # once the appropriate file handlers are added.
                if is_global_rank_zero():
                    # Add a memoryhandler for error messages. Only logged on rank 0
                    self._handlers["memory_err"] = MemoryHandler(-1)
                    self._handlers["memory_err"].addFilter(lambda record: record.levelno > _logging.INFO)
                    formatter = BaseNeMoFormatter
                    self._handlers["memory_err"].setFormatter(formatter())
                    self._logger.addHandler(self._handlers["memory_err"])
                # Add a memoryhandler for all messages on all ranks
                self._handlers["memory_all"] = MemoryHandler(-1)
                formatter = BaseNeMoFormatter
                self._handlers["memory_all"].setFormatter(formatter())
                self._logger.addHandler(self._handlers["memory_all"])

            finally:
                level = Logger.INFO
                if get_envbool(NEMO_ENV_VARNAME_TESTING, False):
                    level = Logger.DEBUG
                self.set_verbosity(verbosity_level=level)
                self.captureWarnings(capture_warnings)

        self._logger.propagate = False

    def remove_stream_handlers(self):
        """ Removes StreamHandler that log to stdout and stderr from the logger."""
        if self._logger is None:
            raise RuntimeError("Impossible to set handlers if the Logger is not predefined")

        # ======== Remove Handler if already existing ========

        try:
            self._logger.removeHandler(self._handlers["stream_stdout"])
            del self._handlers["stream_stdout"]
        except KeyError:
            pass

        try:
            self._logger.removeHandler(self._handlers["stream_stderr"])
            del self._handlers["stream_stderr"]
        except KeyError:
            pass

    def add_stream_handlers(self, formatter=BaseNeMoFormatter):
        """Add StreamHandler that log to stdout and stderr to the logger. INFO and lower logs are streamed to stdout
        while WARNING and higher are streamed to stderr. If the NEMO_ENV_VARNAME_REDIRECT_LOGS_TO_STDERR environment
        variable is set, all logs are sent to stderr instead.
        """
        if self._logger is None:
            raise RuntimeError("Impossible to set handlers if the Logger is not predefined")

        # Add the output handler.
        if get_envbool(NEMO_ENV_VARNAME_REDIRECT_LOGS_TO_STDERR, False):
            self._handlers["stream_stdout"] = _logging.StreamHandler(sys.stderr)

        else:
            self._handlers["stream_stdout"] = _logging.StreamHandler(sys.stdout)
            self._handlers["stream_stdout"].addFilter(lambda record: record.levelno <= _logging.INFO)

            self._handlers["stream_stderr"] = _logging.StreamHandler(sys.stderr)
            self._handlers["stream_stderr"].addFilter(lambda record: record.levelno > _logging.INFO)

        self._handlers["stream_stdout"].setFormatter(formatter())
        self._logger.addHandler(self._handlers["stream_stdout"])

        try:
            self._handlers["stream_stderr"].setFormatter(formatter())
            self._logger.addHandler(self._handlers["stream_stderr"])
        except KeyError:
            pass

    def reset_stream_handler(self, formatter=BaseNeMoFormatter):
        """Removes then adds stream handlers."""
        self.remove_stream_handlers()
        self.add_stream_handlers(formatter=formatter)

    def add_file_handler(self, log_file):
        """Add a FileHandler to logger that logs all messages to a file. If the logger had a MemoryHandler at
        self._handlers["memory_all"], those buffered messages are flushed to the new file, and the MemoryHandler is
        closed."""
        if self._logger is None:
            raise RuntimeError("Impossible to set handlers if the Logger is not predefined")

        self._handlers["file"] = _logging.FileHandler(log_file)
        formatter = BaseNeMoFormatter
        self._handlers["file"].setFormatter(formatter())
        self._logger.addHandler(self._handlers["file"])

        if self._handlers.get("memory_all", None):
            self._handlers["memory_all"].setTarget(self._handlers["file"])
            self._handlers["memory_all"].close()  # flush and remove
            del self._handlers["memory_all"]

    def add_err_file_handler(self, log_file):
        """Add a FileHandler to logger that logs all WARNING and higher messages to a file. If the logger had a
        MemoryHandler at self._handlers["memory_err"], those buffered messages are flushed to the new file, and the
        MemoryHandler is closed."""
        if self._logger is None:
            raise RuntimeError("Impossible to set handlers if the Logger is not predefined")

        self._handlers["file_err"] = _logging.FileHandler(log_file)
        self._handlers["file_err"].addFilter(lambda record: record.levelno > _logging.INFO)

        formatter = BaseNeMoFormatter
        self._handlers["file_err"].setFormatter(formatter())
        self._logger.addHandler(self._handlers["file_err"])

        if self._handlers.get("memory_err", None):
            self._handlers["memory_err"].setTarget(self._handlers["file_err"])
            self._handlers["memory_err"].close()  # flush and remove
            del self._handlers["memory_err"]

    def getEffectiveLevel(self):
        """Return how much logging output will be produced."""
        if self._logger is not None:
            return self._logger.getEffectiveLevel()

    def get_verbosity(self):
        """See getEffectiveLevel"""
        return self.getEffectiveLevel()

    def setLevel(self, verbosity_level):
        """Sets the threshold for what messages will be logged."""
        if self._logger is not None:
            self._logger.setLevel(verbosity_level)

            for handler in self._logger.handlers:
                handler.setLevel(verbosity_level)

    def set_verbosity(self, verbosity_level):
        """See setLevel"""
        self.setLevel(verbosity_level)

    @contextmanager
    def patch_stderr_handler(self, stream):
        """ Sends messages that should log to stderr to stream instead. Useful for unittests """
        if self._logger is not None:
            try:
                old_stream = self._handlers["stream_stderr"].stream
                if old_stream is None:
                    raise ValueError

                # Port backwards set_stream() from python 3.7
                self._handlers["stream_stderr"].acquire()
                try:
                    self._handlers["stream_stderr"].flush()
                    self._handlers["stream_stderr"].stream = stream
                finally:
                    self._handlers["stream_stderr"].release()

                yield stream
            except (KeyError, ValueError):
                raise RuntimeError("Impossible to patch logging handlers if handler does not exist")
            finally:
                # Port backwards set_stream() from python 3.7
                self._handlers["stream_stderr"].acquire()
                try:
                    self._handlers["stream_stderr"].flush()
                    self._handlers["stream_stderr"].stream = old_stream
                finally:
                    self._handlers["stream_stderr"].release()

        else:
            raise RuntimeError("Impossible to patch logging handlers if handler does not exist")

    @contextmanager
    def patch_stdout_handler(self, stream):
        """ Sends messages that should log to stdout to stream instead. Useful for unittests """
        if self._logger is not None:
            try:
                old_stream = self._handlers["stream_stdout"].stream
                if old_stream is None:
                    raise ValueError

                # Port backwards set_stream() from python 3.7
                self._handlers["stream_stdout"].acquire()
                try:
                    self._handlers["stream_stdout"].flush()
                    self._handlers["stream_stdout"].stream = stream
                finally:
                    self._handlers["stream_stdout"].release()

                yield stream
            except (KeyError, ValueError):
                raise RuntimeError("Impossible to patch logging handlers if handler does not exist")
            finally:
                # Port backwards set_stream() from python 3.7
                self._handlers["stream_stdout"].acquire()
                try:
                    self._handlers["stream_stdout"].flush()
                    self._handlers["stream_stdout"].stream = old_stream
                finally:
                    self._handlers["stream_stdout"].release()

        else:
            raise RuntimeError("Impossible to patch logging handlers if handler does not exist")

    @contextmanager
    def temp_verbosity(self, verbosity_level):
        """Sets the a temporary threshold for what messages will be logged."""

        if self._logger is not None:

            old_verbosity = self.get_verbosity()

            try:
                self.set_verbosity(verbosity_level)
                yield

            finally:
                self.set_verbosity(old_verbosity)

        else:
            try:
                yield

            finally:
                pass

    def captureWarnings(self, capture):
        """
        If capture is true, redirect all warnings to the logging package.
        If capture is False, ensure that warnings are not redirected to logging
        but to their original destinations.
        """

        if self._logger is not None:

            if capture and self.old_warnings_showwarning is None:
                # Backup Method
                self.old_warnings_showwarning = warnings.showwarning
                warnings.showwarning = self._showwarning

            elif not capture and self.old_warnings_showwarning is not None:
                # Restore Method
                warnings.showwarning = self.old_warnings_showwarning
                self.old_warnings_showwarning = None

    def _showwarning(self, message, category, filename, lineno, file=None, line=None):
        """
        Implementation of showwarnings which redirects to logging.
        It will call warnings.formatwarning and will log the resulting string
        with level logging.WARNING.
        """
        s = warnings.formatwarning(message, category, filename, lineno, line)
        self.warning("%s", s)

    def _logged_once(self, msg, mode):
        PREFIX_LEN = 12
        if mode == LogMode.ONCE:
            if msg[PREFIX_LEN:] in self.once_logged:
                return True
            self.once_logged.add(msg[PREFIX_LEN:])
        return False

    def debug(self, msg, *args, mode=LogMode.EACH, **kwargs):
        """
        Log 'msg % args' with severity 'DEBUG'.

        To pass exception information, use the keyword argument exc_info with
        a true value, e.g.

        logger.debug("Houston, we have a %s", "thorny problem", exc_info=1)
        """
        if self._logger is not None and self._logger.isEnabledFor(Logger.DEBUG) and not self._logged_once(msg, mode):
            self._logger._log(Logger.DEBUG, msg, args, **kwargs)

    def info(self, msg, *args, mode=LogMode.EACH, **kwargs):
        """
        Log 'msg % args' with severity 'INFO'.

        To pass exception information, use the keyword argument exc_info with
        a true value, e.g.

        logger.info("Houston, we have a %s", "interesting problem", exc_info=1)
        """
        if self._logger is not None and self._logger.isEnabledFor(Logger.INFO) and not self._logged_once(msg, mode):
            self._logger._log(Logger.INFO, msg, args, **kwargs)

    def warning(self, msg, *args, mode=LogMode.EACH, **kwargs):
        """
        Log 'msg % args' with severity 'WARNING'.

        To pass exception information, use the keyword argument exc_info with
        a true value, e.g.

        logger.warning("Houston, we have a %s", "bit of a problem", exc_info=1)
        """
        if self._logger is not None and self._logger.isEnabledFor(Logger.WARNING) and not self._logged_once(msg, mode):
            self._logger._log(Logger.WARNING, msg, args, **kwargs)

    def error(self, msg, *args, mode=LogMode.EACH, **kwargs):
        """
        Log 'msg % args' with severity 'ERROR'.

        To pass exception information, use the keyword argument exc_info with
        a true value, e.g.

        logger.error("Houston, we have a %s", "major problem", exc_info=1)
        """
        if self._logger is not None and self._logger.isEnabledFor(Logger.ERROR) and not self._logged_once(msg, mode):
            self._logger._log(Logger.ERROR, msg, args, **kwargs)

    def critical(self, msg, *args, mode=LogMode.EACH, **kwargs):
        """
        Log 'msg % args' with severity 'CRITICAL'.

        To pass exception information, use the keyword argument exc_info with
        a true value, e.g.

        logger.critical("Houston, we have a %s", "major disaster", exc_info=1)
        """
        if (
            self._logger is not None
            and self._logger.isEnabledFor(Logger.CRITICAL)
            and not self._logged_once(msg, mode)
        ):
            self._logger._log(Logger.CRITICAL, msg, args, **kwargs)