Spaces:
Paused
Paused
# train.py | |
import os | |
import torch | |
from huggingface_hub import snapshot_download | |
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler | |
from peft import LoraConfig, get_peft_model | |
# ββ 1) Configuration βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
# Where you put your images + prompts | |
DATA_DIR = os.getenv("DATA_DIR", "./data") | |
# Where your base model lives (downloaded or cached) | |
MODEL_DIR = os.getenv("MODEL_DIR", "./hidream-model") | |
# Where to save your LoRAβfineβtuned model | |
OUTPUT_DIR = os.getenv("OUTPUT_DIR", "./lora-trained") | |
# ββ 2) Prepare the base model snapshot ββββββββββββββββββββββββββββββββββββββββ | |
print(f"π Loading dataset from: {DATA_DIR}") | |
print("π₯ Fetching or verifying base model: HiDream-ai/HiDream-I1-Dev") | |
# If youβve preβdownloaded into MODEL_DIR, just use it; otherwise pull from HF Hub | |
if not os.path.isdir(MODEL_DIR): | |
MODEL_DIR = snapshot_download( | |
repo_id="HiDream-ai/HiDream-I1-Dev", | |
local_dir=MODEL_DIR | |
) | |
# ββ 3) Load the scheduler manually βββββββββββββββββββββββββββββββββββββββββββββ | |
# Diffusersβ scheduler config JSON points at FlowMatchLCMScheduler, | |
# but your installed version doesnβt have that class. Instead we | |
# forceβload DPMSolverMultistepScheduler via `from_pretrained`. | |
print(f"π Loading scheduler from: {MODEL_DIR}/scheduler") | |
scheduler = DPMSolverMultistepScheduler.from_pretrained( | |
pretrained_model_name_or_path=MODEL_DIR, | |
subfolder="scheduler" | |
) | |
# ββ 4) Build the Stable Diffusion pipeline ββββββββββββββββββββββββββββββββββββ | |
print("π§ Creating StableDiffusionPipeline with custom scheduler") | |
pipe = StableDiffusionPipeline.from_pretrained( | |
pretrained_model_name_or_path=MODEL_DIR, | |
scheduler=scheduler, | |
torch_dtype=torch.float16, | |
).to("cuda") | |
# ββ 5) Apply PEFT LoRA adapters βββββββββββββββββββββββββββββββββββββββββββββββ | |
print("π§ Configuring LoRA adapter on UβNet") | |
lora_config = LoraConfig( | |
r=16, | |
lora_alpha=16, | |
bias="none", | |
task_type="CAUSAL_LM" | |
) | |
pipe.unet = get_peft_model(pipe.unet, lora_config) | |
# ββ 6) (Placeholder) Simulate your training loop βββββββββββββββββββββββββββββ | |
print("π Starting fineβtuning loop (simulated)") | |
for step in range(100): | |
# Here you'd load your data, compute loss, do optimizer.step(), etc. | |
print(f" Training step {step+1}/100") | |
# ββ 7) Save your LoRAβtuned model ββββββββββββββββββββββββββββββββββββββββββββ | |
os.makedirs(OUTPUT_DIR, exist_ok=True) | |
pipe.save_pretrained(OUTPUT_DIR) | |
print("β Training complete. Model saved to", OUTPUT_DIR) | |