Spaces:
Sleeping
Sleeping
from typing import Dict, Any, Callable | |
from collections import namedtuple | |
from easydict import EasyDict | |
import gym | |
import torch | |
from ding.torch_utils import to_device | |
class PolicyFactory: | |
""" | |
Overview: | |
Policy factory class, used to generate different policies for general purpose. Such as random action policy, \ | |
which is used for initial sample collecting for better exploration when ``random_collect_size`` > 0. | |
Interfaces: | |
``get_random_policy`` | |
""" | |
def get_random_policy( | |
policy: 'Policy.collect_mode', # noqa | |
action_space: 'gym.spaces.Space' = None, # noqa | |
forward_fn: Callable = None, | |
) -> 'Policy.collect_mode': # noqa | |
""" | |
Overview: | |
According to the given action space, define the forward function of the random policy, then pack it with \ | |
other interfaces of the given policy, and return the final collect mode interfaces of policy. | |
Arguments: | |
- policy (:obj:`Policy.collect_mode`): The collect mode interfaces of the policy. | |
- action_space (:obj:`gym.spaces.Space`): The action space of the environment, gym-style. | |
- forward_fn (:obj:`Callable`): It action space is too complex, you can define your own forward function \ | |
and pass it to this function, note you should set ``action_space`` to ``None`` in this case. | |
Returns: | |
- random_policy (:obj:`Policy.collect_mode`): The collect mode intefaces of the random policy. | |
""" | |
assert not (action_space is None and forward_fn is None) | |
random_collect_function = namedtuple( | |
'random_collect_function', [ | |
'forward', | |
'process_transition', | |
'get_train_sample', | |
'reset', | |
'get_attribute', | |
] | |
) | |
def forward(data: Dict[int, Any], *args, **kwargs) -> Dict[int, Any]: | |
actions = {} | |
for env_id in data: | |
if not isinstance(action_space, list): | |
if isinstance(action_space, gym.spaces.Discrete): | |
action = torch.LongTensor([action_space.sample()]) | |
elif isinstance(action_space, gym.spaces.MultiDiscrete): | |
action = [torch.LongTensor([v]) for v in action_space.sample()] | |
else: | |
action = torch.as_tensor(action_space.sample()) | |
actions[env_id] = {'action': action} | |
elif 'global_state' in data[env_id].keys(): | |
# for smac | |
logit = torch.ones_like(data[env_id]['action_mask']) | |
logit[data[env_id]['action_mask'] == 0.0] = -1e8 | |
dist = torch.distributions.categorical.Categorical(logits=torch.Tensor(logit)) | |
actions[env_id] = {'action': dist.sample(), 'logit': torch.as_tensor(logit)} | |
else: | |
# for gfootball | |
actions[env_id] = { | |
'action': torch.as_tensor([action_space_agent.sample() for action_space_agent in action_space]), | |
'logit': torch.ones([len(action_space), action_space[0].n]) | |
} | |
return actions | |
def reset(*args, **kwargs) -> None: | |
pass | |
if action_space is None: | |
return random_collect_function( | |
forward_fn, policy.process_transition, policy.get_train_sample, reset, policy.get_attribute | |
) | |
elif forward_fn is None: | |
return random_collect_function( | |
forward, policy.process_transition, policy.get_train_sample, reset, policy.get_attribute | |
) | |
def get_random_policy( | |
cfg: EasyDict, | |
policy: 'Policy.collect_mode', # noqa | |
env: 'BaseEnvManager' # noqa | |
) -> 'Policy.collect_mode': # noqa | |
""" | |
Overview: | |
The entry function to get the corresponding random policy. If a policy needs special data items in a \ | |
transition, then return itself, otherwise, we will use ``PolicyFactory`` to return a general random policy. | |
Arguments: | |
- cfg (:obj:`EasyDict`): The EasyDict-type dict configuration. | |
- policy (:obj:`Policy.collect_mode`): The collect mode interfaces of the policy. | |
- env (:obj:`BaseEnvManager`): The env manager instance, which is used to get the action space for random \ | |
action generation. | |
Returns: | |
- random_policy (:obj:`Policy.collect_mode`): The collect mode intefaces of the random policy. | |
""" | |
if cfg.policy.get('transition_with_policy_data', False): | |
return policy | |
else: | |
action_space = env.action_space | |
return PolicyFactory.get_random_policy(policy, action_space=action_space) | |