Spaces:
Running
on
Zero
Running
on
Zero
batch as cmdline argument
Browse files
msma.py
CHANGED
@@ -331,12 +331,19 @@ def cache_score_norms(preset, dataset_path, outdir, batch_size):
|
|
331 |
default=4,
|
332 |
show_default=True,
|
333 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
334 |
@common_args
|
335 |
-
def train_flow(dataset_path, preset, outdir, epochs, **flow_kwargs):
|
336 |
print("using device:", DEVICE)
|
337 |
device = DEVICE
|
338 |
dsobj = ImageFolderDataset(path=dataset_path, resolution=64)
|
339 |
-
refimg, reflabel = dsobj[0]
|
340 |
print(f"Loaded {len(dsobj)} samples from {dataset_path}")
|
341 |
|
342 |
# Subset of training dataset
|
@@ -351,10 +358,10 @@ def train_flow(dataset_path, preset, outdir, epochs, **flow_kwargs):
|
|
351 |
val_ds = Subset(dsobj, range(train_len, train_len + val_len))
|
352 |
|
353 |
trainiter = torch.utils.data.DataLoader(
|
354 |
-
train_ds, batch_size=
|
355 |
)
|
356 |
testiter = torch.utils.data.DataLoader(
|
357 |
-
val_ds, batch_size=
|
358 |
)
|
359 |
|
360 |
scorenet = build_model_from_pickle(preset)
|
|
|
331 |
default=4,
|
332 |
show_default=True,
|
333 |
)
|
334 |
+
@click.option(
|
335 |
+
"--batch_size",
|
336 |
+
help="Number of samples per batch",
|
337 |
+
metavar="INT",
|
338 |
+
type=int,
|
339 |
+
default=128,
|
340 |
+
show_default=True,
|
341 |
+
)
|
342 |
@common_args
|
343 |
+
def train_flow(dataset_path, preset, outdir, epochs, batch_size, **flow_kwargs):
|
344 |
print("using device:", DEVICE)
|
345 |
device = DEVICE
|
346 |
dsobj = ImageFolderDataset(path=dataset_path, resolution=64)
|
|
|
347 |
print(f"Loaded {len(dsobj)} samples from {dataset_path}")
|
348 |
|
349 |
# Subset of training dataset
|
|
|
358 |
val_ds = Subset(dsobj, range(train_len, train_len + val_len))
|
359 |
|
360 |
trainiter = torch.utils.data.DataLoader(
|
361 |
+
train_ds, batch_size=batch_size, num_workers=4, prefetch_factor=2, shuffle=True
|
362 |
)
|
363 |
testiter = torch.utils.data.DataLoader(
|
364 |
+
val_ds, batch_size=batch_size*2, num_workers=4, prefetch_factor=2
|
365 |
)
|
366 |
|
367 |
scorenet = build_model_from_pickle(preset)
|