#!/usr/bin/python3 # Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. # # LASER Language-Agnostic SEntence Representations # is a toolkit to calculate multilingual sentence embeddings # and to use them for document classification, bitext filtering # and mining # # -------------------------------------------------------- import logging import os import re import sys import warnings from collections import namedtuple from pathlib import Path import numpy as np import torch import torch.nn as nn from fairseq.data.dictionary import Dictionary from fairseq.models.transformer import Embedding, TransformerEncoder from fairseq.modules import LayerNorm from laser_encoders.download_models import LaserModelDownloader from laser_encoders.language_list import LASER2_LANGUAGE, LASER3_LANGUAGE from laser_encoders.laser_tokenizer import LaserTokenizer, initialize_tokenizer SPACE_NORMALIZER = re.compile(r"\s+") Batch = namedtuple("Batch", "srcs tokens lengths") logging.basicConfig( stream=sys.stdout, level=logging.INFO, format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", ) logger = logging.getLogger("embed") class SentenceEncoder: def __init__( self, model_path, max_sentences=None, max_tokens=None, spm_vocab=None, spm_model=None, cpu=False, fp16=False, verbose=False, sort_kind="quicksort", ): if verbose: logger.info(f"loading encoder: {model_path}") self.spm_model = spm_model if self.spm_model: self.tokenizer = LaserTokenizer(spm_model=Path(self.spm_model)) self.use_cuda = torch.cuda.is_available() and not cpu self.max_sentences = max_sentences self.max_tokens = max_tokens if self.max_tokens is None and self.max_sentences is None: self.max_sentences = 1 state_dict = torch.load(model_path) if "params" in state_dict: self.encoder = LaserLstmEncoder(**state_dict["params"]) self.encoder.load_state_dict(state_dict["model"]) self.dictionary = state_dict["dictionary"] self.prepend_bos = False self.left_padding = False else: self.encoder = LaserTransformerEncoder(state_dict, spm_vocab) self.dictionary = self.encoder.dictionary.indices self.prepend_bos = state_dict["cfg"]["model"].prepend_bos self.left_padding = state_dict["cfg"]["model"].left_pad_source del state_dict self.bos_index = self.dictionary[""] = 0 self.pad_index = self.dictionary[""] = 1 self.eos_index = self.dictionary[""] = 2 self.unk_index = self.dictionary[""] = 3 if fp16: self.encoder.half() if self.use_cuda: if verbose: logger.info("transfer encoder to GPU") self.encoder.cuda() self.encoder.eval() self.sort_kind = sort_kind def __call__(self, text_or_batch): if self.spm_model: text_or_batch = self.tokenizer(text_or_batch) if isinstance(text_or_batch, str): text_or_batch = [text_or_batch] return self.encode_sentences(text_or_batch) else: raise ValueError( "Either initialize the encoder with an spm_model or pre-tokenize and use the encode_sentences method." ) def _process_batch(self, batch): tokens = batch.tokens lengths = batch.lengths if self.use_cuda: tokens = tokens.cuda() lengths = lengths.cuda() with torch.no_grad(): sentemb = self.encoder(tokens, lengths)["sentemb"] embeddings = sentemb.detach().cpu().numpy() return embeddings def _tokenize(self, line): tokens = SPACE_NORMALIZER.sub(" ", line).strip().split() ntokens = len(tokens) if self.prepend_bos: ids = torch.LongTensor(ntokens + 2) ids[0] = self.bos_index for i, token in enumerate(tokens): ids[i + 1] = self.dictionary.get(token, self.unk_index) ids[ntokens + 1] = self.eos_index else: ids = torch.LongTensor(ntokens + 1) for i, token in enumerate(tokens): ids[i] = self.dictionary.get(token, self.unk_index) ids[ntokens] = self.eos_index return ids def _make_batches(self, lines): tokens = [self._tokenize(line) for line in lines] lengths = np.array([t.numel() for t in tokens]) indices = np.argsort(-lengths, kind=self.sort_kind) def batch(tokens, lengths, indices): toks = tokens[0].new_full((len(tokens), tokens[0].shape[0]), self.pad_index) if not self.left_padding: for i in range(len(tokens)): toks[i, : tokens[i].shape[0]] = tokens[i] else: for i in range(len(tokens)): toks[i, -tokens[i].shape[0] :] = tokens[i] return ( Batch(srcs=None, tokens=toks, lengths=torch.LongTensor(lengths)), indices, ) batch_tokens, batch_lengths, batch_indices = [], [], [] ntokens = nsentences = 0 for i in indices: if nsentences > 0 and ( (self.max_tokens is not None and ntokens + lengths[i] > self.max_tokens) or (self.max_sentences is not None and nsentences == self.max_sentences) ): yield batch(batch_tokens, batch_lengths, batch_indices) ntokens = nsentences = 0 batch_tokens, batch_lengths, batch_indices = [], [], [] batch_tokens.append(tokens[i]) batch_lengths.append(lengths[i]) batch_indices.append(i) ntokens += tokens[i].shape[0] nsentences += 1 if nsentences > 0: yield batch(batch_tokens, batch_lengths, batch_indices) def encode_sentences(self, sentences, normalize_embeddings=False): indices = [] results = [] for batch, batch_indices in self._make_batches(sentences): indices.extend(batch_indices) encoded_batch = self._process_batch(batch) if normalize_embeddings: # Perform L2 normalization on the embeddings norms = np.linalg.norm(encoded_batch, axis=1, keepdims=True) encoded_batch = encoded_batch / norms results.append(encoded_batch) return np.vstack(results)[np.argsort(indices, kind=self.sort_kind)] class LaserTransformerEncoder(TransformerEncoder): def __init__(self, state_dict, vocab_path): self.dictionary = Dictionary.load(vocab_path) if any( k in state_dict["model"] for k in ["encoder.layer_norm.weight", "layer_norm.weight"] ): self.dictionary.add_symbol("") cfg = state_dict["cfg"]["model"] self.sentemb_criterion = cfg.sentemb_criterion self.pad_idx = self.dictionary.pad_index self.bos_idx = self.dictionary.bos_index embed_tokens = Embedding( len(self.dictionary), cfg.encoder_embed_dim, self.pad_idx, ) super().__init__(cfg, self.dictionary, embed_tokens) if "decoder.version" in state_dict["model"]: self._remove_decoder_layers(state_dict) if "layer_norm.weight" in state_dict["model"]: self.layer_norm = LayerNorm(cfg.encoder_embed_dim) self.load_state_dict(state_dict["model"]) def _remove_decoder_layers(self, state_dict): for key in list(state_dict["model"].keys()): if not key.startswith( ( "encoder.layer_norm", "encoder.layers", "encoder.embed", "encoder.version", ) ): del state_dict["model"][key] else: renamed_key = key.replace("encoder.", "") state_dict["model"][renamed_key] = state_dict["model"].pop(key) def forward(self, src_tokens, src_lengths): encoder_out = super().forward(src_tokens, src_lengths) if isinstance(encoder_out, dict): x = encoder_out["encoder_out"][0] # T x B x C else: x = encoder_out[0] if self.sentemb_criterion == "cls": cls_indices = src_tokens.eq(self.bos_idx).t() sentemb = x[cls_indices, :] else: padding_mask = src_tokens.eq(self.pad_idx).t().unsqueeze(-1) if padding_mask.any(): x = x.float().masked_fill_(padding_mask, float("-inf")).type_as(x) sentemb = x.max(dim=0)[0] return {"sentemb": sentemb} class LaserLstmEncoder(nn.Module): def __init__( self, num_embeddings, padding_idx, embed_dim=320, hidden_size=512, num_layers=1, bidirectional=False, left_pad=True, padding_value=0.0, ): super().__init__() self.num_layers = num_layers self.bidirectional = bidirectional self.hidden_size = hidden_size self.padding_idx = padding_idx self.embed_tokens = nn.Embedding( num_embeddings, embed_dim, padding_idx=self.padding_idx ) self.lstm = nn.LSTM( input_size=embed_dim, hidden_size=hidden_size, num_layers=num_layers, bidirectional=bidirectional, ) self.left_pad = left_pad self.padding_value = padding_value self.output_units = hidden_size if bidirectional: self.output_units *= 2 def forward(self, src_tokens, src_lengths): bsz, seqlen = src_tokens.size() # embed tokens x = self.embed_tokens(src_tokens) # B x T x C -> T x B x C x = x.transpose(0, 1) # pack embedded source tokens into a PackedSequence packed_x = nn.utils.rnn.pack_padded_sequence(x, src_lengths.data.tolist()) # apply LSTM if self.bidirectional: state_size = 2 * self.num_layers, bsz, self.hidden_size else: state_size = self.num_layers, bsz, self.hidden_size h0 = x.data.new(*state_size).zero_() c0 = x.data.new(*state_size).zero_() packed_outs, (final_hiddens, final_cells) = self.lstm(packed_x, (h0, c0)) # unpack outputs and apply dropout x, _ = nn.utils.rnn.pad_packed_sequence( packed_outs, padding_value=self.padding_value ) assert list(x.size()) == [seqlen, bsz, self.output_units] if self.bidirectional: def combine_bidir(outs): return torch.cat( [ torch.cat([outs[2 * i], outs[2 * i + 1]], dim=0).view( 1, bsz, self.output_units ) for i in range(self.num_layers) ], dim=0, ) final_hiddens = combine_bidir(final_hiddens) final_cells = combine_bidir(final_cells) encoder_padding_mask = src_tokens.eq(self.padding_idx).t() # Set padded outputs to -inf so they are not selected by max-pooling padding_mask = src_tokens.eq(self.padding_idx).t().unsqueeze(-1) if padding_mask.any(): x = x.float().masked_fill_(padding_mask, float("-inf")).type_as(x) # Build the sentence embedding by max-pooling over the encoder outputs sentemb = x.max(dim=0)[0] return { "sentemb": sentemb, "encoder_out": (x, final_hiddens, final_cells), "encoder_padding_mask": encoder_padding_mask if encoder_padding_mask.any() else None, } def initialize_encoder( lang: str = None, model_dir: str = None, spm: bool = True, laser: str = None, ): downloader = LaserModelDownloader(model_dir) if laser is not None: if laser == "laser3": lang = downloader.get_language_code(LASER3_LANGUAGE, lang) downloader.download_laser3(lang=lang, spm=spm) file_path = f"laser3-{lang}.v1" elif laser == "laser2": downloader.download_laser2() file_path = "laser2" else: raise ValueError( f"Unsupported laser model: {laser}. Choose either laser2 or laser3." ) else: if lang in LASER3_LANGUAGE: lang = downloader.get_language_code(LASER3_LANGUAGE, lang) downloader.download_laser3(lang=lang, spm=spm) file_path = f"laser3-{lang}.v1" elif lang in LASER2_LANGUAGE: downloader.download_laser2() file_path = "laser2" else: raise ValueError( f"Unsupported language name: {lang}. Please specify a supported language name." ) model_dir = downloader.model_dir model_path = os.path.join(model_dir, f"{file_path}.pt") spm_vocab = os.path.join(model_dir, f"{file_path}.cvocab") if not os.path.exists(spm_vocab): # if there is no cvocab for the laser3 lang use laser2 cvocab spm_vocab = os.path.join(model_dir, "laser2.cvocab") return SentenceEncoder(model_path=model_path, spm_vocab=spm_vocab, spm_model=None) class LaserEncoderPipeline: def __init__( self, lang: str = None, model_dir: str = None, spm: bool = True, laser: str = None, ): if laser == "laser2" and lang is not None: warnings.warn( "Warning: The 'lang' parameter is optional when using 'laser2'. It will be ignored." ) if laser == "laser3" and lang is None: raise ValueError("For 'laser3', the 'lang' parameter is required.") if laser is None and lang is None: raise ValueError("Either 'laser' or 'lang' should be provided.") self.tokenizer = initialize_tokenizer( lang=lang, model_dir=model_dir, laser=laser ) self.encoder = initialize_encoder( lang=lang, model_dir=model_dir, spm=spm, laser=laser ) def encode_sentences( self, sentences: list, normalize_embeddings: bool = False ) -> list: """ Tokenizes and encodes a list of sentences. Args: - sentences (list of str): List of sentences to tokenize and encode. Returns: - List of embeddings for each sentence. """ tokenized_sentences = [ self.tokenizer.tokenize(sentence) for sentence in sentences ] return self.encoder.encode_sentences(tokenized_sentences, normalize_embeddings)