Spaces:
Sleeping
Sleeping
import logging | |
import torch | |
from typing import Dict | |
from functools import partial | |
from transformers.utils.logging import enable_explicit_format | |
from transformers.trainer_callback import PrinterCallback | |
from transformers import ( | |
AutoTokenizer, | |
HfArgumentParser, | |
EvalPrediction, | |
Trainer, | |
set_seed, | |
PreTrainedTokenizerFast | |
) | |
from logger_config import logger, LoggerCallback | |
from config import Arguments | |
from trainers import BiencoderTrainer | |
from loaders import RetrievalDataLoader | |
from collators import BiencoderCollator | |
from metrics import accuracy, batch_mrr | |
from models import BiencoderModel | |
def _common_setup(args: Arguments): | |
if args.process_index > 0: | |
logger.setLevel(logging.WARNING) | |
enable_explicit_format() | |
set_seed(args.seed) | |
def _compute_metrics(args: Arguments, eval_pred: EvalPrediction) -> Dict[str, float]: | |
# field consistent with BiencoderOutput | |
preds = eval_pred.predictions | |
scores = torch.tensor(preds[-1]).float() | |
labels = torch.arange(0, scores.shape[0], dtype=torch.long) * args.train_n_passages | |
labels = labels % scores.shape[1] | |
topk_metrics = accuracy(output=scores, target=labels, topk=(1, 3)) | |
mrr = batch_mrr(output=scores, target=labels) | |
return {'mrr': mrr, 'acc1': topk_metrics[0], 'acc3': topk_metrics[1]} | |
def main(): | |
parser = HfArgumentParser((Arguments,)) | |
args: Arguments = parser.parse_args_into_dataclasses()[0] | |
_common_setup(args) | |
logger.info('Args={}'.format(str(args))) | |
tokenizer: PreTrainedTokenizerFast = AutoTokenizer.from_pretrained(args.model_name_or_path) | |
model: BiencoderModel = BiencoderModel.build(args=args) | |
logger.info(model) | |
logger.info('Vocab size: {}'.format(len(tokenizer))) | |
data_collator = BiencoderCollator( | |
tokenizer=tokenizer, | |
pad_to_multiple_of=8 if args.fp16 else None) | |
retrieval_data_loader = RetrievalDataLoader(args=args, tokenizer=tokenizer) | |
train_dataset = retrieval_data_loader.train_dataset | |
eval_dataset = retrieval_data_loader.eval_dataset | |
trainer: Trainer = BiencoderTrainer( | |
model=model, | |
args=args, | |
train_dataset=train_dataset if args.do_train else None, | |
eval_dataset=eval_dataset if args.do_eval else None, | |
data_collator=data_collator, | |
compute_metrics=partial(_compute_metrics, args), | |
tokenizer=tokenizer, | |
) | |
trainer.remove_callback(PrinterCallback) | |
trainer.add_callback(LoggerCallback) | |
retrieval_data_loader.trainer = trainer | |
model.trainer = trainer | |
if args.do_train: | |
train_result = trainer.train() | |
trainer.save_model() | |
metrics = train_result.metrics | |
metrics["train_samples"] = len(train_dataset) | |
trainer.log_metrics("train", metrics) | |
trainer.save_metrics("train", metrics) | |
if args.do_eval: | |
logger.info("*** Evaluate ***") | |
metrics = trainer.evaluate(metric_key_prefix="eval") | |
metrics["eval_samples"] = len(eval_dataset) | |
trainer.log_metrics("eval", metrics) | |
trainer.save_metrics("eval", metrics) | |
return | |
if __name__ == "__main__": | |
main() | |