import os import pickle import json import torch import logging import datetime import re import random import base64 MODEL_NAME_DICT = { "gpt3":"openai/gpt-3.5-turbo", "gpt-4":"openai/gpt-4", "gpt-4o":"openai/gpt-4o", "gpt-4o-mini":"openai/gpt-4o-mini", "gpt-3.5-turbo":"openai/gpt-3.5-turbo", "deepseek-r1":"deepseek/deepseek-r1", "deepseek-v3":"deepseek/deepseek-chat", "gemini-2":"google/gemini-2.0-flash-001", "gemini-1.5":"google/gemini-flash-1.5", "llama3-70b": "meta-llama/llama-3.3-70b-instruct", "qwen-turbo":"qwen/qwen-turbo", "qwen-plus":"qwen/qwen-plus", "qwen-max":"qwen/qwen-max", "qwen-2.5-72b":"qwen/qwen-2.5-72b-instruct", "claude-3.5-sonnet":"anthropic/claude-3.5-sonnet", "phi-4":"microsoft/phi-4", } def get_models(model_name): if os.getenv("OPENROUTER_API_KEY", default="") and model_name in MODEL_NAME_DICT: from modules.llm.OpenRouter import OpenRouter return OpenRouter(model=MODEL_NAME_DICT[model_name]) elif model_name.startswith('gpt-3.5'): from modules.llm.LangChainGPT import LangChainGPT return LangChainGPT(model="gpt-3.5-turbo") elif model_name == 'gpt-4': from modules.llm.LangChainGPT import LangChainGPT return LangChainGPT(model="gpt-4") elif model_name == 'gpt-4-turbo': from modules.llm.LangChainGPT import LangChainGPT return LangChainGPT(model="gpt-4") elif model_name == 'gpt-4o': from modules.llm.LangChainGPT import LangChainGPT return LangChainGPT(model="gpt-4o") elif model_name == "gpt-4o-mini": from modules.llm.LangChainGPT import LangChainGPT return LangChainGPT(model="gpt-4o-mini") elif model_name.startswith("claude"): from modules.llm.LangChainGPT import LangChainGPT return LangChainGPT(model="claude-3-5-sonnet-20241022") elif model_name.startswith('qwen'): from modules.llm.Qwen import Qwen return Qwen(model = model_name) elif model_name.startswith('deepseek'): from modules.llm.DeepSeek import DeepSeek return DeepSeek() elif model_name.startswith('doubao'): from modules.llm.Doubao import Doubao return Doubao() elif model_name.startswith('gemini'): from modules.llm.Gemini import Gemini return Gemini() else: print(f'Warning! undefined model {model_name}, use gpt-3.5-turbo instead.') from modules.llm.LangChainGPT import LangChainGPT return LangChainGPT() def build_world_agent_data(world_file_path,max_words = 30): world_dir = os.path.dirname(world_file_path) details_dir = os.path.join(world_dir,"./world_details") data = [] settings = [] if os.path.exists(details_dir): for path in get_child_paths(details_dir): if os.path.splitext(path)[-1] == ".txt": text = load_text_file(path) data += split_text_by_max_words(text,max_words) if os.path.splitext(path)[-1] == ".jsonl": jsonl = load_jsonl_file(path) data += [f"{dic['term']}:{dic['detail']}" for dic in jsonl] settings += jsonl return data,settings def build_db(data, db_name, db_type, embedding, save_type="persistent"): if True: from modules.db.ChromaDB import ChromaDB db = ChromaDB(embedding,save_type) db_name = db_name db.init_from_data(data,db_name) return db def get_root_dir(): current_file_path = os.path.abspath(__file__) root_dir = os.path.dirname(current_file_path) return root_dir def create_dir(dirname): if not os.path.exists(dirname): os.makedirs(dirname) def get_logger(experiment_name): logger = logging.getLogger(experiment_name) logger.setLevel(logging.INFO) current_time = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") create_dir(f"{get_root_dir()}/log/{experiment_name}") file_handler = logging.FileHandler(os.path.join(get_root_dir(),f"./log/{experiment_name}/{current_time}.log"),encoding='utf-8') file_handler.setLevel(logging.INFO) formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') file_handler.setFormatter(formatter) logger.addHandler(file_handler) # Avoid logging duplication logger.propagate = False return logger def merge_text_with_limit(text_list, max_words, language = 'en'): """ Merge a list of text strings into one, stopping when adding another text exceeds the maximum count. Args: text_list (list): List of strings to be merged. max_count (int): Maximum number of characters (for Chinese) or words (for English). is_chinese (bool): If True, count Chinese characters; if False, count English words. Returns: str: The merged text, truncated as needed. """ merged_text = "" current_count = 0 for text in text_list: if language == 'zh': # Count Chinese characters text_length = len(text) else: # Count English words text_length = len(text.split(" ")) if current_count + text_length > max_words: break merged_text += text + "\n" current_count += text_length return merged_text def normalize_string(text): # 去除空格并将所有字母转为小写 import re return re.sub(r'[\s\,\;\t\n]+', '', text).lower() def fuzzy_match(str1, str2, threshold=0.8): str1_normalized = normalize_string(str1) str2_normalized = normalize_string(str2) if str1_normalized == str2_normalized: return True return False def load_character_card(path): from PIL import Image import PIL.PngImagePlugin image = Image.open(path) if isinstance(image, PIL.PngImagePlugin.PngImageFile): for key, value in image.text.items(): try: character_info = json.loads(decode_base64(value)) if character_info: return character_info except: continue return None def decode_base64(encoded_string): # Convert the string to bytes if it's not already if isinstance(encoded_string, str): encoded_bytes = encoded_string.encode('ascii') else: encoded_bytes = encoded_string # Decode the Base64 bytes decoded_bytes = base64.b64decode(encoded_bytes) # Try to convert the result to a string, assuming UTF-8 encoding try: decoded_string = decoded_bytes.decode('utf-8') return decoded_string except UnicodeDecodeError: # If it's not valid UTF-8 text, return the raw bytes return decoded_bytes def remove_list_elements(list1, *args): for target in args: if isinstance(target,list) or isinstance(target,dict): list1 = [i for i in list1 if i not in target] else: list1 = [i for i in list1 if i != target] return list1 def extract_html_content(html): from bs4 import BeautifulSoup soup = BeautifulSoup(html, "html.parser") content_div = soup.find("div", {"id": "content"}) if not content_div: return "" paragraphs = [] for div in content_div.find_all("div"): paragraphs.append(div.get_text(strip=True)) main_content = "\n\n".join(paragraphs) return main_content def load_text_file(path): with open(path,"r",encoding="utf-8") as f: text = f.read() return text def save_text_file(path,target): with open(path,"w",encoding="utf-8") as f: text = f.write(target) def load_json_file(path): with open(path,"r",encoding="utf-8") as f: return json.load(f) def save_json_file(path,target): dir_name = os.path.dirname(path) if not os.path.exists(dir_name): os.makedirs(dir_name) with open(path,"w",encoding="utf-8") as f: json.dump(target, f, ensure_ascii=False,indent=True) def load_jsonl_file(path): data = [] with open(path,"r",encoding="utf-8") as f: for line in f: data.append(json.loads(line)) return data def save_jsonl_file(path,target): with open(path, "w",encoding="utf-8") as f: for row in target: print(json.dumps(row, ensure_ascii=False), file=f) def split_text_by_max_words(text: str, max_words: int = 30): segments = [] current_segment = [] current_length = 0 lines = text.splitlines() for line in lines: words_in_line = len(line) current_segment.append(line + '\n') current_length += words_in_line if current_length + words_in_line > max_words: segments.append(''.join(current_segment)) current_segment = [] current_length = 0 if current_segment: segments.append(''.join(current_segment)) return segments def lang_detect(text): import re def count_chinese_characters(text): # 使用正则表达式匹配所有汉字字符 chinese_chars = re.findall(r'[\u4e00-\u9fff]', text) return len(chinese_chars) if count_chinese_characters(text) > len(text) * 0.05: lang = 'zh' else: lang = 'en' return lang def dict_to_str(dic): res = "" for key in dic: res += f"{key}: {dic[key]};" return res def count_tokens_num(string, encoding_name = "cl100k_base"): encoding = tiktoken.get_encoding(encoding_name) num_tokens = len(encoding.encode(string)) return num_tokens def json_parser(output): output = output.replace("\n", "") output = output.replace("\t", "") if "{" not in output: output = "{" + output if "}" not in output: output += "}" pattern = r'\{.*\}' matches = re.findall(pattern, output, re.DOTALL) try: parsed_json = eval(matches[0]) except: try: parsed_json = json.loads(matches[0]) except json.JSONDecodeError: try: detail = re.search(r'"detail":\s*(.+?)\s*}', matches[0]).group(1) detail = f"\"{detail}\"" new_output = re.sub(r'"detail":\s*(.+?)\s*}', f"\"detail\":{detail}}}", matches[0]) parsed_json = json.loads(new_output) except Exception as e: raise ValueError("No valid JSON found in the input string") return parsed_json def action_detail_decomposer(detail): thoughts = re.findall(r'【(.*?)】', detail) actions = re.findall(r'((.*?))', detail) dialogues = re.findall(r'「(.*?)」', detail) return thoughts,actions,dialogues def conceal_thoughts(detail): text = re.sub(r'【.*?】', '', detail) text = re.sub(r'\[.*?\]', '', text) return text def extract_first_number(text): match = re.search(r'\b\d+(?:\.\d+)?\b', text) return int(match.group()) if match else None def check_role_code_availability(role_code,role_file_dir): for path in get_grandchild_folders(role_file_dir): if role_code in path: return True return False def get_grandchild_folders(root_folder): folders = [] for resource in os.listdir(root_folder): subpath = os.path.join(root_folder,resource) for folder_name in os.listdir(subpath): folder_path = os.path.join(subpath, folder_name) folders.append(folder_path) return folders def get_child_folders(root_folder): folders = [] for resource in os.listdir(root_folder): path = os.path.join(root_folder,resource) if os.path.isdir(path): folders.append(path) return folders def get_child_paths(root_folder): paths = [] for resource in os.listdir(root_folder): path = os.path.join(root_folder,resource) if os.path.isfile(path): paths.append(path) return paths def get_first_directory(path): try: for item in os.listdir(path): full_path = os.path.join(path, item) if os.path.isdir(full_path): return full_path return None except Exception as e: print(f"Error: {e}") return None def find_files_with_suffix(directory, suffix): matched_files = [] for root, dirs, files in os.walk(directory): # 遍历目录及其子目录 for file in files: if file.endswith(suffix): # 检查文件后缀 matched_files.append(os.path.join(root, file)) # 将符合条件的文件路径加入列表 return matched_files def remove_element_with_probability(lst, threshold=3, probability=0.2): # 确保列表不为空 if len(lst) > threshold and random.random() < probability: # 随机选择一个元素的索引 index = random.randint(0, len(lst) - 1) # 删除该索引位置的元素 lst.pop(index) return lst def count_token_num(text): from transformers import GPT2TokenizerFast tokenizer = GPT2TokenizerFast.from_pretrained('gpt2') return len(tokenizer.encode(text)) def get_cost(model_name,prompt,output): input_price=0 output_price=0 if model_name.startswith("gpt-4"): input_price=10 output_price=30 elif model_name.startswith("gpt-3.5"): input_price=0.5 output_price=1.5 return input_price*count_token_num(prompt)/1000000 + output_price * count_token_num(output)/1000000 def is_image(filepath): if not os.path.isfile(filepath): return False valid_image_extensions = ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff','.webp'] file_extension = os.path.splitext(filepath)[1].lower() # 判断扩展名是否在有效图片扩展名列表中 if file_extension in valid_image_extensions: return True return False def clean_collection_name(name: str) -> str: cleaned_name = name.replace(' ', '_') cleaned_name = cleaned_name.replace('.', '_') if not all(ord(c) < 128 for c in cleaned_name): encoded = base64.b64encode(cleaned_name.encode('utf-8')).decode('ascii') encoded = encoded[:60] if len(encoded) > 60 else encoded valid_name = f"mem_{encoded}" else: valid_name = cleaned_name valid_name = re.sub(r'[^a-zA-Z0-9_-]', '-', valid_name) valid_name = re.sub(r'\.\.+', '-', valid_name) valid_name = re.sub(r'^[^a-zA-Z0-9]+', '', valid_name) # 移除开头非法字符 valid_name = re.sub(r'[^a-zA-Z0-9]+$', '', valid_name) return valid_name cache_sign = True cache = None def cached(func): def wrapper(*args,**kwargs): global cache cache_path = "bw_cache.pkl" if cache == None: if not os.path.exists(cache_path): cache = {} else: cache = pickle.load(open(cache_path, 'rb')) key = (func.__name__, str([args[0].role_code, args[0].__class__, args[0].llm_name , args[0].history]), str(kwargs.items())) if (cache_sign and key in cache and cache[key] not in [None, '[TOKEN LIMIT]']) : return cache[key] else: result = func(*args, **kwargs) if result != 'busy' and result != None: cache[key] = result pickle.dump(cache, open(cache_path, 'wb')) return result return wrapper