Spaces:
Paused
Paused
File size: 17,946 Bytes
0bc84fc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 |
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses
import gc
import itertools
import logging
import os
import random
from copy import deepcopy
from typing import TYPE_CHECKING, Literal
import torch
import torch.nn.functional as F
import transformers
from accelerate import Accelerator, DeepSpeedPlugin
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from diffusers.optimization import get_scheduler
from einops import rearrange
from PIL import Image
from safetensors.torch import load_file
from torch.utils.data import DataLoader
from tqdm import tqdm
from uno.dataset.uno import FluxPairedDatasetV2
from uno.flux.sampling import denoise, get_noise, get_schedule, prepare_multi_ip, unpack
from uno.flux.util import load_ae, load_clip, load_flow_model, load_t5, set_lora
if TYPE_CHECKING:
from uno.flux.model import Flux
from uno.flux.modules.autoencoder import AutoEncoder
from uno.flux.modules.conditioner import HFEmbedder
logger = get_logger(__name__)
def get_models(name: str, device, offload: bool=False):
t5 = load_t5(device, max_length=512)
clip = load_clip(device)
model = load_flow_model(name, device="cpu")
vae = load_ae(name, device="cpu" if offload else device)
return model, vae, t5, clip
def inference(
batch: dict,
model: "Flux", t5: "HFEmbedder", clip: "HFEmbedder", ae: "AutoEncoder",
accelerator: Accelerator,
seed: int = 0,
pe: Literal["d", "h", "w", "o"] = "d"
) -> Image.Image:
ref_imgs = batch["ref_imgs"]
prompt = batch["txt"]
neg_prompt = ''
width, height = 512, 512
num_steps = 25
x = get_noise(
1, height, width,
device=accelerator.device,
dtype=torch.bfloat16,
seed=seed + accelerator.process_index
)
timesteps = get_schedule(
num_steps,
(width // 8) * (height // 8) // (16 * 16),
shift=True,
)
with torch.no_grad():
ref_imgs = [
ae.encode(ref_img_.to(accelerator.device, torch.float32)).to(torch.bfloat16)
for ref_img_ in ref_imgs
]
inp_cond = prepare_multi_ip(
t5=t5, clip=clip, img=x, prompt=prompt,
ref_imgs=ref_imgs,
pe=pe
)
neg_inp_cond = prepare_multi_ip(
t5=t5, clip=clip, img=x, prompt=neg_prompt,
ref_imgs=ref_imgs,
pe=pe
)
x = denoise(
model,
**inp_cond,
timesteps=timesteps,
guidance=4,
timestep_to_start_cfg=30,
neg_txt=neg_inp_cond['txt'],
neg_txt_ids=neg_inp_cond['txt_ids'],
neg_vec=neg_inp_cond['vec'],
true_gs=3.5,
image_proj=None,
neg_image_proj=None,
ip_scale=1,
neg_ip_scale=1
)
x = unpack(x.float(), height, width)
x = ae.decode(x)
x1 = x.clamp(-1, 1)
x1 = rearrange(x1[-1], "c h w -> h w c")
output_img = Image.fromarray((127.5 * (x1 + 1.0)).cpu().byte().numpy())
return output_img
def resume_from_checkpoint(
resume_from_checkpoint: str | None | Literal["latest"],
project_dir: str,
accelerator: Accelerator,
dit: "Flux",
optimizer: torch.optim.Optimizer,
lr_scheduler: torch.optim.lr_scheduler.LRScheduler,
dit_ema_dict: dict | None = None,
) -> tuple["Flux", torch.optim.Optimizer, torch.optim.lr_scheduler.LRScheduler, dict | None, int]:
# Potentially load in the weights and states from a previous save
if resume_from_checkpoint is None:
return dit, optimizer, lr_scheduler, dit_ema_dict, 0
if resume_from_checkpoint == "latest":
# Get the most recent checkpoint
dirs = os.listdir(project_dir)
dirs = [d for d in dirs if d.startswith("checkpoint")]
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
if len(dirs) == 0:
accelerator.print(
f"Checkpoint '{resume_from_checkpoint}' does not exist. Starting a new training run."
)
return dit, optimizer, lr_scheduler, dit_ema_dict, 0
path = dirs[-1]
else:
path = os.path.basename(resume_from_checkpoint)
accelerator.print(f"Resuming from checkpoint {path}")
lora_state = load_file(os.path.join(project_dir, path, 'dit_lora.safetensors'), device=accelerator.device)
unwarp_dit = accelerator.unwrap_model(dit)
unwarp_dit.load_state_dict(lora_state, strict=False)
if dit_ema_dict is not None:
dit_ema_dict = load_file(
os.path.join(project_dir, path, 'dit_lora_ema.safetensors'),
device=accelerator.device
)
if dit is not unwarp_dit:
dit_ema_dict = {f"module.{k}": v for k, v in dit_ema_dict.items() if k in unwarp_dit.state_dict()}
global_step = int(path.split("-")[1])
return dit, optimizer, lr_scheduler, dit_ema_dict, global_step
@dataclasses.dataclass
class TrainArgs:
## accelerator
project_dir: str | None = None
mixed_precision: Literal["no", "fp16", "bf16"] = "bf16"
gradient_accumulation_steps: int = 1,
seed: int = 42
wandb_project_name: str | None = None
wandb_run_name: str | None = None
## model
model_name: Literal["flux", "flux-schnell"] = "flux"
lora_rank: int = 512
double_blocks_indices: list[int] | None = dataclasses.field(
default=None,
metadata={"help": "Indices of double blocks to apply LoRA. None means all double blocks."}
)
single_blocks_indices: list[int] | None = dataclasses.field(
default=None,
metadata={"help": "Indices of double blocks to apply LoRA. None means all single blocks."}
)
pe: Literal["d", "h", "w", "o"] = "d",
gradient_checkpoint: bool = False
## ema
ema: bool = False
ema_interval: int = 1
ema_decay: float = 0.99
## optimizer
learning_rate: float = 1e-2
adam_betas: list[float] = dataclasses.field(default_factory=lambda: [0.9, 0.999])
adam_eps: float = 1e-8
adam_weight_decay: float = 0.01
## lr_scheduler
lr_scheduler: str = "constant"
lr_warmup_steps: int = 100
max_train_steps: int = 100000
## dataloader
train_data_json: str = "datasets/dreambench_singleip.json" # TODO: change to your own dataset, or use one data syenthsize pipeline comming in the future. stay tuned
batch_size: int = 1
text_dropout: float = 0.1
resolution: int = 512
resolution_ref: int | None = None
eval_data_json: str = "datasets/dreambench_singleip.json"
eval_batch_size: int = 1
## misc
resume_from_checkpoint: str | None | Literal["latest"] = None
checkpointing_steps: int = 1000
def main(
args: TrainArgs,
):
## accelerator
deepspeed_plugins = {
"dit": DeepSpeedPlugin(hf_ds_config='config/deepspeed/zero2_config.json'),
"t5": DeepSpeedPlugin(hf_ds_config='config/deepspeed/zero3_config.json'),
"clip": DeepSpeedPlugin(hf_ds_config='config/deepspeed/zero3_config.json')
}
accelerator = Accelerator(
project_dir=args.project_dir,
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
deepspeed_plugins=deepspeed_plugins,
log_with="wandb",
)
set_seed(args.seed, device_specific=True)
accelerator.init_trackers(
project_name=args.wandb_project_name,
config=args.__dict__,
init_kwargs={
"wandb": {
"name": args.wandb_run_name,
"dir": accelerator.project_dir,
},
},
)
weight_dtype = {
"fp16": torch.float16,
"bf16": torch.bfloat16,
"no": torch.float32,
}.get(accelerator.mixed_precision, torch.float32)
## logger
logging.basicConfig(
format=f"[RANK {accelerator.process_index}] " + "%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
force=True
)
logger.info(accelerator.state)
logger.info("Training script launched", main_process_only=False)
## model
dit, vae, t5, clip = get_models(
name=args.model_name,
device=accelerator.device,
)
vae.requires_grad_(False)
t5.requires_grad_(False)
clip.requires_grad_(False)
dit.requires_grad_(False)
dit = set_lora(dit, args.lora_rank, args.double_blocks_indices, args.single_blocks_indices, accelerator.device)
dit.train()
dit.gradient_checkpointing = args.gradient_checkpoint
## optimizer and lr scheduler
optimizer = torch.optim.AdamW(
[p for p in dit.parameters() if p.requires_grad],
lr=args.learning_rate,
betas=args.adam_betas,
weight_decay=args.adam_weight_decay,
eps=args.adam_eps,
)
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * accelerator.num_processes,
)
# dataloader
dataset = FluxPairedDatasetV2(
data_json=args.train_data_json,
resolution=args.resolution, resolution_ref=args.resolution_ref
)
dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
eval_dataset = FluxPairedDatasetV2(
data_json=args.eval_data_json,
resolution=args.resolution, resolution_ref=args.resolution_ref
)
eval_dataloader = DataLoader(eval_dataset, batch_size=args.eval_batch_size, shuffle=False)
dataloader = accelerator.prepare_data_loader(dataloader)
eval_dataloader = accelerator.prepare_data_loader(eval_dataloader)
dataloader = itertools.cycle(dataloader) # as infinite fetch data loader
## parallel
dit = accelerator.prepare_model(dit)
optimizer = accelerator.prepare_optimizer(optimizer)
lr_scheduler = accelerator.prepare_scheduler(lr_scheduler)
accelerator.state.select_deepspeed_plugin("t5")
t5 = accelerator.prepare_model(t5)
accelerator.state.select_deepspeed_plugin("clip")
clip = accelerator.prepare_model(clip)
## ema
dit_ema_dict = {
k: deepcopy(v).requires_grad_(False) for k, v in dit.named_parameters() if v.requires_grad
} if args.ema else None
## resume
(
dit,
optimizer,
lr_scheduler,
dit_ema_dict,
global_step
) = resume_from_checkpoint(
args.resume_from_checkpoint,
project_dir=args.project_dir,
accelerator=accelerator,
dit=dit,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
dit_ema_dict=dit_ema_dict
)
## noise scheduler
timesteps = get_schedule(
999,
(args.resolution // 8) * (args.resolution // 8) // 4,
shift=True,
)
timesteps = torch.tensor(timesteps, device=accelerator.device)
total_batch_size = args.batch_size * accelerator.num_processes * args.gradient_accumulation_steps
logger.info("***** Running training *****")
logger.info(f" Instantaneous batch size per device = {args.batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {args.max_train_steps}")
logger.info(f" Total validation prompts = {len(eval_dataloader)}")
progress_bar = tqdm(
range(0, args.max_train_steps),
initial=global_step,
desc="Steps",
total=args.max_train_steps,
disable=not accelerator.is_local_main_process,
)
train_loss = 0.0
while global_step < (args.max_train_steps):
batch = next(dataloader)
prompts = [txt_ if random.random() > args.text_dropout else "" for txt_ in batch["txt"]]
img = batch["img"]
ref_imgs = batch["ref_imgs"]
with torch.no_grad():
x_1 = vae.encode(img.to(accelerator.device).to(torch.float32))
x_ref = [vae.encode(ref_img.to(accelerator.device).to(torch.float32)) for ref_img in ref_imgs]
inp = prepare_multi_ip(t5=t5, clip=clip, img=x_1, prompt=prompts, ref_imgs=tuple(x_ref), pe=args.pe)
x_1 = rearrange(x_1, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
x_ref = [rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) for x in x_ref]
bs = img.shape[0]
t = torch.randint(0, 1000, (bs,), device=accelerator.device)
t = timesteps[t]
x_0 = torch.randn_like(x_1, device=accelerator.device)
x_t = (1 - t[:, None, None]) * x_1 + t[:, None, None] * x_0
guidance_vec = torch.full((x_t.shape[0],), 1, device=x_t.device, dtype=x_t.dtype)
with accelerator.accumulate(dit):
# Predict the noise residual and compute loss
model_pred = dit(
img=x_t.to(weight_dtype),
img_ids=inp['img_ids'].to(weight_dtype),
ref_img=[x.to(weight_dtype) for x in x_ref],
ref_img_ids=[ref_img_id.to(weight_dtype) for ref_img_id in inp['ref_img_ids']],
txt=inp['txt'].to(weight_dtype),
txt_ids=inp['txt_ids'].to(weight_dtype),
y=inp['vec'].to(weight_dtype),
timesteps=t.to(weight_dtype),
guidance=guidance_vec.to(weight_dtype)
)
loss = F.mse_loss(model_pred.float(), (x_0 - x_1).float(), reduction="mean")
# Gather the losses across all processes for logging (if we use distributed training).
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
train_loss += avg_loss.item() / args.gradient_accumulation_steps
# Backpropagate
accelerator.backward(loss)
if accelerator.sync_gradients:
accelerator.clip_grad_norm_(dit.parameters(), args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
accelerator.log({"train_loss": train_loss}, step=global_step)
train_loss = 0.0
if accelerator.sync_gradients and dit_ema_dict is not None and global_step % args.ema_interval == 0:
src_dict = dit.state_dict()
for tgt_name in dit_ema_dict:
dit_ema_dict[tgt_name].data.lerp_(src_dict[tgt_name].to(dit_ema_dict[tgt_name]), 1 - args.ema_decay)
if accelerator.sync_gradients and accelerator.is_main_process and global_step % args.checkpointing_steps == 0:
logger.info(f"saving checkpoint in {global_step=}")
save_path = os.path.join(args.project_dir, f"checkpoint-{global_step}")
os.makedirs(save_path, exist_ok=True)
# save
accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(dit)
unwrapped_model_state = unwrapped_model.state_dict()
unwrapped_model_state = {k: v for k, v in unwrapped_model_state.items() if v.requires_grad}
accelerator.save(
unwrapped_model_state,
os.path.join(save_path, 'dit_lora.safetensors'),
safe_serialization=True
)
unwrapped_opt = accelerator.unwrap_model(optimizer)
accelerator.save(unwrapped_opt.state_dict(), os.path.join(save_path, 'optimizer.bin'))
logger.info(f"Saved state to {save_path}")
if args.ema:
accelerator.save(
{k.split("module.")[-1]: v for k, v in dit_ema_dict.items()},
os.path.join(save_path, 'dit_lora_ema.safetensors')
)
# validate
dit.eval()
torch.set_grad_enabled(False)
for i, batch in enumerate(eval_dataloader):
result = inference(batch, dit, t5, clip, vae, accelerator, seed=0)
accelerator.log({f"eval_gen_{i}": result}, step=global_step)
if args.ema:
original_state_dict = dit.state_dict()
dit.load_state_dict(dit_ema_dict, strict=False)
for batch in eval_dataloader:
result = inference(batch, dit, t5, clip, vae, accelerator, seed=0)
accelerator.log({f"eval_ema_gen_{i}": result}, step=global_step)
dit.load_state_dict(original_state_dict, strict=False)
torch.cuda.empty_cache()
gc.collect()
torch.set_grad_enabled(True)
dit.train()
accelerator.wait_for_everyone()
logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
accelerator.wait_for_everyone()
accelerator.end_training()
if __name__ == "__main__":
parser = transformers.HfArgumentParser([TrainArgs])
args_tuple = parser.parse_args_into_dataclasses(args_file_flag="--config")
main(*args_tuple)
|