ramimu commited on
Commit
2ec882e
Β·
verified Β·
1 Parent(s): 35bd3cf

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +32 -65
train.py CHANGED
@@ -5,97 +5,64 @@ from diffusers import (
5
  StableDiffusionPipeline,
6
  DPMSolverMultistepScheduler,
7
  AutoencoderKL,
8
- UNet2DConditionModel,
9
  )
10
  from transformers import CLIPTextModel, CLIPTokenizer
11
  from peft import LoraConfig, get_peft_model
12
 
13
- # ─── CONFIG ───────────────────────────────────────────────────────────────────
 
 
14
 
15
- DATA_DIR = os.getenv("DATA_DIR", "./data")
16
- MODEL_CACHE = os.getenv("MODEL_DIR", "./hidream-model")
17
- OUTPUT_DIR = os.getenv("OUTPUT_DIR", "./lora-trained")
18
- REPO_ID = "HiDream-ai/HiDream-I1-Dev"
19
 
20
- # ─── STEP 1: ENSURE YOU HAVE A COMPLETE SNAPSHOT WITH CONFIGS ─────────────────
21
-
22
- print(f"πŸ“₯ Downloading full model snapshot to {MODEL_CACHE}")
23
- MODEL_ROOT = snapshot_download(
24
- repo_id=REPO_ID,
25
- local_dir=MODEL_CACHE,
26
- local_dir_use_symlinks=False, # force a copy so config.json ends up there
27
- )
28
-
29
- # ─── STEP 2: LOAD SCHEDULER ────────────────────────────────────────────────────
30
-
31
- print("πŸ”„ Loading scheduler")
32
  scheduler = DPMSolverMultistepScheduler.from_pretrained(
33
- MODEL_ROOT,
34
- subfolder="scheduler",
35
  )
36
 
37
- # ─── STEP 3: LOAD VAE ──────────────────────────────────────────────────────────
38
-
39
- print("πŸ”„ Loading VAE")
40
  vae = AutoencoderKL.from_pretrained(
41
- MODEL_ROOT,
42
- subfolder="vae",
43
- torch_dtype=torch.float16,
44
- ).to("cuda")
45
-
46
- # ─── STEP 4: LOAD TEXT ENCODER + TOKENIZER ─────────────────────────────────────
47
 
48
- print("πŸ”„ Loading text encoder + tokenizer")
49
  text_encoder = CLIPTextModel.from_pretrained(
50
- MODEL_ROOT,
51
- subfolder="text_encoder",
52
- torch_dtype=torch.float16,
53
- ).to("cuda")
54
  tokenizer = CLIPTokenizer.from_pretrained(
55
- MODEL_ROOT,
56
- subfolder="tokenizer",
57
  )
58
 
59
- # ─── STEP 5: LOAD U‑NET ───────────────────────────────────────────────────────
60
-
61
- print("πŸ”„ Loading U‑Net")
62
  unet = UNet2DConditionModel.from_pretrained(
63
- MODEL_ROOT,
64
- subfolder="unet",
65
- torch_dtype=torch.float16,
66
- ).to("cuda")
67
-
68
- # ─── STEP 6: BUILD THE PIPELINE ───────────────────────────────────────────────
69
 
70
- print("🌟 Building StableDiffusionPipeline")
 
71
  pipe = StableDiffusionPipeline(
72
  vae=vae,
73
  text_encoder=text_encoder,
74
  tokenizer=tokenizer,
75
  unet=unet,
76
- scheduler=scheduler,
77
  ).to("cuda")
78
 
79
- # ─── STEP 7: APPLY LORA ADAPTER ───────────────────────────────────────────────
80
-
81
- print("🧠 Applying LoRA adapter")
82
- lora_config = LoraConfig(
83
- r=16,
84
- lora_alpha=16,
85
- bias="none",
86
- task_type="CAUSAL_LM",
87
- )
88
  pipe.unet = get_peft_model(pipe.unet, lora_config)
89
 
90
- # ─── STEP 8: YOUR TRAINING LOOP (SIMULATED) ────────────────────────────────────
91
-
92
- print(f"πŸ“‚ Loading dataset from: {DATA_DIR}")
93
  for step in range(100):
94
- # ←– here’s where you’d load your images, run forward/backward, optimizer, etc.
95
  print(f"Training step {step+1}/100")
 
96
 
97
- # ─── STEP 9: SAVE THE FINE‑TUNED LO‑RA WEIGHTS ───────────────────────────────
98
-
99
- os.makedirs(OUTPUT_DIR, exist_ok=True)
100
- pipe.save_pretrained(OUTPUT_DIR)
101
- print("βœ… Training complete. Saved to", OUTPUT_DIR)
 
5
  StableDiffusionPipeline,
6
  DPMSolverMultistepScheduler,
7
  AutoencoderKL,
8
+ UNet2DConditionModel
9
  )
10
  from transformers import CLIPTextModel, CLIPTokenizer
11
  from peft import LoraConfig, get_peft_model
12
 
13
+ MODEL_ID = "black-forest-labs/FLUX.1-dev"
14
+ dataset_path = "/workspace/data"
15
+ output_dir = "/workspace/lora-trained"
16
 
17
+ # 1) grab the model locally
18
+ print("πŸ“₯ Downloading Flux‑Dev model…")
19
+ model_path = snapshot_download(MODEL_ID, local_dir="./fluxdev-model")
 
20
 
21
+ # 2) load each piece with its correct subfolder
22
+ print("πŸ”„ Loading scheduler…")
 
 
 
 
 
 
 
 
 
 
23
  scheduler = DPMSolverMultistepScheduler.from_pretrained(
24
+ model_path, subfolder="scheduler"
 
25
  )
26
 
27
+ print("πŸ”„ Loading VAE…")
 
 
28
  vae = AutoencoderKL.from_pretrained(
29
+ model_path, subfolder="vae", torch_dtype=torch.float16
30
+ )
 
 
 
 
31
 
32
+ print("πŸ”„ Loading text encoder + tokenizer…")
33
  text_encoder = CLIPTextModel.from_pretrained(
34
+ model_path, subfolder="text_encoder", torch_dtype=torch.float16
35
+ )
 
 
36
  tokenizer = CLIPTokenizer.from_pretrained(
37
+ model_path, subfolder="tokenizer"
 
38
  )
39
 
40
+ print("πŸ”„ Loading U‑Net…")
 
 
41
  unet = UNet2DConditionModel.from_pretrained(
42
+ model_path, subfolder="unet", torch_dtype=torch.float16
43
+ )
 
 
 
 
44
 
45
+ # 3) assemble the pipeline
46
+ print("πŸ›  Assembling pipeline…")
47
  pipe = StableDiffusionPipeline(
48
  vae=vae,
49
  text_encoder=text_encoder,
50
  tokenizer=tokenizer,
51
  unet=unet,
52
+ scheduler=scheduler
53
  ).to("cuda")
54
 
55
+ # 4) apply LoRA
56
+ print("🧠 Applying LoRA…")
57
+ lora_config = LoraConfig(r=16, lora_alpha=16, bias="none", task_type="CAUSAL_LM")
 
 
 
 
 
 
58
  pipe.unet = get_peft_model(pipe.unet, lora_config)
59
 
60
+ # 5) your training loop (or dummy loop for illustration)
61
+ print("πŸš€ Starting fine‑tuning…")
 
62
  for step in range(100):
 
63
  print(f"Training step {step+1}/100")
64
+ # …insert your actual data‑loader and loss/backprop here…
65
 
66
+ os.makedirs(output_dir, exist_ok=True)
67
+ pipe.save_pretrained(output_dir)
68
+ print("βœ… Done. LoRA weights in", output_dir)