File size: 5,720 Bytes
c9b1bf6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import os
import json
import inspect
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
from peft import LoraConfig, get_peft_model
import torch
from huggingface_hub import snapshot_download

# ─── 1. Read hyperparameters & mode ───────────────────────────────────────────
model_id     = os.environ.get("BASE_MODEL", "HiDream-ai/HiDream-I1-Dev")
trigger_word = os.environ.get("TRIGGER_WORD", "default-style")
num_steps    = int(os.environ.get("NUM_STEPS", 100))
lora_r       = int(os.environ.get("LORA_R", 16))
lora_alpha   = int(os.environ.get("LORA_ALPHA", 16))
LOCAL        = os.environ.get("LOCAL_TRAIN", "").lower() in ("1", "true")

# ─── 2. Set up directories ────────────────────────────────────────────────────
if LOCAL:
    DATA_DIR    = os.path.join(os.getcwd(), "data")
    OUTPUT_DIR  = os.path.join(os.getcwd(), "lora-trained")
    LOCAL_MODEL = os.path.join(os.getcwd(), "hidream-model")
    os.makedirs(DATA_DIR,   exist_ok=True)
    os.makedirs(OUTPUT_DIR, exist_ok=True)
else:
    DATA_DIR   = "/tmp/data"
    OUTPUT_DIR = "/tmp/lora-trained"
    CACHE_DIR  = "/tmp/hidream-model"
    os.makedirs(DATA_DIR,   exist_ok=True)
    os.makedirs(OUTPUT_DIR, exist_ok=True)
    os.makedirs(CACHE_DIR,  exist_ok=True)

print(f"πŸ“‚ Dataset directory: {DATA_DIR}", flush=True)
print(f"πŸ“₯ Preparing base model: {model_id}", flush=True)

# ─── 3. Resolve model path ────────────────────────────────────────────────────
def get_model_path():
    # If local and predownloaded model exists, use it
    if LOCAL and os.path.isdir(LOCAL_MODEL) and os.path.isfile(os.path.join(LOCAL_MODEL, "config.json")):
        print(f"βœ… Using local model at: {LOCAL_MODEL}", flush=True)
        return LOCAL_MODEL
    # Otherwise download (to ~/.cache on local, or /tmp on Spaces)
    download_kwargs = {} if LOCAL else {"local_dir": CACHE_DIR}
    path = snapshot_download(model_id, **download_kwargs)
    print(f"βœ… Downloaded model to: {path}", flush=True)
    return path

model_path = get_model_path()

# ─── 4. Patch model_index.json to remove unsupported scheduler ────────────────
mi_file = os.path.join(model_path, "model_index.json")
if os.path.isfile(mi_file):
    with open(mi_file, "r") as f:
        mi = json.load(f)
    if "pipeline" in mi and "scheduler" in mi["pipeline"]:
        print("πŸ”§ Removing 'scheduler' entry from model_index.json", flush=True)
        mi["pipeline"].pop("scheduler", None)
        with open(mi_file, "w") as f:
            json.dump(mi, f, indent=2)

# ─── 5. Load & filter scheduler_config.json ──────────────────────────────────
sched_cfg_path = os.path.join(model_path, "scheduler", "scheduler_config.json")
filtered_cfg = {}
if os.path.isfile(sched_cfg_path):
    with open(sched_cfg_path, "r") as f:
        raw_cfg = json.load(f)
    sig = inspect.signature(DPMSolverMultistepScheduler.__init__)
    valid_keys = set(sig.parameters.keys()) - {"self", "args", "kwargs"}
    filtered_cfg = {k: v for k, v in raw_cfg.items() if k in valid_keys}
    dropped = set(raw_cfg) - set(filtered_cfg)
    if dropped:
        print(f"⚠️ Dropped unsupported scheduler keys: {dropped}", flush=True)
    try:
        scheduler = DPMSolverMultistepScheduler(**filtered_cfg)
        print("βœ… Instantiated DPMSolverMultistepScheduler from config", flush=True)
    except Exception as e:
        print(f"❌ Failed to init scheduler from config ({e}), using defaults", flush=True)
        scheduler = DPMSolverMultistepScheduler()
else:
    print("⚠️ No scheduler_config.json found; using default DPMSolverMultistepScheduler", flush=True)
    scheduler = DPMSolverMultistepScheduler()

# ─── 6. Load the Stable Diffusion pipeline ────────────────────────────────────
print(f"πŸ”§ Loading pipeline from: {model_path}", flush=True)
pipe = StableDiffusionPipeline.from_pretrained(
    model_path,
    torch_dtype=torch.float16,
    scheduler=scheduler
).to("cuda")

# ─── 7. Apply LoRA adapters ───────────────────────────────────────────────────
print(f"🧠 Applying LoRA config (r={lora_r}, α={lora_alpha})", flush=True)
lora_config = LoraConfig(
    r=lora_r,
    lora_alpha=lora_alpha,
    bias="none",
    task_type="CAUSAL_LM"
)
pipe.unet = get_peft_model(pipe.unet, lora_config)

# ─── 8. Training loop stub ─────────────────────────────────────────────────────
print(f"πŸš€ Starting fine‑tuning for {num_steps} steps (trigger: {trigger_word})", flush=True)
for step in range(num_steps):
    # TODO: replace this stub with your actual training code:
    #  β€’ Load batches from DATA_DIR
    #  β€’ Forward/backward pass, optimizer.step(), etc.
    print(f"πŸŒ€ Step {step+1}/{num_steps}", flush=True)

# ─── 9. Save the fine‑tuned model ─────────────────────────────────────────────
print(f"πŸ’Ύ Saving fine‑tuned model to: {OUTPUT_DIR}", flush=True)
pipe.save_pretrained(OUTPUT_DIR)
print("βœ… Training complete!", flush=True)