File size: 15,395 Bytes
e636070
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
import os
import pickle
import json
import torch
import logging
import datetime
import re
import random
import base64

MODEL_NAME_DICT = {
    "gpt3":"openai/gpt-3.5-turbo",
    "gpt-4":"openai/gpt-4",
    "gpt-4o":"openai/gpt-4o",
    "gpt-4o-mini":"openai/gpt-4o-mini",
    "gpt-3.5-turbo":"openai/gpt-3.5-turbo",
    "deepseek-r1":"deepseek/deepseek-r1",
    "deepseek-v3":"deepseek/deepseek-chat",
    "gemini-2":"google/gemini-2.0-flash-001",
    "gemini-1.5":"google/gemini-flash-1.5",
    "llama3-70b": "meta-llama/llama-3.3-70b-instruct",
    "qwen-turbo":"qwen/qwen-turbo",
    "qwen-plus":"qwen/qwen-plus",
    "qwen-max":"qwen/qwen-max",
    "qwen-2.5-72b":"qwen/qwen-2.5-72b-instruct",
    "claude-3.5-sonnet":"anthropic/claude-3.5-sonnet",
    "phi-4":"microsoft/phi-4",
}

def get_models(model_name):
    if os.getenv("OPENROUTER_API_KEY", default="") and model_name in MODEL_NAME_DICT:
        from modules.llm.OpenRouter import OpenRouter
        return OpenRouter(model=MODEL_NAME_DICT[model_name])
    elif model_name.startswith('gpt-3.5'):
        from modules.llm.LangChainGPT import LangChainGPT
        return LangChainGPT(model="gpt-3.5-turbo")
    elif model_name == 'gpt-4':
        from modules.llm.LangChainGPT import LangChainGPT
        return LangChainGPT(model="gpt-4")
    elif model_name == 'gpt-4-turbo':
        from modules.llm.LangChainGPT import LangChainGPT
        return LangChainGPT(model="gpt-4")
    elif model_name == 'gpt-4o':
        from modules.llm.LangChainGPT import LangChainGPT
        return LangChainGPT(model="gpt-4o")
    elif model_name == "gpt-4o-mini":
        from modules.llm.LangChainGPT import LangChainGPT
        return LangChainGPT(model="gpt-4o-mini")
    elif model_name.startswith("claude"):
        from modules.llm.LangChainGPT import LangChainGPT
        return LangChainGPT(model="claude-3-5-sonnet-20241022")
    elif model_name.startswith('qwen'):
        from modules.llm.Qwen import Qwen
        return Qwen(model = model_name)
    elif model_name.startswith('deepseek'):
        from modules.llm.DeepSeek import DeepSeek
        return DeepSeek()
    elif model_name.startswith('doubao'):
        from modules.llm.Doubao import Doubao
        return Doubao()
    elif model_name.startswith('gemini'):
        from modules.llm.Gemini import Gemini
        return Gemini()
    else:
        print(f'Warning! undefined model {model_name}, use gpt-3.5-turbo instead.')
        from modules.llm.LangChainGPT import LangChainGPT
        return LangChainGPT()
    
def build_world_agent_data(world_file_path,max_words = 30):
    world_dir = os.path.dirname(world_file_path)
    details_dir = os.path.join(world_dir,"./world_details")
    data = []
    settings = []
    if os.path.exists(details_dir):
        for path in get_child_paths(details_dir):
            if os.path.splitext(path)[-1] == ".txt":
                text = load_text_file(path)
                data += split_text_by_max_words(text,max_words)
            if os.path.splitext(path)[-1] == ".jsonl":
                jsonl = load_jsonl_file(path)
                data += [f"{dic['term']}:{dic['detail']}" for dic in jsonl]
                settings += jsonl
    return data,settings

def build_db(data, db_name, db_type, embedding, save_type="persistent"):
    if True:
        from modules.db.ChromaDB import ChromaDB
        db = ChromaDB(embedding,save_type)
        db_name = db_name
        db.init_from_data(data,db_name)
    return db

def get_root_dir():
    current_file_path = os.path.abspath(__file__)
    root_dir = os.path.dirname(current_file_path)
    return root_dir

def create_dir(dirname):
    if not os.path.exists(dirname):
        os.makedirs(dirname)

def get_logger(experiment_name):
    logger = logging.getLogger(experiment_name)
    logger.setLevel(logging.INFO)
    current_time = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    create_dir(f"{get_root_dir()}/log/{experiment_name}")
    file_handler = logging.FileHandler(os.path.join(get_root_dir(),f"./log/{experiment_name}/{current_time}.log"),encoding='utf-8')
    file_handler.setLevel(logging.INFO)
    
    formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
    file_handler.setFormatter(formatter)
    
    logger.addHandler(file_handler)
    
    # Avoid logging duplication
    logger.propagate = False

    return logger

def merge_text_with_limit(text_list, max_words, language = 'en'):
    """
    Merge a list of text strings into one, stopping when adding another text exceeds the maximum count.

    Args:
        text_list (list): List of strings to be merged.
        max_count (int): Maximum number of characters (for Chinese) or words (for English).
        is_chinese (bool): If True, count Chinese characters; if False, count English words.

    Returns:
        str: The merged text, truncated as needed.
    """
    merged_text = ""
    current_count = 0

    for text in text_list:
        if language == 'zh':
            # Count Chinese characters
            text_length = len(text)
        else:
            # Count English words
            text_length = len(text.split(" "))

        if current_count + text_length > max_words:
            break

        merged_text += text + "\n"
        current_count += text_length

    return merged_text

def normalize_string(text):
    # 去除空格并将所有字母转为小写
    import re
    return re.sub(r'[\s\,\;\t\n]+', '', text).lower()

def fuzzy_match(str1, str2, threshold=0.8):
    str1_normalized = normalize_string(str1)
    str2_normalized = normalize_string(str2)

    if str1_normalized == str2_normalized:
        return True

    return False

def load_character_card(path):
    from PIL import Image
    import PIL.PngImagePlugin
    
    image = Image.open(path)
    if isinstance(image, PIL.PngImagePlugin.PngImageFile):
        for key, value in image.text.items():
            try:
                character_info = json.loads(decode_base64(value))
                if character_info:
                    return character_info
            except:
                continue
    return None

def decode_base64(encoded_string):
    # Convert the string to bytes if it's not already
    if isinstance(encoded_string, str):
        encoded_bytes = encoded_string.encode('ascii')
    else:
        encoded_bytes = encoded_string

    # Decode the Base64 bytes
    decoded_bytes = base64.b64decode(encoded_bytes)

    # Try to convert the result to a string, assuming UTF-8 encoding
    try:
        decoded_string = decoded_bytes.decode('utf-8')
        return decoded_string
    except UnicodeDecodeError:
        # If it's not valid UTF-8 text, return the raw bytes
        return decoded_bytes
    
def remove_list_elements(list1, *args):
    for target in args:
        if isinstance(target,list) or isinstance(target,dict):
            list1 = [i for i in list1 if i not in target]
        else:
            list1 = [i for i in list1 if i != target]
    return list1

def extract_html_content(html):
    from bs4 import BeautifulSoup
    soup = BeautifulSoup(html, "html.parser")
    
    content_div = soup.find("div", {"id": "content"})
    if not content_div:
        return ""

    paragraphs = []
    for div in content_div.find_all("div"):
        paragraphs.append(div.get_text(strip=True))
    
    main_content = "\n\n".join(paragraphs)
    return main_content

def load_text_file(path):
    with open(path,"r",encoding="utf-8") as f:
        text = f.read()
    return text

def save_text_file(path,target):
    with open(path,"w",encoding="utf-8") as f:
        text = f.write(target)

def load_json_file(path):
    with open(path,"r",encoding="utf-8") as f:
        return json.load(f)
    
def save_json_file(path,target):
    dir_name = os.path.dirname(path)
    if not os.path.exists(dir_name):
        os.makedirs(dir_name)
    with open(path,"w",encoding="utf-8") as f:
        json.dump(target, f, ensure_ascii=False,indent=True)
        
def load_jsonl_file(path):
    data = []
    with open(path,"r",encoding="utf-8") as f:
        for line in f:
            data.append(json.loads(line))
    return data
        
def save_jsonl_file(path,target):
    with open(path, "w",encoding="utf-8") as f:
        for row in target:
            print(json.dumps(row, ensure_ascii=False), file=f)

def split_text_by_max_words(text: str, max_words: int = 30):
    segments = []
    current_segment = []
    current_length = 0
    
    lines = text.splitlines()

    for line in lines:
        words_in_line = len(line)
        current_segment.append(line + '\n')
        current_length += words_in_line
        
        if current_length + words_in_line > max_words:
            segments.append(''.join(current_segment))
            current_segment = []
            current_length = 0

    if current_segment:
        segments.append(''.join(current_segment))

    return segments

def lang_detect(text):
    import re
    def count_chinese_characters(text):
        # 使用正则表达式匹配所有汉字字符
        chinese_chars = re.findall(r'[\u4e00-\u9fff]', text)
        return len(chinese_chars)
            
    if count_chinese_characters(text) > len(text) * 0.05:
        lang = 'zh'
    else:
        lang = 'en'
    return lang

def dict_to_str(dic):
    res = ""
    for key in dic:
        res += f"{key}: {dic[key]};"
    return res

def count_tokens_num(string, encoding_name = "cl100k_base"):
    encoding = tiktoken.get_encoding(encoding_name)
    num_tokens = len(encoding.encode(string))
    return num_tokens

 
def json_parser(output):
    output = output.replace("\n", "")
    output = output.replace("\t", "")
    if "{" not in output:
        output = "{" + output
    if "}" not in output:
        output += "}"  
    pattern = r'\{.*\}'
    matches = re.findall(pattern, output, re.DOTALL)
    try:
        parsed_json = eval(matches[0])
    except:
        try:
            parsed_json = json.loads(matches[0])
            
        except json.JSONDecodeError:
            try:
                detail = re.search(r'"detail":\s*(.+?)\s*}', matches[0]).group(1)
                detail = f"\"{detail}\"" 
                new_output = re.sub(r'"detail":\s*(.+?)\s*}', f"\"detail\":{detail}}}", matches[0])
                parsed_json = json.loads(new_output)
            except Exception as e:
                raise ValueError("No valid JSON found in the input string")
    return parsed_json

def action_detail_decomposer(detail):
    thoughts = re.findall(r'【(.*?)】', detail)
    actions = re.findall(r'((.*?))', detail)
    dialogues = re.findall(r'「(.*?)」', detail)
    return thoughts,actions,dialogues

def conceal_thoughts(detail):
    text = re.sub(r'【.*?】', '', detail)
    text = re.sub(r'\[.*?\]', '', text)
    return text

def extract_first_number(text):
    match = re.search(r'\b\d+(?:\.\d+)?\b', text)
    return int(match.group()) if match else None

def check_role_code_availability(role_code,role_file_dir):
    for path in get_grandchild_folders(role_file_dir):
        if role_code in path:
            return True
    return False
    
def get_grandchild_folders(root_folder):
    folders = []
    for resource in os.listdir(root_folder):
        subpath = os.path.join(root_folder,resource)
        for folder_name in os.listdir(subpath):
            folder_path = os.path.join(subpath, folder_name)
            folders.append(folder_path)
    
    return folders

def get_child_folders(root_folder):
    folders = []
    for resource in os.listdir(root_folder):
        path = os.path.join(root_folder,resource)
        if os.path.isdir(path):
            folders.append(path)
    return folders

def get_child_paths(root_folder):
    paths = []
    for resource in os.listdir(root_folder):
        path = os.path.join(root_folder,resource)
        if os.path.isfile(path):
            paths.append(path)
    return paths

def get_first_directory(path):
    try:
        for item in os.listdir(path):
            full_path = os.path.join(path, item)
            if os.path.isdir(full_path):
                return full_path
        return None
    except Exception as e:
        print(f"Error: {e}")
        return None
    
def find_files_with_suffix(directory, suffix):
    matched_files = []
    for root, dirs, files in os.walk(directory):  # 遍历目录及其子目录
        for file in files:
            if file.endswith(suffix):  # 检查文件后缀
                matched_files.append(os.path.join(root, file))  # 将符合条件的文件路径加入列表

    return matched_files

def remove_element_with_probability(lst, threshold=3, probability=0.2):
    # 确保列表不为空
    if len(lst) > threshold and random.random() < probability:
        # 随机选择一个元素的索引
        index = random.randint(0, len(lst) - 1)
        # 删除该索引位置的元素
        lst.pop(index)
    return lst
  
def count_token_num(text):
    from transformers import GPT2TokenizerFast
    tokenizer = GPT2TokenizerFast.from_pretrained('gpt2')
    return len(tokenizer.encode(text))

def get_cost(model_name,prompt,output):
    input_price=0
    output_price=0
    if model_name.startswith("gpt-4"):
        input_price=10
        output_price=30
    elif model_name.startswith("gpt-3.5"):
        input_price=0.5
        output_price=1.5
    
    return input_price*count_token_num(prompt)/1000000 + output_price * count_token_num(output)/1000000

def is_image(filepath):
    if not os.path.isfile(filepath):
        return False

    valid_image_extensions = ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff','.webp']
    file_extension = os.path.splitext(filepath)[1].lower()

    # 判断扩展名是否在有效图片扩展名列表中
    if file_extension in valid_image_extensions:
        return True

    return False

def clean_collection_name(name: str) -> str:
    cleaned_name = name.replace(' ', '_')
    cleaned_name = cleaned_name.replace('.', '_')
    if not all(ord(c) < 128 for c in cleaned_name):
        encoded = base64.b64encode(cleaned_name.encode('utf-8')).decode('ascii')
        encoded = encoded[:60] if len(encoded) > 60 else encoded
        valid_name = f"mem_{encoded}"
    else:
        valid_name = cleaned_name
    valid_name = re.sub(r'[^a-zA-Z0-9_-]', '-', valid_name)
    valid_name = re.sub(r'\.\.+', '-', valid_name)
    valid_name = re.sub(r'^[^a-zA-Z0-9]+', '', valid_name)  # 移除开头非法字符
    valid_name = re.sub(r'[^a-zA-Z0-9]+$', '', valid_name)
    return valid_name
    
cache_sign = True
cache = None 
def cached(func):
    def wrapper(*args,**kwargs):
        global cache
        cache_path = "bw_cache.pkl"
        if cache == None:
            if not os.path.exists(cache_path):
                cache = {}
            else:
                cache = pickle.load(open(cache_path, 'rb'))
        key = (func.__name__, str([args[0].role_code, args[0].__class__, args[0].llm_name , args[0].history]), str(kwargs.items()))
        if (cache_sign and key in cache and cache[key] not in [None, '[TOKEN LIMIT]']) :
            return cache[key]
        else:
            result = func(*args, **kwargs)
            if result != 'busy' and result != None:
                cache[key] = result
                pickle.dump(cache, open(cache_path, 'wb'))
            return result
    return wrapper