|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import List, Optional |
|
|
|
from camel.memories.base import AgentMemory, BaseContextCreator |
|
from camel.memories.blocks import ChatHistoryBlock, VectorDBBlock |
|
from camel.memories.records import ContextRecord, MemoryRecord |
|
from camel.storages import BaseKeyValueStorage, BaseVectorStorage |
|
from camel.types import OpenAIBackendRole |
|
|
|
|
|
class ChatHistoryMemory(AgentMemory): |
|
r"""An agent memory wrapper of :obj:`ChatHistoryBlock`. |
|
|
|
Args: |
|
context_creator (BaseContextCreator): A model context creator. |
|
storage (BaseKeyValueStorage, optional): A storage backend for storing |
|
chat history. If `None`, an :obj:`InMemoryKeyValueStorage` |
|
will be used. (default: :obj:`None`) |
|
window_size (int, optional): The number of recent chat messages to |
|
retrieve. If not provided, the entire chat history will be |
|
retrieved. (default: :obj:`None`) |
|
""" |
|
|
|
def __init__( |
|
self, |
|
context_creator: BaseContextCreator, |
|
storage: Optional[BaseKeyValueStorage] = None, |
|
window_size: Optional[int] = None, |
|
) -> None: |
|
if window_size is not None and not isinstance(window_size, int): |
|
raise TypeError("`window_size` must be an integer or None.") |
|
if window_size is not None and window_size < 0: |
|
raise ValueError("`window_size` must be non-negative.") |
|
self._context_creator = context_creator |
|
self._window_size = window_size |
|
self._chat_history_block = ChatHistoryBlock(storage=storage) |
|
|
|
def retrieve(self) -> List[ContextRecord]: |
|
return self._chat_history_block.retrieve(self._window_size) |
|
|
|
def write_records(self, records: List[MemoryRecord]) -> None: |
|
self._chat_history_block.write_records(records) |
|
|
|
def get_context_creator(self) -> BaseContextCreator: |
|
return self._context_creator |
|
|
|
def clear(self) -> None: |
|
self._chat_history_block.clear() |
|
|
|
|
|
class VectorDBMemory(AgentMemory): |
|
r"""An agent memory wrapper of :obj:`VectorDBBlock`. This memory queries |
|
messages stored in the vector database. Notice that the most recent |
|
messages will not be added to the context. |
|
|
|
Args: |
|
context_creator (BaseContextCreator): A model context creator. |
|
storage (BaseVectorStorage, optional): A vector storage storage. If |
|
`None`, an :obj:`QdrantStorage` will be used. |
|
(default: :obj:`None`) |
|
retrieve_limit (int, optional): The maximum number of messages |
|
to be added into the context. (default: :obj:`3`) |
|
""" |
|
|
|
def __init__( |
|
self, |
|
context_creator: BaseContextCreator, |
|
storage: Optional[BaseVectorStorage] = None, |
|
retrieve_limit: int = 3, |
|
) -> None: |
|
self._context_creator = context_creator |
|
self._retrieve_limit = retrieve_limit |
|
self._vectordb_block = VectorDBBlock(storage=storage) |
|
|
|
self._current_topic: str = "" |
|
|
|
def retrieve(self) -> List[ContextRecord]: |
|
return self._vectordb_block.retrieve( |
|
self._current_topic, |
|
limit=self._retrieve_limit, |
|
) |
|
|
|
def write_records(self, records: List[MemoryRecord]) -> None: |
|
|
|
for record in records: |
|
if record.role_at_backend == OpenAIBackendRole.USER: |
|
self._current_topic = record.message.content |
|
self._vectordb_block.write_records(records) |
|
|
|
def get_context_creator(self) -> BaseContextCreator: |
|
return self._context_creator |
|
|
|
|
|
class LongtermAgentMemory(AgentMemory): |
|
r"""An implementation of the :obj:`AgentMemory` abstract base class for |
|
augmenting ChatHistoryMemory with VectorDBMemory. |
|
|
|
Args: |
|
context_creator (BaseContextCreator): A model context creator. |
|
chat_history_block (Optional[ChatHistoryBlock], optional): A chat |
|
history block. If `None`, a :obj:`ChatHistoryBlock` will be used. |
|
(default: :obj:`None`) |
|
vector_db_block (Optional[VectorDBBlock], optional): A vector database |
|
block. If `None`, a :obj:`VectorDBBlock` will be used. |
|
(default: :obj:`None`) |
|
retrieve_limit (int, optional): The maximum number of messages |
|
to be added into the context. (default: :obj:`3`) |
|
""" |
|
|
|
def __init__( |
|
self, |
|
context_creator: BaseContextCreator, |
|
chat_history_block: Optional[ChatHistoryBlock] = None, |
|
vector_db_block: Optional[VectorDBBlock] = None, |
|
retrieve_limit: int = 3, |
|
) -> None: |
|
self.chat_history_block = chat_history_block or ChatHistoryBlock() |
|
self.vector_db_block = vector_db_block or VectorDBBlock() |
|
self.retrieve_limit = retrieve_limit |
|
self._context_creator = context_creator |
|
self._current_topic: str = "" |
|
|
|
def get_context_creator(self) -> BaseContextCreator: |
|
r"""Returns the context creator used by the memory. |
|
|
|
Returns: |
|
BaseContextCreator: The context creator used by the memory. |
|
""" |
|
return self._context_creator |
|
|
|
def retrieve(self) -> List[ContextRecord]: |
|
r"""Retrieves context records from both the chat history and the vector |
|
database. |
|
|
|
Returns: |
|
List[ContextRecord]: A list of context records retrieved from both |
|
the chat history and the vector database. |
|
""" |
|
chat_history = self.chat_history_block.retrieve() |
|
vector_db_retrieve = self.vector_db_block.retrieve( |
|
self._current_topic, self.retrieve_limit |
|
) |
|
return chat_history[:1] + vector_db_retrieve + chat_history[1:] |
|
|
|
def write_records(self, records: List[MemoryRecord]) -> None: |
|
r"""Converts the provided chat messages into vector representations and |
|
writes them to the vector database. |
|
|
|
Args: |
|
records (List[MemoryRecord]): Messages to be added to the vector |
|
database. |
|
""" |
|
self.vector_db_block.write_records(records) |
|
self.chat_history_block.write_records(records) |
|
|
|
for record in records: |
|
if record.role_at_backend == OpenAIBackendRole.USER: |
|
self._current_topic = record.message.content |
|
|
|
def clear(self) -> None: |
|
r"""Removes all records from the memory.""" |
|
self.chat_history_block.clear() |
|
self.vector_db_block.clear() |
|
|