File size: 2,757 Bytes
5ac1897
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
from typing import List, Dict

def replace_state_dict_name_prefix(state_dict:Dict[str, object], old_prefix:str, new_prefix:str):
    ''' Replace the prefix of the keys in the state_dict. '''
    for old_name in list(state_dict.keys()):
        if old_name.startswith(old_prefix):
            new_name = new_prefix + old_name[len(old_prefix):]
            state_dict[new_name] = state_dict.pop(old_name)

    return state_dict


def match_prefix_and_remove_state_dict(state_dict:Dict[str, object], prefix:str):
    ''' Remove the keys in the state_dict that start with the prefix. '''
    for name in list(state_dict.keys()):
        if name.startswith(prefix):
            state_dict.pop(name)
    return state_dict


class StateDictTree:
    def __init__(self, keys:List[str]):
        self.tree = {}
        for key in keys:
            parts = key.split('.')
            self._recursively_add_leaf(self.tree, parts, key)


    def rich_print(self, depth:int=-1):
        from rich.tree import Tree
        from rich import print
        rich_tree = Tree('.')
        self._recursively_build_rich_tree(rich_tree, self.tree, 0, depth)
        print(rich_tree)

    def update_node_name(self, old_name:str, new_name:str):
        ''' Input full node name and the whole node will be moved to the new name. '''
        old_parts = old_name.split('.')
        # 1. Delete the old node.
        try:
            parent = None
            node = self.tree
            for part in old_parts:
                parent = node
                node = node[part]
            parent.pop(old_parts[-1])
        except KeyError:
            raise KeyError(f'Key {old_name} not found.')
        # 2. Add the new node.
        new_parts = new_name.split('.')
        self._recursively_add_leaf(self.tree, new_parts, new_name)


    def _recursively_add_leaf(self, node, parts, full_key):
        cur_part, rest_parts = parts[0], parts[1:]
        if len(rest_parts) == 0:
            assert cur_part not in node, f'Key {full_key} already exists.'
            node[cur_part] = full_key
        else:
            if cur_part not in node:
                node[cur_part] = {}
            self._recursively_add_leaf(node[cur_part], rest_parts, full_key)


    def _recursively_build_rich_tree(self, rich_node, dict_node, depth, max_depth:int=-1):
        if max_depth > 0 and depth >= max_depth:
            rich_node.add(f'... {len(dict_node)} more')
            return

        keys = sorted(dict_node.keys())
        for key in keys:
            next_dict_node = dict_node[key]
            next_rich_node = rich_node.add(key)
            if isinstance(next_dict_node, Dict):
                self._recursively_build_rich_tree(next_rich_node, next_dict_node, depth+1, max_depth)