File size: 4,614 Bytes
6fc683c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import torch
import logging

from transformers.modeling_utils import cached_path, WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME

logger = logging.getLogger(__name__)


def get_checkpoint_from_transformer_cache(
        archive_file, pretrained_model_name_or_path, pretrained_model_archive_map,
        cache_dir, force_download, proxies, resume_download,
):
    try:
        resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir, force_download=force_download,
                                            proxies=proxies, resume_download=resume_download)
    except EnvironmentError:
        if pretrained_model_name_or_path in pretrained_model_archive_map:
            msg = "Couldn't reach server at '{}' to download pretrained weights.".format(
                archive_file)
        else:
            msg = "Model name '{}' was not found in model name list ({}). " \
                  "We assumed '{}' was a path or url to model weight files named one of {} but " \
                  "couldn't find any such file at this path or url.".format(
                pretrained_model_name_or_path,
                ', '.join(pretrained_model_archive_map.keys()),
                archive_file,
                [WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME])
        raise EnvironmentError(msg)

    if resolved_archive_file == archive_file:
        logger.info("loading weights file {}".format(archive_file))
    else:
        logger.info("loading weights file {} from cache at {}".format(
            archive_file, resolved_archive_file))

    return torch.load(resolved_archive_file, map_location='cpu')


def hf_roberta_to_hf_bert(state_dict):
    logger.info(" * Convert Huggingface RoBERTa format to Huggingface BERT format * ")

    new_state_dict = {}

    for key in state_dict:
        value = state_dict[key]
        if key == 'roberta.embeddings.position_embeddings.weight':
            value = value[2:]
        if key == 'roberta.embeddings.token_type_embeddings.weight':
            continue
        if key.startswith('roberta'):
            key = 'bert.' + key[8:]
        elif key.startswith('lm_head'):
            if 'layer_norm' in key or 'dense' in key:
                key = 'cls.predictions.transform.' + key[8:]
            else:
                key = 'cls.predictions.' + key[8:]
            key = key.replace('layer_norm', 'LayerNorm')

        new_state_dict[key] = value

    return new_state_dict


def hf_distilbert_to_hf_bert(state_dict):
    logger.info(" * Convert Huggingface DistilBERT format to Huggingface BERT format * ")

    new_state_dict = {}

    for key in state_dict:
        value = state_dict[key]
        if key == 'roberta.embeddings.position_embeddings.weight':
            value = value[2:]
        if key == 'roberta.embeddings.token_type_embeddings.weight':
            continue
        if key.startswith('roberta'):
            key = 'bert.' + key[8:]
        elif key.startswith('lm_head'):
            if 'layer_norm' in key or 'dense' in key:
                key = 'cls.predictions.transform.' + key[8:]
            else:
                key = 'cls.predictions.' + key[8:]
            key = key.replace('layer_norm', 'LayerNorm')

        new_state_dict[key] = value

    return new_state_dict


def hf_bert_to_hf_bert(state_dict):
    # NOTE: all cls states are used for prediction,
    #   we predict the index so omit all pretrained states for prediction.
    new_state_dict = {}
    for key in state_dict:
        value = state_dict[key]
        if key.startswith('cls'):
            # NOTE: all cls states are used for prediction,
            #   we predict the index so omit all pretrained states for prediction.
            continue
        new_state_dict[key] = value
    return new_state_dict


def hf_layoutlm_to_hf_bert(state_dict):
    logger.info(" * Convert Huggingface LayoutLM format to Huggingface BERT format * ")

    new_state_dict = {}
    for key in state_dict:
        value = state_dict[key]
        if key.startswith('layoutlm'):
            key = 'bert.' + key[9:]
        elif key.startswith('cls'):
            # NOTE: all cls states are used for prediction,
            #   we predict the index so omit all pretrained states for prediction.
            continue
        new_state_dict[key] = value
    return new_state_dict


state_dict_convert = {
    'bert': hf_bert_to_hf_bert,
    'unilm': hf_bert_to_hf_bert, 
    'minilm': hf_bert_to_hf_bert,
    'layoutlm': hf_layoutlm_to_hf_bert,
    'roberta': hf_roberta_to_hf_bert,
    'xlm-roberta': hf_roberta_to_hf_bert,
    'distilbert': hf_distilbert_to_hf_bert,
}