""" task: multiple choice classification metric: F1 score 婚姻文本分类 """ def compute_wbfl(data_dict): """ A reference (R) contains a list of options, each option is from the option_list. We will extract the options appearing in the prediction and convert them into a set (P). We compute the F1 score between the prediction (P) and the reference (R). """ score_list, abstentions = [], 0 option_list = ["婚后有子女", "限制行为能力子女抚养", "有夫妻共同财产", "支付抚养费", "不动产分割", "婚后分局", "二次起诉离婚", "按月给付抚养费", "准予离婚", "有夫妻共同债务", "婚前个人财产", "法定离婚", "不履行家庭义务", "存在非婚生子", "适当帮助", "不履行离婚协议", "损害赔偿", "感情不和分居满二年", "子女随非抚养权人生活", "婚后个人财产"] for example in data_dict: question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"] assert answer.startswith("类别:") and answer.endswith("。"), f"answer: {answer}, question: {question}" gt_list = (answer[3:-1].split("、")) for gt in gt_list: assert gt in option_list, f"gt: {gt}, question: {question}" gt_set = set(gt_list) prediction_list = [] for option in option_list: if option in prediction: prediction_list.append(option) if len(prediction_list) == 0: abstentions += 1 predict_set = set(prediction_list) precision = len(gt_set.intersection(predict_set)) / len(predict_set) if len(predict_set) != 0 else 0 recall = len(gt_set.intersection(predict_set)) / len(gt_set) if len(gt_set) != 0 else 0 f1_score = 2 * precision * recall / (precision + recall) if (precision + recall) != 0 else 0 score_list.append(f1_score) # compute the accuracy of score_list final_f1_score = sum(score_list) / len(score_list) return {'score': final_f1_score, 'abstention_rate': abstentions / len(data_dict)}