Spaces:
Sleeping
Sleeping
import time | |
import atexit | |
import inspect | |
import torch | |
from typing import Optional, Union, List | |
from pathlib import Path | |
from concurrent.futures import ThreadPoolExecutor | |
def fold_path(fn:str): | |
''' Fold a path like `from/to/file.py` to relative `f/t/file.py`. ''' | |
return '/'.join([p[:1] for p in fn.split('/')[:-1]]) + '/' + fn.split('/')[-1] | |
def summary_frame_info(frame:inspect.FrameInfo): | |
''' Convert a FrameInfo object to a summary string. ''' | |
return f'{frame.function} @ {fold_path(frame.filename)}:{frame.lineno}' | |
class TimeMonitorDisabled: | |
def foo(self, *args, **kwargs): | |
return | |
def __init__(self, log_folder:Optional[Union[str, Path]]=None, record_birth_block:bool=False): | |
self.tick = self.foo | |
self.report = self.foo | |
self.clear = self.foo | |
self.dump_statistics = self.foo | |
def __call__(self, *args, **kwargs): | |
return self | |
def __enter__(self): | |
return self | |
def __exit__(self, exc_type, exc_val, exc_tb): | |
return | |
class TimeMonitor: | |
''' | |
It is supposed to be used like this: | |
time_monitor = TimeMonitor() | |
with time_monitor('test_block', 'Block that does something.') as tm: | |
do_something() | |
time_monitor.report() | |
''' | |
def __init__(self, log_folder:Optional[Union[str, Path]]=None, record_birth_block:bool=False): | |
if log_folder is not None: | |
self.log_folder = Path(log_folder) if isinstance(log_folder, str) else log_folder | |
self.log_folder.mkdir(parents=True, exist_ok=True) | |
log_fn = self.log_folder / 'readable.log' | |
self.log_fh = open(log_fn, 'w') # Log file handler. | |
self.log_fh.write('=== New Exp ===\n') | |
else: | |
self.log_folder = None | |
self.clear() | |
self.current_block_uid_stack : List = [] # Unique block id stack for recording. | |
self.current_block_aid_stack : List = [] # Block id stack for accumulated cost analysis. | |
# Specially add a global start and end block. | |
self.record_birth_block = record_birth_block and log_folder is not None | |
if self.record_birth_block: | |
self.__call__('monitor_birth', 'Since the monitor is constructed.') | |
self.__enter__() | |
# Register the exit hook to dump the data safely. | |
atexit.register(self._die_hook) | |
def __call__(self, block_name:str, block_desc:Optional[str]=None): | |
''' Set up the name of the context for a block. ''' | |
# 1. Format the block name. | |
block_name = block_name.replace('/', '-').replace(' ', '-') | |
block_name_recursive = '/'.join([s.split('/')[-1] for s in self.current_block_aid_stack] + [block_name]) # Tree structure block name. | |
# 2. Get a unique name for the block record. | |
block_postfixed = 0 | |
while f'{block_name_recursive}_{block_postfixed}' in self.block_info: | |
block_postfixed += 1 | |
# 3. Get the caller frame information. | |
caller_frame = inspect.stack()[1] | |
block_position = summary_frame_info(caller_frame) | |
# 4. Initialize the block information. | |
self.current_block_uid_stack.append(f'{block_name_recursive}_{block_postfixed}') | |
self.current_block_aid_stack.append(block_name) | |
self.block_info[self.current_block_uid_stack[-1]] = { | |
'records' : [], | |
'position' : block_position, | |
'desc' : block_desc, | |
} | |
return self | |
def __enter__(self): | |
caller_frame = inspect.stack()[1] | |
record = self._tick_record(caller_frame, 'Start of the block.') | |
self.block_info[self.current_block_uid_stack[-1]]['records'].append(record) | |
return self | |
def __exit__(self, exc_type, exc_val, exc_tb): | |
caller_frame = inspect.stack()[1] | |
record = self._tick_record(caller_frame, 'End of the block.') | |
self.block_info[self.current_block_uid_stack[-1]]['records'].append(record) | |
# Finish one block. | |
curr_block_uid = self.current_block_uid_stack.pop() | |
curr_block_aid = '/'.join(self.current_block_aid_stack) | |
self.current_block_aid_stack.pop() | |
self.finished_blocks.append(curr_block_uid) | |
elapsed = self.block_info[curr_block_uid]['records'][-1]['timestamp'] \ | |
- self.block_info[curr_block_uid]['records'][0]['timestamp'] | |
self.block_cost[curr_block_aid] = self.block_cost.get(curr_block_aid, 0) + elapsed | |
if hasattr(self, 'dump_thread'): | |
self.dump_thread.result() | |
with ThreadPoolExecutor() as executor: | |
self.dump_thread = executor.submit(self.dump_statistics) | |
def tick(self, desc:str=''): | |
''' | |
Record a intermediate timestamp. These records are only for in-block analysis, | |
and will be ignored when analyzing in global view. | |
''' | |
caller_frame = inspect.stack()[1] | |
record = self._tick_record(caller_frame, desc) | |
self.block_info[self.current_block_uid_stack[-1]]['records'].append(record) | |
return | |
def report(self, level:Union[str, List[str]]='global'): | |
import rich | |
caller_frame = inspect.stack()[1] | |
caller_info = summary_frame_info(caller_frame) | |
if isinstance(level, str): | |
level = [level] | |
for lv in level: # To make sure we can output in order. | |
if lv == 'block': | |
rich.print(f'[bold underline][EA-B][/bold underline] {caller_info} -> blocks level records:') | |
for block_name in self.finished_blocks: | |
msg = '\t' + self._generate_block_msg(block_name).replace('\n\t', '\n\t\t') | |
rich.print(msg) | |
elif lv == 'global': | |
rich.print(f'[bold underline][EA-G][/bold underline] {caller_info} -> global efficiency analysis:') | |
for block_name, cost in self.block_cost.items(): | |
rich.print(f'\t{block_name}: {cost:.2f} sec') | |
def clear(self): | |
self.finished_blocks = [] | |
self.block_info = {} | |
self.block_cost = {} | |
def dump_statistics(self): | |
''' Dump the logging raw data for post analysis. ''' | |
if self.log_folder is None: | |
return | |
dump_fn = self.log_folder / 'statistics.pkl' | |
with open(dump_fn, 'wb') as f: | |
import pickle | |
pickle.dump({ | |
'finished_blocks' : self.finished_blocks, | |
'block_info' : self.block_info, | |
'block_cost' : self.block_cost, | |
'curr_aid_stack' : self.current_block_aid_stack, # nonempty when when errors happen inside a block | |
}, f) | |
# TODO: Draw a graph to visualize the time consumption. | |
def _tick_record(self, caller_frame, desc:Optional[str]=''): | |
# 1. Generate the record. | |
torch.cuda.synchronize() | |
timestamp = time.time() | |
readable_time = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(timestamp)) | |
position = summary_frame_info(caller_frame) | |
record = { | |
'time' : readable_time, | |
'timestamp' : timestamp, | |
'position' : position, | |
'desc' : desc, | |
} | |
# 2. Log the record. | |
if self.log_folder is not None: | |
block_uid = self.current_block_uid_stack[-1] | |
log_msg = f'[{readable_time}] ποΈ {block_uid} π {desc} π {position}' | |
self.log_fh.write(log_msg + '\n') | |
return record | |
def _generate_block_msg(self, block_name): | |
block_info = self.block_info[block_name] | |
block_position = block_info['position'] | |
block_desc = block_info['desc'] | |
records = block_info['records'] | |
msg = f'ποΈ {block_name} π {block_desc} π {block_position}' | |
for rid, record in enumerate(records): | |
readable_time = record['time'] | |
tick_desc = record['desc'] | |
tick_position = record['position'] | |
if rid > 0: | |
prev_record = records[rid-1] | |
tick_elapsed = record['timestamp'] - prev_record['timestamp'] | |
tick_elapsed = f'{tick_elapsed:.2f} s' | |
else: | |
tick_elapsed = 'N/A' | |
msg += f'\n\t[{readable_time}] β³ {tick_elapsed} π {tick_desc} π {tick_position}' | |
return msg | |
def _die_hook(self): | |
if self.record_birth_block: | |
self.__exit__(None, None, None) | |
self.dump_statistics() | |
if self.log_folder is not None: | |
self.log_fh.close() |