Tzktz's picture
Upload 7664 files
6fc683c verified
import logging
import torch
from typing import Dict
from functools import partial
from transformers.utils.logging import enable_explicit_format
from transformers.trainer_callback import PrinterCallback
from transformers import (
AutoTokenizer,
HfArgumentParser,
EvalPrediction,
Trainer,
set_seed,
PreTrainedTokenizerFast
)
from logger_config import logger, LoggerCallback
from config import Arguments
from trainers import BiencoderTrainer
from loaders import RetrievalDataLoader
from collators import BiencoderCollator
from metrics import accuracy, batch_mrr
from models import BiencoderModel
def _common_setup(args: Arguments):
if args.process_index > 0:
logger.setLevel(logging.WARNING)
enable_explicit_format()
set_seed(args.seed)
def _compute_metrics(args: Arguments, eval_pred: EvalPrediction) -> Dict[str, float]:
# field consistent with BiencoderOutput
preds = eval_pred.predictions
scores = torch.tensor(preds[-1]).float()
labels = torch.arange(0, scores.shape[0], dtype=torch.long) * args.train_n_passages
labels = labels % scores.shape[1]
topk_metrics = accuracy(output=scores, target=labels, topk=(1, 3))
mrr = batch_mrr(output=scores, target=labels)
return {'mrr': mrr, 'acc1': topk_metrics[0], 'acc3': topk_metrics[1]}
def main():
parser = HfArgumentParser((Arguments,))
args: Arguments = parser.parse_args_into_dataclasses()[0]
_common_setup(args)
logger.info('Args={}'.format(str(args)))
tokenizer: PreTrainedTokenizerFast = AutoTokenizer.from_pretrained(args.model_name_or_path)
model: BiencoderModel = BiencoderModel.build(args=args)
logger.info(model)
logger.info('Vocab size: {}'.format(len(tokenizer)))
data_collator = BiencoderCollator(
tokenizer=tokenizer,
pad_to_multiple_of=8 if args.fp16 else None)
retrieval_data_loader = RetrievalDataLoader(args=args, tokenizer=tokenizer)
train_dataset = retrieval_data_loader.train_dataset
eval_dataset = retrieval_data_loader.eval_dataset
trainer: Trainer = BiencoderTrainer(
model=model,
args=args,
train_dataset=train_dataset if args.do_train else None,
eval_dataset=eval_dataset if args.do_eval else None,
data_collator=data_collator,
compute_metrics=partial(_compute_metrics, args),
tokenizer=tokenizer,
)
trainer.remove_callback(PrinterCallback)
trainer.add_callback(LoggerCallback)
retrieval_data_loader.trainer = trainer
model.trainer = trainer
if args.do_train:
train_result = trainer.train()
trainer.save_model()
metrics = train_result.metrics
metrics["train_samples"] = len(train_dataset)
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
if args.do_eval:
logger.info("*** Evaluate ***")
metrics = trainer.evaluate(metric_key_prefix="eval")
metrics["eval_samples"] = len(eval_dataset)
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
return
if __name__ == "__main__":
main()