HoneyTian commited on
Commit
ccf5554
·
1 Parent(s): b9f223d
examples/conv_tasnet/step_2_train_model.py CHANGED
@@ -41,13 +41,15 @@ def get_args():
41
 
42
  parser.add_argument("--max_epochs", default=200, type=int)
43
 
 
 
44
  parser.add_argument("--num_serialized_models_to_keep", default=10, type=int)
45
  parser.add_argument("--patience", default=5, type=int)
46
  parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
 
47
 
48
  parser.add_argument("--config_file", default="config.yaml", type=str)
49
 
50
- parser.add_argument("--seed", default=1234, type=int)
51
 
52
  args = parser.parse_args()
53
  return args
@@ -139,7 +141,7 @@ def main():
139
  )
140
  train_data_loader = DataLoader(
141
  dataset=train_dataset,
142
- batch_size=config.batch_size,
143
  # shuffle=True,
144
  sampler=None,
145
  # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
@@ -150,7 +152,7 @@ def main():
150
  )
151
  valid_data_loader = DataLoader(
152
  dataset=valid_dataset,
153
- batch_size=config.batch_size,
154
  # shuffle=True,
155
  sampler=None,
156
  # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
 
41
 
42
  parser.add_argument("--max_epochs", default=200, type=int)
43
 
44
+ parser.add_argument("--batch_size", default=64, type=int)
45
+ parser.add_argument("--learning_rate", default=1e-4, type=float)
46
  parser.add_argument("--num_serialized_models_to_keep", default=10, type=int)
47
  parser.add_argument("--patience", default=5, type=int)
48
  parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
49
+ parser.add_argument("--seed", default=1234, type=int)
50
 
51
  parser.add_argument("--config_file", default="config.yaml", type=str)
52
 
 
53
 
54
  args = parser.parse_args()
55
  return args
 
141
  )
142
  train_data_loader = DataLoader(
143
  dataset=train_dataset,
144
+ batch_size=args.batch_size,
145
  # shuffle=True,
146
  sampler=None,
147
  # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
 
152
  )
153
  valid_data_loader = DataLoader(
154
  dataset=valid_dataset,
155
+ batch_size=args.batch_size,
156
  # shuffle=True,
157
  sampler=None,
158
  # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.