File size: 21,301 Bytes
131da64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3a60a49
 
 
 
 
 
 
 
 
 
131da64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
from __future__ import annotations

import re
import copy
import os
import pickle
from functools import partial
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
from decoupled_utils import clear_cache

from model_utils import wrapped_batch_decode
from unidisc.tokenizers.image_tokenizers import decode_latents, get_image_batch

os.environ["UNIDISC_FORCE_CUDNN_SPDA_CONTEXT"] = "1"
os.environ["UNIDISC_DISABLE_APEX_RMSNORM"] = "1"

import hydra
import numpy as np
import torch
import torch.nn.functional as F
from einops import rearrange
from hydra import compose, initialize
from image_utils import Im
from omegaconf import OmegaConf, open_dict
from PIL import Image
from accelerate import PartialState

import dataloader
from decoupled_utils import (breakpoint_on_error, gprint,
                             set_global_breakpoint, set_global_exists)
from demo.inference_utils import (convert_to_model_input, messages_to_batch, save_grid_image)
from demo.server import ChatRequest, ChatMessage, ContentPart
from utils import set_omega_conf_resolvers, set_torch_defaults

os.environ["HYDRA_FULL_ERROR"] = "1"

set_global_breakpoint()  # Overrides breakpoint() to use ipdb.set_trace() instead and handle distributed training
set_global_exists()
set_omega_conf_resolvers()

def get_cfg(overrides: Union[str, list[str]]):
    with initialize(version_base=None, config_path='configs'):
        cfg = compose(config_name='config.yaml', return_hydra_config=False, overrides=overrides)
        return cfg
    
def set_accelerator(config):
    from accelerate import Accelerator
    mixed_precision = "bf16"
    accelerator = Accelerator(mixed_precision=mixed_precision)
    compute_dtype = torch.float32
    if accelerator.mixed_precision == "fp16":
        compute_dtype = torch.float16
    elif accelerator.mixed_precision == "bf16":
        compute_dtype = torch.bfloat16
    gprint(f"Compute dtype is: {compute_dtype}")
    with open_dict(config):
        config.trainer.devices = accelerator.num_processes
        config.trainer.dtype = str(compute_dtype)

    return config, accelerator
    
def setup(config=None, save_config=False, demo_type="jan", device=None, profile_memory=True):
    if profile_memory:
        torch.cuda.memory._record_memory_history()
        from torchtnt.utils.oom import attach_oom_observer
        attach_oom_observer(output_dir=str(os.getcwd()), trace_max_entries=500000)

    set_torch_defaults()
    if config is not None:
        demo_type = "jan"
        config_path = Path(__file__).parent / f'outputs/config_{demo_type}.pkl'
        
    if save_config:
        OmegaConf.resolve(config)
        config_path.parent.mkdir(parents=True, exist_ok=True)
        with open(config_path, 'wb') as f:
            pickle.dump(config, f)

        yaml_config_path = config_path.with_suffix('.yaml')
        with open(yaml_config_path, 'w') as yaml_file:
            OmegaConf.save(config, yaml_file)
        print(f"Saved config to {config_path}")
        exit()
    elif config is None:
        with open(config_path, 'rb') as f:
            config = pickle.load(f)
        print(f"Loaded config from {config_path}")

    if config is not None:
        demo_type = "jan"
    
    config, accelerator = set_accelerator(config)
    from model import Diffusion
    device = PartialState().device if device is None else device
    model = Diffusion(config=config, tokenizer=dataloader.get_tokenizer(config), device=device)
    model.set_accelerator(accelerator)
    return partial(inference, config=config, model=model)

mask_token = "<m>"

def expand_mask_tokens(messages: List[ChatMessage]) -> List[ChatMessage]:
    # Expand <mN> -> <m><m><m>...<m> (N times)
    import re
    def replace_match(match: re.Match) -> str:
        # Extract the number after <m
        count = int(match.group(1))
        return mask_token * count
    
    def expand_mask_in_text(text: str) -> str:
        # Match <m followed by one or more digits and >
        pattern = r'<m(\d+)>'
        return re.sub(pattern, replace_match, text)
    
    messages = copy.deepcopy(messages)
    for message in messages:
        for content in message.content:
            if content.type == "text":
                print(f"Input text: {content.text}")
                content.text = expand_mask_in_text(content.text)
                print(f"Expanded text: {content.text}")
    
    return messages

def get_fixed_batch(config, tokenizer, model, input_data, resolution):
    assert len(input_data) == 2, "Input data must contain 2 messages"
    images = [content["image_url"] for content in input_data if content['type'] == "image_url"]
    texts = [content["text"] for content in input_data if content['type'] == "text"]
    assert len(images) == 1
    assert len(texts) == 1
    _img = Im(images[0])
    if not _img.height == _img.width:
        _img = _img.square(resolution, resolution)
    elif _img.height != resolution or _img.width != resolution:
        _img = _img.resize(resolution, resolution)
    img_image_ids = get_image_batch(config, model.get_vae(), {"img": _img.torch[None]}, model.device)
    txt_input_ids = dataloader.tokenize_text(tokenizer, config.data.block_size, texts)

    data = {}
    seq_len = config.data.block_size + img_image_ids.shape[1] # Allow variable token length
    data["input_ids"] = txt_input_ids["input_ids"]
    data["attention_mask"] = txt_input_ids["attention_mask"]
    data["modality"] = torch.full((1, seq_len,), dtype=torch.int64, fill_value=1)  # assuming images
    data["modality"][..., :data["input_ids"].shape[1]] = 0
    data["input_ids"] = torch.cat([data["input_ids"].to(model.device), img_image_ids], dim=-1)
    data["attention_mask"] = torch.cat([data["attention_mask"], torch.full((1, seq_len - data["attention_mask"].shape[1],), dtype=torch.bool, fill_value=True)], dim=-1).bool()
    data["img"] = Im(images[0]).torch[None]
    data["sample_ids"] = torch.full((1, seq_len), dtype=torch.int64, fill_value=0)
    for k in list(data.keys()):
        data[k] = data[k].to(model.device)

    data["input_ids"] = torch.where(
        (data["modality"] == 1) & (data["input_ids"] != -1),
        data["input_ids"] + config.data.img_token_shift,
        data["input_ids"]
    )

    return data


def inference(
    request: ChatRequest,
    config = None,
    model = None,
):
    messages = request.messages
    messages = expand_mask_tokens(messages)
    input_request = copy.deepcopy(request)

    with open_dict(config):
        config.eval.top_p = request.top_p
        config.eval.temperature = request.temperature
        config.eval.maskgit_r_temp = request.maskgit_r_temp
        config.eval.cfg = request.cfg
        config.sampling.predictor = request.sampler
        model.sampler = config.sampling.predictor

    gen_img = False
    gen_txt = False
    resolution = request.resolution
    print(f"messages: {messages}")
    img_contains_mask = any(content.is_mask for msg in messages for content in msg.content)
    input_contains_img = any(content.type == "image_url" for msg in messages for content in msg.content)
    if img_contains_mask:
        print(f"img_contains_mask: {img_contains_mask}")
        recent_mask = [content.image_url for content in messages[-1].content if content.is_mask]
        recent_img = [content.image_url for content in messages[-1].content if (content.type == "image_url" and not content.is_mask)]
        assert len(recent_mask) == len(recent_img) == 1, "Number of masks must match number of images"
        recent_mask = recent_mask[0]
        recent_img = recent_img[0]
        for msg in messages:
            msg.content = [content for content in msg.content if not content.is_mask]
    
    if any("<image>" in content.text for content in messages[-1].content if content.type == "text") or (not input_contains_img):
        print(f"Generating image: {messages[-1].content[-1].text}")
        gen_img = True
        messages[-1].content[-1].text = messages[-1].content[-1].text.replace("<image>", "")
        messages.append(ChatMessage(
            role="assistant",
            content=[ContentPart(
                type="image_url",
                image_url=Image.new("RGB", (resolution, resolution), color=(0, 0, 0))
            )]
        ))
    elif not any(mask_token in content.text for content in messages[-1].content if content.type == "text") and config.trainer.interleaved and not getattr(config.eval, "static_img_txt_demo", False):
        print(f"Generating {request.max_tokens} tokens of text")
        gen_txt = True
        messages.append(ChatMessage(
            role="assistant",
            content=[ContentPart(
                # "authentication" is a single token in the tokenizer so this gives us exact control over the number of tokens
                type="text", text="authentication" * request.max_tokens
            )]
        ))
    elif any(mask_token in content.text for content in messages[-2].content if content.type == "text") and getattr(config.eval, "static_img_txt_demo", False):
        print(f"Got user text input with mask tokens, generating text")
        gen_txt = True

    force_reorder = True
    mask_eos = True
    if force_reorder and input_contains_img:
        image_messages = [msg for msg in messages if any(content.type == "image_url" for content in msg.content)]
        messages = [msg for msg in messages if not any(content.type == "image_url" for content in msg.content)]
        messages.extend(image_messages)
        print(f"Reordered messages, images are now last")
    
    messages = convert_to_model_input(messages)
    print(f"input messages: {messages}")

    all_special_tokens = {x.content for x in model.tokenizer.added_tokens_decoder.values()}
    if mask_token not in all_special_tokens:
        new_tokens = [mask_token]
        new_tokens = list(set(new_tokens) - set(model.tokenizer.get_vocab().keys()))
        model.tokenizer.add_special_tokens({"additional_special_tokens": new_tokens}, replace_additional_special_tokens=False)
        assert model.tokenizer.added_tokens_decoder[len(model.tokenizer) - 1].content == mask_token
        print(model.tokenizer(new_tokens, add_special_tokens=False))

    mask_token_id = len(model.tokenizer) - 1
    if config.trainer.interleaved:
        batch = messages_to_batch(config, model.tokenizer, model, messages, resolution=resolution)
    else:
        batch = get_fixed_batch(config, model.tokenizer, model, messages, resolution=resolution)

    sampling_steps = request.sampling_steps
    sample_modality = batch["modality"]
    x0 = batch["input_ids"]
    x0_unmask = torch.ones_like(x0, dtype=torch.bool)
    txt_contains_mask = False
    for i in range(x0.shape[0]):
        if gen_img or img_contains_mask:
            modality_seq = sample_modality[i]
            changes = torch.diff(modality_seq)
            change_points = torch.where(changes != 0)[0] + 1
            change_points = torch.cat([torch.tensor([0], device=change_points.device), change_points, torch.tensor([len(modality_seq)], device=change_points.device)])
            
            sequences = []
            for start, end in zip(change_points[:-1], change_points[1:]):
                if modality_seq[start] == 1:
                    sequences.append((start.item(), end.item()))
            
            if sequences:
                last_start, last_end = sequences[-1]
                x0_unmask[i, last_start:last_end] = False
                print(f"Masked slice: {last_start}:{last_end}")
            else:
                print(f"WARNING: No sequences found")

            if img_contains_mask:
                def downscale_bool(arr: np.ndarray, D: int) -> np.ndarray:
                    if len(arr.shape) == 3:
                        print(f"Converting (H, W, C) to (H, W)")
                        arr = arr.sum(axis=-1)
                    H, W = arr.shape
                    assert H % D == 0 and W % D == 0, "H and W must be divisible by D"
                    return arr.reshape(H // D, D, W // D, D).any(axis=(1, 3))
                
                import math
                _res = int(math.sqrt(last_end - last_start) * config.model.downscale_ratio)
                if recent_mask.size != (_res, _res):
                    print(f"WARNING!! recent_mask.size: {recent_mask.size}, last_end - last_start: {last_end - last_start}")
                mask_arr = downscale_bool(np.array(recent_mask.convert("RGB").resize((resolution, resolution), resample=Image.Resampling.NEAREST)).astype(np.bool_), config.model.downscale_ratio)
                print(Im(mask_arr).save())
                mask_arr = torch.from_numpy(mask_arr).to(x0_unmask.device).reshape(-1).nonzero().squeeze()
                mask_arr = mask_arr + last_start
                x0_unmask[i, last_start:last_end] = True
                x0_unmask[i, mask_arr] = False
        
        if gen_img and not gen_txt and getattr(config.eval, "static_img_txt_demo", False):
            print(f"Unmasking all text positions for static_img_txt_demo: {x0_unmask[i, modality_seq == 0].sum().item()}")
            x0_unmask[i, modality_seq == 0] = True
        elif gen_txt and not getattr(config.eval, "static_img_txt_demo", False):
            bos_positions = torch.where(x0[i] == model.tokenizer.bos_token_id)[0]
            if len(bos_positions) == 0:
                continue
                
            last_bos = bos_positions[-2] if force_reorder else bos_positions[-1]
            eos_positions = torch.where((x0[i] == model.tokenizer.eos_token_id) & (torch.arange(len(x0[i]), device=x0.device) > last_bos))[0]
            
            print(f"BOS positions: {bos_positions}, EOS positions: {eos_positions}")
            unmask_to_eos = True
            if unmask_to_eos and len(eos_positions) > 0: 
                last_eos = eos_positions[0]
            else:
                last_eos = None # Mask everything after last BOS

            x0_unmask[i, last_bos+1:last_eos] = False
            if mask_eos and force_reorder:
                x0_unmask[i, last_bos+1:last_eos+3] = False
            print(f"Masked slice: {last_bos}:{last_eos}")

        to_mask = x0[i] == mask_token_id
        if to_mask.sum().item() > 0:
            x0_unmask[i, to_mask] = False
            print(f"Found {to_mask.sum().item()} text mask tokens")
            txt_contains_mask = True


    # Add metrics for x0_unmask[0]
    true_indices = torch.where(x0_unmask[0])[0]
    first_true = true_indices[0].item()
    last_true = true_indices[-1].item()
    total_true = x0_unmask[0].sum().item()
    
    masked_modalities = batch["modality"][0][x0_unmask[0]]
    zeros = (masked_modalities == 0).sum().item()
    ones = (masked_modalities == 1).sum().item()
    
    print(f"x0_unmask num unmasked: {total_true}, x0_unmask, first position: {first_true}, x0_unmask last position: {last_true}")
    print(f"x0_unmask num txt (0) count: {zeros}, x0_unmask num img (1) count: {ones}")
    print(f"Masking {((~x0_unmask) & (batch['sample_ids'] >= 0)).sum().item()} positions, modality shape: {batch['modality'].shape}")

    # Find first invalid sample ID, defaulting to full length if none found
    invalid_positions = (batch["sample_ids"][0].long() == -1).nonzero(as_tuple=True)[0]
    first_invalid_sample_id = invalid_positions[0].item() if len(invalid_positions) > 0 else len(batch["sample_ids"][0])
    print(f"First invalid sample ID position: {first_invalid_sample_id}")
    row_len = save_grid_image(x0_unmask[0][:first_invalid_sample_id], "x0_unmask_viz.png")
    save_grid_image(batch["modality"][0][:first_invalid_sample_id], "modality_viz.png", row_len=row_len)
    _sc = batch["sample_ids"][0].clone()
    _sc[_sc == 0] = 1
    _sc[_sc == -1] = 0
    save_grid_image(_sc, "sample_ids_viz.png", row_len=row_len)

    if request.use_reward_models:
        idx = 0
        bs = 1
        num_iter = 4
        from tensordict import TensorDict
        gen_batch = TensorDict.from_dict(batch, batch_size=[batch['input_ids'].shape[0]])
        text_samples_list = []
        img_samples_list = []
        _gen_batch = []
        for i in range(num_iter):
            _gen_batch.append(gen_batch[[idx]])
        gen_batch = torch.cat(_gen_batch, dim=0)

        for j in range(num_iter):
            _modality = gen_batch[[idx]].get("modality", None)
            _sample_ids = gen_batch[[idx]].get("sample_ids", None)
            if _modality is not None:
                _modality = _modality.to(model.device)
            if _sample_ids is not None:
                _sample_ids = _sample_ids.to(model.device)
            else:
                _sample_ids = torch.zeros_like(_modality)
            text_samples, img_samples, x = model._sample(
                text_only=False,
                num_steps=sampling_steps,
                batch_size_per_gpu=bs,
                modality=_modality,
                sample_ids=_sample_ids,
                x0=gen_batch["input_ids"][[idx]].to(model.device),
                x0_unmask=x0_unmask[[idx]].to(model.device),
                return_raw_data=True,
                allow_interleaved_conditional=True
            )
            gen_batch[[idx]]['input_ids'] = x
            text_samples_list.extend(text_samples)
            img_samples_list.extend(img_samples)
            print(f"Sampled {j + 1} / {num_iter}")

        text_samples_list = wrapped_batch_decode(
            model.tokenizer,
            torch.stack(text_samples_list, dim=0),
            clean_up_tokenization_spaces=True,
            skip_special_tokens=True,
            disable_mask_after_eos=True
        )

        text_samples_list = [x.replace("You are a highly intelligent multimodal AI with the ability to analyze and generate images.", "").removeprefix(" ") for x in text_samples_list]
        if isinstance(img_samples_list[0], Image.Image):
            img_tensors = []
            for img in img_samples_list:
                img_tensor = torch.tensor(np.array(img)).permute(2, 0, 1).float() / 255.0
                img_tensors.append(img_tensor.unsqueeze(0))
            img_samples_list = torch.cat(img_tensors, dim=0)
        else:
            img_samples_list = torch.cat(img_samples_list, dim=0)

        reward_config = config.eval.auto_enhance_reward_config
        rewards, raw_rewards = model.get_rewards(reward_config, img_samples_list, text_samples_list, batch=gen_batch, return_raw_rewards=True)

        gprint(f"Avg Rewards: {rewards}")

        sorted_indices = torch.argsort(rewards, descending=True).tolist()
        sorted_text_samples = [text_samples_list[i] for i in sorted_indices]
        sorted_img_samples = [img_samples_list[i] for i in sorted_indices]
        sorted_avg_rewards = [rewards[i] for i in sorted_indices]
        sorted_raw_rewards = {k: [raw_rewards[k][i] for i in sorted_indices] for k in raw_rewards}

        txt_samples = [sorted_text_samples[0]]
        img_samples = [Im(sorted_img_samples[0]).pil]
    else:
        txt_samples, img_samples = model._sample(
            text_only=False,
            num_steps=sampling_steps,
            batch_size_per_gpu=1,
            example_batch=batch,
            sample_batch_idx=0,
            modality=batch["modality"],
            sample_ids=batch["sample_ids"],
            allow_interleaved_conditional=True,
            x0_unmask=x0_unmask,
            x0=x0,
        )

        if not config.trainer.interleaved:
            txt_samples = model.tokenizer.batch_decode(txt_samples[..., model.static_txt_sl], remove_special_tokens=True)
            txt_samples[0] = txt_samples[0].replace("<unk>", "").strip()
            img_len = (resolution // config.model.downscale_ratio)**2
            img_samples = decode_latents(config, model.get_vae(), img_samples[..., -img_len:])
            assert img_samples.shape[0] == 1
            img_samples = [Im(img_samples[0]).pil]
    
    returned_message = ChatMessage(
        role="assistant",
        content=[]
    )
    if img_contains_mask or gen_img:
        print(f"Inference returned img_samples: {img_samples}")
        returned_message.content.append(ContentPart(
            type="image_url",
            image_url=img_samples[-1]
        ))

    if txt_contains_mask or not gen_img:
        print(f"Inference returned txt_samples: {txt_samples}")
        last_new_txt = ""
        for i, _txt in enumerate(txt_samples[-1].rsplit("<s>")):
            if len(_txt) > 0:
                _txt = _txt.replace("</s>", "").replace("<image>", "").replace("<s>", "").strip().replace('  ', ' ').replace('  ', ' ').replace(' .', '.').replace('\\n', ' ')
                _txt = re.sub(r'[^a-zA-Z. ]', '', _txt)
                if len(_txt) > 0:
                    last_new_txt = _txt

        returned_message.content.append(ContentPart(
            type="text",
            text=last_new_txt
        ))
    
    input_request.messages.append(returned_message)

    clear_cache()

    
    return input_request

@hydra.main(version_base=None, config_path="../configs", config_name="config")
@torch.no_grad()
def main(config=None):
    inference = setup(config, save_config=True)
    exit()
    inference([{"type": "text", "text": "Hello, how are you?"}])
    
if __name__ == "__main__":
    with breakpoint_on_error():
        main()