TwT-6's picture
Upload 2667 files
256a159 verified
from itertools import groupby
from string import punctuation
from typing import List
from modules.tokenizer import Tokenizer
from modules.alignment import Alignment, read_cilin, read_confusion
import Levenshtein
class Merger:
"""
合并编辑操作,从Token-Level转换为Span-Level
"""
def __init__(self,
granularity: str = "word",
merge: bool = False):
chinese_punct = "!?。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟–—‘'‛“”„‟…‧."
self.punctuation = punctuation + chinese_punct
self.not_merge_token = [punct for punct in self.punctuation]
self.granularity = granularity
self.merge = merge
@staticmethod
def _merge_edits(seq, tag="X"):
if seq:
return [(tag, seq[0][1], seq[-1][2], seq[0][3], seq[-1][4])]
else:
return seq
@staticmethod
def _check_revolve(span_a, span_b):
span_a = span_a + span_a
return span_b in span_a
def _process_seq(self, seq, src_tokens, tgt_tokens):
if len(seq) <= 1:
return seq
ops = [op[0] for op in seq]
if set(ops) == {"D"} or set(ops) == {"I"}:
return self._merge_edits(seq, set(ops).pop())
if set(ops) == {"D", "I"} or set(ops) == {"I", "D"}:
# do not merge this pattern_from_qua.txt
return seq
if set(ops) == {"S"}:
if self.granularity == "word":
return seq
else:
return self._merge_edits(seq, "S")
if set(ops) == {"M"}:
return self._merge_edits(seq, "M")
return self._merge_edits(seq, "S")
def __call__(self,
align_obj,
src: List,
tgt: List,
verbose: bool = False):
"""
Based on ERRANT's merge, adapted for Chinese
"""
src_tokens = [x[0] for x in src]
tgt_tokens = [x[0] for x in tgt]
edits = []
# Split alignment into groups of M, T and rest. (T has a number after it)
# Todo 一旦插入、删除、替换的对象中含有标点,那么不与其它编辑合并
# Todo 缺失成分标签也不与其它编辑合并
for op, group in groupby(
align_obj,
lambda x: x[0][0] if x[0][0] in {"M", "T"} else False,
):
group = list(group)
# T is always split TODO: Evaluate this
if op == "T":
for seq in group:
edits.append(seq)
# Process D, I and S subsequence
else:
# Turn the processed sequence into edits
processed = self._process_seq(group, src_tokens, tgt_tokens)
for seq in processed:
edits.append(seq)
filtered_edits = []
i = 0
while i < len(edits):
e1 = edits[i][0][0]
if i < len(edits) - 2:
e2 = edits[i + 1][0][0]
e3 = edits[i + 2][0][0]
# Find "S M S" patterns
# Ex:
# S M S
# 冬阴功 对 外国人
# 外国人 对 冬阴功
if e1 == "S" and e2 == "M" and e3 == "S":
w1 = "".join(src_tokens[edits[i][1]: edits[i][2]])
w2 = "".join(tgt_tokens[edits[i][3]: edits[i][4]])
w3 = "".join(src_tokens[edits[i + 2][1]: edits[i + 2][2]])
w4 = "".join(tgt_tokens[edits[i + 2][3]: edits[i + 2][4]])
if min([len(w1), len(w2), len(w3), len(w4)]) == 1:
if w1 == w4 and w2 == w3:
group = [edits[i], edits[i + 1], edits[i + 2]]
processed = self._merge_edits(group, "T" + str(edits[i+2][2] - edits[i][1]))
for seq in processed:
filtered_edits.append(seq)
i += 3
else:
filtered_edits.append(edits[i])
i += 1
else:
if Levenshtein.distance(w1, w4) <= 1 and Levenshtein.distance(w2, w3) <= 1:
group = [edits[i], edits[i + 1], edits[i + 2]]
processed = self._merge_edits(group, "T" + str(edits[i + 2][2] - edits[i][1]))
for seq in processed:
filtered_edits.append(seq)
i += 3
else:
filtered_edits.append(edits[i])
i += 1
# Find "D M I" or "I M D" patterns
# Ex:
# D M I
# 旅游 去 陌生 的 地方
# 去 陌生 的 地方 旅游
elif (e1 == "D" and (e2 == "M" or e2.startswith("T")) and e3 == "I") or (e1 == "I" and (e2 == "M" or e2.startswith("T")) and e3 == "D"):
if e1 == "D":
delete_token = src_tokens[edits[i][1]: edits[i][2]]
insert_token = tgt_tokens[edits[i + 2][3]: edits[i + 2][4]]
else:
delete_token = src_tokens[edits[i + 2][1]: edits[i + 2][2]]
insert_token = tgt_tokens[edits[i][3]: edits[i][4]]
a, b = "".join(delete_token), "".join(insert_token)
if len(a) < len(b):
a, b = b, a
if a not in self.punctuation and b not in self.punctuation and len(a) - len(b) <= 1:
if len(b) == 1:
if a == b:
group = [edits[i], edits[i + 1], edits[i + 2]]
processed = self._merge_edits(group, "T" + str(edits[i+2][2] - edits[i][1]))
for seq in processed:
filtered_edits.append(seq)
i += 3
else:
filtered_edits.append(edits[i])
i += 1
else:
if Levenshtein.distance(a, b) <= 1 or (len(a) == len(b) and self._check_revolve(a, b)):
group = [edits[i], edits[i + 1], edits[i + 2]]
processed = self._merge_edits(group, "T" + str(edits[i + 2][2] - edits[i][1]))
for seq in processed:
filtered_edits.append(seq)
i += 3
else:
filtered_edits.append(edits[i])
i += 1
else:
filtered_edits.append(edits[i])
i += 1
else:
if e1 != "M":
filtered_edits.append(edits[i])
i += 1
else:
if e1 != "M":
filtered_edits.append(edits[i])
i += 1
# In rare cases with word-level tokenization, the following error can occur:
# M D S M
# 有 時 住 上層
# 有 時住 上層
# Which results in S: 時住 --> 時住
# We need to filter this case out
second_filter = []
for edit in filtered_edits: # 避免因为分词错误导致的mismatch现象
span1 = "".join(src_tokens[edit[1] : edit[2]])
span2 = "".join(tgt_tokens[edit[3] : edit[4]])
if span1 != span2:
if edit[0] == "S":
b = True
# In rare cases with word-level tokenization, the following error can occur:
# S I I M
# 负责任 老师
# 负 责任 的 老师
# Which results in S: 负责任 --> 负 责任 的
# We need to convert this edit to I: --> 的
# 首部有重叠
common_str = ""
tmp_new_start_1 = edit[1]
for i in range(edit[1], edit[2]):
if not span2.startswith(common_str + src_tokens[i]):
break
common_str += src_tokens[i]
tmp_new_start_1 = i + 1
new_start_1, new_start_2 = edit[1], edit[3]
if common_str:
tmp_str = ""
for i in range(edit[3], edit[4]):
tmp_str += tgt_tokens[i]
if tmp_str == common_str:
new_start_1, new_start_2 = tmp_new_start_1, i + 1
# second_filter.append(("S", new_start_1, edit[2], i + 1, edit[4]))
b = False
break
elif len(tmp_str) > len(common_str):
break
# 尾部有重叠
common_str = ""
new_end_1, new_end_2 = edit[2], edit[4]
tmp_new_end_1 = edit[2]
for i in reversed(range(new_start_1, edit[2])):
if not span2.endswith(src_tokens[i] + common_str):
break
common_str = src_tokens[i] + common_str
tmp_new_end_1 = i
if common_str:
tmp_str = ""
for i in reversed(range(new_start_2, edit[4])):
tmp_str = tgt_tokens[i] + tmp_str
if tmp_str == common_str:
new_end_1, new_end_2 = tmp_new_end_1, i
b = False
break
elif len(tmp_str) > len(common_str):
break
if b:
second_filter.append(edit)
else:
if new_start_1 == new_end_1:
new_edit = ("I", new_start_1, new_end_1, new_start_2, new_end_2)
elif new_start_2 == new_end_2:
new_edit = ("D", new_start_1, new_end_1, new_start_2, new_end_2)
else:
new_edit = ("S", new_start_1, new_end_1, new_start_2, new_end_2)
second_filter.append(new_edit)
else:
second_filter.append(edit)
if verbose:
print("========== Parallels ==========")
print("".join(src_tokens))
print("".join(tgt_tokens))
print("========== Results ==========")
for edit in second_filter:
op = edit[0]
s = " ".join(src_tokens[edit[1]: edit[2]])
t = " ".join(tgt_tokens[edit[3]: edit[4]])
print(f"{op}:\t{s}\t-->\t{t}")
print("========== Infos ==========")
print(str(src))
print(str(tgt))
return second_filter
if __name__ == "__main__":
tokenizer = Tokenizer("char")
semantic_dict, semantic_class = read_cilin()
confusion_dict = read_confusion()
alignment = Alignment(semantic_dict, confusion_dict)
sents = [
"所 以 印 度 对 全 世 界 人 没 有 说 服 不 要 吃 牛 肉 。".replace(
" ", ""),
"所 以 印 度 没 有 说 服 全 世 界 人 不 要 吃 牛 肉 。".replace(
" ", "")]
src, tgt = tokenizer(sents)
align_obj = alignment(src, tgt)
m = Merger()
m(align_obj, src, tgt, verbose=True)