Spaces:
Sleeping
Sleeping
from typing import Callable, Tuple, Union | |
import torch | |
from torch import Tensor | |
from ding.torch_utils import fold_batch, unfold_batch | |
from ding.rl_utils import generalized_lambda_returns | |
from ding.torch_utils.network.dreamer import static_scan | |
def q_evaluation(obss: Tensor, actions: Tensor, q_critic_fn: Callable[[Tensor, Tensor], | |
Tensor]) -> Union[Tensor, Tuple[Tensor, Tensor]]: | |
""" | |
Overview: | |
Evaluate (observation, action) pairs along the trajectory | |
Arguments: | |
- obss (:obj:`torch.Tensor`): the observations along the trajectory | |
- actions (:obj:`torch.Size`): the actions along the trajectory | |
- q_critic_fn (:obj:`Callable`): the unified API :math:`Q(S_t, A_t)` | |
Returns: | |
- q_value (:obj:`torch.Tensor`): the action-value function evaluated along the trajectory | |
Shapes: | |
:math:`N`: time step | |
:math:`B`: batch size | |
:math:`O`: observation dimension | |
:math:`A`: action dimension | |
- obss: [N, B, O] | |
- actions: [N, B, A] | |
- q_value: [N, B] | |
""" | |
obss, dim = fold_batch(obss, 1) | |
actions, _ = fold_batch(actions, 1) | |
q_values = q_critic_fn(obss, actions) | |
# twin critic | |
if isinstance(q_values, list): | |
return [unfold_batch(q_values[0], dim), unfold_batch(q_values[1], dim)] | |
return unfold_batch(q_values, dim) | |
def imagine(cfg, world_model, start, actor, horizon, repeats=None): | |
dynamics = world_model.dynamics | |
flatten = lambda x: x.reshape([-1] + list(x.shape[2:])) | |
start = {k: flatten(v) for k, v in start.items()} | |
def step(prev, _): | |
state, _, _ = prev | |
feat = dynamics.get_feat(state) | |
inp = feat.detach() | |
action = actor(inp).sample() | |
succ = dynamics.img_step(state, action, sample=cfg.imag_sample) | |
return succ, feat, action | |
succ, feats, actions = static_scan(step, [torch.arange(horizon)], (start, None, None)) | |
states = {k: torch.cat([start[k][None], v[:-1]], 0) for k, v in succ.items()} | |
return feats, states, actions | |
def compute_target(cfg, world_model, critic, imag_feat, imag_state, reward, actor_ent, state_ent): | |
if "discount" in world_model.heads: | |
inp = world_model.dynamics.get_feat(imag_state) | |
discount = cfg.discount * world_model.heads["discount"](inp).mean | |
# TODO whether to detach | |
discount = discount.detach() | |
else: | |
discount = cfg.discount * torch.ones_like(reward) | |
value = critic(imag_feat).mode() | |
# value(imag_horizon, 16*64, 1) | |
# action(imag_horizon, 16*64, ch) | |
# discount(imag_horizon, 16*64, 1) | |
target = generalized_lambda_returns(value, reward[:-1], discount[:-1], cfg.lambda_) | |
weights = torch.cumprod(torch.cat([torch.ones_like(discount[:1]), discount[:-1]], 0), 0).detach() | |
return target, weights, value[:-1] | |
def compute_actor_loss( | |
cfg, | |
actor, | |
reward_ema, | |
imag_feat, | |
imag_state, | |
imag_action, | |
target, | |
actor_ent, | |
state_ent, | |
weights, | |
base, | |
): | |
metrics = {} | |
inp = imag_feat.detach() | |
policy = actor(inp) | |
actor_ent = policy.entropy() | |
# Q-val for actor is not transformed using symlog | |
if cfg.reward_EMA: | |
offset, scale = reward_ema(target) | |
normed_target = (target - offset) / scale | |
normed_base = (base - offset) / scale | |
adv = normed_target - normed_base | |
metrics.update(tensorstats(normed_target, "normed_target")) | |
values = reward_ema.values | |
metrics["EMA_005"] = values[0].detach().cpu().numpy().item() | |
metrics["EMA_095"] = values[1].detach().cpu().numpy().item() | |
actor_target = adv | |
if cfg.actor_entropy > 0: | |
actor_entropy = cfg.actor_entropy * actor_ent[:-1][:, :, None] | |
actor_target += actor_entropy | |
metrics["actor_entropy"] = torch.mean(actor_entropy).detach().cpu().numpy().item() | |
if cfg.actor_state_entropy > 0: | |
state_entropy = cfg.actor_state_entropy * state_ent[:-1] | |
actor_target += state_entropy | |
metrics["actor_state_entropy"] = torch.mean(state_entropy).detach().cpu().numpy().item() | |
actor_loss = -torch.mean(weights[:-1] * actor_target) | |
return actor_loss, metrics | |
class RewardEMA(object): | |
"""running mean and std""" | |
def __init__(self, device, alpha=1e-2): | |
self.device = device | |
self.values = torch.zeros((2, )).to(device) | |
self.alpha = alpha | |
self.range = torch.tensor([0.05, 0.95]).to(device) | |
def __call__(self, x): | |
flat_x = torch.flatten(x.detach()) | |
x_quantile = torch.quantile(input=flat_x, q=self.range) | |
self.values = self.alpha * x_quantile + (1 - self.alpha) * self.values | |
scale = torch.clip(self.values[1] - self.values[0], min=1.0) | |
offset = self.values[0] | |
return offset.detach(), scale.detach() | |
def tensorstats(tensor, prefix=None): | |
metrics = { | |
'mean': torch.mean(tensor).detach().cpu().numpy(), | |
'std': torch.std(tensor).detach().cpu().numpy(), | |
'min': torch.min(tensor).detach().cpu().numpy(), | |
'max': torch.max(tensor).detach().cpu().numpy(), | |
} | |
if prefix: | |
metrics = {f'{prefix}_{k}': v.item() for k, v in metrics.items()} | |
return metrics | |