import json import logging import os import shutil from itertools import islice from typing import Iterator, List, Optional, Sequence, Tuple from langchain.storage import create_kv_docstore from langchain_core.documents import Document as LCDocument from langchain_core.stores import BaseStore, ByteStore from pie_datasets import Dataset, DatasetDict from .pie_document_store import PieDocumentStore logger = logging.getLogger(__name__) class BasicPieDocumentStore(PieDocumentStore): """PIE Document store that uses a client to store and retrieve documents.""" def __init__( self, client: Optional[BaseStore[str, LCDocument]] = None, byte_store: Optional[ByteStore] = None, ): if byte_store is not None: client = create_kv_docstore(byte_store) elif client is None: raise Exception("You must pass a `byte_store` parameter.") self.client = client def mget(self, keys: Sequence[str]) -> List[LCDocument]: return self.client.mget(keys) def mset(self, items: Sequence[Tuple[str, LCDocument]]) -> None: self.client.mset(items) def mdelete(self, keys: Sequence[str]) -> None: self.client.mdelete(keys) def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]: return self.client.yield_keys(prefix=prefix) def _save_to_directory(self, path: str, batch_size: Optional[int] = None, **kwargs) -> None: all_doc_ids = [] all_metadata = [] pie_documents_path = os.path.join(path, "pie_documents") if os.path.exists(pie_documents_path): # remove existing directory logger.warning(f"Removing existing directory: {pie_documents_path}") shutil.rmtree(pie_documents_path) os.makedirs(pie_documents_path, exist_ok=True) doc_ids_iter = iter(self.client.yield_keys()) while batch_doc_ids := list(islice(doc_ids_iter, batch_size or 1000)): all_doc_ids.extend(batch_doc_ids) docs = self.client.mget(batch_doc_ids) pie_docs = [] for doc in docs: pie_doc = doc.metadata[self.METADATA_KEY_PIE_DOCUMENT] pie_docs.append(pie_doc) all_metadata.append( {k: v for k, v in doc.metadata.items() if k != self.METADATA_KEY_PIE_DOCUMENT} ) pie_dataset = Dataset.from_documents(pie_docs) DatasetDict({"train": pie_dataset}).to_json(path=pie_documents_path) if len(all_doc_ids) > 0: doc_ids_path = os.path.join(path, "doc_ids.json") with open(doc_ids_path, "w") as f: json.dump(all_doc_ids, f) if len(all_metadata) > 0: metadata_path = os.path.join(path, "metadata.json") with open(metadata_path, "w") as f: json.dump(all_metadata, f) def _load_from_directory(self, path: str, **kwargs) -> None: pie_documents_path = os.path.join(path, "pie_documents") if not os.path.exists(pie_documents_path): logger.warning( f"Directory {pie_documents_path} does not exist, don't load any documents." ) return None pie_dataset = DatasetDict.from_json(data_dir=pie_documents_path) pie_docs = pie_dataset["train"] metadata_path = os.path.join(path, "metadata.json") if os.path.exists(metadata_path): with open(metadata_path, "r") as f: all_metadata = json.load(f) else: logger.warning(f"File {metadata_path} does not exist, don't load any metadata.") all_metadata = [{} for _ in pie_docs] docs = [ self.wrap(pie_doc, **metadata) for pie_doc, metadata in zip(pie_docs, all_metadata) ] doc_ids_path = os.path.join(path, "doc_ids.json") if os.path.exists(doc_ids_path): with open(doc_ids_path, "r") as f: all_doc_ids = json.load(f) else: logger.warning(f"File {doc_ids_path} does not exist, don't load any document ids.") all_doc_ids = [doc.id for doc in pie_docs] self.client.mset(zip(all_doc_ids, docs)) logger.info(f"Loaded {len(docs)} documents from {path} into docstore")