daisuke.kikuta
first commit
719d0db
import torch
import torch.nn as nn
import numpy as np
from models.classifiers.ground_truth.ground_truth_tsptw import GroundTruthTSPTW
from models.classifiers.ground_truth.ground_truth_pctsp import GroundTruthPCTSP
from models.classifiers.ground_truth.ground_truth_pctsptw import GroundTruthPCTSPTW
from models.classifiers.ground_truth.ground_truth_cvrp import GroundTruthCVRP
from models.classifiers.ground_truth.ground_truth_cvrptw import GroundTruthCVRPTW
class GroundTruth(nn.Module):
def __init__(self, problem, solver_type):
super().__init__()
self.problem = problem
self.solver_type = solver_type
if problem == "tsptw":
self.ground_truth = GroundTruthTSPTW(solver_type)
elif problem == "pctsp":
self.ground_truth = GroundTruthPCTSP(solver_type)
elif problem == "pctsptw":
self.ground_truth = GroundTruthPCTSPTW(solver_type)
elif problem == "cvrp":
self.ground_truth = GroundTruthCVRP(solver_type)
elif problem == "cvrptw":
self.ground_truth = GroundTruthCVRPTW(solver_type)
else:
raise NotImplementedError
def forward(self, inputs, annotation=False, parallel=False):
return self.ground_truth(inputs, annotation, parallel)
def get_inputs(self, tour, first_explained_step, node_feats, dist_matrix=None):
return self.ground_truth.get_inputs(tour, first_explained_step, node_feats, dist_matrix)
def solve(self, step, input_tour, node_feats, instance_name=None):
return self.ground_truth.solve(step, input_tour, node_feats, instance_name)