File size: 8,642 Bytes
6fc683c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
# coding=utf-8
# Copyright 2018 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""A simple invertible tokenizer.

Converts from a unicode string to a list of tokens
(represented as Unicode strings).

This tokenizer has the following desirable properties:
 - It is invertible.
 - Alphanumeric characters are broken away from non-alphanumeric characters.
 - A single space between words does not produce an extra token.
 - The full Unicode punctuation and separator set is recognized.

The tokenization algorithm is as follows:

1.  Split the text into a list of tokens, splitting at every boundary of an
    alphanumeric character and a non-alphanumeric character.  This produces
    a list which alternates between "alphanumeric tokens"
    (strings of alphanumeric characters) and "non-alphanumeric tokens"
    (strings of non-alphanumeric characters).

2.  Remove every token consisting of a single space, unless it is
    the very first or very last token in the list.  These tokens are now
    implied by the fact that there are two adjacent alphanumeric tokens.

e.g.  u"Dude - that's so cool."
        -> [u"Dude", u" - ", u"that", u"'", u"s", u"so", u"cool", u"."]
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections
import sys
import unicodedata
import six
import logging
from six.moves import range  # pylint: disable=redefined-builtin
# from tensor2tensor.utils import mlperf_log
import time
import glob

# Conversion between Unicode and UTF-8, if required (on Python2)
_native_to_unicode = (lambda s: s.decode("utf-8")) if six.PY2 else (lambda s: s)

logger = logging.getLogger(__name__)
# This set contains all letter and number characters.
_ALPHANUMERIC_CHAR_SET = set(
    six.unichr(i) for i in range(sys.maxunicode)
    if (unicodedata.category(six.unichr(i)).startswith("L") or
        unicodedata.category(six.unichr(i)).startswith("N") or
        unicodedata.category(six.unichr(i)).startswith("P")))
        # unicodedata.category(six.unichr(i)).startswith("S")


def encode(text):
  """Encode a unicode string as a list of tokens.

  Args:
    text: a unicode string
  Returns:
    a list of tokens as Unicode strings
  """
  if not text:
    return []
  ret = []
  token_start = 0
  # Classify each character in the input string
  is_alnum = [c in _ALPHANUMERIC_CHAR_SET for c in text]
  add_remaining = False
  for pos in range(1, len(text)):
    add_remaining = False
    if is_alnum[pos] != is_alnum[pos - 1]:
      if not is_alnum[pos]:
        token = text[token_start:pos]
        if token != u" " or token_start == 0:
          add_remaining = False
          ret.append(token)
      else:
        add_remaining = True
        token_start = pos

  final_token = text[token_start:] if text[-1] in _ALPHANUMERIC_CHAR_SET else text[token_start:-1]

  if add_remaining:
    ret.append(final_token)

  # split on punctuation
  final_tokens = []
  for token in ret:
    splitted_token = _run_split_on_punc(token)
    final_tokens.extend(splitted_token)

  return final_tokens


def _run_split_on_punc(text, never_split=None):
  """Splits punctuation on a piece of text."""
  if never_split is not None and text in never_split:
    return [text]
  chars = list(text)
  i = 0
  start_new_word = True
  output = []
  while i < len(chars):
    char = chars[i]
    if _is_punctuation(char):
      output.append([char])
      start_new_word = True
    else:
      if start_new_word:
        output.append([])
      start_new_word = False
      output[-1].append(char)
    i += 1

  return ["".join(x) for x in output]


def _is_punctuation(char):
  """Checks whether `chars` is a punctuation character."""
  cp = ord(char)
  # We treat all non-letter/number ASCII as punctuation.
  # Characters such as "^", "$", and "`" are not in the Unicode
  # Punctuation class but we treat them as punctuation anyways, for
  # consistency.
  if (cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126):
    return True
  cat = unicodedata.category(char)
  if cat.startswith("P"):
    return True
  return False


def decode(tokens):
  """Decode a list of tokens to a unicode string.

  Args:
    tokens: a list of Unicode strings
  Returns:
    a unicode string
  """
  token_is_alnum = [t[0] in _ALPHANUMERIC_CHAR_SET for t in tokens]
  ret = []
  for i, token in enumerate(tokens):
    if i > 0 and token_is_alnum[i - 1] and token_is_alnum[i]:
      ret.append(u" ")
    ret.append(token)
  return "".join(ret)


def _read_filepattern(filepattern, max_lines=None, split_on_newlines=True, do_lower_case=False):
  """Reads files matching a wildcard pattern, yielding the contents.

  Args:
    filepattern: A wildcard pattern matching one or more files.
    max_lines: If set, stop reading after reading this many lines.
    split_on_newlines: A boolean. If true, then split files by lines and strip
        leading and trailing whitespace from each line. Otherwise, treat each
        file as a single string.

  Yields:
    The contents of the files as lines, if split_on_newlines is True, or
    the entire contents of each file if False.
  """
  filenames = sorted(glob.glob(filepattern))
  print(filenames, 'do lower case:', do_lower_case)
  lines_read = 0
  for filename in filenames:
    start = time.time()
    with open(filename) as f:
      if split_on_newlines:
        for line in f:
          if do_lower_case:
            line = line.lower()
          yield line.strip()
          lines_read += 1
          if max_lines and lines_read >= max_lines:
            return
          if lines_read % 100000 == 0:
            print("read", lines_read, "lines,", time.time() - start, "secs elapsed")

      else:
        if max_lines:
          doc = []
          for line in f:
            if do_lower_case:
              line = line.lower()
            doc.append(line)
            lines_read += 1
            if max_lines and lines_read >= max_lines:
              yield "".join(doc)
              return
          yield "".join(doc)

        else:
          yield f.read()

    print(time.time() - start, "for reading read file :", filename)


def corpus_token_counts(
    text_filepattern, corpus_max_lines, split_on_newlines=True, additional_chars="", do_lower_case=False):
  """Read the corpus and compute a dictionary of token counts.

  Args:
    text_filepattern: A pattern matching one or more files.
    corpus_max_lines: An integer; maximum total lines to read.
    split_on_newlines: A boolean. If true, then split files by lines and strip
        leading and trailing whitespace from each line. Otherwise, treat each
        file as a single string.
    additional_chars: A String. Each consisting characters will be treat as normal
        alphabets so that they will be included in each vocab.

  Returns:
    a dictionary mapping token to count.
  """
  if additional_chars:
    _ALPHANUMERIC_CHAR_SET.add(additional_chars)

  counts = collections.Counter()
  for doc in _read_filepattern(
      text_filepattern,
      max_lines=corpus_max_lines,
      split_on_newlines=split_on_newlines,
      do_lower_case=do_lower_case):
    counts.update(encode(_native_to_unicode(doc)))
  print("read all files")
  return counts


def vocab_token_counts(text_filepattern, max_lines, do_lower_case=False):
  """Read a vocab file and return a dictionary of token counts.

  Reads a two-column CSV file of tokens and their frequency in a dataset. The
  tokens are presumed to be generated by encode() or the equivalent.

  Args:
    text_filepattern: A pattern matching one or more files.
    max_lines: An integer; maximum total lines to read.

  Returns:
    a dictionary mapping token to count.
  """
  ret = {}
  for i, line in enumerate(
      _read_filepattern(text_filepattern, max_lines=max_lines)):
    if "," not in line:
      logger.warning("Malformed vocab line #%d '%s'", i, line)
      continue
    
    if do_lower_case:
      line = line.lower()

    token, count = line.rsplit(",", 1)
    ret[_native_to_unicode(token)] = int(count)

  return ret