File size: 12,283 Bytes
256a159 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 |
from itertools import groupby
from string import punctuation
from typing import List
from modules.tokenizer import Tokenizer
from modules.alignment import Alignment, read_cilin, read_confusion
import Levenshtein
class Merger:
"""
合并编辑操作,从Token-Level转换为Span-Level
"""
def __init__(self,
granularity: str = "word",
merge: bool = False):
chinese_punct = "!?。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟–—‘'‛“”„‟…‧."
self.punctuation = punctuation + chinese_punct
self.not_merge_token = [punct for punct in self.punctuation]
self.granularity = granularity
self.merge = merge
@staticmethod
def _merge_edits(seq, tag="X"):
if seq:
return [(tag, seq[0][1], seq[-1][2], seq[0][3], seq[-1][4])]
else:
return seq
@staticmethod
def _check_revolve(span_a, span_b):
span_a = span_a + span_a
return span_b in span_a
def _process_seq(self, seq, src_tokens, tgt_tokens):
if len(seq) <= 1:
return seq
ops = [op[0] for op in seq]
if set(ops) == {"D"} or set(ops) == {"I"}:
return self._merge_edits(seq, set(ops).pop())
if set(ops) == {"D", "I"} or set(ops) == {"I", "D"}:
# do not merge this pattern_from_qua.txt
return seq
if set(ops) == {"S"}:
if self.granularity == "word":
return seq
else:
return self._merge_edits(seq, "S")
if set(ops) == {"M"}:
return self._merge_edits(seq, "M")
return self._merge_edits(seq, "S")
def __call__(self,
align_obj,
src: List,
tgt: List,
verbose: bool = False):
"""
Based on ERRANT's merge, adapted for Chinese
"""
src_tokens = [x[0] for x in src]
tgt_tokens = [x[0] for x in tgt]
edits = []
# Split alignment into groups of M, T and rest. (T has a number after it)
# Todo 一旦插入、删除、替换的对象中含有标点,那么不与其它编辑合并
# Todo 缺失成分标签也不与其它编辑合并
for op, group in groupby(
align_obj,
lambda x: x[0][0] if x[0][0] in {"M", "T"} else False,
):
group = list(group)
# T is always split TODO: Evaluate this
if op == "T":
for seq in group:
edits.append(seq)
# Process D, I and S subsequence
else:
# Turn the processed sequence into edits
processed = self._process_seq(group, src_tokens, tgt_tokens)
for seq in processed:
edits.append(seq)
filtered_edits = []
i = 0
while i < len(edits):
e1 = edits[i][0][0]
if i < len(edits) - 2:
e2 = edits[i + 1][0][0]
e3 = edits[i + 2][0][0]
# Find "S M S" patterns
# Ex:
# S M S
# 冬阴功 对 外国人
# 外国人 对 冬阴功
if e1 == "S" and e2 == "M" and e3 == "S":
w1 = "".join(src_tokens[edits[i][1]: edits[i][2]])
w2 = "".join(tgt_tokens[edits[i][3]: edits[i][4]])
w3 = "".join(src_tokens[edits[i + 2][1]: edits[i + 2][2]])
w4 = "".join(tgt_tokens[edits[i + 2][3]: edits[i + 2][4]])
if min([len(w1), len(w2), len(w3), len(w4)]) == 1:
if w1 == w4 and w2 == w3:
group = [edits[i], edits[i + 1], edits[i + 2]]
processed = self._merge_edits(group, "T" + str(edits[i+2][2] - edits[i][1]))
for seq in processed:
filtered_edits.append(seq)
i += 3
else:
filtered_edits.append(edits[i])
i += 1
else:
if Levenshtein.distance(w1, w4) <= 1 and Levenshtein.distance(w2, w3) <= 1:
group = [edits[i], edits[i + 1], edits[i + 2]]
processed = self._merge_edits(group, "T" + str(edits[i + 2][2] - edits[i][1]))
for seq in processed:
filtered_edits.append(seq)
i += 3
else:
filtered_edits.append(edits[i])
i += 1
# Find "D M I" or "I M D" patterns
# Ex:
# D M I
# 旅游 去 陌生 的 地方
# 去 陌生 的 地方 旅游
elif (e1 == "D" and (e2 == "M" or e2.startswith("T")) and e3 == "I") or (e1 == "I" and (e2 == "M" or e2.startswith("T")) and e3 == "D"):
if e1 == "D":
delete_token = src_tokens[edits[i][1]: edits[i][2]]
insert_token = tgt_tokens[edits[i + 2][3]: edits[i + 2][4]]
else:
delete_token = src_tokens[edits[i + 2][1]: edits[i + 2][2]]
insert_token = tgt_tokens[edits[i][3]: edits[i][4]]
a, b = "".join(delete_token), "".join(insert_token)
if len(a) < len(b):
a, b = b, a
if a not in self.punctuation and b not in self.punctuation and len(a) - len(b) <= 1:
if len(b) == 1:
if a == b:
group = [edits[i], edits[i + 1], edits[i + 2]]
processed = self._merge_edits(group, "T" + str(edits[i+2][2] - edits[i][1]))
for seq in processed:
filtered_edits.append(seq)
i += 3
else:
filtered_edits.append(edits[i])
i += 1
else:
if Levenshtein.distance(a, b) <= 1 or (len(a) == len(b) and self._check_revolve(a, b)):
group = [edits[i], edits[i + 1], edits[i + 2]]
processed = self._merge_edits(group, "T" + str(edits[i + 2][2] - edits[i][1]))
for seq in processed:
filtered_edits.append(seq)
i += 3
else:
filtered_edits.append(edits[i])
i += 1
else:
filtered_edits.append(edits[i])
i += 1
else:
if e1 != "M":
filtered_edits.append(edits[i])
i += 1
else:
if e1 != "M":
filtered_edits.append(edits[i])
i += 1
# In rare cases with word-level tokenization, the following error can occur:
# M D S M
# 有 時 住 上層
# 有 時住 上層
# Which results in S: 時住 --> 時住
# We need to filter this case out
second_filter = []
for edit in filtered_edits: # 避免因为分词错误导致的mismatch现象
span1 = "".join(src_tokens[edit[1] : edit[2]])
span2 = "".join(tgt_tokens[edit[3] : edit[4]])
if span1 != span2:
if edit[0] == "S":
b = True
# In rare cases with word-level tokenization, the following error can occur:
# S I I M
# 负责任 老师
# 负 责任 的 老师
# Which results in S: 负责任 --> 负 责任 的
# We need to convert this edit to I: --> 的
# 首部有重叠
common_str = ""
tmp_new_start_1 = edit[1]
for i in range(edit[1], edit[2]):
if not span2.startswith(common_str + src_tokens[i]):
break
common_str += src_tokens[i]
tmp_new_start_1 = i + 1
new_start_1, new_start_2 = edit[1], edit[3]
if common_str:
tmp_str = ""
for i in range(edit[3], edit[4]):
tmp_str += tgt_tokens[i]
if tmp_str == common_str:
new_start_1, new_start_2 = tmp_new_start_1, i + 1
# second_filter.append(("S", new_start_1, edit[2], i + 1, edit[4]))
b = False
break
elif len(tmp_str) > len(common_str):
break
# 尾部有重叠
common_str = ""
new_end_1, new_end_2 = edit[2], edit[4]
tmp_new_end_1 = edit[2]
for i in reversed(range(new_start_1, edit[2])):
if not span2.endswith(src_tokens[i] + common_str):
break
common_str = src_tokens[i] + common_str
tmp_new_end_1 = i
if common_str:
tmp_str = ""
for i in reversed(range(new_start_2, edit[4])):
tmp_str = tgt_tokens[i] + tmp_str
if tmp_str == common_str:
new_end_1, new_end_2 = tmp_new_end_1, i
b = False
break
elif len(tmp_str) > len(common_str):
break
if b:
second_filter.append(edit)
else:
if new_start_1 == new_end_1:
new_edit = ("I", new_start_1, new_end_1, new_start_2, new_end_2)
elif new_start_2 == new_end_2:
new_edit = ("D", new_start_1, new_end_1, new_start_2, new_end_2)
else:
new_edit = ("S", new_start_1, new_end_1, new_start_2, new_end_2)
second_filter.append(new_edit)
else:
second_filter.append(edit)
if verbose:
print("========== Parallels ==========")
print("".join(src_tokens))
print("".join(tgt_tokens))
print("========== Results ==========")
for edit in second_filter:
op = edit[0]
s = " ".join(src_tokens[edit[1]: edit[2]])
t = " ".join(tgt_tokens[edit[3]: edit[4]])
print(f"{op}:\t{s}\t-->\t{t}")
print("========== Infos ==========")
print(str(src))
print(str(tgt))
return second_filter
if __name__ == "__main__":
tokenizer = Tokenizer("char")
semantic_dict, semantic_class = read_cilin()
confusion_dict = read_confusion()
alignment = Alignment(semantic_dict, confusion_dict)
sents = [
"所 以 印 度 对 全 世 界 人 没 有 说 服 不 要 吃 牛 肉 。".replace(
" ", ""),
"所 以 印 度 没 有 说 服 全 世 界 人 不 要 吃 牛 肉 。".replace(
" ", "")]
src, tgt = tokenizer(sents)
align_obj = alignment(src, tgt)
m = Merger()
m(align_obj, src, tgt, verbose=True) |