|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from dataclasses import dataclass |
|
from typing import Optional |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
from ...configuration_utils import ConfigMixin, register_to_config |
|
from ...models.modeling_utils import ModelMixin |
|
from ...utils import BaseOutput |
|
|
|
|
|
@dataclass |
|
class ReduxImageEncoderOutput(BaseOutput): |
|
image_embeds: Optional[torch.Tensor] = None |
|
|
|
|
|
class ReduxImageEncoder(ModelMixin, ConfigMixin): |
|
@register_to_config |
|
def __init__( |
|
self, |
|
redux_dim: int = 1152, |
|
txt_in_features: int = 4096, |
|
) -> None: |
|
super().__init__() |
|
|
|
self.redux_up = nn.Linear(redux_dim, txt_in_features * 3) |
|
self.redux_down = nn.Linear(txt_in_features * 3, txt_in_features) |
|
|
|
def forward(self, x: torch.Tensor) -> ReduxImageEncoderOutput: |
|
projected_x = self.redux_down(nn.functional.silu(self.redux_up(x))) |
|
|
|
return ReduxImageEncoderOutput(image_embeds=projected_x) |
|
|