TwT-6's picture
Upload 2667 files
256a159 verified
import re
import cn2an
"""
task: law article prediction
metric: F1 score
法律判决预测-法条预测
"""
def replace_match(match):
return match.group(1)
def compute_ljp_article(data_dict):
"""
Compute the F1-score
A reference contains a list of articles of the Criminal Law of the People's Republic of China.
We compute the F1-score between the prediction and the reference.
"""
score_list, abstentions = [], 0
for example in data_dict:
question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"]
assert answer.startswith("法条:刑法第"), f"answer: {answer}"
assert answer.endswith("条"), f"answer: {answer}"
answer = answer.replace("法条:刑法第", "")
answer = answer.replace("条", "")
answer_law_indices = answer.split("、")
answer_law_index_digit_list = []
for answer_law_index in answer_law_indices:
assert answer_law_index.isdigit(), f"answer_law_index: {answer_law_index}"
answer_law_index_digit = int(answer_law_index)
assert answer_law_index_digit <= 490, "刑法总共只有490条"
answer_law_index_digit_list.append(answer_law_index_digit)
prediction_law_chunks = prediction.split("、")
prediction_law_index_digit_list = []
for prediction_law_chunk in prediction_law_chunks:
prediction_law_chunk = prediction_law_chunk.replace("万元", "元")
# delete phrase starts with "第" and ends with "款", we don't care about it in the answer
prediction_law_chunk = re.sub(r'第(.*?)款', "", prediction_law_chunk)
# keep only the digits in the phrase starts with "第" and ends with "条", otherwise cn may fail to convert
prediction_law_chunk = re.sub(r'第(.*?)条', replace_match, prediction_law_chunk)
prediction_law_chunk = cn2an.transform(prediction_law_chunk, "cn2an")
# find digtis in prediction_law_chunk
prediction_law_section_numbers = re.findall(r"\d+", prediction_law_chunk)
if len(prediction_law_section_numbers) == 0:
continue
if len(prediction_law_section_numbers) != 1:
# in this case, we only take the first number, and reject the others
pass
prediction_law_index_digit = int(prediction_law_section_numbers[0])
prediction_law_index_digit_list.append(prediction_law_index_digit)
gt_set = set(answer_law_index_digit_list)
pred_set = set(prediction_law_index_digit_list)
if len(pred_set) == 0:
abstentions += 1
precision = len(gt_set.intersection(pred_set)) / len(pred_set) if len(pred_set) != 0 else 0
recall = len(gt_set.intersection(pred_set)) / len(gt_set) if len(gt_set) != 0 else 0
f1_score = 2 * precision * recall / (precision + recall) if (precision + recall) != 0 else 0
score_list.append(f1_score)
# compute the accuracy of score_list
average_f1 = sum(score_list) / len(score_list)
return {'score': average_f1, 'abstention_rate': abstentions/len(data_dict)}