BookWorld / modules /world_agent.py
alienet's picture
first commit
e636070
import sys
sys.path.append("../")
import csv
from typing import Any, Dict, List, Optional, Literal
from bw_utils import *
from modules.embedding import get_embedding_model
class WorldAgent:
# Init
def __init__(self,
world_file_path: str,
location_file_path: str,
map_file_path: Optional[str] = "",
world_description: str = "",
llm_name: str = "gpt-4o-mini",
llm = None,
embedding_name: str = "bge",
db_type: str = "chroma",
language: str = "zh",
):
if llm == None:
llm = get_models(llm_name)
self.llm = llm
self.world_info: Dict[str, Any] = load_json_file(world_file_path)
self.world_name: str = self.world_info["world_name"]
self.language: str = language
self.description:str = self.world_info["description"] if world_description == "" else world_description
source = self.world_info["source"]
self.locations_info: Dict[str, Any] = {}
self.locations: List[str] = []
self.history: List[str] = []
self.edges: Dict[tuple, int] = {} # 地点间距离
self.prompts: List[Dict] = []
self.init_from_file(map_file_path = map_file_path,
location_file_path = location_file_path)
self.init_prompt()
if 'embedding' not in globals():
global embedding
embedding = get_embedding_model(embedding_name, language=language)
self.world_data,self.world_settings = build_world_agent_data(world_file_path = world_file_path, max_words = 50)
self.db_name = clean_collection_name(f"settings_{source}_{embedding_name}")
self.db = build_db(data = [row for row in self.world_data],
db_name = self.db_name,
db_type = db_type,
embedding = embedding)
def init_from_file(self, map_file_path: str, location_file_path: str, default_distance: int = 1):
if map_file_path and os.path.exists(map_file_path):
valid_locations = load_json_file(location_file_path) if "locations" not in load_json_file(location_file_path) else load_json_file(location_file_path)["locations"]
with open(map_file_path, mode='r') as file:
csv_reader = csv.reader(file)
locations = next(csv_reader)[1:]
for row in csv_reader:
loc1 = row[0]
if loc1 not in valid_locations:
print(f"Warning: The location {loc1} does not exist")
continue
self.locations_info[loc1] = valid_locations[loc1]
self.locations.append(loc1)
distances = row[1:]
for i, distance in enumerate(distances):
loc2 = locations[i]
if loc2 not in valid_locations:
print(f"Warning: The location {loc2} does not exist")
continue
if distance != '0': # Skip self-loops
self._add_edge(loc1, loc2, int(distance))
else:
valid_locations = load_json_file(location_file_path) if "locations" not in load_json_file(location_file_path) else load_json_file(location_file_path)["locations"]
for loc1 in valid_locations:
self.locations_info[loc1] = valid_locations[loc1]
self.locations.append(loc1)
for loc2 in valid_locations:
if loc2 != loc1:
self._add_edge(loc1, loc2, default_distance)
def init_prompt(self,):
if self.language == "zh":
from modules.prompt.world_agent_prompt_zh import ENVIROMENT_INTERACTION_PROMPT,NPC_INTERACTION_PROMPT,SCRIPT_INSTRUCTION_PROMPT,SCRIPT_ATTENTION_PROMPT,DECIDE_NEXT_ACTOR_PROMPT,GENERATE_INTERVENTION_PROMPT,UPDATE_EVENT_PROMPT,LOCATION_PROLOGUE_PROMPT,SELECT_SCREEN_ACTORS_PROMPT,JUDGE_IF_ENDED_PROMPT,LOG2STORY_PROMPT
else:
from modules.prompt.world_agent_prompt_en import ENVIROMENT_INTERACTION_PROMPT,NPC_INTERACTION_PROMPT,SCRIPT_INSTRUCTION_PROMPT,SCRIPT_ATTENTION_PROMPT,DECIDE_NEXT_ACTOR_PROMPT,GENERATE_INTERVENTION_PROMPT,UPDATE_EVENT_PROMPT,LOCATION_PROLOGUE_PROMPT,SELECT_SCREEN_ACTORS_PROMPT,JUDGE_IF_ENDED_PROMPT,LOG2STORY_PROMPT
self._ENVIROMENT_INTERACTION_PROMPT = ENVIROMENT_INTERACTION_PROMPT
self._NPC_INTERACTION_PROMPT = NPC_INTERACTION_PROMPT
self._SCRIPT_INSTRUCTION_PROMPT = SCRIPT_INSTRUCTION_PROMPT
self._SCRIPT_ATTENTION = SCRIPT_ATTENTION_PROMPT
self._DECIDE_NEXT_ACTOR_PROMPT= DECIDE_NEXT_ACTOR_PROMPT
self._LOCATION_PROLOGUE_PROMPT = LOCATION_PROLOGUE_PROMPT
self._GENERATE_INTERVENTION_PROMPT = GENERATE_INTERVENTION_PROMPT
self._UPDATE_EVENT_PROMPT = UPDATE_EVENT_PROMPT
self._SELECT_SCREEN_ACTORS_PROMPT = SELECT_SCREEN_ACTORS_PROMPT
self._JUDGE_IF_ENDED_PROMPT = JUDGE_IF_ENDED_PROMPT
self._LOG2STORY_PROMPT = LOG2STORY_PROMPT
# Agent
def update_event(self,
cur_event: str,
intervention:str,
history_text: str,
script: str = ""):
prompt = self._UPDATE_EVENT_PROMPT.format(**{
"event":cur_event,
"intervention":intervention,
"history":history_text
})
if script:
prompt = self._SCRIPT_ATTENTION.format(script = script) + prompt
new_event = self.llm.chat(prompt)
self.record(new_event, prompt)
return new_event
def decide_next_actor(self,
history_text: str,
roles_info_text: str,
script: str = "",
event:str = ""):
prompt = self._DECIDE_NEXT_ACTOR_PROMPT.format(**{
"roles_info":roles_info_text,
"history_text":history_text,
})
max_tries = 3
for _ in range(max_tries):
try:
response = self.llm.chat(prompt)
break
except Exception as e:
print(f"Parsing failure! Error:", e)
print(response)
role_code = response
self.prompts.append({"prompt":prompt,
"response":f"{role_code}"})
return role_code
def judge_if_ended(self,history_text):
prompt = self._JUDGE_IF_ENDED_PROMPT.format(**{
"history":history_text
})
max_tries = 3
response = {"if_end":True, "detail":""}
for _ in range(max_tries):
try:
response.update(json_parser(self.llm.chat(prompt)))
break
except Exception as e:
print(f"Parsing failure! Error:", e)
print(response)
return response["if_end"],response["detail"]
def decide_screen_actors(self,roles_info_text, history_text, event, previous_role_codes):
prompt = self._SELECT_SCREEN_ACTORS_PROMPT.format(**{
"roles_info":roles_info_text,
"history_text":history_text,
"event":event,
"previous_role_codes":previous_role_codes
})
response = self.llm.chat(prompt)
role_codes = eval(response)
return role_codes
def generate_location_prologue(self,
location_code,
history_text,
event,
location_info_text):
prompt = self._LOCATION_PROLOGUE_PROMPT.format(**{
"location_name":self.locations_info[location_code]["location_name"],
"location_description":self.locations_info[location_code]["location_name"],
"location_info":location_info_text,
"history_text":history_text,
"event":event,
"world_description":self.description
})
response = self.llm.chat(prompt)
self.record(detail = response,prompt = prompt)
return "\n"+response
def enviroment_interact(self,
action_maker_name: str,
action: str,
action_detail: str,
location_code: str):
references = self.retrieve_references(query = action_detail)
prompt = self._ENVIROMENT_INTERACTION_PROMPT.format(**
{
"role_name":action_maker_name,
"action":action,
"action_detail":action_detail,
"world_description":self.description,
"location":location_code,
"location_description":self.locations_info[location_code]["detail"],
"references":references,
}
)
response = "无事发生。" if self.language == "zh" else "Nothing happens."
for i in range(3):
try:
response = self.llm.chat(prompt)
if response:
break
except Exception as e:
print("Enviroment Interaction failed! {i}th tries. Error:", e)
self.record(response, prompt)
return response
def npc_interact(self,
action_maker_name: str,
action_detail: str,
location_name: str,
target_name: str):
references = self.retrieve_references(query = action_detail)
prompt = self._NPC_INTERACTION_PROMPT.format(**
{
"role_name":action_maker_name,
"action_detail":action_detail,
"world_description":self.description,
"target":target_name,
"references":references,
"location":location_name
}
)
npc_interaction = {"if_end_interaction":True,"detail":"无事发生。"} if self.language == "zh" else {"if_end_interaction":True,"detail":"Nothing happens"}
try:
npc_interaction = json_parser(self.llm.chat(prompt))
response = npc_interaction["detail"]
self.record(response, prompt)
except Exception as e:
print("Enviroment Interaction failed!",e)
return npc_interaction
def get_script_instruction(self,
roles_info_text: str,
event: str,
history_text: str,
script: str,
last_progress: str):
prompt = self._SCRIPT_INSTRUCTION_PROMPT.format(**{
"roles_info":roles_info_text,
"event":event,
"history_text":history_text,
"script":script,
"last_progress":last_progress
})
max_tries = 3
instruction = {}
for i in range(max_tries):
response = self.llm.chat(prompt)
try:
instruction = json_parser(response)
break
except Exception as e:
print(f"Parsing failure! {i+1}th tries. Error:", e)
print(response)
self.record(response, prompt)
return instruction
def generate_event(self,roles_info_text: str, event: str, history_text: str):
prompt = self._GENERATE_INTERVENTION_PROMPT.format(**{
"world_description":self.description,
"roles_info":roles_info_text,
"history_text":history_text
})
response = self.llm.chat(prompt)
self.record(response, prompt)
return response
def generate_script(self, roles_info_text: str, event: str, history_text: str):
prompt = self._GENERATE_INTERVENTION_PROMPT.format(**{
"world_description":self.description,
"roles_info":roles_info_text,
"history_text":history_text
})
response = self.llm.chat(prompt)
self.record(response, prompt)
return response
def log2story(self,logs):
prompt = self._LOG2STORY_PROMPT.format(**{
"logs":logs
})
response = self.llm.chat(prompt)
return response
# Other
def record(self, detail: str, prompt: str = ""):
if prompt:
self.prompts.append({"prompt":prompt,
"response":detail})
self.history.append(detail)
def add_location_during_simulation(self, location: str, detail: str):
self.locations.append(location)
self.locations_info[location] = {
'location_code': location,
"location_name": location,
'description': '',
'detail':detail
}
for loc in self.locations:
if loc != location:
self._add_edge(loc, location, 1)
self._add_edge(location,loc, 1)
return
def retrieve_references(self, query: str, top_k = 3, max_words = 100):
if self.db is None:
return ""
references = "\n".join(self.db.search(query, top_k,self.db_name))
references = references[:max_words]
return references
def find_location_name(self, code: str):
return self.locations_info[code]["location_name"]
def _add_location(self, code: str, location_info: Dict[str, Any]):
self.locations_info[code] = location_info
def _add_edge(self, code1: str, code2: str, distance: int):
self.edges[(code1,code2)] = distance
self.edges[(code2,code1)] = distance
def get_distance(self, code1: str, code2: str):
if (code1,code2) in self.edges:
return self.edges[(code1,code2)]
else:
return None
def __getstate__(self):
state = {key: value for key, value in self.__dict__.items()
if isinstance(value, (str, int, list, dict, float, bool, type(None)))
and (key not in ['llm','embedding','db','locations_info','edges','world_data','world_settings']
and "PROMPT" not in key)
}
return state
def __setstate__(self, state):
self.__dict__.update(state)
def save_to_file(self, root_dir):
filename = os.path.join(root_dir, f"./world_agent.json")
save_json_file(filename, self.__getstate__() )
def load_from_file(self, root_dir):
filename = os.path.join(root_dir, f"./world_agent.json")
state = load_json_file(filename)
self.__setstate__(state)