File size: 3,009 Bytes
5ac1897
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
# Provides methods to visualize the information of data, giving a brief overview in figure.

import torch
import numpy as np
import matplotlib.pyplot as plt

from typing import Optional, Union, List, Dict
from pathlib import Path

from lib.utils.data import to_numpy


def show_distribution(
    data       : Dict,
    fn         : Union[str, Path],  # File name of the saved figure.
    bins       : int  = 100,        # Number of bins in the histogram.
    annotation : bool = False,
    title      : str = 'Data Distribution',
    axis_names : List = ['Value', 'Frequency'],
    bounds     : Optional[List] = None,  # Left and right bounds of the histogram.
):
    '''
    Visualize the distribution of the data using histogram.
    The data should be a dictionary with keys as the labels and values as the data.
    '''
    labels = list(data.keys())
    data = np.stack([ to_numpy(x) for x in data.values() ], axis=0)
    assert data.ndim == 2, f"Data dimension should be 2, but got {data.ndim}."
    assert bounds is None or len(bounds) == 2, f"Bounds should be a list of length 2, but got {bounds}."
    # Preparation.
    N, K = data.shape
    data = data.transpose(1, 0)  # (K, N)
    # Plot.
    plt.hist(data, bins=bins, alpha=0.7, label=labels)
    if annotation:
        for i in range(K):
            for j in range(N):
                plt.text(data[i, j], 0, f'{data[i, j]:.2f}', va='bottom', fontsize=6)
    plt.title(title)
    plt.xlabel(axis_names[0])
    plt.ylabel(axis_names[1])
    plt.legend()
    if bounds:
        plt.xlim(bounds)
    # Save.
    plt.savefig(fn)
    plt.close()



def show_history(
    data       : Dict,
    fn         : Union[str, Path],  # file name of the saved figure
    annotation : bool = False,
    title      : str  = 'Data History',
    axis_names : List = ['Time', 'Value'],
    ex_starts  : Dict[str, int] = {},  # starting points of the history if not starting from 0
):
    '''
    Visualize the value of changing across time.
    The history should be a dictionary with keys as the metric names and values as the metric values.
    '''
    # Make sure the fn's parent exists.
    if isinstance(fn, str):
        fn = Path(fn)
    fn.parent.mkdir(parents=True, exist_ok=True)

    # Preparation.
    history_name = list(data.keys())
    history_data = [ to_numpy(x) for x in data.values() ]
    N = len(history_name)
    Ls = [len(x) for x in history_data]
    Ss = [
            ex_starts[history_name[i]]
            if (history_name[i] in ex_starts.keys()) else 0
            for i in range(N)
        ]

    # Plot.
    for i in range(N):
        plt.plot(range(Ss[i], Ss[i]+Ls[i]), history_data[i], label=history_name[i])
    if annotation:
        for i in range(N):
            for j in range(Ls[i]):
                plt.text(Ss[i]+j, history_data[i][j], f'{history_data[i][j]:.2f}', fontsize=6)

    plt.title(title)
    plt.xlabel(axis_names[0])
    plt.ylabel(axis_names[1])
    plt.legend()
    # Save.
    plt.savefig(fn)
    plt.close()