File size: 5,123 Bytes
ed4d993
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
from __future__ import annotations

import os
from copy import deepcopy
from typing import Dict, Optional, Sequence, Union

import voyageai  # type: ignore
from langchain_core.callbacks.manager import Callbacks
from langchain_core.documents import Document
from langchain_core.documents.compressor import BaseDocumentCompressor
from langchain_core.pydantic_v1 import SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str
from voyageai.object import RerankingObject  # type: ignore


class VoyageAIRerank(BaseDocumentCompressor):
    """Document compressor that uses `VoyageAI Rerank API`."""

    client: voyageai.Client = None
    aclient: voyageai.AsyncClient = None
    """VoyageAI clients to use for compressing documents."""
    voyage_api_key: Optional[SecretStr] = None
    """VoyageAI API key. Must be specified directly or via environment variable 
        VOYAGE_API_KEY."""
    model: str
    """Model to use for reranking."""
    top_k: Optional[int] = None
    """Number of documents to return."""
    truncation: bool = True

    class Config:
        arbitrary_types_allowed = True

    @root_validator(pre=True)
    def validate_environment(cls, values: Dict) -> Dict:
        """Validate that api key exists in environment."""
        voyage_api_key = values.get("voyage_api_key") or os.getenv(
            "VOYAGE_API_KEY", None
        )
        if voyage_api_key:
            api_key_secretstr = convert_to_secret_str(voyage_api_key)
            values["voyage_api_key"] = api_key_secretstr

            api_key_str = api_key_secretstr.get_secret_value()
        else:
            api_key_str = None

        values["client"] = voyageai.Client(api_key=api_key_str)
        values["aclient"] = voyageai.AsyncClient(api_key=api_key_str)

        return values

    def _rerank(
        self,
        documents: Sequence[Union[str, Document]],
        query: str,
    ) -> RerankingObject:
        """Returns an ordered list of documents ordered by their relevance
        to the provided query.

        Args:
            query: The query to use for reranking.
            documents: A sequence of documents to rerank.
        """
        docs = [
            doc.page_content if isinstance(doc, Document) else doc for doc in documents
        ]
        return self.client.rerank(
            query=query,
            documents=docs,
            model=self.model,
            top_k=self.top_k,
            truncation=self.truncation,
        )

    async def _arerank(
        self,
        documents: Sequence[Union[str, Document]],
        query: str,
    ) -> RerankingObject:
        """Returns an ordered list of documents ordered by their relevance
        to the provided query.

        Args:
            query: The query to use for reranking.
            documents: A sequence of documents to rerank.
        """
        docs = [
            doc.page_content if isinstance(doc, Document) else doc for doc in documents
        ]
        return await self.aclient.rerank(
            query=query,
            documents=docs,
            model=self.model,
            top_k=self.top_k,
            truncation=self.truncation,
        )

    def compress_documents(
        self,
        documents: Sequence[Document],
        query: str,
        callbacks: Optional[Callbacks] = None,
    ) -> Sequence[Document]:
        """
        Compress documents using VoyageAI's rerank API.

        Args:
            documents: A sequence of documents to compress.
            query: The query to use for compressing the documents.
            callbacks: Callbacks to run during the compression process.

        Returns:
            A sequence of compressed documents in relevance_score order.
        """
        if len(documents) == 0:
            return []

        compressed = []
        for res in self._rerank(documents, query).results:
            doc = documents[res.index]
            doc_copy = Document(doc.page_content, metadata=deepcopy(doc.metadata))
            doc_copy.metadata["relevance_score"] = res.relevance_score
            compressed.append(doc_copy)
        return compressed

    async def acompress_documents(
        self,
        documents: Sequence[Document],
        query: str,
        callbacks: Optional[Callbacks] = None,
    ) -> Sequence[Document]:
        """
        Compress documents using VoyageAI's rerank API.

        Args:
            documents: A sequence of documents to compress.
            query: The query to use for compressing the documents.
            callbacks: Callbacks to run during the compression process.

        Returns:
            A sequence of compressed documents in relevance_score order.
        """
        if len(documents) == 0:
            return []

        compressed = []
        for res in (await self._arerank(documents, query)).results:
            doc = documents[res.index]
            doc_copy = Document(doc.page_content, metadata=deepcopy(doc.metadata))
            doc_copy.metadata["relevance_score"] = res.relevance_score
            compressed.append(doc_copy)
        return compressed