import os import shutil import subprocess import streamlit as st # ─── 1. Mode detection & data directory ─────────────────────────────────────── # LOCAL_TRAIN=1 → use "./data"; otherwise Spaces uses "/tmp/data" LOCAL = os.environ.get("LOCAL_TRAIN", "").lower() in ("1", "true") DATA_DIR = os.path.join(os.getcwd(), "data") if LOCAL else "/tmp/data" os.makedirs(DATA_DIR, exist_ok=True) # ─── 2. Page layout ─────────────────────────────────────────────────────────── st.set_page_config(page_title="HiDream LoRA Trainer", layout="wide") st.title("🎨 HiDream LoRA Trainer (Streamlit)") # Sidebar for configuration with st.sidebar: st.header("🛠 Configuration") base_model = st.selectbox( "Base Model", ["HiDream-ai/HiDream-I1-Dev", "runwayml/stable-diffusion-v1-5", "stabilityai/stable-diffusion-2-1"] ) trigger_word = st.text_input("Trigger Word", value="default-style") num_steps = st.slider("Training Steps", min_value=10, max_value=500, value=100, step=10) lora_r = st.slider("LoRA Rank (r)", min_value=4, max_value=128, value=16, step=4) lora_alpha = st.slider("LoRA Alpha", min_value=4, max_value=128, value=16, step=4) st.markdown("---") st.header("📂 Upload Dataset") uploaded_files = st.file_uploader( "Select your images & text files", type=["jpg","jpeg","png","txt"], accept_multiple_files=True ) if st.button("Upload Dataset"): # Clear old files for f in os.listdir(DATA_DIR): os.remove(os.path.join(DATA_DIR, f)) # Write new files for up in uploaded_files: dest = os.path.join(DATA_DIR, up.name) with open(dest, "wb") as f: f.write(up.getbuffer()) st.success(f"✅ Uploaded {len(uploaded_files)} files to `{DATA_DIR}`") st.markdown("---") # Trigger training if st.button("🚀 Start Training"): st.session_state.training = True # ─── 3. Training log area ───────────────────────────────────────────────────── log_area = st.empty() # ─── 4. Invoke training when triggered ──────────────────────────────────────── if st.session_state.get("training", False): st.info("Training started… Logs below:") log_lines = [] # Prepare environment for train.py env = os.environ.copy() env.update({ "BASE_MODEL": base_model, "TRIGGER_WORD": trigger_word, "NUM_STEPS": str(num_steps), "LORA_R": str(lora_r), "LORA_ALPHA": str(lora_alpha), "LOCAL_TRAIN": os.environ.get("LOCAL_TRAIN","") }) # Launch train.py as subprocess and stream logs proc = subprocess.Popen( ["python3", "train.py"], stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, env=env ) for line in proc.stdout: log_lines.append(line) # Update the text area with all lines so far log_area.text_area("Training Log", value="".join(log_lines), height=400) proc.wait() if proc.returncode == 0: st.success("✅ Training complete!") else: st.error(f"❌ Training failed (exit code {proc.returncode})") # Reset trigger st.session_state.training = False