# vae: # class_path: src.models.vae.LatentVAE # init_args: # precompute: true # weight_path: /mnt/bn/wangshuai6/models/sd-vae-ft-ema/ # denoiser: # class_path: src.models.denoiser.decoupled_improved_dit.DDT # init_args: # in_channels: 4 # patch_size: 2 # num_groups: 16 # hidden_size: &hidden_dim 1152 # num_blocks: 28 # num_encoder_blocks: 22 # num_classes: 1000 # conditioner: # class_path: src.models.conditioner.LabelConditioner # init_args: # null_class: 1000 # diffusion_sampler: # class_path: src.diffusion.stateful_flow_matching.sampling.EulerSampler # init_args: # num_steps: 250 # guidance: 3.0 # state_refresh_rate: 1 # guidance_interval_min: 0.3 # guidance_interval_max: 1.0 # timeshift: 1.0 # last_step: 0.04 # scheduler: *scheduler # w_scheduler: src.diffusion.stateful_flow_matching.scheduling.LinearScheduler # guidance_fn: src.diffusion.base.guidance.simple_guidance_fn # step_fn: src.diffusion.stateful_flow_matching.sampling.ode_step_fn import os import torch import spaces import argparse from omegaconf import OmegaConf from src.models.vae import fp2uint8 from src.diffusion.base.guidance import simple_guidance_fn from src.diffusion.stateful_flow_matching.sharing_sampling import EulerSampler from src.diffusion.stateful_flow_matching.scheduling import LinearScheduler from PIL import Image import gradio as gr from huggingface_hub import snapshot_download def instantiate_class(config): kwargs = config.get("init_args", {}) class_module, class_name = config["class_path"].rsplit(".", 1) module = __import__(class_module, fromlist=[class_name]) args_class = getattr(module, class_name) return args_class(**kwargs) def load_model(weight_dict, denosier): prefix = "ema_denoiser." for k, v in denoiser.state_dict().items(): try: v.copy_(weight_dict["state_dict"][prefix + k]) except: print(f"Failed to copy {prefix + k} to denoiser weight") return denoiser class Pipeline: def __init__(self, vae, denoiser, conditioner, diffusion_sampler, resolution, classlabels2ids): self.vae = vae self.denoiser = denoiser self.conditioner = conditioner self.diffusion_sampler = diffusion_sampler self.resolution = resolution self.classlabels2ids = classlabels2ids @spaces.GPU @torch.no_grad() @torch.autocast(device_type="cuda", dtype=torch.bfloat16) def __call__(self, y, num_images, seed, num_steps, guidance, state_refresh_rate, guidance_interval_min, guidance_interval_max, timeshift): self.diffusion_sampler.num_steps = num_steps self.diffusion_sampler.guidance = guidance self.diffusion_sampler.state_refresh_rate = state_refresh_rate self.diffusion_sampler.guidance_interval_min = guidance_interval_min self.diffusion_sampler.guidance_interval_max = guidance_interval_max self.diffusion_sampler.timeshift = timeshift generator = torch.Generator(device="cpu").manual_seed(seed) xT = torch.randn((num_images, 4, self.resolution//8, self.resolution//8), device="cpu", dtype=torch.float32, generator=generator) xT = xT.to("cuda") with torch.no_grad(): condition, uncondition = conditioner([self.classlabels2ids[y],]*num_images) # Sample images: samples = diffusion_sampler(denoiser, xT, condition, uncondition) samples = vae.decode(samples) # fp32 -1,1 -> uint8 0,255 samples = fp2uint8(samples) samples = samples.permute(0, 2, 3, 1).cpu().numpy() images = [] for i in range(num_images): image = Image.fromarray(samples[i]) images.append(image) return images if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, default="configs/repa_improved_ddt_xlen22de6_512.yaml") parser.add_argument("--resolution", type=int, default=512) parser.add_argument("--model_id", type=str, default="MCG-NJU/DDT-XL-22en6de-R512") parser.add_argument("--ckpt_path", type=str, default="models") args = parser.parse_args() if not os.path.exists(args.ckpt_path): snapshot_download(repo_id=args.model_id, local_dir=args.ckpt_path) config = OmegaConf.load(args.config) vae_config = config.model.vae diffusion_sampler_config = config.model.diffusion_sampler denoiser_config = config.model.denoiser conditioner_config = config.model.conditioner vae = instantiate_class(vae_config) denoiser = instantiate_class(denoiser_config) conditioner = instantiate_class(conditioner_config) diffusion_sampler = EulerSampler( scheduler=LinearScheduler(), w_scheduler=LinearScheduler(), guidance_fn=simple_guidance_fn, num_steps=50, guidance=4.0, state_refresh_rate=1, guidance_interval_min=0.3, guidance_interval_max=1.0, timeshift=1.0 ) ckpt_path = os.path.join(args.ckpt_path, "model.ckpt") ckpt = torch.load(ckpt_path, map_location="cpu") denoiser = load_model(ckpt, denoiser) denoiser = denoiser.cuda() vae = vae.cuda() denoiser.eval() # read imagenet classlabels with open("imagenet_classlabels.txt", "r") as f: classlabels = f.readlines() classlabels = [label.strip() for label in classlabels] classlabels2id = {label: i for i, label in enumerate(classlabels)} id2classlabels = {i: label for i, label in enumerate(classlabels)} pipeline = Pipeline(vae, denoiser, conditioner, diffusion_sampler, args.resolution, classlabels2id) with gr.Blocks() as demo: gr.Markdown("DDT: Decoupled Diffusion Transformer on ImageNet 512x512") with gr.Row(): with gr.Column(scale=1): num_steps = gr.Slider(minimum=1, maximum=100, step=1, label="num steps", value=50) guidance = gr.Slider(minimum=0.1, maximum=10.0, step=0.1, label="CFG", value=4.0) num_images = gr.Slider(minimum=1, maximum=10, step=1, label="num images", value=4) label = gr.Dropdown(choices=classlabels, value=id2classlabels[950], label="label") seed = gr.Slider(minimum=0, maximum=1000000, step=1, label="seed", value=0) state_refresh_rate = gr.Slider(minimum=1, maximum=10, step=1, label="encoder reuse", value=1) guidance_interval_min = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, label="interval guidance min", value=0.0) guidance_interval_max = gr.Slider(minimum=0.1, maximum=1.0, step=0.1, label="interval guidance max", value=1.0) timeshift = gr.Slider(minimum=0.1, maximum=2.0, step=0.1, label="timeshift", value=1.0) with gr.Column(scale=2): btn = gr.Button("Generate") output = gr.Gallery(label="Images") btn.click(fn=pipeline, inputs=[ label, num_images, seed, num_steps, guidance, state_refresh_rate, guidance_interval_min, guidance_interval_max, timeshift ], outputs=[output]) demo.launch()