Spaces:
Running
Running
File size: 4,439 Bytes
bcc039b afedb16 bcc039b b0120da afedb16 bcc039b afedb16 bcc039b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
import logging
import math
import sys
import time
from datetime import timedelta
import fsspec
from bytelatent.distributed import get_global_rank, get_is_slurm_job
class LogFormatter(logging.Formatter):
"""
Custom logger for distributed jobs, displaying rank
and preserving indent from the custom prefix format.
"""
def __init__(self):
self.start_time = time.time()
self.rank = get_global_rank()
self.show_rank = not get_is_slurm_job() # srun has --label
def formatTime(self, record):
subsecond, seconds = math.modf(record.created)
curr_date = (
time.strftime("%y-%m-%d %H:%M:%S", time.localtime(seconds))
+ f".{int(subsecond * 1_000_000):06d}"
)
delta = timedelta(seconds=round(record.created - self.start_time))
return f"{curr_date} - {delta}"
def formatPrefix(self, record):
fmt_time = self.formatTime(record)
if self.show_rank:
return f"{self.rank}: {record.levelname:<7} {fmt_time} - "
else:
return f"{record.levelname:<7} {fmt_time} - "
def formatMessage(self, record, indent: str):
content = record.getMessage()
content = content.replace("\n", "\n" + indent)
# Exception handling as in the default formatter, albeit with indenting
# according to our custom prefix
if record.exc_info:
# Cache the traceback text to avoid converting it multiple times
# (it's constant anyway)
if not record.exc_text:
record.exc_text = self.formatException(record.exc_info)
if record.exc_text:
if content[-1:] != "\n":
content = content + "\n" + indent
content = content + indent.join(
[l + "\n" for l in record.exc_text.splitlines()]
)
if content[-1:] == "\n":
content = content[:-1]
if record.stack_info:
if content[-1:] != "\n":
content = content + "\n" + indent
stack_text = self.formatStack(record.stack_info)
content = content + indent.join([l + "\n" for l in stack_text.splitlines()])
if content[-1:] == "\n":
content = content[:-1]
return content
def format(self, record):
prefix = self.formatPrefix(record)
indent = " " * len(prefix)
content = self.formatMessage(record, indent)
return prefix + content
def set_root_log_level(log_level: str):
logger = logging.getLogger()
level: int | str = log_level.upper()
try:
level = int(log_level)
except ValueError:
pass
try:
logger.setLevel(level) # type: ignore
except Exception:
logger.warning(
f"Failed to set logging level to {log_level}, using default 'NOTSET'"
)
logger.setLevel(logging.NOTSET)
def init_logger(
log_file: str | None = None,
*,
name: str | None = None,
level: str = "INFO",
fs: fsspec.AbstractFileSystem | None = None,
):
"""
Setup logging.
Args:
log_file: A file name to save file logs to.
name: The name of the logger to configure, by default the root logger.
level: The logging level to use.
"""
set_root_log_level(level)
logger = logging.getLogger(name)
# stdout: everything
stdout_handler = logging.StreamHandler(sys.stdout)
stdout_handler.setLevel(logging.NOTSET)
stdout_handler.setFormatter(LogFormatter())
# stderr: warnings / errors and above
stderr_handler = logging.StreamHandler(sys.stderr)
stderr_handler.setLevel(logging.WARNING)
stderr_handler.setFormatter(LogFormatter())
# set stream handlers
logger.handlers.clear()
logger.handlers.append(stdout_handler)
logger.handlers.append(stderr_handler)
if log_file is not None and get_global_rank() == 0:
# build file handler
if fs is None:
file_handler = logging.FileHandler(log_file, "a")
else:
file_stream = fs.open(log_file, mode="a")
file_handler = logging.StreamHandler(file_stream)
file_handler.setLevel(logging.NOTSET)
file_handler.setFormatter(LogFormatter())
# update logger
logger = logging.getLogger()
logger.addHandler(file_handler)
|