|
from __future__ import annotations |
|
|
|
import re |
|
import copy |
|
import os |
|
import pickle |
|
from functools import partial |
|
from pathlib import Path |
|
from typing import Any, Dict, List, Optional, Tuple, Union |
|
from decoupled_utils import clear_cache |
|
|
|
from model_utils import wrapped_batch_decode |
|
from unidisc.tokenizers.image_tokenizers import decode_latents, get_image_batch |
|
|
|
os.environ["UNIDISC_FORCE_CUDNN_SPDA_CONTEXT"] = "1" |
|
os.environ["UNIDISC_DISABLE_APEX_RMSNORM"] = "1" |
|
|
|
import hydra |
|
import numpy as np |
|
import torch |
|
import torch.nn.functional as F |
|
from einops import rearrange |
|
from hydra import compose, initialize |
|
from image_utils import Im |
|
from omegaconf import OmegaConf, open_dict |
|
from PIL import Image |
|
from accelerate import PartialState |
|
|
|
import dataloader |
|
from decoupled_utils import (breakpoint_on_error, gprint, |
|
set_global_breakpoint, set_global_exists) |
|
from demo.inference_utils import (convert_to_model_input, messages_to_batch, save_grid_image) |
|
from demo.server import ChatRequest, ChatMessage, ContentPart |
|
from utils import set_omega_conf_resolvers, set_torch_defaults |
|
|
|
os.environ["HYDRA_FULL_ERROR"] = "1" |
|
|
|
set_global_breakpoint() |
|
set_global_exists() |
|
set_omega_conf_resolvers() |
|
|
|
def get_cfg(overrides: Union[str, list[str]]): |
|
with initialize(version_base=None, config_path='configs'): |
|
cfg = compose(config_name='config.yaml', return_hydra_config=False, overrides=overrides) |
|
return cfg |
|
|
|
def set_accelerator(config): |
|
from accelerate import Accelerator |
|
mixed_precision = "bf16" |
|
accelerator = Accelerator(mixed_precision=mixed_precision) |
|
compute_dtype = torch.float32 |
|
if accelerator.mixed_precision == "fp16": |
|
compute_dtype = torch.float16 |
|
elif accelerator.mixed_precision == "bf16": |
|
compute_dtype = torch.bfloat16 |
|
gprint(f"Compute dtype is: {compute_dtype}") |
|
with open_dict(config): |
|
config.trainer.devices = accelerator.num_processes |
|
config.trainer.dtype = str(compute_dtype) |
|
|
|
return config, accelerator |
|
|
|
def setup(config=None, save_config=False, demo_type="jan", device=None, profile_memory=True): |
|
if profile_memory: |
|
torch.cuda.memory._record_memory_history() |
|
from torchtnt.utils.oom import attach_oom_observer |
|
attach_oom_observer(output_dir=str(os.getcwd()), trace_max_entries=500000) |
|
|
|
set_torch_defaults() |
|
if config is not None: |
|
demo_type = "jan" |
|
config_path = Path(__file__).parent / f'outputs/config_{demo_type}.pkl' |
|
|
|
if save_config: |
|
OmegaConf.resolve(config) |
|
config_path.parent.mkdir(parents=True, exist_ok=True) |
|
with open(config_path, 'wb') as f: |
|
pickle.dump(config, f) |
|
|
|
yaml_config_path = config_path.with_suffix('.yaml') |
|
with open(yaml_config_path, 'w') as yaml_file: |
|
OmegaConf.save(config, yaml_file) |
|
print(f"Saved config to {config_path}") |
|
exit() |
|
elif config is None: |
|
with open(config_path, 'rb') as f: |
|
config = pickle.load(f) |
|
print(f"Loaded config from {config_path}") |
|
|
|
if config is not None: |
|
demo_type = "jan" |
|
|
|
config, accelerator = set_accelerator(config) |
|
from model import Diffusion |
|
device = PartialState().device if device is None else device |
|
model = Diffusion(config=config, tokenizer=dataloader.get_tokenizer(config), device=device) |
|
model.set_accelerator(accelerator) |
|
return partial(inference, config=config, model=model) |
|
|
|
mask_token = "<m>" |
|
|
|
def expand_mask_tokens(messages: List[ChatMessage]) -> List[ChatMessage]: |
|
|
|
import re |
|
def replace_match(match: re.Match) -> str: |
|
|
|
count = int(match.group(1)) |
|
return mask_token * count |
|
|
|
def expand_mask_in_text(text: str) -> str: |
|
|
|
pattern = r'<m(\d+)>' |
|
return re.sub(pattern, replace_match, text) |
|
|
|
messages = copy.deepcopy(messages) |
|
for message in messages: |
|
for content in message.content: |
|
if content.type == "text": |
|
print(f"Input text: {content.text}") |
|
content.text = expand_mask_in_text(content.text) |
|
print(f"Expanded text: {content.text}") |
|
|
|
return messages |
|
|
|
def get_fixed_batch(config, tokenizer, model, input_data, resolution): |
|
assert len(input_data) == 2, "Input data must contain 2 messages" |
|
images = [content["image_url"] for content in input_data if content['type'] == "image_url"] |
|
texts = [content["text"] for content in input_data if content['type'] == "text"] |
|
assert len(images) == 1 |
|
assert len(texts) == 1 |
|
_img = Im(images[0]) |
|
if not _img.height == _img.width: |
|
_img = _img.square(resolution, resolution) |
|
elif _img.height != resolution or _img.width != resolution: |
|
_img = _img.resize(resolution, resolution) |
|
img_image_ids = get_image_batch(config, model.get_vae(), {"img": _img.torch[None]}, model.device) |
|
txt_input_ids = dataloader.tokenize_text(tokenizer, config.data.block_size, texts) |
|
|
|
data = {} |
|
seq_len = config.data.block_size + img_image_ids.shape[1] |
|
data["input_ids"] = txt_input_ids["input_ids"] |
|
data["attention_mask"] = txt_input_ids["attention_mask"] |
|
data["modality"] = torch.full((1, seq_len,), dtype=torch.int64, fill_value=1) |
|
data["modality"][..., :data["input_ids"].shape[1]] = 0 |
|
data["input_ids"] = torch.cat([data["input_ids"].to(model.device), img_image_ids], dim=-1) |
|
data["attention_mask"] = torch.cat([data["attention_mask"], torch.full((1, seq_len - data["attention_mask"].shape[1],), dtype=torch.bool, fill_value=True)], dim=-1).bool() |
|
data["img"] = Im(images[0]).torch[None] |
|
data["sample_ids"] = torch.full((1, seq_len), dtype=torch.int64, fill_value=0) |
|
for k in list(data.keys()): |
|
data[k] = data[k].to(model.device) |
|
|
|
data["input_ids"] = torch.where( |
|
(data["modality"] == 1) & (data["input_ids"] != -1), |
|
data["input_ids"] + config.data.img_token_shift, |
|
data["input_ids"] |
|
) |
|
|
|
return data |
|
|
|
|
|
def inference( |
|
request: ChatRequest, |
|
config = None, |
|
model = None, |
|
): |
|
messages = request.messages |
|
messages = expand_mask_tokens(messages) |
|
input_request = copy.deepcopy(request) |
|
|
|
with open_dict(config): |
|
config.eval.top_p = request.top_p |
|
config.eval.temperature = request.temperature |
|
config.eval.maskgit_r_temp = request.maskgit_r_temp |
|
config.eval.cfg = request.cfg |
|
config.sampling.predictor = request.sampler |
|
model.sampler = config.sampling.predictor |
|
|
|
gen_img = False |
|
gen_txt = False |
|
resolution = request.resolution |
|
print(f"messages: {messages}") |
|
img_contains_mask = any(content.is_mask for msg in messages for content in msg.content) |
|
input_contains_img = any(content.type == "image_url" for msg in messages for content in msg.content) |
|
if img_contains_mask: |
|
print(f"img_contains_mask: {img_contains_mask}") |
|
recent_mask = [content.image_url for content in messages[-1].content if content.is_mask] |
|
recent_img = [content.image_url for content in messages[-1].content if (content.type == "image_url" and not content.is_mask)] |
|
assert len(recent_mask) == len(recent_img) == 1, "Number of masks must match number of images" |
|
recent_mask = recent_mask[0] |
|
recent_img = recent_img[0] |
|
for msg in messages: |
|
msg.content = [content for content in msg.content if not content.is_mask] |
|
|
|
if any("<image>" in content.text for content in messages[-1].content if content.type == "text") or (not input_contains_img): |
|
print(f"Generating image: {messages[-1].content[-1].text}") |
|
gen_img = True |
|
messages[-1].content[-1].text = messages[-1].content[-1].text.replace("<image>", "") |
|
messages.append(ChatMessage( |
|
role="assistant", |
|
content=[ContentPart( |
|
type="image_url", |
|
image_url=Image.new("RGB", (resolution, resolution), color=(0, 0, 0)) |
|
)] |
|
)) |
|
elif not any(mask_token in content.text for content in messages[-1].content if content.type == "text") and config.trainer.interleaved and not getattr(config.eval, "static_img_txt_demo", False): |
|
print(f"Generating {request.max_tokens} tokens of text") |
|
gen_txt = True |
|
messages.append(ChatMessage( |
|
role="assistant", |
|
content=[ContentPart( |
|
|
|
type="text", text="authentication" * request.max_tokens |
|
)] |
|
)) |
|
elif any(mask_token in content.text for content in messages[-2].content if content.type == "text") and getattr(config.eval, "static_img_txt_demo", False): |
|
print(f"Got user text input with mask tokens, generating text") |
|
gen_txt = True |
|
|
|
force_reorder = True |
|
mask_eos = True |
|
if force_reorder and input_contains_img: |
|
image_messages = [msg for msg in messages if any(content.type == "image_url" for content in msg.content)] |
|
messages = [msg for msg in messages if not any(content.type == "image_url" for content in msg.content)] |
|
messages.extend(image_messages) |
|
print(f"Reordered messages, images are now last") |
|
|
|
messages = convert_to_model_input(messages) |
|
print(f"input messages: {messages}") |
|
|
|
all_special_tokens = {x.content for x in model.tokenizer.added_tokens_decoder.values()} |
|
if mask_token not in all_special_tokens: |
|
new_tokens = [mask_token] |
|
new_tokens = list(set(new_tokens) - set(model.tokenizer.get_vocab().keys())) |
|
model.tokenizer.add_special_tokens({"additional_special_tokens": new_tokens}, replace_additional_special_tokens=False) |
|
assert model.tokenizer.added_tokens_decoder[len(model.tokenizer) - 1].content == mask_token |
|
print(model.tokenizer(new_tokens, add_special_tokens=False)) |
|
|
|
mask_token_id = len(model.tokenizer) - 1 |
|
if config.trainer.interleaved: |
|
batch = messages_to_batch(config, model.tokenizer, model, messages, resolution=resolution) |
|
else: |
|
batch = get_fixed_batch(config, model.tokenizer, model, messages, resolution=resolution) |
|
|
|
sampling_steps = request.sampling_steps |
|
sample_modality = batch["modality"] |
|
x0 = batch["input_ids"] |
|
x0_unmask = torch.ones_like(x0, dtype=torch.bool) |
|
txt_contains_mask = False |
|
for i in range(x0.shape[0]): |
|
if gen_img or img_contains_mask: |
|
modality_seq = sample_modality[i] |
|
changes = torch.diff(modality_seq) |
|
change_points = torch.where(changes != 0)[0] + 1 |
|
change_points = torch.cat([torch.tensor([0], device=change_points.device), change_points, torch.tensor([len(modality_seq)], device=change_points.device)]) |
|
|
|
sequences = [] |
|
for start, end in zip(change_points[:-1], change_points[1:]): |
|
if modality_seq[start] == 1: |
|
sequences.append((start.item(), end.item())) |
|
|
|
if sequences: |
|
last_start, last_end = sequences[-1] |
|
x0_unmask[i, last_start:last_end] = False |
|
print(f"Masked slice: {last_start}:{last_end}") |
|
else: |
|
print(f"WARNING: No sequences found") |
|
|
|
if img_contains_mask: |
|
def downscale_bool(arr: np.ndarray, D: int) -> np.ndarray: |
|
if len(arr.shape) == 3: |
|
print(f"Converting (H, W, C) to (H, W)") |
|
arr = arr.sum(axis=-1) |
|
H, W = arr.shape |
|
assert H % D == 0 and W % D == 0, "H and W must be divisible by D" |
|
return arr.reshape(H // D, D, W // D, D).any(axis=(1, 3)) |
|
|
|
import math |
|
_res = int(math.sqrt(last_end - last_start) * config.model.downscale_ratio) |
|
if recent_mask.size != (_res, _res): |
|
print(f"WARNING!! recent_mask.size: {recent_mask.size}, last_end - last_start: {last_end - last_start}") |
|
mask_arr = downscale_bool(np.array(recent_mask.convert("RGB").resize((resolution, resolution), resample=Image.Resampling.NEAREST)).astype(np.bool_), config.model.downscale_ratio) |
|
print(Im(mask_arr).save()) |
|
mask_arr = torch.from_numpy(mask_arr).to(x0_unmask.device).reshape(-1).nonzero().squeeze() |
|
mask_arr = mask_arr + last_start |
|
x0_unmask[i, last_start:last_end] = True |
|
x0_unmask[i, mask_arr] = False |
|
|
|
if gen_img and not gen_txt and getattr(config.eval, "static_img_txt_demo", False): |
|
print(f"Unmasking all text positions for static_img_txt_demo: {x0_unmask[i, modality_seq == 0].sum().item()}") |
|
x0_unmask[i, modality_seq == 0] = True |
|
elif gen_txt and not getattr(config.eval, "static_img_txt_demo", False): |
|
bos_positions = torch.where(x0[i] == model.tokenizer.bos_token_id)[0] |
|
if len(bos_positions) == 0: |
|
continue |
|
|
|
last_bos = bos_positions[-2] if force_reorder else bos_positions[-1] |
|
eos_positions = torch.where((x0[i] == model.tokenizer.eos_token_id) & (torch.arange(len(x0[i]), device=x0.device) > last_bos))[0] |
|
|
|
print(f"BOS positions: {bos_positions}, EOS positions: {eos_positions}") |
|
unmask_to_eos = True |
|
if unmask_to_eos and len(eos_positions) > 0: |
|
last_eos = eos_positions[0] |
|
else: |
|
last_eos = None |
|
|
|
x0_unmask[i, last_bos+1:last_eos] = False |
|
if mask_eos and force_reorder: |
|
x0_unmask[i, last_bos+1:last_eos+3] = False |
|
print(f"Masked slice: {last_bos}:{last_eos}") |
|
|
|
to_mask = x0[i] == mask_token_id |
|
if to_mask.sum().item() > 0: |
|
x0_unmask[i, to_mask] = False |
|
print(f"Found {to_mask.sum().item()} text mask tokens") |
|
txt_contains_mask = True |
|
|
|
|
|
|
|
true_indices = torch.where(x0_unmask[0])[0] |
|
first_true = true_indices[0].item() |
|
last_true = true_indices[-1].item() |
|
total_true = x0_unmask[0].sum().item() |
|
|
|
masked_modalities = batch["modality"][0][x0_unmask[0]] |
|
zeros = (masked_modalities == 0).sum().item() |
|
ones = (masked_modalities == 1).sum().item() |
|
|
|
print(f"x0_unmask num unmasked: {total_true}, x0_unmask, first position: {first_true}, x0_unmask last position: {last_true}") |
|
print(f"x0_unmask num txt (0) count: {zeros}, x0_unmask num img (1) count: {ones}") |
|
print(f"Masking {((~x0_unmask) & (batch['sample_ids'] >= 0)).sum().item()} positions, modality shape: {batch['modality'].shape}") |
|
|
|
|
|
invalid_positions = (batch["sample_ids"][0].long() == -1).nonzero(as_tuple=True)[0] |
|
first_invalid_sample_id = invalid_positions[0].item() if len(invalid_positions) > 0 else len(batch["sample_ids"][0]) |
|
print(f"First invalid sample ID position: {first_invalid_sample_id}") |
|
row_len = save_grid_image(x0_unmask[0][:first_invalid_sample_id], "x0_unmask_viz.png") |
|
save_grid_image(batch["modality"][0][:first_invalid_sample_id], "modality_viz.png", row_len=row_len) |
|
_sc = batch["sample_ids"][0].clone() |
|
_sc[_sc == 0] = 1 |
|
_sc[_sc == -1] = 0 |
|
save_grid_image(_sc, "sample_ids_viz.png", row_len=row_len) |
|
|
|
if request.use_reward_models: |
|
idx = 0 |
|
bs = 1 |
|
num_iter = 4 |
|
from tensordict import TensorDict |
|
gen_batch = TensorDict.from_dict(batch, batch_size=[batch['input_ids'].shape[0]]) |
|
text_samples_list = [] |
|
img_samples_list = [] |
|
_gen_batch = [] |
|
for i in range(num_iter): |
|
_gen_batch.append(gen_batch[[idx]]) |
|
gen_batch = torch.cat(_gen_batch, dim=0) |
|
|
|
for j in range(num_iter): |
|
_modality = gen_batch[[idx]].get("modality", None) |
|
_sample_ids = gen_batch[[idx]].get("sample_ids", None) |
|
if _modality is not None: |
|
_modality = _modality.to(model.device) |
|
if _sample_ids is not None: |
|
_sample_ids = _sample_ids.to(model.device) |
|
else: |
|
_sample_ids = torch.zeros_like(_modality) |
|
text_samples, img_samples, x = model._sample( |
|
text_only=False, |
|
num_steps=sampling_steps, |
|
batch_size_per_gpu=bs, |
|
modality=_modality, |
|
sample_ids=_sample_ids, |
|
x0=gen_batch["input_ids"][[idx]].to(model.device), |
|
x0_unmask=x0_unmask[[idx]].to(model.device), |
|
return_raw_data=True, |
|
allow_interleaved_conditional=True |
|
) |
|
gen_batch[[idx]]['input_ids'] = x |
|
text_samples_list.extend(text_samples) |
|
img_samples_list.extend(img_samples) |
|
print(f"Sampled {j + 1} / {num_iter}") |
|
|
|
text_samples_list = wrapped_batch_decode( |
|
model.tokenizer, |
|
torch.stack(text_samples_list, dim=0), |
|
clean_up_tokenization_spaces=True, |
|
skip_special_tokens=True, |
|
disable_mask_after_eos=True |
|
) |
|
|
|
text_samples_list = [x.replace("You are a highly intelligent multimodal AI with the ability to analyze and generate images.", "").removeprefix(" ") for x in text_samples_list] |
|
if isinstance(img_samples_list[0], Image.Image): |
|
img_tensors = [] |
|
for img in img_samples_list: |
|
img_tensor = torch.tensor(np.array(img)).permute(2, 0, 1).float() / 255.0 |
|
img_tensors.append(img_tensor.unsqueeze(0)) |
|
img_samples_list = torch.cat(img_tensors, dim=0) |
|
else: |
|
img_samples_list = torch.cat(img_samples_list, dim=0) |
|
|
|
reward_config = config.eval.auto_enhance_reward_config |
|
rewards, raw_rewards = model.get_rewards(reward_config, img_samples_list, text_samples_list, batch=gen_batch, return_raw_rewards=True) |
|
|
|
gprint(f"Avg Rewards: {rewards}") |
|
|
|
sorted_indices = torch.argsort(rewards, descending=True).tolist() |
|
sorted_text_samples = [text_samples_list[i] for i in sorted_indices] |
|
sorted_img_samples = [img_samples_list[i] for i in sorted_indices] |
|
sorted_avg_rewards = [rewards[i] for i in sorted_indices] |
|
sorted_raw_rewards = {k: [raw_rewards[k][i] for i in sorted_indices] for k in raw_rewards} |
|
|
|
txt_samples = [sorted_text_samples[0]] |
|
img_samples = [Im(sorted_img_samples[0]).pil] |
|
else: |
|
txt_samples, img_samples = model._sample( |
|
text_only=False, |
|
num_steps=sampling_steps, |
|
batch_size_per_gpu=1, |
|
example_batch=batch, |
|
sample_batch_idx=0, |
|
modality=batch["modality"], |
|
sample_ids=batch["sample_ids"], |
|
allow_interleaved_conditional=True, |
|
x0_unmask=x0_unmask, |
|
x0=x0, |
|
) |
|
|
|
if not config.trainer.interleaved: |
|
txt_samples = model.tokenizer.batch_decode(txt_samples[..., model.static_txt_sl], remove_special_tokens=True) |
|
txt_samples[0] = txt_samples[0].replace("<unk>", "").strip() |
|
img_len = (resolution // config.model.downscale_ratio)**2 |
|
img_samples = decode_latents(config, model.get_vae(), img_samples[..., -img_len:]) |
|
assert img_samples.shape[0] == 1 |
|
img_samples = [Im(img_samples[0]).pil] |
|
|
|
returned_message = ChatMessage( |
|
role="assistant", |
|
content=[] |
|
) |
|
if img_contains_mask or gen_img: |
|
print(f"Inference returned img_samples: {img_samples}") |
|
returned_message.content.append(ContentPart( |
|
type="image_url", |
|
image_url=img_samples[-1] |
|
)) |
|
|
|
if txt_contains_mask or not gen_img: |
|
print(f"Inference returned txt_samples: {txt_samples}") |
|
last_new_txt = "" |
|
for i, _txt in enumerate(txt_samples[-1].rsplit("<s>")): |
|
if len(_txt) > 0: |
|
_txt = _txt.replace("</s>", "").replace("<image>", "").replace("<s>", "").strip().replace(' ', ' ').replace(' ', ' ').replace(' .', '.').replace('\\n', ' ') |
|
_txt = re.sub(r'[^a-zA-Z. ]', '', _txt) |
|
if len(_txt) > 0: |
|
last_new_txt = _txt |
|
|
|
returned_message.content.append(ContentPart( |
|
type="text", |
|
text=last_new_txt |
|
)) |
|
|
|
input_request.messages.append(returned_message) |
|
|
|
clear_cache() |
|
|
|
|
|
return input_request |
|
|
|
@hydra.main(version_base=None, config_path="../configs", config_name="config") |
|
@torch.no_grad() |
|
def main(config=None): |
|
inference = setup(config, save_config=True) |
|
exit() |
|
inference([{"type": "text", "text": "Hello, how are you?"}]) |
|
|
|
if __name__ == "__main__": |
|
with breakpoint_on_error(): |
|
main() |