File size: 3,679 Bytes
09b6938
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import os
import shutil
import subprocess
import streamlit as st

# ─── 1. Mode detection & data directory ───────────────────────────────────────
# LOCAL_TRAIN=1 β†’ use "./data"; otherwise Spaces uses "/tmp/data"
LOCAL = os.environ.get("LOCAL_TRAIN", "").lower() in ("1", "true")
DATA_DIR = os.path.join(os.getcwd(), "data") if LOCAL else "/tmp/data"
os.makedirs(DATA_DIR, exist_ok=True)

# ─── 2. Page layout ───────────────────────────────────────────────────────────
st.set_page_config(page_title="HiDream LoRA Trainer", layout="wide")
st.title("🎨 HiDream LoRA Trainer (Streamlit)")

# Sidebar for configuration
with st.sidebar:
    st.header("πŸ›  Configuration")
    base_model   = st.selectbox(
        "Base Model",
        ["HiDream-ai/HiDream-I1-Dev",
         "runwayml/stable-diffusion-v1-5",
         "stabilityai/stable-diffusion-2-1"]
    )
    trigger_word = st.text_input("Trigger Word", value="default-style")
    num_steps    = st.slider("Training Steps", min_value=10, max_value=500, value=100, step=10)
    lora_r       = st.slider("LoRA Rank (r)", min_value=4, max_value=128, value=16, step=4)
    lora_alpha   = st.slider("LoRA Alpha", min_value=4, max_value=128, value=16, step=4)

    st.markdown("---")
    st.header("πŸ“‚ Upload Dataset")
    uploaded_files = st.file_uploader(
        "Select your images & text files",
        type=["jpg","jpeg","png","txt"],
        accept_multiple_files=True
    )
    if st.button("Upload Dataset"):
        # Clear old files
        for f in os.listdir(DATA_DIR):
            os.remove(os.path.join(DATA_DIR, f))
        # Write new files
        for up in uploaded_files:
            dest = os.path.join(DATA_DIR, up.name)
            with open(dest, "wb") as f:
                f.write(up.getbuffer())
        st.success(f"βœ… Uploaded {len(uploaded_files)} files to `{DATA_DIR}`")

    st.markdown("---")
    # Trigger training
    if st.button("πŸš€ Start Training"):
        st.session_state.training = True

# ─── 3. Training log area ─────────────────────────────────────────────────────
log_area = st.empty()

# ─── 4. Invoke training when triggered ────────────────────────────────────────
if st.session_state.get("training", False):
    st.info("Training started… Logs below:")
    log_lines = []
    # Prepare environment for train.py
    env = os.environ.copy()
    env.update({
        "BASE_MODEL":    base_model,
        "TRIGGER_WORD":  trigger_word,
        "NUM_STEPS":     str(num_steps),
        "LORA_R":        str(lora_r),
        "LORA_ALPHA":    str(lora_alpha),
        "LOCAL_TRAIN":   os.environ.get("LOCAL_TRAIN","")
    })

    # Launch train.py as subprocess and stream logs
    proc = subprocess.Popen(
        ["python3", "train.py"],
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT,
        text=True,
        env=env
    )

    for line in proc.stdout:
        log_lines.append(line)
        # Update the text area with all lines so far
        log_area.text_area("Training Log", value="".join(log_lines), height=400)
    proc.wait()

    if proc.returncode == 0:
        st.success("βœ… Training complete!")
    else:
        st.error(f"❌ Training failed (exit code {proc.returncode})")
    # Reset trigger
    st.session_state.training = False