unidisc / demo /scoring /generate_rewards.py
aswerdlow's picture
Initial commit
131da64
import json
from pathlib import Path
import torch
import typer
from image_utils import Im
from omegaconf import OmegaConf
from tqdm import tqdm
from accelerate.state import PartialState
from accelerate.utils import gather_object
from PIL import Image
from decoupled_utils import set_global_breakpoint
from model import Diffusion
typer.main.get_command_name = lambda name: name
app = typer.Typer(pretty_exceptions_show_locals=False)
set_global_breakpoint()
@app.command()
def main(
input_dir: Path | None = None,
output_file: Path | None = None,
batch_size: int = 32,
resolution: int = 512,
num_pairs: int | None = None,
num_dirs: int | None = None,
):
"""
Process datasets contained in subdirectories of `input_dir`, distributed across multiple GPUs.
Each GPU processes complete datasets for better efficiency.
"""
distributed_state = PartialState()
device = distributed_state.device
dtype = torch.bfloat16
# Initialize model without Accelerator
model = Diffusion(None, None, device, disable_init=True)
model.device = device
model.dtype = dtype
reward_config = OmegaConf.create({
"dfn_score": 1.0,
"hpsv2_score": 1.0,
"clip_score": 1.0,
"laion_aesthetic_score": 1.0,
"text_reward_model_score": 1.0
})
all_rewards = {}
# Get all dataset directories and distribute them across GPUs
dataset_dirs = sorted([p for p in input_dir.iterdir() if p.is_dir()], key=lambda p: p.name)
if not dataset_dirs:
if distributed_state.is_main_process:
print("No dataset directories found in the input directory.")
raise typer.Exit()
if num_dirs is not None:
dataset_dirs = dataset_dirs[:num_dirs]
# Split datasets across processes
with distributed_state.split_between_processes(dataset_dirs) as process_dataset_dirs:
for ds_dir in tqdm(process_dataset_dirs, desc=f"Processing datasets (GPU {distributed_state.process_index})"):
if distributed_state.is_main_process:
print(f"Processing dataset: {ds_dir.name}")
pair_dirs = sorted([p for p in ds_dir.iterdir() if p.is_dir()], key=lambda p: p.name)
if num_pairs is not None:
pair_dirs = pair_dirs[:num_pairs]
if not pair_dirs:
if distributed_state.is_main_process:
print(f" No pair subdirectories found in {ds_dir.name}, skipping.")
continue
images = []
captions = []
for pair_dir in sorted(pair_dirs, key=lambda p: p.name):
image_path = pair_dir / "image.png"
caption_path = pair_dir / "caption.txt"
if not (image_path.exists() and caption_path.exists()):
print(f" Skipping {pair_dir}: missing image.png or caption.txt")
continue
try:
img = Image.open(image_path)
if resolution != img.height or resolution != img.width:
print(f"WARNING!!! Image resolution {img.height}x{img.width} does not match resolution {resolution}x{resolution}")
min_dim = min(img.width, img.height)
left = (img.width - min_dim) // 2
top = (img.height - min_dim) // 2
img = img.crop((left, top, left + min_dim, top + min_dim))
img = img.resize((resolution, resolution), Image.Resampling.LANCZOS)
images.append(Im(img).torch.unsqueeze(0))
except Exception as e:
print(f"Error processing image {image_path}: {e}")
continue
try:
caption = caption_path.read_text().strip()
captions.append(caption)
except Exception as e:
print(f"Error reading caption {caption_path}: {e}")
continue
num_pairs = len(images)
if num_pairs == 0:
print(f"No valid pairs found in dataset {ds_dir.name}, skipping.")
continue
dataset_reward_batches = []
dataset_raw_rewards = []
for i in tqdm(range(0, num_pairs, batch_size), desc="Processing pairs"):
batch_imgs = torch.cat(images[i : i + batch_size], dim=0).to(device) / 255.0
batch_texts = captions[i : i + batch_size]
with torch.inference_mode():
rewards, raw_rewards = model.get_rewards(reward_config, batch_imgs, batch_texts, None, return_raw_rewards=True)
dataset_reward_batches.append(rewards.cpu())
dataset_raw_rewards.append(raw_rewards)
dataset_rewards_tensor = torch.cat(dataset_reward_batches, dim=0)
dataset_raw_rewards_dict = {}
for key in raw_rewards.keys():
dataset_raw_rewards_dict[key] = torch.cat(
[batch[key] for batch in dataset_raw_rewards], dim=0
)
all_rewards[ds_dir.name] = {
"rewards": dataset_rewards_tensor.tolist(),
"raw_rewards": {k: v.tolist() for k, v in dataset_raw_rewards_dict.items()},
"folder_names": [f.name for f in pair_dirs],
"folder_paths": [f.as_posix() for f in pair_dirs]
}
if distributed_state.is_main_process:
print(f"Finished processing {num_pairs} pairs from {ds_dir.name}")
gathered_rewards = gather_object([all_rewards])
all_keys = set()
all_gathered_rewards = {}
for i in range(len(gathered_rewards)):
assert len(set(gathered_rewards[i].keys()).intersection(all_keys)) == 0
all_keys.update(gathered_rewards[i].keys())
all_gathered_rewards.update(gathered_rewards[i])
gathered_rewards = all_gathered_rewards
if distributed_state.is_main_process:
print("All rewards:")
print(json.dumps(gathered_rewards, indent=2))
try:
output_file.parent.mkdir(parents=True, exist_ok=True)
with open(output_file, "w") as f:
json.dump(gathered_rewards, f, indent=2)
print(f"Rewards saved to {output_file}")
except Exception as e:
print(f"Error saving rewards to file: {e}")
if __name__ == "__main__":
app()