Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
import numpy as np | |
from models.classifiers.ground_truth.ground_truth_tsptw import GroundTruthTSPTW | |
from models.classifiers.ground_truth.ground_truth_pctsp import GroundTruthPCTSP | |
from models.classifiers.ground_truth.ground_truth_pctsptw import GroundTruthPCTSPTW | |
from models.classifiers.ground_truth.ground_truth_cvrp import GroundTruthCVRP | |
from models.classifiers.ground_truth.ground_truth_cvrptw import GroundTruthCVRPTW | |
class GroundTruth(nn.Module): | |
def __init__(self, problem, solver_type): | |
super().__init__() | |
self.problem = problem | |
self.solver_type = solver_type | |
if problem == "tsptw": | |
self.ground_truth = GroundTruthTSPTW(solver_type) | |
elif problem == "pctsp": | |
self.ground_truth = GroundTruthPCTSP(solver_type) | |
elif problem == "pctsptw": | |
self.ground_truth = GroundTruthPCTSPTW(solver_type) | |
elif problem == "cvrp": | |
self.ground_truth = GroundTruthCVRP(solver_type) | |
elif problem == "cvrptw": | |
self.ground_truth = GroundTruthCVRPTW(solver_type) | |
else: | |
raise NotImplementedError | |
def forward(self, inputs, annotation=False, parallel=False): | |
return self.ground_truth(inputs, annotation, parallel) | |
def get_inputs(self, tour, first_explained_step, node_feats, dist_matrix=None): | |
return self.ground_truth.get_inputs(tour, first_explained_step, node_feats, dist_matrix) | |
def solve(self, step, input_tour, node_feats, instance_name=None): | |
return self.ground_truth.solve(step, input_tour, node_feats, instance_name) |