P2A-test-NV / laser /source /lib /text_processing.py
KuangDW
Add laser2.spm using Git LFS
05d3571
#!/usr/bin/python
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
#
# LASER Language-Agnostic SEntence Representations
# is a toolkit to calculate multilingual sentence embeddings
# and to use them for document classification, bitext filtering
# and mining
#
# --------------------------------------------------------
#
# Helper functions for tokenization and BPE
import os
import sys
import logging
from pathlib import Path
import numpy as np
from subprocess import run, check_output, CalledProcessError, DEVNULL
logging.basicConfig(
stream=sys.stdout,
level=logging.INFO,
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s")
logger = logging.getLogger("preprocess")
# get environment
assert os.environ.get('LASER'), 'Please set the enviornment variable LASER'
LASER = os.environ['LASER']
FASTBPE = LASER + '/tools-external/fastBPE/fast'
MOSES_BDIR = LASER + '/tools-external/moses-tokenizer/tokenizer/'
MOSES_TOKENIZER = MOSES_BDIR + 'tokenizer.perl -q -no-escape -threads 20 -l '
MOSES_LC = MOSES_BDIR + 'lowercase.perl'
NORM_PUNC = MOSES_BDIR + 'normalize-punctuation.perl -l '
DESCAPE = MOSES_BDIR + 'deescape-special-chars.perl'
REM_NON_PRINT_CHAR = MOSES_BDIR + 'remove-non-printing-char.perl'
SPM_DIR = LASER + '/tools-external/sentencepiece-master/build/src/'
SPM = 'LD_LIBRARY_PATH=' + SPM_DIR + ' ' + SPM_DIR + '/spm_encode --output_format=piece'
# Romanization (and lower casing)
ROMAN_LC = 'python3 ' + LASER + '/source/lib/romanize_lc.py -l '
# Mecab tokenizer for Japanese
MECAB = LASER + '/tools-external/mecab'
###############################################################################
#
# Tokenize a line of text
#
###############################################################################
def TokenLine(line, lang='en', lower_case=True, romanize=False):
assert lower_case, 'lower case is needed by all the models'
roman = lang if romanize else 'none'
tok = check_output(
REM_NON_PRINT_CHAR
+ '|' + NORM_PUNC + lang
+ '|' + DESCAPE
+ '|' + MOSES_TOKENIZER + lang
+ ('| python3 -m jieba -d ' if lang == 'zh' else '')
+ ('|' + MECAB + '/bin/mecab -O wakati -b 50000 ' if lang == 'ja' else '')
+ '|' + ROMAN_LC + roman,
input=line,
encoding='UTF-8',
shell=True)
return tok.strip()
###############################################################################
#
# Tokenize a file
#
###############################################################################
def Token(inp_fname, out_fname, lang='en',
lower_case=True, romanize=False, descape=False,
verbose=False, over_write=False, gzip=False):
assert lower_case, 'lower case is needed by all the models'
assert not over_write, 'over-write is not yet implemented'
if not os.path.isfile(out_fname):
cat = 'zcat ' if gzip else 'cat '
roman = lang if romanize else 'none'
# handle some iso3 langauge codes
if lang in ('cmn', 'wuu', 'yue'):
lang = 'zh'
if lang in ('jpn'):
lang = 'ja'
if verbose:
logger.info('tokenizing {} in language {} {} {}'
.format(os.path.basename(inp_fname), lang,
'(gzip)' if gzip else '',
'(de-escaped)' if descape else '',
'(romanized)' if romanize else ''))
run(cat + inp_fname
+ '|' + REM_NON_PRINT_CHAR
+ '|' + NORM_PUNC + lang
+ ('|' + DESCAPE if descape else '')
+ '|' + MOSES_TOKENIZER + lang
+ ('| python3 -m jieba -d ' if lang == 'zh' else '')
+ ('|' + MECAB + '/bin/mecab -O wakati -b 50000 ' if lang == 'ja' else '')
+ '|' + ROMAN_LC + roman
+ '>' + out_fname,
env=dict(os.environ, LD_LIBRARY_PATH=MECAB + '/lib'),
shell=True)
elif not over_write and verbose:
logger.info('tokenized file {} exists already'
.format(os.path.basename(out_fname), lang))
###############################################################################
#
# Apply SPM on a whole file
#
###############################################################################
def SPMApply(inp_fname, out_fname, spm_model, lang='en',
lower_case=True, descape=False,
verbose=False, over_write=False, gzip=False):
assert lower_case, 'lower case is needed by all the models'
if not os.path.isfile(out_fname):
cat = 'zcat ' if gzip else 'cat '
if verbose:
logger.info('SPM processing {} {} {}'
.format(os.path.basename(inp_fname),
'(gzip)' if gzip else '',
'(de-escaped)' if descape else ''))
assert os.path.isfile(spm_model), f'SPM model {spm_model} not found'
command = (cat + inp_fname
+ '|' + REM_NON_PRINT_CHAR
+ '|' + NORM_PUNC + lang
+ ('|' + DESCAPE if descape else '')
+ '|' + ROMAN_LC + 'none'
+ '|' + SPM + " --model=" + spm_model
+ ' > ' + out_fname)
try:
run(["/bin/bash", "-o", "pipefail", "-c", command], check=True, capture_output=True)
except CalledProcessError as e:
logger.error(e.stderr.decode().strip())
sys.exit(1)
elif not over_write and verbose:
logger.info('SPM encoded file {} exists already'
.format(os.path.basename(out_fname)))
###############################################################################
#
# Apply FastBPE on a whole file
#
###############################################################################
def BPEfastApply(inp_fname, out_fname, bpe_codes,
verbose=False, over_write=False):
if not os.path.isfile(out_fname):
if verbose:
logger.info('fastBPE: processing {}'
.format(os.path.basename(inp_fname)))
bpe_vocab = bpe_codes.replace('fcodes', 'fvocab')
assert os.path.isfile(bpe_vocab), f'fastBPE: vocab file {bpe_vocab} not found'
run(FASTBPE + ' applybpe '
+ out_fname + ' ' + inp_fname
+ ' ' + bpe_codes
+ ' ' + bpe_vocab, shell=True, stderr=DEVNULL)
elif not over_write and verbose:
logger.info('fastBPE: {} exists already'
.format(os.path.basename(out_fname)))
###############################################################################
#
# Split long lines into multiple sentences at "."
#
###############################################################################
def SplitLines(ifname, of_txt, of_sid):
if os.path.isfile(of_txt):
print(' - SplitLines: {} already exists'.format(of_txt))
return
nl = 0
nl_sp = 0
maxw = 0
maxw_sp = 0
fp_sid = open(of_sid, 'w')
fp_txt = open(of_txt, 'w')
with open(ifname, 'r') as ifp:
for line in ifp:
print('{:d}'.format(nl), file=fp_sid) # store current sentence ID
nw = 0
words = line.strip().split()
maxw = max(maxw, len(words))
for i, word in enumerate(words):
if word == '.' and i != len(words)-1:
if nw > 0:
print(' {}'.format(word), file=fp_txt)
else:
print('{}'.format(word), file=fp_txt)
# store current sentence ID
print('{:d}'.format(nl), file=fp_sid)
nl_sp += 1
maxw_sp = max(maxw_sp, nw+1)
nw = 0
else:
if nw > 0:
print(' {}'.format(word), end='', file=fp_txt)
else:
print('{}'.format(word), end='', file=fp_txt)
nw += 1
if nw > 0:
# handle remainder of sentence
print('', file=fp_txt)
nl_sp += 1
maxw_sp = max(maxw_sp, nw+1)
nl += 1
print(' - Split sentences: {}'.format(ifname))
print(' - lines/max words: {:d}/{:d} -> {:d}/{:d}'
.format(nl, maxw, nl_sp, maxw_sp))
fp_sid.close()
fp_txt.close()
###############################################################################
#
# Join embeddings of previously split lines (average)
#
###############################################################################
def JoinEmbed(if_embed, sid_fname, of_embed, dim=1024):
if os.path.isfile(of_embed):
print(' - JoinEmbed: {} already exists'.format(of_embed))
return
# read the input embeddings
em_in = np.fromfile(if_embed, dtype=np.float32, count=-1).reshape(-1, dim)
ninp = em_in.shape[0]
print(' - Combine embeddings:')
print(' input: {:s} {:d} sentences'.format(if_embed, ninp))
# get all sentence IDs
sid = np.empty(ninp, dtype=np.int32)
i = 0
with open(sid_fname, 'r') as fp_sid:
for line in fp_sid:
sid[i] = int(line)
i += 1
nout = sid.max() + 1
print(' IDs: {:s}, {:d} sentences'.format(sid_fname, nout))
# combining
em_out = np.zeros((nout, dim), dtype=np.float32)
cnt = np.zeros(nout, dtype=np.int32)
for i in range(ninp):
idx = sid[i]
em_out[idx] += em_in[i] # cumulate sentence vectors
cnt[idx] += 1
if (cnt == 0).astype(int).sum() > 0:
print('ERROR: missing lines')
sys.exit(1)
# normalize
for i in range(nout):
em_out[i] /= cnt[i]
print(' output: {:s}'.format(of_embed))
em_out.tofile(of_embed)