File size: 5,385 Bytes
05d3571
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
#!/bin/bash
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
#
# LASER  Language-Agnostic SEntence Representations
# is a toolkit to calculate multilingual sentence embeddings
# and to use them for document classification, bitext filtering
# and mining
#
# -------------------------------------------------------
#
# This python script installs NLLB LASER2 and LASER3 sentence encoders from Amazon s3

import argparse
import logging
import os
import shutil
import sys
import tempfile
from pathlib import Path

import requests
from tqdm import tqdm

from laser_encoders.language_list import LASER2_LANGUAGE, LASER3_LANGUAGE, SPM_LANGUAGE

logging.basicConfig(
    stream=sys.stdout,
    level=logging.INFO,
    format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
)
logger = logging.getLogger(__name__)


class LaserModelDownloader:
    def __init__(self, model_dir: str = None):
        if model_dir is None:
            model_dir = os.path.expanduser("~/.cache/laser_encoders")
            os.makedirs(model_dir, exist_ok=True)

        self.model_dir = Path(model_dir)
        self.base_url = "https://dl.fbaipublicfiles.com/nllb/laser"

    def download(self, filename: str):
        # Because on windows os.path.join will use "\" insted of "/", so link would be:
        # https://dl.fbaipublicfiles.com/nllb/laser\laser2.pt instead of https://dl.fbaipublicfiles.com/nllb/laser/laser2.pt
        # which results in a failed download.
        url = f"{self.base_url}/{filename}"
        local_file_path = os.path.join(self.model_dir, filename)

        if os.path.exists(local_file_path):
            logger.info(f" - {filename} already downloaded")
        else:
            logger.info(f" - Downloading {filename}")

            tf = tempfile.NamedTemporaryFile(delete=False)
            temp_file_path = tf.name

            with tf:
                response = requests.get(url, stream=True)
                total_size = int(response.headers.get("Content-Length", 0))
                progress_bar = tqdm(total=total_size, unit_scale=True, unit="B")

                for chunk in response.iter_content(chunk_size=1024):
                    tf.write(chunk)
                    progress_bar.update(len(chunk))
                progress_bar.close()

            shutil.move(temp_file_path, local_file_path)

    def get_language_code(self, language_list: dict, lang: str) -> str:
        try:
            lang_3_4 = language_list[lang]
            if isinstance(lang_3_4, list):
                options = ", ".join(f"'{opt}'" for opt in lang_3_4)
                raise ValueError(
                    f"Language '{lang}' has multiple options: {options}. Please specify using the 'lang' argument."
                )
            return lang_3_4
        except KeyError:
            raise ValueError(
                f"language name: {lang} not found in language list. Specify a supported language name"
            )

    def download_laser2(self):
        self.download("laser2.pt")
        self.download("laser2.spm")
        self.download("laser2.cvocab")

    def download_laser3(self, lang: str, spm: bool = False):
        result = self.get_language_code(LASER3_LANGUAGE, lang)

        if isinstance(result, list):
            raise ValueError(
                f"There are script-specific models available for {lang}. Please choose one from the following: {result}"
            )

        lang = result
        self.download(f"laser3-{lang}.v1.pt")
        if spm:
            if lang in SPM_LANGUAGE:
                self.download(f"laser3-{lang}.v1.spm")
                self.download(f"laser3-{lang}.v1.cvocab")
            else:
                self.download(f"laser2.spm")
                self.download(f"laser2.cvocab")

    def main(self, args):
        if args.laser:
            if args.laser == "laser2":
                self.download_laser2()
            elif args.laser == "laser3":
                self.download_laser3(lang=args.lang, spm=args.spm)
            else:
                raise ValueError(
                    f"Unsupported laser model: {args.laser}. Choose either laser2 or laser3."
                )
        else:
            if args.lang in LASER3_LANGUAGE:
                self.download_laser3(lang=args.lang, spm=args.spm)
            elif args.lang in LASER2_LANGUAGE:
                self.download_laser2()
            else:
                raise ValueError(
                    f"Unsupported language name: {args.lang}. Please specify a supported language name using --lang."
                )


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="LASER: Download Laser models")
    parser.add_argument(
        "--laser",
        type=str,
        help="Laser model to download",
    )
    parser.add_argument(
        "--lang",
        type=str,
        help="The language name in FLORES200 format",
    )
    parser.add_argument(
        "--spm",
        action="store_false",
        help="Do not download the SPM model?",
    )
    parser.add_argument(
        "--model-dir", type=str, help="The directory to download the models to"
    )
    args = parser.parse_args()
    downloader = LaserModelDownloader(args.model_dir)
    downloader.main(args)