ahsanMah commited on
Commit
d23ef26
·
1 Parent(s): a202591

added option for batch size to caching func

Browse files
Files changed (1) hide show
  1. msma.py +10 -2
msma.py CHANGED
@@ -276,8 +276,16 @@ def train_gmm(preset, outdir, gridsearch=False, **kwargs):
276
 
277
 
278
  @cmdline.command(name="cache-scores")
 
 
 
 
 
 
 
 
279
  @common_args
280
- def cache_score_norms(preset, dataset_path, outdir):
281
  device = DEVICE
282
  dsobj = ImageFolderDataset(path=dataset_path, resolution=64)
283
  refimg, reflabel = dsobj[0]
@@ -286,7 +294,7 @@ def cache_score_norms(preset, dataset_path, outdir):
286
  f"Number of Samples: {len(dsobj)} - shape: {refimg.shape}, dtype: {refimg.dtype}, labels {reflabel}"
287
  )
288
  dsloader = torch.utils.data.DataLoader(
289
- dsobj, batch_size=64, num_workers=4, prefetch_factor=2
290
  )
291
 
292
  model = build_model_from_pickle(preset=preset, device=device)
 
276
 
277
 
278
  @cmdline.command(name="cache-scores")
279
+ @click.option(
280
+ "--batch_size",
281
+ help="Number of samples per batch",
282
+ metavar="INT",
283
+ type=int,
284
+ default=64,
285
+ show_default=True,
286
+ )
287
  @common_args
288
+ def cache_score_norms(preset, dataset_path, outdir, batch_size):
289
  device = DEVICE
290
  dsobj = ImageFolderDataset(path=dataset_path, resolution=64)
291
  refimg, reflabel = dsobj[0]
 
294
  f"Number of Samples: {len(dsobj)} - shape: {refimg.shape}, dtype: {refimg.dtype}, labels {reflabel}"
295
  )
296
  dsloader = torch.utils.data.DataLoader(
297
+ dsobj, batch_size=batch_size, num_workers=4, prefetch_factor=2
298
  )
299
 
300
  model = build_model_from_pickle(preset=preset, device=device)