Spaces:
Runtime error
Runtime error
File size: 2,684 Bytes
ed4d993 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
from typing import Any, List, Optional
import aiohttp
import requests
from langchain_core.callbacks import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
class ChaindeskRetriever(BaseRetriever):
"""`Chaindesk API` retriever."""
datastore_url: str
top_k: Optional[int]
api_key: Optional[str]
def __init__(
self,
datastore_url: str,
top_k: Optional[int] = None,
api_key: Optional[str] = None,
):
self.datastore_url = datastore_url
self.api_key = api_key
self.top_k = top_k
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
response = requests.post(
self.datastore_url,
json={
"query": query,
**({"topK": self.top_k} if self.top_k is not None else {}),
},
headers={
"Content-Type": "application/json",
**(
{"Authorization": f"Bearer {self.api_key}"}
if self.api_key is not None
else {}
),
},
)
data = response.json()
return [
Document(
page_content=r["text"],
metadata={"source": r["source"], "score": r["score"]},
)
for r in data["results"]
]
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
async with aiohttp.ClientSession() as session:
async with session.request(
"POST",
self.datastore_url,
json={
"query": query,
**({"topK": self.top_k} if self.top_k is not None else {}),
},
headers={
"Content-Type": "application/json",
**(
{"Authorization": f"Bearer {self.api_key}"}
if self.api_key is not None
else {}
),
},
) as response:
data = await response.json()
return [
Document(
page_content=r["text"],
metadata={"source": r["source"], "score": r["score"]},
)
for r in data["results"]
]
|