|
import json |
|
from pathlib import Path |
|
|
|
import torch |
|
import typer |
|
from image_utils import Im |
|
from omegaconf import OmegaConf |
|
from tqdm import tqdm |
|
from accelerate.state import PartialState |
|
from accelerate.utils import gather_object |
|
from PIL import Image |
|
|
|
from decoupled_utils import set_global_breakpoint |
|
from model import Diffusion |
|
|
|
typer.main.get_command_name = lambda name: name |
|
app = typer.Typer(pretty_exceptions_show_locals=False) |
|
|
|
set_global_breakpoint() |
|
|
|
@app.command() |
|
def main( |
|
input_dir: Path | None = None, |
|
output_file: Path | None = None, |
|
batch_size: int = 32, |
|
resolution: int = 512, |
|
num_pairs: int | None = None, |
|
num_dirs: int | None = None, |
|
): |
|
""" |
|
Process datasets contained in subdirectories of `input_dir`, distributed across multiple GPUs. |
|
Each GPU processes complete datasets for better efficiency. |
|
""" |
|
distributed_state = PartialState() |
|
device = distributed_state.device |
|
dtype = torch.bfloat16 |
|
|
|
|
|
model = Diffusion(None, None, device, disable_init=True) |
|
model.device = device |
|
model.dtype = dtype |
|
|
|
reward_config = OmegaConf.create({ |
|
"dfn_score": 1.0, |
|
"hpsv2_score": 1.0, |
|
"clip_score": 1.0, |
|
"laion_aesthetic_score": 1.0, |
|
"text_reward_model_score": 1.0 |
|
}) |
|
|
|
all_rewards = {} |
|
|
|
dataset_dirs = sorted([p for p in input_dir.iterdir() if p.is_dir()], key=lambda p: p.name) |
|
if not dataset_dirs: |
|
if distributed_state.is_main_process: |
|
print("No dataset directories found in the input directory.") |
|
raise typer.Exit() |
|
|
|
if num_dirs is not None: |
|
dataset_dirs = dataset_dirs[:num_dirs] |
|
|
|
|
|
with distributed_state.split_between_processes(dataset_dirs) as process_dataset_dirs: |
|
for ds_dir in tqdm(process_dataset_dirs, desc=f"Processing datasets (GPU {distributed_state.process_index})"): |
|
if distributed_state.is_main_process: |
|
print(f"Processing dataset: {ds_dir.name}") |
|
|
|
pair_dirs = sorted([p for p in ds_dir.iterdir() if p.is_dir()], key=lambda p: p.name) |
|
if num_pairs is not None: |
|
pair_dirs = pair_dirs[:num_pairs] |
|
if not pair_dirs: |
|
if distributed_state.is_main_process: |
|
print(f" No pair subdirectories found in {ds_dir.name}, skipping.") |
|
continue |
|
|
|
images = [] |
|
captions = [] |
|
for pair_dir in sorted(pair_dirs, key=lambda p: p.name): |
|
image_path = pair_dir / "image.png" |
|
caption_path = pair_dir / "caption.txt" |
|
|
|
if not (image_path.exists() and caption_path.exists()): |
|
print(f" Skipping {pair_dir}: missing image.png or caption.txt") |
|
continue |
|
|
|
try: |
|
img = Image.open(image_path) |
|
if resolution != img.height or resolution != img.width: |
|
print(f"WARNING!!! Image resolution {img.height}x{img.width} does not match resolution {resolution}x{resolution}") |
|
min_dim = min(img.width, img.height) |
|
left = (img.width - min_dim) // 2 |
|
top = (img.height - min_dim) // 2 |
|
img = img.crop((left, top, left + min_dim, top + min_dim)) |
|
img = img.resize((resolution, resolution), Image.Resampling.LANCZOS) |
|
images.append(Im(img).torch.unsqueeze(0)) |
|
except Exception as e: |
|
print(f"Error processing image {image_path}: {e}") |
|
continue |
|
|
|
try: |
|
caption = caption_path.read_text().strip() |
|
captions.append(caption) |
|
except Exception as e: |
|
print(f"Error reading caption {caption_path}: {e}") |
|
continue |
|
|
|
num_pairs = len(images) |
|
if num_pairs == 0: |
|
print(f"No valid pairs found in dataset {ds_dir.name}, skipping.") |
|
continue |
|
|
|
dataset_reward_batches = [] |
|
dataset_raw_rewards = [] |
|
for i in tqdm(range(0, num_pairs, batch_size), desc="Processing pairs"): |
|
batch_imgs = torch.cat(images[i : i + batch_size], dim=0).to(device) / 255.0 |
|
batch_texts = captions[i : i + batch_size] |
|
with torch.inference_mode(): |
|
rewards, raw_rewards = model.get_rewards(reward_config, batch_imgs, batch_texts, None, return_raw_rewards=True) |
|
dataset_reward_batches.append(rewards.cpu()) |
|
dataset_raw_rewards.append(raw_rewards) |
|
|
|
dataset_rewards_tensor = torch.cat(dataset_reward_batches, dim=0) |
|
dataset_raw_rewards_dict = {} |
|
for key in raw_rewards.keys(): |
|
dataset_raw_rewards_dict[key] = torch.cat( |
|
[batch[key] for batch in dataset_raw_rewards], dim=0 |
|
) |
|
|
|
all_rewards[ds_dir.name] = { |
|
"rewards": dataset_rewards_tensor.tolist(), |
|
"raw_rewards": {k: v.tolist() for k, v in dataset_raw_rewards_dict.items()}, |
|
"folder_names": [f.name for f in pair_dirs], |
|
"folder_paths": [f.as_posix() for f in pair_dirs] |
|
} |
|
if distributed_state.is_main_process: |
|
print(f"Finished processing {num_pairs} pairs from {ds_dir.name}") |
|
|
|
gathered_rewards = gather_object([all_rewards]) |
|
|
|
all_keys = set() |
|
all_gathered_rewards = {} |
|
for i in range(len(gathered_rewards)): |
|
assert len(set(gathered_rewards[i].keys()).intersection(all_keys)) == 0 |
|
all_keys.update(gathered_rewards[i].keys()) |
|
all_gathered_rewards.update(gathered_rewards[i]) |
|
|
|
gathered_rewards = all_gathered_rewards |
|
|
|
if distributed_state.is_main_process: |
|
print("All rewards:") |
|
print(json.dumps(gathered_rewards, indent=2)) |
|
|
|
try: |
|
output_file.parent.mkdir(parents=True, exist_ok=True) |
|
with open(output_file, "w") as f: |
|
json.dump(gathered_rewards, f, indent=2) |
|
print(f"Rewards saved to {output_file}") |
|
except Exception as e: |
|
print(f"Error saving rewards to file: {e}") |
|
|
|
|
|
if __name__ == "__main__": |
|
app() |