ramimu commited on
Commit
09b6938
Β·
verified Β·
1 Parent(s): f458d2e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +91 -0
app.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ import subprocess
4
+ import streamlit as st
5
+
6
+ # ─── 1. Mode detection & data directory ───────────────────────────────────────
7
+ # LOCAL_TRAIN=1 β†’ use "./data"; otherwise Spaces uses "/tmp/data"
8
+ LOCAL = os.environ.get("LOCAL_TRAIN", "").lower() in ("1", "true")
9
+ DATA_DIR = os.path.join(os.getcwd(), "data") if LOCAL else "/tmp/data"
10
+ os.makedirs(DATA_DIR, exist_ok=True)
11
+
12
+ # ─── 2. Page layout ───────────────────────────────────────────────────────────
13
+ st.set_page_config(page_title="HiDream LoRA Trainer", layout="wide")
14
+ st.title("🎨 HiDream LoRA Trainer (Streamlit)")
15
+
16
+ # Sidebar for configuration
17
+ with st.sidebar:
18
+ st.header("πŸ›  Configuration")
19
+ base_model = st.selectbox(
20
+ "Base Model",
21
+ ["HiDream-ai/HiDream-I1-Dev",
22
+ "runwayml/stable-diffusion-v1-5",
23
+ "stabilityai/stable-diffusion-2-1"]
24
+ )
25
+ trigger_word = st.text_input("Trigger Word", value="default-style")
26
+ num_steps = st.slider("Training Steps", min_value=10, max_value=500, value=100, step=10)
27
+ lora_r = st.slider("LoRA Rank (r)", min_value=4, max_value=128, value=16, step=4)
28
+ lora_alpha = st.slider("LoRA Alpha", min_value=4, max_value=128, value=16, step=4)
29
+
30
+ st.markdown("---")
31
+ st.header("πŸ“‚ Upload Dataset")
32
+ uploaded_files = st.file_uploader(
33
+ "Select your images & text files",
34
+ type=["jpg","jpeg","png","txt"],
35
+ accept_multiple_files=True
36
+ )
37
+ if st.button("Upload Dataset"):
38
+ # Clear old files
39
+ for f in os.listdir(DATA_DIR):
40
+ os.remove(os.path.join(DATA_DIR, f))
41
+ # Write new files
42
+ for up in uploaded_files:
43
+ dest = os.path.join(DATA_DIR, up.name)
44
+ with open(dest, "wb") as f:
45
+ f.write(up.getbuffer())
46
+ st.success(f"βœ… Uploaded {len(uploaded_files)} files to `{DATA_DIR}`")
47
+
48
+ st.markdown("---")
49
+ # Trigger training
50
+ if st.button("πŸš€ Start Training"):
51
+ st.session_state.training = True
52
+
53
+ # ─── 3. Training log area ─────────────────────────────────────────────────────
54
+ log_area = st.empty()
55
+
56
+ # ─── 4. Invoke training when triggered ────────────────────────────────────────
57
+ if st.session_state.get("training", False):
58
+ st.info("Training started… Logs below:")
59
+ log_lines = []
60
+ # Prepare environment for train.py
61
+ env = os.environ.copy()
62
+ env.update({
63
+ "BASE_MODEL": base_model,
64
+ "TRIGGER_WORD": trigger_word,
65
+ "NUM_STEPS": str(num_steps),
66
+ "LORA_R": str(lora_r),
67
+ "LORA_ALPHA": str(lora_alpha),
68
+ "LOCAL_TRAIN": os.environ.get("LOCAL_TRAIN","")
69
+ })
70
+
71
+ # Launch train.py as subprocess and stream logs
72
+ proc = subprocess.Popen(
73
+ ["python3", "train.py"],
74
+ stdout=subprocess.PIPE,
75
+ stderr=subprocess.STDOUT,
76
+ text=True,
77
+ env=env
78
+ )
79
+
80
+ for line in proc.stdout:
81
+ log_lines.append(line)
82
+ # Update the text area with all lines so far
83
+ log_area.text_area("Training Log", value="".join(log_lines), height=400)
84
+ proc.wait()
85
+
86
+ if proc.returncode == 0:
87
+ st.success("βœ… Training complete!")
88
+ else:
89
+ st.error(f"❌ Training failed (exit code {proc.returncode})")
90
+ # Reset trigger
91
+ st.session_state.training = False