import sys import os from toolkit.accelerator import get_accelerator def print_acc(*args, **kwargs): if get_accelerator().is_local_main_process: print(*args, **kwargs) class Logger: def __init__(self, filename): self.terminal = sys.stdout self.log = open(filename, 'a') def write(self, message): self.terminal.write(message) self.log.write(message) self.log.flush() # Make sure it's written immediately def flush(self): self.terminal.flush() self.log.flush() def setup_log_to_file(filename): if get_accelerator().is_local_main_process: if not os.path.exists(os.path.dirname(filename)): os.makedirs(os.path.dirname(filename)) sys.stdout = Logger(filename) sys.stderr = Logger(filename)