File size: 6,485 Bytes
131da64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import json
from pathlib import Path

import torch
import typer
from image_utils import Im
from omegaconf import OmegaConf
from tqdm import tqdm
from accelerate.state import PartialState
from accelerate.utils import gather_object
from PIL import Image

from decoupled_utils import set_global_breakpoint
from model import Diffusion

typer.main.get_command_name = lambda name: name
app = typer.Typer(pretty_exceptions_show_locals=False)

set_global_breakpoint()

@app.command()
def main(
    input_dir: Path | None = None,
    output_file: Path | None = None,
    batch_size: int = 32,
    resolution: int = 512,
    num_pairs: int | None = None,
    num_dirs: int | None = None,
):
    """
    Process datasets contained in subdirectories of `input_dir`, distributed across multiple GPUs.
    Each GPU processes complete datasets for better efficiency.
    """
    distributed_state = PartialState()
    device = distributed_state.device
    dtype = torch.bfloat16

    # Initialize model without Accelerator
    model = Diffusion(None, None, device, disable_init=True)
    model.device = device
    model.dtype = dtype

    reward_config = OmegaConf.create({
        "dfn_score": 1.0,
        "hpsv2_score": 1.0,
        "clip_score": 1.0,
        "laion_aesthetic_score": 1.0,
        "text_reward_model_score": 1.0
    })

    all_rewards = {}
    # Get all dataset directories and distribute them across GPUs
    dataset_dirs = sorted([p for p in input_dir.iterdir() if p.is_dir()], key=lambda p: p.name)
    if not dataset_dirs:
        if distributed_state.is_main_process:
            print("No dataset directories found in the input directory.")
        raise typer.Exit()
    
    if num_dirs is not None:
        dataset_dirs = dataset_dirs[:num_dirs]

    # Split datasets across processes
    with distributed_state.split_between_processes(dataset_dirs) as process_dataset_dirs:
        for ds_dir in tqdm(process_dataset_dirs, desc=f"Processing datasets (GPU {distributed_state.process_index})"):
            if distributed_state.is_main_process:
                print(f"Processing dataset: {ds_dir.name}")
            
            pair_dirs = sorted([p for p in ds_dir.iterdir() if p.is_dir()], key=lambda p: p.name)
            if num_pairs is not None:
                pair_dirs = pair_dirs[:num_pairs]
            if not pair_dirs:
                if distributed_state.is_main_process:
                    print(f"  No pair subdirectories found in {ds_dir.name}, skipping.")
                continue

            images = []
            captions = []
            for pair_dir in sorted(pair_dirs, key=lambda p: p.name):
                image_path = pair_dir / "image.png"
                caption_path = pair_dir / "caption.txt"

                if not (image_path.exists() and caption_path.exists()):
                    print(f"  Skipping {pair_dir}: missing image.png or caption.txt")
                    continue

                try:
                    img = Image.open(image_path)
                    if resolution != img.height or resolution != img.width:
                        print(f"WARNING!!! Image resolution {img.height}x{img.width} does not match resolution {resolution}x{resolution}")
                        min_dim = min(img.width, img.height)
                        left = (img.width - min_dim) // 2
                        top = (img.height - min_dim) // 2
                        img = img.crop((left, top, left + min_dim, top + min_dim))
                        img = img.resize((resolution, resolution), Image.Resampling.LANCZOS)
                    images.append(Im(img).torch.unsqueeze(0))
                except Exception as e:
                    print(f"Error processing image {image_path}: {e}")
                    continue

                try:
                    caption = caption_path.read_text().strip()
                    captions.append(caption)
                except Exception as e:
                    print(f"Error reading caption {caption_path}: {e}")
                    continue

            num_pairs = len(images)
            if num_pairs == 0:
                print(f"No valid pairs found in dataset {ds_dir.name}, skipping.")
                continue

            dataset_reward_batches = []
            dataset_raw_rewards = []
            for i in tqdm(range(0, num_pairs, batch_size), desc="Processing pairs"):
                batch_imgs = torch.cat(images[i : i + batch_size], dim=0).to(device) / 255.0
                batch_texts = captions[i : i + batch_size]
                with torch.inference_mode():
                    rewards, raw_rewards = model.get_rewards(reward_config, batch_imgs, batch_texts, None, return_raw_rewards=True)
                dataset_reward_batches.append(rewards.cpu())
                dataset_raw_rewards.append(raw_rewards)

            dataset_rewards_tensor = torch.cat(dataset_reward_batches, dim=0)
            dataset_raw_rewards_dict = {}
            for key in raw_rewards.keys():
                dataset_raw_rewards_dict[key] = torch.cat(
                    [batch[key] for batch in dataset_raw_rewards], dim=0
                )

            all_rewards[ds_dir.name] = {
                "rewards": dataset_rewards_tensor.tolist(),
                "raw_rewards": {k: v.tolist() for k, v in dataset_raw_rewards_dict.items()},
                "folder_names": [f.name for f in pair_dirs],
                "folder_paths": [f.as_posix() for f in pair_dirs]
            }
            if distributed_state.is_main_process:
                print(f"Finished processing {num_pairs} pairs from {ds_dir.name}")

    gathered_rewards = gather_object([all_rewards])
    
    all_keys = set()
    all_gathered_rewards = {}
    for i in range(len(gathered_rewards)):
        assert len(set(gathered_rewards[i].keys()).intersection(all_keys)) == 0
        all_keys.update(gathered_rewards[i].keys())
        all_gathered_rewards.update(gathered_rewards[i])

    gathered_rewards = all_gathered_rewards

    if distributed_state.is_main_process:
        print("All rewards:")
        print(json.dumps(gathered_rewards, indent=2))

        try:
            output_file.parent.mkdir(parents=True, exist_ok=True)
            with open(output_file, "w") as f:
                json.dump(gathered_rewards, f, indent=2)
            print(f"Rewards saved to {output_file}")
        except Exception as e:
            print(f"Error saving rewards to file: {e}")


if __name__ == "__main__":
    app()