ramimu commited on
Commit
0a3593d
Β·
verified Β·
1 Parent(s): 7fc0abf

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +58 -95
train.py CHANGED
@@ -1,111 +1,74 @@
 
 
1
  import os
2
- import json
3
- import inspect
4
- from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
5
- from peft import LoraConfig, get_peft_model
6
  import torch
7
  from huggingface_hub import snapshot_download
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
- # ─── 1. Read hyperparameters & mode ───────────────────────────────────────────
10
- model_id = os.environ.get("BASE_MODEL", "HiDream-ai/HiDream-I1-Dev")
11
- trigger_word = os.environ.get("TRIGGER_WORD", "default-style")
12
- num_steps = int(os.environ.get("NUM_STEPS", 100))
13
- lora_r = int(os.environ.get("LORA_R", 16))
14
- lora_alpha = int(os.environ.get("LORA_ALPHA", 16))
15
- LOCAL = os.environ.get("LOCAL_TRAIN", "").lower() in ("1", "true")
16
-
17
- # ─── 2. Set up directories ────────────────────────────────────────────────────
18
- if LOCAL:
19
- DATA_DIR = os.path.join(os.getcwd(), "data")
20
- OUTPUT_DIR = os.path.join(os.getcwd(), "lora-trained")
21
- LOCAL_MODEL = os.path.join(os.getcwd(), "hidream-model")
22
- os.makedirs(DATA_DIR, exist_ok=True)
23
- os.makedirs(OUTPUT_DIR, exist_ok=True)
24
- else:
25
- DATA_DIR = "/tmp/data"
26
- OUTPUT_DIR = "/tmp/lora-trained"
27
- CACHE_DIR = "/tmp/hidream-model"
28
- os.makedirs(DATA_DIR, exist_ok=True)
29
- os.makedirs(OUTPUT_DIR, exist_ok=True)
30
- os.makedirs(CACHE_DIR, exist_ok=True)
31
-
32
- print(f"πŸ“‚ Dataset directory: {DATA_DIR}", flush=True)
33
- print(f"πŸ“₯ Preparing base model: {model_id}", flush=True)
34
-
35
- # ─── 3. Resolve model path ────────────────────────────────────────────────────
36
- def get_model_path():
37
- # If local and predownloaded model exists, use it
38
- if LOCAL and os.path.isdir(LOCAL_MODEL) and os.path.isfile(os.path.join(LOCAL_MODEL, "config.json")):
39
- print(f"βœ… Using local model at: {LOCAL_MODEL}", flush=True)
40
- return LOCAL_MODEL
41
- # Otherwise download (to ~/.cache on local, or /tmp on Spaces)
42
- download_kwargs = {} if LOCAL else {"local_dir": CACHE_DIR}
43
- path = snapshot_download(model_id, **download_kwargs)
44
- print(f"βœ… Downloaded model to: {path}", flush=True)
45
- return path
46
-
47
- model_path = get_model_path()
48
-
49
- # ─── 4. Patch model_index.json to remove unsupported scheduler ────────────────
50
- mi_file = os.path.join(model_path, "model_index.json")
51
- if os.path.isfile(mi_file):
52
- with open(mi_file, "r") as f:
53
- mi = json.load(f)
54
- if "pipeline" in mi and "scheduler" in mi["pipeline"]:
55
- print("πŸ”§ Removing 'scheduler' entry from model_index.json", flush=True)
56
- mi["pipeline"].pop("scheduler", None)
57
- with open(mi_file, "w") as f:
58
- json.dump(mi, f, indent=2)
59
-
60
- # ─── 5. Load & filter scheduler_config.json ──────────────────────────────────
61
- sched_cfg_path = os.path.join(model_path, "scheduler", "scheduler_config.json")
62
- filtered_cfg = {}
63
- if os.path.isfile(sched_cfg_path):
64
- with open(sched_cfg_path, "r") as f:
65
- raw_cfg = json.load(f)
66
- sig = inspect.signature(DPMSolverMultistepScheduler.__init__)
67
- valid_keys = set(sig.parameters.keys()) - {"self", "args", "kwargs"}
68
- filtered_cfg = {k: v for k, v in raw_cfg.items() if k in valid_keys}
69
- dropped = set(raw_cfg) - set(filtered_cfg)
70
- if dropped:
71
- print(f"⚠️ Dropped unsupported scheduler keys: {dropped}", flush=True)
72
- try:
73
- scheduler = DPMSolverMultistepScheduler(**filtered_cfg)
74
- print("βœ… Instantiated DPMSolverMultistepScheduler from config", flush=True)
75
- except Exception as e:
76
- print(f"❌ Failed to init scheduler from config ({e}), using defaults", flush=True)
77
- scheduler = DPMSolverMultistepScheduler()
78
- else:
79
- print("⚠️ No scheduler_config.json found; using default DPMSolverMultistepScheduler", flush=True)
80
- scheduler = DPMSolverMultistepScheduler()
81
-
82
- # ─── 6. Load the Stable Diffusion pipeline ───────────────────────────��────────
83
- print(f"πŸ”§ Loading pipeline from: {model_path}", flush=True)
84
  pipe = StableDiffusionPipeline.from_pretrained(
85
- model_path,
 
86
  torch_dtype=torch.float16,
87
- scheduler=scheduler
88
  ).to("cuda")
89
 
90
- # ─── 7. Apply LoRA adapters ───────────────────────────────────────────────────
91
- print(f"🧠 Applying LoRA config (r={lora_r}, α={lora_alpha})", flush=True)
 
92
  lora_config = LoraConfig(
93
- r=lora_r,
94
- lora_alpha=lora_alpha,
95
  bias="none",
96
  task_type="CAUSAL_LM"
97
  )
98
  pipe.unet = get_peft_model(pipe.unet, lora_config)
99
 
100
- # ─── 8. Training loop stub ─────────────────────────────────────────────────────
101
- print(f"πŸš€ Starting fine‑tuning for {num_steps} steps (trigger: {trigger_word})", flush=True)
102
- for step in range(num_steps):
103
- # TODO: replace this stub with your actual training code:
104
- # β€’ Load batches from DATA_DIR
105
- # β€’ Forward/backward pass, optimizer.step(), etc.
106
- print(f"πŸŒ€ Step {step+1}/{num_steps}", flush=True)
 
107
 
108
- # ─── 9. Save the fine‑tuned model ─────────────────────────────────────────────
109
- print(f"πŸ’Ύ Saving fine‑tuned model to: {OUTPUT_DIR}", flush=True)
110
  pipe.save_pretrained(OUTPUT_DIR)
111
- print("βœ… Training complete!", flush=True)
 
1
+ # train.py
2
+
3
  import os
 
 
 
 
4
  import torch
5
  from huggingface_hub import snapshot_download
6
+ from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
7
+ from peft import LoraConfig, get_peft_model
8
+
9
+ # ── 1) Configuration ───────────────────────────────────────────────────────────
10
+
11
+ # Where you put your images + prompts
12
+ DATA_DIR = os.getenv("DATA_DIR", "./data")
13
+
14
+ # Where your base model lives (downloaded or cached)
15
+ MODEL_DIR = os.getenv("MODEL_DIR", "./hidream-model")
16
+
17
+ # Where to save your LoRA‑fine‑tuned model
18
+ OUTPUT_DIR = os.getenv("OUTPUT_DIR", "./lora-trained")
19
+
20
+ # ── 2) Prepare the base model snapshot ────────────────────────────────────────
21
 
22
+ print(f"πŸ“‚ Loading dataset from: {DATA_DIR}")
23
+ print("πŸ“₯ Fetching or verifying base model: HiDream-ai/HiDream-I1-Dev")
24
+
25
+ # If you’ve pre‑downloaded into MODEL_DIR, just use it; otherwise pull from HF Hub
26
+ if not os.path.isdir(MODEL_DIR):
27
+ MODEL_DIR = snapshot_download(
28
+ repo_id="HiDream-ai/HiDream-I1-Dev",
29
+ local_dir=MODEL_DIR
30
+ )
31
+
32
+ # ── 3) Load the scheduler manually ─────────────────────────────────────────────
33
+
34
+ # Diffusers’ scheduler config JSON points at FlowMatchLCMScheduler,
35
+ # but your installed version doesn’t have that class. Instead we
36
+ # force‐load DPMSolverMultistepScheduler via `from_pretrained`.
37
+ print(f"πŸ”„ Loading scheduler from: {MODEL_DIR}/scheduler")
38
+ scheduler = DPMSolverMultistepScheduler.from_pretrained(
39
+ pretrained_model_name_or_path=MODEL_DIR,
40
+ subfolder="scheduler"
41
+ )
42
+
43
+ # ── 4) Build the Stable Diffusion pipeline ────────────────────────────────────
44
+
45
+ print("πŸ”§ Creating StableDiffusionPipeline with custom scheduler")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  pipe = StableDiffusionPipeline.from_pretrained(
47
+ pretrained_model_name_or_path=MODEL_DIR,
48
+ scheduler=scheduler,
49
  torch_dtype=torch.float16,
 
50
  ).to("cuda")
51
 
52
+ # ── 5) Apply PEFT LoRA adapters ───────────────────────────────────────────────
53
+
54
+ print("🧠 Configuring LoRA adapter on U‑Net")
55
  lora_config = LoraConfig(
56
+ r=16,
57
+ lora_alpha=16,
58
  bias="none",
59
  task_type="CAUSAL_LM"
60
  )
61
  pipe.unet = get_peft_model(pipe.unet, lora_config)
62
 
63
+ # ── 6) (Placeholder) Simulate your training loop ─────────────────────────────
64
+
65
+ print("πŸš€ Starting fine‑tuning loop (simulated)")
66
+ for step in range(100):
67
+ # Here you'd load your data, compute loss, do optimizer.step(), etc.
68
+ print(f" Training step {step+1}/100")
69
+
70
+ # ── 7) Save your LoRA‑tuned model ────────────────────────────────────────────
71
 
72
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
 
73
  pipe.save_pretrained(OUTPUT_DIR)
74
+ print("βœ… Training complete. Model saved to", OUTPUT_DIR)