wuhp commited on
Commit
b8d4606
Β·
verified Β·
1 Parent(s): 43dfbd2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -33
app.py CHANGED
@@ -47,8 +47,8 @@ except ImportError:
47
  # ───────────────── Config & Constants ───────────────────────────────────────
48
  TMP_ROOT = Path(tempfile.gettempdir()) / "rf_datasets"
49
  TMP_ROOT.mkdir(parents=True, exist_ok=True)
50
- CPU_COUNT = int(os.getenv("QC_CPU", 1))
51
- BATCH_SIZE = int(os.getenv("QC_BATCH", 4))
52
  SAMPLE_LIMIT = int(os.getenv("QC_SAMPLE", 200))
53
 
54
  DEFAULT_W = {
@@ -64,21 +64,20 @@ _model_cache: dict[str, YOLO] = {}
64
 
65
  @dataclass
66
  class QCConfig:
67
- blur_thr: float
68
- iou_thr: float
69
- conf_thr: float
70
- weights: str | None
71
- cpu_count: int = CPU_COUNT
72
- batch_size: int = BATCH_SIZE
73
- sample_limit: int = SAMPLE_LIMIT
74
-
75
 
76
  # ─────────── Helpers & Caching ─────────────────────────────────────────────
77
  def load_yaml(path: Path) -> Dict:
78
  with path.open('r', encoding='utf-8') as f:
79
  return yaml.safe_load(f)
80
 
81
- def parse_label_file(path: Path) -> List[tuple[int, float, float, float, float]]:
82
  if not path or not path.exists() or path.stat().st_size == 0:
83
  return []
84
  try:
@@ -124,7 +123,6 @@ def get_model(weights: str) -> YOLO | None:
124
  _model_cache[weights] = YOLO(weights)
125
  return _model_cache[weights]
126
 
127
-
128
  # ───────── Functions for I/O-bound concurrency ─────────────────────────────
129
  def _quality_stat_args(args: Tuple[Path, float]) -> Tuple[Path, bool, bool, bool]:
130
  path, thr = args
@@ -146,12 +144,11 @@ def _is_corrupt(path: Path) -> bool:
146
  except:
147
  return True
148
 
149
-
150
  # ───────────────── Quality Checks ──────────────────────────────────────────
151
  def qc_integrity(imgs: List[Path], lbls: List[Path], cfg: QCConfig) -> Dict:
152
  missing = [i for i, l in zip(imgs, lbls) if l is None]
153
  corrupt = []
154
- sample = imgs if len(imgs) <= cfg.sample_limit else imgs[:cfg.sample_limit]
155
  with ThreadPoolExecutor(max_workers=cfg.cpu_count) as ex:
156
  fut = {ex.submit(_is_corrupt, p): p for p in sample}
157
  for f in as_completed(fut):
@@ -193,7 +190,7 @@ def qc_image_quality(imgs: List[Path], cfg: QCConfig) -> Dict:
193
  if cv2 is None:
194
  return {"name":"Image quality","score":100,"details":"cv2 missing"}
195
  blurry, dark, bright = [], [], []
196
- sample = imgs if len(imgs) <= cfg.sample_limit else imgs[:cfg.sample_limit]
197
  with ThreadPoolExecutor(max_workers=cfg.cpu_count) as ex:
198
  args = [(p, cfg.blur_thr) for p in sample]
199
  for p, isb, isd, isB in ex.map(_quality_stat_args, args):
@@ -220,15 +217,23 @@ def qc_duplicates(imgs: List[Path], cfg: QCConfig) -> Dict:
220
  work_dir=str(TMP_ROOT / "fastdup")
221
  )
222
  fd.run()
223
- # try DataFrame API
 
224
  try:
225
  cc = fd.connected_components_grouped(sort_by="comp_size", ascending=False)
226
  if "files" in cc.columns:
227
  clusters = cc["files"].tolist()
228
  else:
229
- clusters = cc.groupby("component")["filename"].apply(list).tolist()
 
 
 
 
 
230
  except Exception:
 
231
  clusters = fd.connected_components()
 
232
  dup = sum(len(c) - 1 for c in clusters)
233
  score = max(0.0, 100 - dup / len(imgs) * 100)
234
  return {
@@ -242,14 +247,14 @@ def qc_duplicates(imgs: List[Path], cfg: QCConfig) -> Dict:
242
  "score": 100.0,
243
  "details": {"fastdup_error": str(e)}
244
  }
245
- return {"name":"Duplicates","score":100.0,"details":{"note":"skipped"}}
246
 
247
  def qc_model_qa(imgs: List[Path], lbls: List[Path], cfg: QCConfig) -> Dict:
248
  model = get_model(cfg.weights)
249
  if model is None:
250
  return {"name":"Model QA","score":100,"details":"skipped"}
251
  ious, mism = [], []
252
- sample = imgs if len(imgs) <= cfg.sample_limit else imgs[:cfg.sample_limit]
253
  for i in range(0, len(sample), cfg.batch_size):
254
  batch = sample[i:i+cfg.batch_size]
255
  results = model.predict(batch, verbose=False, half=True, dynamic=True)
@@ -279,7 +284,7 @@ def qc_label_issues(imgs: List[Path], lbls: List[Path], cfg: QCConfig) -> Dict:
279
  if get_noise_indices is None:
280
  return {"name":"Label issues","score":100,"details":"skipped"}
281
  labels, idxs = [], []
282
- sample = imgs if len(imgs) <= cfg.sample_limit else imgs[:cfg.sample_limit]
283
  for i, p in enumerate(sample):
284
  bs = parse_label_file(lbls[i]) if lbls[i] else []
285
  for cls, *_ in bs:
@@ -382,7 +387,7 @@ with gr.Blocks(title="YOLO Dataset Quality Evaluator v3") as demo:
382
  yaml_in = gr.File(label="Custom YAML", file_types=['.yaml'])
383
  weights_in = gr.File(label="YOLO weights (.pt)")
384
  with gr.Row():
385
- blur_sl = gr.Slider(0.0, 500.0, value=150.0, label="Blur threshold")
386
  iou_sl = gr.Slider(0.0, 1.0, value=0.5, label="IOU threshold")
387
  conf_sl = gr.Slider(0.0, 1.0, value=0.25, label="Min detection confidence")
388
  with gr.Row():
@@ -405,14 +410,16 @@ with gr.Blocks(title="YOLO Dataset Quality Evaluator v3") as demo:
405
 
406
  # Roboflow URLs
407
  if url_txt:
408
- lines = Path(url_txt.name).read_text().splitlines()
409
- for line in lines:
410
  if not line.strip():
411
  continue
412
  try:
413
  ds = download_rf_dataset(line, rf, TMP_ROOT)
414
- md, df = run_quality(ds, None, Path(weights.name) if weights else None,
415
- cfg, run_dup, run_modelqa)
 
 
 
416
  reports.append(md)
417
  dfs.append(df)
418
  except Exception as e:
@@ -422,10 +429,12 @@ with gr.Blocks(title="YOLO Dataset Quality Evaluator v3") as demo:
422
  if zip_file:
423
  tmp = Path(tempfile.mkdtemp())
424
  shutil.unpack_archive(zip_file.name, tmp)
425
- md, df = run_quality(tmp,
426
- Path(yaml_file.name) if yaml_file else None,
427
- Path(weights.name) if weights else None,
428
- cfg, run_dup, run_modelqa)
 
 
429
  reports.append(md)
430
  dfs.append(df)
431
  shutil.rmtree(tmp, ignore_errors=True)
@@ -433,10 +442,12 @@ with gr.Blocks(title="YOLO Dataset Quality Evaluator v3") as demo:
433
  # Server path
434
  if server_path:
435
  ds = Path(server_path)
436
- md, df = run_quality(ds,
437
- Path(yaml_file.name) if yaml_file else None,
438
- Path(weights.name) if weights else None,
439
- cfg, run_dup, run_modelqa)
 
 
440
  reports.append(md)
441
  dfs.append(df)
442
 
 
47
  # ───────────────── Config & Constants ───────────────────────────────────────
48
  TMP_ROOT = Path(tempfile.gettempdir()) / "rf_datasets"
49
  TMP_ROOT.mkdir(parents=True, exist_ok=True)
50
+ CPU_COUNT = int(os.getenv("QC_CPU", 1)) # force single-core by default
51
+ BATCH_SIZE = int(os.getenv("QC_BATCH", 4)) # small batches
52
  SAMPLE_LIMIT = int(os.getenv("QC_SAMPLE", 200))
53
 
54
  DEFAULT_W = {
 
64
 
65
  @dataclass
66
  class QCConfig:
67
+ blur_thr: float
68
+ iou_thr: float
69
+ conf_thr: float
70
+ weights: str | None
71
+ cpu_count: int = CPU_COUNT
72
+ batch_size: int = BATCH_SIZE
73
+ sample_limit:int = SAMPLE_LIMIT
 
74
 
75
  # ─────────── Helpers & Caching ─────────────────────────────────────────────
76
  def load_yaml(path: Path) -> Dict:
77
  with path.open('r', encoding='utf-8') as f:
78
  return yaml.safe_load(f)
79
 
80
+ def parse_label_file(path: Path) -> list[tuple[int, float, float, float, float]]:
81
  if not path or not path.exists() or path.stat().st_size == 0:
82
  return []
83
  try:
 
123
  _model_cache[weights] = YOLO(weights)
124
  return _model_cache[weights]
125
 
 
126
  # ───────── Functions for I/O-bound concurrency ─────────────────────────────
127
  def _quality_stat_args(args: Tuple[Path, float]) -> Tuple[Path, bool, bool, bool]:
128
  path, thr = args
 
144
  except:
145
  return True
146
 
 
147
  # ───────────────── Quality Checks ──────────────────────────────────────────
148
  def qc_integrity(imgs: List[Path], lbls: List[Path], cfg: QCConfig) -> Dict:
149
  missing = [i for i, l in zip(imgs, lbls) if l is None]
150
  corrupt = []
151
+ sample = imgs[:cfg.sample_limit]
152
  with ThreadPoolExecutor(max_workers=cfg.cpu_count) as ex:
153
  fut = {ex.submit(_is_corrupt, p): p for p in sample}
154
  for f in as_completed(fut):
 
190
  if cv2 is None:
191
  return {"name":"Image quality","score":100,"details":"cv2 missing"}
192
  blurry, dark, bright = [], [], []
193
+ sample = imgs[:cfg.sample_limit]
194
  with ThreadPoolExecutor(max_workers=cfg.cpu_count) as ex:
195
  args = [(p, cfg.blur_thr) for p in sample]
196
  for p, isb, isd, isB in ex.map(_quality_stat_args, args):
 
217
  work_dir=str(TMP_ROOT / "fastdup")
218
  )
219
  fd.run()
220
+
221
+ # Try the grouped-DataFrame API first:
222
  try:
223
  cc = fd.connected_components_grouped(sort_by="comp_size", ascending=False)
224
  if "files" in cc.columns:
225
  clusters = cc["files"].tolist()
226
  else:
227
+ # fallback: group by component ID, collect filenames
228
+ clusters = (
229
+ cc.groupby("component")["filename"]
230
+ .apply(list)
231
+ .tolist()
232
+ )
233
  except Exception:
234
+ # final fallback to the old list-based API
235
  clusters = fd.connected_components()
236
+
237
  dup = sum(len(c) - 1 for c in clusters)
238
  score = max(0.0, 100 - dup / len(imgs) * 100)
239
  return {
 
247
  "score": 100.0,
248
  "details": {"fastdup_error": str(e)}
249
  }
250
+ return {"name": "Duplicates", "score": 100.0, "details": {"note": "skipped"}}
251
 
252
  def qc_model_qa(imgs: List[Path], lbls: List[Path], cfg: QCConfig) -> Dict:
253
  model = get_model(cfg.weights)
254
  if model is None:
255
  return {"name":"Model QA","score":100,"details":"skipped"}
256
  ious, mism = [], []
257
+ sample = imgs[:cfg.sample_limit]
258
  for i in range(0, len(sample), cfg.batch_size):
259
  batch = sample[i:i+cfg.batch_size]
260
  results = model.predict(batch, verbose=False, half=True, dynamic=True)
 
284
  if get_noise_indices is None:
285
  return {"name":"Label issues","score":100,"details":"skipped"}
286
  labels, idxs = [], []
287
+ sample = imgs[:cfg.sample_limit]
288
  for i, p in enumerate(sample):
289
  bs = parse_label_file(lbls[i]) if lbls[i] else []
290
  for cls, *_ in bs:
 
387
  yaml_in = gr.File(label="Custom YAML", file_types=['.yaml'])
388
  weights_in = gr.File(label="YOLO weights (.pt)")
389
  with gr.Row():
390
+ blur_sl = gr.Slider(0.0, 500.0, value=100.0, label="Blur threshold")
391
  iou_sl = gr.Slider(0.0, 1.0, value=0.5, label="IOU threshold")
392
  conf_sl = gr.Slider(0.0, 1.0, value=0.25, label="Min detection confidence")
393
  with gr.Row():
 
410
 
411
  # Roboflow URLs
412
  if url_txt:
413
+ for line in Path(url_txt.name).read_text().splitlines():
 
414
  if not line.strip():
415
  continue
416
  try:
417
  ds = download_rf_dataset(line, rf, TMP_ROOT)
418
+ md, df = run_quality(
419
+ ds, None,
420
+ Path(weights.name) if weights else None,
421
+ cfg, run_dup, run_modelqa
422
+ )
423
  reports.append(md)
424
  dfs.append(df)
425
  except Exception as e:
 
429
  if zip_file:
430
  tmp = Path(tempfile.mkdtemp())
431
  shutil.unpack_archive(zip_file.name, tmp)
432
+ md, df = run_quality(
433
+ tmp,
434
+ Path(yaml_file.name) if yaml_file else None,
435
+ Path(weights.name) if weights else None,
436
+ cfg, run_dup, run_modelqa
437
+ )
438
  reports.append(md)
439
  dfs.append(df)
440
  shutil.rmtree(tmp, ignore_errors=True)
 
442
  # Server path
443
  if server_path:
444
  ds = Path(server_path)
445
+ md, df = run_quality(
446
+ ds,
447
+ Path(yaml_file.name) if yaml_file else None,
448
+ Path(weights.name) if weights else None,
449
+ cfg, run_dup, run_modelqa
450
+ )
451
  reports.append(md)
452
  dfs.append(df)
453