Q-Learning Agent playing1 Taxi-v3

This is a trained model of a Q-Learning agent playing Taxi-v3 .

Usage

from huggingface_sb3 import load_from_hub
import gymnasium as gym
from tqdm import tqdm
import numpy as np
import pickle

def greedy_policy(Qtable, state):
  action = np.argmax(Qtable[state, :])
  return action

def evaluate_agent(env: gym.Env, max_steps: int, n_eval_episodes: int, Q: np.ndarray, seed: list[int]):
  """
  Evaluate the agent for ``n_eval_episodes`` episodes and returns average reward and std of reward.
  :param env: The evaluation environment
  :param max_steps: Maximum number of steps per episode
  :param n_eval_episodes: Number of episode to evaluate the agent
  :param Q: The Q-table
  :param seed: The evaluation seed array (for taxi-v3)
  """
  episode_rewards = []
  for episode in tqdm(range(n_eval_episodes)):
      if seed:
          state, info = env.reset(seed=seed[episode])
      else:
          state, info = env.reset()

      truncated = False
      terminated = False
      total_rewards_ep = 0

      for step in range(max_steps):
          action = greedy_policy(Q, state)
          new_state, reward, terminated, truncated, info = env.step(action)
          total_rewards_ep += reward

          if terminated or truncated:
              break
          state = new_state

      episode_rewards.append(total_rewards_ep)

  mean_reward = np.mean(episode_rewards)
  std_reward = np.std(episode_rewards)

  return float(mean_reward), float(std_reward)

if __name__ == "__main__":
  file_path = load_from_hub(repo_id="BobChuang/q-Taxi-v1-5x5", filename="q-learning.pkl")
  with open(file_path, "rb") as f:
      model = pickle.load(f)

  env = gym.make(model["env_id"], render_mode="rgb_array")
  max_steps = model["max_steps"]
  n_eval_episodes = model["n_eval_episodes"]
  qtable = model["qtable"]
  eval_seed = model["eval_seed"]

  mean_reward, std_reward = evaluate_agent(env, max_steps, n_eval_episodes, qtable, eval_seed)
  print(f"\n{ mean_reward = }, { std_reward = }")
Downloads last month

-

Downloads are not tracked for this model. How to track
Video Preview
loading

Evaluation results