ramimu commited on
Commit
35bd3cf
Β·
verified Β·
1 Parent(s): b1dde27

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +45 -39
train.py CHANGED
@@ -1,68 +1,73 @@
1
- # train.py
2
-
3
  import os
4
  import torch
5
  from huggingface_hub import snapshot_download
6
- from peft import LoraConfig, get_peft_model
7
-
8
- # 1️⃣ Pick your scheduler class
9
  from diffusers import (
10
  StableDiffusionPipeline,
11
  DPMSolverMultistepScheduler,
12
- UNet2DConditionModel,
13
  AutoencoderKL,
 
14
  )
15
  from transformers import CLIPTextModel, CLIPTokenizer
 
16
 
17
- # ─── 1) CONFIG ────────────────────────────────────────────────────────────────
18
 
19
- DATA_DIR = os.getenv("DATA_DIR", "./data")
20
- MODEL_DIR = os.getenv("MODEL_DIR", "./hidream-model")
21
- OUTPUT_DIR = os.getenv("OUTPUT_DIR", "./lora-trained")
 
22
 
23
- # ─── 2) DOWNLOAD OR VERIFY BASE MODEL ──────────────────────────────────────────
24
 
25
- if not os.path.isdir(MODEL_DIR):
26
- MODEL_DIR = snapshot_download(
27
- repo_id="HiDream-ai/HiDream-I1-Dev",
28
- local_dir=MODEL_DIR
29
- )
 
30
 
31
- # ─── 3) LOAD EACH PIPELINE COMPONENT ──────────────────────────────────────────
32
 
33
- # 3a) Scheduler
34
  scheduler = DPMSolverMultistepScheduler.from_pretrained(
35
- MODEL_DIR,
36
- subfolder="scheduler"
37
  )
38
 
39
- # 3b) VAE
 
 
40
  vae = AutoencoderKL.from_pretrained(
41
- MODEL_DIR,
42
  subfolder="vae",
43
- torch_dtype=torch.float16
44
  ).to("cuda")
45
 
46
- # 3c) Text encoder + tokenizer
 
 
47
  text_encoder = CLIPTextModel.from_pretrained(
48
- MODEL_DIR,
49
  subfolder="text_encoder",
50
- torch_dtype=torch.float16
51
  ).to("cuda")
52
- tokenizer = CLIPTokenizer.from_pretrained(
53
- MODEL_DIR,
54
- subfolder="tokenizer"
55
  )
56
 
57
- # 3d) U‑Net
 
 
58
  unet = UNet2DConditionModel.from_pretrained(
59
- MODEL_DIR,
60
  subfolder="unet",
61
- torch_dtype=torch.float16
62
  ).to("cuda")
63
 
64
- # ─── 4) BUILD THE PIPELINE ────────────────────────────────────────────────────
65
 
 
66
  pipe = StableDiffusionPipeline(
67
  vae=vae,
68
  text_encoder=text_encoder,
@@ -71,8 +76,9 @@ pipe = StableDiffusionPipeline(
71
  scheduler=scheduler,
72
  ).to("cuda")
73
 
74
- # ─── 5) APPLY LORA ────────────────────────────────────────────────────────────
75
 
 
76
  lora_config = LoraConfig(
77
  r=16,
78
  lora_alpha=16,
@@ -81,15 +87,15 @@ lora_config = LoraConfig(
81
  )
82
  pipe.unet = get_peft_model(pipe.unet, lora_config)
83
 
84
- # ─── 6) TRAINING LOOP (SIMULATED) ─────────────────────────────────────────────
85
 
86
- print(f"πŸ“‚ Data at {DATA_DIR}")
87
  for step in range(100):
88
- # … your real data loading + optimizer here …
89
  print(f"Training step {step+1}/100")
90
 
91
- # ─── 7) SAVE THE FINE‑TUNED LO‑RA ─────────────────────────────────────────────
92
 
93
  os.makedirs(OUTPUT_DIR, exist_ok=True)
94
  pipe.save_pretrained(OUTPUT_DIR)
95
- print("βœ… Done! Saved to", OUTPUT_DIR)
 
 
 
1
  import os
2
  import torch
3
  from huggingface_hub import snapshot_download
 
 
 
4
  from diffusers import (
5
  StableDiffusionPipeline,
6
  DPMSolverMultistepScheduler,
 
7
  AutoencoderKL,
8
+ UNet2DConditionModel,
9
  )
10
  from transformers import CLIPTextModel, CLIPTokenizer
11
+ from peft import LoraConfig, get_peft_model
12
 
13
+ # ─── CONFIG ───────────────────────────────────────────────────────────────────
14
 
15
+ DATA_DIR = os.getenv("DATA_DIR", "./data")
16
+ MODEL_CACHE = os.getenv("MODEL_DIR", "./hidream-model")
17
+ OUTPUT_DIR = os.getenv("OUTPUT_DIR", "./lora-trained")
18
+ REPO_ID = "HiDream-ai/HiDream-I1-Dev"
19
 
20
+ # ─── STEP 1: ENSURE YOU HAVE A COMPLETE SNAPSHOT WITH CONFIGS ─────────────────
21
 
22
+ print(f"πŸ“₯ Downloading full model snapshot to {MODEL_CACHE}")
23
+ MODEL_ROOT = snapshot_download(
24
+ repo_id=REPO_ID,
25
+ local_dir=MODEL_CACHE,
26
+ local_dir_use_symlinks=False, # force a copy so config.json ends up there
27
+ )
28
 
29
+ # ─── STEP 2: LOAD SCHEDULER ────────────────────────────────────────────────────
30
 
31
+ print("πŸ”„ Loading scheduler")
32
  scheduler = DPMSolverMultistepScheduler.from_pretrained(
33
+ MODEL_ROOT,
34
+ subfolder="scheduler",
35
  )
36
 
37
+ # ─── STEP 3: LOAD VAE ──────────────────────────────────────────────────────────
38
+
39
+ print("πŸ”„ Loading VAE")
40
  vae = AutoencoderKL.from_pretrained(
41
+ MODEL_ROOT,
42
  subfolder="vae",
43
+ torch_dtype=torch.float16,
44
  ).to("cuda")
45
 
46
+ # ─── STEP 4: LOAD TEXT ENCODER + TOKENIZER ─────────────────────────────────────
47
+
48
+ print("πŸ”„ Loading text encoder + tokenizer")
49
  text_encoder = CLIPTextModel.from_pretrained(
50
+ MODEL_ROOT,
51
  subfolder="text_encoder",
52
+ torch_dtype=torch.float16,
53
  ).to("cuda")
54
+ tokenizer = CLIPTokenizer.from_pretrained(
55
+ MODEL_ROOT,
56
+ subfolder="tokenizer",
57
  )
58
 
59
+ # ─── STEP 5: LOAD U‑NET ───────────────────────────────────────────────────────
60
+
61
+ print("πŸ”„ Loading U‑Net")
62
  unet = UNet2DConditionModel.from_pretrained(
63
+ MODEL_ROOT,
64
  subfolder="unet",
65
+ torch_dtype=torch.float16,
66
  ).to("cuda")
67
 
68
+ # ─── STEP 6: BUILD THE PIPELINE ───────────────────────────────────────────────
69
 
70
+ print("🌟 Building StableDiffusionPipeline")
71
  pipe = StableDiffusionPipeline(
72
  vae=vae,
73
  text_encoder=text_encoder,
 
76
  scheduler=scheduler,
77
  ).to("cuda")
78
 
79
+ # ─── STEP 7: APPLY LORA ADAPTER ───────────────────────────────────────────────
80
 
81
+ print("🧠 Applying LoRA adapter")
82
  lora_config = LoraConfig(
83
  r=16,
84
  lora_alpha=16,
 
87
  )
88
  pipe.unet = get_peft_model(pipe.unet, lora_config)
89
 
90
+ # ─── STEP 8: YOUR TRAINING LOOP (SIMULATED) ────────────────────────────────────
91
 
92
+ print(f"πŸ“‚ Loading dataset from: {DATA_DIR}")
93
  for step in range(100):
94
+ # ←– here’s where you’d load your images, run forward/backward, optimizer, etc.
95
  print(f"Training step {step+1}/100")
96
 
97
+ # ─── STEP 9: SAVE THE FINE‑TUNED LO‑RA WEIGHTS ───────────────────────────────
98
 
99
  os.makedirs(OUTPUT_DIR, exist_ok=True)
100
  pipe.save_pretrained(OUTPUT_DIR)
101
+ print("βœ… Training complete. Saved to", OUTPUT_DIR)