File size: 5,360 Bytes
36b7c16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import logging
import json
import os
import time
import tiktoken
from datetime import datetime
from typing import Dict, Any, Optional, Tuple

# 配置日志
def setup_logging():
    """配置日志系统"""
    log_path = os.environ.get("LOG_PATH", "/tmp/2api.log")
    log_level_str = os.environ.get("LOG_LEVEL", "INFO").upper()
    log_level = getattr(logging, log_level_str, logging.INFO)
    log_format = os.environ.get("LOG_FORMAT", "%(asctime)s - %(name)s - %(levelname)s - %(message)s")

    file_handler = logging.FileHandler(log_path, encoding='utf-8')
    stream_handler = logging.StreamHandler()
    logging.basicConfig(
        level=log_level,
        format=log_format,
        handlers=[stream_handler, file_handler]
    )
    return logging.getLogger('2api')

logger = setup_logging()

def load_config():
    """从 config.json 加载配置(如果存在),否则使用环境变量"""
    default_config_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'config.json')
    CONFIG_FILE = os.environ.get("CONFIG_FILE_PATH", default_config_path)
    config = {}

    if os.path.exists(CONFIG_FILE):
        try:
            with open(CONFIG_FILE, 'r', encoding='utf-8') as f:
                config = json.load(f)
                logger.info(f"已从 {CONFIG_FILE} 加载配置")
        except (json.JSONDecodeError, IOError) as e:
            logger.error(f"加载配置文件失败: {e}")
            config = {}
    
    return config

def mask_email(email: str) -> str:
    """隐藏邮箱中间部分,保护隐私"""
    if not email or '@' not in email:
        return "无效邮箱"
    
    parts = email.split('@')
    username = parts[0]
    domain = parts[1]
    
    if len(username) <= 3:
        masked_username = username[0] + '*' * (len(username) - 1)
    else:
        masked_username = username[0] + '*' * (len(username) - 2) + username[-1]
    
    return f"{masked_username}@{domain}"

def generate_request_id() -> str:
    """生成唯一的请求ID"""
    return f"chatcmpl-{os.urandom(16).hex()}"

def count_tokens(text: str, model: str = "gpt-3.5-turbo") -> int:
    """
    计算文本的token数量

    Args:
        text: 要计算token数量的文本
        model: 模型名称,默认为gpt-3.5-turbo
        
    Returns:
        int: token数量
    """
    # 类型保护,防止text为None或非字符串类型
    if text is None:
        text = ""
    elif not isinstance(text, str):
        text = str(text)
    try:
        # 根据模型名称获取编码器
        if "gpt-4" in model:
            encoding = tiktoken.encoding_for_model("gpt-4")
        elif "gpt-3.5" in model:
            encoding = tiktoken.encoding_for_model("gpt-3.5-turbo")
        elif "claude" in model:
            # Claude模型使用cl100k_base编码器
            encoding = tiktoken.get_encoding("cl100k_base")
        else:
            # 默认使用cl100k_base编码器
            encoding = tiktoken.get_encoding("cl100k_base")
        
        # 计算token数量
        tokens = encoding.encode(text)
        return len(tokens)
    except Exception as e:
        logger.error(f"计算token数量时出错: {e}")
        # 如果出错,使用简单的估算方法(每4个字符约为1个token)
        return len(text) // 4

def count_message_tokens(messages: list, model: str = "gpt-3.5-turbo") -> Tuple[int, int, int]:
    """
    计算OpenAI格式消息列表的token数量
    
    Args:
        messages: OpenAI格式的消息列表
        model: 模型名称,默认为gpt-3.5-turbo
        
    Returns:
        Tuple[int, int, int]: (提示tokens数, 完成tokens数, 总tokens数)
    """
    # 类型保护,防止messages为None或非列表类型
    if messages is None:
        messages = []
    elif not isinstance(messages, list):
        logger.warning(f"count_message_tokens 收到非列表类型的消息: {type(messages)}")
        messages = []
    
    prompt_tokens = 0
    completion_tokens = 0
    
    try:
        # 计算提示tokens
        for message in messages:
            # 确保message是字典类型
            if not isinstance(message, dict):
                logger.warning(f"跳过非字典类型的消息: {type(message)}")
                continue
                
            role = message.get('role', '')
            content = message.get('content', '')
            
            if role and content:
                # 每条消息的基本token开销
                prompt_tokens += 4  # 每条消息的基本开销
                
                # 角色名称的token
                prompt_tokens += 1  # 角色名称的开销
                
                # 内容的token
                prompt_tokens += count_tokens(content, model)
                
                # 如果是assistant角色,计算完成tokens
                if role == 'assistant':
                    completion_tokens += count_tokens(content, model)
        
        # 消息结束的token
        prompt_tokens += 2  # 消息结束的开销
        
        # 计算总tokens
        total_tokens = prompt_tokens + completion_tokens
        
        return prompt_tokens, completion_tokens, total_tokens
    except Exception as e:
        logger.error(f"计算消息token数量时出错: {e}")
        # 返回安全的默认值
        return 0, 0, 0