File size: 635 Bytes
719d0db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import numpy as np
from models.classifiers.ground_truth.ground_truth_base import GroundTruthBase
from models.classifiers.ground_truth.ground_truth_base import get_visited_mask, get_tw_mask

class GroundTruthPCTSPTW(GroundTruthBase):
    def __init__(self, solver_type):
        problem = "pctsptw"
        compared_problems = ["tsp", "pctsp"]
        super().__init__(problem, compared_problems, solver_type)
    
    # @override
    def get_mask(self, tour, step, node_feats):
        visited = get_visited_mask(tour, step, node_feats)
        not_exceed_tw = get_tw_mask(tour, step, node_feats)
        return visited | not_exceed_tw