LoRa_Streamlit / train.py
ramimu's picture
Update train.py
0a3593d verified
raw
history blame
3.21 kB
# train.py
import os
import torch
from huggingface_hub import snapshot_download
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
from peft import LoraConfig, get_peft_model
# ── 1) Configuration ───────────────────────────────────────────────────────────
# Where you put your images + prompts
DATA_DIR = os.getenv("DATA_DIR", "./data")
# Where your base model lives (downloaded or cached)
MODEL_DIR = os.getenv("MODEL_DIR", "./hidream-model")
# Where to save your LoRA‑fine‑tuned model
OUTPUT_DIR = os.getenv("OUTPUT_DIR", "./lora-trained")
# ── 2) Prepare the base model snapshot ────────────────────────────────────────
print(f"πŸ“‚ Loading dataset from: {DATA_DIR}")
print("πŸ“₯ Fetching or verifying base model: HiDream-ai/HiDream-I1-Dev")
# If you’ve pre‑downloaded into MODEL_DIR, just use it; otherwise pull from HF Hub
if not os.path.isdir(MODEL_DIR):
MODEL_DIR = snapshot_download(
repo_id="HiDream-ai/HiDream-I1-Dev",
local_dir=MODEL_DIR
)
# ── 3) Load the scheduler manually ─────────────────────────────────────────────
# Diffusers’ scheduler config JSON points at FlowMatchLCMScheduler,
# but your installed version doesn’t have that class. Instead we
# force‐load DPMSolverMultistepScheduler via `from_pretrained`.
print(f"πŸ”„ Loading scheduler from: {MODEL_DIR}/scheduler")
scheduler = DPMSolverMultistepScheduler.from_pretrained(
pretrained_model_name_or_path=MODEL_DIR,
subfolder="scheduler"
)
# ── 4) Build the Stable Diffusion pipeline ────────────────────────────────────
print("πŸ”§ Creating StableDiffusionPipeline with custom scheduler")
pipe = StableDiffusionPipeline.from_pretrained(
pretrained_model_name_or_path=MODEL_DIR,
scheduler=scheduler,
torch_dtype=torch.float16,
).to("cuda")
# ── 5) Apply PEFT LoRA adapters ───────────────────────────────────────────────
print("🧠 Configuring LoRA adapter on U‑Net")
lora_config = LoraConfig(
r=16,
lora_alpha=16,
bias="none",
task_type="CAUSAL_LM"
)
pipe.unet = get_peft_model(pipe.unet, lora_config)
# ── 6) (Placeholder) Simulate your training loop ─────────────────────────────
print("πŸš€ Starting fine‑tuning loop (simulated)")
for step in range(100):
# Here you'd load your data, compute loss, do optimizer.step(), etc.
print(f" Training step {step+1}/100")
# ── 7) Save your LoRA‑tuned model ────────────────────────────────────────────
os.makedirs(OUTPUT_DIR, exist_ok=True)
pipe.save_pretrained(OUTPUT_DIR)
print("βœ… Training complete. Model saved to", OUTPUT_DIR)