File size: 1,159 Bytes
1c72248
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import torch


class Decorator(torch.nn.Module):
    def __init__(
        self,
        num_tokens: int = 4,
        token_size: int = 4096,
    ) -> None:
        super().__init__()

        self.weight: torch.nn.Parameter = torch.nn.Parameter(
            torch.randn(num_tokens, token_size)
        )
        # ensure it is float32
        self.weight.data = self.weight.data.float()

    def forward(self, text_embeds: torch.Tensor, is_unconditional=False) -> torch.Tensor:
        # make sure the param is float32
        if self.weight.dtype != text_embeds.dtype:
            self.weight.data = self.weight.data.float()
        # expand batch to match text_embeds
        batch_size = text_embeds.shape[0]
        decorator_embeds = self.weight.unsqueeze(0).expand(batch_size, -1, -1)
        if is_unconditional:
            # zero pad the decorator embeds
            decorator_embeds = torch.zeros_like(decorator_embeds)

        if decorator_embeds.dtype != text_embeds.dtype:
            decorator_embeds = decorator_embeds.to(text_embeds.dtype)
        text_embeds = torch.cat((text_embeds, decorator_embeds), dim=-2)

        return text_embeds