Spaces:
Sleeping
Sleeping
#!/usr/bin/python3 | |
# Copyright (c) Facebook, Inc. and its affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the BSD-style license found in the | |
# LICENSE file in the root directory of this source tree. | |
# | |
# LASER Language-Agnostic SEntence Representations | |
# is a toolkit to calculate multilingual sentence embeddings | |
# and to use them for document classification, bitext filtering | |
# and mining | |
# | |
# -------------------------------------------------------- | |
# | |
# Tool to calculate to embed a text file | |
# The functions can be also imported into another Python code | |
import argparse | |
import logging | |
import os | |
import re | |
import sys | |
import tempfile | |
import time | |
from collections import namedtuple | |
from pathlib import Path | |
from subprocess import run | |
from typing import Optional, Union | |
assert os.environ.get("LASER"), "Please set the environment variable LASER" | |
LASER = os.environ["LASER"] | |
sys.path.append(LASER) | |
import numpy as np | |
from lib.text_processing import BPEfastApply, SPMApply, Token | |
from laser_encoders.models import SentenceEncoder | |
SPACE_NORMALIZER = re.compile(r"\s+") | |
Batch = namedtuple("Batch", "srcs tokens lengths") | |
logging.basicConfig( | |
stream=sys.stdout, | |
level=logging.INFO, | |
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", | |
) | |
logger = logging.getLogger("embed") | |
def buffered_read(fp, buffer_size): | |
buffer = [] | |
for src_str in fp: | |
buffer.append(src_str.strip()) | |
if len(buffer) >= buffer_size: | |
yield buffer | |
buffer = [] | |
if len(buffer) > 0: | |
yield buffer | |
class HuggingFaceEncoder: | |
def __init__(self, encoder_name: str, verbose=False): | |
from sentence_transformers import SentenceTransformer | |
encoder = f"sentence-transformers/{encoder_name}" | |
if verbose: | |
logger.info(f"loading HuggingFace encoder: {encoder}") | |
self.encoder = SentenceTransformer(encoder) | |
def encode_sentences(self, sentences): | |
return self.encoder.encode(sentences) | |
def load_model( | |
encoder: str, | |
spm_model: str, | |
bpe_codes: str, | |
hugging_face=False, | |
verbose=False, | |
**encoder_kwargs, | |
) -> Union[SentenceEncoder, HuggingFaceEncoder]: | |
if hugging_face: | |
return HuggingFaceEncoder(encoder, verbose=verbose) | |
if spm_model: | |
spm_vocab = str(Path(spm_model).with_suffix(".cvocab")) | |
if verbose: | |
logger.info(f"spm_model: {spm_model}") | |
logger.info(f"spm_cvocab: {spm_vocab}") | |
else: | |
spm_vocab = None | |
return SentenceEncoder( | |
encoder, spm_vocab=spm_vocab, verbose=verbose, **encoder_kwargs | |
) | |
def EncodeLoad(args): | |
args.buffer_size = max(args.buffer_size, 1) | |
assert ( | |
not args.max_sentences or args.max_sentences <= args.buffer_size | |
), "--max-sentences/--batch-size cannot be larger than --buffer-size" | |
print(" - loading encoder", args.encoder) | |
return SentenceEncoder( | |
args.encoder, | |
max_sentences=args.max_sentences, | |
max_tokens=args.max_tokens, | |
cpu=args.cpu, | |
verbose=args.verbose, | |
) | |
def EncodeTime(t): | |
t = int(time.time() - t) | |
if t < 1000: | |
return "{:d}s".format(t) | |
else: | |
return "{:d}m{:d}s".format(t // 60, t % 60) | |
# Encode sentences (existing file pointers) | |
def EncodeFilep( | |
encoder, inp_file, out_file, buffer_size=10000, fp16=False, verbose=False | |
): | |
n = 0 | |
t = time.time() | |
for sentences in buffered_read(inp_file, buffer_size): | |
encoded = encoder.encode_sentences(sentences) | |
if fp16: | |
encoded = encoded.astype(np.float16) | |
encoded.tofile(out_file) | |
n += len(sentences) | |
if verbose and n % 10000 == 0: | |
logger.info("encoded {:d} sentences".format(n)) | |
if verbose: | |
logger.info(f"encoded {n} sentences in {EncodeTime(t)}") | |
# Encode sentences (file names) | |
def EncodeFile( | |
encoder, | |
inp_fname, | |
out_fname, | |
buffer_size=10000, | |
fp16=False, | |
verbose=False, | |
over_write=False, | |
inp_encoding="utf-8", | |
): | |
# TODO :handle over write | |
if not os.path.isfile(out_fname): | |
if verbose: | |
logger.info( | |
"encoding {} to {}".format( | |
inp_fname if len(inp_fname) > 0 else "stdin", | |
out_fname, | |
) | |
) | |
fin = ( | |
open(inp_fname, "r", encoding=inp_encoding, errors="surrogateescape") | |
if len(inp_fname) > 0 | |
else sys.stdin | |
) | |
fout = open(out_fname, mode="wb") | |
EncodeFilep( | |
encoder, fin, fout, buffer_size=buffer_size, fp16=fp16, verbose=verbose | |
) | |
fin.close() | |
fout.close() | |
elif not over_write and verbose: | |
logger.info("encoder: {} exists already".format(os.path.basename(out_fname))) | |
# Load existing embeddings | |
def EmbedLoad(fname, dim=1024, verbose=False, fp16=False): | |
x = np.fromfile(fname, dtype=(np.float16 if fp16 else np.float32), count=-1) | |
x.resize(x.shape[0] // dim, dim) | |
if verbose: | |
print(" - Embeddings: {:s}, {:d}x{:d}".format(fname, x.shape[0], dim)) | |
return x | |
# Get memory mapped embeddings | |
def EmbedMmap(fname, dim=1024, dtype=np.float32, verbose=False): | |
nbex = int(os.path.getsize(fname) / dim / np.dtype(dtype).itemsize) | |
E = np.memmap(fname, mode="r", dtype=dtype, shape=(nbex, dim)) | |
if verbose: | |
print(" - embeddings on disk: {:s} {:d} x {:d}".format(fname, nbex, dim)) | |
return E | |
def embed_sentences( | |
ifname: str, | |
output: str, | |
encoder: Union[SentenceEncoder, HuggingFaceEncoder] = None, | |
encoder_path: str = None, | |
hugging_face=False, | |
token_lang: Optional[str] = "--", | |
bpe_codes: Optional[str] = None, | |
spm_lang: Optional[str] = "en", | |
spm_model: Optional[str] = None, | |
verbose: bool = False, | |
buffer_size: int = 10000, | |
max_tokens: int = 12000, | |
max_sentences: Optional[int] = None, | |
cpu: bool = False, | |
fp16: bool = False, | |
sort_kind: str = "quicksort", | |
): | |
assert encoder or encoder_path, "Provide initialised encoder or encoder_path" | |
buffer_size = max(buffer_size, 1) | |
assert ( | |
not max_sentences or max_sentences <= buffer_size | |
), "--max-sentences/--batch-size cannot be larger than --buffer-size" | |
assert not (bpe_codes and spm_model), "Cannot specify both spm and bpe" | |
if encoder_path: | |
encoder = load_model( | |
encoder_path, | |
spm_model, | |
bpe_codes, | |
verbose=verbose, | |
hugging_face=hugging_face, | |
max_sentences=max_sentences, | |
max_tokens=max_tokens, | |
sort_kind=sort_kind, | |
cpu=cpu, | |
) | |
if not ifname: | |
ifname = "" # default to stdin | |
with tempfile.TemporaryDirectory() as tmpdir: | |
if token_lang != "--": | |
tok_fname = os.path.join(tmpdir, "tok") | |
Token( | |
ifname, | |
tok_fname, | |
lang=token_lang, | |
romanize=True if token_lang == "el" else False, | |
lower_case=True, | |
gzip=False, | |
verbose=verbose, | |
over_write=False, | |
) | |
ifname = tok_fname | |
if bpe_codes: | |
if ifname == "": # stdin | |
ifname = os.path.join(tmpdir, "no_tok") | |
run(f"cat > {ifname}", shell=True) | |
bpe_fname = os.path.join(tmpdir, "bpe") | |
BPEfastApply( | |
ifname, bpe_fname, bpe_codes, verbose=verbose, over_write=False | |
) | |
ifname = bpe_fname | |
if spm_model: | |
spm_fname = os.path.join(tmpdir, "spm") | |
SPMApply( | |
ifname, | |
spm_fname, | |
spm_model, | |
lang=spm_lang, | |
lower_case=True, | |
verbose=verbose, | |
over_write=False, | |
) | |
ifname = spm_fname | |
EncodeFile( | |
encoder, | |
ifname, | |
output, | |
verbose=verbose, | |
over_write=False, | |
buffer_size=buffer_size, | |
fp16=fp16, | |
) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description="LASER: Embed sentences") | |
parser.add_argument( | |
"-i", | |
"--input", | |
type=str, | |
default=None, | |
help="Input text file", | |
) | |
parser.add_argument("--encoder", type=str, required=True, help="encoder to be used") | |
parser.add_argument( | |
"--token-lang", | |
type=str, | |
default="--", | |
help="Perform tokenization with given language ('--' for no tokenization)", | |
) | |
parser.add_argument( | |
"--bpe-codes", type=str, default=None, help="Apply BPE using specified codes" | |
) | |
parser.add_argument( | |
"--spm-lang", type=str, default="en", help="Apply SPM using specified language" | |
) | |
parser.add_argument( | |
"--spm-model", type=str, default=None, help="Apply SPM using specified model" | |
) | |
parser.add_argument("-v", "--verbose", action="store_true", help="Detailed output") | |
parser.add_argument( | |
"-o", "--output", required=True, help="Output sentence embeddings" | |
) | |
parser.add_argument( | |
"--buffer-size", type=int, default=10000, help="Buffer size (sentences)" | |
) | |
parser.add_argument( | |
"--max-tokens", | |
type=int, | |
default=12000, | |
help="Maximum number of tokens to process in a batch", | |
) | |
parser.add_argument( | |
"--max-sentences", | |
type=int, | |
default=None, | |
help="Maximum number of sentences to process in a batch", | |
) | |
parser.add_argument( | |
"--fp16", | |
action="store_true", | |
help="Store embedding matrices in fp16 instead of fp32", | |
) | |
parser.add_argument("--cpu", action="store_true", help="Use CPU instead of GPU") | |
parser.add_argument( | |
"--sort-kind", | |
type=str, | |
default="quicksort", | |
choices=["quicksort", "mergesort"], | |
help="Algorithm used to sort batch by length", | |
) | |
parser.add_argument( | |
"--use-hugging-face", | |
action="store_true", | |
help="Use a HuggingFace sentence transformer", | |
) | |
args = parser.parse_args() | |
embed_sentences( | |
ifname=args.input, | |
encoder_path=args.encoder, | |
token_lang=args.token_lang, | |
bpe_codes=args.bpe_codes, | |
spm_lang=args.spm_lang, | |
hugging_face=args.use_hugging_face, | |
spm_model=args.spm_model, | |
verbose=args.verbose, | |
output=args.output, | |
buffer_size=args.buffer_size, | |
max_tokens=args.max_tokens, | |
max_sentences=args.max_sentences, | |
cpu=args.cpu, | |
fp16=args.fp16, | |
sort_kind=args.sort_kind, | |
) | |