File size: 2,720 Bytes
6fc683c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import os
import sys
sys.path.append(os.getcwd())
import torch
import torch.nn as nn
import shutil
import logging
import torch.distributed as dist


from transformers import (
    BertTokenizer,
    RobertaTokenizer
)

from args import args
from model import (  
    Layoutlmv1ForQuestionAnswering,
    Layoutlmv1Config,
    Layoutlmv1Config_roberta,
    Layoutlmv1ForQuestionAnswering_roberta
)
from util import set_seed, set_exp_folder, check_screen
from trainer import train, evaluate # choose a specific train function
# from data.datasets.docvqa import DocvqaDataset
from websrc import get_websrc_dataset

def main(args):

    set_seed(args)
    set_exp_folder(args)

    # Set up logger
    logging.basicConfig(filename="{}/output/{}/log.txt".format(args.output_dir, args.exp_name), level=logging.INFO,
                        format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S')
    logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
    logging.info('Args '+str(args))

    # Get config, model, and tokenizer

    if args.model_type == 'bert':
        config_class, model_class, tokenizer_class = Layoutlmv1Config, Layoutlmv1ForQuestionAnswering, BertTokenizer
    elif args.model_type == 'roberta':
        config_class, model_class, tokenizer_class = Layoutlmv1Config_roberta, Layoutlmv1ForQuestionAnswering_roberta, RobertaTokenizer

    config = config_class.from_pretrained(
                args.model_name_or_path, cache_dir=args.cache_dir
            )
    config.add_linear = args.add_linear

    tokenizer = tokenizer_class.from_pretrained(
                args.model_name_or_path, cache_dir=args.cache_dir
            )


    model = model_class.from_pretrained(
            args.model_name_or_path,
            from_tf=bool(".ckpt" in args.model_name_or_path),
            config=config,
            cache_dir=args.cache_dir,
        )

    parameters = sum(p.numel() for p in model.parameters())
    print("Total params: %.2fM" % (parameters/1e6))


    ## Start training
    if args.do_train:
   
        dataset_web = get_websrc_dataset(args, tokenizer)
      
        logging.info(f'Web dataset is successfully loaded. Length : {len(dataset_web)}')
        train(args, dataset_web, model, tokenizer)

    # ## Start evaluating
    # if args.do_eval:

    logging.info('Start evaluating')
    dataset_web, examples, features = get_websrc_dataset(args, tokenizer, evaluate=True, output_examples=True)
    logging.info(f'[Eval] Web dataset is successfully loaded. Length : {len(dataset_web)}')
    evaluate(args, dataset_web, examples, features, model, tokenizer)


    ## Start testing
    if args.do_test:
        pass


if __name__ == '__main__':
    main(args)