|
import base64 |
|
import io |
|
import json |
|
import shutil |
|
import time |
|
from concurrent.futures import ThreadPoolExecutor, as_completed |
|
from pathlib import Path |
|
|
|
import numpy as np |
|
import requests |
|
import typer |
|
from PIL import Image |
|
from tqdm import tqdm |
|
from image_utils import Im |
|
|
|
typer.main.get_command_name = lambda name: name |
|
app = typer.Typer(pretty_exceptions_show_locals=False) |
|
|
|
def square_crop(image: Image.Image) -> Image.Image: |
|
"""Crop the image to a square (centered).""" |
|
width, height = image.size |
|
side = min(width, height) |
|
left = (width - side) // 2 |
|
top = (height - side) // 2 |
|
right = left + side |
|
bottom = top + side |
|
return image.crop((left, top, right, bottom)) |
|
|
|
def process(image: Image.Image, desired_resolution: int = 512) -> Image.Image: |
|
"""Square-crop and resize the image.""" |
|
cropped_image = square_crop(image.convert("RGB")) |
|
return cropped_image.resize((desired_resolution, desired_resolution), Image.LANCZOS) |
|
|
|
def encode_image(file: Path | io.BytesIO | Image.Image) -> dict: |
|
"""Encode an image as base64 data in a dict of the form {'url': 'data:image/jpeg;base64,...'}.""" |
|
if isinstance(file, Image.Image): |
|
buffered = io.BytesIO() |
|
file.save(buffered, format="JPEG") |
|
base64_str = base64.b64encode(buffered.getvalue()).decode("utf-8") |
|
elif isinstance(file, Path): |
|
with file.open("rb") as img_file: |
|
base64_str = base64.b64encode(img_file.read()).decode("utf-8") |
|
else: |
|
base64_str = base64.b64encode(file.getvalue()).decode("utf-8") |
|
return {"url": f"data:image/jpeg;base64,{base64_str}"} |
|
|
|
def encode_array_image(array: np.ndarray) -> dict: |
|
"""Encode a mask array as base64 data in a dict of the form {'url': 'data:image/jpeg;base64,...'}.""" |
|
if array.dtype == bool: |
|
array = array.astype(np.uint8) * 255 |
|
im = Image.fromarray(array) |
|
buffered = io.BytesIO() |
|
im.save(buffered, format="JPEG", quality=95) |
|
base64_str = base64.b64encode(buffered.getvalue()).decode("utf-8") |
|
return {"url": f"data:image/jpeg;base64,{base64_str}"} |
|
|
|
def call_unidisc_api( |
|
image_path: Path | None, |
|
caption: str | None, |
|
mask_path: Path | None, |
|
cfg: dict, |
|
) -> list: |
|
""" |
|
Build the payload and call the UniDisc API, returning a list of |
|
output pieces. Each piece is a dict with either: |
|
{"type": "text", "text": "..."} |
|
or {"type": "image_url", "image_url": {"url": "data:image/jpeg;base64,..."}} |
|
""" |
|
|
|
messages = [] |
|
if caption: |
|
messages.append({"type": "text", "text": caption}) |
|
|
|
if image_path and image_path.exists(): |
|
resolution = int(cfg.get("resolution", 512)) |
|
current_image = process(Image.open(image_path), resolution) |
|
img_data = encode_image(current_image)["url"] |
|
messages.append({ |
|
"type": "image_url", |
|
"image_url": {"url": img_data}, |
|
"is_mask": False |
|
}) |
|
|
|
if mask_path and mask_path.exists(): |
|
mask_array = np.array(Image.open(mask_path)) |
|
mask_data_url = encode_array_image(mask_array)["url"] |
|
messages.append({ |
|
"type": "image_url", |
|
"image_url": {"url": mask_data_url}, |
|
"is_mask": True |
|
}) |
|
|
|
config_payload = { |
|
"max_tokens": int(cfg.get("max_tokens", 32)), |
|
"resolution": int(cfg.get("resolution", 512)), |
|
"sampling_steps": int(cfg.get("sampling_steps", 32)), |
|
"top_p": float(cfg.get("top_p", 0.95)), |
|
"temperature": float(cfg.get("temperature", 0.9)), |
|
"maskgit_r_temp": float(cfg.get("maskgit_r_temp", 4.5)), |
|
"cfg": float(cfg.get("cfg", 2.5)), |
|
"sampler": cfg.get("sampler", "maskgit_nucleus"), |
|
"use_reward_models": bool(cfg.get("use_reward_models", False)), |
|
} |
|
|
|
port = cfg.get('port', 8001) |
|
hostname = f"{port}" if ":" in port else f"localhost:{port}" |
|
|
|
payload = { |
|
"messages": [{"role": "user", "content": messages}], |
|
"model": "unidisc", |
|
**config_payload |
|
} |
|
|
|
api_url = f"http://{hostname}/v1/chat/completions" |
|
response = requests.post(api_url, json=payload) |
|
if response.status_code != 200: |
|
return [{"type": "text", "text": f"API Error: {response.text}", "error": True}] |
|
|
|
response_json = response.json() |
|
if "choices" not in response_json: |
|
return [{"type": "text", "text": f"Malformed response: {response.text}", "error": True}] |
|
|
|
|
|
content = response_json["choices"][0]["message"]["content"] |
|
if isinstance(content, list): |
|
return content |
|
else: |
|
|
|
return [{"type": "text", "text": content}] |
|
|
|
def decode_image_base64(url_str: str) -> Image.Image: |
|
"""Given a 'data:image/...;base64,...' string, return the PIL.Image.""" |
|
|
|
base64_part = url_str.split("base64,")[-1] |
|
raw = base64.b64decode(base64_part) |
|
return Image.open(io.BytesIO(raw)) |
|
|
|
def run_inference_for_folder( |
|
folder: Path, |
|
output_folder: Path, |
|
cfg: dict, |
|
use_image: bool, |
|
use_img_mask: bool, |
|
use_caption: bool, |
|
use_cap_mask: bool, |
|
): |
|
""" |
|
For a single folder with an image, caption, and mask, call the API, |
|
then write out the returned content (images/text). |
|
""" |
|
|
|
image_file = None |
|
caption_file = None |
|
mask_file = None |
|
for f in folder.iterdir(): |
|
name_lower = f.name.lower() |
|
if name_lower.startswith("image") and f.suffix.lower() in [".jpg", ".jpeg", ".png"]: |
|
image_file = f |
|
if name_lower.startswith("mask") and f.suffix.lower() == ".png": |
|
mask_file = f |
|
if name_lower.startswith("caption") and f.suffix.lower() in [".txt"]: |
|
caption_file = f |
|
if name_lower.startswith("mask_caption") and f.suffix.lower() == ".txt": |
|
mask_caption_file = f |
|
|
|
results = call_unidisc_api( |
|
image_path=image_file if use_image else None, |
|
caption=mask_caption_file.read_text().strip() if (mask_caption_file and use_cap_mask) else (caption_file.read_text().strip() if (caption_file and use_caption) else None), |
|
mask_path=mask_file if use_img_mask else None, |
|
cfg=cfg, |
|
) |
|
output_folder.mkdir(parents=True, exist_ok=True) |
|
|
|
text_parts = [] |
|
img_count = 0 |
|
for i, item in enumerate(results): |
|
if item["type"] == "text": |
|
text_parts.append(item["text"]) |
|
elif item["type"] == "image_url": |
|
out_img = decode_image_base64(item["image_url"]["url"]) |
|
out_img_name = output_folder / f"image.png" |
|
out_img.save(out_img_name) |
|
img_count += 1 |
|
if "error" in item: |
|
text_parts.append(item["text"]) |
|
|
|
cfg['mode'] = f"{'img_' if use_image else ''}{'imgmask_' if use_img_mask else ''}{'cap_' if use_caption else ''}{'capmask_' if use_cap_mask else ''}" |
|
cfg['use_image'] = use_image |
|
cfg['use_img_mask'] = use_img_mask |
|
cfg['use_caption'] = use_caption |
|
cfg['use_cap_mask'] = use_cap_mask |
|
|
|
if len(text_parts) > 0: |
|
out_txt = output_folder / "caption.txt" |
|
out_txt.write_text("\n".join(text_parts)) |
|
else: |
|
shutil.copy(caption_file, output_folder / "caption.txt") |
|
print(f"No text found, copied input caption to output: mode={cfg['mode']}") |
|
|
|
if img_count == 0: |
|
shutil.copy(image_file, output_folder / "image.png") |
|
print(f"No image found, copied input image to output: mode={cfg['mode']}") |
|
|
|
config_file = output_folder / "config.json" |
|
config_file.write_text(json.dumps(cfg, indent=2)) |
|
|
|
input_img = (mask_file if use_img_mask else (image_file if use_image else None)) |
|
input_txt = (mask_caption_file if use_cap_mask else (caption_file if use_caption else None)) |
|
|
|
input_img = Im(input_img) if input_img else Im.new(h=512, w=512) |
|
input_txt = input_txt.read_text().strip() if input_txt else "Empty caption" |
|
|
|
input_img.save(output_folder / "input_image.png") |
|
(output_folder / "input_caption.txt").write_text(input_txt) |
|
|
|
@app.command() |
|
def main( |
|
input_dir: Path | None = None, |
|
output_dir: Path | None = None, |
|
param_file: Path | None = None, |
|
num_pairs: int | None = None, |
|
num_workers: int = 32, |
|
batch_sleep: float = 0.2, |
|
use_image: bool = False, |
|
use_img_mask: bool = False, |
|
use_caption: bool = False, |
|
use_cap_mask: bool = False, |
|
iterate_over_modes: bool = False, |
|
single_config: bool = False, |
|
): |
|
""" |
|
Generate datasets by calling the UniDisc API on each (image, caption, mask) triplet in input_dir. |
|
|
|
Modified version: |
|
- Queues tasks in order on a single global ThreadPoolExecutor. |
|
- After queueing each batch, sleeps for a bit. |
|
- Does not wait indefinitely for a batch to finish before moving on. |
|
""" |
|
|
|
if use_img_mask: |
|
assert use_image |
|
|
|
if input_dir is None or output_dir is None: |
|
raise ValueError("Both input_dir and output_dir must be provided.") |
|
|
|
if param_file is not None: |
|
all_configs = json.loads(param_file.read_text()) |
|
if not isinstance(all_configs, list): |
|
raise ValueError("param_file must contain a JSON list of configs.") |
|
else: |
|
all_configs = [] |
|
for cfg in [2.5]: |
|
|
|
for port in ["babel-10-9:8000", "babel-6-29:8001"]: |
|
all_configs.append(dict(port=port, cfg=cfg)) |
|
|
|
if single_config: |
|
all_configs = [{'port': 'localhost:8000', 'cfg': 2.5}] |
|
|
|
subfolders = sorted([f for f in input_dir.iterdir() if f.is_dir()]) |
|
if num_pairs is not None: |
|
subfolders = subfolders[:num_pairs] |
|
|
|
configs = [] |
|
from decoupled_utils import sanitize_filename |
|
for i, cfg in enumerate(all_configs): |
|
|
|
if "name" not in cfg: |
|
|
|
cfg_name = "_".join( |
|
f"{k}={str(v).replace('/', '_').replace(' ', '_')}" |
|
for k, v in sorted(cfg.items()) |
|
) |
|
else: |
|
cfg_name = cfg["name"] |
|
cfg_output_dir = output_dir / sanitize_filename(cfg_name) |
|
cfg_output_dir.mkdir(parents=True, exist_ok=True) |
|
configs.append((cfg, cfg_name, cfg_output_dir)) |
|
|
|
|
|
|
|
batch_size = max(1, num_workers // len(configs)) |
|
total_folders = len(subfolders) |
|
print(f"Processing {total_folders} folders across {len(configs)} configs with batch size {batch_size} per config.") |
|
|
|
modes = [] |
|
if iterate_over_modes: |
|
modes.append(dict(use_image=False, use_img_mask=False, use_caption=True, use_cap_mask=False)) |
|
modes.append(dict(use_image=True, use_img_mask=False, use_caption=False, use_cap_mask=False)) |
|
modes.append(dict(use_image=True, use_img_mask=True, use_caption=True, use_cap_mask=True)) |
|
modes.append(dict(use_image=True, use_img_mask=False, use_caption=True, use_cap_mask=True)) |
|
modes.append(dict(use_image=False, use_img_mask=False, use_caption=True, use_cap_mask=False)) |
|
else: |
|
modes.append(dict(use_image=use_image, use_img_mask=use_img_mask, use_caption=use_caption, use_cap_mask=use_cap_mask)) |
|
|
|
|
|
futures = [] |
|
with ThreadPoolExecutor(max_workers=num_workers) as executor: |
|
for batch_start in range(0, total_folders, batch_size): |
|
batch_folders = subfolders[batch_start : batch_start + batch_size] |
|
for cfg, cfg_name, cfg_out in configs: |
|
for folder in batch_folders: |
|
for mode in modes: |
|
use_image = mode['use_image'] |
|
use_img_mask = mode['use_img_mask'] |
|
use_caption = mode['use_caption'] |
|
use_cap_mask = mode['use_cap_mask'] |
|
key = "" |
|
if use_img_mask: |
|
key += "imgmask_" |
|
elif use_image: |
|
key += "img_" |
|
|
|
if use_cap_mask: |
|
key += "capmask_" |
|
elif use_caption: |
|
key += "cap_" |
|
|
|
key = key.removesuffix('_') |
|
folder_output = cfg_out / f"{key}__{folder.name}" |
|
futures.append( |
|
executor.submit( |
|
run_inference_for_folder, |
|
folder=folder, |
|
output_folder=folder_output, |
|
cfg=cfg, |
|
use_image=use_image, |
|
use_img_mask=use_img_mask, |
|
use_caption=use_caption, |
|
use_cap_mask=use_cap_mask, |
|
) |
|
) |
|
print(f"Queued batch {batch_start} to {batch_start + len(batch_folders)}. Sleeping for {batch_sleep} seconds...") |
|
time.sleep(batch_sleep) |
|
|
|
for future in tqdm(as_completed(futures), total=len(futures), desc=f"Processing folders..."): |
|
future.result() |
|
|
|
print("All processing complete.") |
|
|
|
if __name__ == "__main__": |
|
app() |