Spaces:
Running
on
Zero
Running
on
Zero
added option for batch size to caching func
Browse files
msma.py
CHANGED
@@ -276,8 +276,16 @@ def train_gmm(preset, outdir, gridsearch=False, **kwargs):
|
|
276 |
|
277 |
|
278 |
@cmdline.command(name="cache-scores")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
279 |
@common_args
|
280 |
-
def cache_score_norms(preset, dataset_path, outdir):
|
281 |
device = DEVICE
|
282 |
dsobj = ImageFolderDataset(path=dataset_path, resolution=64)
|
283 |
refimg, reflabel = dsobj[0]
|
@@ -286,7 +294,7 @@ def cache_score_norms(preset, dataset_path, outdir):
|
|
286 |
f"Number of Samples: {len(dsobj)} - shape: {refimg.shape}, dtype: {refimg.dtype}, labels {reflabel}"
|
287 |
)
|
288 |
dsloader = torch.utils.data.DataLoader(
|
289 |
-
dsobj, batch_size=
|
290 |
)
|
291 |
|
292 |
model = build_model_from_pickle(preset=preset, device=device)
|
|
|
276 |
|
277 |
|
278 |
@cmdline.command(name="cache-scores")
|
279 |
+
@click.option(
|
280 |
+
"--batch_size",
|
281 |
+
help="Number of samples per batch",
|
282 |
+
metavar="INT",
|
283 |
+
type=int,
|
284 |
+
default=64,
|
285 |
+
show_default=True,
|
286 |
+
)
|
287 |
@common_args
|
288 |
+
def cache_score_norms(preset, dataset_path, outdir, batch_size):
|
289 |
device = DEVICE
|
290 |
dsobj = ImageFolderDataset(path=dataset_path, resolution=64)
|
291 |
refimg, reflabel = dsobj[0]
|
|
|
294 |
f"Number of Samples: {len(dsobj)} - shape: {refimg.shape}, dtype: {refimg.dtype}, labels {reflabel}"
|
295 |
)
|
296 |
dsloader = torch.utils.data.DataLoader(
|
297 |
+
dsobj, batch_size=batch_size, num_workers=4, prefetch_factor=2
|
298 |
)
|
299 |
|
300 |
model = build_model_from_pickle(preset=preset, device=device)
|