|
import sys |
|
sys.path.append("../") |
|
import csv |
|
from typing import Any, Dict, List, Optional, Literal |
|
from bw_utils import * |
|
from modules.embedding import get_embedding_model |
|
|
|
class WorldAgent: |
|
|
|
def __init__(self, |
|
world_file_path: str, |
|
location_file_path: str, |
|
map_file_path: Optional[str] = "", |
|
world_description: str = "", |
|
llm_name: str = "gpt-4o-mini", |
|
llm = None, |
|
embedding_name: str = "bge", |
|
db_type: str = "chroma", |
|
language: str = "zh", |
|
): |
|
if llm == None: |
|
llm = get_models(llm_name) |
|
self.llm = llm |
|
self.world_info: Dict[str, Any] = load_json_file(world_file_path) |
|
self.world_name: str = self.world_info["world_name"] |
|
self.language: str = language |
|
self.description:str = self.world_info["description"] if world_description == "" else world_description |
|
source = self.world_info["source"] |
|
|
|
self.locations_info: Dict[str, Any] = {} |
|
self.locations: List[str] = [] |
|
self.history: List[str] = [] |
|
self.edges: Dict[tuple, int] = {} |
|
self.prompts: List[Dict] = [] |
|
|
|
self.init_from_file(map_file_path = map_file_path, |
|
location_file_path = location_file_path) |
|
self.init_prompt() |
|
|
|
if 'embedding' not in globals(): |
|
global embedding |
|
embedding = get_embedding_model(embedding_name, language=language) |
|
|
|
self.world_data,self.world_settings = build_world_agent_data(world_file_path = world_file_path, max_words = 50) |
|
self.db_name = clean_collection_name(f"settings_{source}_{embedding_name}") |
|
self.db = build_db(data = [row for row in self.world_data], |
|
db_name = self.db_name, |
|
db_type = db_type, |
|
embedding = embedding) |
|
|
|
def init_from_file(self, map_file_path: str, location_file_path: str, default_distance: int = 1): |
|
if map_file_path and os.path.exists(map_file_path): |
|
valid_locations = load_json_file(location_file_path) if "locations" not in load_json_file(location_file_path) else load_json_file(location_file_path)["locations"] |
|
with open(map_file_path, mode='r') as file: |
|
csv_reader = csv.reader(file) |
|
locations = next(csv_reader)[1:] |
|
for row in csv_reader: |
|
loc1 = row[0] |
|
if loc1 not in valid_locations: |
|
print(f"Warning: The location {loc1} does not exist") |
|
continue |
|
self.locations_info[loc1] = valid_locations[loc1] |
|
self.locations.append(loc1) |
|
distances = row[1:] |
|
for i, distance in enumerate(distances): |
|
loc2 = locations[i] |
|
if loc2 not in valid_locations: |
|
print(f"Warning: The location {loc2} does not exist") |
|
continue |
|
if distance != '0': |
|
self._add_edge(loc1, loc2, int(distance)) |
|
else: |
|
valid_locations = load_json_file(location_file_path) if "locations" not in load_json_file(location_file_path) else load_json_file(location_file_path)["locations"] |
|
for loc1 in valid_locations: |
|
self.locations_info[loc1] = valid_locations[loc1] |
|
self.locations.append(loc1) |
|
for loc2 in valid_locations: |
|
if loc2 != loc1: |
|
self._add_edge(loc1, loc2, default_distance) |
|
|
|
def init_prompt(self,): |
|
if self.language == "zh": |
|
from modules.prompt.world_agent_prompt_zh import ENVIROMENT_INTERACTION_PROMPT,NPC_INTERACTION_PROMPT,SCRIPT_INSTRUCTION_PROMPT,SCRIPT_ATTENTION_PROMPT,DECIDE_NEXT_ACTOR_PROMPT,GENERATE_INTERVENTION_PROMPT,UPDATE_EVENT_PROMPT,LOCATION_PROLOGUE_PROMPT,SELECT_SCREEN_ACTORS_PROMPT,JUDGE_IF_ENDED_PROMPT,LOG2STORY_PROMPT |
|
else: |
|
from modules.prompt.world_agent_prompt_en import ENVIROMENT_INTERACTION_PROMPT,NPC_INTERACTION_PROMPT,SCRIPT_INSTRUCTION_PROMPT,SCRIPT_ATTENTION_PROMPT,DECIDE_NEXT_ACTOR_PROMPT,GENERATE_INTERVENTION_PROMPT,UPDATE_EVENT_PROMPT,LOCATION_PROLOGUE_PROMPT,SELECT_SCREEN_ACTORS_PROMPT,JUDGE_IF_ENDED_PROMPT,LOG2STORY_PROMPT |
|
|
|
self._ENVIROMENT_INTERACTION_PROMPT = ENVIROMENT_INTERACTION_PROMPT |
|
self._NPC_INTERACTION_PROMPT = NPC_INTERACTION_PROMPT |
|
self._SCRIPT_INSTRUCTION_PROMPT = SCRIPT_INSTRUCTION_PROMPT |
|
self._SCRIPT_ATTENTION = SCRIPT_ATTENTION_PROMPT |
|
self._DECIDE_NEXT_ACTOR_PROMPT= DECIDE_NEXT_ACTOR_PROMPT |
|
self._LOCATION_PROLOGUE_PROMPT = LOCATION_PROLOGUE_PROMPT |
|
self._GENERATE_INTERVENTION_PROMPT = GENERATE_INTERVENTION_PROMPT |
|
self._UPDATE_EVENT_PROMPT = UPDATE_EVENT_PROMPT |
|
self._SELECT_SCREEN_ACTORS_PROMPT = SELECT_SCREEN_ACTORS_PROMPT |
|
self._JUDGE_IF_ENDED_PROMPT = JUDGE_IF_ENDED_PROMPT |
|
self._LOG2STORY_PROMPT = LOG2STORY_PROMPT |
|
|
|
|
|
def update_event(self, |
|
cur_event: str, |
|
intervention:str, |
|
history_text: str, |
|
script: str = ""): |
|
prompt = self._UPDATE_EVENT_PROMPT.format(**{ |
|
"event":cur_event, |
|
"intervention":intervention, |
|
"history":history_text |
|
}) |
|
if script: |
|
prompt = self._SCRIPT_ATTENTION.format(script = script) + prompt |
|
new_event = self.llm.chat(prompt) |
|
self.record(new_event, prompt) |
|
return new_event |
|
|
|
def decide_next_actor(self, |
|
history_text: str, |
|
roles_info_text: str, |
|
script: str = "", |
|
event:str = ""): |
|
prompt = self._DECIDE_NEXT_ACTOR_PROMPT.format(**{ |
|
"roles_info":roles_info_text, |
|
"history_text":history_text, |
|
}) |
|
|
|
max_tries = 3 |
|
for _ in range(max_tries): |
|
try: |
|
response = self.llm.chat(prompt) |
|
break |
|
except Exception as e: |
|
print(f"Parsing failure! Error:", e) |
|
print(response) |
|
role_code = response |
|
self.prompts.append({"prompt":prompt, |
|
"response":f"{role_code}"}) |
|
|
|
return role_code |
|
|
|
def judge_if_ended(self,history_text): |
|
prompt = self._JUDGE_IF_ENDED_PROMPT.format(**{ |
|
"history":history_text |
|
}) |
|
max_tries = 3 |
|
response = {"if_end":True, "detail":""} |
|
for _ in range(max_tries): |
|
try: |
|
response.update(json_parser(self.llm.chat(prompt))) |
|
break |
|
except Exception as e: |
|
print(f"Parsing failure! Error:", e) |
|
print(response) |
|
|
|
return response["if_end"],response["detail"] |
|
|
|
def decide_screen_actors(self,roles_info_text, history_text, event, previous_role_codes): |
|
prompt = self._SELECT_SCREEN_ACTORS_PROMPT.format(**{ |
|
"roles_info":roles_info_text, |
|
"history_text":history_text, |
|
"event":event, |
|
"previous_role_codes":previous_role_codes |
|
|
|
}) |
|
response = self.llm.chat(prompt) |
|
role_codes = eval(response) |
|
return role_codes |
|
|
|
def generate_location_prologue(self, |
|
location_code, |
|
history_text, |
|
event, |
|
location_info_text): |
|
prompt = self._LOCATION_PROLOGUE_PROMPT.format(**{ |
|
"location_name":self.locations_info[location_code]["location_name"], |
|
"location_description":self.locations_info[location_code]["location_name"], |
|
"location_info":location_info_text, |
|
"history_text":history_text, |
|
"event":event, |
|
"world_description":self.description |
|
}) |
|
response = self.llm.chat(prompt) |
|
self.record(detail = response,prompt = prompt) |
|
return "\n"+response |
|
|
|
def enviroment_interact(self, |
|
action_maker_name: str, |
|
action: str, |
|
action_detail: str, |
|
location_code: str): |
|
references = self.retrieve_references(query = action_detail) |
|
prompt = self._ENVIROMENT_INTERACTION_PROMPT.format(** |
|
{ |
|
"role_name":action_maker_name, |
|
"action":action, |
|
"action_detail":action_detail, |
|
"world_description":self.description, |
|
"location":location_code, |
|
"location_description":self.locations_info[location_code]["detail"], |
|
"references":references, |
|
} |
|
) |
|
response = "无事发生。" if self.language == "zh" else "Nothing happens." |
|
for i in range(3): |
|
try: |
|
response = self.llm.chat(prompt) |
|
if response: |
|
break |
|
except Exception as e: |
|
print("Enviroment Interaction failed! {i}th tries. Error:", e) |
|
self.record(response, prompt) |
|
return response |
|
|
|
|
|
def npc_interact(self, |
|
action_maker_name: str, |
|
action_detail: str, |
|
location_name: str, |
|
target_name: str): |
|
references = self.retrieve_references(query = action_detail) |
|
prompt = self._NPC_INTERACTION_PROMPT.format(** |
|
{ |
|
"role_name":action_maker_name, |
|
"action_detail":action_detail, |
|
"world_description":self.description, |
|
"target":target_name, |
|
"references":references, |
|
"location":location_name |
|
} |
|
) |
|
|
|
npc_interaction = {"if_end_interaction":True,"detail":"无事发生。"} if self.language == "zh" else {"if_end_interaction":True,"detail":"Nothing happens"} |
|
try: |
|
npc_interaction = json_parser(self.llm.chat(prompt)) |
|
response = npc_interaction["detail"] |
|
self.record(response, prompt) |
|
except Exception as e: |
|
print("Enviroment Interaction failed!",e) |
|
|
|
return npc_interaction |
|
|
|
|
|
def get_script_instruction(self, |
|
roles_info_text: str, |
|
event: str, |
|
history_text: str, |
|
script: str, |
|
last_progress: str): |
|
prompt = self._SCRIPT_INSTRUCTION_PROMPT.format(**{ |
|
"roles_info":roles_info_text, |
|
"event":event, |
|
"history_text":history_text, |
|
"script":script, |
|
"last_progress":last_progress |
|
}) |
|
max_tries = 3 |
|
instruction = {} |
|
for i in range(max_tries): |
|
response = self.llm.chat(prompt) |
|
try: |
|
instruction = json_parser(response) |
|
break |
|
except Exception as e: |
|
print(f"Parsing failure! {i+1}th tries. Error:", e) |
|
print(response) |
|
self.record(response, prompt) |
|
return instruction |
|
|
|
def generate_event(self,roles_info_text: str, event: str, history_text: str): |
|
prompt = self._GENERATE_INTERVENTION_PROMPT.format(**{ |
|
"world_description":self.description, |
|
"roles_info":roles_info_text, |
|
"history_text":history_text |
|
}) |
|
response = self.llm.chat(prompt) |
|
self.record(response, prompt) |
|
return response |
|
|
|
def generate_script(self, roles_info_text: str, event: str, history_text: str): |
|
prompt = self._GENERATE_INTERVENTION_PROMPT.format(**{ |
|
"world_description":self.description, |
|
"roles_info":roles_info_text, |
|
"history_text":history_text |
|
}) |
|
response = self.llm.chat(prompt) |
|
self.record(response, prompt) |
|
return response |
|
|
|
def log2story(self,logs): |
|
prompt = self._LOG2STORY_PROMPT.format(**{ |
|
"logs":logs |
|
}) |
|
response = self.llm.chat(prompt) |
|
return response |
|
|
|
|
|
def record(self, detail: str, prompt: str = ""): |
|
if prompt: |
|
self.prompts.append({"prompt":prompt, |
|
"response":detail}) |
|
self.history.append(detail) |
|
|
|
def add_location_during_simulation(self, location: str, detail: str): |
|
self.locations.append(location) |
|
self.locations_info[location] = { |
|
'location_code': location, |
|
"location_name": location, |
|
'description': '', |
|
'detail':detail |
|
} |
|
for loc in self.locations: |
|
if loc != location: |
|
self._add_edge(loc, location, 1) |
|
self._add_edge(location,loc, 1) |
|
return |
|
|
|
def retrieve_references(self, query: str, top_k = 3, max_words = 100): |
|
if self.db is None: |
|
return "" |
|
references = "\n".join(self.db.search(query, top_k,self.db_name)) |
|
references = references[:max_words] |
|
return references |
|
|
|
def find_location_name(self, code: str): |
|
return self.locations_info[code]["location_name"] |
|
|
|
def _add_location(self, code: str, location_info: Dict[str, Any]): |
|
self.locations_info[code] = location_info |
|
|
|
def _add_edge(self, code1: str, code2: str, distance: int): |
|
self.edges[(code1,code2)] = distance |
|
self.edges[(code2,code1)] = distance |
|
|
|
def get_distance(self, code1: str, code2: str): |
|
if (code1,code2) in self.edges: |
|
return self.edges[(code1,code2)] |
|
else: |
|
return None |
|
|
|
def __getstate__(self): |
|
state = {key: value for key, value in self.__dict__.items() |
|
if isinstance(value, (str, int, list, dict, float, bool, type(None))) |
|
and (key not in ['llm','embedding','db','locations_info','edges','world_data','world_settings'] |
|
and "PROMPT" not in key) |
|
} |
|
return state |
|
|
|
def __setstate__(self, state): |
|
self.__dict__.update(state) |
|
|
|
def save_to_file(self, root_dir): |
|
filename = os.path.join(root_dir, f"./world_agent.json") |
|
save_json_file(filename, self.__getstate__() ) |
|
|
|
def load_from_file(self, root_dir): |
|
filename = os.path.join(root_dir, f"./world_agent.json") |
|
state = load_json_file(filename) |
|
self.__setstate__(state) |
|
|
|
|