File size: 1,264 Bytes
c9b1bf6
 
81e7d73
 
 
 
 
b1dde27
0a3593d
81e7d73
 
 
 
 
aff7e63
81e7d73
 
 
 
aff7e63
 
81e7d73
 
 
 
 
 
 
0a3593d
81e7d73
 
 
 
 
 
0a3593d
 
81e7d73
 
 
35bd3cf
81e7d73
 
b1dde27
81e7d73
 
 
 
 
2ec882e
0a3593d
81e7d73
 
0a3593d
81e7d73
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import os
import torch
from aitoolkit import (
    LoRATrainer,
    StableDiffusionModel,
    LoRAConfig,
    ImageTextDataset,
)

# 1. Configuration
MODEL_ID     = "HiDream-ai/HiDream-I1-Dev"    # or your gated FLUX model if you have access
DATA_DIR     = "/workspace/data"
OUTPUT_DIR   = "/workspace/lora-trained"
DEVICE       = "cuda" if torch.cuda.is_available() else "cpu"

lora_cfg = LoRAConfig(
    rank=16,
    alpha=16,
    bias="none",
)

training_args = {
    "num_train_steps": 100,
    "batch_size": 4,
    "learning_rate": 1e-4,
    "save_every_n_steps": 50,
    "output_dir": OUTPUT_DIR,
}

# 2. Load base diffusion model
model = StableDiffusionModel.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.float16,
    device=DEVICE,
    use_auth_token=True,   # if it’s a gated repo
)

# 3. Prepare your dataset
# Expects pairs of image files + .txt captions in DATA_DIR
dataset = ImageTextDataset(data_root=DATA_DIR, image_size=512)

# 4. Hook up the LoRA adapter
model.apply_lora(lora_cfg)

# 5. Create the trainer and kickoff
trainer = LoRATrainer(
    model=model,
    dataset=dataset,
    args=training_args,
)

print("🚀 Starting training with AI‑Toolkit…")
trainer.train()

print(f"✅ Done! Fine-tuned weights saved to {OUTPUT_DIR}")