Spaces:
Runtime error
Runtime error
"""Wrapper around Embedchain Retriever.""" | |
from __future__ import annotations | |
from typing import Any, Iterable, List, Optional | |
from langchain_core.callbacks import CallbackManagerForRetrieverRun | |
from langchain_core.documents import Document | |
from langchain_core.retrievers import BaseRetriever | |
class EmbedchainRetriever(BaseRetriever): | |
"""`Embedchain` retriever.""" | |
client: Any | |
"""Embedchain Pipeline.""" | |
def create(cls, yaml_path: Optional[str] = None) -> EmbedchainRetriever: | |
""" | |
Create a EmbedchainRetriever from a YAML configuration file. | |
Args: | |
yaml_path: Path to the YAML configuration file. If not provided, | |
a default configuration is used. | |
Returns: | |
An instance of EmbedchainRetriever. | |
""" | |
from embedchain import Pipeline | |
# Create an Embedchain Pipeline instance | |
if yaml_path: | |
client = Pipeline.from_config(yaml_path=yaml_path) | |
else: | |
client = Pipeline() | |
return cls(client=client) | |
def add_texts( | |
self, | |
texts: Iterable[str], | |
) -> List[str]: | |
"""Run more texts through the embeddings and add to the retriever. | |
Args: | |
texts: Iterable of strings/URLs to add to the retriever. | |
Returns: | |
List of ids from adding the texts into the retriever. | |
""" | |
ids = [] | |
for text in texts: | |
_id = self.client.add(text) | |
ids.append(_id) | |
return ids | |
def _get_relevant_documents( | |
self, query: str, *, run_manager: CallbackManagerForRetrieverRun | |
) -> List[Document]: | |
res = self.client.search(query) | |
docs = [] | |
for r in res: | |
docs.append( | |
Document( | |
page_content=r["context"], | |
metadata={ | |
"source": r["metadata"]["url"], | |
"document_id": r["metadata"]["doc_id"], | |
}, | |
) | |
) | |
return docs | |