from typing import List, Dict def replace_state_dict_name_prefix(state_dict:Dict[str, object], old_prefix:str, new_prefix:str): ''' Replace the prefix of the keys in the state_dict. ''' for old_name in list(state_dict.keys()): if old_name.startswith(old_prefix): new_name = new_prefix + old_name[len(old_prefix):] state_dict[new_name] = state_dict.pop(old_name) return state_dict def match_prefix_and_remove_state_dict(state_dict:Dict[str, object], prefix:str): ''' Remove the keys in the state_dict that start with the prefix. ''' for name in list(state_dict.keys()): if name.startswith(prefix): state_dict.pop(name) return state_dict class StateDictTree: def __init__(self, keys:List[str]): self.tree = {} for key in keys: parts = key.split('.') self._recursively_add_leaf(self.tree, parts, key) def rich_print(self, depth:int=-1): from rich.tree import Tree from rich import print rich_tree = Tree('.') self._recursively_build_rich_tree(rich_tree, self.tree, 0, depth) print(rich_tree) def update_node_name(self, old_name:str, new_name:str): ''' Input full node name and the whole node will be moved to the new name. ''' old_parts = old_name.split('.') # 1. Delete the old node. try: parent = None node = self.tree for part in old_parts: parent = node node = node[part] parent.pop(old_parts[-1]) except KeyError: raise KeyError(f'Key {old_name} not found.') # 2. Add the new node. new_parts = new_name.split('.') self._recursively_add_leaf(self.tree, new_parts, new_name) def _recursively_add_leaf(self, node, parts, full_key): cur_part, rest_parts = parts[0], parts[1:] if len(rest_parts) == 0: assert cur_part not in node, f'Key {full_key} already exists.' node[cur_part] = full_key else: if cur_part not in node: node[cur_part] = {} self._recursively_add_leaf(node[cur_part], rest_parts, full_key) def _recursively_build_rich_tree(self, rich_node, dict_node, depth, max_depth:int=-1): if max_depth > 0 and depth >= max_depth: rich_node.add(f'... {len(dict_node)} more') return keys = sorted(dict_node.keys()) for key in keys: next_dict_node = dict_node[key] next_rich_node = rich_node.add(key) if isinstance(next_dict_node, Dict): self._recursively_build_rich_tree(next_rich_node, next_dict_node, depth+1, max_depth)