File size: 11,722 Bytes
f9567e5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 |
"""
This file is used for T2I generation, it also compute the clip similarity between the generated images and the input prompt
"""
from absl import flags
from absl import app
from ml_collections import config_flags
import os
import ml_collections
import torch
from torch import multiprocessing as mp
import torch.nn as nn
import accelerate
import utils
import tempfile
from absl import logging
import builtins
import einops
import math
import numpy as np
import time
from PIL import Image
from diffusion.flow_matching import FlowMatching, ODEFlowMatchingSolver, ODEEulerFlowMatchingSolver
from tools.clip_score import ClipSocre
import libs.autoencoder
from libs.clip import FrozenCLIPEmbedder
from libs.t5 import T5Embedder
def unpreprocess(x):
x = 0.5 * (x + 1.)
x.clamp_(0., 1.)
return x
def batch_decode(_z, decode, batch_size=10):
"""
The VAE decoder requires large GPU memory. To run the interpolation model on GPUs with 24 GB or smaller RAM, you can use this code to reduce memory usage for the VAE.
It works by splitting the input tensor into smaller chunks.
"""
num_samples = _z.size(0)
decoded_batches = []
for i in range(0, num_samples, batch_size):
batch = _z[i:i + batch_size]
decoded_batch = decode(batch)
decoded_batches.append(decoded_batch)
image_unprocessed = torch.cat(decoded_batches, dim=0)
return image_unprocessed
def get_caption(llm, text_model, prompt_dict, batch_size):
if batch_size == 3:
# only addition or only subtraction
assert len(prompt_dict) == 2
_batch_con = list(prompt_dict.values()) + [' ']
elif batch_size == 4:
# addition and subtraction
assert len(prompt_dict) == 3
_batch_con = list(prompt_dict.values()) + [' ']
elif batch_size >= 5:
# linear interpolation
assert len(prompt_dict) == 2
_batch_con = [prompt_dict['prompt_1']] + [' '] * (batch_size-2) + [prompt_dict['prompt_2']]
if llm == "clip":
_latent, _latent_and_others = text_model.encode(_batch_con)
_con = _latent_and_others['token_embedding'].detach()
elif llm == "t5":
_latent, _latent_and_others = text_model.get_text_embeddings(_batch_con)
_con = (_latent_and_others['token_embedding'] * 10.0).detach()
else:
raise NotImplementedError
_con_mask = _latent_and_others['token_mask'].detach()
_batch_token = _latent_and_others['tokens'].detach()
_batch_caption = _batch_con
return (_con, _con_mask, _batch_token, _batch_caption)
def evaluate(config):
if config.get('benchmark', False):
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False
mp.set_start_method('spawn')
accelerator = accelerate.Accelerator()
device = accelerator.device
accelerate.utils.set_seed(config.seed, device_specific=True)
logging.info(f'Process {accelerator.process_index} using device: {device}')
config.mixed_precision = accelerator.mixed_precision
config = ml_collections.FrozenConfigDict(config)
if accelerator.is_main_process:
utils.set_logger(log_level='info', fname=config.output_path)
else:
utils.set_logger(log_level='error')
builtins.print = lambda *args: None
nnet = utils.get_nnet(**config.nnet)
nnet = accelerator.prepare(nnet)
logging.info(f'load nnet from {config.nnet_path}')
accelerator.unwrap_model(nnet).load_state_dict(torch.load(config.nnet_path, map_location='cpu'))
nnet.eval()
##
if config.nnet.model_args.clip_dim == 4096:
llm = "t5"
t5 = T5Embedder(device=device)
elif config.nnet.model_args.clip_dim == 768:
llm = "clip"
clip = FrozenCLIPEmbedder()
clip.eval()
clip.to(device)
else:
raise NotImplementedError
config = ml_collections.ConfigDict(config)
if config.test_type == 'interpolation':
prompt_dict = {'prompt_1':config.prompt_1, 'prompt_2':config.prompt_2}
for key in prompt_dict.keys():
assert prompt_dict[key] is not None
config.sample.mini_batch_size = config.num_of_interpolation
assert config.sample.mini_batch_size >= 5, "for linear interpolation, please sample at least five image"
elif config.test_type == 'arithmetic':
prompt_dict = {'prompt_ori':config.prompt_ori, 'prompt_a':config.prompt_a, 'prompt_s':config.prompt_s}
keys_to_remove = [key for key, value in prompt_dict.items() if value is None]
for key in keys_to_remove:
del prompt_dict[key]
counter = len(prompt_dict)
assert prompt_dict['prompt_ori'] is not None
assert counter == 2 or counter == 3
config.sample.mini_batch_size = counter + 1
else:
raise NotImplementedError
config = ml_collections.FrozenConfigDict(config)
if llm == "clip":
context_generator = get_caption(llm, clip, prompt_dict=prompt_dict, batch_size=config.sample.mini_batch_size)
elif llm == "t5":
context_generator = get_caption(llm, t5, prompt_dict=prompt_dict, batch_size=config.sample.mini_batch_size)
else:
raise NotImplementedError
##
autoencoder = libs.autoencoder.get_model(**config.autoencoder)
autoencoder.to(device)
@torch.cuda.amp.autocast()
def encode(_batch):
return autoencoder.encode(_batch)
@torch.cuda.amp.autocast()
def decode(_batch):
return autoencoder.decode(_batch)
bdv_nnet = None # We don't use Autoguidance
ClipSocre_model = ClipSocre(device=device) # we also return clip score
#######
logging.info(config.sample)
logging.info(f'sample: n_samples={config.sample.n_samples}, mode=t2i, mixed_precision={config.mixed_precision}')
def ode_fm_solver_sample(nnet_ema, _n_samples, _sample_steps, bdv_nnet=bdv_nnet, context=None, caption=None, testbatch_img_blurred=None, two_stage_generation=-1, token=None, token_mask=None, return_clipScore=False, ClipSocre_model=None):
with torch.no_grad():
del testbatch_img_blurred
_z_gaussian = torch.randn(_n_samples, *config.z_shape, device=device)
if 'dimr' in config.nnet.name or 'dit' in config.nnet.name:
_z_x0, _mu, _log_var = nnet_ema(context, text_encoder = True, shape = _z_gaussian.shape, mask=token_mask)
_z_init = _z_x0.reshape(_z_gaussian.shape)
else:
raise NotImplementedError
if len(_z_init) == 3:
if config.prompt_a is not None:
assert config.prompt_s is None
_z_x0_temp = _z_x0[0] + _z_x0[1]
elif config.prompt_s is not None:
assert config.prompt_a is None
_z_x0_temp = _z_x0[0] - _z_x0[1]
else:
raise NotImplementedError
mean = _z_x0_temp.mean()
std = _z_x0_temp.std()
_z_x0[2] = (_z_x0_temp - mean) / std
elif len(_z_init) == 4:
_z_x0_temp = _z_x0[0] + _z_x0[1] - _z_x0[2]
mean = _z_x0_temp.mean()
std = _z_x0_temp.std()
_z_x0[3] = (_z_x0_temp - mean) / std
elif len(_z_init) >= 5:
tensor_a = _z_init[0]
tensor_b = _z_init[-1]
num_interpolations = len(_z_init) - 2
interpolations = [tensor_a + (tensor_b - tensor_a) * (i / (num_interpolations + 1)) for i in range(1, num_interpolations + 1)]
_z_init = torch.stack([tensor_a] + interpolations + [tensor_b], dim=0)
assert config.sample.scale > 1
if config.cfg != -1:
_cfg = config.cfg
else:
_cfg = config.sample.scale
has_null_indicator = hasattr(config.nnet.model_args, "cfg_indicator")
_sample_steps = config.sample.sample_steps
ode_solver = ODEEulerFlowMatchingSolver(nnet_ema, bdv_model_fn=bdv_nnet, step_size_type="step_in_dsigma", guidance_scale=_cfg)
_z, _ = ode_solver.sample(x_T=_z_init, batch_size=_n_samples, sample_steps=_sample_steps, unconditional_guidance_scale=_cfg, has_null_indicator=has_null_indicator)
if config.save_gpu_memory:
image_unprocessed = batch_decode(_z, decode)
else:
image_unprocessed = decode(_z)
clip_score = ClipSocre_model.calculate_clip_score(caption, image_unprocessed)
return image_unprocessed, clip_score
def sample_fn(_n_samples, return_caption=False, return_clipScore=False, ClipSocre_model=None, config=None):
_context, _token_mask, _token, _caption = context_generator
assert return_clipScore
assert not return_caption
return ode_fm_solver_sample(nnet, _n_samples, config.sample.sample_steps, bdv_nnet=bdv_nnet, context=_context, token=_token, token_mask=_token_mask, return_clipScore=return_clipScore, ClipSocre_model=ClipSocre_model, caption=_caption)
with tempfile.TemporaryDirectory() as temp_path:
path = config.img_save_path or config.sample.path or temp_path
if accelerator.is_main_process:
os.makedirs(path, exist_ok=True)
logging.info(f'Samples are saved in {path}')
clip_score_list = utils.sample2dir_wCLIP(accelerator, path, config.sample.n_samples, config.sample.mini_batch_size, sample_fn, unpreprocess, return_clipScore=True, ClipSocre_model=ClipSocre_model, config=config)
if clip_score_list is not None:
_clip_score_list = torch.cat(clip_score_list)
if accelerator.is_main_process:
logging.info(f'nnet_path={config.nnet_path}, clip_score{len(_clip_score_list)}={_clip_score_list.mean().item()}')
FLAGS = flags.FLAGS
config_flags.DEFINE_config_file(
"config", None, "Training configuration.", lock_config=False)
flags.mark_flags_as_required(["config"])
flags.DEFINE_string("nnet_path", None, "The nnet to evaluate.")
flags.DEFINE_string("output_path", None, "The path to output log.")
flags.DEFINE_float("cfg", -1, 'cfg scale, will use the scale defined in the config file is not assigned')
flags.DEFINE_string("img_save_path", None, "The path to image log.")
flags.DEFINE_string("test_type", None, "The prompt used for generation.")
flags.DEFINE_string("prompt_1", None, "The prompt used for linear interpolation.")
flags.DEFINE_string("prompt_2", None, "The prompt used for linear interpolation.")
flags.DEFINE_integer("num_of_interpolation", -1, 'number of image being samples for linear interpolation')
flags.DEFINE_boolean('save_gpu_memory', False, 'To save VRAM')
flags.DEFINE_string("prompt_ori", None, "The prompt used for arithmetic operations.")
flags.DEFINE_string("prompt_a", None, "The prompt used for arithmetic operations (addition).")
flags.DEFINE_string("prompt_s", None, "The prompt used for arithmetic operations (subtraction).")
def main(argv):
config = FLAGS.config
config.nnet_path = FLAGS.nnet_path
config.output_path = FLAGS.output_path
config.img_save_path = FLAGS.img_save_path
config.cfg = FLAGS.cfg
config.test_type = FLAGS.test_type
config.prompt_1 = FLAGS.prompt_1
config.prompt_2 = FLAGS.prompt_2
config.num_of_interpolation = FLAGS.num_of_interpolation
config.save_gpu_memory = FLAGS.save_gpu_memory
config.prompt_ori = FLAGS.prompt_ori
config.prompt_a = FLAGS.prompt_a
config.prompt_s = FLAGS.prompt_s
evaluate(config)
if __name__ == "__main__":
app.run(main)
|