|
import logging |
|
import json |
|
import os |
|
import time |
|
import tiktoken |
|
from datetime import datetime |
|
from typing import Dict, Any, Optional, Tuple |
|
|
|
|
|
def setup_logging(): |
|
"""配置日志系统""" |
|
log_path = os.environ.get("LOG_PATH", "/tmp/2api.log") |
|
log_level_str = os.environ.get("LOG_LEVEL", "INFO").upper() |
|
log_level = getattr(logging, log_level_str, logging.INFO) |
|
log_format = os.environ.get("LOG_FORMAT", "%(asctime)s - %(name)s - %(levelname)s - %(message)s") |
|
|
|
file_handler = logging.FileHandler(log_path, encoding='utf-8') |
|
stream_handler = logging.StreamHandler() |
|
logging.basicConfig( |
|
level=log_level, |
|
format=log_format, |
|
handlers=[stream_handler, file_handler] |
|
) |
|
return logging.getLogger('2api') |
|
|
|
logger = setup_logging() |
|
|
|
def load_config(): |
|
"""从 config.json 加载配置(如果存在),否则使用环境变量""" |
|
default_config_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'config.json') |
|
CONFIG_FILE = os.environ.get("CONFIG_FILE_PATH", default_config_path) |
|
config = {} |
|
|
|
if os.path.exists(CONFIG_FILE): |
|
try: |
|
with open(CONFIG_FILE, 'r', encoding='utf-8') as f: |
|
config = json.load(f) |
|
logger.info(f"已从 {CONFIG_FILE} 加载配置") |
|
except (json.JSONDecodeError, IOError) as e: |
|
logger.error(f"加载配置文件失败: {e}") |
|
config = {} |
|
|
|
return config |
|
|
|
def mask_email(email: str) -> str: |
|
"""隐藏邮箱中间部分,保护隐私""" |
|
if not email or '@' not in email: |
|
return "无效邮箱" |
|
|
|
parts = email.split('@') |
|
username = parts[0] |
|
domain = parts[1] |
|
|
|
if len(username) <= 3: |
|
masked_username = username[0] + '*' * (len(username) - 1) |
|
else: |
|
masked_username = username[0] + '*' * (len(username) - 2) + username[-1] |
|
|
|
return f"{masked_username}@{domain}" |
|
|
|
def generate_request_id() -> str: |
|
"""生成唯一的请求ID""" |
|
return f"chatcmpl-{os.urandom(16).hex()}" |
|
|
|
def count_tokens(text: str, model: str = "gpt-3.5-turbo") -> int: |
|
""" |
|
计算文本的token数量 |
|
|
|
Args: |
|
text: 要计算token数量的文本 |
|
model: 模型名称,默认为gpt-3.5-turbo |
|
|
|
Returns: |
|
int: token数量 |
|
""" |
|
|
|
if text is None: |
|
text = "" |
|
elif not isinstance(text, str): |
|
text = str(text) |
|
try: |
|
|
|
if "gpt-4" in model: |
|
encoding = tiktoken.encoding_for_model("gpt-4") |
|
elif "gpt-3.5" in model: |
|
encoding = tiktoken.encoding_for_model("gpt-3.5-turbo") |
|
elif "claude" in model: |
|
|
|
encoding = tiktoken.get_encoding("cl100k_base") |
|
else: |
|
|
|
encoding = tiktoken.get_encoding("cl100k_base") |
|
|
|
|
|
tokens = encoding.encode(text) |
|
return len(tokens) |
|
except Exception as e: |
|
logger.error(f"计算token数量时出错: {e}") |
|
|
|
return len(text) // 4 |
|
|
|
def count_message_tokens(messages: list, model: str = "gpt-3.5-turbo") -> Tuple[int, int, int]: |
|
""" |
|
计算OpenAI格式消息列表的token数量 |
|
|
|
Args: |
|
messages: OpenAI格式的消息列表 |
|
model: 模型名称,默认为gpt-3.5-turbo |
|
|
|
Returns: |
|
Tuple[int, int, int]: (提示tokens数, 完成tokens数, 总tokens数) |
|
""" |
|
|
|
if messages is None: |
|
messages = [] |
|
elif not isinstance(messages, list): |
|
logger.warning(f"count_message_tokens 收到非列表类型的消息: {type(messages)}") |
|
messages = [] |
|
|
|
prompt_tokens = 0 |
|
completion_tokens = 0 |
|
|
|
try: |
|
|
|
for message in messages: |
|
|
|
if not isinstance(message, dict): |
|
logger.warning(f"跳过非字典类型的消息: {type(message)}") |
|
continue |
|
|
|
role = message.get('role', '') |
|
content = message.get('content', '') |
|
|
|
if role and content: |
|
|
|
prompt_tokens += 4 |
|
|
|
|
|
prompt_tokens += 1 |
|
|
|
|
|
prompt_tokens += count_tokens(content, model) |
|
|
|
|
|
if role == 'assistant': |
|
completion_tokens += count_tokens(content, model) |
|
|
|
|
|
prompt_tokens += 2 |
|
|
|
|
|
total_tokens = prompt_tokens + completion_tokens |
|
|
|
return prompt_tokens, completion_tokens, total_tokens |
|
except Exception as e: |
|
logger.error(f"计算消息token数量时出错: {e}") |
|
|
|
return 0, 0, 0 |