File size: 11,180 Bytes
05d3571
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
import contextlib
import gzip
import logging
import re
import subprocess
import tempfile
from collections import defaultdict
from pathlib import Path
from typing import Callable, Dict, Iterable, List, NamedTuple, Type

from cc_net.jsonql import open_remote_file, open_write
from cc_net.process_wet_file import CCSegmentsReader
from typing import Sequence
import functools
import multiprocessing

BUFFER_SIZE = "32G"
SORT_PARALLEL = 8

KNOWN_VERSIONS = ["v1.0.0", "v1.0.beta", "v1.0.alpha"]


class NormalizedBitextPtr(NamedTuple):
    lang_pair: str
    line_no: int
    segment: str
    digest: str
    ptr_start: int
    ptr_end: int
    score: float


class Bitext(NamedTuple):
    lang_pair: str
    line_no: int
    score: float
    text: str


class SimpleBitext(NamedTuple):
    line_no: int
    score: float
    text: str


WEB_PAT = re.compile(r"https?:[^ \n]* ")
WEB_REPL = "WEB "

WEB2_PAT = re.compile(r"https?:[^ \n]*\n")
WEB2_REPL = "WEB\n"


def clean_content(raw_content: str) -> str:
    # We need to clean all the content, because otherwise there is no way for
    # the user to know if we need to clean it or not.
    par = raw_content
    par = par.replace("</s>", ". ")
    par = par.replace("\t", " ")
    par = re.sub(WEB_PAT, WEB_REPL, par, count=0)
    par = re.sub(WEB2_PAT, WEB2_REPL, par, count=0)
    return par


def get_typed_parser(cls: Type) -> Callable:
    types = cls.__annotations__.values()

    def parser(line: str) -> NamedTuple:
        parts = line.rstrip("\n").split("\t")
        assert len(parts) == len(
            types
        ), f"Print size mismatch expected the following columns {cls.__annotations__} got: {parts}"
        return cls(*(t(p) for t, p in zip(types, parts)))

    return parser


def open_read(file: Path) -> Iterable[str]:
    if file.suffix == ".gz":
        reader = gzip.open(file, "rt")
    else:
        reader = open(file, "rt")
    with reader as f:
        for line in f:
            yield line


def dl(outdir: Path = Path("data"), version: str = KNOWN_VERSIONS[0], parallelism: int = 8):
    """
    Download bitext pointers from FAIR dataset and extract corresponding CC snippets.
    - version: Specific version to download
    - outdir: Directory where the data should go. Files will be in {outdir}/{version}/raw/
    """
    assert version in KNOWN_VERSIONS, f"Unknown version {version}, chose from {KNOWN_VERSIONS}"
    metadata_dir = f"https://dl.fbaipublicfiles.com/laser/CCMatrix/{version}"
    file_list = [l.strip() for l in open_remote_file(metadata_dir + "/list.txt")]
    outdir.mkdir(exist_ok=True)
    outdir = outdir / version / "raw"
    outdir.mkdir(exist_ok=True, parents=True)

    dlf = functools.partial(dl_file, metadata_dir, outdir)
    # list(map(dlf, file_list))
    with multiprocessing.Pool(parallelism) as pool:
        pool.map(dlf, file_list)


def get_documents(segment: str) -> Dict[str, str]:
    return {d["digest"]: d["raw_content"] for d in CCSegmentsReader([segment])}


def dl_file(metadata_dir: str, outdir: Path, file: str):
    metadata = "/".join((metadata_dir, file))
    parser = get_typed_parser(NormalizedBitextPtr)
    found_bitext, missed_bitext, skipped_line = 0, 0, 0
    segment = ""
    segment_downloads: Dict[str, int] = defaultdict(int)
    raw_documents: Dict[str, str] = {}
    cleaned_documents: Dict[str, str] = {}

    outfile = outdir / file
    if outfile.exists():
        return
    o = FileWriterWithTmp(outfile)
    for i, line in enumerate(open_remote_file(metadata)):
        try:
            bitext: NormalizedBitextPtr = parser(line)
            # Add some more assert in case the line is invalid but still parse
            assert bitext.segment.startswith("crawl-data/")
            assert bitext.digest.startswith("sha1:")
        except AssertionError:
            logging.error(f"Skipping line {i}: {line}")
            skipped_line += 1
            continue

        if not segment or bitext.segment != segment:
            segment = bitext.segment
            segment_downloads[segment] += 1
            # Load segment in RAM, purge document cache
            raw_documents = get_documents(segment)
            cleaned_documents = {}

        raw_doc = raw_documents.get(bitext.digest)
        if raw_doc is None:
            logging.error(f"Document not found: {bitext.digest} in {segment}")
            missed_bitext += 1
            continue

        clean_doc = cleaned_documents.get(bitext.digest)
        if clean_doc is None:
            clean_doc = clean_content(raw_doc)
            cleaned_documents[bitext.digest] = clean_doc

        text = clean_doc[bitext.ptr_start : bitext.ptr_end]
        score = getattr(bitext, "score", 0.0)
        bt = Bitext(bitext.lang_pair, bitext.line_no, score, text)
        print(*bt, sep="\t", file=o)

    o.close(True)
    logging.info(f"Found {found_bitext} sentences, missed {missed_bitext} sentences.")
    if skipped_line > 0:
        logging.error(f"Skipped {skipped_line} unparsable lines")
    expected_dl = len(segment_downloads)
    actual_dl = sum(segment_downloads.values())

    if actual_dl != expected_dl:
        logging.error(
            f"Some segments where downloaded twice. Total dl: {actual_dl}, distinct dl: {expected_dl}"
        )


def _tmp(file: Path) -> Path:
    tmp_dir = file.parent
    prefix = file.name.split(".", 1)[0] + "."
    suffix = ".tmp." + file.name[len(prefix) :]
    _, tmp_path = tempfile.mkstemp(dir=tmp_dir, prefix=prefix, suffix=suffix)
    return Path(tmp_path)


class FileWriterWithTmp:
    def __init__(self, file: Path):
        self.file = file
        self.tmp_file = _tmp(file)
        # We don't want to make FileWriterWithTmp a ContextManager
        self.handle = open_write(self.tmp_file).__enter__()

    def write(self, data) -> int:
        return self.handle.write(data)

    def close(self, success: bool = False):
        self.handle.close()
        if success:
            self.tmp_file.rename(self.file)


def transpose_file(outdir: Path, file: Path) -> None:
    sentinel_file = file.with_suffix(".transposed")
    if sentinel_file.exists():
        return
    outputs: Dict[str, FileWriterWithTmp] = {}
    parser = get_typed_parser(Bitext)
    success = False
    try:
        for line in open_read(file):
            bt: Bitext = parser(line)
            lang_pair = bt.lang_pair
            if bt.lang_pair not in outputs:
                assert (
                    "/" in lang_pair
                ), f"Invalid lang pair '{lang_pair}' should be 'src-trg/src' or 'src-trg/trg'"
                (outdir / f"{lang_pair}").mkdir(exist_ok=True, parents=True)
                o = FileWriterWithTmp(outdir / f"{lang_pair}_{file.name}")
                outputs[lang_pair] = o
            simple_bt = SimpleBitext(bt.line_no, bt.score, bt.text)
            print(*simple_bt, sep="\t", file=outputs[lang_pair])
        success = True
    finally:
        for o in outputs.values():
            o.close(success)
        if success:
            sentinel_file.write_text("\n".join(str(o.file) for o in outputs.values()))
            # file.unlink()


def sort_files(outdir: Path, lang_pair_dir: Path, lang: str) -> Path:
    out = outdir / lang_pair_dir.name / f"{lang}.txt"
    if out.exists():
        return out

    files: List[Path] = []
    for f in lang_pair_dir.iterdir():
        if not f.suffix == ".gz":
            continue
        if f.name.split("_")[0] != lang:
            continue
        files.append(f)

    print(f"Found {len(files)} files for lang '{lang}' in {lang_pair_dir}: {files}")
    assert len(files) > 0

    (outdir / lang_pair_dir.name).mkdir(exist_ok=True, parents=True)
    tmp_out = _tmp(out)
    
    unzipped_files = []
    for f in files:
        subprocess.check_call(["gunzip", "-k", str(f)])
        unzipped_files.append(str(f)[:-3])

    sort_cmd = [
        "sort",
        "-nk1",
        f"--parallel={SORT_PARALLEL}",
        f"--buffer-size={BUFFER_SIZE}",
        "--output",
        str(tmp_out),
        ] + unzipped_files
    subprocess.check_call(sort_cmd)
    tmp_out.rename(out)
    return out


def finalize(
    outdir: Path = Path("data"), version: str = KNOWN_VERSIONS[0], pairs: Sequence[str] = []
) -> None:
    """From the downloaded raw text files, extract the bitexts, sorted by language pair.
    Assumes 'dl' has been run with the same outdir and version before.

    - version: Specific version to download
    - outdir: Directory where the data should go. Files will be in {outdir}/{version}/bitext/
    - pairs: List of language pairs you are interested in. Defaults to all.
    """
    raw_dir = outdir / version / "raw"
    if not raw_dir.is_dir():
        cmd = f"python {__file__} dl --outdir {outdir} --version {version}"
        assert raw_dir.is_dir(), f"Dir not found {raw_dir}. Did you run following command?\n{cmd}"

    raw_files = list(raw_dir.glob("*.gz"))
    split_dir = outdir / version / "split_by_lang"
    split_dir.mkdir(exist_ok=True, parents=True)
    tr = functools.partial(transpose_file, split_dir)
    with multiprocessing.Pool() as pool:
        pool.map(tr, raw_files)

    bitext_dir = outdir / version / "bitext"
    bitext_dir.mkdir(exist_ok=True, parents=True)
    if pairs:
        pair_dirs = []
        for pair in pairs:
            assert (
                len(pair.split("-")) == 2
            ), f"Invalid pair '{pair}', should be 'src-trg'"
            pair_dir = split_dir / pair
            assert (
                pair_dir.is_dir()
            ), f"Dir {pair_dir} not found for lang pair '{pair}'. Is the pair valid ?"
            pair_dirs.append(pair_dir)
    else:
        pair_dirs = [d for d in split_dir.iterdir() if d.is_dir()]

    for pair_dir in pair_dirs:
        src, trg = pair_dir.name.split("-")
        src_file = sort_files(bitext_dir, pair_dir, src)
        trg_file = sort_files(bitext_dir, pair_dir, trg)
        validate(src_file, trg_file)


def validate(src_file: Path, trg_file: Path) -> None:
    """Checks that the segments in the given batch are valid."""
    lines_src, lines_trg, found_pairs = 0, 0, 0
    parser = get_typed_parser(SimpleBitext)
    with open(src_file) as src_f, open(trg_file) as trg_f:
        src_l = src_f.readline()
        trg_l = trg_f.readline()
        while src_l and trg_l:
            src: SimpleBitext = parser(src_l)
            trg: SimpleBitext = parser(trg_l)
            if src.line_no <= trg.line_no:
                lines_src += 1
                src_l = src_f.readline()
            if trg.line_no <= src.line_no:
                lines_trg += 1
                trg_l = trg_f.readline()
            if trg.line_no == src.line_no:
                found_pairs += 1

    if found_pairs == lines_src and found_pairs == lines_trg:
        logging.info(
            f"Validated {src_file} and {trg_file}. Found {found_pairs} bitexts."
        )
    else:
        logging.error(
            f"Validated {src_file} and {trg_file}. "
            f"Found {found_pairs} bitexts, from {lines_src} in {src_file} and {lines_trg} in {trg_file}"
        )


if __name__ == "__main__":
    import func_argparse

    func_argparse.main(dl, finalize)