Spaces:
Sleeping
Sleeping
File size: 2,976 Bytes
5ac1897 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
import torch
import numpy as np
from typing import Dict, List
def disassemble_dict(d, keep_dim=False):
'''
Unpack a dictionary into a list of dictionaries. The values should be in the same length.
If not keep dim: {k: [...] * N} -> [{k: [...]}] * N.
If keep dim: {k: [...] * N} -> [{k: [[...]]}] * N.
'''
Ls = [len(v) for v in d.values()]
assert len(set(Ls)) == 1, 'The lengths of the values should be the same!'
N = Ls[0]
if keep_dim:
return [{k: v[[i]] for k, v in d.items()} for i in range(N)]
else:
return [{k: v[i] for k, v in d.items()} for i in range(N)]
def assemble_dict(d, expand_dim=False, keys=None):
'''
Pack a list of dictionaries into one dictionary.
If expand dim, perform stack, else, perform concat.
'''
keys = list(d[0].keys()) if keys is None else keys
if isinstance(d[0][keys[0]], np.ndarray):
if expand_dim:
return {k: np.stack([v[k] for v in d], axis=0) for k in keys}
else:
return {k: np.concatenate([v[k] for v in d], axis=0) for k in keys}
elif isinstance(d[0][keys[0]], torch.Tensor):
if expand_dim:
return {k: torch.stack([v[k] for v in d], dim=0) for k in keys}
else:
return {k: torch.cat([v[k] for v in d], dim=0) for k in keys}
def filter_dict(d:Dict, keys:List, full:bool=False, strict:bool=False):
'''
Use path-like syntax to filter the embedded dictionary.
The `'*'` string is regarded as a wildcard, and will return the matched keys.
For control flags:
- If `full`, return the full path, otherwise, only return the matched values.
- If `strict`, raise error if the key is not found, otherwise, simply ignore.
Eg.
- `x = {'fruit': {'yellow': 'banana', 'red': 'apple'}, 'recycle': {'yellow': 'trash', 'blue': 'recyclable'}}`
- `filter_dict(x, ['*', 'yellow'])` -> `{'fruit': 'banana', 'recycle': 'trash'}`
- `filter_dict(x, ['*', 'yellow'], full=True)` -> `{'fruit': {'yellow': 'banana'}, 'recycle': {'yellow': 'trash'}}`
- `filter_dict(x, ['*', 'blue'])` -> `{'recycle': 'recyclable'}`
- `filter_dict(x, ['*', 'blue'], strict=True)` -> `KeyError: 'blue'`
'''
ret = {}
if keys:
cur_key, rest_keys = keys[0], keys[1:]
if cur_key == '*':
for match in d.keys():
try:
res = filter_dict(d[match], rest_keys, full=full, strict=strict)
if res:
ret[match] = res
except Exception as e:
if strict:
raise e
else:
try:
res = filter_dict(d[cur_key], rest_keys, full=full, strict=strict)
if res:
ret = { cur_key : res } if full else res
except Exception as e:
if strict:
raise e
else:
ret = d
return ret |