Spaces:
Paused
Paused
import os | |
import shutil | |
import subprocess | |
import streamlit as st | |
# βββ 1. Mode detection & data directory βββββββββββββββββββββββββββββββββββββββ | |
# LOCAL_TRAIN=1 β use "./data"; otherwise Spaces uses "/tmp/data" | |
LOCAL = os.environ.get("LOCAL_TRAIN", "").lower() in ("1", "true") | |
DATA_DIR = os.path.join(os.getcwd(), "data") if LOCAL else "/tmp/data" | |
os.makedirs(DATA_DIR, exist_ok=True) | |
# βββ 2. Page layout βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
st.set_page_config(page_title="HiDream LoRA Trainer", layout="wide") | |
st.title("π¨ HiDream LoRA Trainer (Streamlit)") | |
# Sidebar for configuration | |
with st.sidebar: | |
st.header("π Configuration") | |
base_model = st.selectbox( | |
"Base Model", | |
["HiDream-ai/HiDream-I1-Dev", | |
"runwayml/stable-diffusion-v1-5", | |
"stabilityai/stable-diffusion-2-1"] | |
) | |
trigger_word = st.text_input("Trigger Word", value="default-style") | |
num_steps = st.slider("Training Steps", min_value=10, max_value=500, value=100, step=10) | |
lora_r = st.slider("LoRA Rank (r)", min_value=4, max_value=128, value=16, step=4) | |
lora_alpha = st.slider("LoRA Alpha", min_value=4, max_value=128, value=16, step=4) | |
st.markdown("---") | |
st.header("π Upload Dataset") | |
uploaded_files = st.file_uploader( | |
"Select your images & text files", | |
type=["jpg","jpeg","png","txt"], | |
accept_multiple_files=True | |
) | |
if st.button("Upload Dataset"): | |
# Clear old files | |
for f in os.listdir(DATA_DIR): | |
os.remove(os.path.join(DATA_DIR, f)) | |
# Write new files | |
for up in uploaded_files: | |
dest = os.path.join(DATA_DIR, up.name) | |
with open(dest, "wb") as f: | |
f.write(up.getbuffer()) | |
st.success(f"β Uploaded {len(uploaded_files)} files to `{DATA_DIR}`") | |
st.markdown("---") | |
# Trigger training | |
if st.button("π Start Training"): | |
st.session_state.training = True | |
# βββ 3. Training log area βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
log_area = st.empty() | |
# βββ 4. Invoke training when triggered ββββββββββββββββββββββββββββββββββββββββ | |
if st.session_state.get("training", False): | |
st.info("Training started⦠Logs below:") | |
log_lines = [] | |
# Prepare environment for train.py | |
env = os.environ.copy() | |
env.update({ | |
"BASE_MODEL": base_model, | |
"TRIGGER_WORD": trigger_word, | |
"NUM_STEPS": str(num_steps), | |
"LORA_R": str(lora_r), | |
"LORA_ALPHA": str(lora_alpha), | |
"LOCAL_TRAIN": os.environ.get("LOCAL_TRAIN","") | |
}) | |
# Launch train.py as subprocess and stream logs | |
proc = subprocess.Popen( | |
["python3", "train.py"], | |
stdout=subprocess.PIPE, | |
stderr=subprocess.STDOUT, | |
text=True, | |
env=env | |
) | |
for line in proc.stdout: | |
log_lines.append(line) | |
# Update the text area with all lines so far | |
log_area.text_area("Training Log", value="".join(log_lines), height=400) | |
proc.wait() | |
if proc.returncode == 0: | |
st.success("β Training complete!") | |
else: | |
st.error(f"β Training failed (exit code {proc.returncode})") | |
# Reset trigger | |
st.session_state.training = False | |