import numpy as np import torch import gym from models.attention_model_wrapper import Agent from wrappers.syncVectorEnvPomo import SyncVectorEnv from wrappers.recordWrapper import RecordEpisodeStatistics import matplotlib.pyplot as plt from matplotlib.collections import LineCollection from matplotlib.colors import ListedColormap, BoundaryNorm import gradio as gr device = "cpu" ckpt_path = "./runs/tsp-v0__ppo_or__1__1678160003/ckpt/12000.pt" agent = Agent(device=device, name="tsp").to(device) agent.load_state_dict(torch.load(ckpt_path, map_location=torch.device("cpu"))) env_id = "tsp-v0" env_entry_point = "envs.tsp_vector_env:TSPVectorEnv" seed = 0 gym.envs.register( id=env_id, entry_point=env_entry_point, ) def make_env(env_id, seed, cfg={}): def thunk(): env = gym.make(env_id, **cfg) env = RecordEpisodeStatistics(env) env.seed(seed) env.action_space.seed(seed) env.observation_space.seed(seed) return env return thunk def inference(data): envs = SyncVectorEnv( [ make_env( env_id, seed, dict(n_traj=1, max_nodes=len(data), eval_data="from_input", eval_data_from_input=data) ) ] ) trajectories = [] agent.eval() obs = envs.reset() done = np.array([False]) while not done.all(): # ALGO LOGIC: action logic with torch.no_grad(): action, logits = agent(obs) obs, reward, done, info = envs.step(action.cpu().numpy()) trajectories.append(action.cpu().numpy()) nodes_coordinates = obs["observations"][0] final_return = info[0]["episode"]["r"] resulting_traj = np.array(trajectories)[:, 0, 0] return resulting_traj, final_return default_data = np.array( [ [0.5488135, 0.71518937], [0.60276338, 0.54488318], [0.4236548, 0.64589411], [0.43758721, 0.891773], [0.96366276, 0.38344152], [0.79172504, 0.52889492], [0.56804456, 0.92559664], [0.07103606, 0.0871293], [0.0202184, 0.83261985], [0.77815675, 0.87001215], ] ) # @title Helper function for plotting # colorline taken from https://nbviewer.org/github/dpsanders/matplotlib-examples/blob/master/colorline.ipynb def make_segments(x, y): """ Create list of line segments from x and y coordinates, in the correct format for LineCollection: an array of the form numlines x (points per line) x 2 (x and y) array """ points = np.array([x, y]).T.reshape(-1, 1, 2) segments = np.concatenate([points[:-1], points[1:]], axis=1) return segments def colorline(x, y, z=None, cmap=plt.get_cmap("copper"), norm=plt.Normalize(0.0, 1.0), linewidth=1, alpha=1.0): """ Plot a colored line with coordinates x and y Optionally specify colors in the array z Optionally specify a colormap, a norm function and a line width """ # Default colors equally spaced on [0,1]: if z is None: z = np.linspace(0.3, 1.0, len(x)) # Special case if a single number: if not hasattr(z, "__iter__"): # to check for numerical input -- this is a hack z = np.array([z]) z = np.asarray(z) segments = make_segments(x, y) lc = LineCollection(segments, array=z, cmap=cmap, norm=norm, linewidth=linewidth, alpha=alpha) ax = plt.gca() ax.add_collection(lc) return lc def plot(coords): fig = plt.figure() x, y = coords.T lc = colorline(x, y, cmap="Reds") plt.axis("square") return fig def run_inference(data): data = data.astype(float).to_numpy() resulting_traj, final_return = inference(data) result_text = f"Planned Tour:\t{resulting_traj}\nTotal tour length:\t{-final_return[0]:.2f}" return [plot(data[resulting_traj]), result_text] demo = gr.Interface( run_inference, gr.Dataframe( label="Input", headers=["x", "y"], row_count=10, col_count=(2, "fixed"), max_rows=10, value=default_data.tolist(), overflow_row_behaviour="show_ends", ), [gr.Plot(label="Results Visualization"), gr.Code(label="Results", interactive=False)], ) demo.launch()