Spaces:
Sleeping
Sleeping
#!/usr/bin/python3 | |
# Copyright (c) Facebook, Inc. and its affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the BSD-style license found in the | |
# LICENSE file in the root directory of this source tree. | |
# | |
# LASER Language-Agnostic SEntence Representations | |
# is a toolkit to calculate multilingual sentence embeddings | |
# and to use them for document classification, bitext filtering | |
# and mining | |
# | |
# -------------------------------------------------------- | |
# | |
# tools for indexing and search with FAISS | |
import faiss | |
import os.path | |
import sys | |
import numpy as np | |
#------------------------------------------------------------- | |
# Get list of fnames: | |
# - we loop over the list of given languages | |
# - for each language, we also check if there are splitted files .%03d | |
def SplitFnames(par_fname, langs): | |
fnames = [] | |
for l in langs: | |
fname = par_fname + '.' + l | |
if os.path.isfile(fname): | |
fnames.append(fname) | |
for i in range(1000): | |
fname = par_fname + '.' + l + '.{:03d}'.format(i) | |
if os.path.isfile(fname): | |
fnames.append(fname) | |
if len(fnames) == 0: | |
print("ERROR: no embeddings found in {:s}*".format(par_fname)) | |
sys.exit(1) | |
return fnames | |
def SplitOpen(par_fname, langs, dim, dtype, verbose=False): | |
M = [] | |
nf = 0 | |
nc = 0 | |
print('Reading sentence embeddings') | |
print(' - memory mapped files {:s}'.format(par_fname)) | |
for fname in SplitFnames(par_fname, langs): | |
n = int(os.path.getsize(fname) / dim / np.dtype(dtype).itemsize) | |
if verbose: | |
print(' - {:s}: {:d} x {:d}'.format(fname, n, dim)) | |
Mi = np.memmap(fname, mode='r', dtype=dtype, shape=(n, dim)) | |
nc += n | |
nf += 1 | |
M.append(Mi) | |
print(' - total of {:d} files: {:d} x {:d}'.format(nf, nc, dim)) | |
return M | |
def SplitAccess(M, idx): | |
i = idx | |
for Mi in M: | |
n = Mi.shape[0] | |
if i < n: | |
return Mi[i,:] | |
i -= n | |
print('ERROR: index {:d} is too large form memory mapped files'.format(idx)) | |
sys.exit(1) | |
############################################################################### | |
# create an FAISS index on the given data | |
def IndexCreate(dname, idx_type, | |
verbose=False, normalize=True, save_index=False, dim=1024): | |
assert idx_type == 'FlatL2', 'only FlatL2 index is currently supported' | |
x = np.fromfile(dname, dtype=np.float32, count=-1) | |
nbex = x.shape[0] // dim | |
print(' - embedding: {:s} {:d} examples of dim {:d}' | |
.format(dname, nbex, dim)) | |
x.resize(nbex, dim) | |
print(' - creating FAISS index') | |
idx = faiss.IndexFlatL2(dim) | |
if normalize: | |
faiss.normalize_L2(x) | |
idx.add(x) | |
if save_index: | |
iname = 'TODO' | |
print(' - saving index into ' + iname) | |
faiss.write_index(idx, iname) | |
return x, idx | |
############################################################################### | |
# search closest vector for all languages pairs and calculate error rate | |
def IndexSearchMultiple(data, idx, langs, verbose=False, texts=None, print_errors=False): | |
nl = len(data) | |
nbex = data[0].shape[0] | |
err = np.zeros((nl, nl)).astype(float) | |
ref = np.linspace(0, nbex-1, nbex).astype(int) # [0, nbex) | |
if verbose: | |
if texts is None: | |
print('Calculating similarity error (indices):') | |
else: | |
print('Calculating similarity error (textual):') | |
for i1 in range(nl): | |
for i2 in range(nl): | |
if i1 != i2: | |
D, I = idx[i2].search(data[i1], 1) | |
if texts: # do textual comparison | |
e1 = 0 | |
for p in range(I.shape[0]): | |
if texts[i2][p] != texts[i2][I[p,0]]: | |
e1 += 1 | |
if print_errors: | |
print('Error {:s}\n {:s}' | |
.format(texts[i2][p].strip(), texts[i2][I[p,0]].strip())) | |
err[i1, i2] = e1 / nbex | |
else: # do index based comparision | |
err[i1, i2] \ | |
= (nbex - np.equal(I.reshape(nbex), ref) | |
.astype(int).sum()) / nbex | |
if verbose: | |
print(' - similarity error {:s}/{:s}: {:5.2f}%' | |
.format(langs[i1], langs[i2], | |
100.0 * err[i1, i2])) | |
return err | |
############################################################################### | |
# print confusion matrix | |
def IndexPrintConfusionMatrix(err, langs): | |
nl = len(langs) | |
assert nl == err.shape[0], 'size of errror matrix doesn not match' | |
print('Confusion matrix:') | |
print('{:8s}'.format('langs'), end='') | |
for i2 in range(nl): | |
print('{:8s} '.format(langs[i2]), end='') | |
print('{:8s}'.format('avg')) | |
for i1 in range(nl): | |
print('{:3s}'.format(langs[i1]), end='') | |
for i2 in range(nl): | |
print('{:8.2f}%'.format(100 * err[i1, i2]), end='') | |
print('{:8.2f}%'.format(100 * err[i1, :].sum() / (nl-1))) | |
print('avg', end='') | |
for i2 in range(nl): | |
print('{:8.2f}%'.format(100 * err[:, i2].sum() / (nl-1)), end='') | |
# global average | |
print('{:8.2f}%'.format(100 * err.sum() / (nl-1) / nl)) | |
############################################################################### | |
# Load an FAISS index | |
def IndexLoad(idx_name, nprobe, gpu=False): | |
print('Reading FAISS index') | |
print(' - index: {:s}'.format(idx_name)) | |
index = faiss.read_index(idx_name) | |
print(' - found {:d} sentences of dim {:d}'.format(index.ntotal, index.d)) | |
print(' - setting nbprobe to {:d}'.format(nprobe)) | |
if gpu: | |
print(' - transfer index to %d GPUs ' % faiss.get_num_gpus()) | |
#co = faiss.GpuMultipleClonerOptions() | |
#co.shard = True | |
index = faiss.index_cpu_to_all_gpus(index) # co=co | |
faiss.GpuParameterSpace().set_index_parameter(index, 'nprobe', nprobe) | |
return index | |
############################################################################### | |
# Opens a text file with the sentences corresponding to the indices used | |
# by an FAISS index | |
# We also need the reference files with the byte offsets to the beginning | |
# of each sentence | |
# optionnally: array with number of words per sentence | |
# All arrays are memory mapped | |
def IndexTextOpen(txt_fname): | |
print('Reading text corpus') | |
print(' - texts: {:s}'.format(txt_fname)) | |
txt_mmap = np.memmap(txt_fname, mode='r', dtype=np.uint8) | |
fname = txt_fname.replace('.txt', '.ref.bin32') | |
if os.path.isfile(fname): | |
print(' - sentence start offsets (32 bit): {}'.format(fname)) | |
ref_mmap = np.memmap(fname, mode='r', dtype=np.uint32) | |
else: | |
fname = txt_fname.replace('.txt', '.ref.bin64') | |
if os.path.isfile(fname): | |
print(' - sentence start offsets (64 bit): {}'.format(fname)) | |
ref_mmap = np.memmap(fname, mode='r', dtype=np.uint64) | |
else: | |
print('ERROR: no file with sentence start offsets found') | |
sys.exit(1) | |
print(' - found {:d} sentences'.format(ref_mmap.shape[0])) | |
nbw_mmap = None | |
fname = txt_fname.replace('.txt', '.nw.bin8') | |
if os.path.isfile(fname): | |
print(' - word counts: {:s}'.format(fname)) | |
nbw_mmap = np.memmap(fname, mode='r', dtype=np.uint8) | |
M = None | |
fname = txt_fname.replace('.txt', '.meta') | |
if os.path.isfile(fname): | |
M = [] | |
n = 0 | |
print(' - metafile: {:s}'.format(fname)) | |
with open(fname, 'r') as fp: | |
for line in fp: | |
fields = line.strip().split() | |
if len(fields) != 2: | |
print('ERROR: format error in meta file') | |
sys.exit(1) | |
n += int(fields[1]) | |
M.append({'lang': fields[0], 'n': n}) | |
print(' - found {:d} languages:'.format(len(M)), end='') | |
for L in M: | |
print(' {:s}'.format(L['lang']), end='') | |
print('') | |
return txt_mmap, ref_mmap, nbw_mmap, M | |
############################################################################### | |
# Return the text for the given index | |
def IndexTextQuery(txt_mmap, ref_mmap, idx): | |
p = int(ref_mmap[idx]) # get starting byte position | |
i = 0 | |
dim = 10000 # max sentence length in bytes | |
b = bytearray(dim) | |
# find EOL | |
while txt_mmap[p+i] != 10 and i < dim: | |
b[i] = txt_mmap[p+i] | |
i += 1 | |
return b[0:i].decode('utf-8') | |
############################################################################### | |
# Search the [k] nearest vectors of [x] in the given index | |
# and return the text lines | |
def IndexSearchKNN(index, x, T, R, kmax=1, Dmax=1.0, dedup=True): | |
D, I = index.search(x, kmax) | |
prev = {} # for depuplication | |
res = [] | |
for n in range(x.shape[0]): | |
for i in range(kmax): | |
txt = IndexTextQuery(T, R, I[n, i]) | |
if (dedup and txt not in prev) and D[n, i] <= Dmax: | |
prev[txt] = 1 | |
res.append([txt, D[n, i]]) | |
return res | |