Spaces:
Build error
Build error
File size: 1,046 Bytes
fa9854e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 |
from importlib.metadata import requires
import torch
import torch.nn as nn
from .registry import register_model
from .vlpencoder import LanguageEncoder
class FixLanguageEncoder(LanguageEncoder):
def __init__(
self,
*args, **kwargs):
super(FixLanguageEncoder, self).__init__(*args, **kwargs)
self.logit_scale = nn.Parameter(torch.ones([]), requires_grad=False)
@torch.no_grad()
def get_text_embeddings(self, *args, **kwargs):
return super().get_text_embeddings(*args, **kwargs)
@torch.no_grad()
def get_text_token_embeddings(self, *args, **kwargs):
return super().get_text_token_embeddings(*args, **kwargs)
@torch.no_grad()
def forward_language(self, *args, **kwargs):
return super().forward_language(*args, **kwargs)
@torch.no_grad()
def forward_language_token(self, *args, **kwargs):
return super().forward_language_token(*args, **kwargs)
@register_model
def get_language_model(cfg, **kwargs):
return FixLanguageEncoder(cfg) |