Plan2Align-NV / vecalign /metricx24 /evaluate_wmt24.py
KuangDW
add alignment and specify encoder
dd05f29
# coding=utf-8
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Evaluates the predictions from a MetricX model."""
import collections
import dataclasses
import json
import os
from mt_metrics_eval import data
from mt_metrics_eval import tasks
import numpy as np
import transformers
@dataclasses.dataclass
class Arguments:
en_de: str = dataclasses.field(metadata={"help": "The en-de input file."})
en_es: str = dataclasses.field(metadata={"help": "The en-es input file."})
ja_zh: str = dataclasses.field(metadata={"help": "The ja-zh input file."})
output_file: str = dataclasses.field(
metadata={"help": "The output file with evaluation metrics."},
)
def _load_scores(
input_file: str, num_segments: int,
) -> tuple[dict[str, list[float]], dict[str, list[float]]]:
"""Loads segment and system-level scores."""
scores = collections.defaultdict(dict)
with open(input_file, "r") as f:
for line in f:
instance = json.loads(line)
system_id = instance["system_id"]
segment_id = instance["segment_id"]
score = -1 * instance["prediction"]
scores[system_id][segment_id] = score
seg_scores = {}
for system_id in scores:
seg_scores[system_id] = []
for segment_id in range(num_segments):
seg_scores[system_id].append(scores[system_id].get(segment_id, None))
sys_scores = {}
for system_id in seg_scores:
cur_scores = np.asarray(seg_scores[system_id])
sys_scores[system_id] = np.mean(cur_scores[cur_scores != None]) # pylint: disable=singleton-comparison
return seg_scores, sys_scores
def main() -> None:
parser = transformers.HfArgumentParser(Arguments)
(args,) = parser.parse_args_into_dataclasses()
# Download MTME data
data.Download()
metric_name = "metricx-24-v2p6"
wmt24_lps = ["en-de", "en-es", "ja-zh"]
evs_dict = {
("wmt24", lp): data.EvalSet("wmt24", lp, True) for lp in wmt24_lps
}
segment_counts_per_lp = {}
for lp in wmt24_lps:
evs = evs_dict[("wmt24", lp)]
gold_scores = evs.Scores("seg", "mqm")
for _, scores in gold_scores.items():
segment_counts_per_lp[lp] = len(scores)
continue
scores = {
"en-de": _load_scores(args.en_de, segment_counts_per_lp["en-de"]),
"en-es": _load_scores(args.en_es, segment_counts_per_lp["en-es"]),
"ja-zh": _load_scores(args.ja_zh, segment_counts_per_lp["ja-zh"]),
}
for lp in wmt24_lps:
evs = evs_dict[("wmt24", lp)]
seg_scores, sys_scores = scores[lp]
evs._scores["seg"][f"{metric_name}-{evs.std_ref}"] = seg_scores # pylint: disable=protected-access
evs._scores["sys"][f"{metric_name}-{evs.std_ref}"] = sys_scores # pylint: disable=protected-access
evs._metric_names.add(f"{metric_name}-{evs.std_ref}") # pylint: disable=protected-access
evs._metric_basenames.add(metric_name) # pylint: disable=protected-access
for evs in evs_dict.values():
evs.SetPrimaryMetrics(evs.primary_metrics | {metric_name})
wmt24_tasks, wts = tasks.WMT24(wmt24_lps, k=0)
results = wmt24_tasks.Run(eval_set_dict=evs_dict)
metrics = {"average_correlation": results.AverageCorrs(wts)[metric_name]}
if args.output_file:
dirname = os.path.dirname(args.output_file)
if dirname:
os.makedirs(dirname, exist_ok=True)
with open(args.output_file, "w") as out:
out.write(json.dumps(metrics, indent=2))
if __name__ == "__main__":
main()