Spaces:
Sleeping
Sleeping
// Copyright 2016 Google LLC. | |
// | |
// Licensed under the Apache License, Version 2.0 (the "License"); | |
// you may not use this file except in compliance with the License. | |
// You may obtain a copy of the License at | |
// | |
// http://www.apache.org/licenses/LICENSE-2.0 | |
// | |
// Unless required by applicable law or agreed to in writing, software | |
// distributed under the License is distributed on an "AS IS" BASIS, | |
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
// See the License for the specific language governing permissions and | |
// limitations under the License.! | |
namespace sentencepiece { | |
inline std::string PrintProto(const TrainerSpec &message, | |
absl::string_view name) { | |
std::ostringstream os; | |
os << name << " {\n"; | |
PRINT_REPEATED_STRING(input); | |
PRINT_PARAM(input_format); | |
PRINT_PARAM(model_prefix); | |
static const std::map<TrainerSpec::ModelType, std::string> kModelType_Map = { | |
{TrainerSpec::UNIGRAM, "UNIGRAM"}, | |
{TrainerSpec::BPE, "BPE"}, | |
{TrainerSpec::WORD, "WORD"}, | |
{TrainerSpec::CHAR, "CHAR"}, | |
}; | |
PRINT_ENUM(model_type, kModelType_Map); | |
PRINT_PARAM(vocab_size); | |
PRINT_REPEATED_STRING(accept_language); | |
PRINT_PARAM(self_test_sample_size); | |
PRINT_PARAM(character_coverage); | |
PRINT_PARAM(input_sentence_size); | |
PRINT_PARAM(shuffle_input_sentence); | |
PRINT_PARAM(seed_sentencepiece_size); | |
PRINT_PARAM(shrinking_factor); | |
PRINT_PARAM(max_sentence_length); | |
PRINT_PARAM(num_threads); | |
PRINT_PARAM(num_sub_iterations); | |
PRINT_PARAM(max_sentencepiece_length); | |
PRINT_PARAM(split_by_unicode_script); | |
PRINT_PARAM(split_by_number); | |
PRINT_PARAM(split_by_whitespace); | |
PRINT_PARAM(split_digits); | |
PRINT_PARAM(pretokenization_delimiter); | |
PRINT_PARAM(treat_whitespace_as_suffix); | |
PRINT_PARAM(allow_whitespace_only_pieces); | |
PRINT_REPEATED_STRING(control_symbols); | |
PRINT_REPEATED_STRING(user_defined_symbols); | |
PRINT_PARAM(required_chars); | |
PRINT_PARAM(byte_fallback); | |
PRINT_PARAM(vocabulary_output_piece_score); | |
PRINT_PARAM(train_extremely_large_corpus); | |
PRINT_PARAM(seed_sentencepieces_file); | |
PRINT_PARAM(hard_vocab_limit); | |
PRINT_PARAM(use_all_vocab); | |
PRINT_PARAM(unk_id); | |
PRINT_PARAM(bos_id); | |
PRINT_PARAM(eos_id); | |
PRINT_PARAM(pad_id); | |
PRINT_PARAM(unk_piece); | |
PRINT_PARAM(bos_piece); | |
PRINT_PARAM(eos_piece); | |
PRINT_PARAM(pad_piece); | |
PRINT_PARAM(unk_surface); | |
PRINT_PARAM(enable_differential_privacy); | |
PRINT_PARAM(differential_privacy_noise_level); | |
PRINT_PARAM(differential_privacy_clipping_threshold); | |
os << "}\n"; | |
return os.str(); | |
} | |
inline std::string PrintProto(const NormalizerSpec &message, | |
absl::string_view name) { | |
std::ostringstream os; | |
os << name << " {\n"; | |
PRINT_PARAM(name); | |
PRINT_PARAM(add_dummy_prefix); | |
PRINT_PARAM(remove_extra_whitespaces); | |
PRINT_PARAM(escape_whitespaces); | |
PRINT_PARAM(normalization_rule_tsv); | |
os << "}\n"; | |
return os.str(); | |
} | |
util::Status SentencePieceTrainer::SetProtoField(absl::string_view name, | |
absl::string_view value, | |
TrainerSpec *message) { | |
CHECK_OR_RETURN(message); | |
PARSE_REPEATED_STRING(input); | |
PARSE_STRING(input_format); | |
PARSE_STRING(model_prefix); | |
static const std::map<std::string, TrainerSpec::ModelType> kModelType_Map = { | |
{"UNIGRAM", TrainerSpec::UNIGRAM}, | |
{"BPE", TrainerSpec::BPE}, | |
{"WORD", TrainerSpec::WORD}, | |
{"CHAR", TrainerSpec::CHAR}, | |
}; | |
PARSE_ENUM(model_type, kModelType_Map); | |
PARSE_INT32(vocab_size); | |
PARSE_REPEATED_STRING(accept_language); | |
PARSE_INT32(self_test_sample_size); | |
PARSE_DOUBLE(character_coverage); | |
PARSE_UINT64(input_sentence_size); | |
PARSE_BOOL(shuffle_input_sentence); | |
PARSE_INT32(seed_sentencepiece_size); | |
PARSE_DOUBLE(shrinking_factor); | |
PARSE_INT32(max_sentence_length); | |
PARSE_INT32(num_threads); | |
PARSE_INT32(num_sub_iterations); | |
PARSE_INT32(max_sentencepiece_length); | |
PARSE_BOOL(split_by_unicode_script); | |
PARSE_BOOL(split_by_number); | |
PARSE_BOOL(split_by_whitespace); | |
PARSE_BOOL(split_digits); | |
PARSE_STRING(pretokenization_delimiter); | |
PARSE_BOOL(treat_whitespace_as_suffix); | |
PARSE_BOOL(allow_whitespace_only_pieces); | |
PARSE_REPEATED_STRING(control_symbols); | |
PARSE_REPEATED_STRING(user_defined_symbols); | |
PARSE_STRING(required_chars); | |
PARSE_BOOL(byte_fallback); | |
PARSE_BOOL(hard_vocab_limit); | |
PARSE_BOOL(vocabulary_output_piece_score); | |
PARSE_BOOL(train_extremely_large_corpus); | |
PARSE_STRING(seed_sentencepieces_file); | |
PARSE_BOOL(use_all_vocab); | |
PARSE_INT32(unk_id); | |
PARSE_INT32(bos_id); | |
PARSE_INT32(eos_id); | |
PARSE_INT32(pad_id); | |
PARSE_STRING(unk_piece); | |
PARSE_STRING(bos_piece); | |
PARSE_STRING(eos_piece); | |
PARSE_STRING(pad_piece); | |
PARSE_STRING(unk_surface); | |
PARSE_BOOL(enable_differential_privacy); | |
PARSE_DOUBLE(differential_privacy_noise_level); | |
PARSE_UINT64(differential_privacy_clipping_threshold); | |
return util::StatusBuilder(util::StatusCode::kNotFound, GTL_LOC) | |
<< "unknown field name \"" << name << "\" in TrainerSpec."; | |
} | |
util::Status SentencePieceTrainer::SetProtoField(absl::string_view name, | |
absl::string_view value, | |
NormalizerSpec *message) { | |
CHECK_OR_RETURN(message); | |
PARSE_STRING(name); | |
PARSE_BYTE(precompiled_charsmap); | |
PARSE_BOOL(add_dummy_prefix); | |
PARSE_BOOL(remove_extra_whitespaces); | |
PARSE_BOOL(escape_whitespaces); | |
PARSE_STRING(normalization_rule_tsv); | |
return util::StatusBuilder(util::StatusCode::kNotFound, GTL_LOC) | |
<< "unknown field name \"" << name << "\" in NormalizerSpec."; | |
} | |
} // namespace sentencepiece | |