Spaces:
Sleeping
Sleeping
File size: 1,863 Bytes
6fc683c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import tempfile
import unittest
import torch
from fairseq.data import Dictionary
class TestDictionary(unittest.TestCase):
def test_finalize(self):
txt = [
'A B C D',
'B C D',
'C D',
'D',
]
ref_ids1 = list(map(torch.IntTensor, [
[4, 5, 6, 7, 2],
[5, 6, 7, 2],
[6, 7, 2],
[7, 2],
]))
ref_ids2 = list(map(torch.IntTensor, [
[7, 6, 5, 4, 2],
[6, 5, 4, 2],
[5, 4, 2],
[4, 2],
]))
# build dictionary
d = Dictionary()
for line in txt:
d.encode_line(line, add_if_not_exist=True)
def get_ids(dictionary):
ids = []
for line in txt:
ids.append(dictionary.encode_line(line, add_if_not_exist=False))
return ids
def assertMatch(ids, ref_ids):
for toks, ref_toks in zip(ids, ref_ids):
self.assertEqual(toks.size(), ref_toks.size())
self.assertEqual(0, (toks != ref_toks).sum().item())
ids = get_ids(d)
assertMatch(ids, ref_ids1)
# check finalized dictionary
d.finalize()
finalized_ids = get_ids(d)
assertMatch(finalized_ids, ref_ids2)
# write to disk and reload
with tempfile.NamedTemporaryFile(mode='w') as tmp_dict:
d.save(tmp_dict.name)
d = Dictionary.load(tmp_dict.name)
reload_ids = get_ids(d)
assertMatch(reload_ids, ref_ids2)
assertMatch(finalized_ids, reload_ids)
if __name__ == '__main__':
unittest.main()
|