UniTok / unitok /config.py
machuofan
init
7385f22
import os
import sys
import torch
import random
import numpy as np
from tap import Tap
from typing import Optional, Union
from collections import OrderedDict
from unitok import dist
class Args(Tap):
model: str = 'vitamin_large' # 'vitamin_base', 'vitamin_large', xxx
exp_name: str = 'unitok_large'
output_dir: str = 'local_output'
resume_from: str = '' # if specified, load this checkpoint; if not, load the latest checkpoint in output_dir (if exists)
lpips_path: str = 'external/lpips_with_vgg.pth'
dino_path: str = 'external/dinov2_vits14_pretrain.pth'
fid_eval_src: str = ''
fid_eval_dst: str = ''
vis_img_dir: str = 'asset/vis_imgs/'
fid_feature_extractor: str = 'external/weights-inception-2015-12-05-6726825d.pth'
clip_pretrain_path: str = ''
# speed-up
fp16: bool = False # whether to use FP16
bf16: bool = True # whether to use BF16
tf32: bool = True # whether to use TensorFloat32
compile_model: bool = False # whether to use torch.compile()
ddp_static: bool = False # whether to use static graph in DDP
grad_ckpt: bool = True # gradient checkpointing
grad_accu: int = 1 # gradient accumulation
device: str = 'cpu' # will be set automatically
dtype: torch.dtype = torch.float32 # will be set automatically
# data
train_data: str = None
val_data: str = None
dataset_type: str = 'webdataset'
imagenet_val: str = None
imagenet_v2: str = None
subset_ratio: float = 1.0
img_size: int = 256
resize_ratio: float = 1.125 # only applicable to 'img' dataset_type
hflip: bool = False
workers: int = 8 # num workers; 0: auto, -1: don't use multiprocessing in DataLoader
train_num_samples: int = 1280_000_000
train_data_upsampling_factors: str = None
dataset_resampled: bool = False
use_aug: bool = False
# quantizer
vocab_size: int = 32768
vocab_width: int = 64
vocab_norm: bool = True
vq_beta: float = 0.25 # commitment loss weight
num_codebooks: int = 8
quant_proj: str = 'attn'
# model
embed_dim: int = 768
num_query: int = 0
use_clip_pretrain: bool = False
patch_size: int = 16
drop_path: float = 0.1
text_width: int = 768
text_heads: int = 12
text_layers: int = 12
text_vocab_size: int = 49408
text_context_length: int = 77
# CLIP
local_loss: bool = True
gather_with_grad: bool = True
pretrained_clip: str = None
pretrained_clip_text: str = None
lock_text: bool = False
lock_text_unlocked_layers: int = 0
lock_text_freeze_layer_norm: bool = False
force_custom_text: bool = False
force_custom_vision: bool = False
zeroshot_eval_freq: int = 1
# discriminator
dino_depth: int = 12
dino_kernel_size: int = 9
disc_norm: str = 'gn' # gn: group norm, bn: batch norm, sbn: sync batch norm, hbn: hybrid sync batch norm
disc_aug_prob: float = 1.0
disc_specnorm: bool = False
step_disc_every: int = 1
# initialization
vae_init: float = -0.5 # <0: xavier_normal_(gain=abs(init)); >0: trunc_normal_(std=init)
vocab_init: float = -1 # <0: uniform(-abs(init)*base, abs(init)*base), where base = 20/vocab_size; >0: trunc_normal_(std=init)
disc_init: float = -0.5 # <0: xavier_normal_(gain=abs(init)); >0: trunc_normal_(std=init)
# optimization
epoch: int = 1 # number of epochs
local_bs: int = 64 # batch size per device; if this is specified, --global_bs will be ignored
vae_local_bs: int = 64 # sub-batch size for vae loss calculation
global_bs: int = 0 # global batch size (exclusive to --local_bs)
lr: float = 5e-4 # learning rate
wd: float = 0.02 # weight decay
disc_lr: float = 2e-5 # disc lr
disc_wd: float = 0.2
grad_clip: float = 10 # <=0 for not using grad clip
ema: float = 0.9999 # ema ratio
warmup_iter: int = None
warmup_ep: float = 0.01 # lr warmup: epochs
disc_start_ep: float = 0.375 # start using disc loss for VAE after xxx epochs;
disc_warmup_ep: float = 0.03 # disc loss warm up epochs;
schedule: str = 'cos' # lr schedule type
lr_start_ratio: float = 0. # lr warmup: initial lr ratio
lr_end_ratio: float = 0.1 # lr schedule: final lr ratio
disc_lr_end_ratio: float = 0.1
custom_lr_multiplier: float = None
optimizer: str = 'adamw'
optim_eps: float = 1e-6
fuse_opt: bool = False # whether to use fused optimizer
optim_beta: str = '0.9_0.95' # beta1, beta2 of optimizer
disc_optim_beta: str = '0.5_0.9' # beta1, beta2 of disc optimizer
# loss
l1: float = 0.2 # L1 rec loss weight
l2: float = 1.0 # L2 rec loss weight
lp: float = 1.0 # lpips loss weight
lpr: int = 48 # only calculate lpips >= this image resolution
ld: float = 0.4 # discriminator loss weight; if <0: NO ADAPTIVE WEIGHT
le: float = 0.0 # VQ entropy loss weight
lq: float = 1.0
lc: float = 1.0 # CLIP loss weight
e_temp: float = 0.01
gada: int = 1
bcr: float = 4. # balanced Consistency Regularization, used on small dataset with low reso, StyleSwin: 10.0
bcr_cut: float = 0.2 # cutout ratio (0.5: 50% width)
dcrit: str = 'hg' # hg hinge, sp softplus, ln linear
# wandb log
report_wandb: bool = True
wandb_notes: str = None
run_id: str = None
# debug
eval_per_epoch: int = 8
dbg_unused_param: bool = False
dbg_nan: bool = False # 'KEVIN_LOCAL' in os.environ
seed: int = None
deterministic: bool = False
same_seed_for_all_ranks: int = 0 # this is only for distributed sampler
def seed_everything(self):
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False
if self.seed is not None:
if self.deterministic:
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.use_deterministic_algorithms(True)
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':16:8'
seed = self.seed + dist.get_rank() * 10000
os.environ['PYTHONHASHSEED'] = str(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def get_different_generator_for_each_rank(self) -> Optional[torch.Generator]: # for random augmentation
if self.seed is None:
return None
g = torch.Generator()
g.manual_seed(self.seed * dist.get_world_size() + dist.get_rank())
return g
def state_dict(self, key_ordered=True) -> Union[OrderedDict, dict]:
d = (OrderedDict if key_ordered else dict)()
for k in self.class_variables.keys():
if k not in {'device'}: # these are not serializable
d[k] = getattr(self, k)
return d
def load_state_dict(self, state_dict):
for k, v in state_dict.items():
try:
setattr(self, k, v)
except Exception as e:
print(f'k={k}, v={v}')
raise e
@staticmethod
def set_tf32(tf32: bool):
if torch.cuda.is_available():
torch.backends.cudnn.allow_tf32 = bool(tf32)
torch.backends.cuda.matmul.allow_tf32 = bool(tf32)
if hasattr(torch, 'set_float32_matmul_precision'):
torch.set_float32_matmul_precision('high' if tf32 else 'highest')
print(f'[tf32] [precis] torch.get_float32_matmul_precision(): {torch.get_float32_matmul_precision()}')
print(f'[tf32] [ conv ] torch.backends.cudnn.allow_tf32: {torch.backends.cudnn.allow_tf32}')
print(f'[tf32] [matmul] torch.backends.cuda.matmul.allow_tf32: {torch.backends.cuda.matmul.allow_tf32}')
def __str__(self):
s = []
for k in self.class_variables.keys():
if k not in {'device', 'dbg_ks_fp'}: # these are not serializable
s.append(f' {k:20s}: {getattr(self, k)}')
s = '\n'.join(s)
return f'{{\n{s}\n}}\n'
def init_dist_and_get_args():
for i in range(len(sys.argv)):
if sys.argv[i].startswith('--local-rank=') or sys.argv[i].startswith('--local_rank='):
del sys.argv[i]
break
args = Args(explicit_bool=True).parse_args(known_only=True)
# warn args.extra_args
if len(args.extra_args) > 0:
print(f'======================================================================================')
print(f'=========================== WARNING: UNEXPECTED EXTRA ARGS ===========================\n{args.extra_args}')
print(f'=========================== WARNING: UNEXPECTED EXTRA ARGS ===========================')
print(f'======================================================================================\n\n')
# init torch distributed
os.makedirs(args.output_dir, exist_ok=True)
dist.init_distributed_mode(local_out_path=args.output_dir, timeout_minutes=30)
# set env
args.set_tf32(args.tf32)
args.seed_everything()
args.device = dist.get_device()
# update args
if args.local_bs == 0:
args.local_bs = max(1, round(args.global_bs / args.grad_accu / dist.get_world_size()))
args.global_bs = args.local_bs * dist.get_world_size()
if args.fp16 or args.bf16:
args.dtype = torch.float16 if args.fp16 else torch.bfloat16
return args