|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import enum |
|
import logging as _logging |
|
import sys |
|
import threading |
|
import warnings |
|
from contextlib import contextmanager |
|
from logging.handlers import MemoryHandler |
|
|
|
from nemo.constants import NEMO_ENV_VARNAME_REDIRECT_LOGS_TO_STDERR, NEMO_ENV_VARNAME_TESTING |
|
from nemo.utils.env_var_parsing import get_envbool |
|
from nemo.utils.formatters.base import BaseNeMoFormatter, DebugNeMoFormatter |
|
from nemo.utils.get_rank import is_global_rank_zero |
|
from nemo.utils.metaclasses import Singleton |
|
|
|
__all__ = ["Logger", "LogMode"] |
|
|
|
|
|
class LogMode(enum.IntEnum): |
|
EACH = 0 |
|
ONCE = 1 |
|
|
|
|
|
class Logger(metaclass=Singleton): |
|
|
|
|
|
NOTSET = _logging.NOTSET |
|
|
|
|
|
DEBUG = _logging.DEBUG |
|
|
|
|
|
INFO = _logging.INFO |
|
|
|
|
|
WARNING = _logging.WARNING |
|
|
|
|
|
ERROR = _logging.ERROR |
|
|
|
|
|
CRITICAL = _logging.CRITICAL |
|
|
|
_level_names = { |
|
0: "NOTSET", |
|
10: "DEBUG", |
|
20: "INFO", |
|
30: "WARNING", |
|
40: "ERROR", |
|
50: "CRITICAL", |
|
} |
|
|
|
def __init__(self, capture_warnings=True): |
|
|
|
self._logger = None |
|
|
|
self._logger_lock = threading.Lock() |
|
self._handlers = dict() |
|
self.old_warnings_showwarning = None |
|
self._define_logger(capture_warnings) |
|
self.once_logged = set() |
|
self.rank = 0 if is_global_rank_zero() else "UNK" |
|
|
|
def _define_logger(self, capture_warnings=True): |
|
""" Creates the logger if not already created. Called in init""" |
|
|
|
|
|
if self._logger is not None: |
|
return self._logger |
|
|
|
with self._logger_lock: |
|
try: |
|
self._logger = _logging.getLogger("nemo_logger") |
|
|
|
self.remove_stream_handlers() |
|
|
|
if get_envbool(NEMO_ENV_VARNAME_TESTING, False): |
|
old_factory = _logging.getLogRecordFactory() |
|
|
|
def record_factory(*args, **kwargs): |
|
record = old_factory(*args, **kwargs) |
|
record.rank = self.rank |
|
return record |
|
|
|
_logging.setLogRecordFactory(record_factory) |
|
self.add_stream_handlers(formatter=DebugNeMoFormatter) |
|
elif is_global_rank_zero(): |
|
self.add_stream_handlers() |
|
|
|
|
|
|
|
if is_global_rank_zero(): |
|
|
|
self._handlers["memory_err"] = MemoryHandler(-1) |
|
self._handlers["memory_err"].addFilter(lambda record: record.levelno > _logging.INFO) |
|
formatter = BaseNeMoFormatter |
|
self._handlers["memory_err"].setFormatter(formatter()) |
|
self._logger.addHandler(self._handlers["memory_err"]) |
|
|
|
self._handlers["memory_all"] = MemoryHandler(-1) |
|
formatter = BaseNeMoFormatter |
|
self._handlers["memory_all"].setFormatter(formatter()) |
|
self._logger.addHandler(self._handlers["memory_all"]) |
|
|
|
finally: |
|
level = Logger.INFO |
|
if get_envbool(NEMO_ENV_VARNAME_TESTING, False): |
|
level = Logger.DEBUG |
|
self.set_verbosity(verbosity_level=level) |
|
self.captureWarnings(capture_warnings) |
|
|
|
self._logger.propagate = False |
|
|
|
def remove_stream_handlers(self): |
|
""" Removes StreamHandler that log to stdout and stderr from the logger.""" |
|
if self._logger is None: |
|
raise RuntimeError("Impossible to set handlers if the Logger is not predefined") |
|
|
|
|
|
|
|
try: |
|
self._logger.removeHandler(self._handlers["stream_stdout"]) |
|
del self._handlers["stream_stdout"] |
|
except KeyError: |
|
pass |
|
|
|
try: |
|
self._logger.removeHandler(self._handlers["stream_stderr"]) |
|
del self._handlers["stream_stderr"] |
|
except KeyError: |
|
pass |
|
|
|
def add_stream_handlers(self, formatter=BaseNeMoFormatter): |
|
"""Add StreamHandler that log to stdout and stderr to the logger. INFO and lower logs are streamed to stdout |
|
while WARNING and higher are streamed to stderr. If the NEMO_ENV_VARNAME_REDIRECT_LOGS_TO_STDERR environment |
|
variable is set, all logs are sent to stderr instead. |
|
""" |
|
if self._logger is None: |
|
raise RuntimeError("Impossible to set handlers if the Logger is not predefined") |
|
|
|
|
|
if get_envbool(NEMO_ENV_VARNAME_REDIRECT_LOGS_TO_STDERR, False): |
|
self._handlers["stream_stdout"] = _logging.StreamHandler(sys.stderr) |
|
|
|
else: |
|
self._handlers["stream_stdout"] = _logging.StreamHandler(sys.stdout) |
|
self._handlers["stream_stdout"].addFilter(lambda record: record.levelno <= _logging.INFO) |
|
|
|
self._handlers["stream_stderr"] = _logging.StreamHandler(sys.stderr) |
|
self._handlers["stream_stderr"].addFilter(lambda record: record.levelno > _logging.INFO) |
|
|
|
self._handlers["stream_stdout"].setFormatter(formatter()) |
|
self._logger.addHandler(self._handlers["stream_stdout"]) |
|
|
|
try: |
|
self._handlers["stream_stderr"].setFormatter(formatter()) |
|
self._logger.addHandler(self._handlers["stream_stderr"]) |
|
except KeyError: |
|
pass |
|
|
|
def reset_stream_handler(self, formatter=BaseNeMoFormatter): |
|
"""Removes then adds stream handlers.""" |
|
self.remove_stream_handlers() |
|
self.add_stream_handlers(formatter=formatter) |
|
|
|
def add_file_handler(self, log_file): |
|
"""Add a FileHandler to logger that logs all messages to a file. If the logger had a MemoryHandler at |
|
self._handlers["memory_all"], those buffered messages are flushed to the new file, and the MemoryHandler is |
|
closed.""" |
|
if self._logger is None: |
|
raise RuntimeError("Impossible to set handlers if the Logger is not predefined") |
|
|
|
self._handlers["file"] = _logging.FileHandler(log_file) |
|
formatter = BaseNeMoFormatter |
|
self._handlers["file"].setFormatter(formatter()) |
|
self._logger.addHandler(self._handlers["file"]) |
|
|
|
if self._handlers.get("memory_all", None): |
|
self._handlers["memory_all"].setTarget(self._handlers["file"]) |
|
self._handlers["memory_all"].close() |
|
del self._handlers["memory_all"] |
|
|
|
def add_err_file_handler(self, log_file): |
|
"""Add a FileHandler to logger that logs all WARNING and higher messages to a file. If the logger had a |
|
MemoryHandler at self._handlers["memory_err"], those buffered messages are flushed to the new file, and the |
|
MemoryHandler is closed.""" |
|
if self._logger is None: |
|
raise RuntimeError("Impossible to set handlers if the Logger is not predefined") |
|
|
|
self._handlers["file_err"] = _logging.FileHandler(log_file) |
|
self._handlers["file_err"].addFilter(lambda record: record.levelno > _logging.INFO) |
|
|
|
formatter = BaseNeMoFormatter |
|
self._handlers["file_err"].setFormatter(formatter()) |
|
self._logger.addHandler(self._handlers["file_err"]) |
|
|
|
if self._handlers.get("memory_err", None): |
|
self._handlers["memory_err"].setTarget(self._handlers["file_err"]) |
|
self._handlers["memory_err"].close() |
|
del self._handlers["memory_err"] |
|
|
|
def getEffectiveLevel(self): |
|
"""Return how much logging output will be produced.""" |
|
if self._logger is not None: |
|
return self._logger.getEffectiveLevel() |
|
|
|
def get_verbosity(self): |
|
"""See getEffectiveLevel""" |
|
return self.getEffectiveLevel() |
|
|
|
def setLevel(self, verbosity_level): |
|
"""Sets the threshold for what messages will be logged.""" |
|
if self._logger is not None: |
|
self._logger.setLevel(verbosity_level) |
|
|
|
for handler in self._logger.handlers: |
|
handler.setLevel(verbosity_level) |
|
|
|
def set_verbosity(self, verbosity_level): |
|
"""See setLevel""" |
|
self.setLevel(verbosity_level) |
|
|
|
@contextmanager |
|
def patch_stderr_handler(self, stream): |
|
""" Sends messages that should log to stderr to stream instead. Useful for unittests """ |
|
if self._logger is not None: |
|
try: |
|
old_stream = self._handlers["stream_stderr"].stream |
|
if old_stream is None: |
|
raise ValueError |
|
|
|
|
|
self._handlers["stream_stderr"].acquire() |
|
try: |
|
self._handlers["stream_stderr"].flush() |
|
self._handlers["stream_stderr"].stream = stream |
|
finally: |
|
self._handlers["stream_stderr"].release() |
|
|
|
yield stream |
|
except (KeyError, ValueError): |
|
raise RuntimeError("Impossible to patch logging handlers if handler does not exist") |
|
finally: |
|
|
|
self._handlers["stream_stderr"].acquire() |
|
try: |
|
self._handlers["stream_stderr"].flush() |
|
self._handlers["stream_stderr"].stream = old_stream |
|
finally: |
|
self._handlers["stream_stderr"].release() |
|
|
|
else: |
|
raise RuntimeError("Impossible to patch logging handlers if handler does not exist") |
|
|
|
@contextmanager |
|
def patch_stdout_handler(self, stream): |
|
""" Sends messages that should log to stdout to stream instead. Useful for unittests """ |
|
if self._logger is not None: |
|
try: |
|
old_stream = self._handlers["stream_stdout"].stream |
|
if old_stream is None: |
|
raise ValueError |
|
|
|
|
|
self._handlers["stream_stdout"].acquire() |
|
try: |
|
self._handlers["stream_stdout"].flush() |
|
self._handlers["stream_stdout"].stream = stream |
|
finally: |
|
self._handlers["stream_stdout"].release() |
|
|
|
yield stream |
|
except (KeyError, ValueError): |
|
raise RuntimeError("Impossible to patch logging handlers if handler does not exist") |
|
finally: |
|
|
|
self._handlers["stream_stdout"].acquire() |
|
try: |
|
self._handlers["stream_stdout"].flush() |
|
self._handlers["stream_stdout"].stream = old_stream |
|
finally: |
|
self._handlers["stream_stdout"].release() |
|
|
|
else: |
|
raise RuntimeError("Impossible to patch logging handlers if handler does not exist") |
|
|
|
@contextmanager |
|
def temp_verbosity(self, verbosity_level): |
|
"""Sets the a temporary threshold for what messages will be logged.""" |
|
|
|
if self._logger is not None: |
|
|
|
old_verbosity = self.get_verbosity() |
|
|
|
try: |
|
self.set_verbosity(verbosity_level) |
|
yield |
|
|
|
finally: |
|
self.set_verbosity(old_verbosity) |
|
|
|
else: |
|
try: |
|
yield |
|
|
|
finally: |
|
pass |
|
|
|
def captureWarnings(self, capture): |
|
""" |
|
If capture is true, redirect all warnings to the logging package. |
|
If capture is False, ensure that warnings are not redirected to logging |
|
but to their original destinations. |
|
""" |
|
|
|
if self._logger is not None: |
|
|
|
if capture and self.old_warnings_showwarning is None: |
|
|
|
self.old_warnings_showwarning = warnings.showwarning |
|
warnings.showwarning = self._showwarning |
|
|
|
elif not capture and self.old_warnings_showwarning is not None: |
|
|
|
warnings.showwarning = self.old_warnings_showwarning |
|
self.old_warnings_showwarning = None |
|
|
|
def _showwarning(self, message, category, filename, lineno, file=None, line=None): |
|
""" |
|
Implementation of showwarnings which redirects to logging. |
|
It will call warnings.formatwarning and will log the resulting string |
|
with level logging.WARNING. |
|
""" |
|
s = warnings.formatwarning(message, category, filename, lineno, line) |
|
self.warning("%s", s) |
|
|
|
def _logged_once(self, msg, mode): |
|
PREFIX_LEN = 12 |
|
if mode == LogMode.ONCE: |
|
if msg[PREFIX_LEN:] in self.once_logged: |
|
return True |
|
self.once_logged.add(msg[PREFIX_LEN:]) |
|
return False |
|
|
|
def debug(self, msg, *args, mode=LogMode.EACH, **kwargs): |
|
""" |
|
Log 'msg % args' with severity 'DEBUG'. |
|
|
|
To pass exception information, use the keyword argument exc_info with |
|
a true value, e.g. |
|
|
|
logger.debug("Houston, we have a %s", "thorny problem", exc_info=1) |
|
""" |
|
if self._logger is not None and self._logger.isEnabledFor(Logger.DEBUG) and not self._logged_once(msg, mode): |
|
self._logger._log(Logger.DEBUG, msg, args, **kwargs) |
|
|
|
def info(self, msg, *args, mode=LogMode.EACH, **kwargs): |
|
""" |
|
Log 'msg % args' with severity 'INFO'. |
|
|
|
To pass exception information, use the keyword argument exc_info with |
|
a true value, e.g. |
|
|
|
logger.info("Houston, we have a %s", "interesting problem", exc_info=1) |
|
""" |
|
if self._logger is not None and self._logger.isEnabledFor(Logger.INFO) and not self._logged_once(msg, mode): |
|
self._logger._log(Logger.INFO, msg, args, **kwargs) |
|
|
|
def warning(self, msg, *args, mode=LogMode.EACH, **kwargs): |
|
""" |
|
Log 'msg % args' with severity 'WARNING'. |
|
|
|
To pass exception information, use the keyword argument exc_info with |
|
a true value, e.g. |
|
|
|
logger.warning("Houston, we have a %s", "bit of a problem", exc_info=1) |
|
""" |
|
if self._logger is not None and self._logger.isEnabledFor(Logger.WARNING) and not self._logged_once(msg, mode): |
|
self._logger._log(Logger.WARNING, msg, args, **kwargs) |
|
|
|
def error(self, msg, *args, mode=LogMode.EACH, **kwargs): |
|
""" |
|
Log 'msg % args' with severity 'ERROR'. |
|
|
|
To pass exception information, use the keyword argument exc_info with |
|
a true value, e.g. |
|
|
|
logger.error("Houston, we have a %s", "major problem", exc_info=1) |
|
""" |
|
if self._logger is not None and self._logger.isEnabledFor(Logger.ERROR) and not self._logged_once(msg, mode): |
|
self._logger._log(Logger.ERROR, msg, args, **kwargs) |
|
|
|
def critical(self, msg, *args, mode=LogMode.EACH, **kwargs): |
|
""" |
|
Log 'msg % args' with severity 'CRITICAL'. |
|
|
|
To pass exception information, use the keyword argument exc_info with |
|
a true value, e.g. |
|
|
|
logger.critical("Houston, we have a %s", "major disaster", exc_info=1) |
|
""" |
|
if ( |
|
self._logger is not None |
|
and self._logger.isEnabledFor(Logger.CRITICAL) |
|
and not self._logged_once(msg, mode) |
|
): |
|
self._logger._log(Logger.CRITICAL, msg, args, **kwargs) |
|
|