Spaces:
Sleeping
Sleeping
File size: 2,960 Bytes
6fc683c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 |
import logging
import torch
from typing import Dict
from transformers.utils.logging import enable_explicit_format
from transformers.trainer_callback import PrinterCallback
from transformers import (
AutoTokenizer,
HfArgumentParser,
EvalPrediction,
Trainer,
set_seed,
PreTrainedTokenizerFast
)
from logger_config import logger, LoggerCallback
from config import Arguments
from trainers.reranker_trainer import RerankerTrainer
from loaders import CrossEncoderDataLoader
from collators import CrossEncoderCollator
from metrics import accuracy
from models import Reranker
def _common_setup(args: Arguments):
if args.process_index > 0:
logger.setLevel(logging.WARNING)
enable_explicit_format()
set_seed(args.seed)
def _compute_metrics(eval_pred: EvalPrediction) -> Dict:
preds = eval_pred.predictions
if isinstance(preds, tuple):
preds = preds[-1]
logits = torch.tensor(preds).float()
labels = torch.tensor(eval_pred.label_ids).long()
acc = accuracy(output=logits, target=labels)[0]
return {'acc': acc}
def main():
parser = HfArgumentParser((Arguments,))
args: Arguments = parser.parse_args_into_dataclasses()[0]
_common_setup(args)
logger.info('Args={}'.format(str(args)))
tokenizer: PreTrainedTokenizerFast = AutoTokenizer.from_pretrained(args.model_name_or_path)
model: Reranker = Reranker.from_pretrained(
all_args=args,
pretrained_model_name_or_path=args.model_name_or_path,
num_labels=1)
logger.info(model)
logger.info('Vocab size: {}'.format(len(tokenizer)))
data_collator = CrossEncoderCollator(
tokenizer=tokenizer,
pad_to_multiple_of=8 if args.fp16 else None)
rerank_data_loader = CrossEncoderDataLoader(args=args, tokenizer=tokenizer)
train_dataset = rerank_data_loader.train_dataset
eval_dataset = rerank_data_loader.eval_dataset
trainer: Trainer = RerankerTrainer(
model=model,
args=args,
train_dataset=train_dataset if args.do_train else None,
eval_dataset=eval_dataset if args.do_eval else None,
data_collator=data_collator,
compute_metrics=_compute_metrics,
tokenizer=tokenizer,
)
trainer.remove_callback(PrinterCallback)
trainer.add_callback(LoggerCallback)
rerank_data_loader.trainer = trainer
if args.do_train:
train_result = trainer.train()
trainer.save_model()
metrics = train_result.metrics
metrics["train_samples"] = len(train_dataset)
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
if args.do_eval:
logger.info("*** Evaluate ***")
metrics = trainer.evaluate(metric_key_prefix="eval")
metrics["eval_samples"] = len(eval_dataset)
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
return
if __name__ == "__main__":
main()
|