Spaces:
Sleeping
Sleeping
#!/usr/bin/python | |
# -*- coding: utf-8 -*- | |
# Copyright 2018 Google Inc. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License.! | |
import sys | |
sys.path.insert(0, 'src') | |
from collections import defaultdict | |
import io | |
import os | |
import pickle | |
import unittest | |
import sentencepiece as spm | |
print('VERSION={}'.format(spm.__version__)) | |
data_dir = 'test' | |
if sys.platform == 'win32': | |
data_dir = os.path.join('..', 'data') | |
class TestSentencepieceProcessor(unittest.TestCase): | |
"""Test case for SentencePieceProcessor""" | |
def setUp(self): | |
self.sp_ = spm.SentencePieceProcessor() | |
self.jasp_ = spm.SentencePieceProcessor() | |
self.assertTrue(self.sp_.Load(os.path.join('test', 'test_model.model'))) | |
self.assertTrue( | |
self.jasp_.Load(os.path.join('test', 'test_ja_model.model')) | |
) | |
with open(os.path.join('test', 'test_model.model'), 'rb') as f: | |
self.assertTrue(self.sp_.LoadFromSerializedProto(f.read())) | |
with open(os.path.join('test', 'test_ja_model.model'), 'rb') as f: | |
self.assertTrue(self.jasp_.LoadFromSerializedProto(f.read())) | |
def test_load(self): | |
self.assertEqual(1000, self.sp_.GetPieceSize()) | |
self.assertEqual(0, self.sp_.PieceToId('<unk>')) | |
self.assertEqual(1, self.sp_.PieceToId('<s>')) | |
self.assertEqual(2, self.sp_.PieceToId('</s>')) | |
self.assertEqual('<unk>', self.sp_.IdToPiece(0)) | |
self.assertEqual('<s>', self.sp_.IdToPiece(1)) | |
self.assertEqual('</s>', self.sp_.IdToPiece(2)) | |
self.assertEqual(0, self.sp_.unk_id()) | |
self.assertEqual(1, self.sp_.bos_id()) | |
self.assertEqual(2, self.sp_.eos_id()) | |
self.assertEqual(-1, self.sp_.pad_id()) | |
for i in range(self.sp_.GetPieceSize()): | |
piece = self.sp_.IdToPiece(i) | |
self.assertEqual(i, self.sp_.PieceToId(piece)) | |
self.assertEqual(1000, self.sp_.get_piece_size()) | |
self.assertEqual(0, self.sp_.piece_to_id('<unk>')) | |
self.assertEqual(1, self.sp_.piece_to_id('<s>')) | |
self.assertEqual(2, self.sp_.piece_to_id('</s>')) | |
self.assertEqual('<unk>', self.sp_.id_to_piece(0)) | |
self.assertEqual('<s>', self.sp_.id_to_piece(1)) | |
self.assertEqual('</s>', self.sp_.id_to_piece(2)) | |
for i in range(self.sp_.get_piece_size()): | |
piece = self.sp_.id_to_piece(i) | |
self.assertEqual(i, self.sp_.piece_to_id(piece)) | |
def test_roundtrip(self): | |
text = 'I saw a girl with a telescope.' | |
ids = self.sp_.EncodeAsIds(text) | |
pieces1 = self.sp_.EncodeAsPieces(text) | |
pieces2 = self.sp_.NBestEncodeAsPieces(text, 10)[0] | |
self.assertEqual(pieces1, pieces2) | |
self.assertEqual(text, self.sp_.DecodePieces(pieces1)) | |
self.assertEqual(text, self.sp_.DecodeIds(ids)) | |
for n in range(100): | |
self.assertEqual( | |
text, | |
self.sp_.DecodePieces(self.sp_.SampleEncodeAsPieces(text, 64, 0.5)), | |
) | |
self.assertEqual( | |
text, | |
self.sp_.DecodePieces(self.sp_.SampleEncodeAsPieces(text, -1, 0.5)), | |
) | |
self.assertEqual( | |
text, self.sp_.DecodeIds(self.sp_.SampleEncodeAsIds(text, 64, 0.5)) | |
) | |
self.assertEqual( | |
text, self.sp_.DecodeIds(self.sp_.SampleEncodeAsIds(text, -1, 0.5)) | |
) | |
ids2 = self.sp_.encode_as_ids(text) | |
pieces3 = self.sp_.encode_as_pieces(text) | |
pieces4 = self.sp_.nbest_encode_as_pieces(text, 10)[0] | |
self.assertEqual(pieces3, pieces4) | |
self.assertEqual(pieces1, pieces3) | |
self.assertEqual(ids, ids2) | |
self.assertEqual(text, self.sp_.decode_pieces(pieces3)) | |
self.assertEqual(text, self.sp_.decode_ids(ids2)) | |
for n in range(100): | |
self.assertEqual( | |
text, | |
self.sp_.decode_pieces( | |
self.sp_.sample_encode_as_pieces(text, 64, 0.5) | |
), | |
) | |
self.assertEqual( | |
text, | |
self.sp_.decode_pieces( | |
self.sp_.sample_encode_as_pieces(text, -1, 0.5) | |
), | |
) | |
self.assertEqual( | |
text, | |
self.sp_.decode_ids(self.sp_.sample_encode_as_ids(text, 64, 0.5)), | |
) | |
self.assertEqual( | |
text, | |
self.sp_.decode_ids(self.sp_.sample_encode_as_ids(text, -1, 0.5)), | |
) | |
self.assertEqual( | |
self.sp_.calculate_entropy(text, 0.1), | |
self.sp_.CalculateEntropy(text, 0.1), | |
) | |
def test_ja_load(self): | |
self.assertEqual(8000, self.jasp_.GetPieceSize()) | |
self.assertEqual(0, self.jasp_.PieceToId('<unk>')) | |
self.assertEqual(1, self.jasp_.PieceToId('<s>')) | |
self.assertEqual(2, self.jasp_.PieceToId('</s>')) | |
self.assertEqual('<unk>', self.jasp_.IdToPiece(0)) | |
self.assertEqual('<s>', self.jasp_.IdToPiece(1)) | |
self.assertEqual('</s>', self.jasp_.IdToPiece(2)) | |
for i in range(self.jasp_.GetPieceSize()): | |
piece = self.jasp_.IdToPiece(i) | |
self.assertEqual(i, self.jasp_.PieceToId(piece)) | |
self.assertEqual(8000, self.jasp_.get_piece_size()) | |
self.assertEqual(0, self.jasp_.piece_to_id('<unk>')) | |
self.assertEqual(1, self.jasp_.piece_to_id('<s>')) | |
self.assertEqual(2, self.jasp_.piece_to_id('</s>')) | |
self.assertEqual('<unk>', self.jasp_.id_to_piece(0)) | |
self.assertEqual('<s>', self.jasp_.id_to_piece(1)) | |
self.assertEqual('</s>', self.jasp_.id_to_piece(2)) | |
for i in range(self.jasp_.get_piece_size()): | |
piece = self.jasp_.id_to_piece(i) | |
self.assertEqual(i, self.jasp_.piece_to_id(piece)) | |
def test_ja_roundtrip(self): | |
text = '清水寺は京都にある。' | |
ids = self.jasp_.EncodeAsIds(text) | |
pieces1 = self.jasp_.EncodeAsPieces(text) | |
pieces2 = self.jasp_.NBestEncodeAsPieces(text, 10)[0] | |
self.assertEqual(pieces1, pieces2) | |
self.assertEqual(text, self.jasp_.DecodePieces(pieces1)) | |
self.assertEqual(text, self.jasp_.DecodeIds(ids)) | |
for n in range(100): | |
self.assertEqual( | |
text, | |
self.jasp_.DecodePieces( | |
self.jasp_.SampleEncodeAsPieces(text, 64, 0.5) | |
), | |
) | |
self.assertEqual( | |
text, | |
self.jasp_.DecodePieces( | |
self.jasp_.SampleEncodeAsPieces(text, -1, 0.5) | |
), | |
) | |
ids2 = self.jasp_.encode_as_ids(text) | |
pieces3 = self.jasp_.encode_as_pieces(text) | |
pieces4 = self.jasp_.nbest_encode_as_pieces(text, 10)[0] | |
self.assertEqual(pieces3, pieces4) | |
self.assertEqual(pieces1, pieces3) | |
self.assertEqual(ids, ids2) | |
self.assertEqual(text, self.jasp_.decode_pieces(pieces1)) | |
self.assertEqual(text, self.jasp_.decode_ids(ids2)) | |
for n in range(100): | |
self.assertEqual( | |
text, | |
self.jasp_.decode_pieces( | |
self.jasp_.sample_encode_as_pieces(text, 64, 0.5) | |
), | |
) | |
self.assertEqual( | |
text, | |
self.jasp_.decode_pieces( | |
self.jasp_.sample_encode_as_pieces(text, -1, 0.5) | |
), | |
) | |
self.assertEqual( | |
self.jasp_.calculate_entropy(text, 0.1), | |
self.jasp_.CalculateEntropy(text, 0.1), | |
) | |
def test_train(self): | |
spm.SentencePieceTrainer.Train( | |
'--input=' | |
+ os.path.join(data_dir, 'botchan.txt') | |
+ ' --model_prefix=m --vocab_size=1000' | |
) | |
sp = spm.SentencePieceProcessor() | |
sp.Load('m.model') | |
with open(os.path.join(data_dir, 'botchan.txt'), 'r') as file: | |
for line in file: | |
sp.DecodePieces(sp.EncodeAsPieces(line)) | |
sp.DecodeIds(sp.EncodeAsIds(line)) | |
def test_train_iterator(self): | |
spm.SentencePieceTrainer.Train( | |
'--input=' | |
+ os.path.join(data_dir, 'botchan.txt') | |
+ ' --model_prefix=m --vocab_size=1000' | |
) | |
# Load as 'rb' for Python3.5/2.7. | |
os1 = io.BytesIO() | |
os2 = io.BytesIO() | |
# suppress logging (redirect to /dev/null) | |
spm.SentencePieceTrainer.train( | |
input=os.path.join(data_dir, 'botchan.txt'), | |
model_prefix='m', | |
vocab_size=1000, | |
logstream=open(os.devnull, 'w'), | |
) | |
with open(os.path.join(data_dir, 'botchan.txt'), 'rb') as is1: | |
spm.SentencePieceTrainer.train( | |
sentence_iterator=is1, | |
model_prefix='m', | |
vocab_size=1000, | |
logstream=open(os.devnull, 'w'), | |
) | |
spm.SentencePieceTrainer.train( | |
input=os.path.join(data_dir, 'botchan.txt'), | |
model_writer=os1, | |
vocab_size=1000, | |
logstream=open(os.devnull, 'w'), | |
) | |
with open(os.path.join(data_dir, 'botchan.txt'), 'rb') as is2: | |
spm.SentencePieceTrainer.train( | |
sentence_iterator=is2, | |
model_writer=os2, | |
vocab_size=1000, | |
logstream=open(os.devnull, 'w'), | |
) | |
sp1 = spm.SentencePieceProcessor(model_proto=os1.getvalue()) | |
sp2 = spm.SentencePieceProcessor(model_proto=os2.getvalue()) | |
self.assertEqual( | |
[sp1.id_to_piece(i) for i in range(sp1.get_piece_size())], | |
[sp2.id_to_piece(i) for i in range(sp2.get_piece_size())], | |
) | |
def test_train_kwargs(self): | |
# suppress logging (redirect to /dev/null) | |
spm.SentencePieceTrainer.train( | |
input=[os.path.join(data_dir, 'botchan.txt')], | |
model_prefix='m', | |
vocab_size=1002, | |
user_defined_symbols=['foo', 'bar', ',', ' ', '\t', '\b', '\n', '\r'], | |
logstream=open(os.devnull, 'w'), | |
) | |
sp = spm.SentencePieceProcessor() | |
sp.Load('m.model') | |
with open(os.path.join(data_dir, 'botchan.txt'), 'r') as file: | |
for line in file: | |
sp.DecodePieces(sp.EncodeAsPieces(line)) | |
sp.DecodeIds(sp.EncodeAsIds(line)) | |
s = 'hello\tworld\r\nthis\tis a \b pen' | |
self.assertEqual(s, sp.decode(sp.encode(s))) | |
def test_serialized_proto(self): | |
text = 'I saw a girl with a telescope.' | |
s1 = self.sp_.EncodeAsSerializedProto(text) | |
s2 = self.sp_.SampleEncodeAsSerializedProto(text, 10, 0.2) | |
s3 = self.sp_.NBestEncodeAsSerializedProto(text, 10) | |
s4 = self.sp_.DecodePiecesAsSerializedProto(['foo', 'bar']) | |
s5 = self.sp_.DecodeIdsAsSerializedProto([20, 30]) | |
t1 = self.sp_.encode_as_serialized_proto(text) | |
t2 = self.sp_.sample_encode_as_serialized_proto(text, 10, 0.2) | |
t3 = self.sp_.nbest_encode_as_serialized_proto(text, 10) | |
t4 = self.sp_.decode_pieces_as_serialized_proto(['foo', 'bar']) | |
t5 = self.sp_.decode_ids_as_serialized_proto([20, 30]) | |
y1 = self.sp_.encode(text, out_type='serialized_proto') | |
y2 = self.sp_.encode( | |
text, enable_sampling=True, out_type='serialized_proto' | |
) | |
y3 = self.sp_.nbest_encode(text, out_type='serialized_proto', nbest_size=10) | |
y4 = self.sp_.decode(['foo', 'bar'], out_type='serialized_proto') | |
y5 = self.sp_.decode([20, 30], out_type='serialized_proto') | |
self.assertEqual(type(s1), bytes) | |
self.assertEqual(type(s2), bytes) | |
self.assertEqual(type(t2), bytes) | |
self.assertEqual(type(s3), bytes) | |
self.assertEqual(type(s4), bytes) | |
self.assertEqual(type(s5), bytes) | |
self.assertEqual(s1, t1) | |
self.assertEqual(s3, t3) | |
self.assertEqual(s4, t4) | |
self.assertEqual(s5, t5) | |
self.assertEqual(s1, y1) | |
self.assertEqual(s3, y3) | |
self.assertEqual(s4, y4) | |
self.assertEqual(s5, y5) | |
ids = self.jasp_.EncodeAsIds(text) | |
pieces = self.jasp_.EncodeAsPieces(text) | |
s1 = self.jasp_.EncodeAsSerializedProto(text) | |
s2 = self.jasp_.DecodeIdsAsSerializedProto(ids) | |
s3 = self.jasp_.DecodePiecesAsSerializedProto(ids) | |
self.assertEqual(s2, s1) | |
self.assertEqual(s3, s1) | |
def test_decode_bytes(self): | |
texts = ['Hello world', '清水寺は京都にある。'] | |
ids = self.jasp_.encode(texts, out_type=int) | |
s1 = self.jasp_.decode(ids, out_type=bytes) | |
s2 = self.jasp_.decode(ids, out_type=str) | |
self.assertEqual(len(s1), 2) | |
self.assertEqual(type(s1[0]), bytes) | |
self.assertEqual(type(s1[1]), bytes) | |
self.assertEqual(len(s2), 2) | |
self.assertEqual(type(s2[0]), str) | |
self.assertEqual(type(s2[1]), str) | |
self.assertEqual(s1[0].decode(encoding='utf-8'), s2[0]) | |
self.assertEqual(s1[1].decode(encoding='utf-8'), s2[1]) | |
text = 'Hello world' | |
ids = self.jasp_.encode(text, out_type=int) | |
s1 = self.jasp_.decode(ids, out_type=bytes) | |
s2 = self.jasp_.decode(ids, out_type=str) | |
self.assertEqual(type(s1), bytes) | |
self.assertEqual(type(s2), str) | |
self.assertEqual(s1.decode(encoding='utf-8'), s2) | |
x = self.jasp_.encode(text, out_type='immutable_proto') | |
self.assertEqual(x.text, x.text_as_bytes.decode(encoding='utf-8')) | |
for sp in x.pieces: | |
self.assertEqual(sp.piece, sp.piece_as_bytes.decode(encoding='utf-8')) | |
self.assertEqual(sp.surface, sp.surface_as_bytes.decode(encoding='utf-8')) | |
x = self.jasp_.decode(ids, out_type='immutable_proto') | |
self.assertEqual(x.text, x.text_as_bytes.decode(encoding='utf-8')) | |
for sp in x.pieces: | |
self.assertEqual(sp.piece, sp.piece_as_bytes.decode(encoding='utf-8')) | |
self.assertEqual(sp.surface, sp.surface_as_bytes.decode(encoding='utf-8')) | |
def test_immutable_proto(self): | |
text = 'I saw a girl with a telescope.' | |
s1 = self.sp_.EncodeAsImmutableProto(text) | |
s2 = self.sp_.SampleEncodeAsImmutableProto(text, 10, 0.2) | |
s3 = self.sp_.NBestEncodeAsImmutableProto(text, 10) | |
s4 = self.sp_.DecodePiecesAsImmutableProto(['foo', 'bar']) | |
s5 = self.sp_.DecodeIdsAsImmutableProto([20, 30]) | |
print(s1) | |
print(s2) | |
print(s3) | |
print(s4) | |
print(s5) | |
t1 = self.sp_.encode_as_immutable_proto(text) | |
t2 = self.sp_.sample_encode_as_immutable_proto(text, 10, 0.2) | |
t3 = self.sp_.nbest_encode_as_immutable_proto(text, 10) | |
t4 = self.sp_.decode_pieces_as_immutable_proto(['foo', 'bar']) | |
t5 = self.sp_.decode_ids_as_immutable_proto([20, 30]) | |
y1 = self.sp_.encode(text, out_type='immutable_proto') | |
y2 = self.sp_.encode(text, enable_sampling=True, out_type='immutable_proto') | |
y3 = self.sp_.nbest_encode(text, out_type='immutable_proto', nbest_size=10) | |
y4 = self.sp_.decode(['foo', 'bar'], out_type='immutable_proto') | |
y5 = self.sp_.decode([20, 30], out_type='immutable_proto') | |
self.assertEqual(s1, t1) | |
self.assertEqual(s3, t3) | |
self.assertEqual(s4, t4) | |
self.assertEqual(s5, t5) | |
self.assertEqual(s1, y1) | |
self.assertEqual(s3, y3) | |
self.assertEqual(s4, y4) | |
self.assertEqual(s5, y5) | |
hset_piece = defaultdict(int) | |
# eq test | |
for i in range(len(s1.pieces)): | |
self.assertEqual(s1.pieces[i], t1.pieces[i]) | |
hset_piece[s1.pieces[i]] += 1 | |
hset_piece[t1.pieces[i]] += 1 | |
self.assertEqual(len(hset_piece), len(s1.pieces)) | |
# has test | |
hset = defaultdict(int) | |
hset[s1] += 1 | |
hset[t1] += 1 | |
hset[s3] += 1 | |
hset[t3] += 1 | |
self.assertEqual(len(hset), 2) | |
self.assertEqual(hset[s1], 2) | |
self.assertEqual(hset[s3], 2) | |
self.assertEqual(hset[t1], 2) | |
self.assertEqual(hset[t3], 2) | |
x1 = self.sp_.encode_as_serialized_proto(text) | |
x2 = self.sp_.sample_encode_as_serialized_proto(text, 10, 0.2) | |
x3 = self.sp_.nbest_encode_as_serialized_proto(text, 10) | |
x4 = self.sp_.decode_pieces_as_serialized_proto(['foo', 'bar']) | |
x5 = self.sp_.decode_ids_as_serialized_proto([20, 30]) | |
self.assertEqual(x1, t1.SerializeAsString()) | |
self.assertEqual(x3, t3.SerializeAsString()) | |
self.assertEqual(x4, t4.SerializeAsString()) | |
self.assertEqual(x5, t5.SerializeAsString()) | |
v1 = self.sp_.EncodeAsIds(text) | |
v2 = self.sp_.EncodeAsPieces(text) | |
self.assertEqual([x.id for x in s1.pieces], v1) | |
self.assertEqual([x.piece for x in s1.pieces], v2) | |
self.assertEqual(text, s1.text) | |
surfaces1 = [s1.text[x.begin : x.end] for x in s1.pieces] | |
surfaces2 = [x.surface for x in s1.pieces] | |
self.assertEqual(surfaces1, surfaces2) | |
ids = [] | |
for i in range(len(s1.pieces)): | |
ids.append(s1.pieces[i].id) | |
self.assertEqual(ids, v1) | |
pieces = [] | |
for i in range(len(s1.pieces)): | |
pieces.append(s1.pieces[i].piece) | |
self.assertEqual(pieces, v2) | |
for v in s3.nbests: | |
self.assertEqual(text, v.text) | |
self.assertEqual(self.sp_.Decode([x.id for x in v.pieces]), text) | |
for i in range(len(s3.nbests)): | |
self.assertEqual(text, s3.nbests[i].text) | |
self.assertEqual( | |
self.sp_.Decode([x.id for x in s3.nbests[i].pieces]), text | |
) | |
# slice | |
self.assertEqual(s1.pieces[::-1], list(reversed(s1.pieces))) | |
self.assertEqual(s3.nbests[::-1], list(reversed(s3.nbests))) | |
# Japanese offset | |
s1 = self.jasp_.EncodeAsImmutableProto( | |
'吾輩は猫である。Hello world. ABC 123' | |
) | |
surfaces1 = [s1.text[x.begin : x.end] for x in s1.pieces] | |
surfaces2 = [x.surface for x in s1.pieces] | |
self.assertEqual(surfaces1, surfaces2) | |
ids = [x.id for x in s1.pieces] | |
s2 = self.jasp_.DecodeIdsAsImmutableProto(ids) | |
self.assertEqual(s2, s1) | |
pieces = [x.piece for x in s1.pieces] | |
s2 = self.jasp_.DecodePiecesAsImmutableProto(pieces) | |
self.assertEqual(s2, s1) | |
def test_new_api(self): | |
sp = spm.SentencePieceProcessor( | |
model_file=os.path.join('test', 'test_model.model') | |
) | |
text = 'hello world' | |
text2 = 'Tokyo' | |
ids = self.sp_.EncodeAsIds(text) | |
ids2 = self.sp_.EncodeAsIds(text2) | |
pieces = self.sp_.EncodeAsPieces(text) | |
pieces2 = self.sp_.EncodeAsPieces(text2) | |
sprotos = self.sp_.EncodeAsSerializedProto(text) | |
sproto2 = self.sp_.EncodeAsSerializedProto(text2) | |
iprotos = self.sp_.EncodeAsImmutableProto(text) | |
iprotos2 = self.sp_.EncodeAsImmutableProto(text2) | |
self.assertEqual(sp.encode(text, out_type=int), ids) | |
self.assertEqual(sp.encode(text, out_type=str), pieces) | |
self.assertEqual(sp.encode(text, out_type='serialized_proto'), sprotos) | |
self.assertEqual(sp.encode(text, out_type='immutable_proto'), iprotos) | |
self.assertEqual(sp.encode([text], out_type=int), [ids]) | |
self.assertEqual(sp.encode([text], out_type=str), [pieces]) | |
self.assertEqual(sp.encode([text], out_type='serialized_proto'), [sprotos]) | |
self.assertEqual(sp.encode([text], out_type='immutable_proto'), [iprotos]) | |
self.assertEqual(len(iprotos.pieces), len(pieces)) | |
self.assertEqual(len(iprotos.pieces), len(ids)) | |
self.assertEqual(iprotos.text, text) | |
self.assertEqual(len(iprotos2.pieces), len(pieces2)) | |
self.assertEqual(len(iprotos2.pieces), len(ids2)) | |
self.assertEqual(iprotos2.text, text2) | |
for i in range(len(iprotos.pieces)): | |
self.assertEqual(ids[i], iprotos.pieces[i].id) | |
self.assertEqual(pieces[i], iprotos.pieces[i].piece) | |
for i, piece in enumerate(iprotos.pieces): | |
self.assertEqual(ids[i], piece.id) | |
self.assertEqual(pieces[i], piece.piece) | |
for i in range(len(iprotos2.pieces)): | |
self.assertEqual(ids2[i], iprotos2.pieces[i].id) | |
self.assertEqual(pieces2[i], iprotos2.pieces[i].piece) | |
for i, piece in enumerate(iprotos2.pieces): | |
self.assertEqual(ids2[i], piece.id) | |
self.assertEqual(pieces2[i], piece.piece) | |
detok_ids = self.sp_.DecodeIds(ids) | |
detok_pieces = self.sp_.DecodePieces(pieces) | |
self.assertEqual(sp.decode(ids), detok_ids) | |
self.assertEqual(sp.decode(pieces), detok_pieces) | |
self.assertEqual(sp.decode([]), '') | |
self.assertEqual(sp.decode([[]]), ['']) | |
# add_bos, add_eos, reverse | |
self.assertEqual([sp.bos_id()] + ids, sp.encode(text, add_bos=True)) | |
self.assertEqual(ids + [sp.eos_id()], sp.encode(text, add_eos=True)) | |
self.assertEqual(ids + [sp.eos_id()], sp.EncodeAsIds(text, add_eos=True)) | |
rids = ids[:] | |
rids.reverse() | |
self.assertEqual(rids, sp.encode(text, reverse=True)) | |
self.assertEqual(rids, sp.EncodeAsIds(text, reverse=True)) | |
# different shape. | |
self.assertEqual([ids, ids2], sp.encode([text, text2])) | |
self.assertEqual([pieces, pieces2], sp.encode([text, text2], out_type=str)) | |
self.assertEqual([text, text2], sp.decode([ids, ids2])) | |
self.assertEqual([text, text2], sp.decode([pieces, pieces2])) | |
pieces = list(reversed(self.sp_.EncodeAsPieces(text))) | |
self.assertEqual(pieces, sp.encode(text, reverse=True, out_type=str)) | |
# emit unk piece | |
unk_char = '藤' | |
pieces = self.sp_.EncodeAsIds(unk_char, emit_unk_piece=True) | |
pieces2 = self.sp_.encode(unk_char, out_type=int, emit_unk_piece=True) | |
self.assertEqual(pieces[1], sp.unk_id()) | |
self.assertEqual(pieces2[1], sp.unk_id()) | |
self.assertEqual(pieces, pieces2) | |
pieces = self.sp_.EncodeAsPieces(unk_char, emit_unk_piece=True) | |
pieces2 = self.sp_.encode(unk_char, out_type=str, emit_unk_piece=True) | |
self.assertEqual(pieces[1], '<unk>') | |
self.assertEqual(pieces2[1], '<unk>') | |
self.assertEqual(pieces, pieces2) | |
pieces = self.sp_.EncodeAsPieces(unk_char, emit_unk_piece=False) | |
pieces2 = self.sp_.encode(unk_char, out_type=str, emit_unk_piece=False) | |
self.assertEqual(pieces[1], unk_char) | |
self.assertEqual(pieces2[1], unk_char) | |
self.assertEqual(pieces, pieces2) | |
def test_new_api_init(self): | |
sp = spm.SentencePieceProcessor( | |
model_file=os.path.join('test', 'test_model.model'), | |
add_bos=True, | |
add_eos=True, | |
out_type=str, | |
) | |
text = 'hello world' | |
pieces = ['<s>'] + self.sp_.EncodeAsPieces(text) + ['</s>'] | |
self.assertEqual(pieces, sp.encode(text)) | |
pieces = self.sp_.EncodeAsPieces(text) + ['</s>'] | |
self.assertEqual(pieces, sp.encode(text, add_bos=False, add_eos=True)) | |
def test_sampling(self): | |
sp = self.sp_ | |
for out_type in [str, int, 'serialized_proto', 'immutable_proto']: | |
ids = defaultdict(int) | |
for n in range(100): | |
out = sp.encode('hello world', out_type=out_type, enable_sampling=True) | |
if type(out) is list: | |
out = tuple(out) | |
++ids[out] | |
self.assertGreater(len(ids), 1) | |
ids2 = defaultdict(int) | |
for n in range(100): | |
out = sp.encode('hello world', out_type=out_type, enable_sampling=False) | |
if type(out) is list: | |
out = tuple(out) | |
++ids2[out] | |
self.assertEqual(len(ids2), 1) | |
out = sp.encode( | |
['hello world', 'this is a test'], | |
out_type=out_type, | |
enable_sampling=True, | |
) | |
self.assertEqual(len(out), 2) | |
out = sp.encode( | |
['hello world', 'this is a test'], | |
out_type=out_type, | |
enable_sampling=False, | |
) | |
self.assertEqual(len(out), 2) | |
def test_nbest(self): | |
sp = self.sp_ | |
text = 'hello world' | |
text2 = 'I have a pen.' | |
for out_type in [str, int, 'serialized_proto', 'immutable_proto']: | |
results = sp.nbest_encode(text, nbest_size=10, out_type=out_type) | |
self.assertEqual( | |
results, sp.NBestEncode(text, nbest_size=10, out_type=out_type) | |
) | |
if out_type in [str, int]: | |
for n in results: | |
self.assertEqual(sp.decode(n), text) | |
for n in sp.decode(results): | |
self.assertEqual(n, text) | |
# batch test | |
results = sp.nbest_encode([text, text2], nbest_size=10, out_type=out_type) | |
self.assertEqual( | |
results, | |
sp.NBestEncode([text, text2], nbest_size=10, out_type=out_type), | |
) | |
self.assertEqual(len(results), 2) | |
if out_type in [str, int]: | |
for n in results[0]: | |
self.assertEqual(sp.decode(n), text) | |
for n in results[1]: | |
self.assertEqual(sp.decode(n), text2) | |
decoded = sp.decode(results[0]) | |
self.assertEqual(len(decoded), 10) | |
for n in decoded: | |
self.assertEqual(n, text) | |
decoded = sp.decode(results[1]) | |
self.assertEqual(len(decoded), 10) | |
for n in decoded: | |
self.assertEqual(n, text2) | |
self.assertEqual( | |
sp.nbest_encode(text, nbest_size=10, out_type=str), | |
sp.nbest_encode_as_pieces(text, nbest_size=10), | |
) | |
self.assertEqual( | |
sp.nbest_encode(text, nbest_size=10, out_type=int), | |
sp.nbest_encode_as_ids(text, nbest_size=10), | |
) | |
self.assertEqual( | |
sp.nbest_encode(text, nbest_size=10, out_type='serialized_proto'), | |
sp.nbest_encode_as_serialized_proto(text, nbest_size=10), | |
) | |
self.assertEqual( | |
sp.nbest_encode(text, nbest_size=10, out_type='immutable_proto'), | |
sp.nbest_encode_as_immutable_proto(text, nbest_size=10), | |
) | |
def test_sample_and_score(self): | |
sp = self.sp_ | |
text = 'hello world' | |
text2 = 'I have a pen.' | |
for out_type in [str, int, 'serialized_proto', 'immutable_proto']: | |
results = sp.sample_encode_and_score( | |
text, wor=True, num_samples=10, out_type=out_type | |
) | |
results = sp.SampleEncodeAndScore( | |
text, wor=False, num_samples=10, out_type=out_type | |
) | |
if out_type in [str, int]: | |
for n in results: | |
self.assertEqual(sp.decode(n[0]), text) | |
results = sp.sample_encode_and_score( | |
[text, text2], wor=True, num_samples=10, out_type=out_type | |
) | |
results = sp.SampleEncodeAndScore( | |
[text, text2], wor=True, num_samples=10, out_type=out_type | |
) | |
if out_type in [str, int]: | |
for n in results[0]: | |
self.assertEqual(sp.decode(n[0]), text) | |
for n in results[1]: | |
self.assertEqual(sp.decode(n[0]), text2) | |
sp.sample_encode_and_score_as_pieces(text, 10) | |
sp.sample_encode_and_score_as_ids(text, 10) | |
sp.sample_encode_and_score_as_immutable_proto(text, 10) | |
sp.sample_encode_and_score_as_serialized_proto(text, 10) | |
def test_valid_range(self): | |
size = self.sp_.piece_size() | |
funcs = [ | |
'IdToPiece', | |
'GetScore', | |
'IsUnknown', | |
'IsControl', | |
'IsUnused', | |
'IsByte', | |
'DecodeIds', | |
'DecodeIdsAsSerializedProto', | |
] | |
for m in funcs: | |
getattr(self.sp_, m)([10, 20, 30]) | |
for m in funcs: | |
try: | |
getattr(self.sp_, m)([size]) | |
self.assertTrue(False) | |
except: | |
self.assertTrue(True) | |
def test_batch(self): | |
sp = spm.SentencePieceProcessor( | |
model_file=os.path.join('test', 'test_model.model') | |
) | |
with open(os.path.join(data_dir, 'botchan.txt'), 'r') as file: | |
texts = file.readlines() | |
for out_type in [str, int, 'serialized_proto', 'immutable_proto']: | |
r1 = sp.encode(texts, out_type=out_type, num_threads=None) | |
r2 = sp.encode(texts, out_type=out_type, num_threads=1) | |
r3 = sp.encode(texts, out_type=out_type, num_threads=-1) | |
r4 = sp.encode(texts, out_type=out_type, num_threads=8) | |
r5 = [sp.encode(s, out_type=out_type) for s in texts] | |
self.assertEqual(r1, r2) | |
self.assertEqual(r1, r3) | |
self.assertEqual(r1, r4) | |
self.assertEqual(r1, r5) | |
if out_type in [str, int]: | |
d1 = sp.decode(r1, num_threads=None) | |
d2 = sp.decode(r2, num_threads=1) | |
d3 = sp.decode(r3, num_threads=-1) | |
d4 = sp.decode(r4, num_threads=8) | |
d5 = [sp.decode(s) for s in r5] | |
self.assertEqual(d1, d2) | |
self.assertEqual(d1, d3) | |
self.assertEqual(d1, d4) | |
self.assertEqual(d1, d5) | |
e1 = sp.calculate_entropy(texts, alpha=1.0, num_threads=10) | |
e2 = sp.CalculateEntropy(texts, alpha=1.0, num_threads=10) | |
e3 = [sp.calculate_entropy(s, alpha=1.0) for s in texts] | |
self.assertEqual(e1, e2) | |
self.assertEqual(e1, e3) | |
def test_pickle(self): | |
with open('sp.pickle', 'wb') as f: | |
pickle.dump(self.sp_, f) | |
id1 = self.sp_.encode('hello world.', out_type=int) | |
with open('sp.pickle', 'rb') as f: | |
sp = pickle.load(f) | |
id2 = sp.encode('hello world.', out_type=int) | |
self.assertEqual(id1, id2) | |
def test_global_params(self): | |
spm.SetRandomGeneratorSeed(0) | |
spm.SetMinLogLevel(2) | |
spm.set_random_generator_seed(1) | |
spm.set_min_log_level(3) | |
def test_normalize(self): | |
sp = spm.SentencePieceProcessor( | |
model_file=os.path.join('test', 'test_model.model') | |
) | |
self.assertEqual('▁KADOKAWAABC', sp.normalize('KADOKAWAABC')) | |
self.assertEqual('▁KADOKAWAABC', sp.Normalize('KADOKAWAABC')) | |
x = sp.Normalize('KADOKAWAABC', with_offsets=True) | |
self.assertEqual('▁KADOKAWAABC', x[0]) | |
self.assertEqual([0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], x[1]) | |
x = sp.Normalize('KADOKAWAABC'.encode('utf8'), with_offsets=True) | |
self.assertEqual('▁KADOKAWAABC'.encode('utf8'), x[0]) | |
self.assertEqual( | |
[0, 0, 0, 0, 3, 6, 9, 12, 15, 18, 21, 24, 25, 26, 27], x[1] | |
) | |
self.assertEqual( | |
['▁KADOKAWAABC', '▁平成'], sp.normalize(['KADOKAWAABC', '㍻']) | |
) | |
self.assertEqual( | |
['▁KADOKAWAABC', '▁平成'], sp.Normalize(['KADOKAWAABC', '㍻']) | |
) | |
x = sp.Normalize( | |
['KADOKAWAABC'.encode('utf8'), '㍻'.encode('utf8')], | |
with_offsets=True, | |
) | |
self.assertEqual(len(x), 2) | |
self.assertEqual('▁KADOKAWAABC'.encode('utf8'), x[0][0]) | |
self.assertEqual( | |
[0, 0, 0, 0, 3, 6, 9, 12, 15, 18, 21, 24, 25, 26, 27], x[0][1] | |
) | |
x = sp.Normalize(['KADOKAWAABC', '㍻'], with_offsets=True) | |
self.assertEqual(len(x), 2) | |
self.assertEqual('▁KADOKAWAABC', x[0][0]) | |
self.assertEqual([0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], x[0][1]) | |
self.assertEqual('▁平成', x[1][0]) | |
self.assertEqual([0, 0, 0, 1], x[1][1]) | |
def test_normalizer(self): | |
sp = spm.SentencePieceNormalizer( | |
model_file=os.path.join('test', 'test_model.model') | |
) | |
self.assertEqual('KADOKAWAABC', sp.normalize('KADOKAWAABC')) | |
self.assertEqual('KADOKAWAABC', sp.Normalize('KADOKAWAABC')) | |
x = sp.Normalize('KADOKAWAABC'.encode('utf8'), with_offsets=True) | |
self.assertEqual('KADOKAWAABC'.encode('utf8'), x[0]) | |
self.assertEqual([0, 3, 6, 9, 12, 15, 18, 21, 24, 25, 26, 27], x[1]) | |
x = sp.Normalize('KADOKAWAABC', with_offsets=True) | |
self.assertEqual('KADOKAWAABC', x[0]) | |
self.assertEqual([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], x[1]) | |
self.assertEqual( | |
['KADOKAWAABC', '平成'], sp.normalize(['KADOKAWAABC', '㍻']) | |
) | |
self.assertEqual( | |
['KADOKAWAABC', '平成'], sp.Normalize(['KADOKAWAABC', '㍻']) | |
) | |
x = sp.Normalize( | |
['KADOKAWAABC'.encode('utf8'), '㍻'.encode('utf8')], | |
with_offsets=True, | |
) | |
self.assertEqual(len(x), 2) | |
self.assertEqual('KADOKAWAABC'.encode('utf8'), x[0][0]) | |
self.assertEqual([0, 3, 6, 9, 12, 15, 18, 21, 24, 25, 26, 27], x[0][1]) | |
x = sp.Normalize(['KADOKAWAABC', '㍻'], with_offsets=True) | |
self.assertEqual(len(x), 2) | |
self.assertEqual('KADOKAWAABC', x[0][0]) | |
self.assertEqual([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], x[0][1]) | |
self.assertEqual('平成', x[1][0]) | |
self.assertEqual([0, 0, 1], x[1][1]) | |
sp = spm.SentencePieceNormalizer( | |
model_file=os.path.join('test', 'test_model.model'), | |
add_dummy_prefix=True, | |
escape_whitespaces=True, | |
remove_extra_whitespaces=False, | |
) | |
self.assertEqual('▁hello▁▁world', sp.normalize('hello world')) | |
sp = spm.SentencePieceNormalizer( | |
model_file=os.path.join('test', 'test_model.model'), | |
add_dummy_prefix=True, | |
escape_whitespaces=True, | |
remove_extra_whitespaces=True, | |
) | |
self.assertEqual('▁hello▁world', sp.normalize(' hello world ')) | |
sp = spm.SentencePieceNormalizer( | |
model_file=os.path.join('test', 'test_model.model'), | |
add_dummy_prefix=False, | |
escape_whitespaces=False, | |
remove_extra_whitespaces=True, | |
) | |
self.assertEqual('hello world', sp.normalize(' hello world ')) | |
def test_normalizer_rule(self): | |
sp = spm.SentencePieceNormalizer(rule_name='identity') | |
self.assertEqual('ABC', sp.Normalize('ABC')) | |
sp = spm.SentencePieceNormalizer(rule_name='nfkc_cf') | |
self.assertEqual('abc', sp.Normalize('ABC')) | |
def test_override_normalize_spec(self): | |
sp = spm.SentencePieceProcessor( | |
model_file=os.path.join('test', 'test_model.model') | |
) | |
self.assertEqual( | |
sp.EncodeAsPieces(' hello world '), ['▁he', 'll', 'o', '▁world'] | |
) | |
sp.override_normalizer_spec(add_dummy_prefix=False) | |
sp.override_normalizer_spec(remove_extra_whitespaces=False) | |
sp.override_normalizer_spec(escape_whitespaces=False) | |
self.assertEqual( | |
sp.EncodeAsPieces(' hello world '), | |
[' ', 'he', 'll', 'o', ' ', 'w', 'or', 'l', 'd', ' '], | |
) | |
def suite(): | |
suite = unittest.TestSuite() | |
suite.addTests(unittest.makeSuite(TestSentencepieceProcessor)) | |
return suite | |
if __name__ == '__main__': | |
unittest.main() | |