import os import torch from aitoolkit import ( LoRATrainer, StableDiffusionModel, LoRAConfig, ImageTextDataset, ) # 1. Configuration MODEL_ID = "HiDream-ai/HiDream-I1-Dev" # or your gated FLUX model if you have access DATA_DIR = "/workspace/data" OUTPUT_DIR = "/workspace/lora-trained" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" lora_cfg = LoRAConfig( rank=16, alpha=16, bias="none", ) training_args = { "num_train_steps": 100, "batch_size": 4, "learning_rate": 1e-4, "save_every_n_steps": 50, "output_dir": OUTPUT_DIR, } # 2. Load base diffusion model model = StableDiffusionModel.from_pretrained( MODEL_ID, torch_dtype=torch.float16, device=DEVICE, use_auth_token=True, # if it’s a gated repo ) # 3. Prepare your dataset # Expects pairs of image files + .txt captions in DATA_DIR dataset = ImageTextDataset(data_root=DATA_DIR, image_size=512) # 4. Hook up the LoRA adapter model.apply_lora(lora_cfg) # 5. Create the trainer and kickoff trainer = LoRATrainer( model=model, dataset=dataset, args=training_args, ) print("🚀 Starting training with AI‑Toolkit…") trainer.train() print(f"✅ Done! Fine-tuned weights saved to {OUTPUT_DIR}")