Spaces:
Build error
Build error
import torch | |
from torch import Tensor, nn | |
def rollout_iter( | |
nsteps: int, | |
model: nn.Module, | |
batch: dict[str, Tensor | int | float], | |
) -> Tensor: | |
"""A helper function for performing autoregressive rollout. | |
Args: | |
nsteps (int): The number of rollout steps to take | |
model (nn.Module): A model. | |
batch (dict): A data dictionary common to the Prithvi models. | |
Raises: | |
ValueError: If the number of steps isn't positive. | |
Returns: | |
Tensor: the output of the model after nsteps autoregressive iterations. | |
""" | |
if nsteps < 1: | |
raise ValueError("'nsteps' shouold be a positive int.") | |
xlast = batch["x"][:, 1] | |
batch["lead_time"] = batch["lead_time"][..., 0] | |
# Save the masking ratio to be restored later | |
mask_ratio_tmp = model.mask_ratio_inputs | |
for step in range(nsteps): | |
# After first step, turn off masking | |
if step > 0: | |
model.mask_ratio_inputs = 0.0 | |
batch["static"] = batch["statics"][:, step] | |
batch["climate"] = batch["climates"][:, step] | |
batch["y"] = batch["ys"][:, step] | |
out = model(batch) | |
batch["x"] = torch.cat((xlast[:, None], out[:, None]), dim=1) | |
xlast = out | |
# Restore the masking ratio | |
model.mask_ratio_inputs = mask_ratio_tmp | |
return xlast | |