RLOR-TSP / app.py
cpwan's picture
Update app.py
6cb459d
import numpy as np
import torch
import gym
from models.attention_model_wrapper import Agent
from wrappers.syncVectorEnvPomo import SyncVectorEnv
from wrappers.recordWrapper import RecordEpisodeStatistics
import matplotlib.pyplot as plt
from matplotlib.collections import LineCollection
from matplotlib.colors import ListedColormap, BoundaryNorm
import gradio as gr
device = "cpu"
ckpt_path = "./runs/tsp-v0__ppo_or__1__1678160003/ckpt/12000.pt"
agent = Agent(device=device, name="tsp").to(device)
agent.load_state_dict(torch.load(ckpt_path, map_location=torch.device("cpu")))
env_id = "tsp-v0"
env_entry_point = "envs.tsp_vector_env:TSPVectorEnv"
seed = 0
gym.envs.register(
id=env_id,
entry_point=env_entry_point,
)
def make_env(env_id, seed, cfg={}):
def thunk():
env = gym.make(env_id, **cfg)
env = RecordEpisodeStatistics(env)
env.seed(seed)
env.action_space.seed(seed)
env.observation_space.seed(seed)
return env
return thunk
def inference(data):
envs = SyncVectorEnv(
[
make_env(
env_id, seed, dict(n_traj=1, max_nodes=len(data), eval_data="from_input", eval_data_from_input=data)
)
]
)
trajectories = []
agent.eval()
obs = envs.reset()
done = np.array([False])
while not done.all():
# ALGO LOGIC: action logic
with torch.no_grad():
action, logits = agent(obs)
obs, reward, done, info = envs.step(action.cpu().numpy())
trajectories.append(action.cpu().numpy())
nodes_coordinates = obs["observations"][0]
final_return = info[0]["episode"]["r"]
resulting_traj = np.array(trajectories)[:, 0, 0]
return resulting_traj, final_return
default_data = np.array(
[
[0.5488135, 0.71518937],
[0.60276338, 0.54488318],
[0.4236548, 0.64589411],
[0.43758721, 0.891773],
[0.96366276, 0.38344152],
[0.79172504, 0.52889492],
[0.56804456, 0.92559664],
[0.07103606, 0.0871293],
[0.0202184, 0.83261985],
[0.77815675, 0.87001215],
]
)
# @title Helper function for plotting
# colorline taken from https://nbviewer.org/github/dpsanders/matplotlib-examples/blob/master/colorline.ipynb
def make_segments(x, y):
"""
Create list of line segments from x and y coordinates, in the correct format for LineCollection:
an array of the form numlines x (points per line) x 2 (x and y) array
"""
points = np.array([x, y]).T.reshape(-1, 1, 2)
segments = np.concatenate([points[:-1], points[1:]], axis=1)
return segments
def colorline(x, y, z=None, cmap=plt.get_cmap("copper"), norm=plt.Normalize(0.0, 1.0), linewidth=1, alpha=1.0):
"""
Plot a colored line with coordinates x and y
Optionally specify colors in the array z
Optionally specify a colormap, a norm function and a line width
"""
# Default colors equally spaced on [0,1]:
if z is None:
z = np.linspace(0.3, 1.0, len(x))
# Special case if a single number:
if not hasattr(z, "__iter__"): # to check for numerical input -- this is a hack
z = np.array([z])
z = np.asarray(z)
segments = make_segments(x, y)
lc = LineCollection(segments, array=z, cmap=cmap, norm=norm, linewidth=linewidth, alpha=alpha)
ax = plt.gca()
ax.add_collection(lc)
return lc
def plot(coords):
fig = plt.figure()
x, y = coords.T
lc = colorline(x, y, cmap="Reds")
plt.axis("square")
return fig
def run_inference(data):
data = data.astype(float).to_numpy()
resulting_traj, final_return = inference(data)
result_text = f"Planned Tour:\t{resulting_traj}\nTotal tour length:\t{-final_return[0]:.2f}"
return [plot(data[resulting_traj]), result_text]
demo = gr.Interface(
run_inference,
gr.Dataframe(
label="Input",
headers=["x", "y"],
row_count=10,
col_count=(2, "fixed"),
max_rows=10,
value=default_data.tolist(),
overflow_row_behaviour="show_ends",
),
[gr.Plot(label="Results Visualization"), gr.Code(label="Results", interactive=False)],
)
demo.launch()