File size: 14,437 Bytes
256a159 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 |
import numpy as np
from typing import List, Tuple, Dict
from modules.tokenizer import Tokenizer
import os
from string import punctuation
REAL_PATH = os.path.split(os.path.realpath(__file__))[0]
chinese_punct = "!?。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘'‛“”„‟…‧﹏"
english_punct = punctuation
punct = chinese_punct + english_punct
def check_all_chinese(word):
"""
判断一个单词是否全部由中文组成
:param word:
:return:
"""
return all(['\u4e00' <= ch <= '\u9fff' for ch in word])
def read_cilin():
"""
Cilin 詞林 is a thesaurus with semantic information
"""
# TODO -- fix this path
lines = open(os.path.join(os.path.dirname(__file__), "..", "..", "..", "..", "..", "data", "lawbench", "eval_assets", "cilin.txt"), "r", encoding="gbk").read().strip().split("\n")
semantic_dict = {}
semantic_classes = {}
for line in lines:
code, *words = line.split(" ")
for word in words:
semantic_dict[word] = code
# make reverse dict
if code in semantic_classes:
semantic_classes[code] += words
else:
semantic_classes[code] = words
return semantic_dict, semantic_classes
def read_confusion():
confusion_dict = {}
with open(os.path.join(os.path.dirname(__file__), "..", "..", "..", "..", "..", "data", "lawbench", "eval_assets", "confusion_dict.txt"), "r", encoding="utf-8") as f:
for line in f:
li = line.rstrip('\n').split(" ")
confusion_dict[li[0]] = li[1:]
return confusion_dict
class Alignment:
"""
对齐错误句子和正确句子,
使用编辑距离算法抽取编辑操作
"""
def __init__(
self,
semantic_dict: Dict,
confusion_dict: Dict,
granularity: str = "word",
) -> None:
"""
构造函数
:param semantic_dict: 语义词典(大词林)
:param confusion_dict: 字符混淆集
"""
self.insertion_cost = 1
self.deletion_cost = 1
self.semantic_dict = semantic_dict
self.confusion_dict = confusion_dict
# Because we use character level tokenization, this doesn't currently use POS
self._open_pos = {} # 如果是词级别,还可以利用词性是否相同来计算cost
self.granularity = granularity # word-level or character-level
self.align_seqs = []
def __call__(self,
src: List[Tuple],
tgt: List[Tuple],
verbose: bool = False):
cost_matrix, oper_matrix = self.align(src, tgt)
align_seq = self.get_cheapest_align_seq(oper_matrix)
if verbose:
print("========== Seg. and POS: ==========")
print(src)
print(tgt)
print("========== Cost Matrix ==========")
print(cost_matrix)
print("========== Oper Matrix ==========")
print(oper_matrix)
print("========== Alignment ==========")
print(align_seq)
print("========== Results ==========")
for a in align_seq:
print(a[0], src[a[1]: a[2]], tgt[a[3]: a[4]])
return align_seq
def _get_semantic_class(self, word):
"""
NOTE: Based on the paper:
Improved-Edit-Distance Kernel for Chinese Relation Extraction
获取每个词语的语义类别(基于大词林,有三个级别)
"""
if word in self.semantic_dict:
code = self.semantic_dict[word]
high, mid, low = code[0], code[1], code[2:4]
return high, mid, low
else: # unknown
return None
@staticmethod
def _get_class_diff(a_class, b_class):
"""
d == 3 for equivalent semantics
d == 0 for completely different semantics
根据大词林的信息,计算两个词的语义类别的差距
"""
d = sum([a == b for a, b in zip(a_class, b_class)])
return d
def _get_semantic_cost(self, a, b):
"""
计算基于语义信息的替换操作cost
:param a: 单词a的语义类别
:param b: 单词b的语义类别
:return: 替换编辑代价
"""
a_class = self._get_semantic_class(a)
b_class = self._get_semantic_class(b)
# unknown class, default to 1
if a_class is None or b_class is None:
return 4
elif a_class == b_class:
return 0
else:
return 2 * (3 - self._get_class_diff(a_class, b_class))
def _get_pos_cost(self, a_pos, b_pos):
"""
计算基于词性信息的编辑距离cost
:param a_pos: 单词a的词性
:param b_pos: 单词b的词性
:return: 替换编辑代价
"""
if a_pos == b_pos:
return 0
elif a_pos in self._open_pos and b_pos in self._open_pos:
return 0.25
else:
return 0.499
def _get_char_cost(self, a, b, pinyin_a, pinyin_b):
"""
NOTE: This is a replacement of ERRANTS lemma cost for Chinese
计算基于字符相似度的编辑距离cost
"""
if not (check_all_chinese(a) and check_all_chinese(b)):
return 0.5
if len(a) > len(b):
a, b = b, a
pinyin_a, pinyin_b = pinyin_b, pinyin_a
if a == b:
return 0
else:
return self._get_spell_cost(a, b, pinyin_a, pinyin_b)
def _get_spell_cost(self, a, b, pinyin_a, pinyin_b):
"""
计算两个单词拼写相似度,分别由字形相似度和字音相似度组成
:param a: 单词a
:param b: 单词b,且单词a的长度小于等于b
:param pinyin_a: 单词a的拼音
:param pinyin_b: 单词b的拼音
:return: 替换操作cost
"""
count = 0
for i in range(len(a)):
for j in range(len(b)):
if a[i] == b[j] or (set(pinyin_a) & set(pinyin_b)) or (b[j] in self.confusion_dict.keys() and a[i] in self.confusion_dict[b[j]]) or (a[i] in self.confusion_dict.keys() and b[j] in self.confusion_dict[a[i]]):
count += 1
break
return (len(a) - count) / (len(a) * 2)
def get_sub_cost(self, a_seg, b_seg):
"""
Calculate the substitution cost between words a and b
计算两个单词替换操作的编辑cost,最大为2,等于一次删除和一次添加
"""
if a_seg[0] == b_seg[0]:
return 0
if self.granularity == "word": # 词级别可以额外利用词性信息
semantic_cost = self._get_semantic_cost(a_seg[0], b_seg[0]) / 6.0
pos_cost = self._get_pos_cost(a_seg[1], b_seg[1])
char_cost = self._get_char_cost(a_seg[0], b_seg[0], a_seg[2], b_seg[2])
return semantic_cost + pos_cost + char_cost
else: # 字级别只能利用字义信息(从大词林中获取)和字面相似度信息
semantic_cost = self._get_semantic_cost(a_seg[0], b_seg[0]) / 6.0
if a_seg[0] in punct and b_seg[0] in punct:
pos_cost = 0.0
elif a_seg[0] not in punct and b_seg[0] not in punct:
pos_cost = 0.25
else:
pos_cost = 0.499
# pos_cost = 0.0 if (a_seg[0] in punct and b_seg[0] in punct) or (a_seg[0] not in punct and b_seg[0] not in punct) else 0.5
char_cost = self._get_char_cost(a_seg[0], b_seg[0], a_seg[2], b_seg[2])
return semantic_cost + char_cost + pos_cost
def align(self,
src: List[Tuple],
tgt: List[Tuple]):
"""
Based on ERRANT's alignment
基于改进的动态规划算法,为原句子的每个字打上编辑标签,以便使它能够成功转换为目标句子。
编辑操作类别:
1) M:Match,即KEEP,即当前字保持不变
2) D:Delete,删除,即当前字需要被删除
3) I:Insert,插入,即当前字需要被插入
4) T:Transposition,移位操作,即涉及到词序问题
"""
cost_matrix = np.zeros((len(src) + 1, len(tgt) + 1)) # 编辑cost矩阵
oper_matrix = np.full(
(len(src) + 1, len(tgt) + 1), "O", dtype=object
) # 操作矩阵
# Fill in the edges
for i in range(1, len(src) + 1):
cost_matrix[i][0] = cost_matrix[i - 1][0] + 1
oper_matrix[i][0] = ["D"]
for j in range(1, len(tgt) + 1):
cost_matrix[0][j] = cost_matrix[0][j - 1] + 1
oper_matrix[0][j] = ["I"]
# Loop through the cost matrix
for i in range(len(src)):
for j in range(len(tgt)):
# Matches
if src[i][0] == tgt[j][0]: # 如果两个字相等,则匹配成功(Match),编辑距离为0
cost_matrix[i + 1][j + 1] = cost_matrix[i][j]
oper_matrix[i + 1][j + 1] = ["M"]
# Non-matches
else:
del_cost = cost_matrix[i][j + 1] + self.deletion_cost # 由删除动作得到的总cost
ins_cost = cost_matrix[i + 1][j] + self.insertion_cost # 由插入动作得到的总cost
sub_cost = cost_matrix[i][j] + self.get_sub_cost(
src[i], tgt[j]
) # 由替换动作得到的总cost
# Calculate transposition cost
# 计算移位操作的总cost
trans_cost = float("inf")
k = 1
while (
i - k >= 0
and j - k >= 0
and cost_matrix[i - k + 1][j - k + 1]
!= cost_matrix[i - k][j - k]
):
p1 = sorted([a[0] for a in src][i - k: i + 1])
p2 = sorted([b[0] for b in tgt][j - k: j + 1])
if p1 == p2:
trans_cost = cost_matrix[i - k][j - k] + k
break
k += 1
costs = [trans_cost, sub_cost, ins_cost, del_cost]
ind = costs.index(min(costs))
cost_matrix[i + 1][j + 1] = costs[ind]
# ind = costs.index(costs[ind], ind+1)
for idx, cost in enumerate(costs):
if cost == costs[ind]:
if idx == 0:
if oper_matrix[i + 1][j + 1] == "O":
oper_matrix[i + 1][j + 1] = ["T" + str(k + 1)]
else:
oper_matrix[i + 1][j + 1].append("T" + str(k + 1))
elif idx == 1:
if oper_matrix[i + 1][j + 1] == "O":
oper_matrix[i + 1][j + 1] = ["S"]
else:
oper_matrix[i + 1][j + 1].append("S")
elif idx == 2:
if oper_matrix[i + 1][j + 1] == "O":
oper_matrix[i + 1][j + 1] = ["I"]
else:
oper_matrix[i + 1][j + 1].append("I")
else:
if oper_matrix[i + 1][j + 1] == "O":
oper_matrix[i + 1][j + 1] = ["D"]
else:
oper_matrix[i + 1][j + 1].append("D")
return cost_matrix, oper_matrix
def _dfs(self, i, j, align_seq_now, oper_matrix, strategy="all"):
"""
深度优先遍历,获取最小编辑距离相同的所有序列
"""
if i + j == 0:
self.align_seqs.append(align_seq_now)
else:
ops = oper_matrix[i][j] # 可以类比成搜索一棵树从根结点到叶子结点的所有路径
if strategy != "all": ops = ops[:1]
for op in ops:
if op in {"M", "S"}:
self._dfs(i - 1, j - 1, align_seq_now + [(op, i - 1, i, j - 1, j)], oper_matrix, strategy)
elif op == "D":
self._dfs(i - 1, j, align_seq_now + [(op, i - 1, i, j, j)], oper_matrix, strategy)
elif op == "I":
self._dfs(i, j - 1, align_seq_now + [(op, i, i, j - 1, j)], oper_matrix, strategy)
else:
k = int(op[1:])
self._dfs(i - k, j - k, align_seq_now + [(op, i - k, i, j - k, j)], oper_matrix, strategy)
def get_cheapest_align_seq(self, oper_matrix):
"""
回溯获得编辑距离最小的编辑序列
"""
self.align_seqs = []
i = oper_matrix.shape[0] - 1
j = oper_matrix.shape[1] - 1
if abs(i - j) > 10:
self._dfs(i, j , [], oper_matrix, "first")
else:
self._dfs(i, j , [], oper_matrix, "all")
final_align_seqs = [seq[::-1] for seq in self.align_seqs]
return final_align_seqs
if __name__ == "__main__":
tokenizer = Tokenizer("word")
semantic_dict, semantic_class = read_cilin()
confusion_dict = read_confusion()
alignment = Alignment(semantic_dict, confusion_dict)
sents = ["首先 , 我们 得 准备 : 大 虾六 到 九 只 、 盐 一 茶匙 、 已 搾 好 的 柠檬汁 三 汤匙 、 泰国 柠檬 叶三叶 、 柠檬 香草 一 根 、 鱼酱 两 汤匙 、 辣椒 6 粒 , 纯净 水 4量杯 、 香菜 半量杯 和 草菇 10 个 。".replace(" ", ""), "首先 , 我们 得 准备 : 大 虾六 到 九 只 、 盐 一 茶匙 、 已 榨 好 的 柠檬汁 三 汤匙 、 泰国 柠檬 叶三叶 、 柠檬 香草 一 根 、 鱼酱 两 汤匙 、 辣椒 六 粒 , 纯净 水 四 量杯 、 香菜 半量杯 和 草菇 十 个 。".replace(" ", "")]
src, tgt = tokenizer(sents)
alignment(src, tgt, verbose=True) |