LoRa_Streamlit / app.py
ramimu's picture
Create app.py
09b6938 verified
import os
import shutil
import subprocess
import streamlit as st
# ─── 1. Mode detection & data directory ───────────────────────────────────────
# LOCAL_TRAIN=1 β†’ use "./data"; otherwise Spaces uses "/tmp/data"
LOCAL = os.environ.get("LOCAL_TRAIN", "").lower() in ("1", "true")
DATA_DIR = os.path.join(os.getcwd(), "data") if LOCAL else "/tmp/data"
os.makedirs(DATA_DIR, exist_ok=True)
# ─── 2. Page layout ───────────────────────────────────────────────────────────
st.set_page_config(page_title="HiDream LoRA Trainer", layout="wide")
st.title("🎨 HiDream LoRA Trainer (Streamlit)")
# Sidebar for configuration
with st.sidebar:
st.header("πŸ›  Configuration")
base_model = st.selectbox(
"Base Model",
["HiDream-ai/HiDream-I1-Dev",
"runwayml/stable-diffusion-v1-5",
"stabilityai/stable-diffusion-2-1"]
)
trigger_word = st.text_input("Trigger Word", value="default-style")
num_steps = st.slider("Training Steps", min_value=10, max_value=500, value=100, step=10)
lora_r = st.slider("LoRA Rank (r)", min_value=4, max_value=128, value=16, step=4)
lora_alpha = st.slider("LoRA Alpha", min_value=4, max_value=128, value=16, step=4)
st.markdown("---")
st.header("πŸ“‚ Upload Dataset")
uploaded_files = st.file_uploader(
"Select your images & text files",
type=["jpg","jpeg","png","txt"],
accept_multiple_files=True
)
if st.button("Upload Dataset"):
# Clear old files
for f in os.listdir(DATA_DIR):
os.remove(os.path.join(DATA_DIR, f))
# Write new files
for up in uploaded_files:
dest = os.path.join(DATA_DIR, up.name)
with open(dest, "wb") as f:
f.write(up.getbuffer())
st.success(f"βœ… Uploaded {len(uploaded_files)} files to `{DATA_DIR}`")
st.markdown("---")
# Trigger training
if st.button("πŸš€ Start Training"):
st.session_state.training = True
# ─── 3. Training log area ─────────────────────────────────────────────────────
log_area = st.empty()
# ─── 4. Invoke training when triggered ────────────────────────────────────────
if st.session_state.get("training", False):
st.info("Training started… Logs below:")
log_lines = []
# Prepare environment for train.py
env = os.environ.copy()
env.update({
"BASE_MODEL": base_model,
"TRIGGER_WORD": trigger_word,
"NUM_STEPS": str(num_steps),
"LORA_R": str(lora_r),
"LORA_ALPHA": str(lora_alpha),
"LOCAL_TRAIN": os.environ.get("LOCAL_TRAIN","")
})
# Launch train.py as subprocess and stream logs
proc = subprocess.Popen(
["python3", "train.py"],
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
env=env
)
for line in proc.stdout:
log_lines.append(line)
# Update the text area with all lines so far
log_area.text_area("Training Log", value="".join(log_lines), height=400)
proc.wait()
if proc.returncode == 0:
st.success("βœ… Training complete!")
else:
st.error(f"❌ Training failed (exit code {proc.returncode})")
# Reset trigger
st.session_state.training = False