Tzktz's picture
Upload 7664 files
6fc683c verified
raw
history blame contribute delete
841 Bytes
import os
import logging
from transformers.trainer_callback import TrainerCallback
def _setup_logger():
log_format = logging.Formatter("[%(asctime)s %(levelname)s] %(message)s")
logger = logging.getLogger()
logger.setLevel(logging.INFO)
console_handler = logging.StreamHandler()
console_handler.setFormatter(log_format)
data_dir = './data/'
os.makedirs(data_dir, exist_ok=True)
file_handler = logging.FileHandler('{}/log.txt'.format(data_dir))
file_handler.setFormatter(log_format)
logger.handlers = [console_handler, file_handler]
return logger
logger = _setup_logger()
class LoggerCallback(TrainerCallback):
def on_log(self, args, state, control, logs=None, **kwargs):
_ = logs.pop("total_flos", None)
if state.is_world_process_zero:
logger.info(logs)