Spaces:
Sleeping
Sleeping
// Copyright 2016 Google Inc. | |
// | |
// Licensed under the Apache License, Version 2.0 (the "License"); | |
// you may not use this file except in compliance with the License. | |
// You may obtain a copy of the License at | |
// | |
// http://www.apache.org/licenses/LICENSE-2.0 | |
// | |
// Unless required by applicable law or agreed to in writing, software | |
// distributed under the License is distributed on an "AS IS" BASIS, | |
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
// See the License for the specific language governing permissions and | |
// limitations under the License.! | |
namespace sentencepiece { | |
// Space symbol | |
class MockModel : public ModelInterface { | |
public: | |
void SetEncodeResult(absl::string_view input, const EncodeResult &output) { | |
input_ = input; | |
output_ = output; | |
} | |
void SetNBestEncodeResult(absl::string_view input, | |
const NBestEncodeResult &output) { | |
input_ = input; | |
nbest_output_ = output; | |
} | |
EncodeResult Encode(absl::string_view normalized) const { | |
EXPECT_EQ(normalized, input_); | |
return output_; | |
} | |
EncodeResult SampleEncode(absl::string_view normalized, float alpha) const { | |
EXPECT_EQ(normalized, input_); | |
return output_; | |
} | |
NBestEncodeResult NBestEncode(absl::string_view normalized, | |
int nbest_size) const { | |
EXPECT_EQ(normalized, input_); | |
return nbest_output_; | |
} | |
bool IsSampleEncodeAvailable() const override { return true; } | |
bool IsNBestEncodeAvailable() const override { return true; } | |
bool IsControl(int id) const { return id == 1 || id == 2; } | |
bool IsUnknown(int id) const { return id == 0; } | |
int GetPieceSize() const { return 10; } | |
int PieceToId(absl::string_view piece) const { return 0; } | |
const std::string &IdToPiece(int id) const { return kEmptyString; } | |
float GetScore(int id) const { return 0.0; } | |
private: | |
absl::string_view input_; | |
EncodeResult output_; | |
NBestEncodeResult nbest_output_; | |
const std::string kEmptyString; | |
}; | |
class ByteFallbackMockModel : public MockModel { | |
public: | |
bool ByteFallbackEnabled() const override { return true; } | |
}; | |
std::vector<std::string> GetSpVec(const EncodeResult &pieces) { | |
std::vector<std::string> sps; | |
for (const auto &p : pieces) { | |
sps.emplace_back(std::string(p.first)); | |
} | |
return sps; | |
} | |
std::vector<int> GetIdVec(const EncodeResult &pieces) { | |
std::vector<int> ids; | |
for (const auto &p : pieces) { | |
ids.emplace_back(p.second); | |
} | |
return ids; | |
} | |
std::vector<std::string> GetSpVec(const SentencePieceText &spt) { | |
std::vector<std::string> sps; | |
for (auto &sp : spt.pieces()) { | |
sps.emplace_back(sp.piece()); | |
} | |
return sps; | |
} | |
NormalizerSpec MakeDefaultNormalizerSpec() { | |
return SentencePieceTrainer::GetNormalizerSpec("nmt_nfkc"); | |
} | |
TEST(SentencepieceProcessorTest, StatusTest) { | |
SentencePieceProcessor sp; | |
EXPECT_FALSE(sp.status().ok()); | |
auto mock = std::make_unique<MockModel>(); | |
sp.SetModel(std::move(mock)); | |
EXPECT_FALSE(sp.status().ok()); | |
} | |
TEST(SentencepieceProcessorTest, EncodeTest) { | |
const absl::string_view kInput = WS "ABC" WS "DEF"; | |
SentencePieceProcessor sp; | |
const auto normalization_spec = MakeDefaultNormalizerSpec(); | |
{ | |
auto mock = std::make_unique<MockModel>(); | |
const EncodeResult result = { | |
{WS "ABC", 3}, {WS "DE", 4}, {"F", 0}, {"</s>", 2}}; | |
mock->SetEncodeResult(kInput, result); | |
sp.SetModel(std::move(mock)); | |
sp.SetNormalizer( | |
std::make_unique<normalizer::Normalizer>(normalization_spec)); | |
std::vector<std::string> output; | |
EXPECT_TRUE(sp.Encode("ABC DEF", &output).ok()); | |
EXPECT_EQ(GetSpVec(result), output); | |
std::vector<int> ids; | |
EXPECT_TRUE(sp.Encode("ABC DEF", &ids).ok()); | |
EXPECT_EQ(GetIdVec(result), ids); | |
SentencePieceText spt; | |
EXPECT_TRUE(sp.Encode("ABC DEF", &spt).ok()); | |
EXPECT_EQ(4, spt.pieces_size()); | |
for (int i = 0; i < 4; ++i) { | |
EXPECT_EQ(result[i].first, spt.pieces(i).piece()); | |
} | |
SentencePieceText spt2; | |
EXPECT_TRUE(spt2.ParseFromString(sp.EncodeAsSerializedProto("ABC DEF"))); | |
EXPECT_EQ(spt.SerializeAsString(), spt2.SerializeAsString()); | |
EXPECT_EQ("ABC", spt.pieces(0).surface()); | |
EXPECT_EQ(" DE", spt.pieces(1).surface()); | |
EXPECT_EQ("F", spt.pieces(2).surface()); | |
EXPECT_EQ("", spt.pieces(3).surface()); // </s> | |
EXPECT_EQ(3, spt.pieces(0).id()); | |
EXPECT_EQ(4, spt.pieces(1).id()); | |
EXPECT_EQ(0, spt.pieces(2).id()); | |
EXPECT_EQ(2, spt.pieces(3).id()); | |
EXPECT_EQ(0, spt.pieces(0).begin()); | |
EXPECT_EQ(3, spt.pieces(0).end()); | |
EXPECT_EQ(3, spt.pieces(1).begin()); | |
EXPECT_EQ(6, spt.pieces(1).end()); | |
EXPECT_EQ(6, spt.pieces(2).begin()); | |
EXPECT_EQ(7, spt.pieces(2).end()); | |
EXPECT_EQ(7, spt.pieces(3).begin()); | |
EXPECT_EQ(7, spt.pieces(3).end()); | |
} | |
// Unknown sequences. | |
{ | |
auto mock = std::make_unique<MockModel>(); | |
const EncodeResult result = { | |
{WS "ABC", 3}, {WS "D", 4}, {"E", 0}, {"F", 0}, {"</s>", 2}}; | |
const EncodeResult expected = { | |
{WS "ABC", 3}, {WS "D", 4}, {"EF", 0}, {"</s>", 2}}; | |
mock->SetEncodeResult(kInput, result); | |
sp.SetModel(std::move(mock)); | |
sp.SetNormalizer( | |
std::make_unique<normalizer::Normalizer>(normalization_spec)); | |
std::vector<std::string> output; | |
EXPECT_TRUE(sp.Encode("ABC DEF", &output).ok()); | |
EXPECT_EQ(GetSpVec(expected), output); | |
std::vector<int> ids; | |
EXPECT_TRUE(sp.Encode("ABC DEF", &ids).ok()); | |
EXPECT_EQ(GetIdVec(expected), ids); | |
SentencePieceText spt; | |
EXPECT_TRUE(sp.Encode("ABC DEF", &spt).ok()); | |
EXPECT_EQ(4, spt.pieces_size()); | |
for (int i = 0; i < 4; ++i) { | |
EXPECT_EQ(expected[i].first, spt.pieces(i).piece()); | |
} | |
EXPECT_EQ("ABC", spt.pieces(0).surface()); | |
EXPECT_EQ(" D", spt.pieces(1).surface()); | |
EXPECT_EQ("EF", spt.pieces(2).surface()); | |
EXPECT_EQ("", spt.pieces(3).surface()); // </s> | |
EXPECT_EQ(3, spt.pieces(0).id()); | |
EXPECT_EQ(4, spt.pieces(1).id()); | |
EXPECT_EQ(0, spt.pieces(2).id()); | |
EXPECT_EQ(2, spt.pieces(3).id()); | |
EXPECT_EQ(0, spt.pieces(0).begin()); | |
EXPECT_EQ(3, spt.pieces(0).end()); | |
EXPECT_EQ(3, spt.pieces(1).begin()); | |
EXPECT_EQ(5, spt.pieces(1).end()); | |
EXPECT_EQ(5, spt.pieces(2).begin()); | |
EXPECT_EQ(7, spt.pieces(2).end()); | |
EXPECT_EQ(7, spt.pieces(3).begin()); | |
EXPECT_EQ(7, spt.pieces(3).end()); | |
} | |
// Byte-fallback. | |
{ | |
const absl::string_view kInput2 = WS "ABC" WS "DEFあ"; | |
auto mock = std::make_unique<ByteFallbackMockModel>(); | |
const EncodeResult result = {{WS "ABC", 3}, {WS "D", 4}, {"E", 0}, | |
{"F", 0}, {"あ", 0}, {"</s>", 2}}; | |
// "E" -> 0x45 | |
// "F" -> 0x46 | |
// "あ" -> 0xe38182 | |
const EncodeResult expected = {{WS "ABC", 3}, {WS "D", 4}, {"<0x45>", 0}, | |
{"<0x46>", 0}, {"<0xE3>", 0}, {"<0x81>", 0}, | |
{"<0x82>", 0}, {"</s>", 2}}; | |
mock->SetEncodeResult(kInput2, result); | |
sp.SetModel(std::move(mock)); | |
sp.SetNormalizer( | |
std::make_unique<normalizer::Normalizer>(normalization_spec)); | |
std::vector<std::string> output; | |
EXPECT_TRUE(sp.Encode("ABC DEFあ", &output).ok()); | |
EXPECT_EQ(GetSpVec(expected), output); | |
std::vector<int> ids; | |
EXPECT_TRUE(sp.Encode("ABC DEFあ", &ids).ok()); | |
EXPECT_EQ(GetIdVec(expected), ids); | |
SentencePieceText spt; | |
EXPECT_TRUE(sp.Encode("ABC DEFあ", &spt).ok()); | |
EXPECT_EQ(8, spt.pieces_size()); | |
for (int i = 0; i < 8; ++i) { | |
EXPECT_EQ(expected[i].first, spt.pieces(i).piece()); | |
} | |
EXPECT_EQ("ABC", spt.pieces(0).surface()); | |
EXPECT_EQ(" D", spt.pieces(1).surface()); | |
EXPECT_EQ("E", spt.pieces(2).surface()); | |
EXPECT_EQ("F", spt.pieces(3).surface()); | |
EXPECT_EQ("", spt.pieces(4).surface()); // あ | |
EXPECT_EQ("", spt.pieces(5).surface()); // あ | |
EXPECT_EQ("あ", spt.pieces(6).surface()); // あ | |
EXPECT_EQ("", spt.pieces(7).surface()); // </s> | |
EXPECT_EQ(3, spt.pieces(0).id()); | |
EXPECT_EQ(4, spt.pieces(1).id()); | |
EXPECT_EQ(0, spt.pieces(2).id()); | |
EXPECT_EQ(0, spt.pieces(3).id()); | |
EXPECT_EQ(0, spt.pieces(4).id()); | |
EXPECT_EQ(0, spt.pieces(5).id()); | |
EXPECT_EQ(0, spt.pieces(6).id()); | |
EXPECT_EQ(2, spt.pieces(7).id()); | |
EXPECT_EQ(0, spt.pieces(0).begin()); | |
EXPECT_EQ(3, spt.pieces(0).end()); | |
EXPECT_EQ(3, spt.pieces(1).begin()); | |
EXPECT_EQ(5, spt.pieces(1).end()); | |
EXPECT_EQ(5, spt.pieces(2).begin()); | |
EXPECT_EQ(6, spt.pieces(2).end()); | |
EXPECT_EQ(6, spt.pieces(3).begin()); | |
EXPECT_EQ(7, spt.pieces(3).end()); | |
EXPECT_EQ(7, spt.pieces(4).begin()); // あ | |
EXPECT_EQ(7, spt.pieces(4).end()); | |
EXPECT_EQ(7, spt.pieces(5).begin()); // あ | |
EXPECT_EQ(7, spt.pieces(5).end()); | |
EXPECT_EQ(7, spt.pieces(6).begin()); // あ | |
EXPECT_EQ(10, spt.pieces(6).end()); | |
EXPECT_EQ(10, spt.pieces(7).begin()); // </s> | |
EXPECT_EQ(10, spt.pieces(7).end()); | |
} | |
// Crash if | |
// ModelInterface::Encode() returns shorter results. | |
{ | |
auto mock = std::make_unique<MockModel>(); | |
const EncodeResult result = {{WS "ABC", 3}}; | |
mock->SetEncodeResult(kInput, result); | |
sp.SetModel(std::move(mock)); | |
sp.SetNormalizer( | |
std::make_unique<normalizer::Normalizer>(normalization_spec)); | |
SentencePieceText spt; | |
// Expects crash. | |
EXPECT_FALSE(sp.Encode("ABC DEF", &spt).ok()); | |
} | |
// Crash if | |
// ModelInterface::Encode() returns longer results. | |
{ | |
auto mock = std::make_unique<MockModel>(); | |
const EncodeResult result = { | |
{WS "ABC", 3}, {WS "DE", 4}, {"F", 5}, {"G", 6}}; | |
mock->SetEncodeResult(kInput, result); | |
sp.SetModel(std::move(mock)); | |
sp.SetNormalizer( | |
std::make_unique<normalizer::Normalizer>(normalization_spec)); | |
SentencePieceText spt; | |
// Expects crash. | |
EXPECT_FALSE(sp.Encode("ABC DEF", &spt).ok()); | |
} | |
// Crash if | |
// ModelInterface::Encode() returns an empty piece. | |
{ | |
auto mock = std::make_unique<MockModel>(); | |
const EncodeResult result = { | |
{WS "ABC", 3}, {WS "DE", 4}, {"", 5}, {"F", 6}}; | |
mock->SetEncodeResult(kInput, result); | |
sp.SetModel(std::move(mock)); | |
sp.SetNormalizer( | |
std::make_unique<normalizer::Normalizer>(normalization_spec)); | |
SentencePieceText spt; | |
// Expects crash. | |
EXPECT_FALSE(sp.Encode("ABC DEF", &spt).ok()); | |
} | |
// Halfwidth to Fullwidith katakana normalization. | |
{ | |
auto mock = std::make_unique<MockModel>(); | |
const EncodeResult result = {{WS "グー", 3}, {"グル", 4}, {"</s>", 2}}; | |
const absl::string_view input = WS "グーグル"; | |
mock->SetEncodeResult(input, result); | |
sp.SetModel(std::move(mock)); | |
std::vector<std::string> output; | |
EXPECT_TRUE(sp.Encode("グーグル", &output).ok()); | |
EXPECT_EQ(GetSpVec(result), output); | |
SentencePieceText spt; | |
EXPECT_TRUE(sp.Encode("グーグル", &spt).ok()); | |
EXPECT_EQ(3, spt.pieces_size()); | |
for (int i = 0; i < 3; ++i) { | |
EXPECT_EQ(result[i].first, spt.pieces(i).piece()); | |
} | |
EXPECT_EQ("グー", spt.pieces(0).surface()); | |
EXPECT_EQ("グル", spt.pieces(1).surface()); | |
EXPECT_EQ("", spt.pieces(2).surface()); | |
EXPECT_EQ(3, spt.pieces(0).id()); | |
EXPECT_EQ(4, spt.pieces(1).id()); | |
EXPECT_EQ(2, spt.pieces(2).id()); | |
EXPECT_EQ(0, spt.pieces(0).begin()); | |
EXPECT_EQ(9, spt.pieces(0).end()); | |
EXPECT_EQ(9, spt.pieces(1).begin()); | |
EXPECT_EQ(18, spt.pieces(1).end()); | |
EXPECT_EQ(18, spt.pieces(2).begin()); // </s> | |
EXPECT_EQ(18, spt.pieces(2).end()); | |
} | |
// One to many normalization. | |
{ | |
auto mock = std::make_unique<MockModel>(); | |
const EncodeResult result = {{WS "株式", 3}, {"会社", 4}, {"</s>", 2}}; | |
const absl::string_view input = WS "株式会社"; | |
mock->SetEncodeResult(input, result); | |
sp.SetModel(std::move(mock)); | |
std::vector<std::string> output; | |
EXPECT_TRUE(sp.Encode("㍿", &output).ok()); | |
EXPECT_EQ(GetSpVec(result), output); | |
SentencePieceText spt; | |
EXPECT_TRUE(sp.Encode("㍿", &spt).ok()); | |
EXPECT_EQ(3, spt.pieces_size()); | |
for (int i = 0; i < 3; ++i) { | |
EXPECT_EQ(result[i].first, spt.pieces(i).piece()); | |
} | |
EXPECT_EQ("", spt.pieces(0).surface()); | |
EXPECT_EQ("㍿", spt.pieces(1).surface()); | |
EXPECT_EQ("", spt.pieces(2).surface()); | |
EXPECT_EQ(3, spt.pieces(0).id()); | |
EXPECT_EQ(4, spt.pieces(1).id()); | |
EXPECT_EQ(2, spt.pieces(2).id()); | |
EXPECT_EQ(0, spt.pieces(0).begin()); // 株式 | |
EXPECT_EQ(0, spt.pieces(0).end()); | |
EXPECT_EQ(0, spt.pieces(1).begin()); // 会社 | |
EXPECT_EQ(3, spt.pieces(1).end()); | |
EXPECT_EQ(3, spt.pieces(2).begin()); // </s> | |
EXPECT_EQ(3, spt.pieces(2).end()); | |
} | |
} | |
TEST(SentencepieceProcessorTest, NBestEncodeTest) { | |
const std::string kInput = WS "ABC" WS "DEF"; | |
SentencePieceProcessor sp; | |
const auto normalization_spec = MakeDefaultNormalizerSpec(); | |
auto mock = std::make_unique<MockModel>(); | |
const NBestEncodeResult result = { | |
{{{WS "ABC", 3}, {WS "DE", 4}, {"F", 0}, {"</s>", 2}}, | |
static_cast<float>(1.0)}, | |
{{{WS "AB", 5}, {WS "CD", 6}, {"EF", 7}, {"</s>", 2}}, | |
static_cast<float>(0.9)}}; | |
mock->SetNBestEncodeResult(kInput, result); | |
sp.SetModel(std::move(mock)); | |
sp.SetNormalizer( | |
std::make_unique<normalizer::Normalizer>(normalization_spec)); | |
std::vector<std::vector<std::string>> output; | |
EXPECT_TRUE(sp.NBestEncode("ABC DEF", 2, &output).ok()); | |
EXPECT_EQ(2, output.size()); | |
EXPECT_EQ(GetSpVec(result[0].first), output[0]); | |
EXPECT_EQ(GetSpVec(result[1].first), output[1]); | |
std::vector<std::vector<int>> ids; | |
EXPECT_TRUE(sp.NBestEncode("ABC DEF", 2, &ids).ok()); | |
EXPECT_EQ(2, ids.size()); | |
EXPECT_EQ(GetIdVec(result[0].first), ids[0]); | |
EXPECT_EQ(GetIdVec(result[1].first), ids[1]); | |
NBestSentencePieceText spt; | |
EXPECT_TRUE(sp.NBestEncode("ABC DEF", 2, &spt).ok()); | |
EXPECT_EQ(2, spt.nbests_size()); | |
EXPECT_EQ(4, spt.nbests(0).pieces_size()); | |
EXPECT_EQ(4, spt.nbests(1).pieces_size()); | |
EXPECT_NEAR(result[0].second, spt.nbests(0).score(), 0.001); | |
EXPECT_NEAR(result[1].second, spt.nbests(1).score(), 0.001); | |
for (int i = 0; i < 4; ++i) { | |
EXPECT_EQ(result[0].first[i].first, spt.nbests(0).pieces(i).piece()); | |
EXPECT_EQ(result[1].first[i].first, spt.nbests(1).pieces(i).piece()); | |
} | |
NBestSentencePieceText spt2; | |
EXPECT_TRUE( | |
spt2.ParseFromString(sp.NBestEncodeAsSerializedProto("ABC DEF", 2))); | |
EXPECT_EQ(spt.SerializeAsString(), spt2.SerializeAsString()); | |
auto mock_empty = std::make_unique<MockModel>(); | |
mock_empty->SetNBestEncodeResult(kInput, {}); | |
sp.SetModel(std::move(mock_empty)); | |
EXPECT_FALSE(sp.NBestEncode("ABC DEF", 2, &output).ok()); | |
} | |
TEST(SentencepieceProcessorTest, SampleEncodeTest) { | |
const std::string kInput = WS "ABC" WS "DEF"; | |
SentencePieceProcessor sp; | |
const auto normalization_spec = MakeDefaultNormalizerSpec(); | |
auto mock = std::make_unique<MockModel>(); | |
const EncodeResult result = { | |
{WS "ABC", 3}, {WS "DE", 4}, {"F", 0}, {"</s>", 2}}; | |
const NBestEncodeResult nbest_result = { | |
{{{WS "ABC", 3}, {WS "DE", 4}, {"F", 0}, {"</s>", 2}}, | |
static_cast<float>(1.0)}, | |
{{{WS "AB", 5}, {WS "CD", 6}, {"EF", 7}, {"</s>", 2}}, | |
static_cast<float>(0.1)}}; | |
mock->SetNBestEncodeResult(kInput, nbest_result); | |
mock->SetEncodeResult(kInput, result); | |
sp.SetModel(std::move(mock)); | |
sp.SetNormalizer( | |
std::make_unique<normalizer::Normalizer>(normalization_spec)); | |
std::vector<std::string> output; | |
EXPECT_TRUE(sp.SampleEncode("ABC DEF", -1, 0.5, &output).ok()); | |
EXPECT_EQ(4, output.size()); | |
EXPECT_EQ(GetSpVec(result), output); | |
std::vector<int> ids; | |
EXPECT_TRUE(sp.SampleEncode("ABC DEF", -1, 0.5, &ids).ok()); | |
EXPECT_EQ(4, ids.size()); | |
EXPECT_EQ(GetIdVec(result), ids); | |
SentencePieceText spt; | |
EXPECT_TRUE(sp.SampleEncode("ABC DEF", -1, 0.5, &spt).ok()); | |
EXPECT_EQ(4, spt.pieces_size()); | |
for (int i = 0; i < 4; ++i) { | |
EXPECT_EQ(result[i].first, spt.pieces(i).piece()); | |
EXPECT_EQ(result[i].second, spt.pieces(i).id()); | |
} | |
SentencePieceText spt2; | |
EXPECT_TRUE(spt2.ParseFromString( | |
sp.SampleEncodeAsSerializedProto("ABC DEF", -1, 0.5))); | |
EXPECT_EQ(spt.SerializeAsString(), spt2.SerializeAsString()); | |
EXPECT_FALSE(sp.SampleEncode("ABC DEF", 1024, 0.5, &output).ok()); | |
EXPECT_TRUE(sp.SampleEncode("ABC DEF", 0, 0.5, &output).ok()); | |
EXPECT_TRUE(sp.SampleEncode("ABC DEF", 1, 0.5, &output).ok()); | |
std::vector<int> freq(2, 0); | |
for (int i = 0; i < 5000; ++i) { | |
EXPECT_TRUE(sp.SampleEncode("ABC DEF", 20, 0.5, &output).ok()); | |
EXPECT_EQ(4, output.size()); | |
if (GetSpVec(nbest_result[0].first) == output) | |
freq[0]++; | |
else if (GetSpVec(nbest_result[1].first) == output) | |
freq[1]++; | |
else | |
LOG(FATAL) << "Invalid result."; | |
} | |
const float expected_prob = | |
std::exp(0.5 * 1.0) / (std::exp(0.5 * 1.0) + std::exp(0.5 * 0.1)); | |
const float prob = 1.0 * freq[0] / (freq[0] + freq[1]); | |
EXPECT_NEAR(prob, expected_prob, 0.05); | |
auto mock_empty = std::make_unique<MockModel>(); | |
mock_empty->SetNBestEncodeResult(kInput, {}); | |
sp.SetModel(std::move(mock_empty)); | |
EXPECT_FALSE(sp.SampleEncode("ABC DEF", 10, 0.5, &output).ok()); | |
} | |
TEST(SentencepieceProcessorTest, DecodeTest) { | |
class DecodeMockModel : public ModelInterface { | |
public: | |
EncodeResult Encode(absl::string_view normalized) const override { | |
return {}; | |
} | |
int GetPieceSize() const override { return 7; } | |
int PieceToId(absl::string_view piece) const override { | |
static absl::flat_hash_map<absl::string_view, int> kMap = { | |
{"<unk>", 0}, {"<s>", 1}, {"</s>", 2}, {WS "ABC", 3}, | |
{WS "DE", 4}, {"F", 5}, {"G" WS "H", 6}}; | |
return port::FindWithDefault(kMap, piece, 0); | |
} | |
const std::string &IdToPiece(int id) const override { | |
static std::vector<std::string> kMap = { | |
"<unk>", "<s>", "</s>", WS "ABC", WS "DE", "F", "G" WS "H"}; | |
return kMap[id]; | |
} | |
bool IsUnknown(int id) const override { return (id == 0); } | |
bool IsControl(int id) const override { return (id == 1 || id == 2); } | |
bool IsByte(int id) const override { return false; } | |
float GetScore(int id) const override { return 0.0; } | |
}; | |
const std::vector<std::string> input = {"<s>", WS "ABC", "<unk>", WS "DE", | |
"F", "G" WS "H", "I", "</s>"}; | |
{ | |
SentencePieceProcessor sp; | |
auto mock = std::make_unique<DecodeMockModel>(); | |
sp.SetModel(std::move(mock)); | |
const auto normalization_spec = MakeDefaultNormalizerSpec(); | |
sp.SetNormalizer( | |
std::make_unique<normalizer::Normalizer>(normalization_spec)); | |
SentencePieceText spt; | |
EXPECT_TRUE(sp.Decode(input, &spt).ok()); | |
EXPECT_EQ("ABC \xE2\x81\x87 DEFG HI", spt.text()); | |
EXPECT_EQ(8, spt.pieces_size()); | |
for (int i = 0; i < 6; ++i) { | |
EXPECT_EQ(input[i], spt.pieces(i).piece()); | |
} | |
EXPECT_EQ("", spt.pieces(0).surface()); | |
EXPECT_EQ("ABC", spt.pieces(1).surface()); | |
EXPECT_EQ(" \xE2\x81\x87 ", spt.pieces(2).surface()); | |
EXPECT_EQ(" DE", spt.pieces(3).surface()); | |
EXPECT_EQ("F", spt.pieces(4).surface()); | |
EXPECT_EQ("G H", spt.pieces(5).surface()); | |
EXPECT_EQ("I", spt.pieces(6).surface()); | |
EXPECT_EQ("", spt.pieces(7).surface()); | |
EXPECT_EQ(0, spt.pieces(0).begin()); | |
EXPECT_EQ(0, spt.pieces(0).end()); | |
EXPECT_EQ(0, spt.pieces(1).begin()); | |
EXPECT_EQ(3, spt.pieces(1).end()); | |
EXPECT_EQ(3, spt.pieces(2).begin()); | |
EXPECT_EQ(8, spt.pieces(2).end()); | |
EXPECT_EQ(8, spt.pieces(3).begin()); | |
EXPECT_EQ(11, spt.pieces(3).end()); | |
EXPECT_EQ(11, spt.pieces(4).begin()); | |
EXPECT_EQ(12, spt.pieces(4).end()); | |
EXPECT_EQ(12, spt.pieces(5).begin()); | |
EXPECT_EQ(15, spt.pieces(5).end()); | |
EXPECT_EQ(15, spt.pieces(6).begin()); | |
EXPECT_EQ(16, spt.pieces(6).end()); | |
EXPECT_EQ(16, spt.pieces(7).begin()); | |
EXPECT_EQ(16, spt.pieces(7).end()); | |
SentencePieceText spt2; | |
EXPECT_TRUE(spt2.ParseFromString(sp.DecodePiecesAsSerializedProto(input))); | |
EXPECT_EQ(spt.SerializeAsString(), spt2.SerializeAsString()); | |
} | |
// unk_surface is not defined. | |
{ | |
SentencePieceProcessor sp; | |
auto proto = std::make_unique<ModelProto>(); | |
sp.Load(std::move(proto)).IgnoreError(); | |
auto mock = std::make_unique<DecodeMockModel>(); | |
sp.SetModel(std::move(mock)); | |
const auto normalization_spec = MakeDefaultNormalizerSpec(); | |
sp.SetNormalizer( | |
std::make_unique<normalizer::Normalizer>(normalization_spec)); | |
SentencePieceText spt; | |
EXPECT_TRUE(sp.Decode(input, &spt).ok()); | |
EXPECT_EQ("ABC \xE2\x81\x87 DEFG HI", spt.text()); | |
EXPECT_EQ(8, spt.pieces_size()); | |
} | |
{ | |
SentencePieceProcessor sp; | |
auto proto = std::make_unique<ModelProto>(); | |
proto->mutable_trainer_spec()->set_unk_surface(""); | |
sp.Load(std::move(proto)).IgnoreError(); | |
auto mock = std::make_unique<DecodeMockModel>(); | |
sp.SetModel(std::move(mock)); | |
const auto normalization_spec = MakeDefaultNormalizerSpec(); | |
sp.SetNormalizer( | |
std::make_unique<normalizer::Normalizer>(normalization_spec)); | |
SentencePieceText spt; | |
EXPECT_TRUE(sp.Decode(input, &spt).ok()); | |
EXPECT_EQ("ABC DEFG HI", spt.text()); | |
EXPECT_EQ(8, spt.pieces_size()); | |
} | |
{ | |
SentencePieceProcessor sp; | |
auto proto = std::make_unique<ModelProto>(); | |
proto->mutable_trainer_spec()->set_unk_surface("<UNK>"); | |
sp.Load(std::move(proto)).IgnoreError(); | |
auto mock = std::make_unique<DecodeMockModel>(); | |
sp.SetModel(std::move(mock)); | |
const auto normalization_spec = MakeDefaultNormalizerSpec(); | |
sp.SetNormalizer( | |
std::make_unique<normalizer::Normalizer>(normalization_spec)); | |
SentencePieceText spt; | |
EXPECT_TRUE(sp.Decode(input, &spt).ok()); | |
EXPECT_EQ("ABC<UNK> DEFG HI", spt.text()); | |
EXPECT_EQ(8, spt.pieces_size()); | |
} | |
{ | |
SentencePieceProcessor sp; | |
auto proto = std::make_unique<ModelProto>(); | |
proto->mutable_trainer_spec()->set_unk_surface(""); | |
proto->mutable_normalizer_spec()->set_add_dummy_prefix(false); | |
proto->mutable_normalizer_spec()->set_remove_extra_whitespaces(false); | |
sp.Load(std::move(proto)).IgnoreError(); | |
auto mock = std::make_unique<DecodeMockModel>(); | |
sp.SetModel(std::move(mock)); | |
const auto normalization_spec = MakeDefaultNormalizerSpec(); | |
sp.SetNormalizer( | |
std::make_unique<normalizer::Normalizer>(normalization_spec)); | |
SentencePieceText spt; | |
EXPECT_TRUE(sp.Decode(input, &spt).ok()); | |
EXPECT_EQ(" ABC DEFG HI", spt.text()); | |
EXPECT_EQ(8, spt.pieces_size()); | |
} | |
} | |
TEST(SentencepieceProcessorTest, DummyPrefixDecodeTest) { | |
class DecodeMockModel : public ModelInterface { | |
public: | |
EncodeResult Encode(absl::string_view normalized) const override { | |
return {}; | |
} | |
int GetPieceSize() const override { return 7; } | |
int PieceToId(absl::string_view piece) const override { | |
static absl::flat_hash_map<absl::string_view, int> kMap = { | |
{"<unk>", 0}, {"<s>", 1}, {"</s>", 2}, {WS "ABC", 3}, | |
{WS "DE", 4}, {"F", 5}, {"G" WS "H", 6}, {WS, 7}}; | |
return port::FindWithDefault(kMap, piece, 0); | |
} | |
const std::string &IdToPiece(int id) const override { | |
static std::vector<std::string> kMap = { | |
"<unk>", "<s>", "</s>", WS "ABC", WS "DE", "F", "G" WS "H", WS}; | |
return kMap[id]; | |
} | |
bool IsUnknown(int id) const override { return (id == 0); } | |
bool IsControl(int id) const override { return (id == 1 || id == 2); } | |
bool IsByte(int id) const override { return false; } | |
float GetScore(int id) const override { return 0.0; } | |
}; | |
// start the sequence with a whitespace token | |
const std::vector<std::string> input = { | |
"<s>", WS, WS "ABC", "<unk>", WS "DE", "F", "G" WS "H", "I", "</s>"}; | |
{ | |
SentencePieceProcessor sp; | |
auto proto = std::make_unique<ModelProto>(); | |
proto->mutable_trainer_spec()->set_unk_surface(""); | |
proto->mutable_normalizer_spec()->set_add_dummy_prefix(true); | |
proto->mutable_normalizer_spec()->set_remove_extra_whitespaces(false); | |
sp.Load(std::move(proto)).IgnoreError(); | |
auto mock = std::make_unique<DecodeMockModel>(); | |
sp.SetModel(std::move(mock)); | |
const auto normalization_spec = MakeDefaultNormalizerSpec(); | |
sp.SetNormalizer( | |
std::make_unique<normalizer::Normalizer>(normalization_spec)); | |
SentencePieceText spt; | |
EXPECT_TRUE(sp.Decode(input, &spt).ok()); | |
EXPECT_EQ(" ABC DEFG HI", spt.text()); | |
EXPECT_EQ(9, spt.pieces_size()); | |
} | |
{ | |
SentencePieceProcessor sp; | |
auto proto = std::make_unique<ModelProto>(); | |
proto->mutable_trainer_spec()->set_unk_surface(""); | |
proto->mutable_normalizer_spec()->set_add_dummy_prefix(true); | |
proto->mutable_normalizer_spec()->set_remove_extra_whitespaces(true); | |
sp.Load(std::move(proto)).IgnoreError(); | |
auto mock = std::make_unique<DecodeMockModel>(); | |
sp.SetModel(std::move(mock)); | |
const auto normalization_spec = MakeDefaultNormalizerSpec(); | |
sp.SetNormalizer( | |
std::make_unique<normalizer::Normalizer>(normalization_spec)); | |
SentencePieceText spt; | |
EXPECT_TRUE(sp.Decode(input, &spt).ok()); | |
EXPECT_EQ("ABC DEFG HI", spt.text()); | |
EXPECT_EQ(9, spt.pieces_size()); | |
} | |
} | |
TEST(SentencepieceProcessorTest, ByteFallbackDecodeTest) { | |
class ByteFallbackDecodeMockModel : public ModelInterface { | |
public: | |
EncodeResult Encode(absl::string_view normalized) const override { | |
return {}; | |
} | |
int PieceToId(absl::string_view piece) const override { | |
using Map = absl::flat_hash_map<std::string, int>; | |
static const Map kMap = []() -> Map { | |
Map m = { | |
{"<unk>", 0}, {"<s>", 1}, {"</s>", 2}, {"A", 3}, {"B", 4}, {"C", 5}, | |
}; | |
for (int i = 0; i < 256; ++i) { | |
m[ByteToPiece(i)] = 6 + i; | |
} | |
return m; | |
}(); | |
return port::FindWithDefault(kMap, std::string(piece), 0); | |
} | |
const std::string &IdToPiece(int id) const override { | |
static std::vector<std::string> kMap = []() -> std::vector<std::string> { | |
std::vector<std::string> m = {"<unk>", "<s>", "</s>", "A", "B", "C"}; | |
for (int i = 0; i < 256; ++i) { | |
m.push_back(ByteToPiece(i)); | |
} | |
return m; | |
}(); | |
return kMap[id]; | |
} | |
int GetPieceSize() const override { return 256; } | |
bool IsUnknown(int id) const override { return (id == 0); } | |
bool IsControl(int id) const override { return (id == 1 || id == 2); } | |
bool IsByte(int id) const override { return id >= 6; } | |
bool ByteFallbackEnabled() const override { return true; } | |
}; | |
SentencePieceProcessor sp; | |
auto mock = std::make_unique<ByteFallbackDecodeMockModel>(); | |
sp.SetModel(std::move(mock)); | |
const auto normalization_spec = MakeDefaultNormalizerSpec(); | |
sp.SetNormalizer( | |
std::make_unique<normalizer::Normalizer>(normalization_spec)); | |
{ | |
const std::vector<std::string> input = { | |
"<s>", | |
"A", | |
"B", | |
// "あ" -> 0xE3 0x81 0x82 | |
"<0xE3>", | |
"<0x81>", | |
"<0x82>", | |
// "Z" -> 0x5A | |
"<0x5A>", | |
// "Ω" -> 0xCE 0xA9 | |
"<0xCE>", | |
"<0xA9>", | |
"C", | |
// Invalid UTF-8 bytes. | |
"<0xE0>", | |
"<0x80>", | |
// "い" -> 0xE3 0x81 0x84 | |
"<0xE3>", | |
"<0x81>", | |
"<0x84>", | |
// REPLACEMENT CHARACTER as byte pieces. | |
"<0xEF>", | |
"<0xBF>", | |
"<0xBD>", | |
}; | |
SentencePieceText spt; | |
EXPECT_TRUE(sp.Decode(input, &spt).ok()); | |
EXPECT_EQ("ABあZΩC\xEF\xBF\xBD\xEF\xBF\xBDい\xEF\xBF\xBD", spt.text()); | |
EXPECT_EQ(18, spt.pieces_size()); | |
for (int i = 0; i < 18; ++i) { | |
EXPECT_EQ(input[i], spt.pieces(i).piece()); | |
} | |
EXPECT_EQ("", spt.pieces(0).surface()); | |
EXPECT_EQ(0, spt.pieces(0).begin()); | |
EXPECT_EQ(0, spt.pieces(0).end()); | |
EXPECT_EQ("A", spt.pieces(1).surface()); | |
EXPECT_EQ(0, spt.pieces(1).begin()); | |
EXPECT_EQ(1, spt.pieces(1).end()); | |
EXPECT_EQ("B", spt.pieces(2).surface()); | |
EXPECT_EQ(1, spt.pieces(2).begin()); | |
EXPECT_EQ(2, spt.pieces(2).end()); | |
EXPECT_EQ("", spt.pieces(3).surface()); | |
EXPECT_EQ("", spt.pieces(4).surface()); | |
EXPECT_EQ("あ", spt.pieces(5).surface()); | |
EXPECT_EQ(2, spt.pieces(3).begin()); | |
EXPECT_EQ(2, spt.pieces(3).end()); | |
EXPECT_EQ(2, spt.pieces(4).begin()); | |
EXPECT_EQ(2, spt.pieces(4).end()); | |
EXPECT_EQ(2, spt.pieces(5).begin()); | |
EXPECT_EQ(5, spt.pieces(5).end()); | |
EXPECT_EQ("Z", spt.pieces(6).surface()); | |
EXPECT_EQ(5, spt.pieces(6).begin()); | |
EXPECT_EQ(6, spt.pieces(6).end()); | |
EXPECT_EQ("", spt.pieces(7).surface()); | |
EXPECT_EQ("Ω", spt.pieces(8).surface()); | |
EXPECT_EQ(6, spt.pieces(7).begin()); | |
EXPECT_EQ(6, spt.pieces(7).end()); | |
EXPECT_EQ(6, spt.pieces(8).begin()); | |
EXPECT_EQ(8, spt.pieces(8).end()); | |
EXPECT_EQ("C", spt.pieces(9).surface()); | |
EXPECT_EQ(8, spt.pieces(9).begin()); | |
EXPECT_EQ(9, spt.pieces(9).end()); | |
EXPECT_EQ("\xEF\xBF\xBD", spt.pieces(10).surface()); | |
EXPECT_EQ(9, spt.pieces(10).begin()); | |
EXPECT_EQ(12, spt.pieces(10).end()); | |
EXPECT_EQ("\xEF\xBF\xBD", spt.pieces(11).surface()); | |
EXPECT_EQ(12, spt.pieces(11).begin()); | |
EXPECT_EQ(15, spt.pieces(11).end()); | |
EXPECT_EQ("", spt.pieces(12).surface()); | |
EXPECT_EQ("", spt.pieces(13).surface()); | |
EXPECT_EQ("い", spt.pieces(14).surface()); | |
EXPECT_EQ(15, spt.pieces(12).begin()); | |
EXPECT_EQ(15, spt.pieces(12).end()); | |
EXPECT_EQ(15, spt.pieces(13).begin()); | |
EXPECT_EQ(15, spt.pieces(13).end()); | |
EXPECT_EQ(15, spt.pieces(14).begin()); | |
EXPECT_EQ(18, spt.pieces(14).end()); | |
EXPECT_EQ("", spt.pieces(15).surface()); | |
EXPECT_EQ("", spt.pieces(16).surface()); | |
EXPECT_EQ("\xEF\xBF\xBD", spt.pieces(17).surface()); | |
EXPECT_EQ(18, spt.pieces(15).begin()); | |
EXPECT_EQ(18, spt.pieces(15).end()); | |
EXPECT_EQ(18, spt.pieces(16).begin()); | |
EXPECT_EQ(18, spt.pieces(16).end()); | |
EXPECT_EQ(18, spt.pieces(17).begin()); | |
EXPECT_EQ(21, spt.pieces(17).end()); | |
} | |
} | |
void AddPiece(ModelProto *model_proto, absl::string_view piece, | |
float score = 0.0) { | |
auto *sp = model_proto->add_pieces(); | |
sp->set_piece(std::string(piece)); | |
sp->set_score(score); | |
} | |
TEST(SentencePieceProcessorTest, LoadInvalidModelTest) { | |
SentencePieceProcessor sp; | |
EXPECT_FALSE(sp.Load("").ok()); | |
EXPECT_FALSE(sp.Load("__UNKNOWN_FILE__").ok()); | |
} | |
TEST(SentencePieceProcessorTest, LoadSerializedProtoTest) { | |
ModelProto model_proto; | |
auto *sp1 = model_proto.add_pieces(); | |
sp1->set_type(ModelProto::SentencePiece::UNKNOWN); | |
sp1->set_piece("<unk>"); | |
AddPiece(&model_proto, WS, 0.0); | |
*(model_proto.mutable_normalizer_spec()) = MakeDefaultNormalizerSpec(); | |
SentencePieceProcessor sp; | |
EXPECT_FALSE(sp.LoadFromSerializedProto("__NOT_A_PROTO__").ok()); | |
EXPECT_TRUE(sp.LoadFromSerializedProto(model_proto.SerializeAsString()).ok()); | |
EXPECT_EQ(model_proto.SerializeAsString(), | |
sp.model_proto().SerializeAsString()); | |
} | |
TEST(SentencePieceProcessorTest, EndToEndTest) { | |
ModelProto model_proto; | |
auto *sp1 = model_proto.add_pieces(); | |
auto *sp2 = model_proto.add_pieces(); | |
auto *sp3 = model_proto.add_pieces(); | |
sp1->set_type(ModelProto::SentencePiece::UNKNOWN); | |
sp1->set_piece("<unk>"); | |
sp2->set_type(ModelProto::SentencePiece::CONTROL); | |
sp2->set_piece("<s>"); | |
sp3->set_type(ModelProto::SentencePiece::CONTROL); | |
sp3->set_piece("</s>"); | |
AddPiece(&model_proto, "a", 0.0); | |
AddPiece(&model_proto, "b", 0.3); | |
AddPiece(&model_proto, "c", 0.2); | |
AddPiece(&model_proto, "ab", 1.0); | |
AddPiece(&model_proto, "\xE2\x96\x81", 3.0); // kSpaceSymbol | |
*(model_proto.mutable_normalizer_spec()) = MakeDefaultNormalizerSpec(); | |
{ | |
auto output = filesystem::NewWritableFile( | |
util::JoinPath(::testing::TempDir(), "model"), true); | |
output->Write(model_proto.SerializeAsString()); | |
} | |
SentencePieceProcessor sp; | |
EXPECT_TRUE( | |
sp.Load(util::JoinPath(::testing::TempDir(), "model")).ok()); | |
EXPECT_EQ(model_proto.SerializeAsString(), | |
sp.model_proto().SerializeAsString()); | |
EXPECT_EQ(8, sp.GetPieceSize()); | |
EXPECT_EQ(0, sp.PieceToId("<unk>")); | |
EXPECT_EQ(1, sp.PieceToId("<s>")); | |
EXPECT_EQ(2, sp.PieceToId("</s>")); | |
EXPECT_EQ(3, sp.PieceToId("a")); | |
EXPECT_EQ(4, sp.PieceToId("b")); | |
EXPECT_EQ(5, sp.PieceToId("c")); | |
EXPECT_EQ(6, sp.PieceToId("ab")); | |
EXPECT_EQ(7, sp.PieceToId("\xE2\x96\x81")); | |
EXPECT_EQ("<unk>", sp.IdToPiece(0)); | |
EXPECT_EQ("<s>", sp.IdToPiece(1)); | |
EXPECT_EQ("</s>", sp.IdToPiece(2)); | |
EXPECT_EQ("a", sp.IdToPiece(3)); | |
EXPECT_EQ("b", sp.IdToPiece(4)); | |
EXPECT_EQ("c", sp.IdToPiece(5)); | |
EXPECT_EQ("ab", sp.IdToPiece(6)); | |
EXPECT_EQ("\xE2\x96\x81", sp.IdToPiece(7)); | |
EXPECT_NEAR(0.0, sp.GetScore(0), 0.001); | |
EXPECT_NEAR(0.0, sp.GetScore(1), 0.001); | |
EXPECT_NEAR(0.0, sp.GetScore(2), 0.001); | |
EXPECT_NEAR(0.0, sp.GetScore(3), 0.001); | |
EXPECT_NEAR(0.3, sp.GetScore(4), 0.001); | |
EXPECT_NEAR(0.2, sp.GetScore(5), 0.001); | |
EXPECT_NEAR(1.0, sp.GetScore(6), 0.001); | |
EXPECT_NEAR(3.0, sp.GetScore(7), 0.001); | |
EXPECT_TRUE(sp.IsUnknown(0)); | |
EXPECT_FALSE(sp.IsUnknown(1)); | |
EXPECT_FALSE(sp.IsUnknown(2)); | |
EXPECT_FALSE(sp.IsUnknown(3)); | |
EXPECT_FALSE(sp.IsUnknown(4)); | |
EXPECT_FALSE(sp.IsUnknown(5)); | |
EXPECT_FALSE(sp.IsUnknown(6)); | |
EXPECT_FALSE(sp.IsUnknown(7)); | |
EXPECT_FALSE(sp.IsControl(0)); | |
EXPECT_TRUE(sp.IsControl(1)); | |
EXPECT_TRUE(sp.IsControl(2)); | |
EXPECT_FALSE(sp.IsControl(3)); | |
EXPECT_FALSE(sp.IsControl(4)); | |
EXPECT_FALSE(sp.IsControl(5)); | |
EXPECT_FALSE(sp.IsControl(6)); | |
EXPECT_FALSE(sp.IsControl(7)); | |
EXPECT_EQ(0, sp.unk_id()); | |
EXPECT_EQ(1, sp.bos_id()); | |
EXPECT_EQ(2, sp.eos_id()); | |
EXPECT_EQ(-1, sp.pad_id()); | |
{ | |
std::vector<std::string> sps; | |
const std::vector<std::string> expected_str = {WS, "ab", "c"}; | |
EXPECT_TRUE(sp.Encode("abc", &sps).ok()); | |
EXPECT_EQ(expected_str, sps); | |
std::vector<int> ids; | |
const std::vector<int> expected_id = {7, 6, 5}; | |
EXPECT_TRUE(sp.Encode("abc", &ids).ok()); | |
EXPECT_EQ(expected_id, ids); | |
} | |
{ | |
EXPECT_TRUE(sp.SetEncodeExtraOptions("bos").ok()); | |
std::vector<std::string> sps; | |
const std::vector<std::string> expected_str = {"<s>", WS, "ab", "c"}; | |
EXPECT_TRUE(sp.Encode("abc", &sps).ok()); | |
EXPECT_EQ(expected_str, sps); | |
std::vector<int> ids; | |
const std::vector<int> expected_id = {1, 7, 6, 5}; | |
EXPECT_TRUE(sp.Encode("abc", &ids).ok()); | |
EXPECT_EQ(expected_id, ids); | |
} | |
{ | |
EXPECT_TRUE(sp.SetEncodeExtraOptions("eos").ok()); | |
std::vector<std::string> sps; | |
const std::vector<std::string> expected_str = {WS, "ab", "c", "</s>"}; | |
EXPECT_TRUE(sp.Encode("abc", &sps).ok()); | |
EXPECT_EQ(expected_str, sps); | |
std::vector<int> ids; | |
const std::vector<int> expected_id = {7, 6, 5, 2}; | |
EXPECT_TRUE(sp.Encode("abc", &ids).ok()); | |
EXPECT_EQ(expected_id, ids); | |
} | |
{ | |
EXPECT_TRUE(sp.SetEncodeExtraOptions("reverse").ok()); | |
std::vector<std::string> sps; | |
const std::vector<std::string> expected_str = {"c", "ab", WS}; | |
EXPECT_TRUE(sp.Encode("abc", &sps).ok()); | |
EXPECT_EQ(expected_str, sps); | |
std::vector<int> ids; | |
const std::vector<int> expected_id = {5, 6, 7}; | |
EXPECT_TRUE(sp.Encode("abc", &ids).ok()); | |
EXPECT_EQ(expected_id, ids); | |
} | |
{ | |
EXPECT_TRUE(sp.SetEncodeExtraOptions("bos:eos").ok()); | |
std::vector<std::string> sps; | |
const std::vector<std::string> expected_str = {"<s>", WS, "ab", "c", | |
"</s>"}; | |
EXPECT_TRUE(sp.Encode("abc", &sps).ok()); | |
EXPECT_EQ(expected_str, sps); | |
std::vector<int> ids; | |
const std::vector<int> expected_id = {1, 7, 6, 5, 2}; | |
EXPECT_TRUE(sp.Encode("abc", &ids).ok()); | |
EXPECT_EQ(expected_id, ids); | |
} | |
{ | |
EXPECT_TRUE(sp.SetEncodeExtraOptions("reverse:bos:eos").ok()); | |
std::vector<std::string> sps; | |
const std::vector<std::string> expected_str = {"<s>", "c", "ab", WS, | |
"</s>"}; | |
EXPECT_TRUE(sp.Encode("abc", &sps).ok()); | |
EXPECT_EQ(expected_str, sps); | |
std::vector<int> ids; | |
const std::vector<int> expected_id = {1, 5, 6, 7, 2}; | |
EXPECT_TRUE(sp.Encode("abc", &ids).ok()); | |
EXPECT_EQ(expected_id, ids); | |
} | |
{ | |
EXPECT_TRUE(sp.SetEncodeExtraOptions("bos:eos:reverse").ok()); | |
std::vector<std::string> sps; | |
const std::vector<std::string> expected_str = {"</s>", "c", "ab", WS, | |
"<s>"}; | |
EXPECT_TRUE(sp.Encode("abc", &sps).ok()); | |
EXPECT_EQ(expected_str, sps); | |
std::vector<int> ids; | |
const std::vector<int> expected_id = {2, 5, 6, 7, 1}; | |
EXPECT_TRUE(sp.Encode("abc", &ids).ok()); | |
EXPECT_EQ(expected_id, ids); | |
} | |
{ | |
std::string output; | |
const std::vector<std::string> sps = {"ab", "c"}; | |
EXPECT_TRUE(sp.Decode(sps, &output).ok()); | |
EXPECT_EQ("abc", output); | |
const std::vector<int> ids = {3, 4, 5}; | |
EXPECT_TRUE(sp.Decode(ids, &output).ok()); | |
EXPECT_EQ("abc", output); | |
} | |
{ | |
EXPECT_TRUE(sp.SetDecodeExtraOptions("bos").ok()); | |
std::string output; | |
const std::vector<std::string> sps = {"ab", "c"}; | |
EXPECT_TRUE(sp.Decode(sps, &output).ok()); | |
EXPECT_EQ("abc", output); | |
const std::vector<int> ids = {3, 4, 5}; | |
EXPECT_TRUE(sp.Decode(ids, &output).ok()); | |
EXPECT_EQ("abc", output); | |
} | |
{ | |
EXPECT_TRUE(sp.SetDecodeExtraOptions("eos").ok()); | |
std::string output; | |
const std::vector<std::string> sps = {"ab", "c"}; | |
EXPECT_TRUE(sp.Decode(sps, &output).ok()); | |
EXPECT_EQ("abc", output); | |
const std::vector<int> ids = {3, 4, 5}; | |
EXPECT_TRUE(sp.Decode(ids, &output).ok()); | |
EXPECT_EQ("abc", output); | |
} | |
{ | |
EXPECT_TRUE(sp.SetDecodeExtraOptions("reverse").ok()); | |
std::string output; | |
const std::vector<std::string> sps = {"ab", "c"}; | |
EXPECT_TRUE(sp.Decode(sps, &output).ok()); | |
EXPECT_EQ("cab", output); | |
const std::vector<int> ids = {3, 4, 5}; | |
EXPECT_TRUE(sp.Decode(ids, &output).ok()); | |
EXPECT_EQ("cba", output); | |
} | |
{ | |
EXPECT_TRUE(sp.SetDecodeExtraOptions("bos:eos").ok()); | |
std::string output; | |
const std::vector<std::string> sps = {"ab", "c"}; | |
EXPECT_TRUE(sp.Decode(sps, &output).ok()); | |
EXPECT_EQ("abc", output); | |
const std::vector<int> ids = {3, 4, 5}; | |
EXPECT_TRUE(sp.Decode(ids, &output).ok()); | |
EXPECT_EQ("abc", output); | |
} | |
{ | |
EXPECT_TRUE(sp.SetDecodeExtraOptions("reverse:bos:eos").ok()); | |
std::string output; | |
const std::vector<std::string> sps = {"ab", "c"}; | |
EXPECT_TRUE(sp.Decode(sps, &output).ok()); | |
EXPECT_EQ("cab", output); | |
const std::vector<int> ids = {3, 4, 5}; | |
EXPECT_TRUE(sp.Decode(ids, &output).ok()); | |
EXPECT_EQ("cba", output); | |
} | |
// Out of range | |
{ | |
std::string output; | |
const std::vector<int> ids = {3, 4, 127}; | |
EXPECT_FALSE(sp.Decode(ids, &output).ok()); | |
} | |
{ | |
EXPECT_TRUE(sp.SetDecodeExtraOptions("bos:eos:reverse").ok()); | |
std::string output; | |
const std::vector<std::string> sps = {"ab", "c"}; | |
EXPECT_TRUE(sp.Decode(sps, &output).ok()); | |
EXPECT_EQ("cab", output); | |
const std::vector<int> ids = {3, 4, 5}; | |
EXPECT_TRUE(sp.Decode(ids, &output).ok()); | |
EXPECT_EQ("cba", output); | |
} | |
{ | |
EXPECT_TRUE(sp.SetDecodeExtraOptions("reverse:reverse").ok()); | |
std::string output; | |
const std::vector<std::string> sps = {"ab", "c"}; | |
EXPECT_TRUE(sp.Decode(sps, &output).ok()); | |
EXPECT_EQ("abc", output); | |
const std::vector<int> ids = {3, 4, 5}; | |
EXPECT_TRUE(sp.Decode(ids, &output).ok()); | |
EXPECT_EQ("abc", output); | |
} | |
EXPECT_TRUE(sp.SetEncodeExtraOptions("").ok()); | |
EXPECT_TRUE(sp.SetDecodeExtraOptions("").ok()); | |
EXPECT_FALSE(sp.SetEncodeExtraOptions("foo").ok()); | |
EXPECT_FALSE(sp.SetDecodeExtraOptions("foo").ok()); | |
auto RunTest = [&model_proto](const SentencePieceProcessor &sp) { | |
EXPECT_EQ(model_proto.SerializeAsString(), | |
sp.model_proto().SerializeAsString()); | |
EXPECT_EQ(8, sp.GetPieceSize()); | |
EXPECT_EQ(0, sp.PieceToId("<unk>")); | |
EXPECT_EQ(1, sp.PieceToId("<s>")); | |
EXPECT_EQ(2, sp.PieceToId("</s>")); | |
EXPECT_EQ(3, sp.PieceToId("a")); | |
EXPECT_EQ(4, sp.PieceToId("b")); | |
EXPECT_EQ(5, sp.PieceToId("c")); | |
EXPECT_EQ(6, sp.PieceToId("ab")); | |
EXPECT_EQ(7, sp.PieceToId("\xE2\x96\x81")); | |
EXPECT_EQ("<unk>", sp.IdToPiece(0)); | |
EXPECT_EQ("<s>", sp.IdToPiece(1)); | |
EXPECT_EQ("</s>", sp.IdToPiece(2)); | |
EXPECT_EQ("a", sp.IdToPiece(3)); | |
EXPECT_EQ("b", sp.IdToPiece(4)); | |
EXPECT_EQ("c", sp.IdToPiece(5)); | |
EXPECT_EQ("ab", sp.IdToPiece(6)); | |
EXPECT_EQ("\xE2\x96\x81", sp.IdToPiece(7)); | |
EXPECT_TRUE(sp.IsUnknown(0)); | |
EXPECT_FALSE(sp.IsUnknown(1)); | |
EXPECT_FALSE(sp.IsUnknown(2)); | |
EXPECT_FALSE(sp.IsUnknown(3)); | |
EXPECT_FALSE(sp.IsUnknown(4)); | |
EXPECT_FALSE(sp.IsUnknown(5)); | |
EXPECT_FALSE(sp.IsUnknown(6)); | |
EXPECT_FALSE(sp.IsUnknown(7)); | |
EXPECT_FALSE(sp.IsControl(0)); | |
EXPECT_TRUE(sp.IsControl(1)); | |
EXPECT_TRUE(sp.IsControl(2)); | |
EXPECT_FALSE(sp.IsControl(3)); | |
EXPECT_FALSE(sp.IsControl(4)); | |
EXPECT_FALSE(sp.IsControl(5)); | |
EXPECT_FALSE(sp.IsControl(6)); | |
EXPECT_FALSE(sp.IsControl(7)); | |
{ | |
std::vector<std::string> sps; | |
const std::vector<std::string> expected_str = {WS, "ab", "c"}; | |
EXPECT_TRUE(sp.Encode("abc", &sps).ok()); | |
EXPECT_EQ(expected_str, sps); | |
std::vector<int> ids; | |
const std::vector<int> expected_id = {7, 6, 5}; | |
EXPECT_TRUE(sp.Encode("abc", &ids).ok()); | |
EXPECT_EQ(expected_id, ids); | |
} | |
{ | |
std::string output; | |
const std::vector<std::string> sps = {"ab", "c"}; | |
EXPECT_TRUE(sp.Decode(sps, &output).ok()); | |
EXPECT_EQ("abc", output); | |
const std::vector<int> ids = {3, 4, 5}; | |
EXPECT_TRUE(sp.Decode(ids, &output).ok()); | |
EXPECT_EQ("abc", output); | |
} | |
}; | |
// Copies ModelProto. | |
{ | |
SentencePieceProcessor sp; | |
const ModelProto copied = model_proto; | |
EXPECT_TRUE(sp.Load(copied).ok()); | |
RunTest(sp); | |
} | |
// Moves ModelProto. | |
{ | |
SentencePieceProcessor sp; | |
auto moved = std::make_unique<ModelProto>(); | |
const ModelProto *moved_ptr = moved.get(); | |
*moved = model_proto; | |
EXPECT_TRUE(sp.Load(std::move(moved)).ok()); | |
EXPECT_EQ(moved_ptr, &sp.model_proto()); | |
RunTest(sp); | |
} | |
// Restrict Vocabulary. | |
{ | |
SentencePieceProcessor sp; | |
EXPECT_TRUE(sp.Load(model_proto).ok()); | |
EXPECT_TRUE(sp.SetVocabulary({"a", "b", "c"}).ok()); // remove "ab" | |
const std::vector<std::string> expected_str = {WS, "a", "b", "c"}; | |
std::vector<std::string> sps; | |
EXPECT_TRUE(sp.Encode("abc", &sps).ok()); | |
EXPECT_EQ(expected_str, sps); | |
std::vector<int> ids; | |
const std::vector<int> expected_id = {7, 3, 4, 5}; | |
EXPECT_TRUE(sp.Encode("abc", &ids).ok()); | |
EXPECT_EQ(expected_id, ids); | |
} | |
} | |
TEST(SentencePieceProcessorTest, SkipNormalizationTest) { | |
ModelProto model_proto; | |
auto *sp1 = model_proto.add_pieces(); | |
auto *sp2 = model_proto.add_pieces(); | |
sp1->set_type(ModelProto::SentencePiece::UNKNOWN); | |
sp1->set_piece("<unk>"); | |
sp2->set_type(ModelProto::SentencePiece::USER_DEFINED); | |
sp2->set_piece("<USER>"); | |
AddPiece(&model_proto, "a", 0.0); | |
AddPiece(&model_proto, "b", 0.3); | |
AddPiece(&model_proto, "c", 0.2); | |
AddPiece(&model_proto, "u", 0.2); | |
AddPiece(&model_proto, "s", 0.2); | |
AddPiece(&model_proto, "e", 0.2); | |
AddPiece(&model_proto, "r", 0.2); | |
*(model_proto.mutable_normalizer_spec()) = | |
SentencePieceTrainer::GetNormalizerSpec("nmt_nfkc_cf"); | |
SentencePieceProcessor sp; | |
EXPECT_TRUE(sp.Load(model_proto).ok()); | |
std::vector<std::string> pieces; | |
EXPECT_TRUE(sp.Encode("AB<USER>C<uSEr>", &pieces).ok()); | |
for (const auto &sp : pieces) LOG(INFO) << sp; | |
EXPECT_EQ(std::vector<std::string>( | |
{WS, "a", "b", "<USER>", "c", "<", "u", "s", "e", "r", ">"}), | |
pieces); | |
} | |
TEST(SentencePieceProcessorTest, ExtraOptionsUndefinedTest) { | |
ModelProto model_proto; | |
auto *sp1 = model_proto.add_pieces(); | |
// No BOS/EOS. | |
sp1->set_type(ModelProto::SentencePiece::UNKNOWN); | |
sp1->set_piece("<unk>"); | |
AddPiece(&model_proto, "a", 0.0); | |
AddPiece(&model_proto, "b", 0.3); | |
AddPiece(&model_proto, "c", 0.2); | |
AddPiece(&model_proto, "ab", 1.0); | |
SentencePieceProcessor sp; | |
EXPECT_TRUE(sp.Load(model_proto).ok()); | |
EXPECT_FALSE(sp.SetEncodeExtraOptions("bos").ok()); | |
EXPECT_FALSE(sp.SetDecodeExtraOptions("eos").ok()); | |
} | |
TEST(SentencePieceProcessorTest, OverrideSpecialPieceTest) { | |
ModelProto model_proto; | |
auto *sp1 = model_proto.add_pieces(); | |
auto *sp2 = model_proto.add_pieces(); | |
auto *sp3 = model_proto.add_pieces(); | |
model_proto.mutable_trainer_spec()->set_unk_piece("__UNK__"); | |
model_proto.mutable_trainer_spec()->set_bos_piece("__BOS__"); | |
model_proto.mutable_trainer_spec()->set_eos_piece("__EOS__"); | |
model_proto.mutable_trainer_spec()->set_pad_piece("__PAD__"); | |
// No BOS/EOS. | |
sp1->set_type(ModelProto::SentencePiece::UNKNOWN); | |
sp1->set_piece("__UNK__"); | |
sp2->set_type(ModelProto::SentencePiece::CONTROL); | |
sp2->set_piece("__BOS__"); | |
sp3->set_type(ModelProto::SentencePiece::CONTROL); | |
sp3->set_piece("__EOS__"); | |
AddPiece(&model_proto, "a", 0.0); | |
AddPiece(&model_proto, "b", 0.3); | |
SentencePieceProcessor sp; | |
EXPECT_TRUE(sp.Load(model_proto).ok()); | |
EXPECT_EQ(0, sp.unk_id()); | |
EXPECT_EQ(1, sp.bos_id()); | |
EXPECT_EQ(2, sp.eos_id()); | |
EXPECT_EQ(-1, sp.pad_id()); | |
EXPECT_EQ("__UNK__", sp.IdToPiece(sp.unk_id())); | |
EXPECT_EQ("__BOS__", sp.IdToPiece(sp.bos_id())); | |
EXPECT_EQ("__EOS__", sp.IdToPiece(sp.eos_id())); | |
} | |
TEST(SentencePieceProcessorTest, VocabularyTest) { | |
ModelProto model_proto; | |
auto *sp1 = model_proto.add_pieces(); | |
auto *sp2 = model_proto.add_pieces(); | |
auto *sp3 = model_proto.add_pieces(); | |
auto GetInlineFilename = [](const std::string content) { | |
{ | |
auto out = filesystem::NewWritableFile( | |
util::JoinPath(::testing::TempDir(), "vocab.txt")); | |
out->Write(content); | |
} | |
return util::JoinPath(::testing::TempDir(), "vocab.txt"); | |
}; | |
sp1->set_type(ModelProto::SentencePiece::UNKNOWN); | |
sp1->set_piece("<unk>"); | |
sp2->set_type(ModelProto::SentencePiece::CONTROL); | |
sp2->set_piece("<s>"); | |
sp3->set_type(ModelProto::SentencePiece::CONTROL); | |
sp3->set_piece("</s>"); | |
AddPiece(&model_proto, "aa", 0.0); | |
AddPiece(&model_proto, "bb", 0.0); | |
AddPiece(&model_proto, "cc", 0.0); | |
AddPiece(&model_proto, "dd", 0.0); | |
AddPiece(&model_proto, "e", 0.0); | |
SentencePieceProcessor sp; | |
EXPECT_TRUE(sp.Load(model_proto).ok()); | |
EXPECT_FALSE(sp.IsUnused(0)); | |
EXPECT_FALSE(sp.IsUnused(1)); | |
EXPECT_FALSE(sp.IsUnused(2)); | |
EXPECT_FALSE(sp.IsUnused(3)); | |
EXPECT_FALSE(sp.IsUnused(4)); | |
EXPECT_FALSE(sp.IsUnused(5)); | |
EXPECT_FALSE(sp.IsUnused(6)); | |
EXPECT_FALSE(sp.IsUnused(7)); | |
EXPECT_TRUE(sp.SetVocabulary({"aa", "dd", "e"}).ok()); | |
EXPECT_FALSE(sp.IsUnused(0)); | |
EXPECT_FALSE(sp.IsUnused(1)); | |
EXPECT_FALSE(sp.IsUnused(2)); | |
EXPECT_FALSE(sp.IsUnused(3)); | |
EXPECT_TRUE(sp.IsUnused(4)); | |
EXPECT_TRUE(sp.IsUnused(5)); | |
EXPECT_FALSE(sp.IsUnused(6)); | |
EXPECT_FALSE(sp.IsUnused(7)); // single char "e" is always used. | |
EXPECT_TRUE(sp.ResetVocabulary().ok()); | |
EXPECT_FALSE(sp.IsUnused(3)); | |
EXPECT_FALSE(sp.IsUnused(4)); | |
EXPECT_FALSE(sp.IsUnused(5)); | |
EXPECT_FALSE(sp.IsUnused(6)); | |
EXPECT_FALSE(sp.IsUnused(7)); | |
EXPECT_TRUE(sp.SetVocabulary({"bb"}).ok()); | |
EXPECT_TRUE(sp.IsUnused(3)); | |
EXPECT_FALSE(sp.IsUnused(4)); | |
EXPECT_TRUE(sp.IsUnused(5)); | |
EXPECT_TRUE(sp.IsUnused(6)); | |
EXPECT_FALSE(sp.IsUnused(7)); | |
EXPECT_TRUE(sp.LoadVocabulary(GetInlineFilename("aa\t1\ndd\t2\n"), 2).ok()); | |
EXPECT_TRUE(sp.IsUnused(3)); | |
EXPECT_TRUE(sp.IsUnused(4)); | |
EXPECT_TRUE(sp.IsUnused(5)); | |
EXPECT_FALSE(sp.IsUnused(6)); | |
EXPECT_FALSE(sp.IsUnused(7)); | |
EXPECT_TRUE(sp.LoadVocabulary(GetInlineFilename("aa\t1\ndd\t1\n"), 2).ok()); | |
EXPECT_TRUE(sp.IsUnused(3)); | |
EXPECT_TRUE(sp.IsUnused(4)); | |
EXPECT_TRUE(sp.IsUnused(5)); | |
EXPECT_TRUE(sp.IsUnused(6)); | |
EXPECT_FALSE(sp.IsUnused(7)); | |
EXPECT_TRUE(sp.LoadVocabulary(GetInlineFilename("aa\t1\ndd\t1\n"), 1).ok()); | |
EXPECT_FALSE(sp.IsUnused(3)); | |
EXPECT_TRUE(sp.IsUnused(4)); | |
EXPECT_TRUE(sp.IsUnused(5)); | |
EXPECT_FALSE(sp.IsUnused(6)); | |
EXPECT_FALSE(sp.IsUnused(7)); | |
EXPECT_TRUE(sp.LoadVocabulary(GetInlineFilename("aa\t0\ndd\t0\n"), 0).ok()); | |
EXPECT_FALSE(sp.IsUnused(3)); | |
EXPECT_TRUE(sp.IsUnused(4)); | |
EXPECT_TRUE(sp.IsUnused(5)); | |
EXPECT_FALSE(sp.IsUnused(6)); | |
EXPECT_FALSE(sp.IsUnused(7)); | |
// No frequency. | |
EXPECT_TRUE(sp.LoadVocabulary(GetInlineFilename("aa\ndd\n"), 1).ok()); | |
EXPECT_FALSE(sp.IsUnused(3)); | |
EXPECT_TRUE(sp.IsUnused(4)); | |
EXPECT_TRUE(sp.IsUnused(5)); | |
EXPECT_FALSE(sp.IsUnused(6)); | |
EXPECT_FALSE(sp.IsUnused(7)); | |
} | |
TEST(SentencePieceProcessorTest, ImmutableSentencePieceTextTest) { | |
ImmutableSentencePieceText spt; | |
EXPECT_TRUE(spt.text().empty()); | |
EXPECT_EQ(spt.score(), 0.0); | |
EXPECT_TRUE(spt.SerializeAsString().empty()); | |
auto *v = spt.mutable_proto(); | |
v->set_text("hello world"); | |
v->set_score(1.0); | |
for (int i = 0; i < 10; ++i) { | |
auto *p = v->add_pieces(); | |
p->set_surface(absl::StrCat("surface_", i)); | |
p->set_piece(absl::StrCat("surface_", i)); | |
p->set_id(i); | |
p->set_begin(i + 10); | |
p->set_end(i + 20); | |
} | |
EXPECT_EQ(v->pieces_size(), spt.pieces_size()); | |
for (int i = 0; i < spt.pieces_size(); ++i) { | |
EXPECT_EQ(v->pieces(i).surface(), spt.pieces(i).surface()); | |
EXPECT_EQ(v->pieces(i).piece(), spt.pieces(i).piece()); | |
EXPECT_EQ(v->pieces(i).id(), spt.pieces(i).id()); | |
EXPECT_EQ(v->pieces(i).begin(), spt.pieces(i).begin()); | |
EXPECT_EQ(v->pieces(i).end(), spt.pieces(i).end()); | |
} | |
auto check_proto = [&v](const ImmutableSentencePieceText &s) { | |
int n = 0; | |
for (auto &p : s.pieces()) { | |
EXPECT_EQ(v->pieces(n).surface(), p.surface()); | |
EXPECT_EQ(v->pieces(n).piece(), p.piece()); | |
EXPECT_EQ(v->pieces(n).id(), p.id()); | |
EXPECT_EQ(v->pieces(n).begin(), p.begin()); | |
EXPECT_EQ(v->pieces(n).end(), p.end()); | |
++n; | |
} | |
EXPECT_EQ(v->text(), s.text()); | |
EXPECT_EQ(v->score(), s.score()); | |
EXPECT_EQ(v->SerializeAsString(), s.SerializeAsString()); | |
}; | |
// test copy. | |
const auto spt2 = spt; | |
check_proto(spt2); | |
// test assign. | |
const ImmutableSentencePieceText spt3(spt); | |
check_proto(spt3); | |
// default piece. | |
const ImmutableSentencePieceText_ImmutableSentencePiece piece; | |
EXPECT_TRUE(piece.surface().empty()); | |
EXPECT_TRUE(piece.piece().empty()); | |
EXPECT_EQ(piece.begin(), 0); | |
EXPECT_EQ(piece.end(), 0); | |
EXPECT_EQ(piece.id(), 0); | |
} | |
TEST(SentencePieceProcessorTest, ImmutableNBestSentencePieceTextTest) { | |
ImmutableNBestSentencePieceText spt; | |
EXPECT_EQ(spt.nbests_size(), 0); | |
EXPECT_TRUE(spt.SerializeAsString().empty()); | |
auto *v = spt.mutable_proto(); | |
for (int i = 0; i < 10; ++i) { | |
auto *p = v->add_nbests(); | |
p->set_text(absl::StrCat("text_", i)); | |
p->set_score(2.0 * i); | |
} | |
auto check_proto = [&v](const ImmutableNBestSentencePieceText &s) { | |
EXPECT_EQ(v->nbests_size(), s.nbests_size()); | |
for (int i = 0; i < v->nbests_size(); ++i) { | |
EXPECT_EQ(v->nbests(i).text(), s.nbests(i).text()); | |
EXPECT_EQ(v->nbests(i).score(), s.nbests(i).score()); | |
} | |
EXPECT_EQ(v->SerializeAsString(), s.SerializeAsString()); | |
}; | |
check_proto(spt); | |
// test copy. | |
const auto spt2 = spt; | |
check_proto(spt2); | |
// test assign. | |
const ImmutableNBestSentencePieceText spt3(spt); | |
check_proto(spt3); | |
} | |
TEST(SentencePieceProcessorTest, ConvertToUnicodeSpansTest) { | |
auto make_spt = [&](const std::vector<std::string> &tokens) { | |
ImmutableSentencePieceText ispt; | |
auto *spt = ispt.mutable_proto(); | |
int prev = 0; | |
std::string text; | |
for (const auto &tok : tokens) { | |
auto *piece = spt->add_pieces(); | |
piece->set_surface(tok); | |
piece->set_piece(tok); | |
piece->set_begin(prev); | |
piece->set_end(prev + tok.size()); | |
prev += tok.size(); | |
text += tok; | |
} | |
spt->set_text(text); | |
ispt.ConvertToUnicodeSpans(); | |
return ispt; | |
}; | |
{ | |
const auto spt = make_spt({"hello", "_world", "."}); | |
EXPECT_EQ(spt.pieces_size(), 3); | |
EXPECT_EQ(spt.pieces(0).begin(), 0); | |
EXPECT_EQ(spt.pieces(0).end(), 5); | |
EXPECT_EQ(spt.pieces(1).begin(), 5); | |
EXPECT_EQ(spt.pieces(1).end(), 11); | |
EXPECT_EQ(spt.pieces(2).begin(), 11); | |
EXPECT_EQ(spt.pieces(2).end(), 12); | |
} | |
{ | |
const auto spt = make_spt({"これは", "test", "です"}); | |
EXPECT_EQ(spt.pieces_size(), 3); | |
EXPECT_EQ(spt.pieces(0).begin(), 0); | |
EXPECT_EQ(spt.pieces(0).end(), 3); | |
EXPECT_EQ(spt.pieces(1).begin(), 3); | |
EXPECT_EQ(spt.pieces(1).end(), 7); | |
EXPECT_EQ(spt.pieces(2).begin(), 7); | |
EXPECT_EQ(spt.pieces(2).end(), 9); | |
} | |
{ | |
const auto spt = make_spt({"いABは", "にほCD", "へと"}); | |
EXPECT_EQ(spt.pieces_size(), 3); | |
EXPECT_EQ(spt.pieces(0).begin(), 0); | |
EXPECT_EQ(spt.pieces(0).end(), 4); | |
EXPECT_EQ(spt.pieces(1).begin(), 4); | |
EXPECT_EQ(spt.pieces(1).end(), 8); | |
EXPECT_EQ(spt.pieces(2).begin(), 8); | |
EXPECT_EQ(spt.pieces(2).end(), 10); | |
} | |
} | |
} // namespace sentencepiece | |