File size: 2,526 Bytes
a446b0b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import random

import comet.src.train.train as base_train
import comet.src.train.batch as batch
import comet.src.evaluate.atomic_evaluate as evaluate
# import comet.src.evaluate.atomic_generate as gen


def make_trainer(opt, *args):
    return AtomicGenerationIteratorTrainer(opt, *args)


class AtomicGenerationIteratorTrainer(base_train.IteratorTrainer):
    def __init__(self, opt, *args):
        super(AtomicGenerationIteratorTrainer, self).__init__(opt, *args)

        self.initialize_losses(opt.data.get("categories", []))

    def set_evaluator(self, opt, model, data_loader):
        self.evaluator = evaluate.make_evaluator(
            opt, model, data_loader)

    # def set_generator(self, opt, model, data_loader, scores, reward=None):
    #     self.generator = gen.make_generator(
    #         opt, model, data_loader, scores, reward)

    def set_sampler(self, opt):
        if opt.train.static.samp not in self.samplers:
            self.samplers[opt.train.static.samp] = sampling.make_sampler(
                opt.train.static.samp, opt, self.data_loader, batch_mode=True)
        self.batch_variables["sampler"] = self.samplers

    def batch(self, opt, *args):
        outputs = batch.batch_atomic_generate(opt, *args)

        token_loss = outputs["loss"]
        nums = outputs["nums"]
        reset = outputs["reset"]

        return token_loss, nums, reset

    def initialize_losses(self, categories):
        self.losses["train"] = {
            "total_micro": [0],
            "total_macro": [0]
        }

        nums = {"total_micro": 0, "total_macro": 0}

        for category in categories:
            micro_name = "{}_micro".format(category)
            macro_name = "{}_macro".format(category)

            self.losses["train"][micro_name] = [0]
            self.losses["train"][macro_name] = [0]

            nums[micro_name] = 0
            nums[macro_name] = 0

        return nums

    def update_top_score(self, opt):
        print(self.top_score)
        if self.top_score is None:
            self.top_score = (self.opt.train.dynamic.epoch,
                              self.get_tracked_score())
        elif self.get_tracked_score() < self.top_score[-1]:
            self.top_score = (self.opt.train.dynamic.epoch,
                              self.get_tracked_score())
        print(self.top_score)

    def get_tracked_score(self):
        return self.losses["dev"]["total_micro"][self.opt.train.dynamic.epoch]

    def counter(self, nums):
        return nums["total_macro"]