File size: 4,194 Bytes
918d1df a4e57fd 4d5f005 a4e57fd 918d1df a4e57fd 918d1df a4e57fd 918d1df a4e57fd 918d1df a4e57fd 918d1df a4e57fd 918d1df a4e57fd 918d1df a4e57fd 918d1df a4e57fd 918d1df a4e57fd 918d1df a4e57fd 918d1df a4e57fd 918d1df a4e57fd 918d1df a4e57fd 918d1df a4e57fd 918d1df a4e57fd 918d1df a4e57fd 918d1df a4e57fd 918d1df a4e57fd 918d1df a4e57fd 918d1df a4e57fd 918d1df a4e57fd 918d1df 6cb459d a4e57fd 92ea6b2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
import numpy as np
import torch
import gym
from models.attention_model_wrapper import Agent
from wrappers.syncVectorEnvPomo import SyncVectorEnv
from wrappers.recordWrapper import RecordEpisodeStatistics
import matplotlib.pyplot as plt
from matplotlib.collections import LineCollection
from matplotlib.colors import ListedColormap, BoundaryNorm
import gradio as gr
device = "cpu"
ckpt_path = "./runs/tsp-v0__ppo_or__1__1678160003/ckpt/12000.pt"
agent = Agent(device=device, name="tsp").to(device)
agent.load_state_dict(torch.load(ckpt_path, map_location=torch.device("cpu")))
env_id = "tsp-v0"
env_entry_point = "envs.tsp_vector_env:TSPVectorEnv"
seed = 0
gym.envs.register(
id=env_id,
entry_point=env_entry_point,
)
def make_env(env_id, seed, cfg={}):
def thunk():
env = gym.make(env_id, **cfg)
env = RecordEpisodeStatistics(env)
env.seed(seed)
env.action_space.seed(seed)
env.observation_space.seed(seed)
return env
return thunk
def inference(data):
envs = SyncVectorEnv(
[
make_env(
env_id, seed, dict(n_traj=1, max_nodes=len(data), eval_data="from_input", eval_data_from_input=data)
)
]
)
trajectories = []
agent.eval()
obs = envs.reset()
done = np.array([False])
while not done.all():
# ALGO LOGIC: action logic
with torch.no_grad():
action, logits = agent(obs)
obs, reward, done, info = envs.step(action.cpu().numpy())
trajectories.append(action.cpu().numpy())
nodes_coordinates = obs["observations"][0]
final_return = info[0]["episode"]["r"]
resulting_traj = np.array(trajectories)[:, 0, 0]
return resulting_traj, final_return
default_data = np.array(
[
[0.5488135, 0.71518937],
[0.60276338, 0.54488318],
[0.4236548, 0.64589411],
[0.43758721, 0.891773],
[0.96366276, 0.38344152],
[0.79172504, 0.52889492],
[0.56804456, 0.92559664],
[0.07103606, 0.0871293],
[0.0202184, 0.83261985],
[0.77815675, 0.87001215],
]
)
# @title Helper function for plotting
# colorline taken from https://nbviewer.org/github/dpsanders/matplotlib-examples/blob/master/colorline.ipynb
def make_segments(x, y):
"""
Create list of line segments from x and y coordinates, in the correct format for LineCollection:
an array of the form numlines x (points per line) x 2 (x and y) array
"""
points = np.array([x, y]).T.reshape(-1, 1, 2)
segments = np.concatenate([points[:-1], points[1:]], axis=1)
return segments
def colorline(x, y, z=None, cmap=plt.get_cmap("copper"), norm=plt.Normalize(0.0, 1.0), linewidth=1, alpha=1.0):
"""
Plot a colored line with coordinates x and y
Optionally specify colors in the array z
Optionally specify a colormap, a norm function and a line width
"""
# Default colors equally spaced on [0,1]:
if z is None:
z = np.linspace(0.3, 1.0, len(x))
# Special case if a single number:
if not hasattr(z, "__iter__"): # to check for numerical input -- this is a hack
z = np.array([z])
z = np.asarray(z)
segments = make_segments(x, y)
lc = LineCollection(segments, array=z, cmap=cmap, norm=norm, linewidth=linewidth, alpha=alpha)
ax = plt.gca()
ax.add_collection(lc)
return lc
def plot(coords):
fig = plt.figure()
x, y = coords.T
lc = colorline(x, y, cmap="Reds")
plt.axis("square")
return fig
def run_inference(data):
data = data.astype(float).to_numpy()
resulting_traj, final_return = inference(data)
result_text = f"Planned Tour:\t{resulting_traj}\nTotal tour length:\t{-final_return[0]:.2f}"
return [plot(data[resulting_traj]), result_text]
demo = gr.Interface(
run_inference,
gr.Dataframe(
label="Input",
headers=["x", "y"],
row_count=10,
col_count=(2, "fixed"),
max_rows=10,
value=default_data.tolist(),
overflow_row_behaviour="show_ends",
),
[gr.Plot(label="Results Visualization"), gr.Code(label="Results", interactive=False)],
)
demo.launch()
|