Tzktz's picture
Upload 7664 files
6fc683c verified
import os
import sys
sys.path.append(os.getcwd())
import torch
import torch.nn as nn
import shutil
import logging
import torch.distributed as dist
from transformers import (
BertTokenizer,
RobertaTokenizer
)
from args import args
from model import (
Layoutlmv1ForQuestionAnswering,
Layoutlmv1Config,
Layoutlmv1Config_roberta,
Layoutlmv1ForQuestionAnswering_roberta
)
from util import set_seed, set_exp_folder, check_screen
from trainer import train, evaluate # choose a specific train function
# from data.datasets.docvqa import DocvqaDataset
from websrc import get_websrc_dataset
def main(args):
set_seed(args)
set_exp_folder(args)
# Set up logger
logging.basicConfig(filename="{}/output/{}/log.txt".format(args.output_dir, args.exp_name), level=logging.INFO,
format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S')
logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
logging.info('Args '+str(args))
# Get config, model, and tokenizer
if args.model_type == 'bert':
config_class, model_class, tokenizer_class = Layoutlmv1Config, Layoutlmv1ForQuestionAnswering, BertTokenizer
elif args.model_type == 'roberta':
config_class, model_class, tokenizer_class = Layoutlmv1Config_roberta, Layoutlmv1ForQuestionAnswering_roberta, RobertaTokenizer
config = config_class.from_pretrained(
args.model_name_or_path, cache_dir=args.cache_dir
)
config.add_linear = args.add_linear
tokenizer = tokenizer_class.from_pretrained(
args.model_name_or_path, cache_dir=args.cache_dir
)
model = model_class.from_pretrained(
args.model_name_or_path,
from_tf=bool(".ckpt" in args.model_name_or_path),
config=config,
cache_dir=args.cache_dir,
)
parameters = sum(p.numel() for p in model.parameters())
print("Total params: %.2fM" % (parameters/1e6))
## Start training
if args.do_train:
dataset_web = get_websrc_dataset(args, tokenizer)
logging.info(f'Web dataset is successfully loaded. Length : {len(dataset_web)}')
train(args, dataset_web, model, tokenizer)
# ## Start evaluating
# if args.do_eval:
logging.info('Start evaluating')
dataset_web, examples, features = get_websrc_dataset(args, tokenizer, evaluate=True, output_examples=True)
logging.info(f'[Eval] Web dataset is successfully loaded. Length : {len(dataset_web)}')
evaluate(args, dataset_web, examples, features, model, tokenizer)
## Start testing
if args.do_test:
pass
if __name__ == '__main__':
main(args)