File size: 15,059 Bytes
d643072 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 |
import io
import os
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
import clip
import numpy as np
import torch
import webdataset as wds
from PIL import Image
from torch.utils.data import DataLoader, Dataset, IterableDataset
from diffusion.data.transforms import get_transform
from tools.metrics.utils import tracker
try:
from tqdm import tqdm
except ImportError:
# If tqdm is not available, provide a mock version of it
def tqdm(x):
return x
import json
IMAGE_EXTENSIONS = {"bmp", "jpg", "jpeg", "pgm", "png", "ppm", "tif", "tiff", "webp"}
TEXT_EXTENSIONS = {"txt"}
class DummyDataset(Dataset):
FLAGS = ["img", "txt", "json"]
def __init__(
self,
real_path,
fake_path,
real_flag: str = "img",
fake_flag: str = "img",
gen_img_path="",
transform=None,
tokenizer=None,
) -> None:
super().__init__()
assert (
real_flag in self.FLAGS and fake_flag in self.FLAGS
), f"CLIP Score only support modality of {self.FLAGS}. However, get {real_flag} and {fake_flag}"
self.gen_img_path = gen_img_path
print(f"images are from {gen_img_path}")
self.real_folder = self._load_img_from_path(real_path)
self.real_flag = real_flag
self.fake_data = self._load_txt_from_path(fake_path)
self.transform = transform
self.tokenizer = tokenizer
self.data_dict = {}
def __len__(self):
return len(self.real_folder)
def __getitem__(self, index):
if index >= len(self):
raise IndexError
real_path = self.real_folder[index]
real_data = self._load_modality(real_path, self.real_flag)
fake_data = self._load_txt(self.fake_data[index])
sample = dict(real=real_data, fake=fake_data, prompt=self.fake_data[index])
return sample
def _load_modality(self, path, modality):
if modality == "img":
data = self._load_img(path)
else:
raise TypeError(f"Got unexpected modality: {modality}")
return data
def _load_txt(self, data):
if self.tokenizer is not None:
data = self.tokenizer(data, context_length=77, truncate=True).squeeze()
return data
def _load_img(self, path):
img = Image.open(path)
if self.transform is not None:
img = self.transform(img)
return img
def _load_img_from_path(self, path):
image_list = []
if path.endswith(".json"):
with open(path) as file:
data_dict = json.load(file)
all_lines = list(data_dict.keys())[:sample_nums]
if isinstance(all_lines, list):
for k in all_lines:
img_path = os.path.join(self.gen_img_path, f"{k}.jpg")
image_list.append(img_path)
elif isinstance(all_lines, dict):
assert sample_nums >= 30_000, ValueError(f"{sample_nums} is not supported for json files")
for k, v in all_lines.items():
img_path = os.path.join(self.gen_img_path, f"{k}.jpg")
image_list.append(img_path)
else:
raise ValueError(f"Only JSON file type is supported now. Wrong with: {path}")
return image_list
def _load_txt_from_path(self, path):
txt_list = []
if path.endswith(".json"):
with open(path) as file:
data_dict = json.load(file)
all_lines = list(data_dict.keys())[:sample_nums]
if isinstance(all_lines, list):
for k in all_lines:
v = data_dict[k]
txt_list.append(v["prompt"])
elif isinstance(all_lines, dict):
assert sample_nums >= 30_000, ValueError(f"{sample_nums} is not supported for json files")
for k, v in all_lines.items():
txt_list.append(v["prompt"])
else:
raise ValueError(f"Only JSON file type is supported now. Wrong with: {path}")
return txt_list
class DummyTarDataset(IterableDataset):
def __init__(
self, tar_path, transform=None, external_json_path=None, prompt_key="prompt", tokenizer=None, **kwargs
):
assert ".tar" in tar_path
self.sample_nums = args.sample_nums
self.dataset = (
wds.WebDataset(tar_path)
.map(self.safe_decode)
.to_tuple("png;jpg", "json", "__key__")
.map(self.process_sample)
.slice(0, self.sample_nums)
)
if external_json_path is not None and os.path.exists(external_json_path):
print(f"Loading {external_json_path}, wait...")
self.json_file = json.load(open(external_json_path))
else:
self.json_file = {}
assert prompt_key == "prompt"
self.prompt_key = prompt_key
self.transform = transform
self.tokenizer = tokenizer
def __iter__(self):
return self._generator()
def _generator(self):
for i, (ori_img, info, key) in enumerate(self.dataset):
if self.transform is not None:
img = self.transform(ori_img)
if key in self.json_file:
info.update(self.json_file[key])
prompt = info.get(self.prompt_key, "")
if not prompt:
prompt = ""
print(f"{self.prompt_key} not exist in {key}.json")
txt_feat = self._load_txt(prompt)
yield dict(
real=img, fake=txt_feat, prompt=prompt, ori_img=np.array(img), key=key, prompt_key=self.prompt_key
)
def __len__(self):
return self.sample_nums
def _load_txt(self, data):
if self.tokenizer is not None:
data = self.tokenizer(data, context_length=77, truncate=True).squeeze()
return data
@staticmethod
def process_sample(sample):
try:
image_bytes, json_bytes, key = sample
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
json_dict = json.loads(json_bytes)
return image, json_dict, key
except (ValueError, TypeError, OSError) as e:
print(f"Skipping sample due to error: {e}")
return None
@staticmethod
def safe_decode(sample):
def custom_decode(sample):
result = {}
for k, v in sample.items():
result[k] = v
return result
try:
return custom_decode(sample)
except Exception as e:
print(f"skipping sample due to decode error: {e}")
return None
@torch.no_grad()
def calculate_clip_score(dataloader, model, real_flag, fake_flag, save_json_path=None):
score_acc = 0.0
sample_num = 0.0
json_dict = {} if save_json_path is not None else None
logit_scale = model.logit_scale.exp()
for batch_data in tqdm(dataloader, desc=f"CLIP-Score: {args.exp_name}", position=args.gpu_id, leave=True):
real_features = forward_modality(model, batch_data["real"], real_flag)
fake_features = forward_modality(model, batch_data["fake"], fake_flag)
# normalize features
real_features = real_features / real_features.norm(dim=1, keepdim=True).to(torch.float32)
fake_features = fake_features / fake_features.norm(dim=1, keepdim=True).to(torch.float32)
score = logit_scale * (fake_features * real_features).sum()
if save_json_path is not None:
json_dict[batch_data["key"][0]] = {f"{batch_data['prompt_key'][0]}": f"{score:.04f}"}
score_acc += score
sample_num += batch_data["real"].shape[0]
if save_json_path is not None:
json.dump(json_dict, open(save_json_path, "w"))
return score_acc / sample_num
@torch.no_grad()
def calculate_clip_score_official(dataloader):
import numpy as np
from torchmetrics.multimodal.clip_score import CLIPScore
clip_score_fn = CLIPScore(model_name_or_path="openai/clip-vit-large-patch14").to(device)
# clip_score_fn = CLIPScore(model_name_or_path="openai/clip-vit-base-patch16").to(device)
all_clip_scores = []
for batch_data in tqdm(dataloader, desc=args.exp_name, position=args.gpu_id, leave=True):
imgs = batch_data["real"].add_(1.0).mul_(0.5)
imgs = (imgs * 255).to(dtype=torch.uint8, device=device)
prompts = batch_data["prompt"]
clip_scores = clip_score_fn(imgs, prompts).detach().cpu()
all_clip_scores.append(float(clip_scores))
clip_scores = float(np.mean(all_clip_scores))
return clip_scores
def forward_modality(model, data, flag):
device = next(model.parameters()).device
if flag == "img":
features = model.encode_image(data.to(device))
elif flag == "txt":
features = model.encode_text(data.to(device))
else:
raise TypeError
return features
def main():
txt_path = args.txt_path if args.txt_path is not None else args.img_path
gen_img_path = str(os.path.join(args.img_path, args.exp_name))
if ".tar" in gen_img_path:
save_txt_path = os.path.join(txt_path, f"{args.exp_name}_{args.tar_prompt_key}_clip_score.txt").replace(
".tar", ""
)
save_json_path = save_txt_path.replace(".tar", "").replace(".txt", ".json")
if os.path.exists(save_json_path):
print(f"{save_json_path} exists. Finished.")
return None
else:
save_txt_path = os.path.join(txt_path, f"{args.exp_name}_sample{sample_nums}_clip_score.txt")
save_json_path = None
if os.path.exists(save_txt_path):
with open(save_txt_path) as f:
clip_score = f.readlines()[0].strip()
print(f"CLIP Score: {clip_score}: {args.exp_name}")
return {args.exp_name: float(clip_score)}
print(f"Loading CLIP model: {args.clip_model}")
if args.clipscore_type == "diffusers":
preprocess = get_transform("default_train", 512)
else:
model, preprocess = clip.load(args.clip_model, device=device)
if ".tar" in gen_img_path:
dataset = DummyTarDataset(
gen_img_path,
transform=preprocess,
external_json_path=args.external_json_file,
prompt_key=args.tar_prompt_key,
tokenizer=clip.tokenize,
)
else:
dataset = DummyDataset(
args.real_path,
args.fake_path,
args.real_flag,
args.fake_flag,
transform=preprocess,
tokenizer=clip.tokenize,
gen_img_path=gen_img_path,
)
dataloader = DataLoader(dataset, args.batch_size, num_workers=num_workers, pin_memory=True)
print("Calculating CLIP Score:")
if args.clipscore_type == "diffusers":
clip_score = calculate_clip_score_official(dataloader)
else:
clip_score = calculate_clip_score(
dataloader, model, args.real_flag, args.fake_flag, save_json_path=save_json_path
)
clip_score = clip_score.cpu().item()
print("CLIP Score: ", clip_score)
with open(save_txt_path, "w") as file:
file.write(str(clip_score))
print(f"Result saved at: {save_txt_path}")
return {args.exp_name: clip_score}
def parse_args():
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
parser.add_argument("--batch-size", type=int, default=50, help="Batch size to use")
parser.add_argument("--clip-model", type=str, default="ViT-L/14", help="CLIP model to use")
# parser.add_argument('--clip-model', type=str, default='ViT-B/16', help='CLIP model to use')
parser.add_argument("--img_path", type=str, default=None)
parser.add_argument("--txt_path", type=str, default=None)
parser.add_argument("--sample_nums", type=int, default=30_000)
parser.add_argument("--exp_name", type=str, default="Sana")
parser.add_argument(
"--num-workers", type=int, help="Number of processes to use for data loading. Defaults to `min(8, num_cpus)`"
)
parser.add_argument("--device", type=str, default=None, help="Device to use. Like cuda, cuda:0 or cpu")
parser.add_argument("--real_flag", type=str, default="img", help="The modality of real path. Default to img")
parser.add_argument("--fake_flag", type=str, default="txt", help="The modality of real path. Default to txt")
parser.add_argument("--real_path", type=str, help="Paths to the generated images")
parser.add_argument("--fake_path", type=str, help="Paths to the generated images")
parser.add_argument("--external_json_file", type=str, default=None, help="external meta json file for tar_file")
parser.add_argument("--tar_prompt_key", type=str, default="prompt", help="key name of prompt in json")
# online logging setting
parser.add_argument("--clipscore_type", type=str, default="self", choices=["diffusers", "self"])
parser.add_argument("--log_metric", type=str, default="metric")
parser.add_argument("--gpu_id", type=int, default=0)
parser.add_argument("--log_clip_score", action="store_true")
parser.add_argument("--suffix_label", type=str, default="", help="used for clip_score online log")
parser.add_argument("--tracker_pattern", type=str, default="epoch_step", help="used for fid online log")
parser.add_argument(
"--report_to",
type=str,
default=None,
help=(
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
),
)
parser.add_argument(
"--tracker_project_name",
type=str,
default="t2i-evit-baseline",
help=(
"The `project_name` argument passed to Accelerator.init_trackers for"
" more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
),
)
parser.add_argument(
"--name",
type=str,
default="baseline",
help=("Wandb Project Name"),
)
args = parser.parse_args()
return args
if __name__ == "__main__":
args = parse_args()
sample_nums = args.sample_nums
if args.device is None:
device = torch.device("cuda" if (torch.cuda.is_available()) else "cpu")
else:
device = torch.device(args.device)
if args.num_workers is None:
try:
num_cpus = len(os.sched_getaffinity(0))
except AttributeError:
num_cpus = os.cpu_count()
num_workers = min(num_cpus, 8) if num_cpus is not None else 0
else:
num_workers = args.num_workers
args.exp_name = os.path.basename(args.exp_name) or os.path.dirname(args.exp_name)
clip_score_result = main()
if args.log_clip_score:
tracker(args, clip_score_result, args.suffix_label, pattern=args.tracker_pattern, metric="CLIP-Score")
|