#!/usr/bin/python # -*- coding: utf-8 -*- # Copyright 2018 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License.! import sys sys.path.insert(0, 'src') from collections import defaultdict import io import os import pickle import unittest import sentencepiece as spm print('VERSION={}'.format(spm.__version__)) data_dir = 'test' if sys.platform == 'win32': data_dir = os.path.join('..', 'data') class TestSentencepieceProcessor(unittest.TestCase): """Test case for SentencePieceProcessor""" def setUp(self): self.sp_ = spm.SentencePieceProcessor() self.jasp_ = spm.SentencePieceProcessor() self.assertTrue(self.sp_.Load(os.path.join('test', 'test_model.model'))) self.assertTrue( self.jasp_.Load(os.path.join('test', 'test_ja_model.model')) ) with open(os.path.join('test', 'test_model.model'), 'rb') as f: self.assertTrue(self.sp_.LoadFromSerializedProto(f.read())) with open(os.path.join('test', 'test_ja_model.model'), 'rb') as f: self.assertTrue(self.jasp_.LoadFromSerializedProto(f.read())) def test_load(self): self.assertEqual(1000, self.sp_.GetPieceSize()) self.assertEqual(0, self.sp_.PieceToId('')) self.assertEqual(1, self.sp_.PieceToId('')) self.assertEqual(2, self.sp_.PieceToId('')) self.assertEqual('', self.sp_.IdToPiece(0)) self.assertEqual('', self.sp_.IdToPiece(1)) self.assertEqual('', self.sp_.IdToPiece(2)) self.assertEqual(0, self.sp_.unk_id()) self.assertEqual(1, self.sp_.bos_id()) self.assertEqual(2, self.sp_.eos_id()) self.assertEqual(-1, self.sp_.pad_id()) for i in range(self.sp_.GetPieceSize()): piece = self.sp_.IdToPiece(i) self.assertEqual(i, self.sp_.PieceToId(piece)) self.assertEqual(1000, self.sp_.get_piece_size()) self.assertEqual(0, self.sp_.piece_to_id('')) self.assertEqual(1, self.sp_.piece_to_id('')) self.assertEqual(2, self.sp_.piece_to_id('')) self.assertEqual('', self.sp_.id_to_piece(0)) self.assertEqual('', self.sp_.id_to_piece(1)) self.assertEqual('', self.sp_.id_to_piece(2)) for i in range(self.sp_.get_piece_size()): piece = self.sp_.id_to_piece(i) self.assertEqual(i, self.sp_.piece_to_id(piece)) def test_roundtrip(self): text = 'I saw a girl with a telescope.' ids = self.sp_.EncodeAsIds(text) pieces1 = self.sp_.EncodeAsPieces(text) pieces2 = self.sp_.NBestEncodeAsPieces(text, 10)[0] self.assertEqual(pieces1, pieces2) self.assertEqual(text, self.sp_.DecodePieces(pieces1)) self.assertEqual(text, self.sp_.DecodeIds(ids)) for n in range(100): self.assertEqual( text, self.sp_.DecodePieces(self.sp_.SampleEncodeAsPieces(text, 64, 0.5)), ) self.assertEqual( text, self.sp_.DecodePieces(self.sp_.SampleEncodeAsPieces(text, -1, 0.5)), ) self.assertEqual( text, self.sp_.DecodeIds(self.sp_.SampleEncodeAsIds(text, 64, 0.5)) ) self.assertEqual( text, self.sp_.DecodeIds(self.sp_.SampleEncodeAsIds(text, -1, 0.5)) ) ids2 = self.sp_.encode_as_ids(text) pieces3 = self.sp_.encode_as_pieces(text) pieces4 = self.sp_.nbest_encode_as_pieces(text, 10)[0] self.assertEqual(pieces3, pieces4) self.assertEqual(pieces1, pieces3) self.assertEqual(ids, ids2) self.assertEqual(text, self.sp_.decode_pieces(pieces3)) self.assertEqual(text, self.sp_.decode_ids(ids2)) for n in range(100): self.assertEqual( text, self.sp_.decode_pieces( self.sp_.sample_encode_as_pieces(text, 64, 0.5) ), ) self.assertEqual( text, self.sp_.decode_pieces( self.sp_.sample_encode_as_pieces(text, -1, 0.5) ), ) self.assertEqual( text, self.sp_.decode_ids(self.sp_.sample_encode_as_ids(text, 64, 0.5)), ) self.assertEqual( text, self.sp_.decode_ids(self.sp_.sample_encode_as_ids(text, -1, 0.5)), ) self.assertEqual( self.sp_.calculate_entropy(text, 0.1), self.sp_.CalculateEntropy(text, 0.1), ) def test_ja_load(self): self.assertEqual(8000, self.jasp_.GetPieceSize()) self.assertEqual(0, self.jasp_.PieceToId('')) self.assertEqual(1, self.jasp_.PieceToId('')) self.assertEqual(2, self.jasp_.PieceToId('')) self.assertEqual('', self.jasp_.IdToPiece(0)) self.assertEqual('', self.jasp_.IdToPiece(1)) self.assertEqual('', self.jasp_.IdToPiece(2)) for i in range(self.jasp_.GetPieceSize()): piece = self.jasp_.IdToPiece(i) self.assertEqual(i, self.jasp_.PieceToId(piece)) self.assertEqual(8000, self.jasp_.get_piece_size()) self.assertEqual(0, self.jasp_.piece_to_id('')) self.assertEqual(1, self.jasp_.piece_to_id('')) self.assertEqual(2, self.jasp_.piece_to_id('')) self.assertEqual('', self.jasp_.id_to_piece(0)) self.assertEqual('', self.jasp_.id_to_piece(1)) self.assertEqual('', self.jasp_.id_to_piece(2)) for i in range(self.jasp_.get_piece_size()): piece = self.jasp_.id_to_piece(i) self.assertEqual(i, self.jasp_.piece_to_id(piece)) def test_ja_roundtrip(self): text = '清水寺は京都にある。' ids = self.jasp_.EncodeAsIds(text) pieces1 = self.jasp_.EncodeAsPieces(text) pieces2 = self.jasp_.NBestEncodeAsPieces(text, 10)[0] self.assertEqual(pieces1, pieces2) self.assertEqual(text, self.jasp_.DecodePieces(pieces1)) self.assertEqual(text, self.jasp_.DecodeIds(ids)) for n in range(100): self.assertEqual( text, self.jasp_.DecodePieces( self.jasp_.SampleEncodeAsPieces(text, 64, 0.5) ), ) self.assertEqual( text, self.jasp_.DecodePieces( self.jasp_.SampleEncodeAsPieces(text, -1, 0.5) ), ) ids2 = self.jasp_.encode_as_ids(text) pieces3 = self.jasp_.encode_as_pieces(text) pieces4 = self.jasp_.nbest_encode_as_pieces(text, 10)[0] self.assertEqual(pieces3, pieces4) self.assertEqual(pieces1, pieces3) self.assertEqual(ids, ids2) self.assertEqual(text, self.jasp_.decode_pieces(pieces1)) self.assertEqual(text, self.jasp_.decode_ids(ids2)) for n in range(100): self.assertEqual( text, self.jasp_.decode_pieces( self.jasp_.sample_encode_as_pieces(text, 64, 0.5) ), ) self.assertEqual( text, self.jasp_.decode_pieces( self.jasp_.sample_encode_as_pieces(text, -1, 0.5) ), ) self.assertEqual( self.jasp_.calculate_entropy(text, 0.1), self.jasp_.CalculateEntropy(text, 0.1), ) def test_train(self): spm.SentencePieceTrainer.Train( '--input=' + os.path.join(data_dir, 'botchan.txt') + ' --model_prefix=m --vocab_size=1000' ) sp = spm.SentencePieceProcessor() sp.Load('m.model') with open(os.path.join(data_dir, 'botchan.txt'), 'r') as file: for line in file: sp.DecodePieces(sp.EncodeAsPieces(line)) sp.DecodeIds(sp.EncodeAsIds(line)) def test_train_iterator(self): spm.SentencePieceTrainer.Train( '--input=' + os.path.join(data_dir, 'botchan.txt') + ' --model_prefix=m --vocab_size=1000' ) # Load as 'rb' for Python3.5/2.7. os1 = io.BytesIO() os2 = io.BytesIO() # suppress logging (redirect to /dev/null) spm.SentencePieceTrainer.train( input=os.path.join(data_dir, 'botchan.txt'), model_prefix='m', vocab_size=1000, logstream=open(os.devnull, 'w'), ) with open(os.path.join(data_dir, 'botchan.txt'), 'rb') as is1: spm.SentencePieceTrainer.train( sentence_iterator=is1, model_prefix='m', vocab_size=1000, logstream=open(os.devnull, 'w'), ) spm.SentencePieceTrainer.train( input=os.path.join(data_dir, 'botchan.txt'), model_writer=os1, vocab_size=1000, logstream=open(os.devnull, 'w'), ) with open(os.path.join(data_dir, 'botchan.txt'), 'rb') as is2: spm.SentencePieceTrainer.train( sentence_iterator=is2, model_writer=os2, vocab_size=1000, logstream=open(os.devnull, 'w'), ) sp1 = spm.SentencePieceProcessor(model_proto=os1.getvalue()) sp2 = spm.SentencePieceProcessor(model_proto=os2.getvalue()) self.assertEqual( [sp1.id_to_piece(i) for i in range(sp1.get_piece_size())], [sp2.id_to_piece(i) for i in range(sp2.get_piece_size())], ) def test_train_kwargs(self): # suppress logging (redirect to /dev/null) spm.SentencePieceTrainer.train( input=[os.path.join(data_dir, 'botchan.txt')], model_prefix='m', vocab_size=1002, user_defined_symbols=['foo', 'bar', ',', ' ', '\t', '\b', '\n', '\r'], logstream=open(os.devnull, 'w'), ) sp = spm.SentencePieceProcessor() sp.Load('m.model') with open(os.path.join(data_dir, 'botchan.txt'), 'r') as file: for line in file: sp.DecodePieces(sp.EncodeAsPieces(line)) sp.DecodeIds(sp.EncodeAsIds(line)) s = 'hello\tworld\r\nthis\tis a \b pen' self.assertEqual(s, sp.decode(sp.encode(s))) def test_serialized_proto(self): text = 'I saw a girl with a telescope.' s1 = self.sp_.EncodeAsSerializedProto(text) s2 = self.sp_.SampleEncodeAsSerializedProto(text, 10, 0.2) s3 = self.sp_.NBestEncodeAsSerializedProto(text, 10) s4 = self.sp_.DecodePiecesAsSerializedProto(['foo', 'bar']) s5 = self.sp_.DecodeIdsAsSerializedProto([20, 30]) t1 = self.sp_.encode_as_serialized_proto(text) t2 = self.sp_.sample_encode_as_serialized_proto(text, 10, 0.2) t3 = self.sp_.nbest_encode_as_serialized_proto(text, 10) t4 = self.sp_.decode_pieces_as_serialized_proto(['foo', 'bar']) t5 = self.sp_.decode_ids_as_serialized_proto([20, 30]) y1 = self.sp_.encode(text, out_type='serialized_proto') y2 = self.sp_.encode( text, enable_sampling=True, out_type='serialized_proto' ) y3 = self.sp_.nbest_encode(text, out_type='serialized_proto', nbest_size=10) y4 = self.sp_.decode(['foo', 'bar'], out_type='serialized_proto') y5 = self.sp_.decode([20, 30], out_type='serialized_proto') self.assertEqual(type(s1), bytes) self.assertEqual(type(s2), bytes) self.assertEqual(type(t2), bytes) self.assertEqual(type(s3), bytes) self.assertEqual(type(s4), bytes) self.assertEqual(type(s5), bytes) self.assertEqual(s1, t1) self.assertEqual(s3, t3) self.assertEqual(s4, t4) self.assertEqual(s5, t5) self.assertEqual(s1, y1) self.assertEqual(s3, y3) self.assertEqual(s4, y4) self.assertEqual(s5, y5) ids = self.jasp_.EncodeAsIds(text) pieces = self.jasp_.EncodeAsPieces(text) s1 = self.jasp_.EncodeAsSerializedProto(text) s2 = self.jasp_.DecodeIdsAsSerializedProto(ids) s3 = self.jasp_.DecodePiecesAsSerializedProto(ids) self.assertEqual(s2, s1) self.assertEqual(s3, s1) def test_decode_bytes(self): texts = ['Hello world', '清水寺は京都にある。'] ids = self.jasp_.encode(texts, out_type=int) s1 = self.jasp_.decode(ids, out_type=bytes) s2 = self.jasp_.decode(ids, out_type=str) self.assertEqual(len(s1), 2) self.assertEqual(type(s1[0]), bytes) self.assertEqual(type(s1[1]), bytes) self.assertEqual(len(s2), 2) self.assertEqual(type(s2[0]), str) self.assertEqual(type(s2[1]), str) self.assertEqual(s1[0].decode(encoding='utf-8'), s2[0]) self.assertEqual(s1[1].decode(encoding='utf-8'), s2[1]) text = 'Hello world' ids = self.jasp_.encode(text, out_type=int) s1 = self.jasp_.decode(ids, out_type=bytes) s2 = self.jasp_.decode(ids, out_type=str) self.assertEqual(type(s1), bytes) self.assertEqual(type(s2), str) self.assertEqual(s1.decode(encoding='utf-8'), s2) x = self.jasp_.encode(text, out_type='immutable_proto') self.assertEqual(x.text, x.text_as_bytes.decode(encoding='utf-8')) for sp in x.pieces: self.assertEqual(sp.piece, sp.piece_as_bytes.decode(encoding='utf-8')) self.assertEqual(sp.surface, sp.surface_as_bytes.decode(encoding='utf-8')) x = self.jasp_.decode(ids, out_type='immutable_proto') self.assertEqual(x.text, x.text_as_bytes.decode(encoding='utf-8')) for sp in x.pieces: self.assertEqual(sp.piece, sp.piece_as_bytes.decode(encoding='utf-8')) self.assertEqual(sp.surface, sp.surface_as_bytes.decode(encoding='utf-8')) def test_immutable_proto(self): text = 'I saw a girl with a telescope.' s1 = self.sp_.EncodeAsImmutableProto(text) s2 = self.sp_.SampleEncodeAsImmutableProto(text, 10, 0.2) s3 = self.sp_.NBestEncodeAsImmutableProto(text, 10) s4 = self.sp_.DecodePiecesAsImmutableProto(['foo', 'bar']) s5 = self.sp_.DecodeIdsAsImmutableProto([20, 30]) print(s1) print(s2) print(s3) print(s4) print(s5) t1 = self.sp_.encode_as_immutable_proto(text) t2 = self.sp_.sample_encode_as_immutable_proto(text, 10, 0.2) t3 = self.sp_.nbest_encode_as_immutable_proto(text, 10) t4 = self.sp_.decode_pieces_as_immutable_proto(['foo', 'bar']) t5 = self.sp_.decode_ids_as_immutable_proto([20, 30]) y1 = self.sp_.encode(text, out_type='immutable_proto') y2 = self.sp_.encode(text, enable_sampling=True, out_type='immutable_proto') y3 = self.sp_.nbest_encode(text, out_type='immutable_proto', nbest_size=10) y4 = self.sp_.decode(['foo', 'bar'], out_type='immutable_proto') y5 = self.sp_.decode([20, 30], out_type='immutable_proto') self.assertEqual(s1, t1) self.assertEqual(s3, t3) self.assertEqual(s4, t4) self.assertEqual(s5, t5) self.assertEqual(s1, y1) self.assertEqual(s3, y3) self.assertEqual(s4, y4) self.assertEqual(s5, y5) hset_piece = defaultdict(int) # eq test for i in range(len(s1.pieces)): self.assertEqual(s1.pieces[i], t1.pieces[i]) hset_piece[s1.pieces[i]] += 1 hset_piece[t1.pieces[i]] += 1 self.assertEqual(len(hset_piece), len(s1.pieces)) # has test hset = defaultdict(int) hset[s1] += 1 hset[t1] += 1 hset[s3] += 1 hset[t3] += 1 self.assertEqual(len(hset), 2) self.assertEqual(hset[s1], 2) self.assertEqual(hset[s3], 2) self.assertEqual(hset[t1], 2) self.assertEqual(hset[t3], 2) x1 = self.sp_.encode_as_serialized_proto(text) x2 = self.sp_.sample_encode_as_serialized_proto(text, 10, 0.2) x3 = self.sp_.nbest_encode_as_serialized_proto(text, 10) x4 = self.sp_.decode_pieces_as_serialized_proto(['foo', 'bar']) x5 = self.sp_.decode_ids_as_serialized_proto([20, 30]) self.assertEqual(x1, t1.SerializeAsString()) self.assertEqual(x3, t3.SerializeAsString()) self.assertEqual(x4, t4.SerializeAsString()) self.assertEqual(x5, t5.SerializeAsString()) v1 = self.sp_.EncodeAsIds(text) v2 = self.sp_.EncodeAsPieces(text) self.assertEqual([x.id for x in s1.pieces], v1) self.assertEqual([x.piece for x in s1.pieces], v2) self.assertEqual(text, s1.text) surfaces1 = [s1.text[x.begin : x.end] for x in s1.pieces] surfaces2 = [x.surface for x in s1.pieces] self.assertEqual(surfaces1, surfaces2) ids = [] for i in range(len(s1.pieces)): ids.append(s1.pieces[i].id) self.assertEqual(ids, v1) pieces = [] for i in range(len(s1.pieces)): pieces.append(s1.pieces[i].piece) self.assertEqual(pieces, v2) for v in s3.nbests: self.assertEqual(text, v.text) self.assertEqual(self.sp_.Decode([x.id for x in v.pieces]), text) for i in range(len(s3.nbests)): self.assertEqual(text, s3.nbests[i].text) self.assertEqual( self.sp_.Decode([x.id for x in s3.nbests[i].pieces]), text ) # slice self.assertEqual(s1.pieces[::-1], list(reversed(s1.pieces))) self.assertEqual(s3.nbests[::-1], list(reversed(s3.nbests))) # Japanese offset s1 = self.jasp_.EncodeAsImmutableProto( '吾輩は猫である。Hello world. ABC 123' ) surfaces1 = [s1.text[x.begin : x.end] for x in s1.pieces] surfaces2 = [x.surface for x in s1.pieces] self.assertEqual(surfaces1, surfaces2) ids = [x.id for x in s1.pieces] s2 = self.jasp_.DecodeIdsAsImmutableProto(ids) self.assertEqual(s2, s1) pieces = [x.piece for x in s1.pieces] s2 = self.jasp_.DecodePiecesAsImmutableProto(pieces) self.assertEqual(s2, s1) def test_new_api(self): sp = spm.SentencePieceProcessor( model_file=os.path.join('test', 'test_model.model') ) text = 'hello world' text2 = 'Tokyo' ids = self.sp_.EncodeAsIds(text) ids2 = self.sp_.EncodeAsIds(text2) pieces = self.sp_.EncodeAsPieces(text) pieces2 = self.sp_.EncodeAsPieces(text2) sprotos = self.sp_.EncodeAsSerializedProto(text) sproto2 = self.sp_.EncodeAsSerializedProto(text2) iprotos = self.sp_.EncodeAsImmutableProto(text) iprotos2 = self.sp_.EncodeAsImmutableProto(text2) self.assertEqual(sp.encode(text, out_type=int), ids) self.assertEqual(sp.encode(text, out_type=str), pieces) self.assertEqual(sp.encode(text, out_type='serialized_proto'), sprotos) self.assertEqual(sp.encode(text, out_type='immutable_proto'), iprotos) self.assertEqual(sp.encode([text], out_type=int), [ids]) self.assertEqual(sp.encode([text], out_type=str), [pieces]) self.assertEqual(sp.encode([text], out_type='serialized_proto'), [sprotos]) self.assertEqual(sp.encode([text], out_type='immutable_proto'), [iprotos]) self.assertEqual(len(iprotos.pieces), len(pieces)) self.assertEqual(len(iprotos.pieces), len(ids)) self.assertEqual(iprotos.text, text) self.assertEqual(len(iprotos2.pieces), len(pieces2)) self.assertEqual(len(iprotos2.pieces), len(ids2)) self.assertEqual(iprotos2.text, text2) for i in range(len(iprotos.pieces)): self.assertEqual(ids[i], iprotos.pieces[i].id) self.assertEqual(pieces[i], iprotos.pieces[i].piece) for i, piece in enumerate(iprotos.pieces): self.assertEqual(ids[i], piece.id) self.assertEqual(pieces[i], piece.piece) for i in range(len(iprotos2.pieces)): self.assertEqual(ids2[i], iprotos2.pieces[i].id) self.assertEqual(pieces2[i], iprotos2.pieces[i].piece) for i, piece in enumerate(iprotos2.pieces): self.assertEqual(ids2[i], piece.id) self.assertEqual(pieces2[i], piece.piece) detok_ids = self.sp_.DecodeIds(ids) detok_pieces = self.sp_.DecodePieces(pieces) self.assertEqual(sp.decode(ids), detok_ids) self.assertEqual(sp.decode(pieces), detok_pieces) self.assertEqual(sp.decode([]), '') self.assertEqual(sp.decode([[]]), ['']) # add_bos, add_eos, reverse self.assertEqual([sp.bos_id()] + ids, sp.encode(text, add_bos=True)) self.assertEqual(ids + [sp.eos_id()], sp.encode(text, add_eos=True)) self.assertEqual(ids + [sp.eos_id()], sp.EncodeAsIds(text, add_eos=True)) rids = ids[:] rids.reverse() self.assertEqual(rids, sp.encode(text, reverse=True)) self.assertEqual(rids, sp.EncodeAsIds(text, reverse=True)) # different shape. self.assertEqual([ids, ids2], sp.encode([text, text2])) self.assertEqual([pieces, pieces2], sp.encode([text, text2], out_type=str)) self.assertEqual([text, text2], sp.decode([ids, ids2])) self.assertEqual([text, text2], sp.decode([pieces, pieces2])) pieces = list(reversed(self.sp_.EncodeAsPieces(text))) self.assertEqual(pieces, sp.encode(text, reverse=True, out_type=str)) # emit unk piece unk_char = '藤' pieces = self.sp_.EncodeAsIds(unk_char, emit_unk_piece=True) pieces2 = self.sp_.encode(unk_char, out_type=int, emit_unk_piece=True) self.assertEqual(pieces[1], sp.unk_id()) self.assertEqual(pieces2[1], sp.unk_id()) self.assertEqual(pieces, pieces2) pieces = self.sp_.EncodeAsPieces(unk_char, emit_unk_piece=True) pieces2 = self.sp_.encode(unk_char, out_type=str, emit_unk_piece=True) self.assertEqual(pieces[1], '') self.assertEqual(pieces2[1], '') self.assertEqual(pieces, pieces2) pieces = self.sp_.EncodeAsPieces(unk_char, emit_unk_piece=False) pieces2 = self.sp_.encode(unk_char, out_type=str, emit_unk_piece=False) self.assertEqual(pieces[1], unk_char) self.assertEqual(pieces2[1], unk_char) self.assertEqual(pieces, pieces2) def test_new_api_init(self): sp = spm.SentencePieceProcessor( model_file=os.path.join('test', 'test_model.model'), add_bos=True, add_eos=True, out_type=str, ) text = 'hello world' pieces = [''] + self.sp_.EncodeAsPieces(text) + [''] self.assertEqual(pieces, sp.encode(text)) pieces = self.sp_.EncodeAsPieces(text) + [''] self.assertEqual(pieces, sp.encode(text, add_bos=False, add_eos=True)) def test_sampling(self): sp = self.sp_ for out_type in [str, int, 'serialized_proto', 'immutable_proto']: ids = defaultdict(int) for n in range(100): out = sp.encode('hello world', out_type=out_type, enable_sampling=True) if type(out) is list: out = tuple(out) ++ids[out] self.assertGreater(len(ids), 1) ids2 = defaultdict(int) for n in range(100): out = sp.encode('hello world', out_type=out_type, enable_sampling=False) if type(out) is list: out = tuple(out) ++ids2[out] self.assertEqual(len(ids2), 1) out = sp.encode( ['hello world', 'this is a test'], out_type=out_type, enable_sampling=True, ) self.assertEqual(len(out), 2) out = sp.encode( ['hello world', 'this is a test'], out_type=out_type, enable_sampling=False, ) self.assertEqual(len(out), 2) def test_nbest(self): sp = self.sp_ text = 'hello world' text2 = 'I have a pen.' for out_type in [str, int, 'serialized_proto', 'immutable_proto']: results = sp.nbest_encode(text, nbest_size=10, out_type=out_type) self.assertEqual( results, sp.NBestEncode(text, nbest_size=10, out_type=out_type) ) if out_type in [str, int]: for n in results: self.assertEqual(sp.decode(n), text) for n in sp.decode(results): self.assertEqual(n, text) # batch test results = sp.nbest_encode([text, text2], nbest_size=10, out_type=out_type) self.assertEqual( results, sp.NBestEncode([text, text2], nbest_size=10, out_type=out_type), ) self.assertEqual(len(results), 2) if out_type in [str, int]: for n in results[0]: self.assertEqual(sp.decode(n), text) for n in results[1]: self.assertEqual(sp.decode(n), text2) decoded = sp.decode(results[0]) self.assertEqual(len(decoded), 10) for n in decoded: self.assertEqual(n, text) decoded = sp.decode(results[1]) self.assertEqual(len(decoded), 10) for n in decoded: self.assertEqual(n, text2) self.assertEqual( sp.nbest_encode(text, nbest_size=10, out_type=str), sp.nbest_encode_as_pieces(text, nbest_size=10), ) self.assertEqual( sp.nbest_encode(text, nbest_size=10, out_type=int), sp.nbest_encode_as_ids(text, nbest_size=10), ) self.assertEqual( sp.nbest_encode(text, nbest_size=10, out_type='serialized_proto'), sp.nbest_encode_as_serialized_proto(text, nbest_size=10), ) self.assertEqual( sp.nbest_encode(text, nbest_size=10, out_type='immutable_proto'), sp.nbest_encode_as_immutable_proto(text, nbest_size=10), ) def test_sample_and_score(self): sp = self.sp_ text = 'hello world' text2 = 'I have a pen.' for out_type in [str, int, 'serialized_proto', 'immutable_proto']: results = sp.sample_encode_and_score( text, wor=True, num_samples=10, out_type=out_type ) results = sp.SampleEncodeAndScore( text, wor=False, num_samples=10, out_type=out_type ) if out_type in [str, int]: for n in results: self.assertEqual(sp.decode(n[0]), text) results = sp.sample_encode_and_score( [text, text2], wor=True, num_samples=10, out_type=out_type ) results = sp.SampleEncodeAndScore( [text, text2], wor=True, num_samples=10, out_type=out_type ) if out_type in [str, int]: for n in results[0]: self.assertEqual(sp.decode(n[0]), text) for n in results[1]: self.assertEqual(sp.decode(n[0]), text2) sp.sample_encode_and_score_as_pieces(text, 10) sp.sample_encode_and_score_as_ids(text, 10) sp.sample_encode_and_score_as_immutable_proto(text, 10) sp.sample_encode_and_score_as_serialized_proto(text, 10) def test_valid_range(self): size = self.sp_.piece_size() funcs = [ 'IdToPiece', 'GetScore', 'IsUnknown', 'IsControl', 'IsUnused', 'IsByte', 'DecodeIds', 'DecodeIdsAsSerializedProto', ] for m in funcs: getattr(self.sp_, m)([10, 20, 30]) for m in funcs: try: getattr(self.sp_, m)([size]) self.assertTrue(False) except: self.assertTrue(True) def test_batch(self): sp = spm.SentencePieceProcessor( model_file=os.path.join('test', 'test_model.model') ) with open(os.path.join(data_dir, 'botchan.txt'), 'r') as file: texts = file.readlines() for out_type in [str, int, 'serialized_proto', 'immutable_proto']: r1 = sp.encode(texts, out_type=out_type, num_threads=None) r2 = sp.encode(texts, out_type=out_type, num_threads=1) r3 = sp.encode(texts, out_type=out_type, num_threads=-1) r4 = sp.encode(texts, out_type=out_type, num_threads=8) r5 = [sp.encode(s, out_type=out_type) for s in texts] self.assertEqual(r1, r2) self.assertEqual(r1, r3) self.assertEqual(r1, r4) self.assertEqual(r1, r5) if out_type in [str, int]: d1 = sp.decode(r1, num_threads=None) d2 = sp.decode(r2, num_threads=1) d3 = sp.decode(r3, num_threads=-1) d4 = sp.decode(r4, num_threads=8) d5 = [sp.decode(s) for s in r5] self.assertEqual(d1, d2) self.assertEqual(d1, d3) self.assertEqual(d1, d4) self.assertEqual(d1, d5) e1 = sp.calculate_entropy(texts, alpha=1.0, num_threads=10) e2 = sp.CalculateEntropy(texts, alpha=1.0, num_threads=10) e3 = [sp.calculate_entropy(s, alpha=1.0) for s in texts] self.assertEqual(e1, e2) self.assertEqual(e1, e3) def test_pickle(self): with open('sp.pickle', 'wb') as f: pickle.dump(self.sp_, f) id1 = self.sp_.encode('hello world.', out_type=int) with open('sp.pickle', 'rb') as f: sp = pickle.load(f) id2 = sp.encode('hello world.', out_type=int) self.assertEqual(id1, id2) def test_global_params(self): spm.SetRandomGeneratorSeed(0) spm.SetMinLogLevel(2) spm.set_random_generator_seed(1) spm.set_min_log_level(3) def test_normalize(self): sp = spm.SentencePieceProcessor( model_file=os.path.join('test', 'test_model.model') ) self.assertEqual('▁KADOKAWAABC', sp.normalize('KADOKAWAABC')) self.assertEqual('▁KADOKAWAABC', sp.Normalize('KADOKAWAABC')) x = sp.Normalize('KADOKAWAABC', with_offsets=True) self.assertEqual('▁KADOKAWAABC', x[0]) self.assertEqual([0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], x[1]) x = sp.Normalize('KADOKAWAABC'.encode('utf8'), with_offsets=True) self.assertEqual('▁KADOKAWAABC'.encode('utf8'), x[0]) self.assertEqual( [0, 0, 0, 0, 3, 6, 9, 12, 15, 18, 21, 24, 25, 26, 27], x[1] ) self.assertEqual( ['▁KADOKAWAABC', '▁平成'], sp.normalize(['KADOKAWAABC', '㍻']) ) self.assertEqual( ['▁KADOKAWAABC', '▁平成'], sp.Normalize(['KADOKAWAABC', '㍻']) ) x = sp.Normalize( ['KADOKAWAABC'.encode('utf8'), '㍻'.encode('utf8')], with_offsets=True, ) self.assertEqual(len(x), 2) self.assertEqual('▁KADOKAWAABC'.encode('utf8'), x[0][0]) self.assertEqual( [0, 0, 0, 0, 3, 6, 9, 12, 15, 18, 21, 24, 25, 26, 27], x[0][1] ) x = sp.Normalize(['KADOKAWAABC', '㍻'], with_offsets=True) self.assertEqual(len(x), 2) self.assertEqual('▁KADOKAWAABC', x[0][0]) self.assertEqual([0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], x[0][1]) self.assertEqual('▁平成', x[1][0]) self.assertEqual([0, 0, 0, 1], x[1][1]) def test_normalizer(self): sp = spm.SentencePieceNormalizer( model_file=os.path.join('test', 'test_model.model') ) self.assertEqual('KADOKAWAABC', sp.normalize('KADOKAWAABC')) self.assertEqual('KADOKAWAABC', sp.Normalize('KADOKAWAABC')) x = sp.Normalize('KADOKAWAABC'.encode('utf8'), with_offsets=True) self.assertEqual('KADOKAWAABC'.encode('utf8'), x[0]) self.assertEqual([0, 3, 6, 9, 12, 15, 18, 21, 24, 25, 26, 27], x[1]) x = sp.Normalize('KADOKAWAABC', with_offsets=True) self.assertEqual('KADOKAWAABC', x[0]) self.assertEqual([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], x[1]) self.assertEqual( ['KADOKAWAABC', '平成'], sp.normalize(['KADOKAWAABC', '㍻']) ) self.assertEqual( ['KADOKAWAABC', '平成'], sp.Normalize(['KADOKAWAABC', '㍻']) ) x = sp.Normalize( ['KADOKAWAABC'.encode('utf8'), '㍻'.encode('utf8')], with_offsets=True, ) self.assertEqual(len(x), 2) self.assertEqual('KADOKAWAABC'.encode('utf8'), x[0][0]) self.assertEqual([0, 3, 6, 9, 12, 15, 18, 21, 24, 25, 26, 27], x[0][1]) x = sp.Normalize(['KADOKAWAABC', '㍻'], with_offsets=True) self.assertEqual(len(x), 2) self.assertEqual('KADOKAWAABC', x[0][0]) self.assertEqual([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], x[0][1]) self.assertEqual('平成', x[1][0]) self.assertEqual([0, 0, 1], x[1][1]) sp = spm.SentencePieceNormalizer( model_file=os.path.join('test', 'test_model.model'), add_dummy_prefix=True, escape_whitespaces=True, remove_extra_whitespaces=False, ) self.assertEqual('▁hello▁▁world', sp.normalize('hello world')) sp = spm.SentencePieceNormalizer( model_file=os.path.join('test', 'test_model.model'), add_dummy_prefix=True, escape_whitespaces=True, remove_extra_whitespaces=True, ) self.assertEqual('▁hello▁world', sp.normalize(' hello world ')) sp = spm.SentencePieceNormalizer( model_file=os.path.join('test', 'test_model.model'), add_dummy_prefix=False, escape_whitespaces=False, remove_extra_whitespaces=True, ) self.assertEqual('hello world', sp.normalize(' hello world ')) def test_normalizer_rule(self): sp = spm.SentencePieceNormalizer(rule_name='identity') self.assertEqual('ABC', sp.Normalize('ABC')) sp = spm.SentencePieceNormalizer(rule_name='nfkc_cf') self.assertEqual('abc', sp.Normalize('ABC')) def test_override_normalize_spec(self): sp = spm.SentencePieceProcessor( model_file=os.path.join('test', 'test_model.model') ) self.assertEqual( sp.EncodeAsPieces(' hello world '), ['▁he', 'll', 'o', '▁world'] ) sp.override_normalizer_spec(add_dummy_prefix=False) sp.override_normalizer_spec(remove_extra_whitespaces=False) sp.override_normalizer_spec(escape_whitespaces=False) self.assertEqual( sp.EncodeAsPieces(' hello world '), [' ', 'he', 'll', 'o', ' ', 'w', 'or', 'l', 'd', ' '], ) def suite(): suite = unittest.TestSuite() suite.addTests(unittest.makeSuite(TestSentencepieceProcessor)) return suite if __name__ == '__main__': unittest.main()