Spaces:
Sleeping
Sleeping
#!/usr/bin/python3 | |
# Copyright (c) Facebook, Inc. and its affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the BSD-style license found in the | |
# LICENSE file in the root directory of this source tree. | |
# | |
# LASER Language-Agnostic SEntence Representations | |
# is a toolkit to calculate multilingual sentence embeddings | |
# and to use them for document classification, bitext filtering | |
# and mining | |
# | |
# -------------------------------------------------------- | |
# | |
# Tool to calculate multilingual similarity error rate | |
# on various predefined test sets | |
import os | |
import argparse | |
import pandas | |
import tempfile | |
import numpy as np | |
from pathlib import Path | |
import itertools | |
import logging | |
import sys | |
from typing import List, Tuple, Dict | |
from tabulate import tabulate | |
from collections import defaultdict | |
from xsim import xSIM | |
from embed import embed_sentences, load_model | |
logging.basicConfig( | |
stream=sys.stdout, | |
level=logging.INFO, | |
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", | |
) | |
logger = logging.getLogger("eval") | |
class Eval: | |
def __init__(self, args): | |
self.base_dir = args.base_dir | |
self.corpus = args.corpus | |
self.split = args.corpus_part | |
self.min_sents = args.min_sents | |
self.index_comparison = args.index_comparison | |
self.emb_dimension = args.embedding_dimension | |
self.encoder_args = { | |
k: v | |
for k, v in args._get_kwargs() | |
if k in ["max_sentences", "max_tokens", "cpu", "sort_kind", "verbose"] | |
} | |
self.src_bpe_codes = args.src_bpe_codes | |
self.tgt_bpe_codes = args.tgt_bpe_codes | |
self.src_spm_model = args.src_spm_model | |
self.tgt_spm_model = args.tgt_spm_model | |
logger.info("loading src encoder") | |
self.src_encoder = load_model( | |
args.src_encoder, | |
self.src_spm_model, | |
self.src_bpe_codes, | |
hugging_face=args.use_hugging_face, | |
**self.encoder_args, | |
) | |
if args.tgt_encoder: | |
logger.info("loading tgt encoder") | |
self.tgt_encoder = load_model( | |
args.tgt_encoder, | |
self.tgt_spm_model, | |
self.tgt_bpe_codes, | |
hugging_face=args.use_hugging_face, | |
**self.encoder_args, | |
) | |
else: | |
logger.info("encoding tgt using src encoder") | |
self.tgt_encoder = self.src_encoder | |
self.tgt_bpe_codes = self.src_bpe_codes | |
self.tgt_spm_model = self.src_spm_model | |
self.nway = args.nway | |
self.buffer_size = args.buffer_size | |
self.fp16 = args.fp16 | |
self.margin = args.margin | |
def _embed( | |
self, tmpdir, langs, encoder, spm_model, bpe_codes, tgt_aug_langs=[] | |
) -> List[List[str]]: | |
emb_data = [] | |
for lang in langs: | |
augjson = None | |
fname = f"{lang}.{self.split}" | |
infile = self.base_dir / self.corpus / self.split / fname | |
assert infile.exists(), f"{infile} does not exist" | |
outfile = tmpdir / fname | |
if lang in tgt_aug_langs: | |
fname = f"{lang}_augmented.{self.split}" | |
fjname = f"{lang}_errtype.{self.split}.json" | |
augment_dir = self.base_dir / self.corpus / (self.split + "_augmented") | |
augjson = augment_dir / fjname | |
auginfile = augment_dir / fname | |
assert augjson.exists(), f"{augjson} does not exist" | |
assert auginfile.exists(), f"{auginfile} does not exist" | |
combined_infile = tmpdir / f"combined_{lang}" | |
with open(combined_infile, "w") as newfile: | |
for f in [infile, auginfile]: | |
with open(f) as fin: | |
newfile.write(fin.read()) | |
infile = combined_infile | |
embed_sentences( | |
str(infile), | |
str(outfile), | |
encoder=encoder, | |
spm_model=spm_model, | |
bpe_codes=bpe_codes, | |
token_lang=lang if bpe_codes else "--", | |
buffer_size=self.buffer_size, | |
fp16=self.fp16, | |
**self.encoder_args, | |
) | |
assert ( | |
os.path.isfile(outfile) and os.path.getsize(outfile) > 0 | |
), f"Error encoding {infile}" | |
emb_data.append([lang, infile, outfile, augjson]) | |
return emb_data | |
def _xsim( | |
self, src_emb, src_lang, tgt_emb, tgt_lang, tgt_txt, augjson=None | |
) -> Tuple[int, int, Dict[str, int]]: | |
return xSIM( | |
src_emb, | |
tgt_emb, | |
margin=self.margin, | |
dim=self.emb_dimension, | |
fp16=self.fp16, | |
eval_text=tgt_txt if not self.index_comparison else None, | |
augmented_json=augjson, | |
) | |
def calc_xsim( | |
self, embdir, src_langs, tgt_langs, tgt_aug_langs, err_sum=0, totl_nbex=0 | |
) -> None: | |
outputs = [] | |
src_emb_data = self._embed( | |
embdir, | |
src_langs, | |
self.src_encoder, | |
self.src_spm_model, | |
self.src_bpe_codes, | |
) | |
tgt_emb_data = self._embed( | |
embdir, | |
tgt_langs, | |
self.tgt_encoder, | |
self.tgt_spm_model, | |
self.tgt_bpe_codes, | |
tgt_aug_langs, | |
) | |
aug_df = defaultdict(lambda: defaultdict()) | |
combs = list(itertools.product(src_emb_data, tgt_emb_data)) | |
for (src_lang, _, src_emb, _), (tgt_lang, tgt_txt, tgt_emb, augjson) in combs: | |
if src_lang == tgt_lang: | |
continue | |
err, nbex, aug_report = self._xsim( | |
src_emb, src_lang, tgt_emb, tgt_lang, tgt_txt, augjson | |
) | |
result = round(100 * err / nbex, 2) | |
if tgt_lang in tgt_aug_langs: | |
aug_df[tgt_lang][src_lang] = aug_report | |
if nbex < self.min_sents: | |
result = "skipped" | |
else: | |
err_sum += err | |
totl_nbex += nbex | |
outputs.append( | |
[self.corpus, f"{src_lang}-{tgt_lang}", f"{result}", f"{nbex}"] | |
) | |
outputs.append( | |
[ | |
self.corpus, | |
"average", | |
f"{round(100 * err_sum / totl_nbex, 2)}", | |
f"{len(combs)}", | |
] | |
) | |
print( | |
tabulate( | |
outputs, | |
tablefmt="psql", | |
headers=[ | |
"dataset", | |
"src-tgt", | |
"xsim" + ("(++)" if tgt_aug_langs else ""), | |
"nbex", | |
], | |
) | |
) | |
for tgt_aug_lang in tgt_aug_langs: | |
df = pandas.DataFrame.from_dict(aug_df[tgt_aug_lang]).fillna(0).T | |
print( | |
f"\nAbsolute error under augmented transformations for: {tgt_aug_lang}" | |
) | |
print(f"{tabulate(df, df.columns, floatfmt='.2f', tablefmt='grid')}") | |
def calc_xsim_nway(self, embdir, langs) -> None: | |
err_matrix = np.zeros((len(langs), len(langs))) | |
emb_data = self._embed( | |
embdir, | |
langs, | |
self.src_encoder, | |
self.src_spm_model, | |
self.src_bpe_codes, | |
) | |
for i1, (src_lang, _, src_emb, _) in enumerate(emb_data): | |
for i2, (tgt_lang, tgt_txt, tgt_emb, _) in enumerate(emb_data): | |
if src_lang == tgt_lang: | |
err_matrix[i1, i2] = 0 | |
else: | |
err, nbex, _ = self._xsim( | |
src_emb, src_lang, tgt_emb, tgt_lang, tgt_txt | |
) | |
err_matrix[i1, i2] = 100 * err / nbex | |
df = pandas.DataFrame(err_matrix, columns=langs, index=langs) | |
df.loc["avg"] = df.sum() / float(df.shape[0] - 1) # exclude diagonal in average | |
print(f"\n{tabulate(df, langs, floatfmt='.2f', tablefmt='grid')}\n\n") | |
print(f"Global average: {df.loc['avg'].mean():.2f}") | |
def run_eval(args) -> None: | |
evaluation = Eval(args) | |
tmp_dir = None | |
if args.embed_dir: | |
os.makedirs(args.embed_dir, exist_ok=True) | |
embed_dir = args.embed_dir | |
else: | |
tmp_dir = tempfile.TemporaryDirectory() | |
embed_dir = Path(tmp_dir.name) | |
src_langs = sorted(args.src_langs.split(",")) | |
tgt_aug_langs = sorted(args.tgt_aug_langs.split(",")) if args.tgt_aug_langs else [] | |
if evaluation.nway: | |
evaluation.calc_xsim_nway(embed_dir, src_langs) | |
else: | |
assert ( | |
args.tgt_langs | |
), "Please provide tgt langs when not performing n-way comparison" | |
tgt_langs = sorted(args.tgt_langs.split(",")) | |
evaluation.calc_xsim(embed_dir, src_langs, tgt_langs, tgt_aug_langs) | |
if tmp_dir: | |
tmp_dir.cleanup() # remove temporary directory | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser( | |
description="LASER: multilingual similarity error evaluation" | |
) | |
parser.add_argument( | |
"--base-dir", | |
type=Path, | |
default=None, | |
help="Base directory for evaluation files", | |
required=True, | |
) | |
parser.add_argument( | |
"--corpus", | |
type=str, | |
default=None, | |
help="Name of evaluation corpus", | |
required=True, | |
) | |
parser.add_argument( | |
"--corpus-part", | |
type=str, | |
default=None, | |
help="Specify split of the corpus to use e.g., dev", | |
required=True, | |
) | |
parser.add_argument( | |
"--margin", | |
type=str, | |
default=None, | |
help="Margin for xSIM calculation. See: https://aclanthology.org/P19-1309", | |
) | |
parser.add_argument( | |
"--min-sents", | |
type=int, | |
default=100, | |
help="Only use test sets which have at least N sentences", | |
) | |
parser.add_argument( | |
"--nway", action="store_true", help="Test N-way for corpora which support it" | |
) | |
parser.add_argument( | |
"--embed-dir", | |
type=Path, | |
default=None, | |
help="Store/load embeddings from specified directory (default temporary)", | |
) | |
parser.add_argument( | |
"--index-comparison", | |
action="store_true", | |
help="Use index comparison instead of texts (not recommended when test data contains duplicates)", | |
) | |
parser.add_argument("--src-spm-model", type=str, default=None) | |
parser.add_argument("--tgt-spm-model", type=str, default=None) | |
parser.add_argument( | |
"--src-bpe-codes", | |
type=str, | |
default=None, | |
help="Path to bpe codes for src model", | |
) | |
parser.add_argument( | |
"--tgt-bpe-codes", | |
type=str, | |
default=None, | |
help="Path to bpe codes for tgt model", | |
) | |
parser.add_argument("--src-encoder", type=str, default=None, required=True) | |
parser.add_argument("--tgt-encoder", type=str, default=None) | |
parser.add_argument( | |
"--buffer-size", type=int, default=100, help="Buffer size (sentences)" | |
) | |
parser.add_argument( | |
"--max-tokens", | |
type=int, | |
default=12000, | |
help="Maximum number of tokens to process in a batch", | |
) | |
parser.add_argument( | |
"--max-sentences", | |
type=int, | |
default=None, | |
help="Maximum number of sentences to process in a batch", | |
) | |
parser.add_argument("--cpu", action="store_true", help="Use CPU instead of GPU") | |
parser.add_argument( | |
"--src-langs", | |
type=str, | |
default=None, | |
help="Source-side languages for evaluation", | |
required=True, | |
) | |
parser.add_argument( | |
"--tgt-langs", | |
type=str, | |
default=None, | |
help="Target-side languages for evaluation", | |
) | |
parser.add_argument( | |
"--tgt-aug-langs", | |
type=str, | |
default=None, | |
help="languages with augmented data", | |
required=False, | |
) | |
parser.add_argument( | |
"--fp16", | |
action="store_true", | |
help="Store embedding matrices in fp16 instead of fp32", | |
) | |
parser.add_argument( | |
"--sort-kind", | |
type=str, | |
default="quicksort", | |
choices=["quicksort", "mergesort"], | |
help="Algorithm used to sort batch by length", | |
) | |
parser.add_argument( | |
"--use-hugging-face", | |
action="store_true", | |
help="Use a HuggingFace sentence transformer", | |
) | |
parser.add_argument( | |
"--embedding-dimension", | |
type=int, | |
default=1024, | |
help="Embedding dimension for encoders", | |
) | |
parser.add_argument("-v", "--verbose", action="store_true", help="Detailed output") | |
args = parser.parse_args() | |
run_eval(args) | |