File size: 3,940 Bytes
52933b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
918d1df
 
 
52933b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
918d1df
 
 
52933b5
 
 
 
 
 
 
 
918d1df
 
 
52933b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import gym
import numpy as np
from gym import spaces

from .tsp_data import TSPDataset


def assign_env_config(self, kwargs):
    """
    Set self.key = value, for each key in kwargs
    """
    for key, value in kwargs.items():
        setattr(self, key, value)


def dist(loc1, loc2):
    return ((loc1[:, 0] - loc2[:, 0]) ** 2 + (loc1[:, 1] - loc2[:, 1]) ** 2) ** 0.5


class TSPVectorEnv(gym.Env):
    def __init__(self, *args, **kwargs):
        self.max_nodes = 50
        self.n_traj = 50
        # if eval_data==True, load from 'test' set, the '0'th data
        self.eval_data = False
        self.eval_partition = "test"
        self.eval_data_idx = 0

        self.eval_data_from_input = None

        assign_env_config(self, kwargs)

        obs_dict = {"observations": spaces.Box(low=0, high=1, shape=(self.max_nodes, 2))}
        obs_dict["action_mask"] = spaces.MultiBinary(
            [self.n_traj, self.max_nodes]
        )  # 1: OK, 0: cannot go
        obs_dict["first_node_idx"] = spaces.MultiDiscrete([self.max_nodes] * self.n_traj)
        obs_dict["last_node_idx"] = spaces.MultiDiscrete([self.max_nodes] * self.n_traj)
        obs_dict["is_initial_action"] = spaces.Discrete(1)

        self.observation_space = spaces.Dict(obs_dict)
        self.action_space = spaces.MultiDiscrete([self.max_nodes] * self.n_traj)
        self.reward_space = None

        self.reset()

    def seed(self, seed):
        np.random.seed(seed)

    def reset(self):
        self.visited = np.zeros((self.n_traj, self.max_nodes), dtype=bool)
        self.num_steps = 0
        self.last = np.zeros(self.n_traj, dtype=int)  # idx of the first elem
        self.first = np.zeros(self.n_traj, dtype=int)  # idx of the first elem

        if self.eval_data == 'from_input':
            self._load_orders_from_input()
        elif self.eval_data:
            self._load_orders()
        else:
            self._generate_orders()
        self.state = self._update_state()
        self.info = {}
        self.done = False
        return self.state

    def _load_orders_from_input(self):
        self.nodes = self.eval_data_from_input.copy()

    def _load_orders(self):
        self.nodes = np.array(TSPDataset[self.eval_partition, self.max_nodes, self.eval_data_idx])

    def _generate_orders(self):
        self.nodes = np.random.rand(self.max_nodes, 2)

    def step(self, action):

        self._go_to(action)  # Go to node 'action', modify the reward
        self.num_steps += 1
        self.state = self._update_state()

        # need to revisit the first node after visited all other nodes
        self.done = (action == self.first) & self.is_all_visited()

        return self.state, self.reward, self.done, self.info

    # Euclidean cost function
    def cost(self, loc1, loc2):
        return dist(loc1, loc2)

    def is_all_visited(self):
        # assumes no repetition in the first `max_nodes` steps
        return self.visited[:, :].all(axis=1)

    def _go_to(self, destination):
        dest_node = self.nodes[destination]
        if self.num_steps != 0:
            dist = self.cost(dest_node, self.nodes[self.last])
        else:
            dist = np.zeros(self.n_traj)
            self.first = destination

        self.last = destination

        self.visited[np.arange(self.n_traj), destination] = True
        self.reward = -dist

    def _update_state(self):
        obs = {"observations": self.nodes}  # n x 2 array
        obs["action_mask"] = self._update_mask()
        obs["first_node_idx"] = self.first
        obs["last_node_idx"] = self.last
        obs["is_initial_action"] = self.num_steps == 0
        return obs

    def _update_mask(self):
        # Only allow to visit unvisited nodes
        action_mask = ~self.visited
        # can only visit first node when all nodes have been visited
        action_mask[np.arange(self.n_traj), self.first] |= self.is_all_visited()
        return action_mask