Spaces:
Paused
Paused
File size: 3,213 Bytes
0a3593d c9b1bf6 0a3593d c9b1bf6 0a3593d c9b1bf6 0a3593d c9b1bf6 0a3593d c9b1bf6 0a3593d c9b1bf6 0a3593d c9b1bf6 0a3593d c9b1bf6 0a3593d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
# train.py
import os
import torch
from huggingface_hub import snapshot_download
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
from peft import LoraConfig, get_peft_model
# ββ 1) Configuration βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Where you put your images + prompts
DATA_DIR = os.getenv("DATA_DIR", "./data")
# Where your base model lives (downloaded or cached)
MODEL_DIR = os.getenv("MODEL_DIR", "./hidream-model")
# Where to save your LoRAβfineβtuned model
OUTPUT_DIR = os.getenv("OUTPUT_DIR", "./lora-trained")
# ββ 2) Prepare the base model snapshot ββββββββββββββββββββββββββββββββββββββββ
print(f"π Loading dataset from: {DATA_DIR}")
print("π₯ Fetching or verifying base model: HiDream-ai/HiDream-I1-Dev")
# If youβve preβdownloaded into MODEL_DIR, just use it; otherwise pull from HF Hub
if not os.path.isdir(MODEL_DIR):
MODEL_DIR = snapshot_download(
repo_id="HiDream-ai/HiDream-I1-Dev",
local_dir=MODEL_DIR
)
# ββ 3) Load the scheduler manually βββββββββββββββββββββββββββββββββββββββββββββ
# Diffusersβ scheduler config JSON points at FlowMatchLCMScheduler,
# but your installed version doesnβt have that class. Instead we
# forceβload DPMSolverMultistepScheduler via `from_pretrained`.
print(f"π Loading scheduler from: {MODEL_DIR}/scheduler")
scheduler = DPMSolverMultistepScheduler.from_pretrained(
pretrained_model_name_or_path=MODEL_DIR,
subfolder="scheduler"
)
# ββ 4) Build the Stable Diffusion pipeline ββββββββββββββββββββββββββββββββββββ
print("π§ Creating StableDiffusionPipeline with custom scheduler")
pipe = StableDiffusionPipeline.from_pretrained(
pretrained_model_name_or_path=MODEL_DIR,
scheduler=scheduler,
torch_dtype=torch.float16,
).to("cuda")
# ββ 5) Apply PEFT LoRA adapters βββββββββββββββββββββββββββββββββββββββββββββββ
print("π§ Configuring LoRA adapter on UβNet")
lora_config = LoraConfig(
r=16,
lora_alpha=16,
bias="none",
task_type="CAUSAL_LM"
)
pipe.unet = get_peft_model(pipe.unet, lora_config)
# ββ 6) (Placeholder) Simulate your training loop βββββββββββββββββββββββββββββ
print("π Starting fineβtuning loop (simulated)")
for step in range(100):
# Here you'd load your data, compute loss, do optimizer.step(), etc.
print(f" Training step {step+1}/100")
# ββ 7) Save your LoRAβtuned model ββββββββββββββββββββββββββββββββββββββββββββ
os.makedirs(OUTPUT_DIR, exist_ok=True)
pipe.save_pretrained(OUTPUT_DIR)
print("β
Training complete. Model saved to", OUTPUT_DIR)
|