Spaces:
Paused
Paused
File size: 1,264 Bytes
c9b1bf6 81e7d73 b1dde27 0a3593d 81e7d73 aff7e63 81e7d73 aff7e63 81e7d73 0a3593d 81e7d73 0a3593d 81e7d73 35bd3cf 81e7d73 b1dde27 81e7d73 2ec882e 0a3593d 81e7d73 0a3593d 81e7d73 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
import os
import torch
from aitoolkit import (
LoRATrainer,
StableDiffusionModel,
LoRAConfig,
ImageTextDataset,
)
# 1. Configuration
MODEL_ID = "HiDream-ai/HiDream-I1-Dev" # or your gated FLUX model if you have access
DATA_DIR = "/workspace/data"
OUTPUT_DIR = "/workspace/lora-trained"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
lora_cfg = LoRAConfig(
rank=16,
alpha=16,
bias="none",
)
training_args = {
"num_train_steps": 100,
"batch_size": 4,
"learning_rate": 1e-4,
"save_every_n_steps": 50,
"output_dir": OUTPUT_DIR,
}
# 2. Load base diffusion model
model = StableDiffusionModel.from_pretrained(
MODEL_ID,
torch_dtype=torch.float16,
device=DEVICE,
use_auth_token=True, # if it’s a gated repo
)
# 3. Prepare your dataset
# Expects pairs of image files + .txt captions in DATA_DIR
dataset = ImageTextDataset(data_root=DATA_DIR, image_size=512)
# 4. Hook up the LoRA adapter
model.apply_lora(lora_cfg)
# 5. Create the trainer and kickoff
trainer = LoRATrainer(
model=model,
dataset=dataset,
args=training_args,
)
print("🚀 Starting training with AI‑Toolkit…")
trainer.train()
print(f"✅ Done! Fine-tuned weights saved to {OUTPUT_DIR}")
|