Spaces:
Runtime error
Runtime error
File size: 1,434 Bytes
b572032 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
import math
from dataclasses import dataclass
import torch
import torch.nn as nn
from einops import rearrange, repeat
from jaxtyping import Float
from torch import Tensor
from spar3d.models.utils import BaseModule
class TriplaneLearnablePositionalEmbedding(BaseModule):
@dataclass
class Config(BaseModule.Config):
plane_size: int = 96
num_channels: int = 1024
cfg: Config
def configure(self) -> None:
self.embeddings = nn.Parameter(
torch.randn(
(3, self.cfg.num_channels, self.cfg.plane_size, self.cfg.plane_size),
dtype=torch.float32,
)
* 1
/ math.sqrt(self.cfg.num_channels)
)
def forward(self, batch_size: int) -> Float[Tensor, "B Ct Nt"]:
return rearrange(
repeat(self.embeddings, "Np Ct Hp Wp -> B Np Ct Hp Wp", B=batch_size),
"B Np Ct Hp Wp -> B Ct (Np Hp Wp)",
)
def detokenize(
self, tokens: Float[Tensor, "B Ct Nt"]
) -> Float[Tensor, "B 3 Ct Hp Wp"]:
batch_size, Ct, Nt = tokens.shape
assert Nt == self.cfg.plane_size**2 * 3
assert Ct == self.cfg.num_channels
return rearrange(
tokens,
"B Ct (Np Hp Wp) -> B Np Ct Hp Wp",
Np=3,
Hp=self.cfg.plane_size,
Wp=self.cfg.plane_size,
)
|