|
from ddpg import Agent |
|
import gymnasium as gym |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
import torch |
|
import argparse |
|
from train import TrainingLoop |
|
from captum.attr import (IntegratedGradients, LayerConductance, NeuronAttribution) |
|
|
|
training_loop = TrainingLoop(env_spec="LunarLander-v2", continuous=True, gravity=-10) |
|
training_loop.create_agent() |
|
|
|
parser = argparse.ArgumentParser(description="Choose a function to run.") |
|
parser.add_argument("function", choices=["train", "load-trained", "attribute", "video"], help="The function to run.") |
|
|
|
args = parser.parse_args() |
|
|
|
if args.function == "train": |
|
training_loop.train() |
|
elif args.function == "load-trained": |
|
training_loop.load_trained() |
|
elif args.function == "attribute": |
|
frames, attributions = training_loop.explain_trained(option="2", num_iterations=10) |
|
elif args.function == "video": |
|
training_loop.render_video(20) |
|
|
|
|