BookWorld / modules /memory.py
alienet's picture
fix
e2ac27d
import sys
sys.path.append("../")
from bw_utils import *
from modules.embedding import get_embedding_model
from langchain_experimental.generative_agents import GenerativeAgentMemory
from langchain.retrievers import TimeWeightedVectorStoreRetriever
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.llms import Tongyi,OpenAI
from langchain_community.docstore import InMemoryDocstore
from langchain_community.vectorstores import FAISS
import faiss
import math
def build_role_agent_memory(type = "ga",**kwargs):
if type == "ga":
llm_name = kwargs["llm_name"]
embedding_name = kwargs["embedding_name"]
db_name = kwargs["db_name"]
language = kwargs["language"] if "language" in kwargs else ""
embedding_model = get_embedding_model(embedding_name, language=language)
index = faiss.IndexFlatL2(len(embedding_model.embed_query("hello world")))
vectorstore = FAISS(
embedding_function=embedding_model,
index=index,
docstore=InMemoryDocstore(),
index_to_docstore_id={},
)
memory_retriever = TimeWeightedVectorStoreRetriever(vectorstore=vectorstore, other_score_keys=["importance"], k=5)
if llm_name.startswith("qwen"):
chat_model = Tongyi(
temperature=0.9,
)
else:
chat_model = OpenAI(
temperature=0.9,
model="gpt-3.5-turbo",
)
agent_memory = RoleMemory_GA(
llm=chat_model,
memory_retriever=memory_retriever,
embedding_model=embedding_model,
memory_decay_rate=0.01
)
return agent_memory
else:
db_name = kwargs["db_name"]
embedding = kwargs["embedding"]
db_type = kwargs["db_type"] if "db_type" in kwargs else "chromadb"
capacity= kwargs["capacity"] if "capacity" in kwargs else 5
agent_memory = RoleMemory(db_name=db_name,
embedding=embedding,
db_type=db_type,
capacity=capacity)
return agent_memory
def relevance_score_fn(score: float) -> float:
return 1.0 - score / math.sqrt(2)
class RoleMemory_GA(GenerativeAgentMemory):
def init_from_data(self,data):
for text in data:
self.add_record(text)
def add_record(self,text):
self.add_memory(text)
def search(self,query,top_k):
fetched_memories = [doc.page_content for doc in self.fetch_memories(query)[:top_k]]
if len(fetched_memories)>=top_k:
print("-Memory Searching...")
print(fetched_memories)
return fetched_memories
def delete_record(self, idx):
pass
class RoleMemory:
def __init__(self,db_name,embedding,db_type = "chroma",capacity = 5,) -> None:
self.idx = 0
self.capacity = capacity
self.db_name = db_name
self.db = build_db([],db_name,db_type,embedding,save_type="temporary")
def init_from_data(self,data):
for text in data:
self.add_record(text)
def add_record(self,text):
self.idx += 1
self.db.add(text, str(self.idx), db_name=self.db_name)
def search(self,query,top_k):
return self.db.search(query, top_k,self.db_name)
def delete_record(self, idx):
self.db.delete(idx)
@property
def len(self):
return self.db.len