File size: 6,676 Bytes
256a159
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
from char_smi import CharFuncs
from collections import namedtuple
from pypinyin import pinyin, Style
import os
Correction = namedtuple(
    "Correction",
    [
        "op",
        "toks",
        "inds",
    ],
) 
char_smi = CharFuncs(os.path.join(os.path.dirname(__file__), "..", "..", "..", "..", "..", "data", "lawbench", "eval_assets", "char_meta.txt"))

def check_spell_error(src_span: str,
                      tgt_span: str,
                      threshold: float = 0.8) -> bool:
    if len(src_span) != len(tgt_span):
        return False
    src_chars = [ch for ch in src_span]
    tgt_chars = [ch for ch in tgt_span]
    if sorted(src_chars) == sorted(tgt_chars):  # 词内部字符异位
        return True
    for src_char, tgt_char in zip(src_chars, tgt_chars):
        if src_char != tgt_char:
            if src_char not in char_smi.data or tgt_char not in char_smi.data:
                return False
            v_sim = char_smi.shape_similarity(src_char, tgt_char)
            p_sim = char_smi.pronunciation_similarity(src_char, tgt_char)
            if v_sim + p_sim < threshold and not (
                    set(pinyin(src_char, style=Style.NORMAL, heteronym=True)[0]) & set(pinyin(tgt_char, style=Style.NORMAL, heteronym=True)[0])):
                return False
    return True

class Classifier:
    """
    错误类型分类器
    """
    def __init__(self,
                 granularity: str = "word"):

        self.granularity = granularity

    @staticmethod
    def get_pos_type(pos):
        if pos in {"n", "nd"}:
            return "NOUN"
        if pos in {"nh", "ni", "nl", "ns", "nt", "nz"}:
            return "NOUN-NE"
        if pos in {"v"}:
            return "VERB"
        if pos in {"a", "b"}:
            return "ADJ"
        if pos in {"c"}:
            return "CONJ"
        if pos in {"r"}:
            return "PRON"
        if pos in {"d"}:
            return "ADV"
        if pos in {"u"}:
            return "AUX"
        # if pos in {"k"}:  # TODO 后缀词比例太少,暂且分入其它
        #     return "SUFFIX"
        if pos in {"m"}:
            return "NUM"
        if pos in {"p"}:
            return "PREP"
        if pos in {"q"}:
            return "QUAN"
        if pos in {"wp"}:
            return "PUNCT"
        return "OTHER"

    def __call__(self,
                 src,
                 tgt,
                 edits,
                 verbose: bool = False):
        """
        为编辑操作划分错误类型
        :param src: 错误句子信息
        :param tgt: 正确句子信息
        :param edits: 编辑操作
        :param verbose: 是否打印信息
        :return: 划分完错误类型后的编辑操作
        """
        results = []
        src_tokens = [x[0] for x in src]
        tgt_tokens = [x[0] for x in tgt]
        for edit in edits:
            error_type = edit[0]
            src_span = " ".join(src_tokens[edit[1]: edit[2]])
            tgt_span = " ".join(tgt_tokens[edit[3]: edit[4]])
            # print(tgt_span)
            cor = None
            if error_type[0] == "T":
                cor = Correction("W", tgt_span, (edit[1], edit[2]))
            elif error_type[0] == "D":
                if self.granularity == "word":  # 词级别可以细分错误类型
                    if edit[2] - edit[1] > 1:  # 词组冗余暂时分为OTHER
                        cor = Correction("R:OTHER", "-NONE-", (edit[1], edit[2]))
                    else:
                        pos = self.get_pos_type(src[edit[1]][1])
                        pos = "NOUN" if pos == "NOUN-NE" else pos
                        pos = "MC" if tgt_span == "[缺失成分]" else pos
                        cor = Correction("R:{:s}".format(pos), "-NONE-", (edit[1], edit[2]))
                else:  # 字级别可以只需要根据操作划分类型即可
                    cor = Correction("R", "-NONE-", (edit[1], edit[2]))
            elif error_type[0] == "I":
                if self.granularity == "word":  # 词级别可以细分错误类型
                    if edit[4] - edit[3] > 1:  # 词组丢失暂时分为OTHER
                        cor = Correction("M:OTHER", tgt_span, (edit[1], edit[2]))
                    else:
                        pos = self.get_pos_type(tgt[edit[3]][1])
                        pos = "NOUN" if pos == "NOUN-NE" else pos
                        pos = "MC" if tgt_span == "[缺失成分]" else pos
                        cor = Correction("M:{:s}".format(pos), tgt_span, (edit[1], edit[2]))
                else:  # 字级别可以只需要根据操作划分类型即可
                    cor = Correction("M", tgt_span, (edit[1], edit[2]))
            elif error_type[0] == "S":
                if self.granularity == "word":  # 词级别可以细分错误类型
                    if check_spell_error(src_span.replace(" ", ""), tgt_span.replace(" ", "")):
                        cor = Correction("S:SPELL", tgt_span, (edit[1], edit[2]))
                        # Todo 暂且不单独区分命名实体拼写错误
                        # if edit[4] - edit[3] > 1:
                        #     cor = Correction("S:SPELL:COMMON", tgt_span, (edit[1], edit[2]))
                        # else:
                        #     pos = self.get_pos_type(tgt[edit[3]][1])
                        #     if pos == "NOUN-NE":  # 命名实体拼写有误
                        #         cor = Correction("S:SPELL:NE", tgt_span, (edit[1], edit[2]))
                        #     else:  # 普通词语拼写有误
                        #         cor = Correction("S:SPELL:COMMON", tgt_span, (edit[1], edit[2]))
                    else:
                        if edit[4] - edit[3] > 1:  # 词组被替换暂时分为OTHER
                            cor = Correction("S:OTHER", tgt_span, (edit[1], edit[2]))
                        else:
                            pos = self.get_pos_type(tgt[edit[3]][1])
                            pos = "NOUN" if pos == "NOUN-NE" else pos
                            pos = "MC" if tgt_span == "[缺失成分]" else pos
                            cor = Correction("S:{:s}".format(pos), tgt_span, (edit[1], edit[2]))
                else:  # 字级别可以只需要根据操作划分类型即可
                    cor = Correction("S", tgt_span, (edit[1], edit[2]))
            results.append(cor)
        if verbose:
            print("========== Corrections ==========")
            for cor in results:
                print("Type: {:s}, Position: {:d} -> {:d}, Target: {:s}".format(cor.op, cor.inds[0], cor.inds[1], cor.toks))
        return results

# print(pinyin("朝", style=Style.NORMAL))