Spaces:
Sleeping
Sleeping
#!/bin/bash | |
# Copyright (c) Facebook, Inc. and its affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the BSD-style license found in the | |
# LICENSE file in the root directory of this source tree. | |
# | |
# LASER Language-Agnostic SEntence Representations | |
# is a toolkit to calculate multilingual sentence embeddings | |
# and to use them for document classification, bitext filtering | |
# and mining | |
# | |
# ------------------------------------------------------- | |
# | |
# This python script installs NLLB LASER2 and LASER3 sentence encoders from Amazon s3 | |
import argparse | |
import logging | |
import os | |
import shutil | |
import sys | |
import tempfile | |
from pathlib import Path | |
import requests | |
from tqdm import tqdm | |
from laser_encoders.language_list import LASER2_LANGUAGE, LASER3_LANGUAGE, SPM_LANGUAGE | |
logging.basicConfig( | |
stream=sys.stdout, | |
level=logging.INFO, | |
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", | |
) | |
logger = logging.getLogger(__name__) | |
class LaserModelDownloader: | |
def __init__(self, model_dir: str = None): | |
if model_dir is None: | |
model_dir = os.path.expanduser("~/.cache/laser_encoders") | |
os.makedirs(model_dir, exist_ok=True) | |
self.model_dir = Path(model_dir) | |
self.base_url = "https://dl.fbaipublicfiles.com/nllb/laser" | |
def download(self, filename: str): | |
# Because on windows os.path.join will use "\" insted of "/", so link would be: | |
# https://dl.fbaipublicfiles.com/nllb/laser\laser2.pt instead of https://dl.fbaipublicfiles.com/nllb/laser/laser2.pt | |
# which results in a failed download. | |
url = f"{self.base_url}/{filename}" | |
local_file_path = os.path.join(self.model_dir, filename) | |
if os.path.exists(local_file_path): | |
logger.info(f" - {filename} already downloaded") | |
else: | |
logger.info(f" - Downloading {filename}") | |
tf = tempfile.NamedTemporaryFile(delete=False) | |
temp_file_path = tf.name | |
with tf: | |
response = requests.get(url, stream=True) | |
total_size = int(response.headers.get("Content-Length", 0)) | |
progress_bar = tqdm(total=total_size, unit_scale=True, unit="B") | |
for chunk in response.iter_content(chunk_size=1024): | |
tf.write(chunk) | |
progress_bar.update(len(chunk)) | |
progress_bar.close() | |
shutil.move(temp_file_path, local_file_path) | |
def get_language_code(self, language_list: dict, lang: str) -> str: | |
try: | |
lang_3_4 = language_list[lang] | |
if isinstance(lang_3_4, list): | |
options = ", ".join(f"'{opt}'" for opt in lang_3_4) | |
raise ValueError( | |
f"Language '{lang}' has multiple options: {options}. Please specify using the 'lang' argument." | |
) | |
return lang_3_4 | |
except KeyError: | |
raise ValueError( | |
f"language name: {lang} not found in language list. Specify a supported language name" | |
) | |
def download_laser2(self): | |
self.download("laser2.pt") | |
self.download("laser2.spm") | |
self.download("laser2.cvocab") | |
def download_laser3(self, lang: str, spm: bool = False): | |
result = self.get_language_code(LASER3_LANGUAGE, lang) | |
if isinstance(result, list): | |
raise ValueError( | |
f"There are script-specific models available for {lang}. Please choose one from the following: {result}" | |
) | |
lang = result | |
self.download(f"laser3-{lang}.v1.pt") | |
if spm: | |
if lang in SPM_LANGUAGE: | |
self.download(f"laser3-{lang}.v1.spm") | |
self.download(f"laser3-{lang}.v1.cvocab") | |
else: | |
self.download(f"laser2.spm") | |
self.download(f"laser2.cvocab") | |
def main(self, args): | |
if args.laser: | |
if args.laser == "laser2": | |
self.download_laser2() | |
elif args.laser == "laser3": | |
self.download_laser3(lang=args.lang, spm=args.spm) | |
else: | |
raise ValueError( | |
f"Unsupported laser model: {args.laser}. Choose either laser2 or laser3." | |
) | |
else: | |
if args.lang in LASER3_LANGUAGE: | |
self.download_laser3(lang=args.lang, spm=args.spm) | |
elif args.lang in LASER2_LANGUAGE: | |
self.download_laser2() | |
else: | |
raise ValueError( | |
f"Unsupported language name: {args.lang}. Please specify a supported language name using --lang." | |
) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description="LASER: Download Laser models") | |
parser.add_argument( | |
"--laser", | |
type=str, | |
help="Laser model to download", | |
) | |
parser.add_argument( | |
"--lang", | |
type=str, | |
help="The language name in FLORES200 format", | |
) | |
parser.add_argument( | |
"--spm", | |
action="store_false", | |
help="Do not download the SPM model?", | |
) | |
parser.add_argument( | |
"--model-dir", type=str, help="The directory to download the models to" | |
) | |
args = parser.parse_args() | |
downloader = LaserModelDownloader(args.model_dir) | |
downloader.main(args) | |