|
import numpy as np |
|
import torch |
|
import gym |
|
from models.attention_model_wrapper import Agent |
|
|
|
from wrappers.syncVectorEnvPomo import SyncVectorEnv |
|
from wrappers.recordWrapper import RecordEpisodeStatistics |
|
|
|
|
|
import matplotlib.pyplot as plt |
|
from matplotlib.collections import LineCollection |
|
from matplotlib.colors import ListedColormap, BoundaryNorm |
|
|
|
import gradio as gr |
|
|
|
device = "cpu" |
|
ckpt_path = "./runs/tsp-v0__ppo_or__1__1678160003/ckpt/12000.pt" |
|
agent = Agent(device=device, name="tsp").to(device) |
|
agent.load_state_dict(torch.load(ckpt_path, map_location=torch.device("cpu"))) |
|
|
|
|
|
env_id = "tsp-v0" |
|
env_entry_point = "envs.tsp_vector_env:TSPVectorEnv" |
|
seed = 0 |
|
|
|
gym.envs.register( |
|
id=env_id, |
|
entry_point=env_entry_point, |
|
) |
|
|
|
|
|
def make_env(env_id, seed, cfg={}): |
|
def thunk(): |
|
env = gym.make(env_id, **cfg) |
|
env = RecordEpisodeStatistics(env) |
|
env.seed(seed) |
|
env.action_space.seed(seed) |
|
env.observation_space.seed(seed) |
|
return env |
|
|
|
return thunk |
|
|
|
|
|
def inference(data): |
|
envs = SyncVectorEnv( |
|
[ |
|
make_env( |
|
env_id, seed, dict(n_traj=1, max_nodes=len(data), eval_data="from_input", eval_data_from_input=data) |
|
) |
|
] |
|
) |
|
|
|
trajectories = [] |
|
agent.eval() |
|
obs = envs.reset() |
|
done = np.array([False]) |
|
while not done.all(): |
|
|
|
with torch.no_grad(): |
|
action, logits = agent(obs) |
|
obs, reward, done, info = envs.step(action.cpu().numpy()) |
|
trajectories.append(action.cpu().numpy()) |
|
nodes_coordinates = obs["observations"][0] |
|
final_return = info[0]["episode"]["r"] |
|
resulting_traj = np.array(trajectories)[:, 0, 0] |
|
return resulting_traj, final_return |
|
|
|
|
|
default_data = np.array( |
|
[ |
|
[0.5488135, 0.71518937], |
|
[0.60276338, 0.54488318], |
|
[0.4236548, 0.64589411], |
|
[0.43758721, 0.891773], |
|
[0.96366276, 0.38344152], |
|
[0.79172504, 0.52889492], |
|
[0.56804456, 0.92559664], |
|
[0.07103606, 0.0871293], |
|
[0.0202184, 0.83261985], |
|
[0.77815675, 0.87001215], |
|
] |
|
) |
|
|
|
|
|
|
|
|
|
|
|
def make_segments(x, y): |
|
""" |
|
Create list of line segments from x and y coordinates, in the correct format for LineCollection: |
|
an array of the form numlines x (points per line) x 2 (x and y) array |
|
""" |
|
|
|
points = np.array([x, y]).T.reshape(-1, 1, 2) |
|
segments = np.concatenate([points[:-1], points[1:]], axis=1) |
|
|
|
return segments |
|
|
|
|
|
def colorline(x, y, z=None, cmap=plt.get_cmap("copper"), norm=plt.Normalize(0.0, 1.0), linewidth=1, alpha=1.0): |
|
""" |
|
Plot a colored line with coordinates x and y |
|
Optionally specify colors in the array z |
|
Optionally specify a colormap, a norm function and a line width |
|
""" |
|
|
|
|
|
if z is None: |
|
z = np.linspace(0.3, 1.0, len(x)) |
|
|
|
|
|
if not hasattr(z, "__iter__"): |
|
z = np.array([z]) |
|
|
|
z = np.asarray(z) |
|
|
|
segments = make_segments(x, y) |
|
lc = LineCollection(segments, array=z, cmap=cmap, norm=norm, linewidth=linewidth, alpha=alpha) |
|
|
|
ax = plt.gca() |
|
ax.add_collection(lc) |
|
|
|
return lc |
|
|
|
|
|
def plot(coords): |
|
fig = plt.figure() |
|
x, y = coords.T |
|
lc = colorline(x, y, cmap="Reds") |
|
plt.axis("square") |
|
return fig |
|
|
|
|
|
def run_inference(data): |
|
data = data.astype(float).to_numpy() |
|
resulting_traj, final_return = inference(data) |
|
result_text = f"Planned Tour:\t{resulting_traj}\nTotal tour length:\t{-final_return[0]:.2f}" |
|
return [plot(data[resulting_traj]), result_text] |
|
|
|
|
|
demo = gr.Interface( |
|
run_inference, |
|
gr.Dataframe( |
|
label="Input", |
|
headers=["x", "y"], |
|
row_count=10, |
|
col_count=(2, "fixed"), |
|
max_rows=10, |
|
value=default_data.tolist(), |
|
overflow_row_behaviour="show_ends", |
|
), |
|
[gr.Plot(label="Results Visualization"), gr.Code(label="Results", interactive=False)], |
|
) |
|
demo.launch() |
|
|