File size: 3,064 Bytes
ed4d993
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
from typing import Any, Dict, List, Optional

import numpy as np
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator

LASER_MULTILINGUAL_MODEL: str = "laser2"


class LaserEmbeddings(BaseModel, Embeddings):
    """LASER Language-Agnostic SEntence Representations.
    LASER is a Python library developed by the Meta AI Research team
    and used for creating multilingual sentence embeddings for over 147 languages
    as of 2/25/2024
    See more documentation at:
    * https://github.com/facebookresearch/LASER/
    * https://github.com/facebookresearch/LASER/tree/main/laser_encoders
    * https://arxiv.org/abs/2205.12654

    To use this class, you must install the `laser_encoders` Python package.

    `pip install laser_encoders`
    Example:
        from laser_encoders import LaserEncoderPipeline
        encoder = LaserEncoderPipeline(lang="eng_Latn")
        embeddings = encoder.encode_sentences(["Hello", "World"])
    """

    lang: Optional[str]
    """The language or language code you'd like to use
    If empty, this implementation will default
    to using a multilingual earlier LASER encoder model (called laser2)
    Find the list of supported languages at
    https://github.com/facebookresearch/flores/blob/main/flores200/README.md#languages-in-flores-200
    """

    _encoder_pipeline: Any  # : :meta private:

    class Config:
        """Configuration for this pydantic object."""

        extra = Extra.forbid

    @root_validator()
    def validate_environment(cls, values: Dict) -> Dict:
        """Validate that laser_encoders has been installed."""
        try:
            from laser_encoders import LaserEncoderPipeline

            lang = values.get("lang")
            if lang:
                encoder_pipeline = LaserEncoderPipeline(lang=lang)
            else:
                encoder_pipeline = LaserEncoderPipeline(laser=LASER_MULTILINGUAL_MODEL)
            values["_encoder_pipeline"] = encoder_pipeline

        except ImportError as e:
            raise ImportError(
                "Could not import 'laser_encoders' Python package. "
                "Please install it with `pip install laser_encoders`."
            ) from e
        return values

    def embed_documents(self, texts: List[str]) -> List[List[float]]:
        """Generate embeddings for documents using LASER.

        Args:
            texts: The list of texts to embed.

        Returns:
            List of embeddings, one for each text.
        """
        embeddings: np.ndarray
        embeddings = self._encoder_pipeline.encode_sentences(texts)

        return embeddings.tolist()

    def embed_query(self, text: str) -> List[float]:
        """Generate single query text embeddings using LASER.

        Args:
            text: The text to embed.

        Returns:
            Embeddings for the text.
        """
        query_embeddings: np.ndarray
        query_embeddings = self._encoder_pipeline.encode_sentences([text])
        return query_embeddings.tolist()[0]