qq1990's picture
Upload 6 files
60840ab verified
import torch
from torch import Tensor, nn
def rollout_iter(
nsteps: int,
model: nn.Module,
batch: dict[str, Tensor | int | float],
) -> Tensor:
"""A helper function for performing autoregressive rollout.
Args:
nsteps (int): The number of rollout steps to take
model (nn.Module): A model.
batch (dict): A data dictionary common to the Prithvi models.
Raises:
ValueError: If the number of steps isn't positive.
Returns:
Tensor: the output of the model after nsteps autoregressive iterations.
"""
if nsteps < 1:
raise ValueError("'nsteps' shouold be a positive int.")
xlast = batch["x"][:, 1]
batch["lead_time"] = batch["lead_time"][..., 0]
# Save the masking ratio to be restored later
mask_ratio_tmp = model.mask_ratio_inputs
for step in range(nsteps):
# After first step, turn off masking
if step > 0:
model.mask_ratio_inputs = 0.0
batch["static"] = batch["statics"][:, step]
batch["climate"] = batch["climates"][:, step]
batch["y"] = batch["ys"][:, step]
out = model(batch)
batch["x"] = torch.cat((xlast[:, None], out[:, None]), dim=1)
xlast = out
# Restore the masking ratio
model.mask_ratio_inputs = mask_ratio_tmp
return xlast