# Provides methods to visualize the information of data, giving a brief overview in figure. import torch import numpy as np import matplotlib.pyplot as plt from typing import Optional, Union, List, Dict from pathlib import Path from lib.utils.data import to_numpy def show_distribution( data : Dict, fn : Union[str, Path], # File name of the saved figure. bins : int = 100, # Number of bins in the histogram. annotation : bool = False, title : str = 'Data Distribution', axis_names : List = ['Value', 'Frequency'], bounds : Optional[List] = None, # Left and right bounds of the histogram. ): ''' Visualize the distribution of the data using histogram. The data should be a dictionary with keys as the labels and values as the data. ''' labels = list(data.keys()) data = np.stack([ to_numpy(x) for x in data.values() ], axis=0) assert data.ndim == 2, f"Data dimension should be 2, but got {data.ndim}." assert bounds is None or len(bounds) == 2, f"Bounds should be a list of length 2, but got {bounds}." # Preparation. N, K = data.shape data = data.transpose(1, 0) # (K, N) # Plot. plt.hist(data, bins=bins, alpha=0.7, label=labels) if annotation: for i in range(K): for j in range(N): plt.text(data[i, j], 0, f'{data[i, j]:.2f}', va='bottom', fontsize=6) plt.title(title) plt.xlabel(axis_names[0]) plt.ylabel(axis_names[1]) plt.legend() if bounds: plt.xlim(bounds) # Save. plt.savefig(fn) plt.close() def show_history( data : Dict, fn : Union[str, Path], # file name of the saved figure annotation : bool = False, title : str = 'Data History', axis_names : List = ['Time', 'Value'], ex_starts : Dict[str, int] = {}, # starting points of the history if not starting from 0 ): ''' Visualize the value of changing across time. The history should be a dictionary with keys as the metric names and values as the metric values. ''' # Make sure the fn's parent exists. if isinstance(fn, str): fn = Path(fn) fn.parent.mkdir(parents=True, exist_ok=True) # Preparation. history_name = list(data.keys()) history_data = [ to_numpy(x) for x in data.values() ] N = len(history_name) Ls = [len(x) for x in history_data] Ss = [ ex_starts[history_name[i]] if (history_name[i] in ex_starts.keys()) else 0 for i in range(N) ] # Plot. for i in range(N): plt.plot(range(Ss[i], Ss[i]+Ls[i]), history_data[i], label=history_name[i]) if annotation: for i in range(N): for j in range(Ls[i]): plt.text(Ss[i]+j, history_data[i][j], f'{history_data[i][j]:.2f}', fontsize=6) plt.title(title) plt.xlabel(axis_names[0]) plt.ylabel(axis_names[1]) plt.legend() # Save. plt.savefig(fn) plt.close()