# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import dataclasses import gc import itertools import logging import os import random from copy import deepcopy from typing import TYPE_CHECKING, Literal import torch import torch.nn.functional as F import transformers from accelerate import Accelerator, DeepSpeedPlugin from accelerate.logging import get_logger from accelerate.utils import set_seed from diffusers.optimization import get_scheduler from einops import rearrange from PIL import Image from safetensors.torch import load_file from torch.utils.data import DataLoader from tqdm import tqdm from uno.dataset.uno import FluxPairedDatasetV2 from uno.flux.sampling import denoise, get_noise, get_schedule, prepare_multi_ip, unpack from uno.flux.util import load_ae, load_clip, load_flow_model, load_t5, set_lora if TYPE_CHECKING: from uno.flux.model import Flux from uno.flux.modules.autoencoder import AutoEncoder from uno.flux.modules.conditioner import HFEmbedder logger = get_logger(__name__) def get_models(name: str, device, offload: bool=False): t5 = load_t5(device, max_length=512) clip = load_clip(device) model = load_flow_model(name, device="cpu") vae = load_ae(name, device="cpu" if offload else device) return model, vae, t5, clip def inference( batch: dict, model: "Flux", t5: "HFEmbedder", clip: "HFEmbedder", ae: "AutoEncoder", accelerator: Accelerator, seed: int = 0, pe: Literal["d", "h", "w", "o"] = "d" ) -> Image.Image: ref_imgs = batch["ref_imgs"] prompt = batch["txt"] neg_prompt = '' width, height = 512, 512 num_steps = 25 x = get_noise( 1, height, width, device=accelerator.device, dtype=torch.bfloat16, seed=seed + accelerator.process_index ) timesteps = get_schedule( num_steps, (width // 8) * (height // 8) // (16 * 16), shift=True, ) with torch.no_grad(): ref_imgs = [ ae.encode(ref_img_.to(accelerator.device, torch.float32)).to(torch.bfloat16) for ref_img_ in ref_imgs ] inp_cond = prepare_multi_ip( t5=t5, clip=clip, img=x, prompt=prompt, ref_imgs=ref_imgs, pe=pe ) neg_inp_cond = prepare_multi_ip( t5=t5, clip=clip, img=x, prompt=neg_prompt, ref_imgs=ref_imgs, pe=pe ) x = denoise( model, **inp_cond, timesteps=timesteps, guidance=4, timestep_to_start_cfg=30, neg_txt=neg_inp_cond['txt'], neg_txt_ids=neg_inp_cond['txt_ids'], neg_vec=neg_inp_cond['vec'], true_gs=3.5, image_proj=None, neg_image_proj=None, ip_scale=1, neg_ip_scale=1 ) x = unpack(x.float(), height, width) x = ae.decode(x) x1 = x.clamp(-1, 1) x1 = rearrange(x1[-1], "c h w -> h w c") output_img = Image.fromarray((127.5 * (x1 + 1.0)).cpu().byte().numpy()) return output_img def resume_from_checkpoint( resume_from_checkpoint: str | None | Literal["latest"], project_dir: str, accelerator: Accelerator, dit: "Flux", optimizer: torch.optim.Optimizer, lr_scheduler: torch.optim.lr_scheduler.LRScheduler, dit_ema_dict: dict | None = None, ) -> tuple["Flux", torch.optim.Optimizer, torch.optim.lr_scheduler.LRScheduler, dict | None, int]: # Potentially load in the weights and states from a previous save if resume_from_checkpoint is None: return dit, optimizer, lr_scheduler, dit_ema_dict, 0 if resume_from_checkpoint == "latest": # Get the most recent checkpoint dirs = os.listdir(project_dir) dirs = [d for d in dirs if d.startswith("checkpoint")] dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) if len(dirs) == 0: accelerator.print( f"Checkpoint '{resume_from_checkpoint}' does not exist. Starting a new training run." ) return dit, optimizer, lr_scheduler, dit_ema_dict, 0 path = dirs[-1] else: path = os.path.basename(resume_from_checkpoint) accelerator.print(f"Resuming from checkpoint {path}") lora_state = load_file(os.path.join(project_dir, path, 'dit_lora.safetensors'), device=accelerator.device) unwarp_dit = accelerator.unwrap_model(dit) unwarp_dit.load_state_dict(lora_state, strict=False) if dit_ema_dict is not None: dit_ema_dict = load_file( os.path.join(project_dir, path, 'dit_lora_ema.safetensors'), device=accelerator.device ) if dit is not unwarp_dit: dit_ema_dict = {f"module.{k}": v for k, v in dit_ema_dict.items() if k in unwarp_dit.state_dict()} global_step = int(path.split("-")[1]) return dit, optimizer, lr_scheduler, dit_ema_dict, global_step @dataclasses.dataclass class TrainArgs: ## accelerator project_dir: str | None = None mixed_precision: Literal["no", "fp16", "bf16"] = "bf16" gradient_accumulation_steps: int = 1, seed: int = 42 wandb_project_name: str | None = None wandb_run_name: str | None = None ## model model_name: Literal["flux", "flux-schnell"] = "flux" lora_rank: int = 512 double_blocks_indices: list[int] | None = dataclasses.field( default=None, metadata={"help": "Indices of double blocks to apply LoRA. None means all double blocks."} ) single_blocks_indices: list[int] | None = dataclasses.field( default=None, metadata={"help": "Indices of double blocks to apply LoRA. None means all single blocks."} ) pe: Literal["d", "h", "w", "o"] = "d", gradient_checkpoint: bool = False ## ema ema: bool = False ema_interval: int = 1 ema_decay: float = 0.99 ## optimizer learning_rate: float = 1e-2 adam_betas: list[float] = dataclasses.field(default_factory=lambda: [0.9, 0.999]) adam_eps: float = 1e-8 adam_weight_decay: float = 0.01 ## lr_scheduler lr_scheduler: str = "constant" lr_warmup_steps: int = 100 max_train_steps: int = 100000 ## dataloader train_data_json: str = "datasets/dreambench_singleip.json" # TODO: change to your own dataset, or use one data syenthsize pipeline comming in the future. stay tuned batch_size: int = 1 text_dropout: float = 0.1 resolution: int = 512 resolution_ref: int | None = None eval_data_json: str = "datasets/dreambench_singleip.json" eval_batch_size: int = 1 ## misc resume_from_checkpoint: str | None | Literal["latest"] = None checkpointing_steps: int = 1000 def main( args: TrainArgs, ): ## accelerator deepspeed_plugins = { "dit": DeepSpeedPlugin(hf_ds_config='config/deepspeed/zero2_config.json'), "t5": DeepSpeedPlugin(hf_ds_config='config/deepspeed/zero3_config.json'), "clip": DeepSpeedPlugin(hf_ds_config='config/deepspeed/zero3_config.json') } accelerator = Accelerator( project_dir=args.project_dir, gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision, deepspeed_plugins=deepspeed_plugins, log_with="wandb", ) set_seed(args.seed, device_specific=True) accelerator.init_trackers( project_name=args.wandb_project_name, config=args.__dict__, init_kwargs={ "wandb": { "name": args.wandb_run_name, "dir": accelerator.project_dir, }, }, ) weight_dtype = { "fp16": torch.float16, "bf16": torch.bfloat16, "no": torch.float32, }.get(accelerator.mixed_precision, torch.float32) ## logger logging.basicConfig( format=f"[RANK {accelerator.process_index}] " + "%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, force=True ) logger.info(accelerator.state) logger.info("Training script launched", main_process_only=False) ## model dit, vae, t5, clip = get_models( name=args.model_name, device=accelerator.device, ) vae.requires_grad_(False) t5.requires_grad_(False) clip.requires_grad_(False) dit.requires_grad_(False) dit = set_lora(dit, args.lora_rank, args.double_blocks_indices, args.single_blocks_indices, accelerator.device) dit.train() dit.gradient_checkpointing = args.gradient_checkpoint ## optimizer and lr scheduler optimizer = torch.optim.AdamW( [p for p in dit.parameters() if p.requires_grad], lr=args.learning_rate, betas=args.adam_betas, weight_decay=args.adam_weight_decay, eps=args.adam_eps, ) lr_scheduler = get_scheduler( args.lr_scheduler, optimizer=optimizer, num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, num_training_steps=args.max_train_steps * accelerator.num_processes, ) # dataloader dataset = FluxPairedDatasetV2( data_json=args.train_data_json, resolution=args.resolution, resolution_ref=args.resolution_ref ) dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True) eval_dataset = FluxPairedDatasetV2( data_json=args.eval_data_json, resolution=args.resolution, resolution_ref=args.resolution_ref ) eval_dataloader = DataLoader(eval_dataset, batch_size=args.eval_batch_size, shuffle=False) dataloader = accelerator.prepare_data_loader(dataloader) eval_dataloader = accelerator.prepare_data_loader(eval_dataloader) dataloader = itertools.cycle(dataloader) # as infinite fetch data loader ## parallel dit = accelerator.prepare_model(dit) optimizer = accelerator.prepare_optimizer(optimizer) lr_scheduler = accelerator.prepare_scheduler(lr_scheduler) accelerator.state.select_deepspeed_plugin("t5") t5 = accelerator.prepare_model(t5) accelerator.state.select_deepspeed_plugin("clip") clip = accelerator.prepare_model(clip) ## ema dit_ema_dict = { k: deepcopy(v).requires_grad_(False) for k, v in dit.named_parameters() if v.requires_grad } if args.ema else None ## resume ( dit, optimizer, lr_scheduler, dit_ema_dict, global_step ) = resume_from_checkpoint( args.resume_from_checkpoint, project_dir=args.project_dir, accelerator=accelerator, dit=dit, optimizer=optimizer, lr_scheduler=lr_scheduler, dit_ema_dict=dit_ema_dict ) ## noise scheduler timesteps = get_schedule( 999, (args.resolution // 8) * (args.resolution // 8) // 4, shift=True, ) timesteps = torch.tensor(timesteps, device=accelerator.device) total_batch_size = args.batch_size * accelerator.num_processes * args.gradient_accumulation_steps logger.info("***** Running training *****") logger.info(f" Instantaneous batch size per device = {args.batch_size}") logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") logger.info(f" Total optimization steps = {args.max_train_steps}") logger.info(f" Total validation prompts = {len(eval_dataloader)}") progress_bar = tqdm( range(0, args.max_train_steps), initial=global_step, desc="Steps", total=args.max_train_steps, disable=not accelerator.is_local_main_process, ) train_loss = 0.0 while global_step < (args.max_train_steps): batch = next(dataloader) prompts = [txt_ if random.random() > args.text_dropout else "" for txt_ in batch["txt"]] img = batch["img"] ref_imgs = batch["ref_imgs"] with torch.no_grad(): x_1 = vae.encode(img.to(accelerator.device).to(torch.float32)) x_ref = [vae.encode(ref_img.to(accelerator.device).to(torch.float32)) for ref_img in ref_imgs] inp = prepare_multi_ip(t5=t5, clip=clip, img=x_1, prompt=prompts, ref_imgs=tuple(x_ref), pe=args.pe) x_1 = rearrange(x_1, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) x_ref = [rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) for x in x_ref] bs = img.shape[0] t = torch.randint(0, 1000, (bs,), device=accelerator.device) t = timesteps[t] x_0 = torch.randn_like(x_1, device=accelerator.device) x_t = (1 - t[:, None, None]) * x_1 + t[:, None, None] * x_0 guidance_vec = torch.full((x_t.shape[0],), 1, device=x_t.device, dtype=x_t.dtype) with accelerator.accumulate(dit): # Predict the noise residual and compute loss model_pred = dit( img=x_t.to(weight_dtype), img_ids=inp['img_ids'].to(weight_dtype), ref_img=[x.to(weight_dtype) for x in x_ref], ref_img_ids=[ref_img_id.to(weight_dtype) for ref_img_id in inp['ref_img_ids']], txt=inp['txt'].to(weight_dtype), txt_ids=inp['txt_ids'].to(weight_dtype), y=inp['vec'].to(weight_dtype), timesteps=t.to(weight_dtype), guidance=guidance_vec.to(weight_dtype) ) loss = F.mse_loss(model_pred.float(), (x_0 - x_1).float(), reduction="mean") # Gather the losses across all processes for logging (if we use distributed training). avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() train_loss += avg_loss.item() / args.gradient_accumulation_steps # Backpropagate accelerator.backward(loss) if accelerator.sync_gradients: accelerator.clip_grad_norm_(dit.parameters(), args.max_grad_norm) optimizer.step() lr_scheduler.step() optimizer.zero_grad() # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: progress_bar.update(1) global_step += 1 accelerator.log({"train_loss": train_loss}, step=global_step) train_loss = 0.0 if accelerator.sync_gradients and dit_ema_dict is not None and global_step % args.ema_interval == 0: src_dict = dit.state_dict() for tgt_name in dit_ema_dict: dit_ema_dict[tgt_name].data.lerp_(src_dict[tgt_name].to(dit_ema_dict[tgt_name]), 1 - args.ema_decay) if accelerator.sync_gradients and accelerator.is_main_process and global_step % args.checkpointing_steps == 0: logger.info(f"saving checkpoint in {global_step=}") save_path = os.path.join(args.project_dir, f"checkpoint-{global_step}") os.makedirs(save_path, exist_ok=True) # save accelerator.wait_for_everyone() unwrapped_model = accelerator.unwrap_model(dit) unwrapped_model_state = unwrapped_model.state_dict() unwrapped_model_state = {k: v for k, v in unwrapped_model_state.items() if v.requires_grad} accelerator.save( unwrapped_model_state, os.path.join(save_path, 'dit_lora.safetensors'), safe_serialization=True ) unwrapped_opt = accelerator.unwrap_model(optimizer) accelerator.save(unwrapped_opt.state_dict(), os.path.join(save_path, 'optimizer.bin')) logger.info(f"Saved state to {save_path}") if args.ema: accelerator.save( {k.split("module.")[-1]: v for k, v in dit_ema_dict.items()}, os.path.join(save_path, 'dit_lora_ema.safetensors') ) # validate dit.eval() torch.set_grad_enabled(False) for i, batch in enumerate(eval_dataloader): result = inference(batch, dit, t5, clip, vae, accelerator, seed=0) accelerator.log({f"eval_gen_{i}": result}, step=global_step) if args.ema: original_state_dict = dit.state_dict() dit.load_state_dict(dit_ema_dict, strict=False) for batch in eval_dataloader: result = inference(batch, dit, t5, clip, vae, accelerator, seed=0) accelerator.log({f"eval_ema_gen_{i}": result}, step=global_step) dit.load_state_dict(original_state_dict, strict=False) torch.cuda.empty_cache() gc.collect() torch.set_grad_enabled(True) dit.train() accelerator.wait_for_everyone() logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) accelerator.wait_for_everyone() accelerator.end_training() if __name__ == "__main__": parser = transformers.HfArgumentParser([TrainArgs]) args_tuple = parser.parse_args_into_dataclasses(args_file_flag="--config") main(*args_tuple)