File size: 1,635 Bytes
719d0db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import torch
import torch.nn as nn
import numpy as np
from models.classifiers.ground_truth.ground_truth_tsptw import GroundTruthTSPTW
from models.classifiers.ground_truth.ground_truth_pctsp import GroundTruthPCTSP
from models.classifiers.ground_truth.ground_truth_pctsptw import GroundTruthPCTSPTW
from models.classifiers.ground_truth.ground_truth_cvrp import GroundTruthCVRP
from models.classifiers.ground_truth.ground_truth_cvrptw import GroundTruthCVRPTW

class GroundTruth(nn.Module):
    def __init__(self, problem, solver_type):
        super().__init__()
        self.problem = problem
        self.solver_type = solver_type
        if problem == "tsptw":
            self.ground_truth = GroundTruthTSPTW(solver_type)
        elif problem == "pctsp":
            self.ground_truth = GroundTruthPCTSP(solver_type)
        elif problem == "pctsptw":
            self.ground_truth = GroundTruthPCTSPTW(solver_type)
        elif problem == "cvrp":
            self.ground_truth = GroundTruthCVRP(solver_type)
        elif problem == "cvrptw":
            self.ground_truth = GroundTruthCVRPTW(solver_type)
        else:
            raise NotImplementedError

    def forward(self, inputs, annotation=False, parallel=False):
        return self.ground_truth(inputs, annotation, parallel)
    
    def get_inputs(self, tour, first_explained_step, node_feats, dist_matrix=None):
        return self.ground_truth.get_inputs(tour, first_explained_step, node_feats, dist_matrix)
    
    def solve(self, step, input_tour, node_feats, instance_name=None):
        return self.ground_truth.solve(step, input_tour, node_feats, instance_name)