Spaces:
Build error
Build error
File size: 1,385 Bytes
60840ab |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
import torch
from torch import Tensor, nn
def rollout_iter(
nsteps: int,
model: nn.Module,
batch: dict[str, Tensor | int | float],
) -> Tensor:
"""A helper function for performing autoregressive rollout.
Args:
nsteps (int): The number of rollout steps to take
model (nn.Module): A model.
batch (dict): A data dictionary common to the Prithvi models.
Raises:
ValueError: If the number of steps isn't positive.
Returns:
Tensor: the output of the model after nsteps autoregressive iterations.
"""
if nsteps < 1:
raise ValueError("'nsteps' shouold be a positive int.")
xlast = batch["x"][:, 1]
batch["lead_time"] = batch["lead_time"][..., 0]
# Save the masking ratio to be restored later
mask_ratio_tmp = model.mask_ratio_inputs
for step in range(nsteps):
# After first step, turn off masking
if step > 0:
model.mask_ratio_inputs = 0.0
batch["static"] = batch["statics"][:, step]
batch["climate"] = batch["climates"][:, step]
batch["y"] = batch["ys"][:, step]
out = model(batch)
batch["x"] = torch.cat((xlast[:, None], out[:, None]), dim=1)
xlast = out
# Restore the masking ratio
model.mask_ratio_inputs = mask_ratio_tmp
return xlast
|