Spaces:
Paused
Paused
File size: 3,679 Bytes
09b6938 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
import os
import shutil
import subprocess
import streamlit as st
# βββ 1. Mode detection & data directory βββββββββββββββββββββββββββββββββββββββ
# LOCAL_TRAIN=1 β use "./data"; otherwise Spaces uses "/tmp/data"
LOCAL = os.environ.get("LOCAL_TRAIN", "").lower() in ("1", "true")
DATA_DIR = os.path.join(os.getcwd(), "data") if LOCAL else "/tmp/data"
os.makedirs(DATA_DIR, exist_ok=True)
# βββ 2. Page layout βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
st.set_page_config(page_title="HiDream LoRA Trainer", layout="wide")
st.title("π¨ HiDream LoRA Trainer (Streamlit)")
# Sidebar for configuration
with st.sidebar:
st.header("π Configuration")
base_model = st.selectbox(
"Base Model",
["HiDream-ai/HiDream-I1-Dev",
"runwayml/stable-diffusion-v1-5",
"stabilityai/stable-diffusion-2-1"]
)
trigger_word = st.text_input("Trigger Word", value="default-style")
num_steps = st.slider("Training Steps", min_value=10, max_value=500, value=100, step=10)
lora_r = st.slider("LoRA Rank (r)", min_value=4, max_value=128, value=16, step=4)
lora_alpha = st.slider("LoRA Alpha", min_value=4, max_value=128, value=16, step=4)
st.markdown("---")
st.header("π Upload Dataset")
uploaded_files = st.file_uploader(
"Select your images & text files",
type=["jpg","jpeg","png","txt"],
accept_multiple_files=True
)
if st.button("Upload Dataset"):
# Clear old files
for f in os.listdir(DATA_DIR):
os.remove(os.path.join(DATA_DIR, f))
# Write new files
for up in uploaded_files:
dest = os.path.join(DATA_DIR, up.name)
with open(dest, "wb") as f:
f.write(up.getbuffer())
st.success(f"β
Uploaded {len(uploaded_files)} files to `{DATA_DIR}`")
st.markdown("---")
# Trigger training
if st.button("π Start Training"):
st.session_state.training = True
# βββ 3. Training log area βββββββββββββββββββββββββββββββββββββββββββββββββββββ
log_area = st.empty()
# βββ 4. Invoke training when triggered ββββββββββββββββββββββββββββββββββββββββ
if st.session_state.get("training", False):
st.info("Training started⦠Logs below:")
log_lines = []
# Prepare environment for train.py
env = os.environ.copy()
env.update({
"BASE_MODEL": base_model,
"TRIGGER_WORD": trigger_word,
"NUM_STEPS": str(num_steps),
"LORA_R": str(lora_r),
"LORA_ALPHA": str(lora_alpha),
"LOCAL_TRAIN": os.environ.get("LOCAL_TRAIN","")
})
# Launch train.py as subprocess and stream logs
proc = subprocess.Popen(
["python3", "train.py"],
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
env=env
)
for line in proc.stdout:
log_lines.append(line)
# Update the text area with all lines so far
log_area.text_area("Training Log", value="".join(log_lines), height=400)
proc.wait()
if proc.returncode == 0:
st.success("β
Training complete!")
else:
st.error(f"β Training failed (exit code {proc.returncode})")
# Reset trigger
st.session_state.training = False
|