ramimu commited on
Commit
81e7d73
·
verified ·
1 Parent(s): aff7e63

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +40 -65
train.py CHANGED
@@ -1,80 +1,55 @@
1
  import os
2
  import torch
3
- from huggingface_hub import snapshot_download
4
- from diffusers import (
5
- StableDiffusionPipeline,
6
- DPMSolverMultistepScheduler,
7
- AutoencoderKL,
8
- UNet2DConditionModel
9
  )
10
- from transformers import CLIPTextModel, CLIPTokenizer
11
- from peft import LoraConfig, get_peft_model
12
 
13
- MODEL_ID = "black-forest-labs/FLUX.1-dev"
 
 
 
 
14
 
15
- # download
16
- model_path = snapshot_download(
17
- MODEL_ID,
18
- local_dir="./fluxdev-model",
19
- use_auth_token=True
20
  )
21
 
22
- # later loading
23
- pipe = StableDiffusionPipeline.from_pretrained(
24
- model_path,
25
- torch_dtype=torch.float16,
26
- use_auth_token=True
27
- ).to("cuda")
28
-
29
- # 1) grab the model locally
30
- print("📥 Downloading Flux‑Dev model…")
31
- model_path = snapshot_download(MODEL_ID, local_dir="./fluxdev-model")
32
 
33
- # 2) load each piece with its correct subfolder
34
- print("🔄 Loading scheduler…")
35
- scheduler = DPMSolverMultistepScheduler.from_pretrained(
36
- model_path, subfolder="scheduler"
 
 
37
  )
38
 
39
- print("🔄 Loading VAE…")
40
- vae = AutoencoderKL.from_pretrained(
41
- model_path, subfolder="vae", torch_dtype=torch.float16
42
- )
43
 
44
- print("🔄 Loading text encoder + tokenizer…")
45
- text_encoder = CLIPTextModel.from_pretrained(
46
- model_path, subfolder="text_encoder", torch_dtype=torch.float16
47
- )
48
- tokenizer = CLIPTokenizer.from_pretrained(
49
- model_path, subfolder="tokenizer"
50
- )
51
 
52
- print("🔄 Loading U‑Net…")
53
- unet = UNet2DConditionModel.from_pretrained(
54
- model_path, subfolder="unet", torch_dtype=torch.float16
 
 
55
  )
56
 
57
- # 3) assemble the pipeline
58
- print("🛠 Assembling pipeline…")
59
- pipe = StableDiffusionPipeline(
60
- vae=vae,
61
- text_encoder=text_encoder,
62
- tokenizer=tokenizer,
63
- unet=unet,
64
- scheduler=scheduler
65
- ).to("cuda")
66
-
67
- # 4) apply LoRA
68
- print("🧠 Applying LoRA…")
69
- lora_config = LoraConfig(r=16, lora_alpha=16, bias="none", task_type="CAUSAL_LM")
70
- pipe.unet = get_peft_model(pipe.unet, lora_config)
71
-
72
- # 5) your training loop (or dummy loop for illustration)
73
- print("🚀 Starting fine‑tuning…")
74
- for step in range(100):
75
- print(f"Training step {step+1}/100")
76
- # …insert your actual data‑loader and loss/backprop here…
77
 
78
- os.makedirs(output_dir, exist_ok=True)
79
- pipe.save_pretrained(output_dir)
80
- print("✅ Done. LoRA weights in", output_dir)
 
1
  import os
2
  import torch
3
+ from aitoolkit import (
4
+ LoRATrainer,
5
+ StableDiffusionModel,
6
+ LoRAConfig,
7
+ ImageTextDataset,
 
8
  )
 
 
9
 
10
+ # 1. Configuration
11
+ MODEL_ID = "HiDream-ai/HiDream-I1-Dev" # or your gated FLUX model if you have access
12
+ DATA_DIR = "/workspace/data"
13
+ OUTPUT_DIR = "/workspace/lora-trained"
14
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
15
 
16
+ lora_cfg = LoRAConfig(
17
+ rank=16,
18
+ alpha=16,
19
+ bias="none",
 
20
  )
21
 
22
+ training_args = {
23
+ "num_train_steps": 100,
24
+ "batch_size": 4,
25
+ "learning_rate": 1e-4,
26
+ "save_every_n_steps": 50,
27
+ "output_dir": OUTPUT_DIR,
28
+ }
 
 
 
29
 
30
+ # 2. Load base diffusion model
31
+ model = StableDiffusionModel.from_pretrained(
32
+ MODEL_ID,
33
+ torch_dtype=torch.float16,
34
+ device=DEVICE,
35
+ use_auth_token=True, # if it’s a gated repo
36
  )
37
 
38
+ # 3. Prepare your dataset
39
+ # Expects pairs of image files + .txt captions in DATA_DIR
40
+ dataset = ImageTextDataset(data_root=DATA_DIR, image_size=512)
 
41
 
42
+ # 4. Hook up the LoRA adapter
43
+ model.apply_lora(lora_cfg)
 
 
 
 
 
44
 
45
+ # 5. Create the trainer and kickoff
46
+ trainer = LoRATrainer(
47
+ model=model,
48
+ dataset=dataset,
49
+ args=training_args,
50
  )
51
 
52
+ print("🚀 Starting training with AI‑Toolkit…")
53
+ trainer.train()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
+ print(f"✅ Done! Fine-tuned weights saved to {OUTPUT_DIR}")