Spaces:
Sleeping
Sleeping
#!/usr/bin/python | |
# Copyright (c) Facebook, Inc. and its affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the BSD-style license found in the | |
# LICENSE file in the root directory of this source tree. | |
# | |
# LASER Language-Agnostic SEntence Representations | |
# is a toolkit to calculate multilingual sentence embeddings | |
# and to use them for document classification, bitext filtering | |
# and mining | |
# | |
# -------------------------------------------------------- | |
# | |
# Helper functions for tokenization and BPE | |
import os | |
import sys | |
import logging | |
from pathlib import Path | |
import numpy as np | |
from subprocess import run, check_output, CalledProcessError, DEVNULL | |
logging.basicConfig( | |
stream=sys.stdout, | |
level=logging.INFO, | |
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s") | |
logger = logging.getLogger("preprocess") | |
# get environment | |
assert os.environ.get('LASER'), 'Please set the enviornment variable LASER' | |
LASER = os.environ['LASER'] | |
FASTBPE = LASER + '/tools-external/fastBPE/fast' | |
MOSES_BDIR = LASER + '/tools-external/moses-tokenizer/tokenizer/' | |
MOSES_TOKENIZER = MOSES_BDIR + 'tokenizer.perl -q -no-escape -threads 20 -l ' | |
MOSES_LC = MOSES_BDIR + 'lowercase.perl' | |
NORM_PUNC = MOSES_BDIR + 'normalize-punctuation.perl -l ' | |
DESCAPE = MOSES_BDIR + 'deescape-special-chars.perl' | |
REM_NON_PRINT_CHAR = MOSES_BDIR + 'remove-non-printing-char.perl' | |
SPM_DIR = LASER + '/tools-external/sentencepiece-master/build/src/' | |
SPM = 'LD_LIBRARY_PATH=' + SPM_DIR + ' ' + SPM_DIR + '/spm_encode --output_format=piece' | |
# Romanization (and lower casing) | |
ROMAN_LC = 'python3 ' + LASER + '/source/lib/romanize_lc.py -l ' | |
# Mecab tokenizer for Japanese | |
MECAB = LASER + '/tools-external/mecab' | |
############################################################################### | |
# | |
# Tokenize a line of text | |
# | |
############################################################################### | |
def TokenLine(line, lang='en', lower_case=True, romanize=False): | |
assert lower_case, 'lower case is needed by all the models' | |
roman = lang if romanize else 'none' | |
tok = check_output( | |
REM_NON_PRINT_CHAR | |
+ '|' + NORM_PUNC + lang | |
+ '|' + DESCAPE | |
+ '|' + MOSES_TOKENIZER + lang | |
+ ('| python3 -m jieba -d ' if lang == 'zh' else '') | |
+ ('|' + MECAB + '/bin/mecab -O wakati -b 50000 ' if lang == 'ja' else '') | |
+ '|' + ROMAN_LC + roman, | |
input=line, | |
encoding='UTF-8', | |
shell=True) | |
return tok.strip() | |
############################################################################### | |
# | |
# Tokenize a file | |
# | |
############################################################################### | |
def Token(inp_fname, out_fname, lang='en', | |
lower_case=True, romanize=False, descape=False, | |
verbose=False, over_write=False, gzip=False): | |
assert lower_case, 'lower case is needed by all the models' | |
assert not over_write, 'over-write is not yet implemented' | |
if not os.path.isfile(out_fname): | |
cat = 'zcat ' if gzip else 'cat ' | |
roman = lang if romanize else 'none' | |
# handle some iso3 langauge codes | |
if lang in ('cmn', 'wuu', 'yue'): | |
lang = 'zh' | |
if lang in ('jpn'): | |
lang = 'ja' | |
if verbose: | |
logger.info('tokenizing {} in language {} {} {}' | |
.format(os.path.basename(inp_fname), lang, | |
'(gzip)' if gzip else '', | |
'(de-escaped)' if descape else '', | |
'(romanized)' if romanize else '')) | |
run(cat + inp_fname | |
+ '|' + REM_NON_PRINT_CHAR | |
+ '|' + NORM_PUNC + lang | |
+ ('|' + DESCAPE if descape else '') | |
+ '|' + MOSES_TOKENIZER + lang | |
+ ('| python3 -m jieba -d ' if lang == 'zh' else '') | |
+ ('|' + MECAB + '/bin/mecab -O wakati -b 50000 ' if lang == 'ja' else '') | |
+ '|' + ROMAN_LC + roman | |
+ '>' + out_fname, | |
env=dict(os.environ, LD_LIBRARY_PATH=MECAB + '/lib'), | |
shell=True) | |
elif not over_write and verbose: | |
logger.info('tokenized file {} exists already' | |
.format(os.path.basename(out_fname), lang)) | |
############################################################################### | |
# | |
# Apply SPM on a whole file | |
# | |
############################################################################### | |
def SPMApply(inp_fname, out_fname, spm_model, lang='en', | |
lower_case=True, descape=False, | |
verbose=False, over_write=False, gzip=False): | |
assert lower_case, 'lower case is needed by all the models' | |
if not os.path.isfile(out_fname): | |
cat = 'zcat ' if gzip else 'cat ' | |
if verbose: | |
logger.info('SPM processing {} {} {}' | |
.format(os.path.basename(inp_fname), | |
'(gzip)' if gzip else '', | |
'(de-escaped)' if descape else '')) | |
assert os.path.isfile(spm_model), f'SPM model {spm_model} not found' | |
command = (cat + inp_fname | |
+ '|' + REM_NON_PRINT_CHAR | |
+ '|' + NORM_PUNC + lang | |
+ ('|' + DESCAPE if descape else '') | |
+ '|' + ROMAN_LC + 'none' | |
+ '|' + SPM + " --model=" + spm_model | |
+ ' > ' + out_fname) | |
try: | |
run(["/bin/bash", "-o", "pipefail", "-c", command], check=True, capture_output=True) | |
except CalledProcessError as e: | |
logger.error(e.stderr.decode().strip()) | |
sys.exit(1) | |
elif not over_write and verbose: | |
logger.info('SPM encoded file {} exists already' | |
.format(os.path.basename(out_fname))) | |
############################################################################### | |
# | |
# Apply FastBPE on a whole file | |
# | |
############################################################################### | |
def BPEfastApply(inp_fname, out_fname, bpe_codes, | |
verbose=False, over_write=False): | |
if not os.path.isfile(out_fname): | |
if verbose: | |
logger.info('fastBPE: processing {}' | |
.format(os.path.basename(inp_fname))) | |
bpe_vocab = bpe_codes.replace('fcodes', 'fvocab') | |
assert os.path.isfile(bpe_vocab), f'fastBPE: vocab file {bpe_vocab} not found' | |
run(FASTBPE + ' applybpe ' | |
+ out_fname + ' ' + inp_fname | |
+ ' ' + bpe_codes | |
+ ' ' + bpe_vocab, shell=True, stderr=DEVNULL) | |
elif not over_write and verbose: | |
logger.info('fastBPE: {} exists already' | |
.format(os.path.basename(out_fname))) | |
############################################################################### | |
# | |
# Split long lines into multiple sentences at "." | |
# | |
############################################################################### | |
def SplitLines(ifname, of_txt, of_sid): | |
if os.path.isfile(of_txt): | |
print(' - SplitLines: {} already exists'.format(of_txt)) | |
return | |
nl = 0 | |
nl_sp = 0 | |
maxw = 0 | |
maxw_sp = 0 | |
fp_sid = open(of_sid, 'w') | |
fp_txt = open(of_txt, 'w') | |
with open(ifname, 'r') as ifp: | |
for line in ifp: | |
print('{:d}'.format(nl), file=fp_sid) # store current sentence ID | |
nw = 0 | |
words = line.strip().split() | |
maxw = max(maxw, len(words)) | |
for i, word in enumerate(words): | |
if word == '.' and i != len(words)-1: | |
if nw > 0: | |
print(' {}'.format(word), file=fp_txt) | |
else: | |
print('{}'.format(word), file=fp_txt) | |
# store current sentence ID | |
print('{:d}'.format(nl), file=fp_sid) | |
nl_sp += 1 | |
maxw_sp = max(maxw_sp, nw+1) | |
nw = 0 | |
else: | |
if nw > 0: | |
print(' {}'.format(word), end='', file=fp_txt) | |
else: | |
print('{}'.format(word), end='', file=fp_txt) | |
nw += 1 | |
if nw > 0: | |
# handle remainder of sentence | |
print('', file=fp_txt) | |
nl_sp += 1 | |
maxw_sp = max(maxw_sp, nw+1) | |
nl += 1 | |
print(' - Split sentences: {}'.format(ifname)) | |
print(' - lines/max words: {:d}/{:d} -> {:d}/{:d}' | |
.format(nl, maxw, nl_sp, maxw_sp)) | |
fp_sid.close() | |
fp_txt.close() | |
############################################################################### | |
# | |
# Join embeddings of previously split lines (average) | |
# | |
############################################################################### | |
def JoinEmbed(if_embed, sid_fname, of_embed, dim=1024): | |
if os.path.isfile(of_embed): | |
print(' - JoinEmbed: {} already exists'.format(of_embed)) | |
return | |
# read the input embeddings | |
em_in = np.fromfile(if_embed, dtype=np.float32, count=-1).reshape(-1, dim) | |
ninp = em_in.shape[0] | |
print(' - Combine embeddings:') | |
print(' input: {:s} {:d} sentences'.format(if_embed, ninp)) | |
# get all sentence IDs | |
sid = np.empty(ninp, dtype=np.int32) | |
i = 0 | |
with open(sid_fname, 'r') as fp_sid: | |
for line in fp_sid: | |
sid[i] = int(line) | |
i += 1 | |
nout = sid.max() + 1 | |
print(' IDs: {:s}, {:d} sentences'.format(sid_fname, nout)) | |
# combining | |
em_out = np.zeros((nout, dim), dtype=np.float32) | |
cnt = np.zeros(nout, dtype=np.int32) | |
for i in range(ninp): | |
idx = sid[i] | |
em_out[idx] += em_in[i] # cumulate sentence vectors | |
cnt[idx] += 1 | |
if (cnt == 0).astype(int).sum() > 0: | |
print('ERROR: missing lines') | |
sys.exit(1) | |
# normalize | |
for i in range(nout): | |
em_out[i] /= cnt[i] | |
print(' output: {:s}'.format(of_embed)) | |
em_out.tofile(of_embed) | |