|
from itertools import groupby |
|
from string import punctuation |
|
from typing import List |
|
from modules.tokenizer import Tokenizer |
|
from modules.alignment import Alignment, read_cilin, read_confusion |
|
import Levenshtein |
|
|
|
class Merger: |
|
""" |
|
合并编辑操作,从Token-Level转换为Span-Level |
|
""" |
|
|
|
def __init__(self, |
|
granularity: str = "word", |
|
merge: bool = False): |
|
chinese_punct = "!?。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟–—‘'‛“”„‟…‧." |
|
self.punctuation = punctuation + chinese_punct |
|
self.not_merge_token = [punct for punct in self.punctuation] |
|
self.granularity = granularity |
|
self.merge = merge |
|
|
|
@staticmethod |
|
def _merge_edits(seq, tag="X"): |
|
if seq: |
|
return [(tag, seq[0][1], seq[-1][2], seq[0][3], seq[-1][4])] |
|
else: |
|
return seq |
|
|
|
@staticmethod |
|
def _check_revolve(span_a, span_b): |
|
span_a = span_a + span_a |
|
return span_b in span_a |
|
|
|
def _process_seq(self, seq, src_tokens, tgt_tokens): |
|
if len(seq) <= 1: |
|
return seq |
|
|
|
ops = [op[0] for op in seq] |
|
if set(ops) == {"D"} or set(ops) == {"I"}: |
|
return self._merge_edits(seq, set(ops).pop()) |
|
|
|
if set(ops) == {"D", "I"} or set(ops) == {"I", "D"}: |
|
|
|
return seq |
|
|
|
if set(ops) == {"S"}: |
|
if self.granularity == "word": |
|
return seq |
|
else: |
|
return self._merge_edits(seq, "S") |
|
|
|
if set(ops) == {"M"}: |
|
return self._merge_edits(seq, "M") |
|
|
|
return self._merge_edits(seq, "S") |
|
|
|
def __call__(self, |
|
align_obj, |
|
src: List, |
|
tgt: List, |
|
verbose: bool = False): |
|
""" |
|
Based on ERRANT's merge, adapted for Chinese |
|
""" |
|
src_tokens = [x[0] for x in src] |
|
tgt_tokens = [x[0] for x in tgt] |
|
edits = [] |
|
|
|
|
|
|
|
for op, group in groupby( |
|
align_obj, |
|
lambda x: x[0][0] if x[0][0] in {"M", "T"} else False, |
|
): |
|
group = list(group) |
|
|
|
if op == "T": |
|
for seq in group: |
|
edits.append(seq) |
|
|
|
else: |
|
|
|
processed = self._process_seq(group, src_tokens, tgt_tokens) |
|
for seq in processed: |
|
edits.append(seq) |
|
|
|
filtered_edits = [] |
|
i = 0 |
|
while i < len(edits): |
|
e1 = edits[i][0][0] |
|
|
|
if i < len(edits) - 2: |
|
e2 = edits[i + 1][0][0] |
|
e3 = edits[i + 2][0][0] |
|
|
|
|
|
|
|
|
|
|
|
|
|
if e1 == "S" and e2 == "M" and e3 == "S": |
|
w1 = "".join(src_tokens[edits[i][1]: edits[i][2]]) |
|
w2 = "".join(tgt_tokens[edits[i][3]: edits[i][4]]) |
|
w3 = "".join(src_tokens[edits[i + 2][1]: edits[i + 2][2]]) |
|
w4 = "".join(tgt_tokens[edits[i + 2][3]: edits[i + 2][4]]) |
|
if min([len(w1), len(w2), len(w3), len(w4)]) == 1: |
|
if w1 == w4 and w2 == w3: |
|
group = [edits[i], edits[i + 1], edits[i + 2]] |
|
processed = self._merge_edits(group, "T" + str(edits[i+2][2] - edits[i][1])) |
|
for seq in processed: |
|
filtered_edits.append(seq) |
|
i += 3 |
|
else: |
|
filtered_edits.append(edits[i]) |
|
i += 1 |
|
else: |
|
if Levenshtein.distance(w1, w4) <= 1 and Levenshtein.distance(w2, w3) <= 1: |
|
group = [edits[i], edits[i + 1], edits[i + 2]] |
|
processed = self._merge_edits(group, "T" + str(edits[i + 2][2] - edits[i][1])) |
|
for seq in processed: |
|
filtered_edits.append(seq) |
|
i += 3 |
|
else: |
|
filtered_edits.append(edits[i]) |
|
i += 1 |
|
|
|
|
|
|
|
|
|
|
|
elif (e1 == "D" and (e2 == "M" or e2.startswith("T")) and e3 == "I") or (e1 == "I" and (e2 == "M" or e2.startswith("T")) and e3 == "D"): |
|
if e1 == "D": |
|
delete_token = src_tokens[edits[i][1]: edits[i][2]] |
|
insert_token = tgt_tokens[edits[i + 2][3]: edits[i + 2][4]] |
|
else: |
|
delete_token = src_tokens[edits[i + 2][1]: edits[i + 2][2]] |
|
insert_token = tgt_tokens[edits[i][3]: edits[i][4]] |
|
a, b = "".join(delete_token), "".join(insert_token) |
|
if len(a) < len(b): |
|
a, b = b, a |
|
if a not in self.punctuation and b not in self.punctuation and len(a) - len(b) <= 1: |
|
if len(b) == 1: |
|
if a == b: |
|
group = [edits[i], edits[i + 1], edits[i + 2]] |
|
processed = self._merge_edits(group, "T" + str(edits[i+2][2] - edits[i][1])) |
|
for seq in processed: |
|
filtered_edits.append(seq) |
|
i += 3 |
|
else: |
|
filtered_edits.append(edits[i]) |
|
i += 1 |
|
else: |
|
if Levenshtein.distance(a, b) <= 1 or (len(a) == len(b) and self._check_revolve(a, b)): |
|
group = [edits[i], edits[i + 1], edits[i + 2]] |
|
processed = self._merge_edits(group, "T" + str(edits[i + 2][2] - edits[i][1])) |
|
for seq in processed: |
|
filtered_edits.append(seq) |
|
i += 3 |
|
else: |
|
filtered_edits.append(edits[i]) |
|
i += 1 |
|
else: |
|
filtered_edits.append(edits[i]) |
|
i += 1 |
|
else: |
|
if e1 != "M": |
|
filtered_edits.append(edits[i]) |
|
i += 1 |
|
else: |
|
if e1 != "M": |
|
filtered_edits.append(edits[i]) |
|
i += 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
second_filter = [] |
|
for edit in filtered_edits: |
|
span1 = "".join(src_tokens[edit[1] : edit[2]]) |
|
span2 = "".join(tgt_tokens[edit[3] : edit[4]]) |
|
|
|
if span1 != span2: |
|
if edit[0] == "S": |
|
b = True |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
common_str = "" |
|
tmp_new_start_1 = edit[1] |
|
for i in range(edit[1], edit[2]): |
|
if not span2.startswith(common_str + src_tokens[i]): |
|
break |
|
common_str += src_tokens[i] |
|
tmp_new_start_1 = i + 1 |
|
new_start_1, new_start_2 = edit[1], edit[3] |
|
if common_str: |
|
tmp_str = "" |
|
for i in range(edit[3], edit[4]): |
|
tmp_str += tgt_tokens[i] |
|
if tmp_str == common_str: |
|
new_start_1, new_start_2 = tmp_new_start_1, i + 1 |
|
|
|
b = False |
|
break |
|
elif len(tmp_str) > len(common_str): |
|
break |
|
|
|
common_str = "" |
|
new_end_1, new_end_2 = edit[2], edit[4] |
|
tmp_new_end_1 = edit[2] |
|
for i in reversed(range(new_start_1, edit[2])): |
|
if not span2.endswith(src_tokens[i] + common_str): |
|
break |
|
common_str = src_tokens[i] + common_str |
|
tmp_new_end_1 = i |
|
if common_str: |
|
tmp_str = "" |
|
for i in reversed(range(new_start_2, edit[4])): |
|
tmp_str = tgt_tokens[i] + tmp_str |
|
if tmp_str == common_str: |
|
new_end_1, new_end_2 = tmp_new_end_1, i |
|
b = False |
|
break |
|
elif len(tmp_str) > len(common_str): |
|
break |
|
if b: |
|
second_filter.append(edit) |
|
else: |
|
if new_start_1 == new_end_1: |
|
new_edit = ("I", new_start_1, new_end_1, new_start_2, new_end_2) |
|
elif new_start_2 == new_end_2: |
|
new_edit = ("D", new_start_1, new_end_1, new_start_2, new_end_2) |
|
else: |
|
new_edit = ("S", new_start_1, new_end_1, new_start_2, new_end_2) |
|
second_filter.append(new_edit) |
|
else: |
|
second_filter.append(edit) |
|
if verbose: |
|
print("========== Parallels ==========") |
|
print("".join(src_tokens)) |
|
print("".join(tgt_tokens)) |
|
print("========== Results ==========") |
|
for edit in second_filter: |
|
op = edit[0] |
|
s = " ".join(src_tokens[edit[1]: edit[2]]) |
|
t = " ".join(tgt_tokens[edit[3]: edit[4]]) |
|
print(f"{op}:\t{s}\t-->\t{t}") |
|
print("========== Infos ==========") |
|
print(str(src)) |
|
print(str(tgt)) |
|
return second_filter |
|
|
|
if __name__ == "__main__": |
|
tokenizer = Tokenizer("char") |
|
semantic_dict, semantic_class = read_cilin() |
|
confusion_dict = read_confusion() |
|
alignment = Alignment(semantic_dict, confusion_dict) |
|
sents = [ |
|
"所 以 印 度 对 全 世 界 人 没 有 说 服 不 要 吃 牛 肉 。".replace( |
|
" ", ""), |
|
"所 以 印 度 没 有 说 服 全 世 界 人 不 要 吃 牛 肉 。".replace( |
|
" ", "")] |
|
src, tgt = tokenizer(sents) |
|
align_obj = alignment(src, tgt) |
|
m = Merger() |
|
m(align_obj, src, tgt, verbose=True) |