wuhp commited on
Commit
5b4455d
Β·
verified Β·
1 Parent(s): 6e60466

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +319 -299
app.py CHANGED
@@ -1,7 +1,6 @@
1
  from __future__ import annotations
2
 
3
  import base64
4
- import functools
5
  import imghdr
6
  import io
7
  import json
@@ -12,13 +11,12 @@ import re
12
  import shutil
13
  import stat
14
  import tempfile
15
- import time
16
  import zipfile
17
  from collections import Counter
18
- from concurrent.futures import ProcessPoolExecutor, as_completed
19
  from dataclasses import dataclass
20
  from pathlib import Path
21
- from typing import Dict, List, Optional, Tuple
22
 
23
  import gradio as gr
24
  import numpy as np
@@ -51,25 +49,15 @@ try:
51
  from cleanlab.pruning import get_noise_indices
52
  except ImportError:
53
  get_noise_indices = None
54
- try:
55
- import torch
56
- except ImportError:
57
- torch = None
58
 
59
  # ───────────────── Config & Constants ───────────────────────────────────────
60
  TMP_ROOT = Path(tempfile.gettempdir()) / "rf_datasets"
61
  TMP_ROOT.mkdir(parents=True, exist_ok=True)
 
 
 
62
 
63
- # Clean out any subfolders older than 1 hour
64
- _now = time.time()
65
- for sub in TMP_ROOT.iterdir():
66
- if sub.is_dir() and (_now - sub.stat().st_mtime) > 3600:
67
- shutil.rmtree(sub, ignore_errors=True)
68
-
69
- CPU_COUNT = int(os.getenv("QC_CPU", os.cpu_count() or 1))
70
- BATCH_SIZE = int(os.getenv("QC_BATCH", 4))
71
- SAMPLE_LIMIT = int(os.getenv("QC_SAMPLE", 200))
72
- DEFAULT_W = {
73
  "Integrity": 0.25,
74
  "Class balance": 0.10,
75
  "Image quality": 0.15,
@@ -77,32 +65,26 @@ DEFAULT_W = {
77
  "Model QA": 0.30,
78
  "Label issues": 0.10,
79
  }
 
80
  logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(message)s")
81
 
82
  _model_cache: dict[str, YOLO] = {}
83
- autoinc = 0 # for ZIP unpack dirs
84
 
 
 
85
  # ────────────────────────────────────────────────────────────────────────────
86
  @dataclass
87
  class QCConfig:
88
  blur_thr: float
89
  iou_thr: float
90
  conf_thr: float
91
- weights: Optional[str]
92
  cpu_count: int = CPU_COUNT
93
  batch_size: int = BATCH_SIZE
94
  sample_limit:int = SAMPLE_LIMIT
95
 
96
- @dataclass
97
- class DatasetInfo:
98
- path: str
99
- class_names: List[str]
100
- splits: List[str]
101
- name: str
102
-
103
- # ────────────────────────────────────────────────────────────────────────────
104
  def load_yaml(path: Path) -> Dict:
105
- """Load a YAML file safely."""
106
  with path.open('r', encoding='utf-8') as f:
107
  return yaml.safe_load(f)
108
 
@@ -113,10 +95,8 @@ def load_class_names(yaml_path: Path) -> List[str]:
113
  return [names[k] for k in sorted(names, key=lambda x: int(x))]
114
  return list(names)
115
 
116
- @functools.lru_cache(maxsize=None)
117
- def parse_label_file(path: Path) -> List[Tuple[int,float,float,float,float]]:
118
- """Parse a YOLO label file; empty or missing yields empty list."""
119
- if not path.exists() or path.stat().st_size == 0:
120
  return []
121
  try:
122
  arr = np.loadtxt(path, dtype=float)
@@ -136,8 +116,7 @@ def guess_image_dirs(root: Path) -> List[Path]:
136
  ]
137
  return [d for d in candidates if d.exists()]
138
 
139
- def gather_dataset(root: Path, yaml_path: Optional[Path]) -> Tuple[List[Path], List[Optional[Path]], Dict]:
140
- """Return lists of image paths, label paths, and metadata."""
141
  if yaml_path is None:
142
  yamls = list(root.glob('*.yaml'))
143
  if not yamls:
@@ -155,16 +134,26 @@ def gather_dataset(root: Path, yaml_path: Optional[Path]) -> Tuple[List[Path], L
155
  ]
156
  return imgs, lbls, meta
157
 
158
- def get_model(weights: str) -> Optional[YOLO]:
159
- """Load/cache a YOLO model, preferring GPU if available."""
160
  if not weights or YOLO is None:
161
  return None
162
  if weights not in _model_cache:
163
- device = "cuda" if torch and torch.cuda.is_available() else "cpu"
164
- _model_cache[weights] = YOLO(weights, device=device)
165
  return _model_cache[weights]
166
 
167
- # ────────────────────────────────────────────────────────────────────────────
 
 
 
 
 
 
 
 
 
 
 
 
168
  def _is_corrupt(path: Path) -> bool:
169
  try:
170
  with Image.open(path) as im:
@@ -173,172 +162,171 @@ def _is_corrupt(path: Path) -> bool:
173
  except Exception:
174
  return True
175
 
176
- def _quality_stat_args(args: Tuple[str, float]) -> Tuple[str, bool, bool, bool]:
177
- """Return (path, is_blurry, is_dark, is_bright)."""
178
- path_str, thr = args
179
- if cv2 is None:
180
- return path_str, False, False, False
181
- im = cv2.imread(path_str)
182
- if im is None:
183
- return path_str, False, False, False
184
- gray = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY)
185
- lap = cv2.Laplacian(gray, cv2.CV_64F).var()
186
- mean = float(gray.mean())
187
- return path_str, lap < thr, mean < 25, mean > 230
188
-
189
- def qc_integrity(imgs: List[Path], lbls: List[Optional[Path]], cfg: QCConfig) -> Dict:
190
- """Check for missing labels & corrupt images."""
191
  missing = [i for i, l in zip(imgs, lbls) if l is None]
192
  corrupt = []
193
  sample = imgs[:cfg.sample_limit]
194
- with ProcessPoolExecutor(max_workers=cfg.cpu_count) as ex:
195
- futures = {ex.submit(_is_corrupt, str(p)): p for p in sample}
196
- for f in as_completed(futures):
197
  if f.result():
198
- corrupt.append(futures[f])
199
- score = max(0.0, 100 - (len(missing) + len(corrupt)) / max(len(imgs),1) * 100)
200
  return {
201
  "name": "Integrity",
202
- "score": score,
203
  "details": {
204
  "missing_label_files": [str(p) for p in missing],
205
- "corrupt_images": [str(p) for p in corrupt],
206
  }
207
  }
208
 
209
- def qc_class_balance(lbls: List[Optional[Path]], cfg: QCConfig) -> Dict:
210
  counts, boxes = Counter(), []
211
  for l in lbls[:cfg.sample_limit]:
212
  bs = parse_label_file(l) if l else []
213
  boxes.append(len(bs))
214
  counts.update(int(b[0]) for b in bs)
215
  if not counts:
216
- return {"name": "Class balance", "score": 0.0, "details": "No labels"}
217
  bal = min(counts.values()) / max(counts.values()) * 100
218
  return {
219
  "name": "Class balance",
220
  "score": bal,
221
  "details": {
222
- "class_counts": dict(counts),
223
- "boxes_per_image": {"min": min(boxes), "max": max(boxes), "mean": float(np.mean(boxes))}
 
 
 
 
224
  }
225
  }
226
 
227
  def qc_image_quality(imgs: List[Path], cfg: QCConfig) -> Dict:
228
  if cv2 is None:
229
- return {"name": "Image quality", "score": 100.0, "details": "cv2 missing"}
230
  blurry, dark, bright = [], [], []
231
  sample = imgs[:cfg.sample_limit]
232
- args = [(str(p), cfg.blur_thr) for p in sample]
233
- with ProcessPoolExecutor(max_workers=cfg.cpu_count) as ex:
234
- for path_str, isb, isd, isB in ex.map(_quality_stat_args, args):
235
- if isb: blurry.append(path_str)
236
- if isd: dark.append(path_str)
237
- if isB: bright.append(path_str)
238
- bad = len(set(blurry) | set(dark) | set(bright))
239
- score = max(0.0, 100 - bad / max(len(sample),1) * 100)
240
  return {
241
  "name": "Image quality",
242
  "score": score,
243
- "details": {"blurry": blurry, "dark": dark, "bright": bright}
 
 
 
 
244
  }
245
 
246
  def qc_duplicates(imgs: List[Path], cfg: QCConfig) -> Dict:
247
- if fastdup and len(imgs) > 50:
248
  try:
249
  fd = fastdup.create(
250
  input_dir=str(Path(imgs[0]).parent.parent),
251
- work_dir=str(TMP_ROOT/"fastdup")
252
  )
253
  fd.run()
254
- clusters = []
255
  try:
256
  cc = fd.connected_components_grouped(sort_by="comp_size", ascending=False)
257
- clusters = cc["files"].tolist()
258
  except Exception:
259
  clusters = fd.connected_components()
260
- dup_count = sum(len(c)-1 for c in clusters)
261
- score = max(0.0, 100 - dup_count/len(imgs)*100)
262
- return {"name":"Duplicates","score":score,"details":{"groups":clusters[:50]}}
263
  except Exception as e:
264
- return {"name":"Duplicates","score":100.0,"details":{"error":str(e)}}
265
- return {"name":"Duplicates","score":100.0,"details":{"note":"skipped"}}
266
-
267
- def _rel_iou(b1, b2) -> float:
268
- x1,y1,w1,h1 = b1; x2,y2,w2,h2 = b2
269
- xa1,ya1 = x1-w1/2, y1-h1/2
270
- xa2,ya2 = x1+w1/2, y1+h1/2
271
- xb1,yb1 = x2-w2/2, y2-h2/2
272
- xb2,yb2 = x2+w2/2, y2+h2/2
273
- ix = max(0, min(xa2,xb2)-max(xa1,xb1))
274
- iy = max(0, min(ya2,yb2)-max(ya1,yb1))
275
- inter = ix*iy
 
276
  union = w1*h1 + w2*h2 - inter
277
  return inter/union if union else 0.0
278
 
279
- def qc_model_qa(imgs: List[Path], lbls: List[Optional[Path]], cfg: QCConfig) -> Dict:
280
  model = get_model(cfg.weights)
281
  if model is None:
282
- return {"name":"Model QA","score":100.0,"details":"skipped"}
283
- ious, mismatches = [], []
284
  sample = imgs[:cfg.sample_limit]
285
  for i in range(0, len(sample), cfg.batch_size):
286
  batch = sample[i:i+cfg.batch_size]
287
- with torch.no_grad() if torch else contextlib.nullcontext():
288
- results = model.predict(batch, verbose=False, half=True, dynamic=True)
289
  for p, res in zip(batch, results):
290
- gt = parse_label_file(Path(p).with_suffix('.txt').parent) # adjust as needed
291
  for cls, x, y, w, h in gt:
292
- best = max(
293
- ( _rel_iou((x,y,w,h), tuple(b))
294
- for b,c,conf in zip(res.boxes.xywh.cpu().numpy(),
295
- res.boxes.cls.cpu().numpy(),
296
- res.boxes.conf.cpu().numpy())
297
- if conf>=cfg.conf_thr and int(c)==cls
298
- ), default=0.0
299
- )
 
300
  ious.append(best)
301
  if best < cfg.iou_thr:
302
- mismatches.append(str(p))
303
  miou = float(np.mean(ious)) if ious else 1.0
304
- return {"name":"Model QA","score":miou*100,"details":{"mean_iou":miou,"mismatches":mismatches[:50]}}
305
 
306
- def qc_label_issues(imgs: List[Path], lbls: List[Optional[Path]], cfg: QCConfig) -> Dict:
307
- if not get_noise_indices:
308
- return {"name":"Label issues","score":100.0,"details":"skipped"}
309
  labels, idxs = [], []
310
  sample = imgs[:cfg.sample_limit]
311
- for idx, p in enumerate(sample):
312
- bs = parse_label_file(lbls[idx]) if lbls[idx] else []
313
- for cls,*_ in bs:
314
- labels.append(int(cls)); idxs.append(idx)
315
  if not labels:
316
- return {"name":"Label issues","score":100.0,"details":"no ground truth"}
317
- arr = np.array(labels)
318
- uniq = sorted(set(arr))
319
- probs = np.eye(len(uniq))[np.searchsorted(uniq, arr)]
320
- noise = get_noise_indices(labels=arr, probabilities=probs)
321
- flagged = sorted({idxs[n] for n in noise})
322
- files = [str(sample[i]) for i in flagged]
323
- score = max(0.0, 100 - len(flagged)/len(labels)*100)
324
- return {"name":"Label issues","score":score,"details":{"files":files[:50]}}
325
 
326
  def aggregate(results: List[Dict]) -> float:
327
  return sum(DEFAULT_W[r["name"]]*r["score"] for r in results)
328
 
329
- def gather_class_counts(infos: List[DatasetInfo]) -> Counter[str]:
330
- ctr = Counter()
331
- for info in infos:
332
- for split in info.splits:
333
- labels_dir = Path(info.path)/split/'labels'
334
- if not labels_dir.exists(): continue
335
- for lp in labels_dir.rglob('*.txt'):
336
- for cls, *_ in parse_label_file(lp):
337
- idx = int(cls)
338
- if 0 <= idx < len(info.class_names):
339
- ctr[info.class_names[idx]] += 1
340
- return ctr
341
-
 
 
 
 
 
342
  RF_RE = re.compile(r"https?://universe\.roboflow\.com/([^/]+)/([^/]+)/dataset/(\d+)")
343
 
344
  def download_rf_dataset(url: str, rf_api: Roboflow, dest: Path) -> Path:
@@ -346,17 +334,18 @@ def download_rf_dataset(url: str, rf_api: Roboflow, dest: Path) -> Path:
346
  if not m:
347
  raise ValueError(f"Bad RF URL: {url}")
348
  ws, proj, ver = m.groups()
349
- ds_dir = dest/f"{ws}_{proj}_v{ver}"
350
  if ds_dir.exists():
351
  return ds_dir
352
  pr = rf_api.workspace(ws).project(proj)
353
  pr.version(int(ver)).download("yolov8", location=str(ds_dir))
354
  return ds_dir
355
 
 
356
  def run_quality(
357
  root: Path,
358
- yaml_file: Optional[Path],
359
- weights: Optional[Path],
360
  cfg: QCConfig,
361
  run_dup: bool,
362
  run_modelqa: bool
@@ -366,167 +355,197 @@ def run_quality(
366
  qc_integrity(imgs, lbls, cfg),
367
  qc_class_balance(lbls, cfg),
368
  qc_image_quality(imgs, cfg),
369
- qc_duplicates(imgs, cfg) if run_dup else {"name":"Duplicates","score":100.0,"details":"skipped"},
370
- qc_model_qa(imgs, lbls, cfg) if run_modelqa else {"name":"Model QA","score":100.0,"details":"skipped"},
371
- qc_label_issues(imgs, lbls, cfg) if run_modelqa else {"name":"Label issues","score":100.0,"details":"skipped"},
372
  ]
373
  final = aggregate(results)
374
  md = [f"## **{meta.get('name', root.name)}** β€” Score {final:.1f}/100"]
375
  for r in results:
376
- md += [
377
- f"### {r['name']} {r['score']:.1f}",
378
- "<details><summary>details</summary>\n```json",
379
- json.dumps(r["details"], indent=2),
380
- "```\n</details>\n"
381
- ]
382
- # build class-balance df
383
- cb = next(r for r in results if r["name"]=="Class balance")["details"]["class_counts"]
384
- df = pd.DataFrame.from_dict(cb, orient='index', columns=['count'])
385
  df.index.name = "class"
386
  return "\n".join(md), df
387
 
388
  def merge_datasets(
389
- infos: List[DatasetInfo],
390
  class_map_df: pd.DataFrame,
391
  out_dir: Path = Path("merged_dataset"),
392
- seed: int = 1234
393
  ) -> Path:
394
  random.seed(seed)
395
  if out_dir.exists():
396
- shutil.rmtree(out_dir, onerror=lambda f,p,e: (os.chmod(p, stat.S_IWRITE), f(p)))
397
  for sub in ("train/images","train/labels","valid/images","valid/labels"):
398
- (out_dir/sub).mkdir(parents=True, exist_ok=True)
399
 
400
- # build mapping
401
- mapping = {
402
- row.original_class: ("__REMOVED__" if row.remove else row.new_name)
403
- for row in class_map_df.itertuples()
404
  }
405
- limits = {
406
- row.new_name: row.max_images
407
- for row in class_map_df.itertuples() if not row.remove
 
408
  }
409
- active = sorted({v for v in mapping.values() if v!="__REMOVED__"})
410
- id_map = {c:i for i,c in enumerate(active)}
411
-
412
- # collect image→classes
413
- img2cls, img2lbl = {}, {}
414
- cls2imgs = {c:set() for c in active}
415
- for info in infos:
416
- for split in info.splits:
417
- for lp in Path(info.path)/split/'labels'.rglob('*.txt'):
418
- origs = [info.class_names[int(cls)] for cls,*_ in parse_label_file(lp)]
419
- newset = {mapping[o] for o in origs if mapping.get(o) in active}
420
- if not newset: continue
421
- img_path = str(lp.parent.parent/'images'/f"{lp.stem}.jpg")
422
- img2cls[img_path] = newset
423
- img2lbl[img_path] = lp
424
- for c in newset:
425
- cls2imgs[c].add(img_path)
426
-
427
- # select under class limits
428
- selected, counters = set(), {c:0 for c in active}
429
- pool = list({img for imgs in cls2imgs.values() for img in imgs})
 
 
 
 
 
 
 
 
 
 
 
430
  random.shuffle(pool)
 
431
  for img in pool:
432
- cs = img2cls[img]
433
- if any(counters[c]>=limits[c] for c in cs): continue
434
- selected.add(img)
435
- for c in cs: counters[c]+=1
436
-
437
- # copy & rewrite labels
438
- for img in selected:
439
- split = "train" if random.random()<0.9 else "valid"
440
- dst_im = out_dir/split/'images'/Path(img).name
441
- dst_im.parent.mkdir(parents=True, exist_ok=True)
442
- shutil.copy(img, dst_im)
443
- lp = img2lbl[img]
444
- dst_lbl = out_dir/split/'labels'/lp.name
 
 
445
  dst_lbl.parent.mkdir(parents=True, exist_ok=True)
446
- new = []
447
- for line in lp.read_text().splitlines():
 
448
  parts = line.split()
449
  cid = int(parts[0])
450
- orig = info.class_names[cid] if cid<len(info.class_names) else None
451
- merged = mapping.get(orig)
452
- if merged in active:
453
- new.append(" ".join([str(id_map[merged])]+parts[1:]))
454
- if new:
455
- dst_lbl.write_text("\n".join(new))
 
 
 
 
 
 
456
  else:
457
- dst_im.unlink(missing_ok=True)
458
 
459
- # write data.yaml
460
- meta = {
461
  "path": str(out_dir.resolve()),
462
  "train": "train/images",
463
- "val": "valid/images",
464
- "nc": len(active),
465
- "names": active
466
  }
467
- (out_dir/"data.yaml").write_text(yaml.safe_dump(meta))
468
  return out_dir
469
 
470
  # ════════════════════════════════════════════════════════════════════════════
471
- # UI Layer
472
  # ════════════════════════════════════════════════════════════════════════════
473
  with gr.Blocks(css="#classdf td{min-width:120px}") as demo:
474
- gr.Markdown("# 🏹 **YOLO Dataset Toolkit**\n_Evaluate β€’ Merge β€’ Edit β€’ Download_")
 
 
 
475
 
476
- # Evaluate Tab
477
  with gr.Tab("Evaluate"):
478
- api_in = gr.Textbox(label="Roboflow API key", type="password")
479
- url_txt = gr.File(label=".txt of RF dataset URLs", file_types=['.txt'])
480
- zip_in = gr.File(label="Dataset ZIP")
481
- path_in = gr.Textbox(label="Server path")
482
- yaml_in = gr.File(label="Custom YAML", file_types=['.yaml'])
483
- weights_in = gr.File(label="YOLO weights (.pt)")
484
- blur_sl = gr.Slider(0.0,500.0,value=100.0,label="Blur threshold")
485
- iou_sl = gr.Slider(0.0,1.0,value=0.5,label="IOU threshold")
486
- conf_sl = gr.Slider(0.0,1.0,value=0.25,label="Min detection confidence")
487
- run_dup = gr.Checkbox(label="Check duplicates", value=False)
488
- run_modelqa = gr.Checkbox(label="Run Model QA & cleanlab", value=False)
489
- run_eval = gr.Button("Run Evaluation")
490
- out_md = gr.Markdown()
491
- out_df = gr.Dataframe()
 
 
 
492
 
493
  def _evaluate_cb(
494
- api_key, url_file, zip_file, server_path, yaml_file, weights,
495
- blur_thr, iou_thr, conf_thr, dup, modelqa
496
  ):
497
  reports, dfs = [], []
498
- cfg = QCConfig(blur_thr,iou_thr,conf_thr, weights.name if weights else None)
499
- rf = Roboflow(api_key) if api_key and Roboflow else None
500
 
501
- # Roboflow URLs
502
- if url_file and rf:
503
- for line in Path(url_file.name).read_text().splitlines():
504
  if not line.strip(): continue
505
  try:
506
  ds = download_rf_dataset(line, rf, TMP_ROOT)
507
- md, df = run_quality(Path(ds), None, Path(weights.name) if weights else None,
508
- cfg, dup, modelqa)
 
 
 
509
  reports.append(md); dfs.append(df)
510
  except Exception as e:
511
  reports.append(f"### {line}\n⚠️ {e}")
512
 
513
- # ZIP upload
514
  if zip_file:
515
  tmp = Path(tempfile.mkdtemp())
516
  shutil.unpack_archive(zip_file.name, tmp)
517
- md, df = run_quality(tmp,
518
- Path(yaml_file.name) if yaml_file else None,
519
- Path(weights.name) if weights else None,
520
- cfg, dup, modelqa)
 
 
521
  reports.append(md); dfs.append(df)
522
  shutil.rmtree(tmp, ignore_errors=True)
523
 
524
- # Server path
525
  if server_path:
526
- md, df = run_quality(Path(server_path),
527
- Path(yaml_file.name) if yaml_file else None,
528
- Path(weights.name) if weights else None,
529
- cfg, dup, modelqa)
 
 
 
530
  reports.append(md); dfs.append(df)
531
 
532
  summary = "\n---\n".join(reports) if reports else ""
@@ -535,24 +554,24 @@ with gr.Blocks(css="#classdf td{min-width:120px}") as demo:
535
 
536
  run_eval.click(
537
  _evaluate_cb,
538
- inputs=[api_in,url_txt,zip_in,path_in,yaml_in,weights_in,
539
- blur_sl,iou_sl,conf_sl,run_dup,run_modelqa],
540
- outputs=[out_md,out_df]
541
  )
542
 
543
- # Merge/Edit Tab
544
  with gr.Tab("Merge / Edit"):
545
- gr.Markdown("### 1️⃣ Load datasets")
546
- rf_key = gr.Textbox(label="Roboflow API key", type="password")
547
- rf_urls = gr.File(label=".txt of RF URLs", file_types=['.txt'])
548
- zips = gr.Files(label="Dataset ZIPs")
549
- load_btn= gr.Button("Load")
550
- load_log= gr.Markdown()
551
- ds_state= gr.State([])
552
 
553
  def _load_cb(rf_key, rf_urls_file, zip_files):
554
- nonlocal autoinc
555
- infos, logs = [], []
556
  rf = Roboflow(rf_key) if rf_key and Roboflow else None
557
 
558
  if rf_urls_file and rf:
@@ -561,61 +580,62 @@ with gr.Blocks(css="#classdf td{min-width:120px}") as demo:
561
  if not url: continue
562
  try:
563
  ds = download_rf_dataset(url, rf, TMP_ROOT)
564
- names = load_class_names(Path(ds)/"data.yaml")
565
- splits= [s for s in ("train","valid","test") if (Path(ds)/s).exists()]
566
- infos.append(DatasetInfo(str(ds),names,splits,Path(ds).name))
567
- logs.append(f"βœ”οΈ Loaded RF dataset {Path(ds).name}")
568
  except Exception as e:
569
- logs.append(f"⚠️ RF failed {url!r}: {e}")
570
 
571
  for f in zip_files or []:
572
  autoinc += 1
573
- tmp = TMP_ROOT/f"zip_{autoinc}"
574
  tmp.mkdir(parents=True, exist_ok=True)
575
  shutil.unpack_archive(f.name, tmp)
576
- y = next(tmp.rglob("*.yaml"), None)
577
- if y:
578
- names = load_class_names(y)
579
- splits= [s for s in ("train","valid","test") if (tmp/s).exists()]
580
- infos.append(DatasetInfo(str(tmp),names,splits,tmp.name))
581
- logs.append(f"βœ”οΈ Loaded ZIP {tmp.name}")
582
 
583
- return infos, "\n".join(logs) or "No datasets loaded."
584
 
585
- load_btn.click(_load_cb, [rf_key, rf_urls, zips], [ds_state, load_log])
586
 
587
- gr.Markdown("### 2️⃣ Edit classes")
588
  class_df = gr.Dataframe(
589
  headers=["original_class","new_name","max_images","remove"],
590
  datatype=["str","str","number","bool"],
591
  interactive=True, elem_id="classdf"
592
  )
593
- refresh = gr.Button("Build table")
594
- merge_btn = gr.Button("Merge")
595
- zip_out = gr.File(label="Download merged ZIP")
596
- merge_log = gr.Markdown()
597
 
598
- def _build_df(infos):
599
- counts = gather_class_counts(infos)
600
- names = sorted(counts)
601
  return pd.DataFrame({
602
- "original_class": names,
603
- "new_name": names,
604
- "max_images": [counts[n] for n in names],
605
- "remove": [False]*len(names),
606
  })
607
 
608
- refresh.click(_build_df, [ds_state], [class_df])
 
 
 
 
609
 
610
- def _merge_cb(infos, df):
611
- if not infos:
612
- return None, "⚠️ Load datasets first."
613
- out = merge_datasets(infos, df)
614
- zipf = shutil.make_archive(str(out), "zip", out)
615
- cnt = len(list(Path(out).rglob("*.jpg")))
616
- return zipf, f"βœ… Merged to {out} ({cnt} images)"
617
 
618
  merge_btn.click(_merge_cb, [ds_state, class_df], [zip_out, merge_log])
619
 
620
  if __name__ == "__main__":
621
- demo.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT",7860)))
 
1
  from __future__ import annotations
2
 
3
  import base64
 
4
  import imghdr
5
  import io
6
  import json
 
11
  import shutil
12
  import stat
13
  import tempfile
 
14
  import zipfile
15
  from collections import Counter
16
+ from concurrent.futures import ThreadPoolExecutor, as_completed
17
  from dataclasses import dataclass
18
  from pathlib import Path
19
+ from typing import Dict, List, Tuple
20
 
21
  import gradio as gr
22
  import numpy as np
 
49
  from cleanlab.pruning import get_noise_indices
50
  except ImportError:
51
  get_noise_indices = None
 
 
 
 
52
 
53
  # ───────────────── Config & Constants ───────────────────────────────────────
54
  TMP_ROOT = Path(tempfile.gettempdir()) / "rf_datasets"
55
  TMP_ROOT.mkdir(parents=True, exist_ok=True)
56
+ CPU_COUNT = int(os.getenv("QC_CPU", 1))
57
+ BATCH_SIZE = int(os.getenv("QC_BATCH", 4))
58
+ SAMPLE_LIMIT = int(os.getenv("QC_SAMPLE", 200))
59
 
60
+ DEFAULT_W = {
 
 
 
 
 
 
 
 
 
61
  "Integrity": 0.25,
62
  "Class balance": 0.10,
63
  "Image quality": 0.15,
 
65
  "Model QA": 0.30,
66
  "Label issues": 0.10,
67
  }
68
+
69
  logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(message)s")
70
 
71
  _model_cache: dict[str, YOLO] = {}
72
+ autoinc = 0 # helper for tmp‑dir names
73
 
74
+ # ────────────────────────────────────────────────────────────────────────────
75
+ # Data‑class & basic helpers
76
  # ────────────────────────────────────────────────────────────────────────────
77
  @dataclass
78
  class QCConfig:
79
  blur_thr: float
80
  iou_thr: float
81
  conf_thr: float
82
+ weights: str | None
83
  cpu_count: int = CPU_COUNT
84
  batch_size: int = BATCH_SIZE
85
  sample_limit:int = SAMPLE_LIMIT
86
 
 
 
 
 
 
 
 
 
87
  def load_yaml(path: Path) -> Dict:
 
88
  with path.open('r', encoding='utf-8') as f:
89
  return yaml.safe_load(f)
90
 
 
95
  return [names[k] for k in sorted(names, key=lambda x: int(x))]
96
  return list(names)
97
 
98
+ def parse_label_file(path: Path) -> list[tuple[int, float, float, float, float]]:
99
+ if not path or not path.exists() or path.stat().st_size == 0:
 
 
100
  return []
101
  try:
102
  arr = np.loadtxt(path, dtype=float)
 
116
  ]
117
  return [d for d in candidates if d.exists()]
118
 
119
+ def gather_dataset(root: Path, yaml_path: Path | None):
 
120
  if yaml_path is None:
121
  yamls = list(root.glob('*.yaml'))
122
  if not yamls:
 
134
  ]
135
  return imgs, lbls, meta
136
 
137
+ def get_model(weights: str) -> YOLO | None:
 
138
  if not weights or YOLO is None:
139
  return None
140
  if weights not in _model_cache:
141
+ _model_cache[weights] = YOLO(weights)
 
142
  return _model_cache[weights]
143
 
144
+ # ───────── Concurrency helpers & QC functions ───────────────────────────────
145
+ def _quality_stat_args(args: Tuple[Path, float]) -> Tuple[Path, bool, bool, bool]:
146
+ path, thr = args
147
+ if cv2 is None:
148
+ return path, False, False, False
149
+ im = cv2.imread(str(path))
150
+ if im is None:
151
+ return path, False, False, False
152
+ gray = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY)
153
+ lap = cv2.Laplacian(gray, cv2.CV_64F).var()
154
+ mean = gray.mean()
155
+ return path, lap < thr, mean < 25, mean > 230
156
+
157
  def _is_corrupt(path: Path) -> bool:
158
  try:
159
  with Image.open(path) as im:
 
162
  except Exception:
163
  return True
164
 
165
+ def qc_integrity(imgs: List[Path], lbls: List[Path], cfg: QCConfig) -> Dict:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
  missing = [i for i, l in zip(imgs, lbls) if l is None]
167
  corrupt = []
168
  sample = imgs[:cfg.sample_limit]
169
+ with ThreadPoolExecutor(max_workers=cfg.cpu_count) as ex:
170
+ fut = {ex.submit(_is_corrupt, p): p for p in sample}
171
+ for f in as_completed(fut):
172
  if f.result():
173
+ corrupt.append(fut[f])
174
+ score = 100 - (len(missing) + len(corrupt)) / max(len(imgs), 1) * 100
175
  return {
176
  "name": "Integrity",
177
+ "score": max(score, 0),
178
  "details": {
179
  "missing_label_files": [str(p) for p in missing],
180
+ "corrupt_images": [str(p) for p in corrupt],
181
  }
182
  }
183
 
184
+ def qc_class_balance(lbls: List[Path], cfg: QCConfig) -> Dict:
185
  counts, boxes = Counter(), []
186
  for l in lbls[:cfg.sample_limit]:
187
  bs = parse_label_file(l) if l else []
188
  boxes.append(len(bs))
189
  counts.update(int(b[0]) for b in bs)
190
  if not counts:
191
+ return {"name": "Class balance", "score": 0, "details": "No labels"}
192
  bal = min(counts.values()) / max(counts.values()) * 100
193
  return {
194
  "name": "Class balance",
195
  "score": bal,
196
  "details": {
197
+ "class_counts": dict(counts),
198
+ "boxes_per_image": {
199
+ "min": min(boxes),
200
+ "max": max(boxes),
201
+ "mean": float(np.mean(boxes))
202
+ }
203
  }
204
  }
205
 
206
  def qc_image_quality(imgs: List[Path], cfg: QCConfig) -> Dict:
207
  if cv2 is None:
208
+ return {"name": "Image quality", "score": 100, "details": "cv2 missing"}
209
  blurry, dark, bright = [], [], []
210
  sample = imgs[:cfg.sample_limit]
211
+ with ThreadPoolExecutor(max_workers=cfg.cpu_count) as ex:
212
+ args = [(p, cfg.blur_thr) for p in sample]
213
+ for p, isb, isd, isB in ex.map(_quality_stat_args, args):
214
+ if isb: blurry.append(p)
215
+ if isd: dark.append(p)
216
+ if isB: bright.append(p)
217
+ bad = len({*blurry, *dark, *bright})
218
+ score = 100 - bad / max(len(sample), 1) * 100
219
  return {
220
  "name": "Image quality",
221
  "score": score,
222
+ "details": {
223
+ "blurry": [str(p) for p in blurry],
224
+ "dark": [str(p) for p in dark],
225
+ "bright": [str(p) for p in bright]
226
+ }
227
  }
228
 
229
  def qc_duplicates(imgs: List[Path], cfg: QCConfig) -> Dict:
230
+ if fastdup is not None and len(imgs) > 50:
231
  try:
232
  fd = fastdup.create(
233
  input_dir=str(Path(imgs[0]).parent.parent),
234
+ work_dir=str(TMP_ROOT / "fastdup")
235
  )
236
  fd.run()
 
237
  try:
238
  cc = fd.connected_components_grouped(sort_by="comp_size", ascending=False)
239
+ clusters = cc["files"].tolist() if "files" in cc.columns else cc.groupby("component")["filename"].apply(list).tolist()
240
  except Exception:
241
  clusters = fd.connected_components()
242
+ dup = sum(len(c) - 1 for c in clusters)
243
+ score = max(0.0, 100 - dup / len(imgs) * 100)
244
+ return {"name": "Duplicates", "score": score, "details": {"groups": clusters[:50]}}
245
  except Exception as e:
246
+ return {"name": "Duplicates", "score": 100.0, "details": {"fastdup_error": str(e)}}
247
+ return {"name": "Duplicates", "score": 100.0, "details": {"note": "skipped"}}
248
+
249
+ def _rel_iou(b1, b2):
250
+ x1, y1, w1, h1 = b1
251
+ x2, y2, w2, h2 = b2
252
+ xa1, ya1 = x1 - w1/2, y1 - h1/2
253
+ xa2, ya2 = x1 + w1/2, y1 + h1/2
254
+ xb1, yb1 = x2 - w2/2, y2 - h2/2
255
+ xb2, yb2 = x2 + w2/2, y2 + h2/2
256
+ ix1 = max(xa1, xb1); iy1 = max(ya1, yb1)
257
+ ix2 = min(xa2, xb2); iy2 = min(ya2, yb2)
258
+ inter = max(ix2 - ix1, 0) * max(iy2 - iy1, 0)
259
  union = w1*h1 + w2*h2 - inter
260
  return inter/union if union else 0.0
261
 
262
+ def qc_model_qa(imgs: List[Path], lbls: List[Path], cfg: QCConfig) -> Dict:
263
  model = get_model(cfg.weights)
264
  if model is None:
265
+ return {"name": "Model QA", "score": 100, "details": "skipped"}
266
+ ious, mism = [], []
267
  sample = imgs[:cfg.sample_limit]
268
  for i in range(0, len(sample), cfg.batch_size):
269
  batch = sample[i:i+cfg.batch_size]
270
+ results = model.predict(batch, verbose=False, half=True, dynamic=True)
 
271
  for p, res in zip(batch, results):
272
+ gt = parse_label_file(Path(p).parent.parent/'labels'/f"{Path(p).stem}.txt")
273
  for cls, x, y, w, h in gt:
274
+ best = 0.0
275
+ for b, c, conf in zip(
276
+ res.boxes.xywh.cpu().numpy(),
277
+ res.boxes.cls.cpu().numpy(),
278
+ res.boxes.conf.cpu().numpy()
279
+ ):
280
+ if conf < cfg.conf_thr or int(c) != cls:
281
+ continue
282
+ best = max(best, _rel_iou((x, y, w, h), tuple(b)))
283
  ious.append(best)
284
  if best < cfg.iou_thr:
285
+ mism.append(str(p))
286
  miou = float(np.mean(ious)) if ious else 1.0
287
+ return {"name": "Model QA", "score": miou*100, "details": {"mean_iou": miou, "mismatches": mism[:50]}}
288
 
289
+ def qc_label_issues(imgs: List[Path], lbls: List[Path], cfg: QCConfig) -> Dict:
290
+ if get_noise_indices is None:
291
+ return {"name": "Label issues", "score": 100, "details": "skipped"}
292
  labels, idxs = [], []
293
  sample = imgs[:cfg.sample_limit]
294
+ for i, p in enumerate(sample):
295
+ bs = parse_label_file(lbls[i]) if lbls[i] else []
296
+ for cls, *_ in bs:
297
+ labels.append(int(cls)); idxs.append(i)
298
  if not labels:
299
+ return {"name": "Label issues", "score": 100, "details": "no GT"}
300
+ labels_arr = np.array(labels)
301
+ uniq = sorted(set(labels_arr))
302
+ probs = np.eye(len(uniq))[np.searchsorted(uniq, labels_arr)]
303
+ noise = get_noise_indices(labels=labels_arr, probabilities=probs)
304
+ flags = sorted({idxs[n] for n in noise})
305
+ files = [str(sample[i]) for i in flags]
306
+ score = 100 - len(flags)/len(labels)*100
307
+ return {"name": "Label issues", "score": score, "details": {"files": files[:50]}}
308
 
309
  def aggregate(results: List[Dict]) -> float:
310
  return sum(DEFAULT_W[r["name"]]*r["score"] for r in results)
311
 
312
+ # ───────── gathering actual per-class counts ────────────────────────────────
313
+ def gather_class_counts(
314
+ dataset_info_list: List[Tuple[str, List[str], List[str], str]]
315
+ ) -> Counter[str]:
316
+ counts: Counter[str] = Counter()
317
+ for dloc, class_names, splits, _ in dataset_info_list:
318
+ for split in splits:
319
+ labels_dir = Path(dloc) / split / "labels"
320
+ if not labels_dir.exists():
321
+ continue
322
+ for lp in labels_dir.rglob("*.txt"):
323
+ for cls_id_float, *_ in parse_label_file(lp):
324
+ idx = int(cls_id_float)
325
+ if 0 <= idx < len(class_names):
326
+ counts[class_names[idx]] += 1
327
+ return counts
328
+
329
+ # ───────────────── Roboflow TXT‑loading logic ─────────────────────────────
330
  RF_RE = re.compile(r"https?://universe\.roboflow\.com/([^/]+)/([^/]+)/dataset/(\d+)")
331
 
332
  def download_rf_dataset(url: str, rf_api: Roboflow, dest: Path) -> Path:
 
334
  if not m:
335
  raise ValueError(f"Bad RF URL: {url}")
336
  ws, proj, ver = m.groups()
337
+ ds_dir = dest / f"{ws}_{proj}_v{ver}"
338
  if ds_dir.exists():
339
  return ds_dir
340
  pr = rf_api.workspace(ws).project(proj)
341
  pr.version(int(ver)).download("yolov8", location=str(ds_dir))
342
  return ds_dir
343
 
344
+ # ───────────────── run_quality & merge_datasets ────────────────────────────
345
  def run_quality(
346
  root: Path,
347
+ yaml_file: Path | None,
348
+ weights: Path | None,
349
  cfg: QCConfig,
350
  run_dup: bool,
351
  run_modelqa: bool
 
355
  qc_integrity(imgs, lbls, cfg),
356
  qc_class_balance(lbls, cfg),
357
  qc_image_quality(imgs, cfg),
358
+ qc_duplicates(imgs, cfg) if run_dup else {"name":"Duplicates","score":100,"details":"skipped"},
359
+ qc_model_qa(imgs, lbls, cfg) if run_modelqa else {"name":"Model QA","score":100,"details":"skipped"},
360
+ qc_label_issues(imgs, lbls, cfg) if run_modelqa else {"name":"Label issues","score":100,"details":"skipped"},
361
  ]
362
  final = aggregate(results)
363
  md = [f"## **{meta.get('name', root.name)}** β€” Score {final:.1f}/100"]
364
  for r in results:
365
+ md.append(f"### {r['name']} {r['score']:.1f}")
366
+ md.append("<details><summary>details</summary>\n```json")
367
+ md.append(json.dumps(r["details"], indent=2))
368
+ md.append("```\n</details>\n")
369
+ df = pd.DataFrame.from_dict(
370
+ next(r for r in results if r["name"] == "Class balance")["details"]["class_counts"],
371
+ orient="index", columns=["count"]
372
+ )
 
373
  df.index.name = "class"
374
  return "\n".join(md), df
375
 
376
  def merge_datasets(
377
+ dataset_info_list: List[Tuple[str, List[str], List[str], str]],
378
  class_map_df: pd.DataFrame,
379
  out_dir: Path = Path("merged_dataset"),
380
+ seed: int = 1234,
381
  ) -> Path:
382
  random.seed(seed)
383
  if out_dir.exists():
384
+ shutil.rmtree(out_dir, onerror=lambda f, p, _: (os.chmod(p, stat.S_IWRITE), f(p)))
385
  for sub in ("train/images","train/labels","valid/images","valid/labels"):
386
+ (out_dir / sub).mkdir(parents=True, exist_ok=True)
387
 
388
+ class_name_mapping = {
389
+ row["original_class"]: row["new_name"] if not row["remove"] else "__REMOVED__"
390
+ for _, row in class_map_df.iterrows()
 
391
  }
392
+ limits_per_merged = {
393
+ row["new_name"]: int(row["max_images"])
394
+ for _, row in class_map_df.iterrows()
395
+ if not row["remove"]
396
  }
397
+ active_classes = [c for c in sorted(set(class_name_mapping.values())) if c != "__REMOVED__"]
398
+ id_map = {cls: idx for idx, cls in enumerate(active_classes)}
399
+
400
+ image_to_classes: dict[str, set[str]] = {}
401
+ image_to_label: dict[str, Path] = {}
402
+ class_to_images: dict[str, set[str]] = {c: set() for c in active_classes}
403
+
404
+ for dloc, class_names_dataset, splits, _ in dataset_info_list:
405
+ for split in splits:
406
+ labels_root = Path(dloc) / split / "labels"
407
+ if not labels_root.exists():
408
+ continue
409
+ for lp in labels_root.rglob("*.txt"):
410
+ cls_set: set[str] = set()
411
+ for cls_id_float, *rest in parse_label_file(lp):
412
+ idx = int(cls_id_float)
413
+ if 0 <= idx < len(class_names_dataset):
414
+ orig = class_names_dataset[idx]
415
+ new = class_name_mapping.get(orig, orig)
416
+ if new in active_classes:
417
+ cls_set.add(new)
418
+ if not cls_set:
419
+ continue
420
+ img_path = str(lp.parent.parent / "images" / f"{lp.stem}.jpg")
421
+ image_to_classes[img_path] = cls_set
422
+ image_to_label[img_path] = lp
423
+ for c in cls_set:
424
+ class_to_images[c].add(img_path)
425
+
426
+ selected_images: set[str] = set()
427
+ counters = {c: 0 for c in active_classes}
428
+ pool = [img for imgs in class_to_images.values() for img in imgs]
429
  random.shuffle(pool)
430
+
431
  for img in pool:
432
+ cs = image_to_classes[img]
433
+ if any(counters[c] >= limits_per_merged.get(c, 0) for c in cs):
434
+ continue
435
+ selected_images.add(img)
436
+ for c in cs:
437
+ counters[c] += 1
438
+
439
+ for img in selected_images:
440
+ split = "train" if random.random() < 0.9 else "valid"
441
+ dst_img = out_dir / split / "images" / Path(img).name
442
+ dst_img.parent.mkdir(parents=True, exist_ok=True)
443
+ shutil.copy(img, dst_img)
444
+
445
+ lp_src = image_to_label[img]
446
+ dst_lbl = out_dir / split / "labels" / lp_src.name
447
  dst_lbl.parent.mkdir(parents=True, exist_ok=True)
448
+ lines = lp_src.read_text().splitlines()
449
+ new_lines: List[str] = []
450
+ for line in lines:
451
  parts = line.split()
452
  cid = int(parts[0])
453
+ orig = None
454
+ # find which dataset tuple this lp_src belongs to, to get class_names_dataset
455
+ for dloc, class_names_dataset, splits, _ in dataset_info_list:
456
+ if str(lp_src).startswith(dloc):
457
+ orig = class_names_dataset[cid] if cid < len(class_names_dataset) else None
458
+ break
459
+ merged = class_name_mapping.get(orig, orig) if orig else None
460
+ if merged and merged in active_classes:
461
+ new_id = id_map[merged]
462
+ new_lines.append(" ".join([str(new_id)] + parts[1:]))
463
+ if new_lines:
464
+ dst_lbl.write_text("\n".join(new_lines))
465
  else:
466
+ dst_img.unlink(missing_ok=True)
467
 
468
+ data_yaml = {
 
469
  "path": str(out_dir.resolve()),
470
  "train": "train/images",
471
+ "val": "valid/images",
472
+ "nc": len(active_classes),
473
+ "names": active_classes,
474
  }
475
+ (out_dir / "data.yaml").write_text(yaml.safe_dump(data_yaml))
476
  return out_dir
477
 
478
  # ════════════════════════════════════════════════════════════════════════════
479
+ # UI LAYER
480
  # ════════════════════════════════════════════════════════════════════════════
481
  with gr.Blocks(css="#classdf td{min-width:120px}") as demo:
482
+ gr.Markdown("""
483
+ # 🏹 **YOLO Dataset Toolkit**
484
+ _Evaluate β€’ Merge β€’ Edit β€’ Download_
485
+ """)
486
 
487
+ # Evaluate Tab ...
488
  with gr.Tab("Evaluate"):
489
+ api_in = gr.Textbox(label="Roboflow API key", type="password")
490
+ url_txt = gr.File(label=".txt of RF dataset URLs", file_types=['.txt'])
491
+ zip_in = gr.File(label="Dataset ZIP")
492
+ path_in = gr.Textbox(label="Server path")
493
+ yaml_in = gr.File(label="Custom YAML", file_types=['.yaml'])
494
+ weights_in = gr.File(label="YOLO weights (.pt)")
495
+
496
+ blur_sl = gr.Slider(0.0, 500.0, value=100.0, label="Blur threshold")
497
+ iou_sl = gr.Slider(0.0, 1.0, value=0.5, label="IOU threshold")
498
+ conf_sl = gr.Slider(0.0, 1.0, value=0.25, label="Min detection confidence")
499
+
500
+ run_dup = gr.Checkbox(label="Check duplicates (fastdup)", value=False)
501
+ run_modelqa= gr.Checkbox(label="Run Model QA & cleanlab", value=False)
502
+
503
+ run_eval = gr.Button("Run Evaluation")
504
+ out_md = gr.Markdown()
505
+ out_df = gr.Dataframe()
506
 
507
  def _evaluate_cb(
508
+ api_key, url_txt, zip_file, server_path, yaml_file, weights,
509
+ blur_thr, iou_thr, conf_thr, run_dup, run_modelqa
510
  ):
511
  reports, dfs = [], []
512
+ cfg = QCConfig(blur_thr, iou_thr, conf_thr, weights.name if weights else None)
513
+ rf = Roboflow(api_key) if api_key and Roboflow else None
514
 
515
+ if url_txt and rf:
516
+ for line in Path(url_txt.name).read_text().splitlines():
 
517
  if not line.strip(): continue
518
  try:
519
  ds = download_rf_dataset(line, rf, TMP_ROOT)
520
+ md, df = run_quality(
521
+ ds, None,
522
+ Path(weights.name) if weights else None,
523
+ cfg, run_dup, run_modelqa
524
+ )
525
  reports.append(md); dfs.append(df)
526
  except Exception as e:
527
  reports.append(f"### {line}\n⚠️ {e}")
528
 
 
529
  if zip_file:
530
  tmp = Path(tempfile.mkdtemp())
531
  shutil.unpack_archive(zip_file.name, tmp)
532
+ md, df = run_quality(
533
+ tmp,
534
+ Path(yaml_file.name) if yaml_file else None,
535
+ Path(weights.name) if weights else None,
536
+ cfg, run_dup, run_modelqa
537
+ )
538
  reports.append(md); dfs.append(df)
539
  shutil.rmtree(tmp, ignore_errors=True)
540
 
 
541
  if server_path:
542
+ ds = Path(server_path)
543
+ md, df = run_quality(
544
+ ds,
545
+ Path(yaml_file.name) if yaml_file else None,
546
+ Path(weights.name) if weights else None,
547
+ cfg, run_dup, run_modelqa
548
+ )
549
  reports.append(md); dfs.append(df)
550
 
551
  summary = "\n---\n".join(reports) if reports else ""
 
554
 
555
  run_eval.click(
556
  _evaluate_cb,
557
+ inputs=[api_in, url_txt, zip_in, path_in, yaml_in, weights_in,
558
+ blur_sl, iou_sl, conf_sl, run_dup, run_modelqa],
559
+ outputs=[out_md, out_df]
560
  )
561
 
562
+ # Merge / Edit Tab
563
  with gr.Tab("Merge / Edit"):
564
+ gr.Markdown("### 1️⃣ Load one or more datasets")
565
+ rf_key = gr.Textbox(label="Roboflow API key", type="password")
566
+ rf_urls = gr.File(label=".txt of RF URLs", file_types=['.txt'])
567
+ zips_in = gr.Files(label="One or more dataset ZIPs")
568
+ load_btn = gr.Button("Load datasets")
569
+ load_log = gr.Markdown()
570
+ ds_state = gr.State([])
571
 
572
  def _load_cb(rf_key, rf_urls_file, zip_files):
573
+ global autoinc
574
+ info_list, log_lines = [], []
575
  rf = Roboflow(rf_key) if rf_key and Roboflow else None
576
 
577
  if rf_urls_file and rf:
 
580
  if not url: continue
581
  try:
582
  ds = download_rf_dataset(url, rf, TMP_ROOT)
583
+ names = load_class_names(ds/"data.yaml")
584
+ splits = [s for s in ("train","valid","test") if (ds/s).exists()]
585
+ info_list.append((str(ds), names, splits, Path(ds).name))
586
+ log_lines.append(f"βœ”οΈ RF dataset **{Path(ds).name}** loaded ({len(names)} classes)")
587
  except Exception as e:
588
+ log_lines.append(f"⚠️ RF load failed for {url!r}: {e}")
589
 
590
  for f in zip_files or []:
591
  autoinc += 1
592
+ tmp = TMP_ROOT / f"zip_{autoinc}"
593
  tmp.mkdir(parents=True, exist_ok=True)
594
  shutil.unpack_archive(f.name, tmp)
595
+ yaml_p = next(tmp.rglob("*.yaml"), None)
596
+ if yaml_p:
597
+ names = load_class_names(yaml_p)
598
+ splits = [s for s in ("train","valid","test") if (tmp/s).exists()]
599
+ info_list.append((str(tmp), names, splits, tmp.name))
600
+ log_lines.append(f"βœ”οΈ ZIP **{tmp.name}** loaded")
601
 
602
+ return info_list, "\n".join(log_lines) or "No datasets loaded."
603
 
604
+ load_btn.click(_load_cb, [rf_key, rf_urls, zips_in], [ds_state, load_log])
605
 
606
+ gr.Markdown("### 2️⃣ Edit class mapping / limits / removal")
607
  class_df = gr.Dataframe(
608
  headers=["original_class","new_name","max_images","remove"],
609
  datatype=["str","str","number","bool"],
610
  interactive=True, elem_id="classdf"
611
  )
612
+ refresh_btn = gr.Button("Build class table from loaded datasets")
 
 
 
613
 
614
+ def _build_class_df(ds_info):
615
+ counts = gather_class_counts(ds_info)
616
+ all_names = sorted(counts.keys())
617
  return pd.DataFrame({
618
+ "original_class": all_names,
619
+ "new_name": all_names,
620
+ "max_images": [counts[n] for n in all_names],
621
+ "remove": [False]*len(all_names),
622
  })
623
 
624
+ refresh_btn.click(_build_class_df, [ds_state], [class_df])
625
+
626
+ merge_btn = gr.Button("Merge datasets ✨")
627
+ zip_out = gr.File(label="Download merged ZIP")
628
+ merge_log = gr.Markdown()
629
 
630
+ def _merge_cb(ds_info, class_df):
631
+ if not ds_info:
632
+ return None, "⚠️ Load datasets first."
633
+ out_dir = merge_datasets(ds_info, class_df)
634
+ zip_path = shutil.make_archive(str(out_dir), "zip", out_dir)
635
+ count = len(list(Path(out_dir).rglob("*.jpg")))
636
+ return zip_path, f"βœ…Β Merged dataset at **{out_dir}** with {count} images."
637
 
638
  merge_btn.click(_merge_cb, [ds_state, class_df], [zip_out, merge_log])
639
 
640
  if __name__ == "__main__":
641
+ demo.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", 7860)))