gomoku / DI-engine /ding /policy /policy_factory.py
zjowowen's picture
init space
3dfe8fb
raw
history blame contribute delete
4.88 kB
from typing import Dict, Any, Callable
from collections import namedtuple
from easydict import EasyDict
import gym
import torch
from ding.torch_utils import to_device
class PolicyFactory:
"""
Overview:
Policy factory class, used to generate different policies for general purpose. Such as random action policy, \
which is used for initial sample collecting for better exploration when ``random_collect_size`` > 0.
Interfaces:
``get_random_policy``
"""
@staticmethod
def get_random_policy(
policy: 'Policy.collect_mode', # noqa
action_space: 'gym.spaces.Space' = None, # noqa
forward_fn: Callable = None,
) -> 'Policy.collect_mode': # noqa
"""
Overview:
According to the given action space, define the forward function of the random policy, then pack it with \
other interfaces of the given policy, and return the final collect mode interfaces of policy.
Arguments:
- policy (:obj:`Policy.collect_mode`): The collect mode interfaces of the policy.
- action_space (:obj:`gym.spaces.Space`): The action space of the environment, gym-style.
- forward_fn (:obj:`Callable`): It action space is too complex, you can define your own forward function \
and pass it to this function, note you should set ``action_space`` to ``None`` in this case.
Returns:
- random_policy (:obj:`Policy.collect_mode`): The collect mode intefaces of the random policy.
"""
assert not (action_space is None and forward_fn is None)
random_collect_function = namedtuple(
'random_collect_function', [
'forward',
'process_transition',
'get_train_sample',
'reset',
'get_attribute',
]
)
def forward(data: Dict[int, Any], *args, **kwargs) -> Dict[int, Any]:
actions = {}
for env_id in data:
if not isinstance(action_space, list):
if isinstance(action_space, gym.spaces.Discrete):
action = torch.LongTensor([action_space.sample()])
elif isinstance(action_space, gym.spaces.MultiDiscrete):
action = [torch.LongTensor([v]) for v in action_space.sample()]
else:
action = torch.as_tensor(action_space.sample())
actions[env_id] = {'action': action}
elif 'global_state' in data[env_id].keys():
# for smac
logit = torch.ones_like(data[env_id]['action_mask'])
logit[data[env_id]['action_mask'] == 0.0] = -1e8
dist = torch.distributions.categorical.Categorical(logits=torch.Tensor(logit))
actions[env_id] = {'action': dist.sample(), 'logit': torch.as_tensor(logit)}
else:
# for gfootball
actions[env_id] = {
'action': torch.as_tensor([action_space_agent.sample() for action_space_agent in action_space]),
'logit': torch.ones([len(action_space), action_space[0].n])
}
return actions
def reset(*args, **kwargs) -> None:
pass
if action_space is None:
return random_collect_function(
forward_fn, policy.process_transition, policy.get_train_sample, reset, policy.get_attribute
)
elif forward_fn is None:
return random_collect_function(
forward, policy.process_transition, policy.get_train_sample, reset, policy.get_attribute
)
def get_random_policy(
cfg: EasyDict,
policy: 'Policy.collect_mode', # noqa
env: 'BaseEnvManager' # noqa
) -> 'Policy.collect_mode': # noqa
"""
Overview:
The entry function to get the corresponding random policy. If a policy needs special data items in a \
transition, then return itself, otherwise, we will use ``PolicyFactory`` to return a general random policy.
Arguments:
- cfg (:obj:`EasyDict`): The EasyDict-type dict configuration.
- policy (:obj:`Policy.collect_mode`): The collect mode interfaces of the policy.
- env (:obj:`BaseEnvManager`): The env manager instance, which is used to get the action space for random \
action generation.
Returns:
- random_policy (:obj:`Policy.collect_mode`): The collect mode intefaces of the random policy.
"""
if cfg.policy.get('transition_with_policy_data', False):
return policy
else:
action_space = env.action_space
return PolicyFactory.get_random_policy(policy, action_space=action_space)