File size: 2,976 Bytes
5ac1897
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import torch
import numpy as np
from typing import Dict, List


def disassemble_dict(d, keep_dim=False):
    '''
    Unpack a dictionary into a list of dictionaries. The values should be in the same length.
    If not keep dim: {k: [...] * N} -> [{k: [...]}] * N.
    If keep dim: {k: [...] * N} -> [{k: [[...]]}] * N.
    '''
    Ls = [len(v) for v in d.values()]
    assert len(set(Ls)) == 1, 'The lengths of the values should be the same!'

    N = Ls[0]
    if keep_dim:
        return [{k: v[[i]] for k, v in d.items()} for i in range(N)]
    else:
        return [{k: v[i] for k, v in d.items()} for i in range(N)]


def assemble_dict(d, expand_dim=False, keys=None):
    '''
    Pack a list of dictionaries into one dictionary.
    If expand dim, perform stack, else, perform concat.
    '''
    keys = list(d[0].keys()) if keys is None else keys
    if isinstance(d[0][keys[0]], np.ndarray):
        if expand_dim:
            return {k: np.stack([v[k] for v in d], axis=0) for k in keys}
        else:
            return {k: np.concatenate([v[k] for v in d], axis=0) for k in keys}
    elif isinstance(d[0][keys[0]], torch.Tensor):
        if expand_dim:
            return {k: torch.stack([v[k] for v in d], dim=0) for k in keys}
        else:
            return {k: torch.cat([v[k] for v in d], dim=0) for k in keys}


def filter_dict(d:Dict, keys:List, full:bool=False, strict:bool=False):
    '''
    Use path-like syntax to filter the embedded dictionary.
    The `'*'` string is regarded as a wildcard, and will return the matched keys.
    For control flags:
    - If `full`, return the full path, otherwise, only return the matched values.
    - If `strict`, raise error if the key is not found, otherwise, simply ignore.

    Eg.
    - `x = {'fruit': {'yellow': 'banana', 'red': 'apple'}, 'recycle': {'yellow': 'trash', 'blue': 'recyclable'}}`
        - `filter_dict(x, ['*', 'yellow'])` -> `{'fruit': 'banana', 'recycle': 'trash'}`
        - `filter_dict(x, ['*', 'yellow'], full=True)` -> `{'fruit': {'yellow': 'banana'}, 'recycle': {'yellow': 'trash'}}`
        - `filter_dict(x, ['*', 'blue'])` -> `{'recycle': 'recyclable'}`
        - `filter_dict(x, ['*', 'blue'], strict=True)` -> `KeyError: 'blue'`
    '''

    ret = {}
    if keys:
        cur_key, rest_keys = keys[0], keys[1:]
        if cur_key == '*':
            for match in d.keys():
                try:
                    res = filter_dict(d[match], rest_keys, full=full, strict=strict)
                    if res:
                        ret[match] = res
                except Exception as e:
                    if strict:
                        raise e
        else:
            try:
                res = filter_dict(d[cur_key], rest_keys, full=full, strict=strict)
                if res:
                    ret = { cur_key : res } if full else res
            except Exception as e:
                if strict:
                    raise e
    else:
        ret = d

    return ret