Spaces:
Sleeping
Sleeping
File size: 1,841 Bytes
6fc683c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
from collections import OrderedDict
from transformers import CONFIG_MAPPING, MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, MODEL_NAMES_MAPPING, TOKENIZER_MAPPING
from transformers.convert_slow_tokenizer import SLOW_TO_FAST_CONVERTERS, BertConverter, XLMRobertaConverter
from transformers.models.auto.modeling_auto import auto_class_factory
from .models.layoutlmv2 import (
LayoutLMv2Config,
LayoutLMv2ForRelationExtraction,
LayoutLMv2ForTokenClassification,
LayoutLMv2Tokenizer,
LayoutLMv2TokenizerFast,
)
from .models.layoutxlm import (
LayoutXLMConfig,
LayoutXLMForRelationExtraction,
LayoutXLMForTokenClassification,
LayoutXLMTokenizer,
LayoutXLMTokenizerFast,
)
CONFIG_MAPPING.update([("layoutlmv2", LayoutLMv2Config), ("layoutxlm", LayoutXLMConfig)])
MODEL_NAMES_MAPPING.update([("layoutlmv2", "LayoutLMv2"), ("layoutxlm", "LayoutXLM")])
TOKENIZER_MAPPING.update(
[
(LayoutLMv2Config, (LayoutLMv2Tokenizer, LayoutLMv2TokenizerFast)),
(LayoutXLMConfig, (LayoutXLMTokenizer, LayoutXLMTokenizerFast)),
]
)
SLOW_TO_FAST_CONVERTERS.update({"LayoutLMv2Tokenizer": BertConverter, "LayoutXLMTokenizer": XLMRobertaConverter})
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.update(
[(LayoutLMv2Config, LayoutLMv2ForTokenClassification), (LayoutXLMConfig, LayoutXLMForTokenClassification)]
)
MODEL_FOR_RELATION_EXTRACTION_MAPPING = OrderedDict(
[(LayoutLMv2Config, LayoutLMv2ForRelationExtraction), (LayoutXLMConfig, LayoutXLMForRelationExtraction)]
)
AutoModelForTokenClassification = auto_class_factory(
"AutoModelForTokenClassification", MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, head_doc="token classification"
)
AutoModelForRelationExtraction = auto_class_factory(
"AutoModelForRelationExtraction", MODEL_FOR_RELATION_EXTRACTION_MAPPING, head_doc="relation extraction"
)
|