Spaces:
Running
on
Zero
Running
on
Zero
File size: 9,529 Bytes
7385f22 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 |
import os
import sys
import torch
import random
import numpy as np
from tap import Tap
from typing import Optional, Union
from collections import OrderedDict
from unitok import dist
class Args(Tap):
model: str = 'vitamin_large' # 'vitamin_base', 'vitamin_large', xxx
exp_name: str = 'unitok_large'
output_dir: str = 'local_output'
resume_from: str = '' # if specified, load this checkpoint; if not, load the latest checkpoint in output_dir (if exists)
lpips_path: str = 'external/lpips_with_vgg.pth'
dino_path: str = 'external/dinov2_vits14_pretrain.pth'
fid_eval_src: str = ''
fid_eval_dst: str = ''
vis_img_dir: str = 'asset/vis_imgs/'
fid_feature_extractor: str = 'external/weights-inception-2015-12-05-6726825d.pth'
clip_pretrain_path: str = ''
# speed-up
fp16: bool = False # whether to use FP16
bf16: bool = True # whether to use BF16
tf32: bool = True # whether to use TensorFloat32
compile_model: bool = False # whether to use torch.compile()
ddp_static: bool = False # whether to use static graph in DDP
grad_ckpt: bool = True # gradient checkpointing
grad_accu: int = 1 # gradient accumulation
device: str = 'cpu' # will be set automatically
dtype: torch.dtype = torch.float32 # will be set automatically
# data
train_data: str = None
val_data: str = None
dataset_type: str = 'webdataset'
imagenet_val: str = None
imagenet_v2: str = None
subset_ratio: float = 1.0
img_size: int = 256
resize_ratio: float = 1.125 # only applicable to 'img' dataset_type
hflip: bool = False
workers: int = 8 # num workers; 0: auto, -1: don't use multiprocessing in DataLoader
train_num_samples: int = 1280_000_000
train_data_upsampling_factors: str = None
dataset_resampled: bool = False
use_aug: bool = False
# quantizer
vocab_size: int = 32768
vocab_width: int = 64
vocab_norm: bool = True
vq_beta: float = 0.25 # commitment loss weight
num_codebooks: int = 8
quant_proj: str = 'attn'
# model
embed_dim: int = 768
num_query: int = 0
use_clip_pretrain: bool = False
patch_size: int = 16
drop_path: float = 0.1
text_width: int = 768
text_heads: int = 12
text_layers: int = 12
text_vocab_size: int = 49408
text_context_length: int = 77
# CLIP
local_loss: bool = True
gather_with_grad: bool = True
pretrained_clip: str = None
pretrained_clip_text: str = None
lock_text: bool = False
lock_text_unlocked_layers: int = 0
lock_text_freeze_layer_norm: bool = False
force_custom_text: bool = False
force_custom_vision: bool = False
zeroshot_eval_freq: int = 1
# discriminator
dino_depth: int = 12
dino_kernel_size: int = 9
disc_norm: str = 'gn' # gn: group norm, bn: batch norm, sbn: sync batch norm, hbn: hybrid sync batch norm
disc_aug_prob: float = 1.0
disc_specnorm: bool = False
step_disc_every: int = 1
# initialization
vae_init: float = -0.5 # <0: xavier_normal_(gain=abs(init)); >0: trunc_normal_(std=init)
vocab_init: float = -1 # <0: uniform(-abs(init)*base, abs(init)*base), where base = 20/vocab_size; >0: trunc_normal_(std=init)
disc_init: float = -0.5 # <0: xavier_normal_(gain=abs(init)); >0: trunc_normal_(std=init)
# optimization
epoch: int = 1 # number of epochs
local_bs: int = 64 # batch size per device; if this is specified, --global_bs will be ignored
vae_local_bs: int = 64 # sub-batch size for vae loss calculation
global_bs: int = 0 # global batch size (exclusive to --local_bs)
lr: float = 5e-4 # learning rate
wd: float = 0.02 # weight decay
disc_lr: float = 2e-5 # disc lr
disc_wd: float = 0.2
grad_clip: float = 10 # <=0 for not using grad clip
ema: float = 0.9999 # ema ratio
warmup_iter: int = None
warmup_ep: float = 0.01 # lr warmup: epochs
disc_start_ep: float = 0.375 # start using disc loss for VAE after xxx epochs;
disc_warmup_ep: float = 0.03 # disc loss warm up epochs;
schedule: str = 'cos' # lr schedule type
lr_start_ratio: float = 0. # lr warmup: initial lr ratio
lr_end_ratio: float = 0.1 # lr schedule: final lr ratio
disc_lr_end_ratio: float = 0.1
custom_lr_multiplier: float = None
optimizer: str = 'adamw'
optim_eps: float = 1e-6
fuse_opt: bool = False # whether to use fused optimizer
optim_beta: str = '0.9_0.95' # beta1, beta2 of optimizer
disc_optim_beta: str = '0.5_0.9' # beta1, beta2 of disc optimizer
# loss
l1: float = 0.2 # L1 rec loss weight
l2: float = 1.0 # L2 rec loss weight
lp: float = 1.0 # lpips loss weight
lpr: int = 48 # only calculate lpips >= this image resolution
ld: float = 0.4 # discriminator loss weight; if <0: NO ADAPTIVE WEIGHT
le: float = 0.0 # VQ entropy loss weight
lq: float = 1.0
lc: float = 1.0 # CLIP loss weight
e_temp: float = 0.01
gada: int = 1
bcr: float = 4. # balanced Consistency Regularization, used on small dataset with low reso, StyleSwin: 10.0
bcr_cut: float = 0.2 # cutout ratio (0.5: 50% width)
dcrit: str = 'hg' # hg hinge, sp softplus, ln linear
# wandb log
report_wandb: bool = True
wandb_notes: str = None
run_id: str = None
# debug
eval_per_epoch: int = 8
dbg_unused_param: bool = False
dbg_nan: bool = False # 'KEVIN_LOCAL' in os.environ
seed: int = None
deterministic: bool = False
same_seed_for_all_ranks: int = 0 # this is only for distributed sampler
def seed_everything(self):
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False
if self.seed is not None:
if self.deterministic:
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.use_deterministic_algorithms(True)
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':16:8'
seed = self.seed + dist.get_rank() * 10000
os.environ['PYTHONHASHSEED'] = str(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def get_different_generator_for_each_rank(self) -> Optional[torch.Generator]: # for random augmentation
if self.seed is None:
return None
g = torch.Generator()
g.manual_seed(self.seed * dist.get_world_size() + dist.get_rank())
return g
def state_dict(self, key_ordered=True) -> Union[OrderedDict, dict]:
d = (OrderedDict if key_ordered else dict)()
for k in self.class_variables.keys():
if k not in {'device'}: # these are not serializable
d[k] = getattr(self, k)
return d
def load_state_dict(self, state_dict):
for k, v in state_dict.items():
try:
setattr(self, k, v)
except Exception as e:
print(f'k={k}, v={v}')
raise e
@staticmethod
def set_tf32(tf32: bool):
if torch.cuda.is_available():
torch.backends.cudnn.allow_tf32 = bool(tf32)
torch.backends.cuda.matmul.allow_tf32 = bool(tf32)
if hasattr(torch, 'set_float32_matmul_precision'):
torch.set_float32_matmul_precision('high' if tf32 else 'highest')
print(f'[tf32] [precis] torch.get_float32_matmul_precision(): {torch.get_float32_matmul_precision()}')
print(f'[tf32] [ conv ] torch.backends.cudnn.allow_tf32: {torch.backends.cudnn.allow_tf32}')
print(f'[tf32] [matmul] torch.backends.cuda.matmul.allow_tf32: {torch.backends.cuda.matmul.allow_tf32}')
def __str__(self):
s = []
for k in self.class_variables.keys():
if k not in {'device', 'dbg_ks_fp'}: # these are not serializable
s.append(f' {k:20s}: {getattr(self, k)}')
s = '\n'.join(s)
return f'{{\n{s}\n}}\n'
def init_dist_and_get_args():
for i in range(len(sys.argv)):
if sys.argv[i].startswith('--local-rank=') or sys.argv[i].startswith('--local_rank='):
del sys.argv[i]
break
args = Args(explicit_bool=True).parse_args(known_only=True)
# warn args.extra_args
if len(args.extra_args) > 0:
print(f'======================================================================================')
print(f'=========================== WARNING: UNEXPECTED EXTRA ARGS ===========================\n{args.extra_args}')
print(f'=========================== WARNING: UNEXPECTED EXTRA ARGS ===========================')
print(f'======================================================================================\n\n')
# init torch distributed
os.makedirs(args.output_dir, exist_ok=True)
dist.init_distributed_mode(local_out_path=args.output_dir, timeout_minutes=30)
# set env
args.set_tf32(args.tf32)
args.seed_everything()
args.device = dist.get_device()
# update args
if args.local_bs == 0:
args.local_bs = max(1, round(args.global_bs / args.grad_accu / dist.get_world_size()))
args.global_bs = args.local_bs * dist.get_world_size()
if args.fp16 or args.bf16:
args.dtype = torch.float16 if args.fp16 else torch.bfloat16
return args
|