Spaces:
Paused
Paused
File size: 747 Bytes
1c72248 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 |
import torch
import torch.nn as nn
class ReduxImageEncoder(torch.nn.Module):
def __init__(
self,
redux_dim: int = 1152,
txt_in_features: int = 4096,
device=None,
dtype=None,
) -> None:
super().__init__()
self.redux_dim = redux_dim
self.device = device
self.dtype = dtype
self.redux_up = nn.Linear(redux_dim, txt_in_features * 3, dtype=dtype)
self.redux_down = nn.Linear(
txt_in_features * 3, txt_in_features, dtype=dtype)
def forward(self, sigclip_embeds) -> torch.Tensor:
x = self.redux_up(sigclip_embeds)
x = torch.nn.functional.silu(x)
projected_x = self.redux_down(x)
return projected_x
|