Spaces:
Sleeping
Sleeping
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import contextlib | |
from io import StringIO | |
import json | |
import os | |
import tempfile | |
import unittest | |
from . import test_binaries | |
class TestReproducibility(unittest.TestCase): | |
def _test_reproducibility(self, name, extra_flags=None): | |
if extra_flags is None: | |
extra_flags = [] | |
with tempfile.TemporaryDirectory(name) as data_dir: | |
with contextlib.redirect_stdout(StringIO()): | |
test_binaries.create_dummy_data(data_dir) | |
test_binaries.preprocess_translation_data(data_dir) | |
# train epochs 1 and 2 together | |
stdout = StringIO() | |
with contextlib.redirect_stdout(stdout): | |
test_binaries.train_translation_model( | |
data_dir, 'fconv_iwslt_de_en', [ | |
'--dropout', '0.0', | |
'--log-format', 'json', | |
'--log-interval', '1', | |
'--max-epoch', '3', | |
] + extra_flags, | |
) | |
stdout = stdout.getvalue() | |
train_log, valid_log = map(json.loads, stdout.split('\n')[-5:-3]) | |
# train epoch 2, resuming from previous checkpoint 1 | |
os.rename( | |
os.path.join(data_dir, 'checkpoint1.pt'), | |
os.path.join(data_dir, 'checkpoint_last.pt'), | |
) | |
stdout = StringIO() | |
with contextlib.redirect_stdout(stdout): | |
test_binaries.train_translation_model( | |
data_dir, 'fconv_iwslt_de_en', [ | |
'--dropout', '0.0', | |
'--log-format', 'json', | |
'--log-interval', '1', | |
'--max-epoch', '3', | |
] + extra_flags, | |
) | |
stdout = stdout.getvalue() | |
train_res_log, valid_res_log = map(json.loads, stdout.split('\n')[-5:-3]) | |
def cast(s): | |
return round(float(s), 3) | |
for k in ['train_loss', 'train_ppl', 'train_num_updates', 'train_gnorm']: | |
self.assertEqual(cast(train_log[k]), cast(train_res_log[k])) | |
for k in ['valid_loss', 'valid_ppl', 'valid_num_updates', 'valid_best_loss']: | |
self.assertEqual(cast(valid_log[k]), cast(valid_res_log[k])) | |
def test_reproducibility(self): | |
self._test_reproducibility('test_reproducibility') | |
def test_reproducibility_fp16(self): | |
self._test_reproducibility('test_reproducibility_fp16', [ | |
'--fp16', | |
'--fp16-init-scale', '4096', | |
]) | |
def test_reproducibility_memory_efficient_fp16(self): | |
self._test_reproducibility('test_reproducibility_memory_efficient_fp16', [ | |
'--memory-efficient-fp16', | |
'--fp16-init-scale', '4096', | |
]) | |
if __name__ == '__main__': | |
unittest.main() | |