Spaces:
Paused
Paused
import torch | |
class Decorator(torch.nn.Module): | |
def __init__( | |
self, | |
num_tokens: int = 4, | |
token_size: int = 4096, | |
) -> None: | |
super().__init__() | |
self.weight: torch.nn.Parameter = torch.nn.Parameter( | |
torch.randn(num_tokens, token_size) | |
) | |
# ensure it is float32 | |
self.weight.data = self.weight.data.float() | |
def forward(self, text_embeds: torch.Tensor, is_unconditional=False) -> torch.Tensor: | |
# make sure the param is float32 | |
if self.weight.dtype != text_embeds.dtype: | |
self.weight.data = self.weight.data.float() | |
# expand batch to match text_embeds | |
batch_size = text_embeds.shape[0] | |
decorator_embeds = self.weight.unsqueeze(0).expand(batch_size, -1, -1) | |
if is_unconditional: | |
# zero pad the decorator embeds | |
decorator_embeds = torch.zeros_like(decorator_embeds) | |
if decorator_embeds.dtype != text_embeds.dtype: | |
decorator_embeds = decorator_embeds.to(text_embeds.dtype) | |
text_embeds = torch.cat((text_embeds, decorator_embeds), dim=-2) | |
return text_embeds | |