File size: 4,616 Bytes
2aebc50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
// Copyright 2016 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.!

#ifndef BPE_MODEL_TRAINER_H_
#define BPE_MODEL_TRAINER_H_

#include <cstdint>
#include <limits>
#include <string>
#include <vector>

#include "sentencepiece_model.pb.h"
#include "third_party/absl/container/btree_set.h"
#include "third_party/absl/container/flat_hash_map.h"
#include "trainer_interface.h"

namespace sentencepiece {
namespace bpe {

// Trainer class for BPE model.
class Trainer : public TrainerInterface {
 public:
  Trainer(const TrainerSpec &trainer_spec,
          const NormalizerSpec &normalizer_spec,
          const NormalizerSpec &denormalizer_spec)
      : TrainerInterface::TrainerInterface(trainer_spec, normalizer_spec,
                                           denormalizer_spec) {}

  util::Status Train() override;

 private:
  // Symbol represents a character or symbol bigram.
  struct Symbol {
    const Symbol *left;              // left symbol in bigram
    const Symbol *right;             // right symbol in bigram
    string_util::UnicodeText chars;  // all flattend chracter sequence
    bool is_unk;                     // true if this symbol is unknown.
    uint64_t fp;                     // fingerprint of this symbol.
    uint64_t freq;                   // frequency of this symbol.

    // Position list. Use set so that we can keep the order of occurrence.
    // See EncodePos/DecodePos.
    absl::btree_set<uint64_t> positions;

    bool IsBigram() const { return left != nullptr && right != nullptr; }
    std::string ToString() const;
    Symbol() : left(nullptr), right(nullptr), is_unk(false), fp(0), freq(0) {}
  };

  struct Position {
    int sid;    // sentence id
    int left;   // left symbol index
    int right;  // right symbol index
  };

  // Encodes sid, left and right bigram index into uint64_t.
  // Encoded value keeps the order of sid, left and right.
  static uint64_t EncodePos(int sid, int l, int r) {
    CHECK_GE(l, 0);
    CHECK_GE(r, 0);
    CHECK_LE(l, std::numeric_limits<uint16_t>::max());
    CHECK_LE(r, std::numeric_limits<uint16_t>::max());
    const uint64_t n = (static_cast<uint64_t>(sid) << 32) |
                       (static_cast<uint64_t>(l) << 16) | r;
    return n;
  }

  // Decodes sid, left and right bigram index from uint64_t.
  static Position DecodePos(uint64_t n) {
    Position p;
    p.sid = n >> 32;
    p.left = (n >> 16) & 0xffff;
    p.right = n & 0xffff;
    return p;
  }

  // Gets unary (character) symbol from the char code |c|.
  // The return value is cached.
  Symbol *GetCharSymbol(char32 c);

  // Gets symbol pair from left/right symbols. The return value is cached.
  Symbol *GetPairSymbol(const Symbol *left, const Symbol *right);

  // Computes the frequency of |symbol| and update symbol->freq field.
  void ComputeFreq(Symbol *symbol) const;

  // Returns the valid index before symbols_[sid][index].
  int GetNextIndex(int sid, int index) const;

  // Returns the valid index after symbols_[sid][index].
  int GetPrevIndex(int sid, int index) const;

  // Makes a new bigram from [symbols_[sid][left], symbols_[sid][right]] and
  // Adds it to symbols_cache_ and active_symbols_.
  void AddNewPair(int sid, int left, int right);

  // Resets the fequency of bigram [symbols_[sid][left] symbols_[sid][right]],
  // if this bigram is not |best|.
  void ResetFreq(int sid, int left, int right, const Symbol *best);

  // Updates |active_symbols_| by copying the top 5% frequent symbols in
  // symbols_cache_.
  void UpdateActiveSymbols();

  // All unique symbols. Key is a fingerprint of Symbol.
  absl::flat_hash_map<uint64_t, Symbol *> symbols_cache_;

  // Set of symbols from which we find the best symbol in each iteration.
  absl::btree_set<Symbol *> active_symbols_;

  // Stores symbols allocated in heap so that we can delete them at onece.
  std::vector<Symbol *> allocated_;

  // Sentences. symbols_[sid][index] stores a symbol in sentence_[sid][index].
  std::vector<std::vector<Symbol *>> symbols_;
};
}  // namespace bpe
}  // namespace sentencepiece
#endif  // BPE_MODEL_TRAINER_H_