Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn.functional as F | |
from transformers import AutoConfig | |
from .modeling_utils import ConfigMixin, ModelMixin, register_to_config | |
from .sampling import cosine_schedule, mask_by_random_topk | |
from .phi import PhiForCausalLM | |
try: | |
import xformers.ops as xops | |
is_xformers_available = True | |
except ImportError: | |
is_xformers_available = False | |
class Showo(ModelMixin, ConfigMixin): | |
_supports_gradient_checkpointing = True | |
def __init__( | |
self, | |
w_clip_vit, | |
vocab_size, | |
llm_vocab_size, | |
llm_model_path='', | |
codebook_size=8192, | |
num_vq_tokens=256, | |
**kwargs, | |
): | |
super().__init__() | |
self.vocab_size = vocab_size | |
self.register_to_config(mask_token_id=vocab_size - 1) | |
config = AutoConfig.from_pretrained(llm_model_path) | |
self.showo = PhiForCausalLM(config) | |
self.showo.resize_token_embeddings(self.vocab_size) | |
self.output_size = self.vocab_size | |
if self.w_clip_vit: | |
self.mm_projector = torch.nn.Sequential( | |
torch.nn.Linear(1024, 2048), | |
torch.nn.GELU(), | |
torch.nn.Linear(2048, 2048) | |
) | |
def _set_gradient_checkpointing(self, module, value=False): | |
self.gradient_checkpointing = True | |
def forward( | |
self, | |
input_ids, | |
input_embeddings=None, | |
attention_mask=None, | |
labels=None, | |
label_smoothing=0.0, | |
config=None, | |
labels_mask_text=None, | |
labels_mask_image=None, | |
**kwargs, | |
): | |
if input_embeddings is None: | |
logits = self.showo(input_ids=input_ids, attention_mask=attention_mask)['logits'] | |
else: | |
logits = self.showo(inputs_embeds=input_embeddings, attention_mask=attention_mask)['logits'] | |
if labels is not None: | |
raise NotImplementedError | |
return logits | |
def t2i_generate( | |
self, | |
input_ids: torch.LongTensor = None, | |
uncond_input_ids: torch.LongTensor = None, | |
attention_mask=None, | |
temperature=1.0, | |
timesteps=18, # ideal number of steps is 18 in maskgit paper | |
guidance_scale=0, | |
noise_schedule=cosine_schedule, | |
generator: torch.Generator = None, | |
uni_prompting=None, | |
config=None, | |
**kwargs, | |
): | |
""" | |
Generate 1:1 similar to the original MaskGit repo | |
https://github.com/google-research/maskgit/blob/main/maskgit/libml/parallel_decode.py#L79 | |
""" | |
# begin with all image token ids masked | |
mask_token_id = self.config.mask_token_id | |
seq_len = config.model.showo.num_vq_tokens | |
input_ids_minus_lm_vocab_size = input_ids[:, -(seq_len + 1):-1].clone() | |
input_ids_minus_lm_vocab_size = torch.where(input_ids_minus_lm_vocab_size == mask_token_id, | |
mask_token_id, | |
input_ids_minus_lm_vocab_size - config.model.showo.llm_vocab_size - 10) | |
# import ipdb | |
# ipdb.set_trace() | |
if uncond_input_ids is not None: | |
uncond_prefix = uncond_input_ids[:, :config.dataset.preprocessing.max_seq_length + 1] | |
for step in range(timesteps): | |
if uncond_input_ids is not None and guidance_scale > 0: | |
uncond_input_ids = torch.cat( | |
[uncond_prefix, input_ids[:, config.dataset.preprocessing.max_seq_length + 1:]], dim=1) | |
model_input = torch.cat([input_ids, uncond_input_ids]) | |
cond_logits, uncond_logits = self(model_input, attention_mask=attention_mask).chunk(2) | |
# logits = uncond_logits + guidance_scale * (cond_logits - uncond_logits) | |
# it seems that muse has different cfg setting | |
logits = (1 + guidance_scale) * cond_logits - guidance_scale * uncond_logits | |
logits = logits[:, -(seq_len + 1):-1, config.model.showo.llm_vocab_size + 10:-1] | |
else: | |
logits = self(input_ids, attention_mask=attention_mask) | |
logits = logits[:, -(seq_len + 1):-1, config.model.showo.llm_vocab_size + 10:-1] | |
probs = logits.softmax(dim=-1) | |
sampled = probs.reshape(-1, logits.size(-1)) | |
sampled_ids = torch.multinomial(sampled, 1, generator=generator)[:, 0].view(*logits.shape[:-1]) | |
unknown_map = input_ids_minus_lm_vocab_size == mask_token_id | |
sampled_ids = torch.where(unknown_map, sampled_ids, input_ids_minus_lm_vocab_size) | |
# Defines the mask ratio for the next round. The number to mask out is | |
# determined by mask_ratio * unknown_number_in_the_beginning. | |
ratio = 1.0 * (step + 1) / timesteps | |
mask_ratio = noise_schedule(torch.tensor(ratio)) | |
# Computes the probabilities of each selected tokens. | |
selected_probs = torch.gather(probs, -1, sampled_ids.long()[..., None]) | |
selected_probs = selected_probs.squeeze(-1) | |
# Ignores the tokens given in the input by overwriting their confidence. | |
selected_probs = torch.where(unknown_map, selected_probs, torch.finfo(selected_probs.dtype).max) | |
# Gets mask lens for each sample in the batch according to the mask ratio. | |
mask_len = (seq_len * mask_ratio).floor().unsqueeze(0).to(logits.device) | |
# Keeps at least one of prediction in this round and also masks out at least | |
# one and for the next iteration | |
mask_len = torch.max( | |
torch.tensor([1], device=logits.device), torch.min(unknown_map.sum(dim=-1, keepdim=True) - 1, mask_len) | |
) | |
# Adds noise for randomness | |
temperature = temperature * (1.0 - ratio) | |
masking = mask_by_random_topk(mask_len, selected_probs, temperature, generator=generator) | |
# Masks tokens with lower confidence. | |
input_ids[:, -(seq_len + 1):-1] = torch.where(masking, mask_token_id, | |
sampled_ids + config.model.showo.llm_vocab_size + 10) | |
input_ids_minus_lm_vocab_size = torch.where(masking, mask_token_id, sampled_ids) | |
return sampled_ids | |
def mmu_generate(self, idx=None, input_embeddings=None, attention_mask=None, max_new_tokens=100, temperature=1.0, top_k=None, eot_token=None): | |
""" | |
Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete | |
the sequence max_new_tokens times, feeding the predictions back into the model each time. | |
Most likely you'll want to make sure to be in model.eval() mode of operation for this. | |
""" | |
try: | |
device = idx.device | |
except: | |
device = input_embeddings.device | |
result = [] | |
for _ in range(max_new_tokens): | |
# if the sequence context is growing too long we must crop it at block_size | |
# idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:] | |
# forward the model to get the logits for the index in the sequence | |
# logits, _ = self(idx_cond) | |
logits = self(idx, input_embeddings=input_embeddings, attention_mask=attention_mask) | |
L = attention_mask.shape[-1] | |
attention_mask = attention_mask.squeeze() | |
attention_mask_a = torch.hstack( | |
[ | |
attention_mask, # L, L | |
torch.zeros((L, 1)).to(device) + torch.finfo(logits.dtype).min, | |
] | |
) | |
attention_mask_b = torch.vstack( | |
[ | |
attention_mask_a, # L, L+1 | |
torch.hstack([attention_mask[-1, :], torch.tensor([0]).to(device)]).unsqueeze(0), | |
] | |
) | |
attention_mask = attention_mask_b | |
# pluck the logits at the final step and scale by desired temperature | |
logits = logits[:, -1, :] / temperature | |
# optionally crop the logits to only the top k options | |
if top_k is not None: | |
v, _ = torch.topk(logits, min(top_k, logits.size(-1))) | |
logits[logits < v[:, [-1]]] = -float('Inf') | |
# apply softmax to convert logits to (normalized) probabilities | |
probs = F.softmax(logits, dim=-1) | |
# sample from the distribution | |
idx_next = torch.multinomial(probs, num_samples=1) | |
result.append(idx_next[0][0]) | |
# append sampled index to the running sequence and continue | |
if self.config.w_clip_vit: | |
idx_next_embeddings = self.showo.model.embed_tokens(idx_next) | |
input_embeddings = torch.cat([input_embeddings, idx_next_embeddings], dim=1) | |
else: | |
idx = torch.cat((idx, idx_next), dim=1) | |
if eot_token is not None and idx_next.cpu() == eot_token: | |
break | |
return result | |