Spaces:
Running
Running
update
Browse files
examples/conv_tasnet/step_2_train_model.py
CHANGED
@@ -41,13 +41,15 @@ def get_args():
|
|
41 |
|
42 |
parser.add_argument("--max_epochs", default=200, type=int)
|
43 |
|
|
|
|
|
44 |
parser.add_argument("--num_serialized_models_to_keep", default=10, type=int)
|
45 |
parser.add_argument("--patience", default=5, type=int)
|
46 |
parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
|
|
|
47 |
|
48 |
parser.add_argument("--config_file", default="config.yaml", type=str)
|
49 |
|
50 |
-
parser.add_argument("--seed", default=1234, type=int)
|
51 |
|
52 |
args = parser.parse_args()
|
53 |
return args
|
@@ -139,7 +141,7 @@ def main():
|
|
139 |
)
|
140 |
train_data_loader = DataLoader(
|
141 |
dataset=train_dataset,
|
142 |
-
batch_size=
|
143 |
# shuffle=True,
|
144 |
sampler=None,
|
145 |
# Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
|
@@ -150,7 +152,7 @@ def main():
|
|
150 |
)
|
151 |
valid_data_loader = DataLoader(
|
152 |
dataset=valid_dataset,
|
153 |
-
batch_size=
|
154 |
# shuffle=True,
|
155 |
sampler=None,
|
156 |
# Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
|
|
|
41 |
|
42 |
parser.add_argument("--max_epochs", default=200, type=int)
|
43 |
|
44 |
+
parser.add_argument("--batch_size", default=64, type=int)
|
45 |
+
parser.add_argument("--learning_rate", default=1e-4, type=float)
|
46 |
parser.add_argument("--num_serialized_models_to_keep", default=10, type=int)
|
47 |
parser.add_argument("--patience", default=5, type=int)
|
48 |
parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
|
49 |
+
parser.add_argument("--seed", default=1234, type=int)
|
50 |
|
51 |
parser.add_argument("--config_file", default="config.yaml", type=str)
|
52 |
|
|
|
53 |
|
54 |
args = parser.parse_args()
|
55 |
return args
|
|
|
141 |
)
|
142 |
train_data_loader = DataLoader(
|
143 |
dataset=train_dataset,
|
144 |
+
batch_size=args.batch_size,
|
145 |
# shuffle=True,
|
146 |
sampler=None,
|
147 |
# Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
|
|
|
152 |
)
|
153 |
valid_data_loader = DataLoader(
|
154 |
dataset=valid_dataset,
|
155 |
+
batch_size=args.batch_size,
|
156 |
# shuffle=True,
|
157 |
sampler=None,
|
158 |
# Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
|