|
import ml_collections |
|
import torch |
|
from torch import multiprocessing as mp |
|
from datasets import get_dataset |
|
from torchvision.utils import make_grid, save_image |
|
import utils |
|
import einops |
|
from torch.utils._pytree import tree_map |
|
import accelerate |
|
from torch.utils.data import DataLoader |
|
from tqdm.auto import tqdm |
|
import tempfile |
|
from absl import logging |
|
import builtins |
|
import os |
|
import wandb |
|
import numpy as np |
|
import time |
|
import random |
|
|
|
import libs.autoencoder |
|
from libs.t5 import T5Embedder |
|
from libs.clip import FrozenCLIPEmbedder |
|
from diffusion.flow_matching import FlowMatching, ODEFlowMatchingSolver, ODEEulerFlowMatchingSolver |
|
from tools.fid_score import calculate_fid_given_paths |
|
from tools.clip_score import ClipSocre |
|
|
|
|
|
def train(config): |
|
if config.get('benchmark', False): |
|
torch.backends.cudnn.benchmark = True |
|
torch.backends.cudnn.deterministic = False |
|
|
|
mp.set_start_method('spawn') |
|
accelerator = accelerate.Accelerator() |
|
device = accelerator.device |
|
accelerate.utils.set_seed(config.seed, device_specific=True) |
|
logging.info(f'Process {accelerator.process_index} using device: {device}') |
|
|
|
config.mixed_precision = accelerator.mixed_precision |
|
config = ml_collections.FrozenConfigDict(config) |
|
|
|
assert config.train.batch_size % accelerator.num_processes == 0 |
|
mini_batch_size = config.train.batch_size // accelerator.num_processes |
|
|
|
if accelerator.is_main_process: |
|
os.makedirs(config.ckpt_root, exist_ok=True) |
|
os.makedirs(config.sample_dir, exist_ok=True) |
|
accelerator.wait_for_everyone() |
|
if accelerator.is_main_process: |
|
wandb.init(dir=os.path.abspath(config.workdir), project=f'uvit_{config.dataset.name}', config=config.to_dict(), |
|
name=config.hparams, job_type='train', mode='offline') |
|
utils.set_logger(log_level='info', fname=os.path.join(config.workdir, 'output.log')) |
|
logging.info(config) |
|
else: |
|
utils.set_logger(log_level='error') |
|
builtins.print = lambda *args: None |
|
logging.info(f'Run on {accelerator.num_processes} devices') |
|
|
|
dataset = get_dataset(**config.dataset) |
|
assert os.path.exists(dataset.fid_stat) |
|
|
|
gpu_model = torch.cuda.get_device_name(torch.cuda.current_device()) |
|
num_workers = 8 |
|
|
|
train_dataset = dataset.get_split(split='train', labeled=True) |
|
train_dataset_loader = DataLoader(train_dataset, batch_size=mini_batch_size, shuffle=True, drop_last=True, |
|
num_workers=num_workers, pin_memory=True, persistent_workers=True) |
|
|
|
test_dataset = dataset.get_split(split='test', labeled=True) |
|
test_dataset_loader = DataLoader(test_dataset, batch_size=config.sample.mini_batch_size, shuffle=True, drop_last=True, |
|
num_workers=num_workers, pin_memory=True, persistent_workers=True) |
|
|
|
train_state = utils.initialize_train_state(config, device) |
|
nnet, nnet_ema, optimizer, train_dataset_loader, test_dataset_loader = accelerator.prepare( |
|
train_state.nnet, train_state.nnet_ema, train_state.optimizer, train_dataset_loader, test_dataset_loader) |
|
lr_scheduler = train_state.lr_scheduler |
|
train_state.resume(config.ckpt_root) |
|
|
|
autoencoder = libs.autoencoder.get_model(**config.autoencoder) |
|
autoencoder.to(device) |
|
|
|
if config.nnet.model_args.clip_dim == 4096: |
|
llm = "t5" |
|
t5 = T5Embedder(device=device) |
|
elif config.nnet.model_args.clip_dim == 768: |
|
llm = "clip" |
|
clip = FrozenCLIPEmbedder() |
|
clip.eval() |
|
clip.to(device) |
|
else: |
|
raise NotImplementedError |
|
|
|
ss_empty_context = None |
|
|
|
ClipSocre_model = ClipSocre(device=device) |
|
|
|
@ torch.cuda.amp.autocast() |
|
def encode(_batch): |
|
return autoencoder.encode(_batch) |
|
|
|
@ torch.cuda.amp.autocast() |
|
def decode(_batch): |
|
return autoencoder.decode(_batch) |
|
|
|
def get_data_generator(): |
|
while True: |
|
for data in tqdm(train_dataset_loader, disable=not accelerator.is_main_process, desc='epoch'): |
|
yield data |
|
|
|
data_generator = get_data_generator() |
|
|
|
def get_context_generator(autoencoder): |
|
while True: |
|
for data in test_dataset_loader: |
|
if len(data) == 5: |
|
_img, _context, _token_mask, _token, _caption = data |
|
else: |
|
_img, _context = data |
|
_token_mask = None |
|
_token = None |
|
_caption = None |
|
|
|
if len(_img.shape)==5: |
|
_testbatch_img_blurred = autoencoder.sample(_img[:,1,:]) |
|
yield _context, _token_mask, _token, _caption, _testbatch_img_blurred |
|
else: |
|
assert len(_img.shape)==4 |
|
yield _context, _token_mask, _token, _caption, None |
|
|
|
context_generator = get_context_generator(autoencoder) |
|
|
|
_flow_mathcing_model = FlowMatching() |
|
|
|
def train_step(_batch, _ss_empty_context): |
|
_metrics = dict() |
|
optimizer.zero_grad() |
|
|
|
assert len(_batch)==6 |
|
assert not config.dataset.cfg |
|
_batch_img = _batch[0] |
|
_batch_con = _batch[1] |
|
_batch_mask = _batch[2] |
|
_batch_token = _batch[3] |
|
_batch_caption = _batch[4] |
|
_batch_img_ori = _batch[5] |
|
|
|
_z = autoencoder.sample(_batch_img) |
|
|
|
loss, loss_dict = _flow_mathcing_model(_z, nnet, loss_coeffs=config.loss_coeffs, cond=_batch_con, con_mask=_batch_mask, batch_img_clip=_batch_img_ori, \ |
|
nnet_style=config.nnet.name, text_token=_batch_token, model_config=config.nnet.model_args, all_config=config, training_step=train_state.step) |
|
|
|
_metrics['loss'] = accelerator.gather(loss.detach()).mean() |
|
for key in loss_dict.keys(): |
|
_metrics[key] = accelerator.gather(loss_dict[key].detach()).mean() |
|
accelerator.backward(loss.mean()) |
|
optimizer.step() |
|
lr_scheduler.step() |
|
train_state.ema_update(config.get('ema_rate', 0.9999)) |
|
train_state.step += 1 |
|
return dict(lr=train_state.optimizer.param_groups[0]['lr'], **_metrics) |
|
|
|
def ode_fm_solver_sample(nnet_ema, _n_samples, _sample_steps, context=None, caption=None, testbatch_img_blurred=None, two_stage_generation=-1, token_mask=None, return_clipScore=False, ClipSocre_model=None): |
|
with torch.no_grad(): |
|
_z_gaussian = torch.randn(_n_samples, *config.z_shape, device=device) |
|
|
|
_z_x0, _mu, _log_var = nnet_ema(context, text_encoder = True, shape = _z_gaussian.shape, mask=token_mask) |
|
_z_init = _z_x0.reshape(_z_gaussian.shape) |
|
|
|
assert config.sample.scale > 1 |
|
_cfg = config.sample.scale |
|
|
|
has_null_indicator = hasattr(config.nnet.model_args, "cfg_indicator") |
|
|
|
ode_solver = ODEEulerFlowMatchingSolver(nnet_ema, step_size_type="step_in_dsigma", guidance_scale=_cfg) |
|
_z, _ = ode_solver.sample(x_T=_z_init, batch_size=_n_samples, sample_steps=_sample_steps, unconditional_guidance_scale=_cfg, has_null_indicator=has_null_indicator) |
|
|
|
image_unprocessed = decode(_z) |
|
|
|
if return_clipScore: |
|
clip_score = ClipSocre_model.calculate_clip_score(caption, image_unprocessed) |
|
return image_unprocessed, clip_score |
|
else: |
|
return image_unprocessed |
|
|
|
def eval_step(n_samples, sample_steps): |
|
logging.info(f'eval_step: n_samples={n_samples}, sample_steps={sample_steps}, algorithm=ODE_Euler_Flow_Matching_Solver, ' |
|
f'mini_batch_size={config.sample.mini_batch_size}') |
|
|
|
def sample_fn(_n_samples, return_caption=False, return_clipScore=False, ClipSocre_model=None, config=None): |
|
_context, _token_mask, _token, _caption, _testbatch_img_blurred = next(context_generator) |
|
assert _context.size(0) == _n_samples |
|
assert not return_caption |
|
if return_caption: |
|
return ode_fm_solver_sample(nnet_ema, _n_samples, sample_steps, context=_context, token_mask=_token_mask), _caption |
|
elif return_clipScore: |
|
return ode_fm_solver_sample(nnet_ema, _n_samples, sample_steps, context=_context, token_mask=_token_mask, return_clipScore=return_clipScore, ClipSocre_model=ClipSocre_model, caption=_caption) |
|
else: |
|
return ode_fm_solver_sample(nnet_ema, _n_samples, sample_steps, context=_context, token_mask=_token_mask) |
|
|
|
with tempfile.TemporaryDirectory() as temp_path: |
|
path = config.sample.path or temp_path |
|
if accelerator.is_main_process: |
|
os.makedirs(path, exist_ok=True) |
|
clip_score_list = utils.sample2dir(accelerator, path, n_samples, config.sample.mini_batch_size, sample_fn, dataset.unpreprocess, return_clipScore=True, ClipSocre_model=ClipSocre_model, config=config) |
|
_fid = 0 |
|
if accelerator.is_main_process: |
|
_fid = calculate_fid_given_paths((dataset.fid_stat, path)) |
|
_clip_score_list = torch.cat(clip_score_list) |
|
logging.info(f'step={train_state.step} fid{n_samples}={_fid} clip_score{len(_clip_score_list)} = {_clip_score_list.mean().item()}') |
|
with open(os.path.join(config.workdir, 'eval.log'), 'a') as f: |
|
print(f'step={train_state.step} fid{n_samples}={_fid} clip_score{len(_clip_score_list)} = {_clip_score_list.mean().item()}', file=f) |
|
wandb.log({f'fid{n_samples}': _fid}, step=train_state.step) |
|
_fid = torch.tensor(_fid, device=device) |
|
_fid = accelerator.reduce(_fid, reduction='sum') |
|
|
|
return _fid.item() |
|
|
|
logging.info(f'Start fitting, step={train_state.step}, mixed_precision={config.mixed_precision}') |
|
|
|
step_fid = [] |
|
while train_state.step < config.train.n_steps: |
|
nnet.train() |
|
batch = tree_map(lambda x: x, next(data_generator)) |
|
metrics = train_step(batch, ss_empty_context) |
|
|
|
nnet.eval() |
|
if accelerator.is_main_process and train_state.step % config.train.log_interval == 0: |
|
logging.info(utils.dct2str(dict(step=train_state.step, **metrics))) |
|
logging.info(config.workdir) |
|
wandb.log(metrics, step=train_state.step) |
|
|
|
|
|
if train_state.step % config.train.eval_interval == 0: |
|
torch.cuda.empty_cache() |
|
logging.info('Save a grid of images...') |
|
if hasattr(dataset, "token_embedding"): |
|
contexts = torch.tensor(dataset.token_embedding, device=device)[ : config.train.n_samples_eval] |
|
token_mask = torch.tensor(dataset.token_mask, device=device)[ : config.train.n_samples_eval] |
|
elif hasattr(dataset, "contexts"): |
|
contexts = torch.tensor(dataset.contexts, device=device)[ : config.train.n_samples_eval] |
|
token_mask = None |
|
else: |
|
raise NotImplementedError |
|
samples = ode_fm_solver_sample(nnet_ema, _n_samples=config.train.n_samples_eval, _sample_steps=50, context=contexts, token_mask=token_mask) |
|
samples = make_grid(dataset.unpreprocess(samples), 5) |
|
if accelerator.is_main_process: |
|
save_image(samples, os.path.join(config.sample_dir, f'{train_state.step}.png')) |
|
wandb.log({'samples': wandb.Image(samples)}, step=train_state.step) |
|
accelerator.wait_for_everyone() |
|
torch.cuda.empty_cache() |
|
|
|
|
|
if train_state.step % config.train.save_interval == 0 or train_state.step == config.train.n_steps: |
|
torch.cuda.empty_cache() |
|
logging.info(f'Save and eval checkpoint {train_state.step}...') |
|
|
|
if accelerator.local_process_index == 0: |
|
train_state.save(os.path.join(config.ckpt_root, f'{train_state.step}.ckpt')) |
|
accelerator.wait_for_everyone() |
|
|
|
fid = eval_step(n_samples=10000, sample_steps=50) |
|
step_fid.append((train_state.step, fid)) |
|
|
|
torch.cuda.empty_cache() |
|
accelerator.wait_for_everyone() |
|
|
|
logging.info(f'Finish fitting, step={train_state.step}') |
|
logging.info(f'step_fid: {step_fid}') |
|
step_best = sorted(step_fid, key=lambda x: x[1])[0][0] |
|
logging.info(f'step_best: {step_best}') |
|
train_state.load(os.path.join(config.ckpt_root, f'{step_best}.ckpt')) |
|
del metrics |
|
accelerator.wait_for_everyone() |
|
eval_step(n_samples=config.sample.n_samples, sample_steps=config.sample.sample_steps) |
|
|
|
|
|
|
|
from absl import flags |
|
from absl import app |
|
from ml_collections import config_flags |
|
import sys |
|
from pathlib import Path |
|
|
|
|
|
FLAGS = flags.FLAGS |
|
config_flags.DEFINE_config_file( |
|
"config", None, "Training configuration.", lock_config=False) |
|
flags.mark_flags_as_required(["config"]) |
|
flags.DEFINE_string("workdir", None, "Work unit directory.") |
|
|
|
|
|
def get_config_name(): |
|
argv = sys.argv |
|
for i in range(1, len(argv)): |
|
if argv[i].startswith('--config='): |
|
return Path(argv[i].split('=')[-1]).stem |
|
|
|
|
|
def get_hparams(): |
|
argv = sys.argv |
|
lst = [] |
|
for i in range(1, len(argv)): |
|
assert '=' in argv[i] |
|
if argv[i].startswith('--config.') and not argv[i].startswith('--config.dataset.path'): |
|
hparam, val = argv[i].split('=') |
|
hparam = hparam.split('.')[-1] |
|
if hparam.endswith('path'): |
|
val = Path(val).stem |
|
lst.append(f'{hparam}={val}') |
|
hparams = '-'.join(lst) |
|
if hparams == '': |
|
hparams = 'default' |
|
return hparams |
|
|
|
|
|
def main(argv): |
|
config = FLAGS.config |
|
config.config_name = get_config_name() |
|
config.hparams = get_hparams() |
|
config.workdir = FLAGS.workdir or os.path.join('workdir', config.config_name, config.hparams) |
|
config.ckpt_root = os.path.join(config.workdir, 'ckpts') |
|
config.sample_dir = os.path.join(config.workdir, 'samples') |
|
train(config) |
|
|
|
|
|
if __name__ == "__main__": |
|
app.run(main) |
|
|