Spaces:
Sleeping
Sleeping
import os | |
import sys | |
sys.path.append(os.getcwd()) | |
import torch | |
import torch.nn as nn | |
import shutil | |
import logging | |
import torch.distributed as dist | |
from transformers import ( | |
BertTokenizer, | |
RobertaTokenizer | |
) | |
from args import args | |
from model import ( | |
Layoutlmv1ForQuestionAnswering, | |
Layoutlmv1Config, | |
Layoutlmv1Config_roberta, | |
Layoutlmv1ForQuestionAnswering_roberta | |
) | |
from util import set_seed, set_exp_folder, check_screen | |
from trainer import train, evaluate # choose a specific train function | |
# from data.datasets.docvqa import DocvqaDataset | |
from websrc import get_websrc_dataset | |
def main(args): | |
set_seed(args) | |
set_exp_folder(args) | |
# Set up logger | |
logging.basicConfig(filename="{}/output/{}/log.txt".format(args.output_dir, args.exp_name), level=logging.INFO, | |
format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S') | |
logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) | |
logging.info('Args '+str(args)) | |
# Get config, model, and tokenizer | |
if args.model_type == 'bert': | |
config_class, model_class, tokenizer_class = Layoutlmv1Config, Layoutlmv1ForQuestionAnswering, BertTokenizer | |
elif args.model_type == 'roberta': | |
config_class, model_class, tokenizer_class = Layoutlmv1Config_roberta, Layoutlmv1ForQuestionAnswering_roberta, RobertaTokenizer | |
config = config_class.from_pretrained( | |
args.model_name_or_path, cache_dir=args.cache_dir | |
) | |
config.add_linear = args.add_linear | |
tokenizer = tokenizer_class.from_pretrained( | |
args.model_name_or_path, cache_dir=args.cache_dir | |
) | |
model = model_class.from_pretrained( | |
args.model_name_or_path, | |
from_tf=bool(".ckpt" in args.model_name_or_path), | |
config=config, | |
cache_dir=args.cache_dir, | |
) | |
parameters = sum(p.numel() for p in model.parameters()) | |
print("Total params: %.2fM" % (parameters/1e6)) | |
## Start training | |
if args.do_train: | |
dataset_web = get_websrc_dataset(args, tokenizer) | |
logging.info(f'Web dataset is successfully loaded. Length : {len(dataset_web)}') | |
train(args, dataset_web, model, tokenizer) | |
# ## Start evaluating | |
# if args.do_eval: | |
logging.info('Start evaluating') | |
dataset_web, examples, features = get_websrc_dataset(args, tokenizer, evaluate=True, output_examples=True) | |
logging.info(f'[Eval] Web dataset is successfully loaded. Length : {len(dataset_web)}') | |
evaluate(args, dataset_web, examples, features, model, tokenizer) | |
## Start testing | |
if args.do_test: | |
pass | |
if __name__ == '__main__': | |
main(args) | |