KuangDW
add alignment and specify encoder
dd05f29
# coding=utf-8
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Runs inference with a MetricX model."""
import dataclasses
import json
import os
import datasets
# from metricx24 import models
from . import models
import torch
import transformers
import time
@dataclasses.dataclass
class Arguments:
"""Prediction command-line arguments."""
tokenizer: str = dataclasses.field(
metadata={"help": "The name of the tokenizer"},
)
model_name_or_path: str = dataclasses.field(
metadata={
"help": (
"Path to pretrained model or model identifier from"
" huggingface.co/models"
)
},
)
max_input_length: int = dataclasses.field(
metadata={"help": "The maximum allowable input sequence length."},
)
batch_size: int = dataclasses.field(
metadata={"help": "The global prediction batch size."},
)
input_file: str = dataclasses.field(metadata={"help": "The input file."})
output_file: str = dataclasses.field(
metadata={"help": "The output file with predictions."},
)
qe: bool = dataclasses.field(
metadata={"help": "Indicates the metric is a QE metric."},
default=False,
)
device: str = dataclasses.field(
metadata={"help": "No device."},
default='0'
)
def get_dataset(
input_file: str, tokenizer, max_input_length: int, device, is_qe: bool
):
"""Gets the test dataset for prediction.
If `is_qe` is true, the input data must have "hypothesis" and "source" fields.
If it is false, there must be "hypothesis" and "reference" fields.
Args:
input_file: The path to the jsonl input file.
tokenizer: The tokenizer to use.
max_input_length: The maximum input sequence length.
device: The ID of the device to put the PyTorch tensors on.
is_qe: Indicates whether the metric is a QE metric or not.
Returns:
The dataset.
"""
def _make_input(example):
if is_qe:
example["input"] = (
"source: "
+ example["source"]
+ " candidate: "
+ example["hypothesis"]
)
else:
example["input"] = (
"source: "
+ example["source"]
+ " candidate: "
+ example["hypothesis"]
+ " reference: "
+ example["reference"]
)
return example
def _tokenize(example):
return tokenizer(
example["input"],
max_length=max_input_length,
truncation=True,
padding=False,
)
def _remove_eos(example):
example["input_ids"] = example["input_ids"][:-1]
example["attention_mask"] = example["attention_mask"][:-1]
return example
ds = datasets.load_dataset("json", data_files={"test": input_file})
ds = ds.map(_make_input)
ds = ds.map(_tokenize)
ds = ds.map(_remove_eos)
ds.set_format(
type="torch",
columns=["input_ids", "attention_mask"],
device=device,
output_all_columns=True,
)
return ds
def main() -> None:
parser = transformers.HfArgumentParser(Arguments)
(args,) = parser.parse_args_into_dataclasses()
os.environ['CUDA_VISIBLE_DEVICES']=args.device
os.environ['NCCL_P2P_DISABLE'] = "1"
os.environ['NCCL_IB_DISABLE'] = "1"
if torch.cuda.is_available():
device = torch.device(f"cuda:0")
per_device_batch_size = args.batch_size // torch.cuda.device_count()
else:
device = torch.device("cpu")
per_device_batch_size = args.batch_size
tokenizer = transformers.AutoTokenizer.from_pretrained(args.tokenizer)
model = models.MT5ForRegression.from_pretrained(
args.model_name_or_path, torch_dtype="auto"
)
model.to(device)
model.eval()
ds = get_dataset(
args.input_file,
tokenizer,
args.max_input_length,
device,
args.qe,
)
training_args = transformers.TrainingArguments(
output_dir=os.path.dirname(args.output_file),
per_device_eval_batch_size=per_device_batch_size,
dataloader_pin_memory=False,
)
trainer = transformers.Trainer(
model=model,
args=training_args,
)
predictions, _, _ = trainer.predict(test_dataset=ds["test"])
dirname = os.path.dirname(args.output_file)
if dirname:
os.makedirs(dirname, exist_ok=True)
with open(args.output_file, "w") as out:
for pred, example in zip(predictions, ds["test"]):
example["prediction"] = float(pred)
del example["input"]
del example["input_ids"]
del example["attention_mask"]
out.write(json.dumps(example, ensure_ascii=False) + "\n")
if __name__ == "__main__":
main()