KuangDW
add laser tool
2aebc50
#!/usr/bin/python
# -*- coding: utf-8 -*-
# Copyright 2018 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.!
import sys
sys.path.insert(0, 'src')
from collections import defaultdict
import io
import os
import pickle
import unittest
import sentencepiece as spm
print('VERSION={}'.format(spm.__version__))
data_dir = 'test'
if sys.platform == 'win32':
data_dir = os.path.join('..', 'data')
class TestSentencepieceProcessor(unittest.TestCase):
"""Test case for SentencePieceProcessor"""
def setUp(self):
self.sp_ = spm.SentencePieceProcessor()
self.jasp_ = spm.SentencePieceProcessor()
self.assertTrue(self.sp_.Load(os.path.join('test', 'test_model.model')))
self.assertTrue(
self.jasp_.Load(os.path.join('test', 'test_ja_model.model'))
)
with open(os.path.join('test', 'test_model.model'), 'rb') as f:
self.assertTrue(self.sp_.LoadFromSerializedProto(f.read()))
with open(os.path.join('test', 'test_ja_model.model'), 'rb') as f:
self.assertTrue(self.jasp_.LoadFromSerializedProto(f.read()))
def test_load(self):
self.assertEqual(1000, self.sp_.GetPieceSize())
self.assertEqual(0, self.sp_.PieceToId('<unk>'))
self.assertEqual(1, self.sp_.PieceToId('<s>'))
self.assertEqual(2, self.sp_.PieceToId('</s>'))
self.assertEqual('<unk>', self.sp_.IdToPiece(0))
self.assertEqual('<s>', self.sp_.IdToPiece(1))
self.assertEqual('</s>', self.sp_.IdToPiece(2))
self.assertEqual(0, self.sp_.unk_id())
self.assertEqual(1, self.sp_.bos_id())
self.assertEqual(2, self.sp_.eos_id())
self.assertEqual(-1, self.sp_.pad_id())
for i in range(self.sp_.GetPieceSize()):
piece = self.sp_.IdToPiece(i)
self.assertEqual(i, self.sp_.PieceToId(piece))
self.assertEqual(1000, self.sp_.get_piece_size())
self.assertEqual(0, self.sp_.piece_to_id('<unk>'))
self.assertEqual(1, self.sp_.piece_to_id('<s>'))
self.assertEqual(2, self.sp_.piece_to_id('</s>'))
self.assertEqual('<unk>', self.sp_.id_to_piece(0))
self.assertEqual('<s>', self.sp_.id_to_piece(1))
self.assertEqual('</s>', self.sp_.id_to_piece(2))
for i in range(self.sp_.get_piece_size()):
piece = self.sp_.id_to_piece(i)
self.assertEqual(i, self.sp_.piece_to_id(piece))
def test_roundtrip(self):
text = 'I saw a girl with a telescope.'
ids = self.sp_.EncodeAsIds(text)
pieces1 = self.sp_.EncodeAsPieces(text)
pieces2 = self.sp_.NBestEncodeAsPieces(text, 10)[0]
self.assertEqual(pieces1, pieces2)
self.assertEqual(text, self.sp_.DecodePieces(pieces1))
self.assertEqual(text, self.sp_.DecodeIds(ids))
for n in range(100):
self.assertEqual(
text,
self.sp_.DecodePieces(self.sp_.SampleEncodeAsPieces(text, 64, 0.5)),
)
self.assertEqual(
text,
self.sp_.DecodePieces(self.sp_.SampleEncodeAsPieces(text, -1, 0.5)),
)
self.assertEqual(
text, self.sp_.DecodeIds(self.sp_.SampleEncodeAsIds(text, 64, 0.5))
)
self.assertEqual(
text, self.sp_.DecodeIds(self.sp_.SampleEncodeAsIds(text, -1, 0.5))
)
ids2 = self.sp_.encode_as_ids(text)
pieces3 = self.sp_.encode_as_pieces(text)
pieces4 = self.sp_.nbest_encode_as_pieces(text, 10)[0]
self.assertEqual(pieces3, pieces4)
self.assertEqual(pieces1, pieces3)
self.assertEqual(ids, ids2)
self.assertEqual(text, self.sp_.decode_pieces(pieces3))
self.assertEqual(text, self.sp_.decode_ids(ids2))
for n in range(100):
self.assertEqual(
text,
self.sp_.decode_pieces(
self.sp_.sample_encode_as_pieces(text, 64, 0.5)
),
)
self.assertEqual(
text,
self.sp_.decode_pieces(
self.sp_.sample_encode_as_pieces(text, -1, 0.5)
),
)
self.assertEqual(
text,
self.sp_.decode_ids(self.sp_.sample_encode_as_ids(text, 64, 0.5)),
)
self.assertEqual(
text,
self.sp_.decode_ids(self.sp_.sample_encode_as_ids(text, -1, 0.5)),
)
self.assertEqual(
self.sp_.calculate_entropy(text, 0.1),
self.sp_.CalculateEntropy(text, 0.1),
)
def test_ja_load(self):
self.assertEqual(8000, self.jasp_.GetPieceSize())
self.assertEqual(0, self.jasp_.PieceToId('<unk>'))
self.assertEqual(1, self.jasp_.PieceToId('<s>'))
self.assertEqual(2, self.jasp_.PieceToId('</s>'))
self.assertEqual('<unk>', self.jasp_.IdToPiece(0))
self.assertEqual('<s>', self.jasp_.IdToPiece(1))
self.assertEqual('</s>', self.jasp_.IdToPiece(2))
for i in range(self.jasp_.GetPieceSize()):
piece = self.jasp_.IdToPiece(i)
self.assertEqual(i, self.jasp_.PieceToId(piece))
self.assertEqual(8000, self.jasp_.get_piece_size())
self.assertEqual(0, self.jasp_.piece_to_id('<unk>'))
self.assertEqual(1, self.jasp_.piece_to_id('<s>'))
self.assertEqual(2, self.jasp_.piece_to_id('</s>'))
self.assertEqual('<unk>', self.jasp_.id_to_piece(0))
self.assertEqual('<s>', self.jasp_.id_to_piece(1))
self.assertEqual('</s>', self.jasp_.id_to_piece(2))
for i in range(self.jasp_.get_piece_size()):
piece = self.jasp_.id_to_piece(i)
self.assertEqual(i, self.jasp_.piece_to_id(piece))
def test_ja_roundtrip(self):
text = '清水寺は京都にある。'
ids = self.jasp_.EncodeAsIds(text)
pieces1 = self.jasp_.EncodeAsPieces(text)
pieces2 = self.jasp_.NBestEncodeAsPieces(text, 10)[0]
self.assertEqual(pieces1, pieces2)
self.assertEqual(text, self.jasp_.DecodePieces(pieces1))
self.assertEqual(text, self.jasp_.DecodeIds(ids))
for n in range(100):
self.assertEqual(
text,
self.jasp_.DecodePieces(
self.jasp_.SampleEncodeAsPieces(text, 64, 0.5)
),
)
self.assertEqual(
text,
self.jasp_.DecodePieces(
self.jasp_.SampleEncodeAsPieces(text, -1, 0.5)
),
)
ids2 = self.jasp_.encode_as_ids(text)
pieces3 = self.jasp_.encode_as_pieces(text)
pieces4 = self.jasp_.nbest_encode_as_pieces(text, 10)[0]
self.assertEqual(pieces3, pieces4)
self.assertEqual(pieces1, pieces3)
self.assertEqual(ids, ids2)
self.assertEqual(text, self.jasp_.decode_pieces(pieces1))
self.assertEqual(text, self.jasp_.decode_ids(ids2))
for n in range(100):
self.assertEqual(
text,
self.jasp_.decode_pieces(
self.jasp_.sample_encode_as_pieces(text, 64, 0.5)
),
)
self.assertEqual(
text,
self.jasp_.decode_pieces(
self.jasp_.sample_encode_as_pieces(text, -1, 0.5)
),
)
self.assertEqual(
self.jasp_.calculate_entropy(text, 0.1),
self.jasp_.CalculateEntropy(text, 0.1),
)
def test_train(self):
spm.SentencePieceTrainer.Train(
'--input='
+ os.path.join(data_dir, 'botchan.txt')
+ ' --model_prefix=m --vocab_size=1000'
)
sp = spm.SentencePieceProcessor()
sp.Load('m.model')
with open(os.path.join(data_dir, 'botchan.txt'), 'r') as file:
for line in file:
sp.DecodePieces(sp.EncodeAsPieces(line))
sp.DecodeIds(sp.EncodeAsIds(line))
def test_train_iterator(self):
spm.SentencePieceTrainer.Train(
'--input='
+ os.path.join(data_dir, 'botchan.txt')
+ ' --model_prefix=m --vocab_size=1000'
)
# Load as 'rb' for Python3.5/2.7.
os1 = io.BytesIO()
os2 = io.BytesIO()
# suppress logging (redirect to /dev/null)
spm.SentencePieceTrainer.train(
input=os.path.join(data_dir, 'botchan.txt'),
model_prefix='m',
vocab_size=1000,
logstream=open(os.devnull, 'w'),
)
with open(os.path.join(data_dir, 'botchan.txt'), 'rb') as is1:
spm.SentencePieceTrainer.train(
sentence_iterator=is1,
model_prefix='m',
vocab_size=1000,
logstream=open(os.devnull, 'w'),
)
spm.SentencePieceTrainer.train(
input=os.path.join(data_dir, 'botchan.txt'),
model_writer=os1,
vocab_size=1000,
logstream=open(os.devnull, 'w'),
)
with open(os.path.join(data_dir, 'botchan.txt'), 'rb') as is2:
spm.SentencePieceTrainer.train(
sentence_iterator=is2,
model_writer=os2,
vocab_size=1000,
logstream=open(os.devnull, 'w'),
)
sp1 = spm.SentencePieceProcessor(model_proto=os1.getvalue())
sp2 = spm.SentencePieceProcessor(model_proto=os2.getvalue())
self.assertEqual(
[sp1.id_to_piece(i) for i in range(sp1.get_piece_size())],
[sp2.id_to_piece(i) for i in range(sp2.get_piece_size())],
)
def test_train_kwargs(self):
# suppress logging (redirect to /dev/null)
spm.SentencePieceTrainer.train(
input=[os.path.join(data_dir, 'botchan.txt')],
model_prefix='m',
vocab_size=1002,
user_defined_symbols=['foo', 'bar', ',', ' ', '\t', '\b', '\n', '\r'],
logstream=open(os.devnull, 'w'),
)
sp = spm.SentencePieceProcessor()
sp.Load('m.model')
with open(os.path.join(data_dir, 'botchan.txt'), 'r') as file:
for line in file:
sp.DecodePieces(sp.EncodeAsPieces(line))
sp.DecodeIds(sp.EncodeAsIds(line))
s = 'hello\tworld\r\nthis\tis a \b pen'
self.assertEqual(s, sp.decode(sp.encode(s)))
def test_serialized_proto(self):
text = 'I saw a girl with a telescope.'
s1 = self.sp_.EncodeAsSerializedProto(text)
s2 = self.sp_.SampleEncodeAsSerializedProto(text, 10, 0.2)
s3 = self.sp_.NBestEncodeAsSerializedProto(text, 10)
s4 = self.sp_.DecodePiecesAsSerializedProto(['foo', 'bar'])
s5 = self.sp_.DecodeIdsAsSerializedProto([20, 30])
t1 = self.sp_.encode_as_serialized_proto(text)
t2 = self.sp_.sample_encode_as_serialized_proto(text, 10, 0.2)
t3 = self.sp_.nbest_encode_as_serialized_proto(text, 10)
t4 = self.sp_.decode_pieces_as_serialized_proto(['foo', 'bar'])
t5 = self.sp_.decode_ids_as_serialized_proto([20, 30])
y1 = self.sp_.encode(text, out_type='serialized_proto')
y2 = self.sp_.encode(
text, enable_sampling=True, out_type='serialized_proto'
)
y3 = self.sp_.nbest_encode(text, out_type='serialized_proto', nbest_size=10)
y4 = self.sp_.decode(['foo', 'bar'], out_type='serialized_proto')
y5 = self.sp_.decode([20, 30], out_type='serialized_proto')
self.assertEqual(type(s1), bytes)
self.assertEqual(type(s2), bytes)
self.assertEqual(type(t2), bytes)
self.assertEqual(type(s3), bytes)
self.assertEqual(type(s4), bytes)
self.assertEqual(type(s5), bytes)
self.assertEqual(s1, t1)
self.assertEqual(s3, t3)
self.assertEqual(s4, t4)
self.assertEqual(s5, t5)
self.assertEqual(s1, y1)
self.assertEqual(s3, y3)
self.assertEqual(s4, y4)
self.assertEqual(s5, y5)
ids = self.jasp_.EncodeAsIds(text)
pieces = self.jasp_.EncodeAsPieces(text)
s1 = self.jasp_.EncodeAsSerializedProto(text)
s2 = self.jasp_.DecodeIdsAsSerializedProto(ids)
s3 = self.jasp_.DecodePiecesAsSerializedProto(ids)
self.assertEqual(s2, s1)
self.assertEqual(s3, s1)
def test_decode_bytes(self):
texts = ['Hello world', '清水寺は京都にある。']
ids = self.jasp_.encode(texts, out_type=int)
s1 = self.jasp_.decode(ids, out_type=bytes)
s2 = self.jasp_.decode(ids, out_type=str)
self.assertEqual(len(s1), 2)
self.assertEqual(type(s1[0]), bytes)
self.assertEqual(type(s1[1]), bytes)
self.assertEqual(len(s2), 2)
self.assertEqual(type(s2[0]), str)
self.assertEqual(type(s2[1]), str)
self.assertEqual(s1[0].decode(encoding='utf-8'), s2[0])
self.assertEqual(s1[1].decode(encoding='utf-8'), s2[1])
text = 'Hello world'
ids = self.jasp_.encode(text, out_type=int)
s1 = self.jasp_.decode(ids, out_type=bytes)
s2 = self.jasp_.decode(ids, out_type=str)
self.assertEqual(type(s1), bytes)
self.assertEqual(type(s2), str)
self.assertEqual(s1.decode(encoding='utf-8'), s2)
x = self.jasp_.encode(text, out_type='immutable_proto')
self.assertEqual(x.text, x.text_as_bytes.decode(encoding='utf-8'))
for sp in x.pieces:
self.assertEqual(sp.piece, sp.piece_as_bytes.decode(encoding='utf-8'))
self.assertEqual(sp.surface, sp.surface_as_bytes.decode(encoding='utf-8'))
x = self.jasp_.decode(ids, out_type='immutable_proto')
self.assertEqual(x.text, x.text_as_bytes.decode(encoding='utf-8'))
for sp in x.pieces:
self.assertEqual(sp.piece, sp.piece_as_bytes.decode(encoding='utf-8'))
self.assertEqual(sp.surface, sp.surface_as_bytes.decode(encoding='utf-8'))
def test_immutable_proto(self):
text = 'I saw a girl with a telescope.'
s1 = self.sp_.EncodeAsImmutableProto(text)
s2 = self.sp_.SampleEncodeAsImmutableProto(text, 10, 0.2)
s3 = self.sp_.NBestEncodeAsImmutableProto(text, 10)
s4 = self.sp_.DecodePiecesAsImmutableProto(['foo', 'bar'])
s5 = self.sp_.DecodeIdsAsImmutableProto([20, 30])
print(s1)
print(s2)
print(s3)
print(s4)
print(s5)
t1 = self.sp_.encode_as_immutable_proto(text)
t2 = self.sp_.sample_encode_as_immutable_proto(text, 10, 0.2)
t3 = self.sp_.nbest_encode_as_immutable_proto(text, 10)
t4 = self.sp_.decode_pieces_as_immutable_proto(['foo', 'bar'])
t5 = self.sp_.decode_ids_as_immutable_proto([20, 30])
y1 = self.sp_.encode(text, out_type='immutable_proto')
y2 = self.sp_.encode(text, enable_sampling=True, out_type='immutable_proto')
y3 = self.sp_.nbest_encode(text, out_type='immutable_proto', nbest_size=10)
y4 = self.sp_.decode(['foo', 'bar'], out_type='immutable_proto')
y5 = self.sp_.decode([20, 30], out_type='immutable_proto')
self.assertEqual(s1, t1)
self.assertEqual(s3, t3)
self.assertEqual(s4, t4)
self.assertEqual(s5, t5)
self.assertEqual(s1, y1)
self.assertEqual(s3, y3)
self.assertEqual(s4, y4)
self.assertEqual(s5, y5)
hset_piece = defaultdict(int)
# eq test
for i in range(len(s1.pieces)):
self.assertEqual(s1.pieces[i], t1.pieces[i])
hset_piece[s1.pieces[i]] += 1
hset_piece[t1.pieces[i]] += 1
self.assertEqual(len(hset_piece), len(s1.pieces))
# has test
hset = defaultdict(int)
hset[s1] += 1
hset[t1] += 1
hset[s3] += 1
hset[t3] += 1
self.assertEqual(len(hset), 2)
self.assertEqual(hset[s1], 2)
self.assertEqual(hset[s3], 2)
self.assertEqual(hset[t1], 2)
self.assertEqual(hset[t3], 2)
x1 = self.sp_.encode_as_serialized_proto(text)
x2 = self.sp_.sample_encode_as_serialized_proto(text, 10, 0.2)
x3 = self.sp_.nbest_encode_as_serialized_proto(text, 10)
x4 = self.sp_.decode_pieces_as_serialized_proto(['foo', 'bar'])
x5 = self.sp_.decode_ids_as_serialized_proto([20, 30])
self.assertEqual(x1, t1.SerializeAsString())
self.assertEqual(x3, t3.SerializeAsString())
self.assertEqual(x4, t4.SerializeAsString())
self.assertEqual(x5, t5.SerializeAsString())
v1 = self.sp_.EncodeAsIds(text)
v2 = self.sp_.EncodeAsPieces(text)
self.assertEqual([x.id for x in s1.pieces], v1)
self.assertEqual([x.piece for x in s1.pieces], v2)
self.assertEqual(text, s1.text)
surfaces1 = [s1.text[x.begin : x.end] for x in s1.pieces]
surfaces2 = [x.surface for x in s1.pieces]
self.assertEqual(surfaces1, surfaces2)
ids = []
for i in range(len(s1.pieces)):
ids.append(s1.pieces[i].id)
self.assertEqual(ids, v1)
pieces = []
for i in range(len(s1.pieces)):
pieces.append(s1.pieces[i].piece)
self.assertEqual(pieces, v2)
for v in s3.nbests:
self.assertEqual(text, v.text)
self.assertEqual(self.sp_.Decode([x.id for x in v.pieces]), text)
for i in range(len(s3.nbests)):
self.assertEqual(text, s3.nbests[i].text)
self.assertEqual(
self.sp_.Decode([x.id for x in s3.nbests[i].pieces]), text
)
# slice
self.assertEqual(s1.pieces[::-1], list(reversed(s1.pieces)))
self.assertEqual(s3.nbests[::-1], list(reversed(s3.nbests)))
# Japanese offset
s1 = self.jasp_.EncodeAsImmutableProto(
'吾輩は猫である。Hello world. ABC 123'
)
surfaces1 = [s1.text[x.begin : x.end] for x in s1.pieces]
surfaces2 = [x.surface for x in s1.pieces]
self.assertEqual(surfaces1, surfaces2)
ids = [x.id for x in s1.pieces]
s2 = self.jasp_.DecodeIdsAsImmutableProto(ids)
self.assertEqual(s2, s1)
pieces = [x.piece for x in s1.pieces]
s2 = self.jasp_.DecodePiecesAsImmutableProto(pieces)
self.assertEqual(s2, s1)
def test_new_api(self):
sp = spm.SentencePieceProcessor(
model_file=os.path.join('test', 'test_model.model')
)
text = 'hello world'
text2 = 'Tokyo'
ids = self.sp_.EncodeAsIds(text)
ids2 = self.sp_.EncodeAsIds(text2)
pieces = self.sp_.EncodeAsPieces(text)
pieces2 = self.sp_.EncodeAsPieces(text2)
sprotos = self.sp_.EncodeAsSerializedProto(text)
sproto2 = self.sp_.EncodeAsSerializedProto(text2)
iprotos = self.sp_.EncodeAsImmutableProto(text)
iprotos2 = self.sp_.EncodeAsImmutableProto(text2)
self.assertEqual(sp.encode(text, out_type=int), ids)
self.assertEqual(sp.encode(text, out_type=str), pieces)
self.assertEqual(sp.encode(text, out_type='serialized_proto'), sprotos)
self.assertEqual(sp.encode(text, out_type='immutable_proto'), iprotos)
self.assertEqual(sp.encode([text], out_type=int), [ids])
self.assertEqual(sp.encode([text], out_type=str), [pieces])
self.assertEqual(sp.encode([text], out_type='serialized_proto'), [sprotos])
self.assertEqual(sp.encode([text], out_type='immutable_proto'), [iprotos])
self.assertEqual(len(iprotos.pieces), len(pieces))
self.assertEqual(len(iprotos.pieces), len(ids))
self.assertEqual(iprotos.text, text)
self.assertEqual(len(iprotos2.pieces), len(pieces2))
self.assertEqual(len(iprotos2.pieces), len(ids2))
self.assertEqual(iprotos2.text, text2)
for i in range(len(iprotos.pieces)):
self.assertEqual(ids[i], iprotos.pieces[i].id)
self.assertEqual(pieces[i], iprotos.pieces[i].piece)
for i, piece in enumerate(iprotos.pieces):
self.assertEqual(ids[i], piece.id)
self.assertEqual(pieces[i], piece.piece)
for i in range(len(iprotos2.pieces)):
self.assertEqual(ids2[i], iprotos2.pieces[i].id)
self.assertEqual(pieces2[i], iprotos2.pieces[i].piece)
for i, piece in enumerate(iprotos2.pieces):
self.assertEqual(ids2[i], piece.id)
self.assertEqual(pieces2[i], piece.piece)
detok_ids = self.sp_.DecodeIds(ids)
detok_pieces = self.sp_.DecodePieces(pieces)
self.assertEqual(sp.decode(ids), detok_ids)
self.assertEqual(sp.decode(pieces), detok_pieces)
self.assertEqual(sp.decode([]), '')
self.assertEqual(sp.decode([[]]), [''])
# add_bos, add_eos, reverse
self.assertEqual([sp.bos_id()] + ids, sp.encode(text, add_bos=True))
self.assertEqual(ids + [sp.eos_id()], sp.encode(text, add_eos=True))
self.assertEqual(ids + [sp.eos_id()], sp.EncodeAsIds(text, add_eos=True))
rids = ids[:]
rids.reverse()
self.assertEqual(rids, sp.encode(text, reverse=True))
self.assertEqual(rids, sp.EncodeAsIds(text, reverse=True))
# different shape.
self.assertEqual([ids, ids2], sp.encode([text, text2]))
self.assertEqual([pieces, pieces2], sp.encode([text, text2], out_type=str))
self.assertEqual([text, text2], sp.decode([ids, ids2]))
self.assertEqual([text, text2], sp.decode([pieces, pieces2]))
pieces = list(reversed(self.sp_.EncodeAsPieces(text)))
self.assertEqual(pieces, sp.encode(text, reverse=True, out_type=str))
# emit unk piece
unk_char = '藤'
pieces = self.sp_.EncodeAsIds(unk_char, emit_unk_piece=True)
pieces2 = self.sp_.encode(unk_char, out_type=int, emit_unk_piece=True)
self.assertEqual(pieces[1], sp.unk_id())
self.assertEqual(pieces2[1], sp.unk_id())
self.assertEqual(pieces, pieces2)
pieces = self.sp_.EncodeAsPieces(unk_char, emit_unk_piece=True)
pieces2 = self.sp_.encode(unk_char, out_type=str, emit_unk_piece=True)
self.assertEqual(pieces[1], '<unk>')
self.assertEqual(pieces2[1], '<unk>')
self.assertEqual(pieces, pieces2)
pieces = self.sp_.EncodeAsPieces(unk_char, emit_unk_piece=False)
pieces2 = self.sp_.encode(unk_char, out_type=str, emit_unk_piece=False)
self.assertEqual(pieces[1], unk_char)
self.assertEqual(pieces2[1], unk_char)
self.assertEqual(pieces, pieces2)
def test_new_api_init(self):
sp = spm.SentencePieceProcessor(
model_file=os.path.join('test', 'test_model.model'),
add_bos=True,
add_eos=True,
out_type=str,
)
text = 'hello world'
pieces = ['<s>'] + self.sp_.EncodeAsPieces(text) + ['</s>']
self.assertEqual(pieces, sp.encode(text))
pieces = self.sp_.EncodeAsPieces(text) + ['</s>']
self.assertEqual(pieces, sp.encode(text, add_bos=False, add_eos=True))
def test_sampling(self):
sp = self.sp_
for out_type in [str, int, 'serialized_proto', 'immutable_proto']:
ids = defaultdict(int)
for n in range(100):
out = sp.encode('hello world', out_type=out_type, enable_sampling=True)
if type(out) is list:
out = tuple(out)
++ids[out]
self.assertGreater(len(ids), 1)
ids2 = defaultdict(int)
for n in range(100):
out = sp.encode('hello world', out_type=out_type, enable_sampling=False)
if type(out) is list:
out = tuple(out)
++ids2[out]
self.assertEqual(len(ids2), 1)
out = sp.encode(
['hello world', 'this is a test'],
out_type=out_type,
enable_sampling=True,
)
self.assertEqual(len(out), 2)
out = sp.encode(
['hello world', 'this is a test'],
out_type=out_type,
enable_sampling=False,
)
self.assertEqual(len(out), 2)
def test_nbest(self):
sp = self.sp_
text = 'hello world'
text2 = 'I have a pen.'
for out_type in [str, int, 'serialized_proto', 'immutable_proto']:
results = sp.nbest_encode(text, nbest_size=10, out_type=out_type)
self.assertEqual(
results, sp.NBestEncode(text, nbest_size=10, out_type=out_type)
)
if out_type in [str, int]:
for n in results:
self.assertEqual(sp.decode(n), text)
for n in sp.decode(results):
self.assertEqual(n, text)
# batch test
results = sp.nbest_encode([text, text2], nbest_size=10, out_type=out_type)
self.assertEqual(
results,
sp.NBestEncode([text, text2], nbest_size=10, out_type=out_type),
)
self.assertEqual(len(results), 2)
if out_type in [str, int]:
for n in results[0]:
self.assertEqual(sp.decode(n), text)
for n in results[1]:
self.assertEqual(sp.decode(n), text2)
decoded = sp.decode(results[0])
self.assertEqual(len(decoded), 10)
for n in decoded:
self.assertEqual(n, text)
decoded = sp.decode(results[1])
self.assertEqual(len(decoded), 10)
for n in decoded:
self.assertEqual(n, text2)
self.assertEqual(
sp.nbest_encode(text, nbest_size=10, out_type=str),
sp.nbest_encode_as_pieces(text, nbest_size=10),
)
self.assertEqual(
sp.nbest_encode(text, nbest_size=10, out_type=int),
sp.nbest_encode_as_ids(text, nbest_size=10),
)
self.assertEqual(
sp.nbest_encode(text, nbest_size=10, out_type='serialized_proto'),
sp.nbest_encode_as_serialized_proto(text, nbest_size=10),
)
self.assertEqual(
sp.nbest_encode(text, nbest_size=10, out_type='immutable_proto'),
sp.nbest_encode_as_immutable_proto(text, nbest_size=10),
)
def test_sample_and_score(self):
sp = self.sp_
text = 'hello world'
text2 = 'I have a pen.'
for out_type in [str, int, 'serialized_proto', 'immutable_proto']:
results = sp.sample_encode_and_score(
text, wor=True, num_samples=10, out_type=out_type
)
results = sp.SampleEncodeAndScore(
text, wor=False, num_samples=10, out_type=out_type
)
if out_type in [str, int]:
for n in results:
self.assertEqual(sp.decode(n[0]), text)
results = sp.sample_encode_and_score(
[text, text2], wor=True, num_samples=10, out_type=out_type
)
results = sp.SampleEncodeAndScore(
[text, text2], wor=True, num_samples=10, out_type=out_type
)
if out_type in [str, int]:
for n in results[0]:
self.assertEqual(sp.decode(n[0]), text)
for n in results[1]:
self.assertEqual(sp.decode(n[0]), text2)
sp.sample_encode_and_score_as_pieces(text, 10)
sp.sample_encode_and_score_as_ids(text, 10)
sp.sample_encode_and_score_as_immutable_proto(text, 10)
sp.sample_encode_and_score_as_serialized_proto(text, 10)
def test_valid_range(self):
size = self.sp_.piece_size()
funcs = [
'IdToPiece',
'GetScore',
'IsUnknown',
'IsControl',
'IsUnused',
'IsByte',
'DecodeIds',
'DecodeIdsAsSerializedProto',
]
for m in funcs:
getattr(self.sp_, m)([10, 20, 30])
for m in funcs:
try:
getattr(self.sp_, m)([size])
self.assertTrue(False)
except:
self.assertTrue(True)
def test_batch(self):
sp = spm.SentencePieceProcessor(
model_file=os.path.join('test', 'test_model.model')
)
with open(os.path.join(data_dir, 'botchan.txt'), 'r') as file:
texts = file.readlines()
for out_type in [str, int, 'serialized_proto', 'immutable_proto']:
r1 = sp.encode(texts, out_type=out_type, num_threads=None)
r2 = sp.encode(texts, out_type=out_type, num_threads=1)
r3 = sp.encode(texts, out_type=out_type, num_threads=-1)
r4 = sp.encode(texts, out_type=out_type, num_threads=8)
r5 = [sp.encode(s, out_type=out_type) for s in texts]
self.assertEqual(r1, r2)
self.assertEqual(r1, r3)
self.assertEqual(r1, r4)
self.assertEqual(r1, r5)
if out_type in [str, int]:
d1 = sp.decode(r1, num_threads=None)
d2 = sp.decode(r2, num_threads=1)
d3 = sp.decode(r3, num_threads=-1)
d4 = sp.decode(r4, num_threads=8)
d5 = [sp.decode(s) for s in r5]
self.assertEqual(d1, d2)
self.assertEqual(d1, d3)
self.assertEqual(d1, d4)
self.assertEqual(d1, d5)
e1 = sp.calculate_entropy(texts, alpha=1.0, num_threads=10)
e2 = sp.CalculateEntropy(texts, alpha=1.0, num_threads=10)
e3 = [sp.calculate_entropy(s, alpha=1.0) for s in texts]
self.assertEqual(e1, e2)
self.assertEqual(e1, e3)
def test_pickle(self):
with open('sp.pickle', 'wb') as f:
pickle.dump(self.sp_, f)
id1 = self.sp_.encode('hello world.', out_type=int)
with open('sp.pickle', 'rb') as f:
sp = pickle.load(f)
id2 = sp.encode('hello world.', out_type=int)
self.assertEqual(id1, id2)
def test_global_params(self):
spm.SetRandomGeneratorSeed(0)
spm.SetMinLogLevel(2)
spm.set_random_generator_seed(1)
spm.set_min_log_level(3)
def test_normalize(self):
sp = spm.SentencePieceProcessor(
model_file=os.path.join('test', 'test_model.model')
)
self.assertEqual('▁KADOKAWAABC', sp.normalize('KADOKAWAABC'))
self.assertEqual('▁KADOKAWAABC', sp.Normalize('KADOKAWAABC'))
x = sp.Normalize('KADOKAWAABC', with_offsets=True)
self.assertEqual('▁KADOKAWAABC', x[0])
self.assertEqual([0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], x[1])
x = sp.Normalize('KADOKAWAABC'.encode('utf8'), with_offsets=True)
self.assertEqual('▁KADOKAWAABC'.encode('utf8'), x[0])
self.assertEqual(
[0, 0, 0, 0, 3, 6, 9, 12, 15, 18, 21, 24, 25, 26, 27], x[1]
)
self.assertEqual(
['▁KADOKAWAABC', '▁平成'], sp.normalize(['KADOKAWAABC', '㍻'])
)
self.assertEqual(
['▁KADOKAWAABC', '▁平成'], sp.Normalize(['KADOKAWAABC', '㍻'])
)
x = sp.Normalize(
['KADOKAWAABC'.encode('utf8'), '㍻'.encode('utf8')],
with_offsets=True,
)
self.assertEqual(len(x), 2)
self.assertEqual('▁KADOKAWAABC'.encode('utf8'), x[0][0])
self.assertEqual(
[0, 0, 0, 0, 3, 6, 9, 12, 15, 18, 21, 24, 25, 26, 27], x[0][1]
)
x = sp.Normalize(['KADOKAWAABC', '㍻'], with_offsets=True)
self.assertEqual(len(x), 2)
self.assertEqual('▁KADOKAWAABC', x[0][0])
self.assertEqual([0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], x[0][1])
self.assertEqual('▁平成', x[1][0])
self.assertEqual([0, 0, 0, 1], x[1][1])
def test_normalizer(self):
sp = spm.SentencePieceNormalizer(
model_file=os.path.join('test', 'test_model.model')
)
self.assertEqual('KADOKAWAABC', sp.normalize('KADOKAWAABC'))
self.assertEqual('KADOKAWAABC', sp.Normalize('KADOKAWAABC'))
x = sp.Normalize('KADOKAWAABC'.encode('utf8'), with_offsets=True)
self.assertEqual('KADOKAWAABC'.encode('utf8'), x[0])
self.assertEqual([0, 3, 6, 9, 12, 15, 18, 21, 24, 25, 26, 27], x[1])
x = sp.Normalize('KADOKAWAABC', with_offsets=True)
self.assertEqual('KADOKAWAABC', x[0])
self.assertEqual([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], x[1])
self.assertEqual(
['KADOKAWAABC', '平成'], sp.normalize(['KADOKAWAABC', '㍻'])
)
self.assertEqual(
['KADOKAWAABC', '平成'], sp.Normalize(['KADOKAWAABC', '㍻'])
)
x = sp.Normalize(
['KADOKAWAABC'.encode('utf8'), '㍻'.encode('utf8')],
with_offsets=True,
)
self.assertEqual(len(x), 2)
self.assertEqual('KADOKAWAABC'.encode('utf8'), x[0][0])
self.assertEqual([0, 3, 6, 9, 12, 15, 18, 21, 24, 25, 26, 27], x[0][1])
x = sp.Normalize(['KADOKAWAABC', '㍻'], with_offsets=True)
self.assertEqual(len(x), 2)
self.assertEqual('KADOKAWAABC', x[0][0])
self.assertEqual([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], x[0][1])
self.assertEqual('平成', x[1][0])
self.assertEqual([0, 0, 1], x[1][1])
sp = spm.SentencePieceNormalizer(
model_file=os.path.join('test', 'test_model.model'),
add_dummy_prefix=True,
escape_whitespaces=True,
remove_extra_whitespaces=False,
)
self.assertEqual('▁hello▁▁world', sp.normalize('hello world'))
sp = spm.SentencePieceNormalizer(
model_file=os.path.join('test', 'test_model.model'),
add_dummy_prefix=True,
escape_whitespaces=True,
remove_extra_whitespaces=True,
)
self.assertEqual('▁hello▁world', sp.normalize(' hello world '))
sp = spm.SentencePieceNormalizer(
model_file=os.path.join('test', 'test_model.model'),
add_dummy_prefix=False,
escape_whitespaces=False,
remove_extra_whitespaces=True,
)
self.assertEqual('hello world', sp.normalize(' hello world '))
def test_normalizer_rule(self):
sp = spm.SentencePieceNormalizer(rule_name='identity')
self.assertEqual('ABC', sp.Normalize('ABC'))
sp = spm.SentencePieceNormalizer(rule_name='nfkc_cf')
self.assertEqual('abc', sp.Normalize('ABC'))
def test_override_normalize_spec(self):
sp = spm.SentencePieceProcessor(
model_file=os.path.join('test', 'test_model.model')
)
self.assertEqual(
sp.EncodeAsPieces(' hello world '), ['▁he', 'll', 'o', '▁world']
)
sp.override_normalizer_spec(add_dummy_prefix=False)
sp.override_normalizer_spec(remove_extra_whitespaces=False)
sp.override_normalizer_spec(escape_whitespaces=False)
self.assertEqual(
sp.EncodeAsPieces(' hello world '),
[' ', 'he', 'll', 'o', ' ', 'w', 'or', 'l', 'd', ' '],
)
def suite():
suite = unittest.TestSuite()
suite.addTests(unittest.makeSuite(TestSentencepieceProcessor))
return suite
if __name__ == '__main__':
unittest.main()