Spaces:
Sleeping
Sleeping
import time | |
import torch | |
import inspect | |
def fold_path(fn:str): | |
''' Fold a path like `from/to/file.py` to relative `f/t/file.py`. ''' | |
return '/'.join([p[:1] for p in fn.split('/')[:-1]]) + '/' + fn.split('/')[-1] | |
def summary_frame_info(frame:inspect.FrameInfo): | |
''' Convert a FrameInfo object to a summary string. ''' | |
return f'{frame.function} @ {fold_path(frame.filename)}:{frame.lineno}' | |
class GPUMonitor(): | |
''' | |
This monitor is designed for GPU memory analysis. It records the peak memory usage in a period of time. | |
A snapshot will record the peak memory usage until the snapshot is taken. (After init / reset / previous snapshot.) | |
''' | |
def __init__(self): | |
self.reset() | |
self.clear() | |
self.log_fn = 'gpu_monitor.log' | |
def snapshot(self, desc:str='snapshot'): | |
timestamp = time.time() | |
caller_frame = inspect.stack()[1] | |
peak_MB = torch.cuda.max_memory_allocated() / 1024 / 1024 | |
free_mem, total_mem = torch.cuda.mem_get_info(0) | |
free_mem_MB, total_mem_MB = free_mem / 1024 / 1024, total_mem / 1024 / 1024 | |
record = { | |
'until' : time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(timestamp)), | |
'until_raw' : timestamp, | |
'position' : summary_frame_info(caller_frame), | |
'peak' : peak_MB, | |
'peak_msg' : f'{peak_MB:.2f} MB', | |
'free' : free_mem_MB, | |
'total' : total_mem_MB, | |
'free_msg' : f'{free_mem_MB:.2f} MB', | |
'total_msg' : f'{total_mem_MB:.2f} MB', | |
'desc' : desc, | |
} | |
self.max_peak = max(self.max_peak_MB, peak_MB) | |
self.records.append(record) | |
self._update_log(record) | |
self.reset() | |
return record | |
def report_latest(self, k:int=1): | |
import rich | |
caller_frame = inspect.stack()[1] | |
caller_info = summary_frame_info(caller_frame) | |
rich.print(f'{caller_info} -> latest {k} records:') | |
for rid, record in enumerate(self.records[-k:]): | |
msg = self._generate_log_msg(record) | |
rich.print(msg) | |
def report_all(self): | |
self.report_latest(len(self.records)) | |
def reset(self): | |
torch.cuda.reset_peak_memory_stats() | |
return | |
def clear(self): | |
self.records = [] | |
self.max_peak_MB = 0 | |
def _generate_log_msg(self, record): | |
time = record['until'] | |
peak = record['peak'] | |
desc = record['desc'] | |
position = record['position'] | |
msg = f'[{time}] β°οΈ {peak:>8.2f} MB π {desc} π {position}' | |
return msg | |
def _update_log(self, record): | |
msg = self._generate_log_msg(record) | |
with open(self.log_fn, 'a') as f: | |
f.write(msg + '\n') | |