IsshikiHugh's picture
feat: CPU demo
5ac1897
import time
import atexit
import inspect
import torch
from typing import Optional, Union, List
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor
def fold_path(fn:str):
''' Fold a path like `from/to/file.py` to relative `f/t/file.py`. '''
return '/'.join([p[:1] for p in fn.split('/')[:-1]]) + '/' + fn.split('/')[-1]
def summary_frame_info(frame:inspect.FrameInfo):
''' Convert a FrameInfo object to a summary string. '''
return f'{frame.function} @ {fold_path(frame.filename)}:{frame.lineno}'
class TimeMonitorDisabled:
def foo(self, *args, **kwargs):
return
def __init__(self, log_folder:Optional[Union[str, Path]]=None, record_birth_block:bool=False):
self.tick = self.foo
self.report = self.foo
self.clear = self.foo
self.dump_statistics = self.foo
def __call__(self, *args, **kwargs):
return self
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
return
class TimeMonitor:
'''
It is supposed to be used like this:
time_monitor = TimeMonitor()
with time_monitor('test_block', 'Block that does something.') as tm:
do_something()
time_monitor.report()
'''
def __init__(self, log_folder:Optional[Union[str, Path]]=None, record_birth_block:bool=False):
if log_folder is not None:
self.log_folder = Path(log_folder) if isinstance(log_folder, str) else log_folder
self.log_folder.mkdir(parents=True, exist_ok=True)
log_fn = self.log_folder / 'readable.log'
self.log_fh = open(log_fn, 'w') # Log file handler.
self.log_fh.write('=== New Exp ===\n')
else:
self.log_folder = None
self.clear()
self.current_block_uid_stack : List = [] # Unique block id stack for recording.
self.current_block_aid_stack : List = [] # Block id stack for accumulated cost analysis.
# Specially add a global start and end block.
self.record_birth_block = record_birth_block and log_folder is not None
if self.record_birth_block:
self.__call__('monitor_birth', 'Since the monitor is constructed.')
self.__enter__()
# Register the exit hook to dump the data safely.
atexit.register(self._die_hook)
def __call__(self, block_name:str, block_desc:Optional[str]=None):
''' Set up the name of the context for a block. '''
# 1. Format the block name.
block_name = block_name.replace('/', '-').replace(' ', '-')
block_name_recursive = '/'.join([s.split('/')[-1] for s in self.current_block_aid_stack] + [block_name]) # Tree structure block name.
# 2. Get a unique name for the block record.
block_postfixed = 0
while f'{block_name_recursive}_{block_postfixed}' in self.block_info:
block_postfixed += 1
# 3. Get the caller frame information.
caller_frame = inspect.stack()[1]
block_position = summary_frame_info(caller_frame)
# 4. Initialize the block information.
self.current_block_uid_stack.append(f'{block_name_recursive}_{block_postfixed}')
self.current_block_aid_stack.append(block_name)
self.block_info[self.current_block_uid_stack[-1]] = {
'records' : [],
'position' : block_position,
'desc' : block_desc,
}
return self
def __enter__(self):
caller_frame = inspect.stack()[1]
record = self._tick_record(caller_frame, 'Start of the block.')
self.block_info[self.current_block_uid_stack[-1]]['records'].append(record)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
caller_frame = inspect.stack()[1]
record = self._tick_record(caller_frame, 'End of the block.')
self.block_info[self.current_block_uid_stack[-1]]['records'].append(record)
# Finish one block.
curr_block_uid = self.current_block_uid_stack.pop()
curr_block_aid = '/'.join(self.current_block_aid_stack)
self.current_block_aid_stack.pop()
self.finished_blocks.append(curr_block_uid)
elapsed = self.block_info[curr_block_uid]['records'][-1]['timestamp'] \
- self.block_info[curr_block_uid]['records'][0]['timestamp']
self.block_cost[curr_block_aid] = self.block_cost.get(curr_block_aid, 0) + elapsed
if hasattr(self, 'dump_thread'):
self.dump_thread.result()
with ThreadPoolExecutor() as executor:
self.dump_thread = executor.submit(self.dump_statistics)
def tick(self, desc:str=''):
'''
Record a intermediate timestamp. These records are only for in-block analysis,
and will be ignored when analyzing in global view.
'''
caller_frame = inspect.stack()[1]
record = self._tick_record(caller_frame, desc)
self.block_info[self.current_block_uid_stack[-1]]['records'].append(record)
return
def report(self, level:Union[str, List[str]]='global'):
import rich
caller_frame = inspect.stack()[1]
caller_info = summary_frame_info(caller_frame)
if isinstance(level, str):
level = [level]
for lv in level: # To make sure we can output in order.
if lv == 'block':
rich.print(f'[bold underline][EA-B][/bold underline] {caller_info} -> blocks level records:')
for block_name in self.finished_blocks:
msg = '\t' + self._generate_block_msg(block_name).replace('\n\t', '\n\t\t')
rich.print(msg)
elif lv == 'global':
rich.print(f'[bold underline][EA-G][/bold underline] {caller_info} -> global efficiency analysis:')
for block_name, cost in self.block_cost.items():
rich.print(f'\t{block_name}: {cost:.2f} sec')
def clear(self):
self.finished_blocks = []
self.block_info = {}
self.block_cost = {}
def dump_statistics(self):
''' Dump the logging raw data for post analysis. '''
if self.log_folder is None:
return
dump_fn = self.log_folder / 'statistics.pkl'
with open(dump_fn, 'wb') as f:
import pickle
pickle.dump({
'finished_blocks' : self.finished_blocks,
'block_info' : self.block_info,
'block_cost' : self.block_cost,
'curr_aid_stack' : self.current_block_aid_stack, # nonempty when when errors happen inside a block
}, f)
# TODO: Draw a graph to visualize the time consumption.
def _tick_record(self, caller_frame, desc:Optional[str]=''):
# 1. Generate the record.
torch.cuda.synchronize()
timestamp = time.time()
readable_time = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(timestamp))
position = summary_frame_info(caller_frame)
record = {
'time' : readable_time,
'timestamp' : timestamp,
'position' : position,
'desc' : desc,
}
# 2. Log the record.
if self.log_folder is not None:
block_uid = self.current_block_uid_stack[-1]
log_msg = f'[{readable_time}] πŸ—‚οΈ {block_uid} πŸ“Œ {desc} 🌐 {position}'
self.log_fh.write(log_msg + '\n')
return record
def _generate_block_msg(self, block_name):
block_info = self.block_info[block_name]
block_position = block_info['position']
block_desc = block_info['desc']
records = block_info['records']
msg = f'πŸ—‚οΈ {block_name} πŸ“Œ {block_desc} 🌐 {block_position}'
for rid, record in enumerate(records):
readable_time = record['time']
tick_desc = record['desc']
tick_position = record['position']
if rid > 0:
prev_record = records[rid-1]
tick_elapsed = record['timestamp'] - prev_record['timestamp']
tick_elapsed = f'{tick_elapsed:.2f} s'
else:
tick_elapsed = 'N/A'
msg += f'\n\t[{readable_time}] ⏳ {tick_elapsed} πŸ“Œ {tick_desc} 🌐 {tick_position}'
return msg
def _die_hook(self):
if self.record_birth_block:
self.__exit__(None, None, None)
self.dump_statistics()
if self.log_folder is not None:
self.log_fh.close()