ramimu commited on
Commit
b1dde27
Β·
verified Β·
1 Parent(s): e469a9f

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +53 -32
train.py CHANGED
@@ -3,72 +3,93 @@
3
  import os
4
  import torch
5
  from huggingface_hub import snapshot_download
6
- from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
7
  from peft import LoraConfig, get_peft_model
8
 
9
- # ── 1) Configuration ───────────────────────────────────────────────────────────
 
 
 
 
 
 
 
10
 
11
- # Where you put your images + prompts
12
- DATA_DIR = os.getenv("DATA_DIR", "./data")
13
 
14
- # Where your base model lives (downloaded or cached)
15
  MODEL_DIR = os.getenv("MODEL_DIR", "./hidream-model")
16
-
17
- # Where to save your LoRA‑fine‑tuned model
18
  OUTPUT_DIR = os.getenv("OUTPUT_DIR", "./lora-trained")
19
 
20
- # ── 2) Prepare the base model snapshot ────────────────────────────────────────
21
-
22
- print(f"πŸ“‚ Loading dataset from: {DATA_DIR}")
23
- print("πŸ“₯ Fetching or verifying base model: HiDream-ai/HiDream-I1-Dev")
24
 
25
- # If you’ve pre‑downloaded into MODEL_DIR, just use it; otherwise pull from HF Hub
26
  if not os.path.isdir(MODEL_DIR):
27
  MODEL_DIR = snapshot_download(
28
  repo_id="HiDream-ai/HiDream-I1-Dev",
29
  local_dir=MODEL_DIR
30
  )
31
 
32
- # ── 3) Load the scheduler manually ─────────────────────────────────────────────
33
 
34
- # Diffusers’ scheduler config JSON points at FlowMatchLCMScheduler,
35
- # but your installed version doesn’t have that class. Instead we
36
- # force‐load DPMSolverMultistepScheduler via `from_pretrained`.
37
- print(f"πŸ”„ Loading scheduler from: {MODEL_DIR}/scheduler")
38
  scheduler = DPMSolverMultistepScheduler.from_pretrained(
39
- pretrained_model_name_or_path=MODEL_DIR,
40
  subfolder="scheduler"
41
  )
42
 
43
- # ── 4) Build the Stable Diffusion pipeline ────────────────────────────────────
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
- print("πŸ”§ Creating StableDiffusionPipeline with custom scheduler")
46
- pipe = StableDiffusionPipeline.from_pretrained(
47
- pretrained_model_name_or_path=MODEL_DIR,
 
 
48
  scheduler=scheduler,
49
- torch_dtype=torch.float16,
50
  ).to("cuda")
51
 
52
- # ── 5) Apply PEFT LoRA adapters ───────────────────────────────────────────────
53
 
54
- print("🧠 Configuring LoRA adapter on U‑Net")
55
  lora_config = LoraConfig(
56
  r=16,
57
  lora_alpha=16,
58
  bias="none",
59
- task_type="CAUSAL_LM"
60
  )
61
  pipe.unet = get_peft_model(pipe.unet, lora_config)
62
 
63
- # ── 6) (Placeholder) Simulate your training loop ─────────────────────────────
64
 
65
- print("πŸš€ Starting fine‑tuning loop (simulated)")
66
  for step in range(100):
67
- # Here you'd load your data, compute loss, do optimizer.step(), etc.
68
- print(f" Training step {step+1}/100")
69
 
70
- # ── 7) Save your LoRA‑tuned model ────────────────────────────────────────────
71
 
72
  os.makedirs(OUTPUT_DIR, exist_ok=True)
73
  pipe.save_pretrained(OUTPUT_DIR)
74
- print("βœ… Training complete. Model saved to", OUTPUT_DIR)
 
3
  import os
4
  import torch
5
  from huggingface_hub import snapshot_download
 
6
  from peft import LoraConfig, get_peft_model
7
 
8
+ # 1️⃣ Pick your scheduler class
9
+ from diffusers import (
10
+ StableDiffusionPipeline,
11
+ DPMSolverMultistepScheduler,
12
+ UNet2DConditionModel,
13
+ AutoencoderKL,
14
+ )
15
+ from transformers import CLIPTextModel, CLIPTokenizer
16
 
17
+ # ─── 1) CONFIG ────────────────────────────────────────────────────────────────
 
18
 
19
+ DATA_DIR = os.getenv("DATA_DIR", "./data")
20
  MODEL_DIR = os.getenv("MODEL_DIR", "./hidream-model")
 
 
21
  OUTPUT_DIR = os.getenv("OUTPUT_DIR", "./lora-trained")
22
 
23
+ # ─── 2) DOWNLOAD OR VERIFY BASE MODEL ──────────────────────────────────────────
 
 
 
24
 
 
25
  if not os.path.isdir(MODEL_DIR):
26
  MODEL_DIR = snapshot_download(
27
  repo_id="HiDream-ai/HiDream-I1-Dev",
28
  local_dir=MODEL_DIR
29
  )
30
 
31
+ # ─── 3) LOAD EACH PIPELINE COMPONENT ──────────────────────────────────────────
32
 
33
+ # 3a) Scheduler
 
 
 
34
  scheduler = DPMSolverMultistepScheduler.from_pretrained(
35
+ MODEL_DIR,
36
  subfolder="scheduler"
37
  )
38
 
39
+ # 3b) VAE
40
+ vae = AutoencoderKL.from_pretrained(
41
+ MODEL_DIR,
42
+ subfolder="vae",
43
+ torch_dtype=torch.float16
44
+ ).to("cuda")
45
+
46
+ # 3c) Text encoder + tokenizer
47
+ text_encoder = CLIPTextModel.from_pretrained(
48
+ MODEL_DIR,
49
+ subfolder="text_encoder",
50
+ torch_dtype=torch.float16
51
+ ).to("cuda")
52
+ tokenizer = CLIPTokenizer.from_pretrained(
53
+ MODEL_DIR,
54
+ subfolder="tokenizer"
55
+ )
56
+
57
+ # 3d) U‑Net
58
+ unet = UNet2DConditionModel.from_pretrained(
59
+ MODEL_DIR,
60
+ subfolder="unet",
61
+ torch_dtype=torch.float16
62
+ ).to("cuda")
63
+
64
+ # ─── 4) BUILD THE PIPELINE ────────────────────────────────────────────────────
65
 
66
+ pipe = StableDiffusionPipeline(
67
+ vae=vae,
68
+ text_encoder=text_encoder,
69
+ tokenizer=tokenizer,
70
+ unet=unet,
71
  scheduler=scheduler,
 
72
  ).to("cuda")
73
 
74
+ # ─── 5) APPLY LORA ────────────────────────────────────────────────────────────
75
 
 
76
  lora_config = LoraConfig(
77
  r=16,
78
  lora_alpha=16,
79
  bias="none",
80
+ task_type="CAUSAL_LM",
81
  )
82
  pipe.unet = get_peft_model(pipe.unet, lora_config)
83
 
84
+ # ─── 6) TRAINING LOOP (SIMULATED) ─────────────────────────────────────────────
85
 
86
+ print(f"πŸ“‚ Data at {DATA_DIR}")
87
  for step in range(100):
88
+ # … your real data loading + optimizer here …
89
+ print(f"Training step {step+1}/100")
90
 
91
+ # ─── 7) SAVE THE FINE‑TUNED LO‑RA ─────────────────────────────────────────────
92
 
93
  os.makedirs(OUTPUT_DIR, exist_ok=True)
94
  pipe.save_pretrained(OUTPUT_DIR)
95
+ print("βœ… Done! Saved to", OUTPUT_DIR)