ramimu's picture
Upload 586 files
1c72248 verified
raw
history blame contribute delete
806 Bytes
import sys
import os
from toolkit.accelerator import get_accelerator
def print_acc(*args, **kwargs):
if get_accelerator().is_local_main_process:
print(*args, **kwargs)
class Logger:
def __init__(self, filename):
self.terminal = sys.stdout
self.log = open(filename, 'a')
def write(self, message):
self.terminal.write(message)
self.log.write(message)
self.log.flush() # Make sure it's written immediately
def flush(self):
self.terminal.flush()
self.log.flush()
def setup_log_to_file(filename):
if get_accelerator().is_local_main_process:
if not os.path.exists(os.path.dirname(filename)):
os.makedirs(os.path.dirname(filename))
sys.stdout = Logger(filename)
sys.stderr = Logger(filename)