Spaces:
Running
on
Zero
Running
on
Zero
from typing import * | |
import os | |
os.environ['TOKENIZERS_PARALLELISM'] = 'true' | |
import torch | |
from transformers import AutoTokenizer, CLIPTextModel | |
from ....utils import dist_utils | |
class TextConditionedMixin: | |
""" | |
Mixin for text-conditioned models. | |
Args: | |
text_cond_model: The text conditioning model. | |
""" | |
def __init__(self, *args, text_cond_model: str = 'openai/clip-vit-large-patch14', **kwargs): | |
super().__init__(*args, **kwargs) | |
self.text_cond_model_name = text_cond_model | |
self.text_cond_model = None # the model is init lazily | |
def _init_text_cond_model(self): | |
""" | |
Initialize the text conditioning model. | |
""" | |
# load model | |
with dist_utils.local_master_first(): | |
model = CLIPTextModel.from_pretrained(self.text_cond_model_name) | |
tokenizer = AutoTokenizer.from_pretrained(self.text_cond_model_name) | |
model.eval() | |
model = model.cuda() | |
self.text_cond_model = { | |
'model': model, | |
'tokenizer': tokenizer, | |
} | |
self.text_cond_model['null_cond'] = self.encode_text(['']) | |
def encode_text(self, text: List[str]) -> torch.Tensor: | |
""" | |
Encode the text. | |
""" | |
assert isinstance(text, list) and isinstance(text[0], str), "TextConditionedMixin only supports list of strings as cond" | |
if self.text_cond_model is None: | |
self._init_text_cond_model() | |
encoding = self.text_cond_model['tokenizer'](text, max_length=77, padding='max_length', truncation=True, return_tensors='pt') | |
tokens = encoding['input_ids'].cuda() | |
embeddings = self.text_cond_model['model'](input_ids=tokens).last_hidden_state | |
return embeddings | |
def get_cond(self, cond, **kwargs): | |
""" | |
Get the conditioning data. | |
""" | |
cond = self.encode_text(cond) | |
kwargs['neg_cond'] = self.text_cond_model['null_cond'].repeat(cond.shape[0], 1, 1) | |
cond = super().get_cond(cond, **kwargs) | |
return cond | |
def get_inference_cond(self, cond, **kwargs): | |
""" | |
Get the conditioning data for inference. | |
""" | |
cond = self.encode_text(cond) | |
kwargs['neg_cond'] = self.text_cond_model['null_cond'].repeat(cond.shape[0], 1, 1) | |
cond = super().get_inference_cond(cond, **kwargs) | |
return cond | |