import torch class Decorator(torch.nn.Module): def __init__( self, num_tokens: int = 4, token_size: int = 4096, ) -> None: super().__init__() self.weight: torch.nn.Parameter = torch.nn.Parameter( torch.randn(num_tokens, token_size) ) # ensure it is float32 self.weight.data = self.weight.data.float() def forward(self, text_embeds: torch.Tensor, is_unconditional=False) -> torch.Tensor: # make sure the param is float32 if self.weight.dtype != text_embeds.dtype: self.weight.data = self.weight.data.float() # expand batch to match text_embeds batch_size = text_embeds.shape[0] decorator_embeds = self.weight.unsqueeze(0).expand(batch_size, -1, -1) if is_unconditional: # zero pad the decorator embeds decorator_embeds = torch.zeros_like(decorator_embeds) if decorator_embeds.dtype != text_embeds.dtype: decorator_embeds = decorator_embeds.to(text_embeds.dtype) text_embeds = torch.cat((text_embeds, decorator_embeds), dim=-2) return text_embeds