Spaces:
Sleeping
Sleeping
#!/bin/bash | |
# Copyright (c) Facebook, Inc. and its affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the BSD-style license found in the | |
# LICENSE file in the root directory of this source tree. | |
# | |
# LASER Language-Agnostic SEntence Representations | |
# is a toolkit to calculate multilingual sentence embeddings | |
# and to use them for document classification, bitext filtering | |
# and mining | |
# | |
# -------------------------------------------------------- | |
# | |
# Python tools for BUCC bitext mining | |
import argparse | |
############################################################################### | |
# | |
# Find te optimal threshold given gold alignments | |
# | |
############################################################################### | |
def BuccOptimize(candidate2score, gold): | |
items = sorted(candidate2score.items(), key=lambda x: -x[1]) | |
ngold = len(gold) | |
nextract = ncorrect = 0 | |
threshold = 0 | |
best_f1 = 0 | |
for i in range(len(items)): | |
nextract += 1 | |
if '\t'.join(items[i][0]) in gold: | |
ncorrect += 1 | |
if ncorrect > 0: | |
precision = ncorrect / nextract | |
recall = ncorrect / ngold | |
f1 = 2 * precision * recall / (precision + recall) | |
if f1 > best_f1: | |
best_f1 = f1 | |
threshold = (items[i][1] + items[i + 1][1]) / 2 | |
return threshold | |
############################################################################### | |
# | |
# Main | |
# | |
############################################################################### | |
parser = argparse.ArgumentParser(description='LASER: tools for BUCC bitext mining') | |
parser.add_argument('--encoding', default='utf-8', | |
help='character encoding for input/output') | |
parser.add_argument('--src-lang', required=True, | |
help='the source language id') | |
parser.add_argument('--trg-lang', required=True, | |
help='the target language id') | |
parser.add_argument('--bucc-texts', required=True, | |
help='Base name of the text files (language added)') | |
parser.add_argument('--bucc-ids', required=True, | |
help='Base name of the ID files (language added)') | |
parser.add_argument('--candidates', required=True, | |
help='File name of candidate alignments') | |
parser.add_argument('--gold', default=None, | |
help='File name of gold alignments') | |
parser.add_argument('--threshold', type=float, default=-1, | |
help='Threshold (used with --output)') | |
parser.add_argument('--output', default=None, | |
help='File name of output alignments which are below threshold') | |
parser.add_argument('--verbose', action='store_true', | |
help='Detailed output') | |
args = parser.parse_args() | |
print('LASER: tools for BUCC bitext mining') | |
assert (args.gold or args.threshold > 0) \ | |
and not (args.gold and args.threshold > 0), \ | |
'Either "--gold" or "--threshold" must be specified' | |
if args.verbose: | |
print(' - reading sentences and IDs') | |
src_sent2id, trg_sent2id = {}, {} | |
for lang, sent2id in (args.src_lang, src_sent2id), (args.trg_lang, trg_sent2id): | |
repeated = set() | |
with open(args.bucc_texts + '.' + lang, encoding=args.encoding, errors='surrogateescape') as f: | |
sentences = [line.strip() for line in f] | |
with open(args.bucc_ids + '.' + lang, encoding=args.encoding, errors='surrogateescape') as f: | |
ids = [line.strip() for line in f] | |
for id, sent in zip(ids, sentences): | |
if sent in sent2id: | |
repeated.add(sent) | |
else: | |
sent2id[sent] = id | |
for sent in repeated: | |
del sent2id[sent] | |
if args.verbose: | |
print(' - reading candidates {}'.format(args.candidates)) | |
candidate2score = {} | |
# id2txt = {} | |
with open(args.candidates, encoding=args.encoding, errors='surrogateescape') as f: | |
for line in f: | |
score, src, trg = line.split('\t') | |
score = float(score) | |
src = src.strip() | |
trg = trg.strip() | |
if src in src_sent2id and trg in trg_sent2id: | |
src_id = src_sent2id[src] | |
trg_id = trg_sent2id[trg] | |
score = max(score, candidate2score.get((src_id, trg_id), score)) | |
candidate2score[(src_id, trg_id)] = score | |
# id2txt[src_id + '\t' + trg_id] = src + '\t' + trg | |
def BuccExtract(cand2score, th, fname): | |
if fname: | |
of = open(fname, 'w', encoding=args.encoding) | |
bitexts = [] | |
for (src, trg), score in cand2score.items(): | |
if score >= th: | |
bitexts.append(src + '\t' + trg) | |
if fname: | |
of.write(src + '\t' + trg + '\n') | |
if fname: | |
of.close() | |
return bitexts | |
if args.gold: | |
if args.verbose: | |
print(' - optimizing threshold on gold alignments {}'.format(args.gold)) | |
if args.output: | |
print(' - extracted bitext are written into {:s}'.format(args.output)) | |
gold = {line.strip() for line in open(args.gold)} | |
threshold = BuccOptimize(candidate2score, gold) | |
bitexts = BuccExtract(candidate2score, threshold, args.output) | |
ncorrect = len(gold.intersection(bitexts)) | |
if ncorrect > 0: | |
precision = ncorrect / len(bitexts) | |
recall = ncorrect / len(gold) | |
f1 = 2*precision*recall / (precision + recall) | |
else: | |
precision = recall = f1 = 0 | |
print(' - best threshold={:f}: precision={:.2f}, recall={:.2f}, F1={:.2f}' | |
.format(threshold, 100*precision, 100*recall, 100*f1)) | |
if args.threshold > 0: | |
if args.verbose: | |
print(' - extracting bitexts for threshold {:f} into {:s}'.format(args.threshold, args.output)) | |
BuccExtract(candidate2score, args.threshold, args.output) | |