File size: 963 Bytes
828992f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 |
import os
import random
import torch
import copy
import colbert.utils.distributed as distributed
from colbert.utils.parser import Arguments
from colbert.utils.runs import Run
from colbert.training.training import train
def main():
parser = Arguments(description='Training ColBERT with <query, positive passage, negative passage> triples.')
parser.add_model_parameters()
parser.add_model_training_parameters()
parser.add_training_input()
args = parser.parse()
assert args.bsize % args.accumsteps == 0, ((args.bsize, args.accumsteps),
"The batch size must be divisible by the number of gradient accumulation steps.")
assert args.query_maxlen <= 512
assert args.doc_maxlen <= 512
args.lazy = args.collection is not None
with Run.context(consider_failed_if_interrupted=False):
train(args)
if __name__ == "__main__":
main()
|