File size: 13,708 Bytes
131da64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
import base64
import io
import json
import shutil
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path

import numpy as np
import requests
import typer
from PIL import Image
from tqdm import tqdm
from image_utils import Im

typer.main.get_command_name = lambda name: name
app = typer.Typer(pretty_exceptions_show_locals=False)

def square_crop(image: Image.Image) -> Image.Image:
    """Crop the image to a square (centered)."""
    width, height = image.size
    side = min(width, height)
    left = (width - side) // 2
    top = (height - side) // 2
    right = left + side
    bottom = top + side
    return image.crop((left, top, right, bottom))

def process(image: Image.Image, desired_resolution: int = 512) -> Image.Image:
    """Square-crop and resize the image."""
    cropped_image = square_crop(image.convert("RGB"))
    return cropped_image.resize((desired_resolution, desired_resolution), Image.LANCZOS)

def encode_image(file: Path | io.BytesIO | Image.Image) -> dict:
    """Encode an image as base64 data in a dict of the form {'url': 'data:image/jpeg;base64,...'}."""
    if isinstance(file, Image.Image):
        buffered = io.BytesIO()
        file.save(buffered, format="JPEG")
        base64_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
    elif isinstance(file, Path):
        with file.open("rb") as img_file:
            base64_str = base64.b64encode(img_file.read()).decode("utf-8")
    else:
        base64_str = base64.b64encode(file.getvalue()).decode("utf-8")
    return {"url": f"data:image/jpeg;base64,{base64_str}"}

def encode_array_image(array: np.ndarray) -> dict:
    """Encode a mask array as base64 data in a dict of the form {'url': 'data:image/jpeg;base64,...'}."""
    if array.dtype == bool:
        array = array.astype(np.uint8) * 255
    im = Image.fromarray(array)
    buffered = io.BytesIO()
    im.save(buffered, format="JPEG", quality=95)
    base64_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
    return {"url": f"data:image/jpeg;base64,{base64_str}"}

def call_unidisc_api(
    image_path: Path | None,
    caption: str | None,
    mask_path: Path | None,
    cfg: dict,
) -> list:
    """
    Build the payload and call the UniDisc API, returning a list of
    output pieces. Each piece is a dict with either:
      {"type": "text", "text": "..."}
    or {"type": "image_url", "image_url": {"url": "data:image/jpeg;base64,..."}}
    """
    # Prepare message content as in reference code
    messages = []
    if caption:
        messages.append({"type": "text", "text": caption})

    if image_path and image_path.exists():
        resolution = int(cfg.get("resolution", 512))
        current_image = process(Image.open(image_path), resolution)
        img_data = encode_image(current_image)["url"]
        messages.append({
            "type": "image_url",
            "image_url": {"url": img_data},
            "is_mask": False
        })

        if mask_path and mask_path.exists():
            mask_array = np.array(Image.open(mask_path))
            mask_data_url = encode_array_image(mask_array)["url"]
            messages.append({
                "type": "image_url",
                "image_url": {"url": mask_data_url},
                "is_mask": True
            })

    config_payload = {
        "max_tokens": int(cfg.get("max_tokens", 32)),
        "resolution": int(cfg.get("resolution", 512)),
        "sampling_steps": int(cfg.get("sampling_steps", 32)),
        "top_p": float(cfg.get("top_p", 0.95)),
        "temperature": float(cfg.get("temperature", 0.9)),
        "maskgit_r_temp": float(cfg.get("maskgit_r_temp", 4.5)),
        "cfg": float(cfg.get("cfg", 2.5)),
        "sampler": cfg.get("sampler", "maskgit_nucleus"),
        "use_reward_models": bool(cfg.get("use_reward_models", False)),
    }

    port = cfg.get('port', 8001)
    hostname = f"{port}" if ":" in port else f"localhost:{port}"

    payload = {
        "messages": [{"role": "user", "content": messages}],
        "model": "unidisc",
        **config_payload
    }

    api_url = f"http://{hostname}/v1/chat/completions"
    response = requests.post(api_url, json=payload)
    if response.status_code != 200:
        return [{"type": "text", "text": f"API Error: {response.text}", "error": True}]

    response_json = response.json()
    if "choices" not in response_json:
        return [{"type": "text", "text": f"Malformed response: {response.text}", "error": True}]

    # The reference code expects "content" to be a list with items typed "text" or "image_url"
    content = response_json["choices"][0]["message"]["content"]
    if isinstance(content, list):
        return content
    else:
        # If it's not a list, wrap it
        return [{"type": "text", "text": content}]

def decode_image_base64(url_str: str) -> Image.Image:
    """Given a 'data:image/...;base64,...' string, return the PIL.Image."""
    # e.g. "data:image/jpeg;base64,xxxx..."
    base64_part = url_str.split("base64,")[-1]
    raw = base64.b64decode(base64_part)
    return Image.open(io.BytesIO(raw))

def run_inference_for_folder(
    folder: Path,
    output_folder: Path,
    cfg: dict,
    use_image: bool,
    use_img_mask: bool,
    use_caption: bool,
    use_cap_mask: bool,
):
    """
    For a single folder with an image, caption, and mask, call the API,
    then write out the returned content (images/text).
    """

    image_file = None
    caption_file = None
    mask_file = None
    for f in folder.iterdir():
        name_lower = f.name.lower()
        if name_lower.startswith("image") and f.suffix.lower() in [".jpg", ".jpeg", ".png"]:
            image_file = f
        if name_lower.startswith("mask") and f.suffix.lower() == ".png":
            mask_file = f
        if name_lower.startswith("caption") and f.suffix.lower() in [".txt"]:
            caption_file = f
        if name_lower.startswith("mask_caption") and f.suffix.lower() == ".txt":
            mask_caption_file = f

    results = call_unidisc_api(
        image_path=image_file if use_image else None,
        caption=mask_caption_file.read_text().strip() if (mask_caption_file and use_cap_mask) else (caption_file.read_text().strip() if (caption_file and use_caption) else None),
        mask_path=mask_file if use_img_mask else None,
        cfg=cfg,
    )
    output_folder.mkdir(parents=True, exist_ok=True)

    text_parts = []
    img_count = 0
    for i, item in enumerate(results):
        if item["type"] == "text":
            text_parts.append(item["text"])
        elif item["type"] == "image_url":
            out_img = decode_image_base64(item["image_url"]["url"])
            out_img_name = output_folder / f"image.png"
            out_img.save(out_img_name)
            img_count += 1
        if "error" in item:
            text_parts.append(item["text"])

    cfg['mode'] = f"{'img_' if use_image else ''}{'imgmask_' if use_img_mask else ''}{'cap_' if use_caption else ''}{'capmask_' if use_cap_mask else ''}"
    cfg['use_image'] = use_image
    cfg['use_img_mask'] = use_img_mask
    cfg['use_caption'] = use_caption
    cfg['use_cap_mask'] = use_cap_mask

    if len(text_parts) > 0:
        out_txt = output_folder / "caption.txt"
        out_txt.write_text("\n".join(text_parts))
    else:
        shutil.copy(caption_file, output_folder / "caption.txt")
        print(f"No text found, copied input caption to output: mode={cfg['mode']}")

    if img_count == 0:
        shutil.copy(image_file, output_folder / "image.png")
        print(f"No image found, copied input image to output: mode={cfg['mode']}")
    
    config_file = output_folder / "config.json"
    config_file.write_text(json.dumps(cfg, indent=2))

    input_img = (mask_file if use_img_mask else (image_file if use_image else None))
    input_txt = (mask_caption_file if use_cap_mask else (caption_file if use_caption else None))

    input_img = Im(input_img) if input_img else Im.new(h=512, w=512)
    input_txt = input_txt.read_text().strip() if input_txt else "Empty caption"

    input_img.save(output_folder / "input_image.png")
    (output_folder / "input_caption.txt").write_text(input_txt)

@app.command()
def main(
    input_dir: Path | None = None,
    output_dir: Path | None = None,
    param_file: Path | None = None,
    num_pairs: int | None = None,
    num_workers: int = 32,
    batch_sleep: float = 0.2,
    use_image: bool = False,
    use_img_mask: bool = False,
    use_caption: bool = False,
    use_cap_mask: bool = False,
    iterate_over_modes: bool = False,
    single_config: bool = False,
):
    """
    Generate datasets by calling the UniDisc API on each (image, caption, mask) triplet in input_dir.
    
    Modified version:
      - Queues tasks in order on a single global ThreadPoolExecutor.
      - After queueing each batch, sleeps for a bit.
      - Does not wait indefinitely for a batch to finish before moving on.
    """

    if use_img_mask:
        assert use_image

    if input_dir is None or output_dir is None:
        raise ValueError("Both input_dir and output_dir must be provided.")

    if param_file is not None:
        all_configs = json.loads(param_file.read_text())
        if not isinstance(all_configs, list):
            raise ValueError("param_file must contain a JSON list of configs.")
    else:
        all_configs = []
        for cfg in [2.5]:
            # for sampler in ["maskgit_nucleus", "maskgit"]:
            for port in ["babel-10-9:8000", "babel-6-29:8001"]:
                all_configs.append(dict(port=port, cfg=cfg))

    if single_config:
        all_configs = [{'port': 'localhost:8000', 'cfg': 2.5}]

    subfolders = sorted([f for f in input_dir.iterdir() if f.is_dir()])
    if num_pairs is not None:
        subfolders = subfolders[:num_pairs]

    configs = []
    from decoupled_utils import sanitize_filename
    for i, cfg in enumerate(all_configs):
        # Build config name from key/values if not explicitly provided
        if "name" not in cfg:
            # Sanitize values and join with underscores
            cfg_name = "_".join(
                f"{k}={str(v).replace('/', '_').replace(' ', '_')}" 
                for k, v in sorted(cfg.items())
            )
        else:
            cfg_name = cfg["name"]
        cfg_output_dir = output_dir / sanitize_filename(cfg_name)
        cfg_output_dir.mkdir(parents=True, exist_ok=True)
        configs.append((cfg, cfg_name, cfg_output_dir))

    # Compute the batch size so each config receives a fair share of workers per batch.
    # E.g., if num_workers=10 and there are 2 configs then each config gets ~5 workers in a batch.
    batch_size = max(1, num_workers // len(configs))
    total_folders = len(subfolders)
    print(f"Processing {total_folders} folders across {len(configs)} configs with batch size {batch_size} per config.")

    modes = []
    if iterate_over_modes:
        modes.append(dict(use_image=False, use_img_mask=False, use_caption=True, use_cap_mask=False)) # T2I
        modes.append(dict(use_image=True, use_img_mask=False, use_caption=False, use_cap_mask=False)) # I2T
        modes.append(dict(use_image=True, use_img_mask=True, use_caption=True, use_cap_mask=True)) # Both masked
        modes.append(dict(use_image=True, use_img_mask=False, use_caption=True, use_cap_mask=True))
        modes.append(dict(use_image=False, use_img_mask=False, use_caption=True, use_cap_mask=False))
    else:
        modes.append(dict(use_image=use_image, use_img_mask=use_img_mask, use_caption=use_caption, use_cap_mask=use_cap_mask))

    # Use a single, global ThreadPoolExecutor so we can submit tasks in order
    futures = []
    with ThreadPoolExecutor(max_workers=num_workers) as executor:
        for batch_start in range(0, total_folders, batch_size):
            batch_folders = subfolders[batch_start : batch_start + batch_size]
            for cfg, cfg_name, cfg_out in configs:
                for folder in batch_folders:
                    for mode in modes:
                        use_image = mode['use_image']
                        use_img_mask = mode['use_img_mask']
                        use_caption = mode['use_caption']
                        use_cap_mask = mode['use_cap_mask']
                        key = ""
                        if use_img_mask:
                            key += "imgmask_"
                        elif use_image:
                            key += "img_"

                        if use_cap_mask:
                            key += "capmask_"
                        elif use_caption:
                            key += "cap_"

                        key = key.removesuffix('_')
                        folder_output = cfg_out / f"{key}__{folder.name}"
                        futures.append(
                            executor.submit(
                                run_inference_for_folder,
                                folder=folder,
                                output_folder=folder_output,
                                cfg=cfg,
                                use_image=use_image,
                                use_img_mask=use_img_mask,
                                use_caption=use_caption,
                                use_cap_mask=use_cap_mask,
                            )
                        )
            print(f"Queued batch {batch_start} to {batch_start + len(batch_folders)}. Sleeping for {batch_sleep} seconds...")
            time.sleep(batch_sleep)

        for future in tqdm(as_completed(futures), total=len(futures), desc=f"Processing folders..."):
            future.result()

    print("All processing complete.")

if __name__ == "__main__":
    app()