ramimu commited on
Commit
c9b1bf6
Β·
verified Β·
1 Parent(s): 09b6938

Create train.py

Browse files
Files changed (1) hide show
  1. train.py +111 -0
train.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import inspect
4
+ from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
5
+ from peft import LoraConfig, get_peft_model
6
+ import torch
7
+ from huggingface_hub import snapshot_download
8
+
9
+ # ─── 1. Read hyperparameters & mode ───────────────────────────────────────────
10
+ model_id = os.environ.get("BASE_MODEL", "HiDream-ai/HiDream-I1-Dev")
11
+ trigger_word = os.environ.get("TRIGGER_WORD", "default-style")
12
+ num_steps = int(os.environ.get("NUM_STEPS", 100))
13
+ lora_r = int(os.environ.get("LORA_R", 16))
14
+ lora_alpha = int(os.environ.get("LORA_ALPHA", 16))
15
+ LOCAL = os.environ.get("LOCAL_TRAIN", "").lower() in ("1", "true")
16
+
17
+ # ─── 2. Set up directories ────────────────────────────────────────────────────
18
+ if LOCAL:
19
+ DATA_DIR = os.path.join(os.getcwd(), "data")
20
+ OUTPUT_DIR = os.path.join(os.getcwd(), "lora-trained")
21
+ LOCAL_MODEL = os.path.join(os.getcwd(), "hidream-model")
22
+ os.makedirs(DATA_DIR, exist_ok=True)
23
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
24
+ else:
25
+ DATA_DIR = "/tmp/data"
26
+ OUTPUT_DIR = "/tmp/lora-trained"
27
+ CACHE_DIR = "/tmp/hidream-model"
28
+ os.makedirs(DATA_DIR, exist_ok=True)
29
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
30
+ os.makedirs(CACHE_DIR, exist_ok=True)
31
+
32
+ print(f"πŸ“‚ Dataset directory: {DATA_DIR}", flush=True)
33
+ print(f"πŸ“₯ Preparing base model: {model_id}", flush=True)
34
+
35
+ # ─── 3. Resolve model path ────────────────────────────────────────────────────
36
+ def get_model_path():
37
+ # If local and predownloaded model exists, use it
38
+ if LOCAL and os.path.isdir(LOCAL_MODEL) and os.path.isfile(os.path.join(LOCAL_MODEL, "config.json")):
39
+ print(f"βœ… Using local model at: {LOCAL_MODEL}", flush=True)
40
+ return LOCAL_MODEL
41
+ # Otherwise download (to ~/.cache on local, or /tmp on Spaces)
42
+ download_kwargs = {} if LOCAL else {"local_dir": CACHE_DIR}
43
+ path = snapshot_download(model_id, **download_kwargs)
44
+ print(f"βœ… Downloaded model to: {path}", flush=True)
45
+ return path
46
+
47
+ model_path = get_model_path()
48
+
49
+ # ─── 4. Patch model_index.json to remove unsupported scheduler ────────────────
50
+ mi_file = os.path.join(model_path, "model_index.json")
51
+ if os.path.isfile(mi_file):
52
+ with open(mi_file, "r") as f:
53
+ mi = json.load(f)
54
+ if "pipeline" in mi and "scheduler" in mi["pipeline"]:
55
+ print("πŸ”§ Removing 'scheduler' entry from model_index.json", flush=True)
56
+ mi["pipeline"].pop("scheduler", None)
57
+ with open(mi_file, "w") as f:
58
+ json.dump(mi, f, indent=2)
59
+
60
+ # ─── 5. Load & filter scheduler_config.json ──────────────────────────────────
61
+ sched_cfg_path = os.path.join(model_path, "scheduler", "scheduler_config.json")
62
+ filtered_cfg = {}
63
+ if os.path.isfile(sched_cfg_path):
64
+ with open(sched_cfg_path, "r") as f:
65
+ raw_cfg = json.load(f)
66
+ sig = inspect.signature(DPMSolverMultistepScheduler.__init__)
67
+ valid_keys = set(sig.parameters.keys()) - {"self", "args", "kwargs"}
68
+ filtered_cfg = {k: v for k, v in raw_cfg.items() if k in valid_keys}
69
+ dropped = set(raw_cfg) - set(filtered_cfg)
70
+ if dropped:
71
+ print(f"⚠️ Dropped unsupported scheduler keys: {dropped}", flush=True)
72
+ try:
73
+ scheduler = DPMSolverMultistepScheduler(**filtered_cfg)
74
+ print("βœ… Instantiated DPMSolverMultistepScheduler from config", flush=True)
75
+ except Exception as e:
76
+ print(f"❌ Failed to init scheduler from config ({e}), using defaults", flush=True)
77
+ scheduler = DPMSolverMultistepScheduler()
78
+ else:
79
+ print("⚠️ No scheduler_config.json found; using default DPMSolverMultistepScheduler", flush=True)
80
+ scheduler = DPMSolverMultistepScheduler()
81
+
82
+ # ─── 6. Load the Stable Diffusion pipeline ────────────────────────────────────
83
+ print(f"πŸ”§ Loading pipeline from: {model_path}", flush=True)
84
+ pipe = StableDiffusionPipeline.from_pretrained(
85
+ model_path,
86
+ torch_dtype=torch.float16,
87
+ scheduler=scheduler
88
+ ).to("cuda")
89
+
90
+ # ─── 7. Apply LoRA adapters ───────────────────────────────────────────────────
91
+ print(f"🧠 Applying LoRA config (r={lora_r}, α={lora_alpha})", flush=True)
92
+ lora_config = LoraConfig(
93
+ r=lora_r,
94
+ lora_alpha=lora_alpha,
95
+ bias="none",
96
+ task_type="CAUSAL_LM"
97
+ )
98
+ pipe.unet = get_peft_model(pipe.unet, lora_config)
99
+
100
+ # ─── 8. Training loop stub ─────────────────────────────────────────────────────
101
+ print(f"πŸš€ Starting fine‑tuning for {num_steps} steps (trigger: {trigger_word})", flush=True)
102
+ for step in range(num_steps):
103
+ # TODO: replace this stub with your actual training code:
104
+ # β€’ Load batches from DATA_DIR
105
+ # β€’ Forward/backward pass, optimizer.step(), etc.
106
+ print(f"πŸŒ€ Step {step+1}/{num_steps}", flush=True)
107
+
108
+ # ─── 9. Save the fine‑tuned model ─────────────────────────────────────────────
109
+ print(f"πŸ’Ύ Saving fine‑tuned model to: {OUTPUT_DIR}", flush=True)
110
+ pipe.save_pretrained(OUTPUT_DIR)
111
+ print("βœ… Training complete!", flush=True)