|
from collections import defaultdict |
|
from pathlib import Path |
|
import json |
|
import re |
|
import typer |
|
|
|
typer.main.get_command_name = lambda name: name |
|
app = typer.Typer(pretty_exceptions_show_locals=False) |
|
|
|
|
|
PREFIX_PATTERN = re.compile(r"(.+?)__pair_") |
|
|
|
def extract_prefix(folder: str) -> str: |
|
""" |
|
Extracts the prefix from a folder name using the PREFIX_PATTERN. |
|
If the pattern does not match, returns the folder name as-is. |
|
""" |
|
match = PREFIX_PATTERN.match(folder) |
|
return match.group(1) if match else folder |
|
|
|
def get_ignored_reward_keys(prefix: str) -> set[str]: |
|
""" |
|
Returns a set of raw reward keys to ignore based on the prefix. |
|
Adjust this mapping as your application logic requires. |
|
""" |
|
if "capmask" not in prefix and "cap" in prefix: |
|
return {"text_reward_model_score"} |
|
elif "imgmask" not in prefix and "img" in prefix: |
|
return {"laion_aesthetic_score"} |
|
return set() |
|
|
|
@app.command() |
|
def main( |
|
input_file: Path, |
|
save_image: bool = False |
|
): |
|
""" |
|
Reads a generated JSON rewards file and, for each dataset, |
|
processes each unique prefix. For each prefix, it finds the matching examples |
|
in the dataset and computes: |
|
- The overall normalized reward average (normalized from 0 to 1). |
|
- The normalized average for each raw reward type (ignoring certain types based on the prefix). |
|
|
|
The prefix is extracted from each folder name by matching everything before the |
|
'__pair_' substring. If the folder name does not match, the entire folder name is used. |
|
|
|
Normalization is performed using the global minimum and maximum values for each |
|
reward type (computed over all datasets and indices where the reward is not ignored). |
|
""" |
|
try: |
|
content = input_file.read_text() |
|
data = json.loads(content) |
|
except Exception as e: |
|
typer.echo(f"Error reading JSON file: {e}") |
|
raise typer.Exit(code=1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
norm_stats: dict[str, dict[str, tuple[float, float]]] = {} |
|
for dataset in data.values(): |
|
folder_names: list[str] = dataset.get("folder_names", []) |
|
raw_rewards: dict[str, list[float]] = dataset.get("raw_rewards", {}) |
|
for i, folder in enumerate(folder_names): |
|
current_prefix = extract_prefix(folder) |
|
ignore_keys = get_ignored_reward_keys(current_prefix) |
|
for reward_key, values in raw_rewards.items(): |
|
if reward_key in ignore_keys: |
|
continue |
|
if i >= len(values): |
|
continue |
|
value = values[i] |
|
if current_prefix not in norm_stats: |
|
norm_stats[current_prefix] = {} |
|
if reward_key not in norm_stats[current_prefix]: |
|
norm_stats[current_prefix][reward_key] = (value, value) |
|
else: |
|
curr_min, curr_max = norm_stats[current_prefix][reward_key] |
|
norm_stats[current_prefix][reward_key] = (min(curr_min, value), max(curr_max, value)) |
|
|
|
|
|
unique_prefixes: set[str] = set() |
|
for dataset in data.values(): |
|
folder_names = dataset.get("folder_names", []) |
|
for folder in folder_names: |
|
unique_prefixes.add(extract_prefix(folder)) |
|
unique_prefixes = sorted(unique_prefixes) |
|
|
|
print(f"Found {len(unique_prefixes)} unique prefixes: {unique_prefixes}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for prefix in unique_prefixes: |
|
typer.echo(f"Prefix: {prefix}") |
|
dataset_outputs = [] |
|
img_outputs = defaultdict(list) |
|
for dataset_name, dataset in data.items(): |
|
output_lines = [] |
|
output_lines.append(f" Dataset: {dataset_name}") |
|
|
|
folder_names: list[str] = dataset.get("folder_names", []) |
|
folder_paths: list[str] = dataset.get("folder_paths", []) |
|
raw_rewards: dict[str, list[float]] = dataset.get("raw_rewards", {}) |
|
|
|
if not folder_names: |
|
output_lines.append(" No folder names provided in this dataset.") |
|
dataset_outputs.append((float("-inf"), "\n".join(output_lines))) |
|
continue |
|
|
|
|
|
indices = [ |
|
idx for idx, folder in enumerate(folder_names) |
|
if extract_prefix(folder) == prefix |
|
] |
|
|
|
if save_image: |
|
num_to_save = 2 |
|
_folder_paths = sorted([Path(p) for p in folder_paths]) |
|
for idx in indices[:num_to_save]: |
|
img_outputs[_folder_paths[idx].name].append((dataset_name, _folder_paths[idx])) |
|
|
|
ignore_keys = get_ignored_reward_keys(prefix) |
|
reward_details = "" |
|
total_norm_rewards = 0.0 |
|
for reward_key, values in raw_rewards.items(): |
|
if reward_key in ignore_keys: |
|
continue |
|
|
|
|
|
norm_info = norm_stats.get(prefix, {}).get(reward_key) |
|
if norm_info is None: |
|
reward_details += f"{reward_key}: No data, " |
|
continue |
|
min_val, max_val = norm_info |
|
|
|
|
|
group_norm_values = [] |
|
for i in indices: |
|
if i < len(values): |
|
orig_value = values[i] |
|
normalized = ((orig_value - min_val) / (max_val - min_val)) if max_val != min_val else 0.0 |
|
group_norm_values.append(normalized) |
|
if group_norm_values: |
|
avg_norm = sum(group_norm_values) / len(group_norm_values) |
|
reward_details += f"{reward_key}: {avg_norm:.4f}, " |
|
total_norm_rewards += avg_norm |
|
else: |
|
reward_details += f"{reward_key}: No data, " |
|
|
|
|
|
overall_avg = total_norm_rewards / len(raw_rewards) if raw_rewards else 0.0 |
|
|
|
reward_details = f" Avg: {overall_avg:.4f}, " + reward_details |
|
output_lines.append(reward_details) |
|
dataset_outputs.append((overall_avg, "\n".join(output_lines))) |
|
|
|
|
|
for avg, out in sorted(dataset_outputs, key=lambda x: x[0], reverse=True): |
|
typer.echo(out) |
|
|
|
typer.echo("-" * 40) |
|
|
|
if save_image: |
|
from unidisc.utils.viz_utils import create_text_image |
|
from PIL import Image |
|
from image_utils import Im |
|
for k, v in img_outputs.items(): |
|
imgs = [] |
|
for _dataset_name, _folder_path in v: |
|
def get_img(_img_path, _txt_path): |
|
_img = Image.open(_img_path).resize((1024, 1024)) |
|
out = f'{_dataset_name}: {_txt_path.read_text()}' |
|
txt_img = create_text_image(out, desired_width=_img.width) |
|
_img = Im.concat_vertical(_img, txt_img) |
|
|
|
input_img = None |
|
if (_folder_path / "input_image.png").exists(): |
|
input_img = get_img(_folder_path / "input_image.png", _folder_path / "input_caption.txt") |
|
|
|
out_img = get_img(_folder_path / "image.png", _folder_path / "caption.txt") |
|
if input_img: |
|
imgs.append(Im.concat_vertical(input_img, out_img)) |
|
else: |
|
imgs.append(out_img) |
|
|
|
Im.concat_horizontal([x.pil for x in imgs]).save(f"{k}.png") |
|
|
|
if __name__ == "__main__": |
|
app() |
|
|