Spaces:
Sleeping
Sleeping
// Copyright 2016 Google Inc. | |
// | |
// Licensed under the Apache License, Version 2.0 (the "License"); | |
// you may not use this file except in compliance with the License. | |
// You may obtain a copy of the License at | |
// | |
// http://www.apache.org/licenses/LICENSE-2.0 | |
// | |
// Unless required by applicable law or agreed to in writing, software | |
// distributed under the License is distributed on an "AS IS" BASIS, | |
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
// See the License for the specific language governing permissions and | |
// limitations under the License.! | |
namespace sentencepiece { | |
namespace { | |
const std::vector<TrainerSpec::ModelType> kModelTypes = { | |
TrainerSpec::UNIGRAM, TrainerSpec::BPE, TrainerSpec::WORD, | |
TrainerSpec::CHAR}; | |
ModelProto MakeBaseModelProto(TrainerSpec::ModelType type, | |
bool byte_fallback = false) { | |
ModelProto model_proto; | |
auto *sp1 = model_proto.add_pieces(); | |
auto *sp2 = model_proto.add_pieces(); | |
auto *sp3 = model_proto.add_pieces(); | |
model_proto.mutable_trainer_spec()->set_model_type(type); | |
model_proto.mutable_trainer_spec()->set_byte_fallback(byte_fallback); | |
sp1->set_type(ModelProto::SentencePiece::UNKNOWN); | |
sp1->set_piece("<unk>"); | |
sp2->set_type(ModelProto::SentencePiece::CONTROL); | |
sp2->set_piece("<s>"); | |
sp3->set_type(ModelProto::SentencePiece::CONTROL); | |
sp3->set_piece("</s>"); | |
return model_proto; | |
} | |
void AddPiece(ModelProto *model_proto, const std::string &piece, | |
float score = 0.0) { | |
auto *sp = model_proto->add_pieces(); | |
sp->set_piece(piece); | |
sp->set_score(score); | |
} | |
void AddBytePiece(ModelProto *model_proto, unsigned char byte) { | |
auto *sp = model_proto->add_pieces(); | |
sp->set_piece(ByteToPiece(byte)); | |
sp->set_type(ModelProto::SentencePiece::BYTE); | |
} | |
TEST(ModelInterfaceTest, GetDefaultPieceTest) { | |
{ | |
ModelProto model_proto; | |
EXPECT_EQ("<unk>", model_proto.trainer_spec().unk_piece()); | |
EXPECT_EQ("<s>", model_proto.trainer_spec().bos_piece()); | |
EXPECT_EQ("</s>", model_proto.trainer_spec().eos_piece()); | |
EXPECT_EQ("<pad>", model_proto.trainer_spec().pad_piece()); | |
} | |
{ | |
ModelProto model_proto = MakeBaseModelProto(TrainerSpec::UNIGRAM); | |
AddPiece(&model_proto, "a"); | |
auto model = ModelFactory::Create(model_proto); | |
EXPECT_EQ("<unk>", model->unk_piece()); | |
EXPECT_EQ("<s>", model->bos_piece()); | |
EXPECT_EQ("</s>", model->eos_piece()); | |
EXPECT_EQ("<pad>", model->pad_piece()); | |
} | |
{ | |
ModelProto model_proto = MakeBaseModelProto(TrainerSpec::UNIGRAM); | |
AddPiece(&model_proto, "a"); | |
model_proto.mutable_trainer_spec()->clear_unk_piece(); | |
model_proto.mutable_trainer_spec()->clear_bos_piece(); | |
model_proto.mutable_trainer_spec()->clear_eos_piece(); | |
model_proto.mutable_trainer_spec()->clear_pad_piece(); | |
auto model = ModelFactory::Create(model_proto); | |
EXPECT_EQ("<unk>", model->unk_piece()); | |
EXPECT_EQ("<s>", model->bos_piece()); | |
EXPECT_EQ("</s>", model->eos_piece()); | |
EXPECT_EQ("<pad>", model->pad_piece()); | |
} | |
{ | |
ModelProto model_proto = MakeBaseModelProto(TrainerSpec::UNIGRAM); | |
AddPiece(&model_proto, "a"); | |
model_proto.mutable_trainer_spec()->set_unk_piece("UNK"); | |
model_proto.mutable_trainer_spec()->set_bos_piece("BOS"); | |
model_proto.mutable_trainer_spec()->set_eos_piece("EOS"); | |
model_proto.mutable_trainer_spec()->set_pad_piece("PAD"); | |
auto model = ModelFactory::Create(model_proto); | |
EXPECT_EQ("UNK", model->unk_piece()); | |
EXPECT_EQ("BOS", model->bos_piece()); | |
EXPECT_EQ("EOS", model->eos_piece()); | |
EXPECT_EQ("PAD", model->pad_piece()); | |
} | |
} | |
TEST(ModelInterfaceTest, SetModelInterfaceTest) { | |
for (const auto type : kModelTypes) { | |
ModelProto model_proto = MakeBaseModelProto(type); | |
AddPiece(&model_proto, "a"); | |
AddPiece(&model_proto, "b"); | |
AddPiece(&model_proto, "c"); | |
AddPiece(&model_proto, "d"); | |
auto model = ModelFactory::Create(model_proto); | |
EXPECT_EQ(model_proto.SerializeAsString(), | |
model->model_proto().SerializeAsString()); | |
} | |
} | |
TEST(ModelInterfaceTest, PieceToIdTest) { | |
for (const auto type : kModelTypes) { | |
ModelProto model_proto = MakeBaseModelProto(type); | |
AddPiece(&model_proto, "a", 0.1); // 3 | |
AddPiece(&model_proto, "b", 0.2); // 4 | |
AddPiece(&model_proto, "c", 0.3); // 5 | |
AddPiece(&model_proto, "d", 0.4); // 6 | |
AddPiece(&model_proto, "e", 0.5); // 7 | |
model_proto.mutable_pieces(6)->set_type(ModelProto::SentencePiece::UNUSED); | |
model_proto.mutable_pieces(7)->set_type( | |
ModelProto::SentencePiece::USER_DEFINED); | |
auto model = ModelFactory::Create(model_proto); | |
EXPECT_EQ(model_proto.SerializeAsString(), | |
model->model_proto().SerializeAsString()); | |
EXPECT_EQ(0, model->PieceToId("<unk>")); | |
EXPECT_EQ(1, model->PieceToId("<s>")); | |
EXPECT_EQ(2, model->PieceToId("</s>")); | |
EXPECT_EQ(3, model->PieceToId("a")); | |
EXPECT_EQ(4, model->PieceToId("b")); | |
EXPECT_EQ(5, model->PieceToId("c")); | |
EXPECT_EQ(6, model->PieceToId("d")); | |
EXPECT_EQ(7, model->PieceToId("e")); | |
EXPECT_EQ(0, model->PieceToId("f")); // unk | |
EXPECT_EQ(0, model->PieceToId("")); // unk | |
EXPECT_EQ("<unk>", model->IdToPiece(0)); | |
EXPECT_EQ("<s>", model->IdToPiece(1)); | |
EXPECT_EQ("</s>", model->IdToPiece(2)); | |
EXPECT_EQ("a", model->IdToPiece(3)); | |
EXPECT_EQ("b", model->IdToPiece(4)); | |
EXPECT_EQ("c", model->IdToPiece(5)); | |
EXPECT_EQ("d", model->IdToPiece(6)); | |
EXPECT_EQ("e", model->IdToPiece(7)); | |
EXPECT_TRUE(model->IsUnknown(0)); | |
EXPECT_FALSE(model->IsUnknown(1)); | |
EXPECT_FALSE(model->IsUnknown(2)); | |
EXPECT_FALSE(model->IsUnknown(3)); | |
EXPECT_FALSE(model->IsUnknown(4)); | |
EXPECT_FALSE(model->IsUnknown(5)); | |
EXPECT_FALSE(model->IsUnknown(6)); | |
EXPECT_FALSE(model->IsUnknown(7)); | |
EXPECT_FALSE(model->IsControl(0)); | |
EXPECT_TRUE(model->IsControl(1)); | |
EXPECT_TRUE(model->IsControl(2)); | |
EXPECT_FALSE(model->IsControl(3)); | |
EXPECT_FALSE(model->IsControl(4)); | |
EXPECT_FALSE(model->IsControl(5)); | |
EXPECT_FALSE(model->IsControl(6)); | |
EXPECT_FALSE(model->IsControl(7)); | |
EXPECT_FALSE(model->IsUnused(0)); | |
EXPECT_FALSE(model->IsUnused(1)); | |
EXPECT_FALSE(model->IsUnused(2)); | |
EXPECT_FALSE(model->IsUnused(3)); | |
EXPECT_FALSE(model->IsUnused(4)); | |
EXPECT_FALSE(model->IsUnused(5)); | |
EXPECT_TRUE(model->IsUnused(6)); | |
EXPECT_FALSE(model->IsUnused(7)); | |
EXPECT_FALSE(model->IsUserDefined(0)); | |
EXPECT_FALSE(model->IsUserDefined(1)); | |
EXPECT_FALSE(model->IsUserDefined(2)); | |
EXPECT_FALSE(model->IsUserDefined(3)); | |
EXPECT_FALSE(model->IsUserDefined(4)); | |
EXPECT_FALSE(model->IsUserDefined(5)); | |
EXPECT_FALSE(model->IsUserDefined(6)); | |
EXPECT_TRUE(model->IsUserDefined(7)); | |
EXPECT_NEAR(0, model->GetScore(0), 0.0001); | |
EXPECT_NEAR(0, model->GetScore(1), 0.0001); | |
EXPECT_NEAR(0, model->GetScore(2), 0.0001); | |
EXPECT_NEAR(0.1, model->GetScore(3), 0.0001); | |
EXPECT_NEAR(0.2, model->GetScore(4), 0.0001); | |
EXPECT_NEAR(0.3, model->GetScore(5), 0.0001); | |
EXPECT_NEAR(0.4, model->GetScore(6), 0.0001); | |
EXPECT_NEAR(0.5, model->GetScore(7), 0.0001); | |
} | |
} | |
TEST(ModelInterfaceTest, InvalidModelTest) { | |
// Empty piece. | |
{ | |
ModelProto model_proto = MakeBaseModelProto(TrainerSpec::UNIGRAM); | |
AddPiece(&model_proto, ""); | |
auto model = ModelFactory::Create(model_proto); | |
EXPECT_FALSE(model->status().ok()); | |
} | |
// Duplicated pieces. | |
{ | |
ModelProto model_proto = MakeBaseModelProto(TrainerSpec::UNIGRAM); | |
AddPiece(&model_proto, "a"); | |
AddPiece(&model_proto, "a"); | |
auto model = ModelFactory::Create(model_proto); | |
EXPECT_FALSE(model->status().ok()); | |
} | |
// Multiple unknowns. | |
{ | |
ModelProto model_proto = MakeBaseModelProto(TrainerSpec::UNIGRAM); | |
model_proto.mutable_pieces(1)->set_type(ModelProto::SentencePiece::UNKNOWN); | |
auto model = ModelFactory::Create(model_proto); | |
EXPECT_FALSE(model->status().ok()); | |
} | |
// No unknown. | |
{ | |
ModelProto model_proto = MakeBaseModelProto(TrainerSpec::UNIGRAM); | |
model_proto.mutable_pieces(0)->set_type(ModelProto::SentencePiece::CONTROL); | |
auto model = ModelFactory::Create(model_proto); | |
EXPECT_FALSE(model->status().ok()); | |
} | |
} | |
TEST(ModelInterfaceTest, ByteFallbackModelTest) { | |
{ | |
ModelProto model_proto = MakeBaseModelProto(TrainerSpec::UNIGRAM, true); | |
for (int i = 0; i < 256; ++i) { | |
AddBytePiece(&model_proto, i); | |
} | |
AddPiece(&model_proto, "a"); | |
auto model = ModelFactory::Create(model_proto); | |
EXPECT_TRUE(model->status().ok()); | |
} | |
// `byte_fallback` is true, but there are not 256 byte pieces. | |
{ | |
ModelProto model_proto = MakeBaseModelProto(TrainerSpec::UNIGRAM, true); | |
for (int i = 0; i < 10; ++i) { | |
AddBytePiece(&model_proto, i); | |
} | |
AddPiece(&model_proto, "a"); | |
auto model = ModelFactory::Create(model_proto); | |
EXPECT_FALSE(model->status().ok()); | |
} | |
// `byte_fallback` is false, but a byte piece is found. | |
{ | |
ModelProto model_proto = MakeBaseModelProto(TrainerSpec::UNIGRAM); | |
for (int i = 0; i < 10; ++i) { | |
AddBytePiece(&model_proto, i); | |
} | |
AddPiece(&model_proto, "a"); | |
auto model = ModelFactory::Create(model_proto); | |
EXPECT_FALSE(model->status().ok()); | |
} | |
} | |
std::string RandomString(int length) { | |
const char kAlphaNum[] = | |
"0123456789" | |
"!@#$%^&*" | |
"ABCDEFGHIJKLMNOPQRSTUVWXYZ" | |
"abcdefghijklmnopqrstuvwxyz"; | |
const int kAlphaSize = sizeof(kAlphaNum) - 1; | |
const int size = rand() % length + 1; | |
std::string result; | |
for (int i = 0; i < size; ++i) { | |
result += kAlphaNum[rand() % kAlphaSize]; | |
} | |
return result; | |
} | |
TEST(ModelInterfaceTest, PieceToIdStressTest) { | |
for (const auto type : kModelTypes) { | |
for (int i = 0; i < 100; ++i) { | |
absl::flat_hash_map<std::string, int> expected_p2i; | |
absl::flat_hash_map<int, std::string> expected_i2p; | |
ModelProto model_proto = MakeBaseModelProto(type); | |
for (int n = 0; n < 1000; ++n) { | |
const std::string piece = RandomString(10); | |
if (expected_p2i.find(piece) != expected_p2i.end()) { | |
continue; | |
} | |
expected_p2i[piece] = model_proto.pieces_size(); | |
expected_i2p[model_proto.pieces_size()] = piece; | |
AddPiece(&model_proto, piece); | |
} | |
auto model = ModelFactory::Create(model_proto); | |
for (const auto &it : expected_p2i) { | |
EXPECT_EQ(it.second, model->PieceToId(it.first)); | |
} | |
for (const auto &it : expected_i2p) { | |
EXPECT_EQ(it.second, model->IdToPiece(it.first)); | |
} | |
} | |
} | |
} | |
TEST(ModelInterfaceTest, SplitIntoWordsTest) { | |
{ | |
const auto v = SplitIntoWords(WS "this" WS "is" WS "a" WS "pen"); | |
EXPECT_EQ(4, v.size()); | |
EXPECT_EQ(WS "this", v[0]); | |
EXPECT_EQ(WS "is", v[1]); | |
EXPECT_EQ(WS "a", v[2]); | |
EXPECT_EQ(WS "pen", v[3]); | |
} | |
{ | |
const auto v = SplitIntoWords("this" WS "is" WS "a" WS "pen"); | |
EXPECT_EQ(4, v.size()); | |
EXPECT_EQ("this", v[0]); | |
EXPECT_EQ(WS "is", v[1]); | |
EXPECT_EQ(WS "a", v[2]); | |
EXPECT_EQ(WS "pen", v[3]); | |
} | |
{ | |
const auto v = SplitIntoWords(WS "this" WS WS "is"); | |
EXPECT_EQ(3, v.size()); | |
EXPECT_EQ(WS "this", v[0]); | |
EXPECT_EQ(WS, v[1]); | |
EXPECT_EQ(WS "is", v[2]); | |
} | |
{ | |
const auto v = SplitIntoWords(""); | |
EXPECT_TRUE(v.empty()); | |
} | |
{ | |
const auto v = SplitIntoWords("hello"); | |
EXPECT_EQ(1, v.size()); | |
EXPECT_EQ("hello", v[0]); | |
} | |
} | |
TEST(ModelInterfaceTest, SplitIntoWordsSuffixTest) { | |
{ | |
const auto v = SplitIntoWords("this" WS "is" WS "a" WS "pen" WS, true); | |
EXPECT_EQ(4, v.size()); | |
EXPECT_EQ("this" WS, v[0]); | |
EXPECT_EQ("is" WS, v[1]); | |
EXPECT_EQ("a" WS, v[2]); | |
EXPECT_EQ("pen" WS, v[3]); | |
} | |
{ | |
const auto v = SplitIntoWords("this" WS "is" WS "a" WS "pen", true); | |
EXPECT_EQ(4, v.size()); | |
EXPECT_EQ("this" WS, v[0]); | |
EXPECT_EQ("is" WS, v[1]); | |
EXPECT_EQ("a" WS, v[2]); | |
EXPECT_EQ("pen", v[3]); | |
} | |
{ | |
const auto v = SplitIntoWords(WS "this" WS WS "is", true); | |
EXPECT_EQ(4, v.size()); | |
EXPECT_EQ(WS, v[0]); | |
EXPECT_EQ("this" WS, v[1]); | |
EXPECT_EQ(WS, v[2]); | |
EXPECT_EQ("is", v[3]); | |
} | |
{ | |
const auto v = SplitIntoWords("", true); | |
EXPECT_TRUE(v.empty()); | |
} | |
{ | |
const auto v = SplitIntoWords("hello", true); | |
EXPECT_EQ(1, v.size()); | |
EXPECT_EQ("hello", v[0]); | |
} | |
{ | |
const auto v = SplitIntoWords("hello" WS WS, true); | |
EXPECT_EQ(2, v.size()); | |
EXPECT_EQ("hello" WS, v[0]); | |
EXPECT_EQ(WS, v[1]); | |
} | |
{ | |
const auto v = SplitIntoWords(WS WS "hello" WS WS, true); | |
EXPECT_EQ(4, v.size()); | |
EXPECT_EQ(WS, v[0]); | |
EXPECT_EQ(WS, v[1]); | |
EXPECT_EQ("hello" WS, v[2]); | |
EXPECT_EQ(WS, v[3]); | |
} | |
} | |
TEST(ModelInterfaceTest, SplitIntoWordsWhiteSpaceOnly) { | |
{ | |
const auto v = | |
SplitIntoWords("this" WS "is" WS "a" WS "pen" WS, true, true); | |
EXPECT_EQ(4, v.size()); | |
EXPECT_EQ("this" WS, v[0]); | |
EXPECT_EQ("is" WS, v[1]); | |
EXPECT_EQ("a" WS, v[2]); | |
EXPECT_EQ("pen" WS, v[3]); | |
} | |
{ | |
const auto v = SplitIntoWords(WS WS WS "a", false, true); | |
EXPECT_EQ(1, v.size()); | |
EXPECT_EQ(WS WS WS "a", v[0]); | |
} | |
{ | |
const auto v = SplitIntoWords("a" WS WS WS, true, true); | |
EXPECT_EQ(1, v.size()); | |
EXPECT_EQ("a" WS WS WS, v[0]); | |
} | |
{ | |
const auto v = SplitIntoWords(WS WS, true, true); | |
EXPECT_EQ(1, v.size()); | |
EXPECT_EQ(WS WS, v[0]); | |
} | |
{ | |
const auto v = SplitIntoWords(WS WS "a" WS, true, true); | |
EXPECT_EQ(2, v.size()); | |
EXPECT_EQ(WS WS, v[0]); | |
EXPECT_EQ("a" WS, v[1]); | |
} | |
{ | |
const auto v = SplitIntoWords(WS WS "a" WS, false, true); | |
EXPECT_EQ(2, v.size()); | |
EXPECT_EQ(WS WS "a", v[0]); | |
EXPECT_EQ(WS, v[1]); | |
} | |
} | |
TEST(ModelInterfaceTest, ByteToPieceTest) { | |
EXPECT_EQ(ByteToPiece(0), "<0x00>"); | |
EXPECT_EQ(ByteToPiece(1), "<0x01>"); | |
EXPECT_EQ(ByteToPiece(10), "<0x0A>"); | |
EXPECT_EQ(ByteToPiece(16), "<0x10>"); | |
EXPECT_EQ(ByteToPiece(255), "<0xFF>"); | |
} | |
TEST(ModelInterfaceTest, PieceToByteTest) { | |
// Valid byte pieces. | |
EXPECT_EQ(PieceToByte("<0x00>"), 0); | |
EXPECT_EQ(PieceToByte("<0x01>"), 1); | |
EXPECT_EQ(PieceToByte("<0x0A>"), 10); | |
EXPECT_EQ(PieceToByte("<0x10>"), 16); | |
EXPECT_EQ(PieceToByte("<0xFF>"), 255); | |
// Invalid byte pieces. | |
EXPECT_EQ(PieceToByte("<0x0>"), -1); | |
EXPECT_EQ(PieceToByte("<0x000>"), -1); | |
EXPECT_EQ(PieceToByte("<0x001>"), -1); | |
EXPECT_EQ(PieceToByte("<0xff>"), -1); | |
EXPECT_EQ(PieceToByte("<0xFG>"), -1); | |
EXPECT_EQ(PieceToByte("a"), -1); | |
} | |
TEST(ModelInterfaceTest, VerifyOutputsEquivalent) { | |
for (const auto type : kModelTypes) { | |
ModelProto model_proto = MakeBaseModelProto(type); | |
AddPiece(&model_proto, "a", 1.0); | |
AddPiece(&model_proto, "b", 2.0); | |
auto model = ModelFactory::Create(model_proto); | |
// Equivalent outputs. | |
EXPECT_TRUE(model->VerifyOutputsEquivalent("", "")); | |
EXPECT_TRUE(model->VerifyOutputsEquivalent("a b", "a b")); | |
// Inequivalent outputs. | |
EXPECT_FALSE(model->VerifyOutputsEquivalent("a", "a b")); | |
} | |
} | |
} // namespace | |
} // namespace sentencepiece | |