File size: 5,360 Bytes
36b7c16 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
import logging
import json
import os
import time
import tiktoken
from datetime import datetime
from typing import Dict, Any, Optional, Tuple
# 配置日志
def setup_logging():
"""配置日志系统"""
log_path = os.environ.get("LOG_PATH", "/tmp/2api.log")
log_level_str = os.environ.get("LOG_LEVEL", "INFO").upper()
log_level = getattr(logging, log_level_str, logging.INFO)
log_format = os.environ.get("LOG_FORMAT", "%(asctime)s - %(name)s - %(levelname)s - %(message)s")
file_handler = logging.FileHandler(log_path, encoding='utf-8')
stream_handler = logging.StreamHandler()
logging.basicConfig(
level=log_level,
format=log_format,
handlers=[stream_handler, file_handler]
)
return logging.getLogger('2api')
logger = setup_logging()
def load_config():
"""从 config.json 加载配置(如果存在),否则使用环境变量"""
default_config_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'config.json')
CONFIG_FILE = os.environ.get("CONFIG_FILE_PATH", default_config_path)
config = {}
if os.path.exists(CONFIG_FILE):
try:
with open(CONFIG_FILE, 'r', encoding='utf-8') as f:
config = json.load(f)
logger.info(f"已从 {CONFIG_FILE} 加载配置")
except (json.JSONDecodeError, IOError) as e:
logger.error(f"加载配置文件失败: {e}")
config = {}
return config
def mask_email(email: str) -> str:
"""隐藏邮箱中间部分,保护隐私"""
if not email or '@' not in email:
return "无效邮箱"
parts = email.split('@')
username = parts[0]
domain = parts[1]
if len(username) <= 3:
masked_username = username[0] + '*' * (len(username) - 1)
else:
masked_username = username[0] + '*' * (len(username) - 2) + username[-1]
return f"{masked_username}@{domain}"
def generate_request_id() -> str:
"""生成唯一的请求ID"""
return f"chatcmpl-{os.urandom(16).hex()}"
def count_tokens(text: str, model: str = "gpt-3.5-turbo") -> int:
"""
计算文本的token数量
Args:
text: 要计算token数量的文本
model: 模型名称,默认为gpt-3.5-turbo
Returns:
int: token数量
"""
# 类型保护,防止text为None或非字符串类型
if text is None:
text = ""
elif not isinstance(text, str):
text = str(text)
try:
# 根据模型名称获取编码器
if "gpt-4" in model:
encoding = tiktoken.encoding_for_model("gpt-4")
elif "gpt-3.5" in model:
encoding = tiktoken.encoding_for_model("gpt-3.5-turbo")
elif "claude" in model:
# Claude模型使用cl100k_base编码器
encoding = tiktoken.get_encoding("cl100k_base")
else:
# 默认使用cl100k_base编码器
encoding = tiktoken.get_encoding("cl100k_base")
# 计算token数量
tokens = encoding.encode(text)
return len(tokens)
except Exception as e:
logger.error(f"计算token数量时出错: {e}")
# 如果出错,使用简单的估算方法(每4个字符约为1个token)
return len(text) // 4
def count_message_tokens(messages: list, model: str = "gpt-3.5-turbo") -> Tuple[int, int, int]:
"""
计算OpenAI格式消息列表的token数量
Args:
messages: OpenAI格式的消息列表
model: 模型名称,默认为gpt-3.5-turbo
Returns:
Tuple[int, int, int]: (提示tokens数, 完成tokens数, 总tokens数)
"""
# 类型保护,防止messages为None或非列表类型
if messages is None:
messages = []
elif not isinstance(messages, list):
logger.warning(f"count_message_tokens 收到非列表类型的消息: {type(messages)}")
messages = []
prompt_tokens = 0
completion_tokens = 0
try:
# 计算提示tokens
for message in messages:
# 确保message是字典类型
if not isinstance(message, dict):
logger.warning(f"跳过非字典类型的消息: {type(message)}")
continue
role = message.get('role', '')
content = message.get('content', '')
if role and content:
# 每条消息的基本token开销
prompt_tokens += 4 # 每条消息的基本开销
# 角色名称的token
prompt_tokens += 1 # 角色名称的开销
# 内容的token
prompt_tokens += count_tokens(content, model)
# 如果是assistant角色,计算完成tokens
if role == 'assistant':
completion_tokens += count_tokens(content, model)
# 消息结束的token
prompt_tokens += 2 # 消息结束的开销
# 计算总tokens
total_tokens = prompt_tokens + completion_tokens
return prompt_tokens, completion_tokens, total_tokens
except Exception as e:
logger.error(f"计算消息token数量时出错: {e}")
# 返回安全的默认值
return 0, 0, 0 |