MichielPronk commited on
Commit
07def9e
verified
1 Parent(s): 9950495

Added script to load model and compute metrics on given file

Browse files
Files changed (1) hide show
  1. compute_metrics.py +256 -0
compute_metrics.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Separate file which contains the functions to convert predictions to hard
2
+ # labels and calculate the IoU score using the settings of our best model in
3
+ # SemEval 2025 Task 3.
4
+ import argparse
5
+ import collections
6
+ from scipy.stats import spearmanr
7
+
8
+ import jsonlines
9
+ import numpy as np
10
+ from datasets import load_dataset
11
+ from tqdm.auto import tqdm
12
+ from transformers import AutoModelForQuestionAnswering, AutoTokenizer, TrainingArguments, Trainer
13
+
14
+ def add_answers_column(example):
15
+ starts, texts = [], []
16
+ for hard_label in example["hard_labels"]:
17
+ starts.append(hard_label[0])
18
+ texts.append(example["context"][hard_label[0]:hard_label[1]])
19
+ example["answers"] = {"answer_start": starts, "text": texts}
20
+ return example
21
+
22
+ def to_dataset(file_path):
23
+ mushroom = load_dataset("json", data_files=file_path)["train"]
24
+ mushroom = mushroom.rename_column("model_output_text", "context")
25
+ mushroom = mushroom.rename_column("model_input", "question")
26
+ if "hard_labels" in mushroom.column_names:
27
+ mushroom = mushroom.map(add_answers_column)
28
+ else:
29
+ print("No hard labels found in the evaluation data: only generating predictions.")
30
+
31
+ return mushroom
32
+
33
+ def preprocess_examples(examples, tokenizer):
34
+ questions = [q.strip() for q in examples["question"]]
35
+ inputs = tokenizer(
36
+ questions,
37
+ examples["context"],
38
+ max_length=384,
39
+ truncation="only_second",
40
+ stride=128,
41
+ return_overflowing_tokens=True,
42
+ return_offsets_mapping=True,
43
+ padding="max_length",
44
+ )
45
+
46
+ sample_map = inputs.pop("overflow_to_sample_mapping")
47
+ example_ids = []
48
+
49
+ for i in range(len(inputs["input_ids"])):
50
+ sample_idx = sample_map[i]
51
+ example_ids.append(examples["id"][sample_idx])
52
+
53
+ sequence_ids = inputs.sequence_ids(i)
54
+ offset = inputs["offset_mapping"][i]
55
+ inputs["offset_mapping"][i] = [
56
+ o if sequence_ids[k] == 1 else None for k, o in enumerate(offset)
57
+ ]
58
+
59
+ inputs["example_id"] = example_ids
60
+ return inputs
61
+
62
+
63
+ def score_iou(ref_dict, pred_dict):
64
+ """
65
+ Computes intersection-over-union between reference and predicted hard
66
+ labels, for a single datapoint.
67
+
68
+ Arguments:
69
+ ref_dict (dict): a gold reference datapoint,
70
+ pred_dict (dict): a model's prediction
71
+
72
+ Returns:
73
+ int: The IoU, or 1.0 if neither the reference nor the prediction contain hallucinations
74
+ """
75
+ # ensure the prediction is correctly matched to its reference
76
+ assert ref_dict['id'] == pred_dict['id']
77
+ # convert annotations to sets of indices
78
+ ref_indices = {idx for span in ref_dict['hard_labels'] for idx in range(*span)}
79
+ pred_indices = {idx for span in pred_dict['hard_labels'] for idx in range(*span)}
80
+ # avoid division by zero
81
+ if not pred_indices and not ref_indices: return 1.
82
+ # otherwise compute & return IoU
83
+ return len(ref_indices & pred_indices) / len(ref_indices | pred_indices)
84
+
85
+ def score_cor(ref_dict, pred_dict):
86
+ """computes Spearman correlation between predicted and reference soft labels, for a single datapoint.
87
+ inputs:
88
+ - ref_dict: a gold reference datapoint,
89
+ - pred_dict: a model's prediction
90
+ returns:
91
+ the Spearman correlation, or a binarized exact match (0.0 or 1.0) if the reference or prediction contains no variation
92
+ """
93
+ # ensure the prediction is correctly matched to its reference
94
+ assert ref_dict['id'] == pred_dict['id']
95
+ # convert annotations to vectors of observations
96
+ ref_vec = [0.] * ref_dict['text_len']
97
+ pred_vec = [0.] * ref_dict['text_len']
98
+ for span in ref_dict['soft_labels']:
99
+ for idx in range(span['start'], span['end']):
100
+ ref_vec[idx] = span['prob']
101
+ for span in pred_dict['soft_labels']:
102
+ for idx in range(span['start'], span['end']):
103
+ pred_vec[idx] = span['prob']
104
+ # constant series (i.e., no hallucination) => cor is undef
105
+ if len({round(flt, 8) for flt in pred_vec}) == 1 or len({round(flt, 8) for flt in ref_vec}) == 1 :
106
+ return float(len({round(flt, 8) for flt in ref_vec}) == len({round(flt, 8) for flt in pred_vec}))
107
+ # otherwise compute Spearman's rho
108
+ return spearmanr(ref_vec, pred_vec).correlation
109
+
110
+ def infer_soft_labels(hard_labels):
111
+ """reformat hard labels into soft labels with prob 1"""
112
+ return [
113
+ {
114
+ 'start': start,
115
+ 'end': end,
116
+ 'prob': 1.0,
117
+ }
118
+ for start, end in hard_labels
119
+ ]
120
+
121
+ def find_possible_spans(answers, example):
122
+ """
123
+ Creates and filters possible hallucination spans.
124
+
125
+ Arguments:
126
+ answers (list): List containing dictionaries with spans as text and
127
+ logit scores.
128
+ example: The instance which is being predicted. The context is used to map the predicted text to the start
129
+ and end indexes of the target context.
130
+ Returns:
131
+ list: List with lists of hard labels.
132
+ """
133
+ best_answer = max(answers, key=lambda x: x["logit_score"])
134
+ threshold = best_answer["logit_score"] * 0.8
135
+ hard_labels = []
136
+ for answer in answers:
137
+ if answer["logit_score"] > threshold:
138
+ start_index = example["context"].index(answer["text"])
139
+ end_index = start_index + len(answer["text"])
140
+ hard_labels.append([start_index, end_index])
141
+ soft_labels = infer_soft_labels(hard_labels)
142
+ return hard_labels, soft_labels
143
+
144
+ def compute_metrics(start_logits, end_logits, features, examples, predictions_file):
145
+ """
146
+ Function to process predictions, create spans and if possible,
147
+ calculates IoU
148
+
149
+ Arguments:
150
+ args (ArgumentParser): Arguments supplied by user.
151
+ start_logits (list): Logits of all start positions.
152
+ end_logits (list): Logits of all end positions.
153
+ features (Dataset): Dataset containing features of questions and context.
154
+ examples (Dataset): Dataset containing examples with hard labels.
155
+
156
+ Returns:
157
+ None
158
+ """
159
+ example_to_features = collections.defaultdict(list)
160
+ for idx, feature in enumerate(features):
161
+ example_to_features[feature["example_id"]].append(idx)
162
+
163
+ predicted_answers = []
164
+ for example in tqdm(examples):
165
+ example_id = example["id"]
166
+ context = example["context"]
167
+ answers = []
168
+
169
+ # Loop through all features associated with that example
170
+ for feature_index in example_to_features[example_id]:
171
+ start_logit = start_logits[feature_index]
172
+ end_logit = end_logits[feature_index]
173
+ offsets = features[feature_index]["offset_mapping"]
174
+
175
+ start_indexes = np.argsort(start_logit)[-1: -20 - 1: -1].tolist()
176
+ end_indexes = np.argsort(end_logit)[-1: -20 - 1: -1].tolist()
177
+ for start_index in start_indexes:
178
+ for end_index in end_indexes:
179
+ # Skip answers that are not fully in the context
180
+ if offsets[start_index] is None or offsets[end_index] is None:
181
+ continue
182
+ # Skip answers with a length that is either < 0 or > max_answer_length
183
+ if (
184
+ end_index < start_index
185
+ or end_index - start_index + 1 > 30
186
+ ):
187
+ continue
188
+
189
+ answer = {
190
+ "text": context[offsets[start_index][0]: offsets[end_index][1]],
191
+ "logit_score": start_logit[start_index] + end_logit[end_index],
192
+ }
193
+ answers.append(answer)
194
+
195
+ # Select the answer with the best score
196
+ if len(answers) > 0:
197
+ hard_labels, soft_labels = find_possible_spans(answers, example)
198
+ predicted_answers.append(
199
+ {"id": example_id, "hard_labels": hard_labels, "soft_labels": soft_labels}
200
+ )
201
+ else:
202
+ predicted_answers.append({"id": example_id, "hard_labels": [], "soft_labels": []})
203
+
204
+ with jsonlines.open(predictions_file, mode="w") as writer:
205
+ writer.write_all(predicted_answers)
206
+
207
+ if "answers" in examples.column_names:
208
+ true_answers = [{"id": ex["id"], "hard_labels": ex["hard_labels"], "soft_labels": ex["soft_labels"],
209
+ "text_len": len(ex["context"])} for ex in examples]
210
+ ious = np.array([score_iou(r, d) for r, d in zip(true_answers, predicted_answers)])
211
+ cors = np.array([score_cor(r, d) for r, d in zip(true_answers, predicted_answers)])
212
+
213
+ print(f"IOU: {ious.mean():.8f}, COR: {cors.mean():.8f}")
214
+ else:
215
+ print("Evaluation data contained no answers. No scores to show.")
216
+
217
+ def main(model_path, evaluation_file_path, output_file):
218
+ model = AutoModelForQuestionAnswering.from_pretrained(
219
+ model_path
220
+ )
221
+ tokenizer = AutoTokenizer.from_pretrained(
222
+ model_path
223
+ )
224
+ # Initialize Trainer
225
+ args = TrainingArguments(
226
+ output_dir="output_dir",
227
+ per_device_eval_batch_size=16,
228
+ report_to="none"
229
+ )
230
+
231
+ model = Trainer(
232
+ model=model,
233
+ args=args,
234
+ tokenizer=tokenizer,
235
+ )
236
+
237
+ mushroom_dataset = to_dataset(evaluation_file_path)
238
+ features = mushroom_dataset.map(
239
+ preprocess_examples,
240
+ batched=True,
241
+ remove_columns=mushroom_dataset.column_names,
242
+ fn_kwargs={"tokenizer": tokenizer}
243
+ )
244
+
245
+ predictions, _, _ = model.predict(features)
246
+ start_logits, end_logits = predictions
247
+ compute_metrics(start_logits, end_logits, features, mushroom_dataset, output_file)
248
+
249
+
250
+ if __name__ == '__main__':
251
+ p = argparse.ArgumentParser()
252
+ p.add_argument('model_name', type=str)
253
+ p.add_argument('evaluation_file_path', type=str)
254
+ p.add_argument('output_file', type=str)
255
+ a = p.parse_args()
256
+ main(a.model_name, a.evaluation_file_path, a.output_file)