from __future__ import annotations import re import copy import os import pickle from functools import partial from pathlib import Path from typing import Any, Dict, List, Optional, Tuple, Union from decoupled_utils import clear_cache from model_utils import wrapped_batch_decode from unidisc.tokenizers.image_tokenizers import decode_latents, get_image_batch os.environ["UNIDISC_FORCE_CUDNN_SPDA_CONTEXT"] = "1" os.environ["UNIDISC_DISABLE_APEX_RMSNORM"] = "1" import hydra import numpy as np import torch import torch.nn.functional as F from einops import rearrange from hydra import compose, initialize from image_utils import Im from omegaconf import OmegaConf, open_dict from PIL import Image from accelerate import PartialState import dataloader from decoupled_utils import (breakpoint_on_error, gprint, set_global_breakpoint, set_global_exists) from demo.inference_utils import (convert_to_model_input, messages_to_batch, save_grid_image) from demo.server import ChatRequest, ChatMessage, ContentPart from utils import set_omega_conf_resolvers, set_torch_defaults os.environ["HYDRA_FULL_ERROR"] = "1" set_global_breakpoint() # Overrides breakpoint() to use ipdb.set_trace() instead and handle distributed training set_global_exists() set_omega_conf_resolvers() def get_cfg(overrides: Union[str, list[str]]): with initialize(version_base=None, config_path='configs'): cfg = compose(config_name='config.yaml', return_hydra_config=False, overrides=overrides) return cfg def set_accelerator(config): from accelerate import Accelerator mixed_precision = "bf16" accelerator = Accelerator(mixed_precision=mixed_precision) compute_dtype = torch.float32 if accelerator.mixed_precision == "fp16": compute_dtype = torch.float16 elif accelerator.mixed_precision == "bf16": compute_dtype = torch.bfloat16 gprint(f"Compute dtype is: {compute_dtype}") with open_dict(config): config.trainer.devices = accelerator.num_processes config.trainer.dtype = str(compute_dtype) return config, accelerator def setup(config=None, save_config=False, demo_type="jan", device=None, profile_memory=True): if profile_memory: torch.cuda.memory._record_memory_history() from torchtnt.utils.oom import attach_oom_observer attach_oom_observer(output_dir=str(os.getcwd()), trace_max_entries=500000) set_torch_defaults() if config is not None: demo_type = "jan" config_path = Path(__file__).parent / f'outputs/config_{demo_type}.pkl' if save_config: OmegaConf.resolve(config) config_path.parent.mkdir(parents=True, exist_ok=True) with open(config_path, 'wb') as f: pickle.dump(config, f) yaml_config_path = config_path.with_suffix('.yaml') with open(yaml_config_path, 'w') as yaml_file: OmegaConf.save(config, yaml_file) print(f"Saved config to {config_path}") exit() elif config is None: with open(config_path, 'rb') as f: config = pickle.load(f) print(f"Loaded config from {config_path}") if config is not None: demo_type = "jan" config, accelerator = set_accelerator(config) from model import Diffusion device = PartialState().device if device is None else device model = Diffusion(config=config, tokenizer=dataloader.get_tokenizer(config), device=device) model.set_accelerator(accelerator) return partial(inference, config=config, model=model) mask_token = "" def expand_mask_tokens(messages: List[ChatMessage]) -> List[ChatMessage]: # Expand -> ... (N times) import re def replace_match(match: re.Match) -> str: # Extract the number after str: # Match pattern = r'' return re.sub(pattern, replace_match, text) messages = copy.deepcopy(messages) for message in messages: for content in message.content: if content.type == "text": print(f"Input text: {content.text}") content.text = expand_mask_in_text(content.text) print(f"Expanded text: {content.text}") return messages def get_fixed_batch(config, tokenizer, model, input_data, resolution): assert len(input_data) == 2, "Input data must contain 2 messages" images = [content["image_url"] for content in input_data if content['type'] == "image_url"] texts = [content["text"] for content in input_data if content['type'] == "text"] assert len(images) == 1 assert len(texts) == 1 _img = Im(images[0]) if not _img.height == _img.width: _img = _img.square(resolution, resolution) elif _img.height != resolution or _img.width != resolution: _img = _img.resize(resolution, resolution) img_image_ids = get_image_batch(config, model.get_vae(), {"img": _img.torch[None]}, model.device) txt_input_ids = dataloader.tokenize_text(tokenizer, config.data.block_size, texts) data = {} seq_len = config.data.block_size + img_image_ids.shape[1] # Allow variable token length data["input_ids"] = txt_input_ids["input_ids"] data["attention_mask"] = txt_input_ids["attention_mask"] data["modality"] = torch.full((1, seq_len,), dtype=torch.int64, fill_value=1) # assuming images data["modality"][..., :data["input_ids"].shape[1]] = 0 data["input_ids"] = torch.cat([data["input_ids"].to(model.device), img_image_ids], dim=-1) data["attention_mask"] = torch.cat([data["attention_mask"], torch.full((1, seq_len - data["attention_mask"].shape[1],), dtype=torch.bool, fill_value=True)], dim=-1).bool() data["img"] = Im(images[0]).torch[None] data["sample_ids"] = torch.full((1, seq_len), dtype=torch.int64, fill_value=0) for k in list(data.keys()): data[k] = data[k].to(model.device) data["input_ids"] = torch.where( (data["modality"] == 1) & (data["input_ids"] != -1), data["input_ids"] + config.data.img_token_shift, data["input_ids"] ) return data def inference( request: ChatRequest, config = None, model = None, ): messages = request.messages messages = expand_mask_tokens(messages) input_request = copy.deepcopy(request) with open_dict(config): config.eval.top_p = request.top_p config.eval.temperature = request.temperature config.eval.maskgit_r_temp = request.maskgit_r_temp config.eval.cfg = request.cfg config.sampling.predictor = request.sampler model.sampler = config.sampling.predictor gen_img = False gen_txt = False resolution = request.resolution print(f"messages: {messages}") img_contains_mask = any(content.is_mask for msg in messages for content in msg.content) input_contains_img = any(content.type == "image_url" for msg in messages for content in msg.content) if img_contains_mask: print(f"img_contains_mask: {img_contains_mask}") recent_mask = [content.image_url for content in messages[-1].content if content.is_mask] recent_img = [content.image_url for content in messages[-1].content if (content.type == "image_url" and not content.is_mask)] assert len(recent_mask) == len(recent_img) == 1, "Number of masks must match number of images" recent_mask = recent_mask[0] recent_img = recent_img[0] for msg in messages: msg.content = [content for content in msg.content if not content.is_mask] if any("" in content.text for content in messages[-1].content if content.type == "text") or (not input_contains_img): print(f"Generating image: {messages[-1].content[-1].text}") gen_img = True messages[-1].content[-1].text = messages[-1].content[-1].text.replace("", "") messages.append(ChatMessage( role="assistant", content=[ContentPart( type="image_url", image_url=Image.new("RGB", (resolution, resolution), color=(0, 0, 0)) )] )) elif not any(mask_token in content.text for content in messages[-1].content if content.type == "text") and config.trainer.interleaved and not getattr(config.eval, "static_img_txt_demo", False): print(f"Generating {request.max_tokens} tokens of text") gen_txt = True messages.append(ChatMessage( role="assistant", content=[ContentPart( # "authentication" is a single token in the tokenizer so this gives us exact control over the number of tokens type="text", text="authentication" * request.max_tokens )] )) elif any(mask_token in content.text for content in messages[-2].content if content.type == "text") and getattr(config.eval, "static_img_txt_demo", False): print(f"Got user text input with mask tokens, generating text") gen_txt = True force_reorder = True mask_eos = True if force_reorder and input_contains_img: image_messages = [msg for msg in messages if any(content.type == "image_url" for content in msg.content)] messages = [msg for msg in messages if not any(content.type == "image_url" for content in msg.content)] messages.extend(image_messages) print(f"Reordered messages, images are now last") messages = convert_to_model_input(messages) print(f"input messages: {messages}") all_special_tokens = {x.content for x in model.tokenizer.added_tokens_decoder.values()} if mask_token not in all_special_tokens: new_tokens = [mask_token] new_tokens = list(set(new_tokens) - set(model.tokenizer.get_vocab().keys())) model.tokenizer.add_special_tokens({"additional_special_tokens": new_tokens}, replace_additional_special_tokens=False) assert model.tokenizer.added_tokens_decoder[len(model.tokenizer) - 1].content == mask_token print(model.tokenizer(new_tokens, add_special_tokens=False)) mask_token_id = len(model.tokenizer) - 1 if config.trainer.interleaved: batch = messages_to_batch(config, model.tokenizer, model, messages, resolution=resolution) else: batch = get_fixed_batch(config, model.tokenizer, model, messages, resolution=resolution) sampling_steps = request.sampling_steps sample_modality = batch["modality"] x0 = batch["input_ids"] x0_unmask = torch.ones_like(x0, dtype=torch.bool) txt_contains_mask = False for i in range(x0.shape[0]): if gen_img or img_contains_mask: modality_seq = sample_modality[i] changes = torch.diff(modality_seq) change_points = torch.where(changes != 0)[0] + 1 change_points = torch.cat([torch.tensor([0], device=change_points.device), change_points, torch.tensor([len(modality_seq)], device=change_points.device)]) sequences = [] for start, end in zip(change_points[:-1], change_points[1:]): if modality_seq[start] == 1: sequences.append((start.item(), end.item())) if sequences: last_start, last_end = sequences[-1] x0_unmask[i, last_start:last_end] = False print(f"Masked slice: {last_start}:{last_end}") else: print(f"WARNING: No sequences found") if img_contains_mask: def downscale_bool(arr: np.ndarray, D: int) -> np.ndarray: if len(arr.shape) == 3: print(f"Converting (H, W, C) to (H, W)") arr = arr.sum(axis=-1) H, W = arr.shape assert H % D == 0 and W % D == 0, "H and W must be divisible by D" return arr.reshape(H // D, D, W // D, D).any(axis=(1, 3)) import math _res = int(math.sqrt(last_end - last_start) * config.model.downscale_ratio) if recent_mask.size != (_res, _res): print(f"WARNING!! recent_mask.size: {recent_mask.size}, last_end - last_start: {last_end - last_start}") mask_arr = downscale_bool(np.array(recent_mask.convert("RGB").resize((resolution, resolution), resample=Image.Resampling.NEAREST)).astype(np.bool_), config.model.downscale_ratio) print(Im(mask_arr).save()) mask_arr = torch.from_numpy(mask_arr).to(x0_unmask.device).reshape(-1).nonzero().squeeze() mask_arr = mask_arr + last_start x0_unmask[i, last_start:last_end] = True x0_unmask[i, mask_arr] = False if gen_img and not gen_txt and getattr(config.eval, "static_img_txt_demo", False): print(f"Unmasking all text positions for static_img_txt_demo: {x0_unmask[i, modality_seq == 0].sum().item()}") x0_unmask[i, modality_seq == 0] = True elif gen_txt and not getattr(config.eval, "static_img_txt_demo", False): bos_positions = torch.where(x0[i] == model.tokenizer.bos_token_id)[0] if len(bos_positions) == 0: continue last_bos = bos_positions[-2] if force_reorder else bos_positions[-1] eos_positions = torch.where((x0[i] == model.tokenizer.eos_token_id) & (torch.arange(len(x0[i]), device=x0.device) > last_bos))[0] print(f"BOS positions: {bos_positions}, EOS positions: {eos_positions}") unmask_to_eos = True if unmask_to_eos and len(eos_positions) > 0: last_eos = eos_positions[0] else: last_eos = None # Mask everything after last BOS x0_unmask[i, last_bos+1:last_eos] = False if mask_eos and force_reorder: x0_unmask[i, last_bos+1:last_eos+3] = False print(f"Masked slice: {last_bos}:{last_eos}") to_mask = x0[i] == mask_token_id if to_mask.sum().item() > 0: x0_unmask[i, to_mask] = False print(f"Found {to_mask.sum().item()} text mask tokens") txt_contains_mask = True # Add metrics for x0_unmask[0] true_indices = torch.where(x0_unmask[0])[0] first_true = true_indices[0].item() last_true = true_indices[-1].item() total_true = x0_unmask[0].sum().item() masked_modalities = batch["modality"][0][x0_unmask[0]] zeros = (masked_modalities == 0).sum().item() ones = (masked_modalities == 1).sum().item() print(f"x0_unmask num unmasked: {total_true}, x0_unmask, first position: {first_true}, x0_unmask last position: {last_true}") print(f"x0_unmask num txt (0) count: {zeros}, x0_unmask num img (1) count: {ones}") print(f"Masking {((~x0_unmask) & (batch['sample_ids'] >= 0)).sum().item()} positions, modality shape: {batch['modality'].shape}") # Find first invalid sample ID, defaulting to full length if none found invalid_positions = (batch["sample_ids"][0].long() == -1).nonzero(as_tuple=True)[0] first_invalid_sample_id = invalid_positions[0].item() if len(invalid_positions) > 0 else len(batch["sample_ids"][0]) print(f"First invalid sample ID position: {first_invalid_sample_id}") row_len = save_grid_image(x0_unmask[0][:first_invalid_sample_id], "x0_unmask_viz.png") save_grid_image(batch["modality"][0][:first_invalid_sample_id], "modality_viz.png", row_len=row_len) _sc = batch["sample_ids"][0].clone() _sc[_sc == 0] = 1 _sc[_sc == -1] = 0 save_grid_image(_sc, "sample_ids_viz.png", row_len=row_len) if request.use_reward_models: idx = 0 bs = 1 num_iter = 4 from tensordict import TensorDict gen_batch = TensorDict.from_dict(batch, batch_size=[batch['input_ids'].shape[0]]) text_samples_list = [] img_samples_list = [] _gen_batch = [] for i in range(num_iter): _gen_batch.append(gen_batch[[idx]]) gen_batch = torch.cat(_gen_batch, dim=0) for j in range(num_iter): _modality = gen_batch[[idx]].get("modality", None) _sample_ids = gen_batch[[idx]].get("sample_ids", None) if _modality is not None: _modality = _modality.to(model.device) if _sample_ids is not None: _sample_ids = _sample_ids.to(model.device) else: _sample_ids = torch.zeros_like(_modality) text_samples, img_samples, x = model._sample( text_only=False, num_steps=sampling_steps, batch_size_per_gpu=bs, modality=_modality, sample_ids=_sample_ids, x0=gen_batch["input_ids"][[idx]].to(model.device), x0_unmask=x0_unmask[[idx]].to(model.device), return_raw_data=True, allow_interleaved_conditional=True ) gen_batch[[idx]]['input_ids'] = x text_samples_list.extend(text_samples) img_samples_list.extend(img_samples) print(f"Sampled {j + 1} / {num_iter}") text_samples_list = wrapped_batch_decode( model.tokenizer, torch.stack(text_samples_list, dim=0), clean_up_tokenization_spaces=True, skip_special_tokens=True, disable_mask_after_eos=True ) text_samples_list = [x.replace("You are a highly intelligent multimodal AI with the ability to analyze and generate images.", "").removeprefix(" ") for x in text_samples_list] if isinstance(img_samples_list[0], Image.Image): img_tensors = [] for img in img_samples_list: img_tensor = torch.tensor(np.array(img)).permute(2, 0, 1).float() / 255.0 img_tensors.append(img_tensor.unsqueeze(0)) img_samples_list = torch.cat(img_tensors, dim=0) else: img_samples_list = torch.cat(img_samples_list, dim=0) reward_config = config.eval.auto_enhance_reward_config rewards, raw_rewards = model.get_rewards(reward_config, img_samples_list, text_samples_list, batch=gen_batch, return_raw_rewards=True) gprint(f"Avg Rewards: {rewards}") sorted_indices = torch.argsort(rewards, descending=True).tolist() sorted_text_samples = [text_samples_list[i] for i in sorted_indices] sorted_img_samples = [img_samples_list[i] for i in sorted_indices] sorted_avg_rewards = [rewards[i] for i in sorted_indices] sorted_raw_rewards = {k: [raw_rewards[k][i] for i in sorted_indices] for k in raw_rewards} txt_samples = [sorted_text_samples[0]] img_samples = [Im(sorted_img_samples[0]).pil] else: txt_samples, img_samples = model._sample( text_only=False, num_steps=sampling_steps, batch_size_per_gpu=1, example_batch=batch, sample_batch_idx=0, modality=batch["modality"], sample_ids=batch["sample_ids"], allow_interleaved_conditional=True, x0_unmask=x0_unmask, x0=x0, ) if not config.trainer.interleaved: txt_samples = model.tokenizer.batch_decode(txt_samples[..., model.static_txt_sl], remove_special_tokens=True) txt_samples[0] = txt_samples[0].replace("", "").strip() img_len = (resolution // config.model.downscale_ratio)**2 img_samples = decode_latents(config, model.get_vae(), img_samples[..., -img_len:]) assert img_samples.shape[0] == 1 img_samples = [Im(img_samples[0]).pil] returned_message = ChatMessage( role="assistant", content=[] ) if img_contains_mask or gen_img: print(f"Inference returned img_samples: {img_samples}") returned_message.content.append(ContentPart( type="image_url", image_url=img_samples[-1] )) if txt_contains_mask or not gen_img: print(f"Inference returned txt_samples: {txt_samples}") last_new_txt = "" for i, _txt in enumerate(txt_samples[-1].rsplit("")): if len(_txt) > 0: _txt = _txt.replace("", "").replace("", "").replace("", "").strip().replace(' ', ' ').replace(' ', ' ').replace(' .', '.').replace('\\n', ' ') _txt = re.sub(r'[^a-zA-Z. ]', '', _txt) if len(_txt) > 0: last_new_txt = _txt returned_message.content.append(ContentPart( type="text", text=last_new_txt )) input_request.messages.append(returned_message) clear_cache() return input_request @hydra.main(version_base=None, config_path="../configs", config_name="config") @torch.no_grad() def main(config=None): inference = setup(config, save_config=True) exit() inference([{"type": "text", "text": "Hello, how are you?"}]) if __name__ == "__main__": with breakpoint_on_error(): main()