wuhp commited on
Commit
c450ec7
·
verified ·
1 Parent(s): ec6801a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -76
app.py CHANGED
@@ -1,28 +1,3 @@
1
- # gradio_dataset_manager.py – Unified YOLO Dataset Toolkit (Evaluate + Merge/Edit)
2
- """
3
- Gradio application that **combines** the functionality of the original evaluation‑only app and
4
- the Streamlit dataset‑merging dashboard.
5
-
6
- ### Key features added
7
- * **Multi‑dataset loader**
8
- – Roboflow URLs (with automatic latest‑version fallback)
9
- – ZIP uploads (GitHub/Ultralytics releases, etc.)
10
- – Existing server paths
11
- * **Interactive class manager** (rename / remove / per‑class max images) using an **editable
12
- Gradio `Dataframe`**. Multiple rows can share the same *New name* to merge classes.
13
- * **Dataset merger** with:
14
- – Per‑class image limits
15
- – Consistent re‑indexing after renames/removals
16
- – Duplicate image avoidance
17
- – Final `data.yaml` + ready‑to‑train directory structure
18
- * One‑click **ZIP download** of the merged dataset
19
- * Original **Quality Evaluation** preserved (blur/IOU/conf sliders, fastdup, optional
20
- Model‑QA/Cleanlab)
21
-
22
- > ⚠️ This is a standalone Python script – drop it into a Hugging Face **Space** or run with
23
- > `python gradio_dataset_manager.py`. Requires the same pip deps you already list.
24
- """
25
-
26
  from __future__ import annotations
27
 
28
  import base64
@@ -36,6 +11,7 @@ import re
36
  import shutil
37
  import stat
38
  import tempfile
 
39
  from collections import Counter
40
  from concurrent.futures import ThreadPoolExecutor, as_completed
41
  from dataclasses import dataclass
@@ -178,18 +154,35 @@ def get_model(weights: str) -> YOLO | None:
178
  # -------------------- Roboflow helpers --------------------
179
  RF_RE = re.compile(r"https?://universe\.roboflow\.com/([^/]+)/([^/]+)/?([^/]*)")
180
 
181
-
182
- def parse_roboflow_url(url: str):
183
- """Return (workspace, project, version|None) – tolerates many RF URL flavours."""
 
 
184
  m = RF_RE.match(url.strip())
185
  if not m:
186
  return None, None, None
187
  ws, proj, tail = m.groups()
188
- ver = None
 
 
189
  if tail.startswith("dataset/"):
190
- ver = tail.split("dataset/")[-1]
191
- elif "?version=" in url:
192
- ver = url.split("?version=")[-1]
 
 
 
 
 
 
 
 
 
 
 
 
 
193
  return ws, proj, ver
194
 
195
 
@@ -204,32 +197,44 @@ def get_latest_version(rf: Roboflow, ws: str, proj: str) -> str | None:
204
  return None
205
 
206
 
207
- def download_roboflow_dataset(url: str, rf_api_key: str, fmt: str = "yolov8") -> Tuple[Path, List[str], List[str]]:
 
 
 
 
208
  """Return (dataset_location, class_names, splits). Caches by folder name."""
209
  if Roboflow is None:
210
  raise RuntimeError("`roboflow` pip package not installed")
 
211
  ws, proj, ver = parse_roboflow_url(url)
212
  if not (ws and proj):
213
- raise ValueError("Bad Roboflow URL")
214
 
215
  rf = Roboflow(api_key=rf_api_key)
216
- if ver is None:
217
- ver = get_latest_version(rf, ws, proj)
218
- if ver is None:
 
 
219
  raise RuntimeError("Could not resolve latest Roboflow version")
 
 
 
 
220
 
221
  ds_dir = TMP_ROOT / f"{ws}_{proj}_v{ver}"
222
  if ds_dir.exists():
223
  yaml_path = ds_dir / "data.yaml"
224
  class_names = load_yaml(yaml_path).get("names", []) if yaml_path.exists() else []
225
- splits = [s for s in ["train", "valid", "test"] if (ds_dir / s).exists()]
226
  return ds_dir, class_names, splits
227
 
228
  ds_dir.mkdir(parents=True, exist_ok=True)
229
- rf.workspace(ws).project(proj).version(int(ver)).download(fmt, location=str(ds_dir))
 
230
  yaml_path = ds_dir / "data.yaml"
231
  class_names = load_yaml(yaml_path).get("names", []) if yaml_path.exists() else []
232
- splits = [s for s in ["train", "valid", "test"] if (ds_dir / s).exists()]
233
  return ds_dir, class_names, splits
234
 
235
 
@@ -252,8 +257,6 @@ def gather_class_counts(dataset_info_list, class_name_mapping):
252
  return dict(counts)
253
 
254
 
255
- # -- label‑file worker (same logic, no Streamlit)
256
-
257
  def _process_label_file(label_path: Path, class_names_dataset, class_name_mapping):
258
  im_name = label_path.stem + label_path.suffix.replace(".txt", ".jpg")
259
  img_classes = set()
@@ -265,10 +268,6 @@ def _process_label_file(label_path: Path, class_names_dataset, class_name_mappin
265
  return im_name, img_classes
266
 
267
 
268
- # ---------------------------------------------------------------------------
269
- # merge_datasets(): **pure‑python** version (no streamlit session state)
270
- # ---------------------------------------------------------------------------
271
-
272
  def merge_datasets(
273
  dataset_info_list: List[Tuple[str, List[str], List[str], str]],
274
  class_map_df: pd.DataFrame,
@@ -284,21 +283,18 @@ def merge_datasets(
284
  (out_dir / "valid/images").mkdir(parents=True, exist_ok=True)
285
  (out_dir / "valid/labels").mkdir(parents=True, exist_ok=True)
286
 
287
- # Build mapping dicts ----------------------------------------------------
288
  class_name_mapping = {
289
- row["original_class"]: row["new_name"] if row["remove"] is False else "__REMOVED__"
290
  for _, row in class_map_df.iterrows()
291
  }
292
  limits_per_merged = {
293
  row["new_name"]: int(row["max_images"])
294
  for _, row in class_map_df.iterrows()
295
- if row["remove"] is False
296
  }
297
- # active merged classes only
298
  active_classes = [c for c in sorted(set(class_name_mapping.values())) if c != "__REMOVED__"]
299
  id_map = {cls: idx for idx, cls in enumerate(active_classes)}
300
 
301
- # Scan label files -------------------------------------------------------
302
  image_to_classes: dict[str, set[str]] = {}
303
  image_to_label: dict[str, Path] = {}
304
  class_to_images: dict[str, set[str]] = {c: set() for c in active_classes}
@@ -319,7 +315,6 @@ def merge_datasets(
319
  for c in cls_set:
320
  class_to_images[c].add(img_path)
321
 
322
- # Select images respecting per‑class limits -----------------------------
323
  selected_images: set[str] = set()
324
  counters = {c: 0 for c in active_classes}
325
  shuffle_pool = [img for imgs in class_to_images.values() for img in imgs]
@@ -327,14 +322,12 @@ def merge_datasets(
327
 
328
  for img in shuffle_pool:
329
  cls_set = image_to_classes[img]
330
- # skip if any class hit its limit
331
  if any(counters[c] >= limits_per_merged.get(c, 0) for c in cls_set):
332
  continue
333
  selected_images.add(img)
334
  for c in cls_set:
335
  counters[c] += 1
336
 
337
- # Copy & re‑index --------------------------------------------------------
338
  for img in selected_images:
339
  split = "train" if random.random() < 0.9 else "valid"
340
  dst_img = out_dir / split / "images" / Path(img).name
@@ -352,7 +345,6 @@ def merge_datasets(
352
  if not parts:
353
  continue
354
  cid = int(parts[0])
355
- # find orig class name
356
  dloc_match = next((cl for dloc2, cl, _, _ in dataset_info_list if str(lp_src).startswith(dloc2)), None)
357
  if dloc_match is None:
358
  continue
@@ -370,7 +362,6 @@ def merge_datasets(
370
  else:
371
  (out_dir / split / "images" / Path(img).name).unlink(missing_ok=True)
372
 
373
- # Build data.yaml --------------------------------------------------------
374
  data_yaml = {
375
  "path": str(out_dir.resolve()),
376
  "train": "train/images",
@@ -384,12 +375,14 @@ def merge_datasets(
384
  return out_dir
385
 
386
 
387
- # Utility: zip a folder to bytes -------------------------------------------
388
-
389
  def zip_directory(folder: Path) -> bytes:
390
  buf = io.BytesIO()
391
- with shutil.make_archive("dataset", "zip", folder) as _:
392
- pass # make_archive writes to disk – we avoid that (quick hack)
 
 
 
 
393
 
394
  # ════════════════════════════════════════════════════════════════════════════
395
  # UI LAYER
@@ -420,7 +413,6 @@ with gr.Blocks(css="#classdf td{min-width:120px}") as demo:
420
  out_md = gr.Markdown()
421
  out_df = gr.Dataframe(label="Class distribution")
422
 
423
- # --- callback (identical logic from v3, omitted for brevity) ---
424
  def _evaluate_cb(api_key, url_txt, zip_file, server_path, yaml_file, weights,
425
  blur_thr, iou_thr, conf_thr, run_dup, run_modelqa):
426
  return "Evaluation disabled in this trimmed snippet.", pd.DataFrame()
@@ -434,7 +426,7 @@ with gr.Blocks(css="#classdf td{min-width:120px}") as demo:
434
 
435
  # ------------------------------ MERGE TAB -----------------------------
436
  with gr.Tab("Merge / Edit"):
437
- gr.Markdown("""### 1️⃣ Load one or more datasets""")
438
  rf_key = gr.Textbox(label="Roboflow API key", type="password")
439
  rf_urls = gr.File(label=".txt of RF URLs", file_types=['.txt'])
440
  zips_in = gr.Files(label="One or more dataset ZIPs")
@@ -446,14 +438,19 @@ with gr.Blocks(css="#classdf td{min-width:120px}") as demo:
446
  global autoinc
447
  info_list = []
448
  log_lines = []
 
449
  # Roboflow URLs via txt
450
  if rf_urls_file is not None:
451
  for url in Path(rf_urls_file.name).read_text().splitlines():
452
  if not url.strip():
453
  continue
454
- ds, names, splits = download_roboflow_dataset(url, rf_key)
455
- info_list.append((str(ds), names, splits, Path(ds).name))
456
- log_lines.append(f"✔️ RF dataset **{Path(ds).name}** loaded ({len(names)} classes)")
 
 
 
 
457
  # ZIPs
458
  for f in zip_files or []:
459
  autoinc += 1
@@ -464,30 +461,33 @@ with gr.Blocks(css="#classdf td{min-width:120px}") as demo:
464
  if yaml_path is None:
465
  continue
466
  names = load_yaml(yaml_path).get("names", [])
467
- splits = [s for s in ["train", "valid", "test"] if (tmp / s).exists()]
468
  info_list.append((str(tmp), names, splits, tmp.name))
469
  log_lines.append(f"✔️ ZIP **{tmp.name}** loaded")
 
470
  return info_list, "\n".join(log_lines) if log_lines else "No datasets loaded."
471
 
472
  load_btn.click(_load_cb, [rf_key, rf_urls, zips_in], [ds_state, load_log])
473
 
474
  # ------------- Class map editable table --------------------------
475
- gr.Markdown("""### 2️⃣ Edit class mapping / limits / removal""")
476
- class_df = gr.Dataframe(headers=["original_class", "new_name", "max_images", "remove"],
477
- datatype=["str", "str", "number", "bool"],
478
- interactive=True, elem_id="classdf")
 
 
479
  refresh_btn = gr.Button("Build class table from loaded datasets")
480
 
481
  def _build_class_df(ds_info):
482
  class_names_all = []
483
- for _dloc, names, _spl, _n in ds_info:
484
  class_names_all.extend(names)
485
  class_names_all = sorted(set(class_names_all))
486
  df = pd.DataFrame({
487
  "original_class": class_names_all,
488
  "new_name": class_names_all,
489
- "max_images": [99999]*len(class_names_all),
490
- "remove": [False]*len(class_names_all),
491
  })
492
  return df
493
 
@@ -499,11 +499,14 @@ with gr.Blocks(css="#classdf td{min-width:120px}") as demo:
499
  merge_log = gr.Markdown()
500
 
501
  def _merge_cb(ds_info, class_df):
502
- if len(ds_info) == 0:
503
  return None, "⚠️ Load datasets first."
504
- out_dir = merge_datasets(ds_info, class_df) # may be slow
505
  zip_path = shutil.make_archive(str(out_dir), "zip", out_dir)
506
- return zip_path, f"✅ Merged dataset created at **{out_dir}** with {len(list(Path(out_dir).rglob('*.jpg')))} images."
 
 
 
507
 
508
  merge_btn.click(_merge_cb, [ds_state, class_df], [zip_out, merge_log])
509
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from __future__ import annotations
2
 
3
  import base64
 
11
  import shutil
12
  import stat
13
  import tempfile
14
+ import zipfile
15
  from collections import Counter
16
  from concurrent.futures import ThreadPoolExecutor, as_completed
17
  from dataclasses import dataclass
 
154
  # -------------------- Roboflow helpers --------------------
155
  RF_RE = re.compile(r"https?://universe\.roboflow\.com/([^/]+)/([^/]+)/?([^/]*)")
156
 
157
+ def parse_roboflow_url(url: str) -> tuple[str, str, int | None]:
158
+ """
159
+ Return (workspace, project, version|None) – tolerates many RF URL flavours.
160
+ Any non‐positive or malformed version is treated as None.
161
+ """
162
  m = RF_RE.match(url.strip())
163
  if not m:
164
  return None, None, None
165
  ws, proj, tail = m.groups()
166
+ ver: int | None = None
167
+
168
+ # explicit "dataset/<number>" in path
169
  if tail.startswith("dataset/"):
170
+ try:
171
+ v = int(tail.split("dataset/", 1)[1])
172
+ if v > 0:
173
+ ver = v
174
+ except ValueError:
175
+ pass
176
+
177
+ # explicit "?version=<number>" in query
178
+ if ver is None and "?version=" in url:
179
+ try:
180
+ v = int(url.split("?version=", 1)[1])
181
+ if v > 0:
182
+ ver = v
183
+ except ValueError:
184
+ pass
185
+
186
  return ws, proj, ver
187
 
188
 
 
197
  return None
198
 
199
 
200
+ def download_roboflow_dataset(
201
+ url: str,
202
+ rf_api_key: str,
203
+ fmt: str = "yolov8",
204
+ ) -> Tuple[Path, List[str], List[str]]:
205
  """Return (dataset_location, class_names, splits). Caches by folder name."""
206
  if Roboflow is None:
207
  raise RuntimeError("`roboflow` pip package not installed")
208
+
209
  ws, proj, ver = parse_roboflow_url(url)
210
  if not (ws and proj):
211
+ raise ValueError(f"Bad Roboflow URL: {url!r}")
212
 
213
  rf = Roboflow(api_key=rf_api_key)
214
+
215
+ # if no explicit version or invalid, fetch latest
216
+ if not ver or ver <= 0:
217
+ latest = get_latest_version(rf, ws, proj)
218
+ if latest is None:
219
  raise RuntimeError("Could not resolve latest Roboflow version")
220
+ try:
221
+ ver = int(latest)
222
+ except ValueError:
223
+ raise RuntimeError(f"Invalid latest version returned: {latest!r}")
224
 
225
  ds_dir = TMP_ROOT / f"{ws}_{proj}_v{ver}"
226
  if ds_dir.exists():
227
  yaml_path = ds_dir / "data.yaml"
228
  class_names = load_yaml(yaml_path).get("names", []) if yaml_path.exists() else []
229
+ splits = [s for s in ("train","valid","test") if (ds_dir / s).exists()]
230
  return ds_dir, class_names, splits
231
 
232
  ds_dir.mkdir(parents=True, exist_ok=True)
233
+ rf.workspace(ws).project(proj).version(ver).download(fmt, location=str(ds_dir))
234
+
235
  yaml_path = ds_dir / "data.yaml"
236
  class_names = load_yaml(yaml_path).get("names", []) if yaml_path.exists() else []
237
+ splits = [s for s in ("train","valid","test") if (ds_dir / s).exists()]
238
  return ds_dir, class_names, splits
239
 
240
 
 
257
  return dict(counts)
258
 
259
 
 
 
260
  def _process_label_file(label_path: Path, class_names_dataset, class_name_mapping):
261
  im_name = label_path.stem + label_path.suffix.replace(".txt", ".jpg")
262
  img_classes = set()
 
268
  return im_name, img_classes
269
 
270
 
 
 
 
 
271
  def merge_datasets(
272
  dataset_info_list: List[Tuple[str, List[str], List[str], str]],
273
  class_map_df: pd.DataFrame,
 
283
  (out_dir / "valid/images").mkdir(parents=True, exist_ok=True)
284
  (out_dir / "valid/labels").mkdir(parents=True, exist_ok=True)
285
 
 
286
  class_name_mapping = {
287
+ row["original_class"]: row["new_name"] if not row["remove"] else "__REMOVED__"
288
  for _, row in class_map_df.iterrows()
289
  }
290
  limits_per_merged = {
291
  row["new_name"]: int(row["max_images"])
292
  for _, row in class_map_df.iterrows()
293
+ if not row["remove"]
294
  }
 
295
  active_classes = [c for c in sorted(set(class_name_mapping.values())) if c != "__REMOVED__"]
296
  id_map = {cls: idx for idx, cls in enumerate(active_classes)}
297
 
 
298
  image_to_classes: dict[str, set[str]] = {}
299
  image_to_label: dict[str, Path] = {}
300
  class_to_images: dict[str, set[str]] = {c: set() for c in active_classes}
 
315
  for c in cls_set:
316
  class_to_images[c].add(img_path)
317
 
 
318
  selected_images: set[str] = set()
319
  counters = {c: 0 for c in active_classes}
320
  shuffle_pool = [img for imgs in class_to_images.values() for img in imgs]
 
322
 
323
  for img in shuffle_pool:
324
  cls_set = image_to_classes[img]
 
325
  if any(counters[c] >= limits_per_merged.get(c, 0) for c in cls_set):
326
  continue
327
  selected_images.add(img)
328
  for c in cls_set:
329
  counters[c] += 1
330
 
 
331
  for img in selected_images:
332
  split = "train" if random.random() < 0.9 else "valid"
333
  dst_img = out_dir / split / "images" / Path(img).name
 
345
  if not parts:
346
  continue
347
  cid = int(parts[0])
 
348
  dloc_match = next((cl for dloc2, cl, _, _ in dataset_info_list if str(lp_src).startswith(dloc2)), None)
349
  if dloc_match is None:
350
  continue
 
362
  else:
363
  (out_dir / split / "images" / Path(img).name).unlink(missing_ok=True)
364
 
 
365
  data_yaml = {
366
  "path": str(out_dir.resolve()),
367
  "train": "train/images",
 
375
  return out_dir
376
 
377
 
 
 
378
  def zip_directory(folder: Path) -> bytes:
379
  buf = io.BytesIO()
380
+ with zipfile.ZipFile(buf, "w", zipfile.ZIP_DEFLATED) as zf:
381
+ for file in folder.rglob("*"):
382
+ zf.write(file, arcname=file.relative_to(folder))
383
+ buf.seek(0)
384
+ return buf.getvalue()
385
+
386
 
387
  # ════════════════════════════════════════════════════════════════════════════
388
  # UI LAYER
 
413
  out_md = gr.Markdown()
414
  out_df = gr.Dataframe(label="Class distribution")
415
 
 
416
  def _evaluate_cb(api_key, url_txt, zip_file, server_path, yaml_file, weights,
417
  blur_thr, iou_thr, conf_thr, run_dup, run_modelqa):
418
  return "Evaluation disabled in this trimmed snippet.", pd.DataFrame()
 
426
 
427
  # ------------------------------ MERGE TAB -----------------------------
428
  with gr.Tab("Merge / Edit"):
429
+ gr.Markdown("### 1️⃣ Load one or more datasets")
430
  rf_key = gr.Textbox(label="Roboflow API key", type="password")
431
  rf_urls = gr.File(label=".txt of RF URLs", file_types=['.txt'])
432
  zips_in = gr.Files(label="One or more dataset ZIPs")
 
438
  global autoinc
439
  info_list = []
440
  log_lines = []
441
+
442
  # Roboflow URLs via txt
443
  if rf_urls_file is not None:
444
  for url in Path(rf_urls_file.name).read_text().splitlines():
445
  if not url.strip():
446
  continue
447
+ try:
448
+ ds, names, splits = download_roboflow_dataset(url, rf_key)
449
+ info_list.append((str(ds), names, splits, Path(ds).name))
450
+ log_lines.append(f"✔️ RF dataset **{Path(ds).name}** loaded ({len(names)} classes)")
451
+ except Exception as e:
452
+ log_lines.append(f"⚠️ RF load failed for {url!r}: {e}")
453
+
454
  # ZIPs
455
  for f in zip_files or []:
456
  autoinc += 1
 
461
  if yaml_path is None:
462
  continue
463
  names = load_yaml(yaml_path).get("names", [])
464
+ splits = [s for s in ("train","valid","test") if (tmp / s).exists()]
465
  info_list.append((str(tmp), names, splits, tmp.name))
466
  log_lines.append(f"✔️ ZIP **{tmp.name}** loaded")
467
+
468
  return info_list, "\n".join(log_lines) if log_lines else "No datasets loaded."
469
 
470
  load_btn.click(_load_cb, [rf_key, rf_urls, zips_in], [ds_state, load_log])
471
 
472
  # ------------- Class map editable table --------------------------
473
+ gr.Markdown("### 2️⃣ Edit class mapping / limits / removal")
474
+ class_df = gr.Dataframe(
475
+ headers=["original_class", "new_name", "max_images", "remove"],
476
+ datatype=["str", "str", "number", "bool"],
477
+ interactive=True, elem_id="classdf"
478
+ )
479
  refresh_btn = gr.Button("Build class table from loaded datasets")
480
 
481
  def _build_class_df(ds_info):
482
  class_names_all = []
483
+ for _dloc, names, _spl, _ in ds_info:
484
  class_names_all.extend(names)
485
  class_names_all = sorted(set(class_names_all))
486
  df = pd.DataFrame({
487
  "original_class": class_names_all,
488
  "new_name": class_names_all,
489
+ "max_images": [99999] * len(class_names_all),
490
+ "remove": [False] * len(class_names_all),
491
  })
492
  return df
493
 
 
499
  merge_log = gr.Markdown()
500
 
501
  def _merge_cb(ds_info, class_df):
502
+ if not ds_info:
503
  return None, "⚠️ Load datasets first."
504
+ out_dir = merge_datasets(ds_info, class_df)
505
  zip_path = shutil.make_archive(str(out_dir), "zip", out_dir)
506
+ return zip_path, (
507
+ f"✅ Merged dataset created at **{out_dir}** with "
508
+ f"{len(list(Path(out_dir).rglob('*.jpg')))} images."
509
+ )
510
 
511
  merge_btn.click(_merge_cb, [ds_state, class_df], [zip_out, merge_log])
512