HSMR / lib /info /show.py
IsshikiHugh's picture
feat: CPU demo
5ac1897
# Provides methods to visualize the information of data, giving a brief overview in figure.
import torch
import numpy as np
import matplotlib.pyplot as plt
from typing import Optional, Union, List, Dict
from pathlib import Path
from lib.utils.data import to_numpy
def show_distribution(
data : Dict,
fn : Union[str, Path], # File name of the saved figure.
bins : int = 100, # Number of bins in the histogram.
annotation : bool = False,
title : str = 'Data Distribution',
axis_names : List = ['Value', 'Frequency'],
bounds : Optional[List] = None, # Left and right bounds of the histogram.
):
'''
Visualize the distribution of the data using histogram.
The data should be a dictionary with keys as the labels and values as the data.
'''
labels = list(data.keys())
data = np.stack([ to_numpy(x) for x in data.values() ], axis=0)
assert data.ndim == 2, f"Data dimension should be 2, but got {data.ndim}."
assert bounds is None or len(bounds) == 2, f"Bounds should be a list of length 2, but got {bounds}."
# Preparation.
N, K = data.shape
data = data.transpose(1, 0) # (K, N)
# Plot.
plt.hist(data, bins=bins, alpha=0.7, label=labels)
if annotation:
for i in range(K):
for j in range(N):
plt.text(data[i, j], 0, f'{data[i, j]:.2f}', va='bottom', fontsize=6)
plt.title(title)
plt.xlabel(axis_names[0])
plt.ylabel(axis_names[1])
plt.legend()
if bounds:
plt.xlim(bounds)
# Save.
plt.savefig(fn)
plt.close()
def show_history(
data : Dict,
fn : Union[str, Path], # file name of the saved figure
annotation : bool = False,
title : str = 'Data History',
axis_names : List = ['Time', 'Value'],
ex_starts : Dict[str, int] = {}, # starting points of the history if not starting from 0
):
'''
Visualize the value of changing across time.
The history should be a dictionary with keys as the metric names and values as the metric values.
'''
# Make sure the fn's parent exists.
if isinstance(fn, str):
fn = Path(fn)
fn.parent.mkdir(parents=True, exist_ok=True)
# Preparation.
history_name = list(data.keys())
history_data = [ to_numpy(x) for x in data.values() ]
N = len(history_name)
Ls = [len(x) for x in history_data]
Ss = [
ex_starts[history_name[i]]
if (history_name[i] in ex_starts.keys()) else 0
for i in range(N)
]
# Plot.
for i in range(N):
plt.plot(range(Ss[i], Ss[i]+Ls[i]), history_data[i], label=history_name[i])
if annotation:
for i in range(N):
for j in range(Ls[i]):
plt.text(Ss[i]+j, history_data[i][j], f'{history_data[i][j]:.2f}', fontsize=6)
plt.title(title)
plt.xlabel(axis_names[0])
plt.ylabel(axis_names[1])
plt.legend()
# Save.
plt.savefig(fn)
plt.close()