ahsanMah commited on
Commit
c98586f
·
1 Parent(s): db8f4d5

batch as cmdline argument

Browse files
Files changed (1) hide show
  1. msma.py +11 -4
msma.py CHANGED
@@ -331,12 +331,19 @@ def cache_score_norms(preset, dataset_path, outdir, batch_size):
331
  default=4,
332
  show_default=True,
333
  )
 
 
 
 
 
 
 
 
334
  @common_args
335
- def train_flow(dataset_path, preset, outdir, epochs, **flow_kwargs):
336
  print("using device:", DEVICE)
337
  device = DEVICE
338
  dsobj = ImageFolderDataset(path=dataset_path, resolution=64)
339
- refimg, reflabel = dsobj[0]
340
  print(f"Loaded {len(dsobj)} samples from {dataset_path}")
341
 
342
  # Subset of training dataset
@@ -351,10 +358,10 @@ def train_flow(dataset_path, preset, outdir, epochs, **flow_kwargs):
351
  val_ds = Subset(dsobj, range(train_len, train_len + val_len))
352
 
353
  trainiter = torch.utils.data.DataLoader(
354
- train_ds, batch_size=64, num_workers=4, prefetch_factor=2, shuffle=True
355
  )
356
  testiter = torch.utils.data.DataLoader(
357
- val_ds, batch_size=128, num_workers=4, prefetch_factor=2
358
  )
359
 
360
  scorenet = build_model_from_pickle(preset)
 
331
  default=4,
332
  show_default=True,
333
  )
334
+ @click.option(
335
+ "--batch_size",
336
+ help="Number of samples per batch",
337
+ metavar="INT",
338
+ type=int,
339
+ default=128,
340
+ show_default=True,
341
+ )
342
  @common_args
343
+ def train_flow(dataset_path, preset, outdir, epochs, batch_size, **flow_kwargs):
344
  print("using device:", DEVICE)
345
  device = DEVICE
346
  dsobj = ImageFolderDataset(path=dataset_path, resolution=64)
 
347
  print(f"Loaded {len(dsobj)} samples from {dataset_path}")
348
 
349
  # Subset of training dataset
 
358
  val_ds = Subset(dsobj, range(train_len, train_len + val_len))
359
 
360
  trainiter = torch.utils.data.DataLoader(
361
+ train_ds, batch_size=batch_size, num_workers=4, prefetch_factor=2, shuffle=True
362
  )
363
  testiter = torch.utils.data.DataLoader(
364
+ val_ds, batch_size=batch_size*2, num_workers=4, prefetch_factor=2
365
  )
366
 
367
  scorenet = build_model_from_pickle(preset)