Update app.py
Browse files
app.py
CHANGED
@@ -1,28 +1,3 @@
|
|
1 |
-
# gradio_dataset_manager.py – Unified YOLO Dataset Toolkit (Evaluate + Merge/Edit)
|
2 |
-
"""
|
3 |
-
Gradio application that **combines** the functionality of the original evaluation‑only app and
|
4 |
-
the Streamlit dataset‑merging dashboard.
|
5 |
-
|
6 |
-
### Key features added
|
7 |
-
* **Multi‑dataset loader**
|
8 |
-
– Roboflow URLs (with automatic latest‑version fallback)
|
9 |
-
– ZIP uploads (GitHub/Ultralytics releases, etc.)
|
10 |
-
– Existing server paths
|
11 |
-
* **Interactive class manager** (rename / remove / per‑class max images) using an **editable
|
12 |
-
Gradio `Dataframe`**. Multiple rows can share the same *New name* to merge classes.
|
13 |
-
* **Dataset merger** with:
|
14 |
-
– Per‑class image limits
|
15 |
-
– Consistent re‑indexing after renames/removals
|
16 |
-
– Duplicate image avoidance
|
17 |
-
– Final `data.yaml` + ready‑to‑train directory structure
|
18 |
-
* One‑click **ZIP download** of the merged dataset
|
19 |
-
* Original **Quality Evaluation** preserved (blur/IOU/conf sliders, fastdup, optional
|
20 |
-
Model‑QA/Cleanlab)
|
21 |
-
|
22 |
-
> ⚠️ This is a standalone Python script – drop it into a Hugging Face **Space** or run with
|
23 |
-
> `python gradio_dataset_manager.py`. Requires the same pip deps you already list.
|
24 |
-
"""
|
25 |
-
|
26 |
from __future__ import annotations
|
27 |
|
28 |
import base64
|
@@ -36,6 +11,7 @@ import re
|
|
36 |
import shutil
|
37 |
import stat
|
38 |
import tempfile
|
|
|
39 |
from collections import Counter
|
40 |
from concurrent.futures import ThreadPoolExecutor, as_completed
|
41 |
from dataclasses import dataclass
|
@@ -178,18 +154,35 @@ def get_model(weights: str) -> YOLO | None:
|
|
178 |
# -------------------- Roboflow helpers --------------------
|
179 |
RF_RE = re.compile(r"https?://universe\.roboflow\.com/([^/]+)/([^/]+)/?([^/]*)")
|
180 |
|
181 |
-
|
182 |
-
|
183 |
-
|
|
|
|
|
184 |
m = RF_RE.match(url.strip())
|
185 |
if not m:
|
186 |
return None, None, None
|
187 |
ws, proj, tail = m.groups()
|
188 |
-
ver = None
|
|
|
|
|
189 |
if tail.startswith("dataset/"):
|
190 |
-
|
191 |
-
|
192 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
193 |
return ws, proj, ver
|
194 |
|
195 |
|
@@ -204,32 +197,44 @@ def get_latest_version(rf: Roboflow, ws: str, proj: str) -> str | None:
|
|
204 |
return None
|
205 |
|
206 |
|
207 |
-
def download_roboflow_dataset(
|
|
|
|
|
|
|
|
|
208 |
"""Return (dataset_location, class_names, splits). Caches by folder name."""
|
209 |
if Roboflow is None:
|
210 |
raise RuntimeError("`roboflow` pip package not installed")
|
|
|
211 |
ws, proj, ver = parse_roboflow_url(url)
|
212 |
if not (ws and proj):
|
213 |
-
raise ValueError("Bad Roboflow URL")
|
214 |
|
215 |
rf = Roboflow(api_key=rf_api_key)
|
216 |
-
|
217 |
-
|
218 |
-
|
|
|
|
|
219 |
raise RuntimeError("Could not resolve latest Roboflow version")
|
|
|
|
|
|
|
|
|
220 |
|
221 |
ds_dir = TMP_ROOT / f"{ws}_{proj}_v{ver}"
|
222 |
if ds_dir.exists():
|
223 |
yaml_path = ds_dir / "data.yaml"
|
224 |
class_names = load_yaml(yaml_path).get("names", []) if yaml_path.exists() else []
|
225 |
-
splits = [s for s in
|
226 |
return ds_dir, class_names, splits
|
227 |
|
228 |
ds_dir.mkdir(parents=True, exist_ok=True)
|
229 |
-
rf.workspace(ws).project(proj).version(
|
|
|
230 |
yaml_path = ds_dir / "data.yaml"
|
231 |
class_names = load_yaml(yaml_path).get("names", []) if yaml_path.exists() else []
|
232 |
-
splits = [s for s in
|
233 |
return ds_dir, class_names, splits
|
234 |
|
235 |
|
@@ -252,8 +257,6 @@ def gather_class_counts(dataset_info_list, class_name_mapping):
|
|
252 |
return dict(counts)
|
253 |
|
254 |
|
255 |
-
# -- label‑file worker (same logic, no Streamlit)
|
256 |
-
|
257 |
def _process_label_file(label_path: Path, class_names_dataset, class_name_mapping):
|
258 |
im_name = label_path.stem + label_path.suffix.replace(".txt", ".jpg")
|
259 |
img_classes = set()
|
@@ -265,10 +268,6 @@ def _process_label_file(label_path: Path, class_names_dataset, class_name_mappin
|
|
265 |
return im_name, img_classes
|
266 |
|
267 |
|
268 |
-
# ---------------------------------------------------------------------------
|
269 |
-
# merge_datasets(): **pure‑python** version (no streamlit session state)
|
270 |
-
# ---------------------------------------------------------------------------
|
271 |
-
|
272 |
def merge_datasets(
|
273 |
dataset_info_list: List[Tuple[str, List[str], List[str], str]],
|
274 |
class_map_df: pd.DataFrame,
|
@@ -284,21 +283,18 @@ def merge_datasets(
|
|
284 |
(out_dir / "valid/images").mkdir(parents=True, exist_ok=True)
|
285 |
(out_dir / "valid/labels").mkdir(parents=True, exist_ok=True)
|
286 |
|
287 |
-
# Build mapping dicts ----------------------------------------------------
|
288 |
class_name_mapping = {
|
289 |
-
row["original_class"]: row["new_name"] if row["remove"]
|
290 |
for _, row in class_map_df.iterrows()
|
291 |
}
|
292 |
limits_per_merged = {
|
293 |
row["new_name"]: int(row["max_images"])
|
294 |
for _, row in class_map_df.iterrows()
|
295 |
-
if row["remove"]
|
296 |
}
|
297 |
-
# active merged classes only
|
298 |
active_classes = [c for c in sorted(set(class_name_mapping.values())) if c != "__REMOVED__"]
|
299 |
id_map = {cls: idx for idx, cls in enumerate(active_classes)}
|
300 |
|
301 |
-
# Scan label files -------------------------------------------------------
|
302 |
image_to_classes: dict[str, set[str]] = {}
|
303 |
image_to_label: dict[str, Path] = {}
|
304 |
class_to_images: dict[str, set[str]] = {c: set() for c in active_classes}
|
@@ -319,7 +315,6 @@ def merge_datasets(
|
|
319 |
for c in cls_set:
|
320 |
class_to_images[c].add(img_path)
|
321 |
|
322 |
-
# Select images respecting per‑class limits -----------------------------
|
323 |
selected_images: set[str] = set()
|
324 |
counters = {c: 0 for c in active_classes}
|
325 |
shuffle_pool = [img for imgs in class_to_images.values() for img in imgs]
|
@@ -327,14 +322,12 @@ def merge_datasets(
|
|
327 |
|
328 |
for img in shuffle_pool:
|
329 |
cls_set = image_to_classes[img]
|
330 |
-
# skip if any class hit its limit
|
331 |
if any(counters[c] >= limits_per_merged.get(c, 0) for c in cls_set):
|
332 |
continue
|
333 |
selected_images.add(img)
|
334 |
for c in cls_set:
|
335 |
counters[c] += 1
|
336 |
|
337 |
-
# Copy & re‑index --------------------------------------------------------
|
338 |
for img in selected_images:
|
339 |
split = "train" if random.random() < 0.9 else "valid"
|
340 |
dst_img = out_dir / split / "images" / Path(img).name
|
@@ -352,7 +345,6 @@ def merge_datasets(
|
|
352 |
if not parts:
|
353 |
continue
|
354 |
cid = int(parts[0])
|
355 |
-
# find orig class name
|
356 |
dloc_match = next((cl for dloc2, cl, _, _ in dataset_info_list if str(lp_src).startswith(dloc2)), None)
|
357 |
if dloc_match is None:
|
358 |
continue
|
@@ -370,7 +362,6 @@ def merge_datasets(
|
|
370 |
else:
|
371 |
(out_dir / split / "images" / Path(img).name).unlink(missing_ok=True)
|
372 |
|
373 |
-
# Build data.yaml --------------------------------------------------------
|
374 |
data_yaml = {
|
375 |
"path": str(out_dir.resolve()),
|
376 |
"train": "train/images",
|
@@ -384,12 +375,14 @@ def merge_datasets(
|
|
384 |
return out_dir
|
385 |
|
386 |
|
387 |
-
# Utility: zip a folder to bytes -------------------------------------------
|
388 |
-
|
389 |
def zip_directory(folder: Path) -> bytes:
|
390 |
buf = io.BytesIO()
|
391 |
-
with
|
392 |
-
|
|
|
|
|
|
|
|
|
393 |
|
394 |
# ════════════════════════════════════════════════════════════════════════════
|
395 |
# UI LAYER
|
@@ -420,7 +413,6 @@ with gr.Blocks(css="#classdf td{min-width:120px}") as demo:
|
|
420 |
out_md = gr.Markdown()
|
421 |
out_df = gr.Dataframe(label="Class distribution")
|
422 |
|
423 |
-
# --- callback (identical logic from v3, omitted for brevity) ---
|
424 |
def _evaluate_cb(api_key, url_txt, zip_file, server_path, yaml_file, weights,
|
425 |
blur_thr, iou_thr, conf_thr, run_dup, run_modelqa):
|
426 |
return "Evaluation disabled in this trimmed snippet.", pd.DataFrame()
|
@@ -434,7 +426,7 @@ with gr.Blocks(css="#classdf td{min-width:120px}") as demo:
|
|
434 |
|
435 |
# ------------------------------ MERGE TAB -----------------------------
|
436 |
with gr.Tab("Merge / Edit"):
|
437 |
-
gr.Markdown("
|
438 |
rf_key = gr.Textbox(label="Roboflow API key", type="password")
|
439 |
rf_urls = gr.File(label=".txt of RF URLs", file_types=['.txt'])
|
440 |
zips_in = gr.Files(label="One or more dataset ZIPs")
|
@@ -446,14 +438,19 @@ with gr.Blocks(css="#classdf td{min-width:120px}") as demo:
|
|
446 |
global autoinc
|
447 |
info_list = []
|
448 |
log_lines = []
|
|
|
449 |
# Roboflow URLs via txt
|
450 |
if rf_urls_file is not None:
|
451 |
for url in Path(rf_urls_file.name).read_text().splitlines():
|
452 |
if not url.strip():
|
453 |
continue
|
454 |
-
|
455 |
-
|
456 |
-
|
|
|
|
|
|
|
|
|
457 |
# ZIPs
|
458 |
for f in zip_files or []:
|
459 |
autoinc += 1
|
@@ -464,30 +461,33 @@ with gr.Blocks(css="#classdf td{min-width:120px}") as demo:
|
|
464 |
if yaml_path is None:
|
465 |
continue
|
466 |
names = load_yaml(yaml_path).get("names", [])
|
467 |
-
splits = [s for s in
|
468 |
info_list.append((str(tmp), names, splits, tmp.name))
|
469 |
log_lines.append(f"✔️ ZIP **{tmp.name}** loaded")
|
|
|
470 |
return info_list, "\n".join(log_lines) if log_lines else "No datasets loaded."
|
471 |
|
472 |
load_btn.click(_load_cb, [rf_key, rf_urls, zips_in], [ds_state, load_log])
|
473 |
|
474 |
# ------------- Class map editable table --------------------------
|
475 |
-
gr.Markdown("
|
476 |
-
class_df = gr.Dataframe(
|
477 |
-
|
478 |
-
|
|
|
|
|
479 |
refresh_btn = gr.Button("Build class table from loaded datasets")
|
480 |
|
481 |
def _build_class_df(ds_info):
|
482 |
class_names_all = []
|
483 |
-
for _dloc, names, _spl,
|
484 |
class_names_all.extend(names)
|
485 |
class_names_all = sorted(set(class_names_all))
|
486 |
df = pd.DataFrame({
|
487 |
"original_class": class_names_all,
|
488 |
"new_name": class_names_all,
|
489 |
-
"max_images": [99999]*len(class_names_all),
|
490 |
-
"remove": [False]*len(class_names_all),
|
491 |
})
|
492 |
return df
|
493 |
|
@@ -499,11 +499,14 @@ with gr.Blocks(css="#classdf td{min-width:120px}") as demo:
|
|
499 |
merge_log = gr.Markdown()
|
500 |
|
501 |
def _merge_cb(ds_info, class_df):
|
502 |
-
if
|
503 |
return None, "⚠️ Load datasets first."
|
504 |
-
out_dir = merge_datasets(ds_info, class_df)
|
505 |
zip_path = shutil.make_archive(str(out_dir), "zip", out_dir)
|
506 |
-
return zip_path,
|
|
|
|
|
|
|
507 |
|
508 |
merge_btn.click(_merge_cb, [ds_state, class_df], [zip_out, merge_log])
|
509 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from __future__ import annotations
|
2 |
|
3 |
import base64
|
|
|
11 |
import shutil
|
12 |
import stat
|
13 |
import tempfile
|
14 |
+
import zipfile
|
15 |
from collections import Counter
|
16 |
from concurrent.futures import ThreadPoolExecutor, as_completed
|
17 |
from dataclasses import dataclass
|
|
|
154 |
# -------------------- Roboflow helpers --------------------
|
155 |
RF_RE = re.compile(r"https?://universe\.roboflow\.com/([^/]+)/([^/]+)/?([^/]*)")
|
156 |
|
157 |
+
def parse_roboflow_url(url: str) -> tuple[str, str, int | None]:
|
158 |
+
"""
|
159 |
+
Return (workspace, project, version|None) – tolerates many RF URL flavours.
|
160 |
+
Any non‐positive or malformed version is treated as None.
|
161 |
+
"""
|
162 |
m = RF_RE.match(url.strip())
|
163 |
if not m:
|
164 |
return None, None, None
|
165 |
ws, proj, tail = m.groups()
|
166 |
+
ver: int | None = None
|
167 |
+
|
168 |
+
# explicit "dataset/<number>" in path
|
169 |
if tail.startswith("dataset/"):
|
170 |
+
try:
|
171 |
+
v = int(tail.split("dataset/", 1)[1])
|
172 |
+
if v > 0:
|
173 |
+
ver = v
|
174 |
+
except ValueError:
|
175 |
+
pass
|
176 |
+
|
177 |
+
# explicit "?version=<number>" in query
|
178 |
+
if ver is None and "?version=" in url:
|
179 |
+
try:
|
180 |
+
v = int(url.split("?version=", 1)[1])
|
181 |
+
if v > 0:
|
182 |
+
ver = v
|
183 |
+
except ValueError:
|
184 |
+
pass
|
185 |
+
|
186 |
return ws, proj, ver
|
187 |
|
188 |
|
|
|
197 |
return None
|
198 |
|
199 |
|
200 |
+
def download_roboflow_dataset(
|
201 |
+
url: str,
|
202 |
+
rf_api_key: str,
|
203 |
+
fmt: str = "yolov8",
|
204 |
+
) -> Tuple[Path, List[str], List[str]]:
|
205 |
"""Return (dataset_location, class_names, splits). Caches by folder name."""
|
206 |
if Roboflow is None:
|
207 |
raise RuntimeError("`roboflow` pip package not installed")
|
208 |
+
|
209 |
ws, proj, ver = parse_roboflow_url(url)
|
210 |
if not (ws and proj):
|
211 |
+
raise ValueError(f"Bad Roboflow URL: {url!r}")
|
212 |
|
213 |
rf = Roboflow(api_key=rf_api_key)
|
214 |
+
|
215 |
+
# if no explicit version or invalid, fetch latest
|
216 |
+
if not ver or ver <= 0:
|
217 |
+
latest = get_latest_version(rf, ws, proj)
|
218 |
+
if latest is None:
|
219 |
raise RuntimeError("Could not resolve latest Roboflow version")
|
220 |
+
try:
|
221 |
+
ver = int(latest)
|
222 |
+
except ValueError:
|
223 |
+
raise RuntimeError(f"Invalid latest version returned: {latest!r}")
|
224 |
|
225 |
ds_dir = TMP_ROOT / f"{ws}_{proj}_v{ver}"
|
226 |
if ds_dir.exists():
|
227 |
yaml_path = ds_dir / "data.yaml"
|
228 |
class_names = load_yaml(yaml_path).get("names", []) if yaml_path.exists() else []
|
229 |
+
splits = [s for s in ("train","valid","test") if (ds_dir / s).exists()]
|
230 |
return ds_dir, class_names, splits
|
231 |
|
232 |
ds_dir.mkdir(parents=True, exist_ok=True)
|
233 |
+
rf.workspace(ws).project(proj).version(ver).download(fmt, location=str(ds_dir))
|
234 |
+
|
235 |
yaml_path = ds_dir / "data.yaml"
|
236 |
class_names = load_yaml(yaml_path).get("names", []) if yaml_path.exists() else []
|
237 |
+
splits = [s for s in ("train","valid","test") if (ds_dir / s).exists()]
|
238 |
return ds_dir, class_names, splits
|
239 |
|
240 |
|
|
|
257 |
return dict(counts)
|
258 |
|
259 |
|
|
|
|
|
260 |
def _process_label_file(label_path: Path, class_names_dataset, class_name_mapping):
|
261 |
im_name = label_path.stem + label_path.suffix.replace(".txt", ".jpg")
|
262 |
img_classes = set()
|
|
|
268 |
return im_name, img_classes
|
269 |
|
270 |
|
|
|
|
|
|
|
|
|
271 |
def merge_datasets(
|
272 |
dataset_info_list: List[Tuple[str, List[str], List[str], str]],
|
273 |
class_map_df: pd.DataFrame,
|
|
|
283 |
(out_dir / "valid/images").mkdir(parents=True, exist_ok=True)
|
284 |
(out_dir / "valid/labels").mkdir(parents=True, exist_ok=True)
|
285 |
|
|
|
286 |
class_name_mapping = {
|
287 |
+
row["original_class"]: row["new_name"] if not row["remove"] else "__REMOVED__"
|
288 |
for _, row in class_map_df.iterrows()
|
289 |
}
|
290 |
limits_per_merged = {
|
291 |
row["new_name"]: int(row["max_images"])
|
292 |
for _, row in class_map_df.iterrows()
|
293 |
+
if not row["remove"]
|
294 |
}
|
|
|
295 |
active_classes = [c for c in sorted(set(class_name_mapping.values())) if c != "__REMOVED__"]
|
296 |
id_map = {cls: idx for idx, cls in enumerate(active_classes)}
|
297 |
|
|
|
298 |
image_to_classes: dict[str, set[str]] = {}
|
299 |
image_to_label: dict[str, Path] = {}
|
300 |
class_to_images: dict[str, set[str]] = {c: set() for c in active_classes}
|
|
|
315 |
for c in cls_set:
|
316 |
class_to_images[c].add(img_path)
|
317 |
|
|
|
318 |
selected_images: set[str] = set()
|
319 |
counters = {c: 0 for c in active_classes}
|
320 |
shuffle_pool = [img for imgs in class_to_images.values() for img in imgs]
|
|
|
322 |
|
323 |
for img in shuffle_pool:
|
324 |
cls_set = image_to_classes[img]
|
|
|
325 |
if any(counters[c] >= limits_per_merged.get(c, 0) for c in cls_set):
|
326 |
continue
|
327 |
selected_images.add(img)
|
328 |
for c in cls_set:
|
329 |
counters[c] += 1
|
330 |
|
|
|
331 |
for img in selected_images:
|
332 |
split = "train" if random.random() < 0.9 else "valid"
|
333 |
dst_img = out_dir / split / "images" / Path(img).name
|
|
|
345 |
if not parts:
|
346 |
continue
|
347 |
cid = int(parts[0])
|
|
|
348 |
dloc_match = next((cl for dloc2, cl, _, _ in dataset_info_list if str(lp_src).startswith(dloc2)), None)
|
349 |
if dloc_match is None:
|
350 |
continue
|
|
|
362 |
else:
|
363 |
(out_dir / split / "images" / Path(img).name).unlink(missing_ok=True)
|
364 |
|
|
|
365 |
data_yaml = {
|
366 |
"path": str(out_dir.resolve()),
|
367 |
"train": "train/images",
|
|
|
375 |
return out_dir
|
376 |
|
377 |
|
|
|
|
|
378 |
def zip_directory(folder: Path) -> bytes:
|
379 |
buf = io.BytesIO()
|
380 |
+
with zipfile.ZipFile(buf, "w", zipfile.ZIP_DEFLATED) as zf:
|
381 |
+
for file in folder.rglob("*"):
|
382 |
+
zf.write(file, arcname=file.relative_to(folder))
|
383 |
+
buf.seek(0)
|
384 |
+
return buf.getvalue()
|
385 |
+
|
386 |
|
387 |
# ════════════════════════════════════════════════════════════════════════════
|
388 |
# UI LAYER
|
|
|
413 |
out_md = gr.Markdown()
|
414 |
out_df = gr.Dataframe(label="Class distribution")
|
415 |
|
|
|
416 |
def _evaluate_cb(api_key, url_txt, zip_file, server_path, yaml_file, weights,
|
417 |
blur_thr, iou_thr, conf_thr, run_dup, run_modelqa):
|
418 |
return "Evaluation disabled in this trimmed snippet.", pd.DataFrame()
|
|
|
426 |
|
427 |
# ------------------------------ MERGE TAB -----------------------------
|
428 |
with gr.Tab("Merge / Edit"):
|
429 |
+
gr.Markdown("### 1️⃣ Load one or more datasets")
|
430 |
rf_key = gr.Textbox(label="Roboflow API key", type="password")
|
431 |
rf_urls = gr.File(label=".txt of RF URLs", file_types=['.txt'])
|
432 |
zips_in = gr.Files(label="One or more dataset ZIPs")
|
|
|
438 |
global autoinc
|
439 |
info_list = []
|
440 |
log_lines = []
|
441 |
+
|
442 |
# Roboflow URLs via txt
|
443 |
if rf_urls_file is not None:
|
444 |
for url in Path(rf_urls_file.name).read_text().splitlines():
|
445 |
if not url.strip():
|
446 |
continue
|
447 |
+
try:
|
448 |
+
ds, names, splits = download_roboflow_dataset(url, rf_key)
|
449 |
+
info_list.append((str(ds), names, splits, Path(ds).name))
|
450 |
+
log_lines.append(f"✔️ RF dataset **{Path(ds).name}** loaded ({len(names)} classes)")
|
451 |
+
except Exception as e:
|
452 |
+
log_lines.append(f"⚠️ RF load failed for {url!r}: {e}")
|
453 |
+
|
454 |
# ZIPs
|
455 |
for f in zip_files or []:
|
456 |
autoinc += 1
|
|
|
461 |
if yaml_path is None:
|
462 |
continue
|
463 |
names = load_yaml(yaml_path).get("names", [])
|
464 |
+
splits = [s for s in ("train","valid","test") if (tmp / s).exists()]
|
465 |
info_list.append((str(tmp), names, splits, tmp.name))
|
466 |
log_lines.append(f"✔️ ZIP **{tmp.name}** loaded")
|
467 |
+
|
468 |
return info_list, "\n".join(log_lines) if log_lines else "No datasets loaded."
|
469 |
|
470 |
load_btn.click(_load_cb, [rf_key, rf_urls, zips_in], [ds_state, load_log])
|
471 |
|
472 |
# ------------- Class map editable table --------------------------
|
473 |
+
gr.Markdown("### 2️⃣ Edit class mapping / limits / removal")
|
474 |
+
class_df = gr.Dataframe(
|
475 |
+
headers=["original_class", "new_name", "max_images", "remove"],
|
476 |
+
datatype=["str", "str", "number", "bool"],
|
477 |
+
interactive=True, elem_id="classdf"
|
478 |
+
)
|
479 |
refresh_btn = gr.Button("Build class table from loaded datasets")
|
480 |
|
481 |
def _build_class_df(ds_info):
|
482 |
class_names_all = []
|
483 |
+
for _dloc, names, _spl, _ in ds_info:
|
484 |
class_names_all.extend(names)
|
485 |
class_names_all = sorted(set(class_names_all))
|
486 |
df = pd.DataFrame({
|
487 |
"original_class": class_names_all,
|
488 |
"new_name": class_names_all,
|
489 |
+
"max_images": [99999] * len(class_names_all),
|
490 |
+
"remove": [False] * len(class_names_all),
|
491 |
})
|
492 |
return df
|
493 |
|
|
|
499 |
merge_log = gr.Markdown()
|
500 |
|
501 |
def _merge_cb(ds_info, class_df):
|
502 |
+
if not ds_info:
|
503 |
return None, "⚠️ Load datasets first."
|
504 |
+
out_dir = merge_datasets(ds_info, class_df)
|
505 |
zip_path = shutil.make_archive(str(out_dir), "zip", out_dir)
|
506 |
+
return zip_path, (
|
507 |
+
f"✅ Merged dataset created at **{out_dir}** with "
|
508 |
+
f"{len(list(Path(out_dir).rglob('*.jpg')))} images."
|
509 |
+
)
|
510 |
|
511 |
merge_btn.click(_merge_cb, [ds_state, class_df], [zip_out, merge_log])
|
512 |
|