api-demo
/
opencompass-my-api
/opencompass
/datasets
/lawbench
/evaluation_functions
/ljp_article.py
import re | |
import cn2an | |
""" | |
task: law article prediction | |
metric: F1 score | |
法律判决预测-法条预测 | |
""" | |
def replace_match(match): | |
return match.group(1) | |
def compute_ljp_article(data_dict): | |
""" | |
Compute the F1-score | |
A reference contains a list of articles of the Criminal Law of the People's Republic of China. | |
We compute the F1-score between the prediction and the reference. | |
""" | |
score_list, abstentions = [], 0 | |
for example in data_dict: | |
question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"] | |
assert answer.startswith("法条:刑法第"), f"answer: {answer}" | |
assert answer.endswith("条"), f"answer: {answer}" | |
answer = answer.replace("法条:刑法第", "") | |
answer = answer.replace("条", "") | |
answer_law_indices = answer.split("、") | |
answer_law_index_digit_list = [] | |
for answer_law_index in answer_law_indices: | |
assert answer_law_index.isdigit(), f"answer_law_index: {answer_law_index}" | |
answer_law_index_digit = int(answer_law_index) | |
assert answer_law_index_digit <= 490, "刑法总共只有490条" | |
answer_law_index_digit_list.append(answer_law_index_digit) | |
prediction_law_chunks = prediction.split("、") | |
prediction_law_index_digit_list = [] | |
for prediction_law_chunk in prediction_law_chunks: | |
prediction_law_chunk = prediction_law_chunk.replace("万元", "元") | |
# delete phrase starts with "第" and ends with "款", we don't care about it in the answer | |
prediction_law_chunk = re.sub(r'第(.*?)款', "", prediction_law_chunk) | |
# keep only the digits in the phrase starts with "第" and ends with "条", otherwise cn may fail to convert | |
prediction_law_chunk = re.sub(r'第(.*?)条', replace_match, prediction_law_chunk) | |
prediction_law_chunk = cn2an.transform(prediction_law_chunk, "cn2an") | |
# find digtis in prediction_law_chunk | |
prediction_law_section_numbers = re.findall(r"\d+", prediction_law_chunk) | |
if len(prediction_law_section_numbers) == 0: | |
continue | |
if len(prediction_law_section_numbers) != 1: | |
# in this case, we only take the first number, and reject the others | |
pass | |
prediction_law_index_digit = int(prediction_law_section_numbers[0]) | |
prediction_law_index_digit_list.append(prediction_law_index_digit) | |
gt_set = set(answer_law_index_digit_list) | |
pred_set = set(prediction_law_index_digit_list) | |
if len(pred_set) == 0: | |
abstentions += 1 | |
precision = len(gt_set.intersection(pred_set)) / len(pred_set) if len(pred_set) != 0 else 0 | |
recall = len(gt_set.intersection(pred_set)) / len(gt_set) if len(gt_set) != 0 else 0 | |
f1_score = 2 * precision * recall / (precision + recall) if (precision + recall) != 0 else 0 | |
score_list.append(f1_score) | |
# compute the accuracy of score_list | |
average_f1 = sum(score_list) / len(score_list) | |
return {'score': average_f1, 'abstention_rate': abstentions/len(data_dict)} | |