Spaces:
Paused
Paused
File size: 5,720 Bytes
c9b1bf6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
import os
import json
import inspect
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
from peft import LoraConfig, get_peft_model
import torch
from huggingface_hub import snapshot_download
# βββ 1. Read hyperparameters & mode βββββββββββββββββββββββββββββββββββββββββββ
model_id = os.environ.get("BASE_MODEL", "HiDream-ai/HiDream-I1-Dev")
trigger_word = os.environ.get("TRIGGER_WORD", "default-style")
num_steps = int(os.environ.get("NUM_STEPS", 100))
lora_r = int(os.environ.get("LORA_R", 16))
lora_alpha = int(os.environ.get("LORA_ALPHA", 16))
LOCAL = os.environ.get("LOCAL_TRAIN", "").lower() in ("1", "true")
# βββ 2. Set up directories ββββββββββββββββββββββββββββββββββββββββββββββββββββ
if LOCAL:
DATA_DIR = os.path.join(os.getcwd(), "data")
OUTPUT_DIR = os.path.join(os.getcwd(), "lora-trained")
LOCAL_MODEL = os.path.join(os.getcwd(), "hidream-model")
os.makedirs(DATA_DIR, exist_ok=True)
os.makedirs(OUTPUT_DIR, exist_ok=True)
else:
DATA_DIR = "/tmp/data"
OUTPUT_DIR = "/tmp/lora-trained"
CACHE_DIR = "/tmp/hidream-model"
os.makedirs(DATA_DIR, exist_ok=True)
os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(CACHE_DIR, exist_ok=True)
print(f"π Dataset directory: {DATA_DIR}", flush=True)
print(f"π₯ Preparing base model: {model_id}", flush=True)
# βββ 3. Resolve model path ββββββββββββββββββββββββββββββββββββββββββββββββββββ
def get_model_path():
# If local and predownloaded model exists, use it
if LOCAL and os.path.isdir(LOCAL_MODEL) and os.path.isfile(os.path.join(LOCAL_MODEL, "config.json")):
print(f"β
Using local model at: {LOCAL_MODEL}", flush=True)
return LOCAL_MODEL
# Otherwise download (to ~/.cache on local, or /tmp on Spaces)
download_kwargs = {} if LOCAL else {"local_dir": CACHE_DIR}
path = snapshot_download(model_id, **download_kwargs)
print(f"β
Downloaded model to: {path}", flush=True)
return path
model_path = get_model_path()
# βββ 4. Patch model_index.json to remove unsupported scheduler ββββββββββββββββ
mi_file = os.path.join(model_path, "model_index.json")
if os.path.isfile(mi_file):
with open(mi_file, "r") as f:
mi = json.load(f)
if "pipeline" in mi and "scheduler" in mi["pipeline"]:
print("π§ Removing 'scheduler' entry from model_index.json", flush=True)
mi["pipeline"].pop("scheduler", None)
with open(mi_file, "w") as f:
json.dump(mi, f, indent=2)
# βββ 5. Load & filter scheduler_config.json ββββββββββββββββββββββββββββββββββ
sched_cfg_path = os.path.join(model_path, "scheduler", "scheduler_config.json")
filtered_cfg = {}
if os.path.isfile(sched_cfg_path):
with open(sched_cfg_path, "r") as f:
raw_cfg = json.load(f)
sig = inspect.signature(DPMSolverMultistepScheduler.__init__)
valid_keys = set(sig.parameters.keys()) - {"self", "args", "kwargs"}
filtered_cfg = {k: v for k, v in raw_cfg.items() if k in valid_keys}
dropped = set(raw_cfg) - set(filtered_cfg)
if dropped:
print(f"β οΈ Dropped unsupported scheduler keys: {dropped}", flush=True)
try:
scheduler = DPMSolverMultistepScheduler(**filtered_cfg)
print("β
Instantiated DPMSolverMultistepScheduler from config", flush=True)
except Exception as e:
print(f"β Failed to init scheduler from config ({e}), using defaults", flush=True)
scheduler = DPMSolverMultistepScheduler()
else:
print("β οΈ No scheduler_config.json found; using default DPMSolverMultistepScheduler", flush=True)
scheduler = DPMSolverMultistepScheduler()
# βββ 6. Load the Stable Diffusion pipeline ββββββββββββββββββββββββββββββββββββ
print(f"π§ Loading pipeline from: {model_path}", flush=True)
pipe = StableDiffusionPipeline.from_pretrained(
model_path,
torch_dtype=torch.float16,
scheduler=scheduler
).to("cuda")
# βββ 7. Apply LoRA adapters βββββββββββββββββββββββββββββββββββββββββββββββββββ
print(f"π§ Applying LoRA config (r={lora_r}, Ξ±={lora_alpha})", flush=True)
lora_config = LoraConfig(
r=lora_r,
lora_alpha=lora_alpha,
bias="none",
task_type="CAUSAL_LM"
)
pipe.unet = get_peft_model(pipe.unet, lora_config)
# βββ 8. Training loop stub βββββββββββββββββββββββββββββββββββββββββββββββββββββ
print(f"π Starting fineβtuning for {num_steps} steps (trigger: {trigger_word})", flush=True)
for step in range(num_steps):
# TODO: replace this stub with your actual training code:
# β’ Load batches from DATA_DIR
# β’ Forward/backward pass, optimizer.step(), etc.
print(f"π Step {step+1}/{num_steps}", flush=True)
# βββ 9. Save the fineβtuned model βββββββββββββββββββββββββββββββββββββββββββββ
print(f"πΎ Saving fineβtuned model to: {OUTPUT_DIR}", flush=True)
pipe.save_pretrained(OUTPUT_DIR)
print("β
Training complete!", flush=True)
|