Spaces:
Runtime error
Runtime error
File size: 5,123 Bytes
ed4d993 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
from __future__ import annotations
import os
from copy import deepcopy
from typing import Dict, Optional, Sequence, Union
import voyageai # type: ignore
from langchain_core.callbacks.manager import Callbacks
from langchain_core.documents import Document
from langchain_core.documents.compressor import BaseDocumentCompressor
from langchain_core.pydantic_v1 import SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str
from voyageai.object import RerankingObject # type: ignore
class VoyageAIRerank(BaseDocumentCompressor):
"""Document compressor that uses `VoyageAI Rerank API`."""
client: voyageai.Client = None
aclient: voyageai.AsyncClient = None
"""VoyageAI clients to use for compressing documents."""
voyage_api_key: Optional[SecretStr] = None
"""VoyageAI API key. Must be specified directly or via environment variable
VOYAGE_API_KEY."""
model: str
"""Model to use for reranking."""
top_k: Optional[int] = None
"""Number of documents to return."""
truncation: bool = True
class Config:
arbitrary_types_allowed = True
@root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key exists in environment."""
voyage_api_key = values.get("voyage_api_key") or os.getenv(
"VOYAGE_API_KEY", None
)
if voyage_api_key:
api_key_secretstr = convert_to_secret_str(voyage_api_key)
values["voyage_api_key"] = api_key_secretstr
api_key_str = api_key_secretstr.get_secret_value()
else:
api_key_str = None
values["client"] = voyageai.Client(api_key=api_key_str)
values["aclient"] = voyageai.AsyncClient(api_key=api_key_str)
return values
def _rerank(
self,
documents: Sequence[Union[str, Document]],
query: str,
) -> RerankingObject:
"""Returns an ordered list of documents ordered by their relevance
to the provided query.
Args:
query: The query to use for reranking.
documents: A sequence of documents to rerank.
"""
docs = [
doc.page_content if isinstance(doc, Document) else doc for doc in documents
]
return self.client.rerank(
query=query,
documents=docs,
model=self.model,
top_k=self.top_k,
truncation=self.truncation,
)
async def _arerank(
self,
documents: Sequence[Union[str, Document]],
query: str,
) -> RerankingObject:
"""Returns an ordered list of documents ordered by their relevance
to the provided query.
Args:
query: The query to use for reranking.
documents: A sequence of documents to rerank.
"""
docs = [
doc.page_content if isinstance(doc, Document) else doc for doc in documents
]
return await self.aclient.rerank(
query=query,
documents=docs,
model=self.model,
top_k=self.top_k,
truncation=self.truncation,
)
def compress_documents(
self,
documents: Sequence[Document],
query: str,
callbacks: Optional[Callbacks] = None,
) -> Sequence[Document]:
"""
Compress documents using VoyageAI's rerank API.
Args:
documents: A sequence of documents to compress.
query: The query to use for compressing the documents.
callbacks: Callbacks to run during the compression process.
Returns:
A sequence of compressed documents in relevance_score order.
"""
if len(documents) == 0:
return []
compressed = []
for res in self._rerank(documents, query).results:
doc = documents[res.index]
doc_copy = Document(doc.page_content, metadata=deepcopy(doc.metadata))
doc_copy.metadata["relevance_score"] = res.relevance_score
compressed.append(doc_copy)
return compressed
async def acompress_documents(
self,
documents: Sequence[Document],
query: str,
callbacks: Optional[Callbacks] = None,
) -> Sequence[Document]:
"""
Compress documents using VoyageAI's rerank API.
Args:
documents: A sequence of documents to compress.
query: The query to use for compressing the documents.
callbacks: Callbacks to run during the compression process.
Returns:
A sequence of compressed documents in relevance_score order.
"""
if len(documents) == 0:
return []
compressed = []
for res in (await self._arerank(documents, query)).results:
doc = documents[res.index]
doc_copy = Document(doc.page_content, metadata=deepcopy(doc.metadata))
doc_copy.metadata["relevance_score"] = res.relevance_score
compressed.append(doc_copy)
return compressed
|