Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved. | |
# Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved. | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import json | |
import os | |
import numpy as np | |
import torch | |
import torchvision.transforms.functional as TVF | |
from torch.utils.data import DataLoader, Dataset | |
from torchvision.transforms import Compose, Normalize, ToTensor | |
def bucket_images(images: list[torch.Tensor], resolution: int = 512): | |
bucket_override=[ | |
# h w | |
(256, 768), | |
(320, 768), | |
(320, 704), | |
(384, 640), | |
(448, 576), | |
(512, 512), | |
(576, 448), | |
(640, 384), | |
(704, 320), | |
(768, 320), | |
(768, 256) | |
] | |
bucket_override = [(int(h / 512 * resolution), int(w / 512 * resolution)) for h, w in bucket_override] | |
bucket_override = [(h // 16 * 16, w // 16 * 16) for h, w in bucket_override] | |
aspect_ratios = [image.shape[-2] / image.shape[-1] for image in images] | |
mean_aspect_ratio = np.mean(aspect_ratios) | |
new_h, new_w = bucket_override[0] | |
min_aspect_diff = np.abs(new_h / new_w - mean_aspect_ratio) | |
for h, w in bucket_override: | |
aspect_diff = np.abs(h / w - mean_aspect_ratio) | |
if aspect_diff < min_aspect_diff: | |
min_aspect_diff = aspect_diff | |
new_h, new_w = h, w | |
images = [TVF.resize(image, (new_h, new_w)) for image in images] | |
images = torch.stack(images, dim=0) | |
return images | |
class FluxPairedDatasetV2(Dataset): | |
def __init__(self, json_file: str, resolution: int, resolution_ref: int | None = None): | |
super().__init__() | |
self.json_file = json_file | |
self.resolution = resolution | |
self.resolution_ref = resolution_ref if resolution_ref is not None else resolution | |
self.image_root = os.path.dirname(json_file) | |
with open(self.json_file, "rt") as f: | |
self.data_dicts = json.load(f) | |
self.transform = Compose([ | |
ToTensor(), | |
Normalize([0.5], [0.5]), | |
]) | |
def __getitem__(self, idx): | |
data_dict = self.data_dicts[idx] | |
image_paths = [data_dict["image_path"]] if "image_path" in data_dict else data_dict["image_paths"] | |
txt = data_dict["prompt"] | |
image_tgt_path = data_dict.get("image_tgt_path", None) | |
ref_imgs = [ | |
Image.open(os.path.join(self.image_root, path)).convert("RGB") | |
for path in image_paths | |
] | |
ref_imgs = [self.transform(img) for img in ref_imgs] | |
img = None | |
if image_tgt_path is not None: | |
img = Image.open(os.path.join(self.image_root, image_tgt_path)).convert("RGB") | |
img = self.transform(img) | |
return { | |
"img": img, | |
"txt": txt, | |
"ref_imgs": ref_imgs, | |
} | |
def __len__(self): | |
return len(self.data_dicts) | |
def collate_fn(self, batch): | |
img = [data["img"] for data in batch] | |
txt = [data["txt"] for data in batch] | |
ref_imgs = [data["ref_imgs"] for data in batch] | |
assert all([len(ref_imgs[0]) == len(ref_imgs[i]) for i in range(len(ref_imgs))]) | |
n_ref = len(ref_imgs[0]) | |
img = bucket_images(img, self.resolution) | |
ref_imgs_new = [] | |
for i in range(n_ref): | |
ref_imgs_i = [refs[i] for refs in ref_imgs] | |
ref_imgs_i = bucket_images(ref_imgs_i, self.resolution_ref) | |
ref_imgs_new.append(ref_imgs_i) | |
return { | |
"txt": txt, | |
"img": img, | |
"ref_imgs": ref_imgs_new, | |
} | |
if __name__ == '__main__': | |
import argparse | |
from pprint import pprint | |
parser = argparse.ArgumentParser() | |
# parser.add_argument("--json_file", type=str, required=True) | |
parser.add_argument("--json_file", type=str, default="datasets/fake_train_data.json") | |
args = parser.parse_args() | |
dataset = FluxPairedDatasetV2(args.json_file, 512) | |
dataloder = DataLoader(dataset, batch_size=4, collate_fn=dataset.collate_fn) | |
for i, data_dict in enumerate(dataloder): | |
pprint(i) | |
pprint(data_dict) | |
breakpoint() | |