|
import os |
|
import json |
|
import time |
|
from collections import defaultdict |
|
import threading |
|
from typing import Dict, List, Any, Optional, Union, get_type_hints |
|
from datetime import datetime, timedelta |
|
from utils import logger, load_config |
|
|
|
|
|
class Config: |
|
"""配置管理类,用于存储和管理所有配置""" |
|
|
|
|
|
_defaults = { |
|
"ondemand_session_timeout_minutes": 30, |
|
"session_timeout_minutes": 3600, |
|
"max_retries": 5, |
|
"retry_delay": 3, |
|
"request_timeout": 45, |
|
"stream_timeout": 180, |
|
"rate_limit": 30, |
|
"account_cooldown_seconds": 300, |
|
"debug_mode": False, |
|
"api_access_token": "sk-2api-ondemand-access-token-2025", |
|
"stats_file_path": "stats_data.json", |
|
"stats_backup_path": "stats_data_backup.json", |
|
"stats_save_interval": 300, |
|
"max_history_items": 1000, |
|
"default_endpoint_id": "predefined-claude-3.7-sonnet" |
|
} |
|
|
|
|
|
_model_mapping = { |
|
"gpt-3.5-turbo": "predefined-openai-gpto3-mini", |
|
"gpto3-mini": "predefined-openai-gpto3-mini", |
|
"gpt-4o": "predefined-openai-gpt4o", |
|
"gpt-4o-mini": "predefined-openai-gpt4o-mini", |
|
"gpt-4-turbo": "predefined-openai-gpt4.1", |
|
"gpt-4.1": "predefined-openai-gpt4.1", |
|
"gpt-4.1-mini": "predefined-openai-gpt4.1-mini", |
|
"gpt-4.1-nano": "predefined-openai-gpt4.1-nano", |
|
"deepseek-v3": "predefined-deepseek-v3", |
|
"deepseek-r1": "predefined-deepseek-r1", |
|
"claude-3.5-sonnet": "predefined-claude-3.5-sonnet", |
|
"claude-3.7-sonnet": "predefined-claude-3.7-sonnet", |
|
"claude-3-opus": "predefined-claude-3-opus", |
|
"claude-3-haiku": "predefined-claude-3-haiku", |
|
"gemini-1.5-pro": "predefined-gemini-2.0-flash", |
|
"gemini-2.0-flash": "predefined-gemini-2.0-flash", |
|
|
|
} |
|
|
|
def __init__(self): |
|
"""初始化配置对象""" |
|
|
|
self._config = self._defaults.copy() |
|
|
|
|
|
self.usage_stats = { |
|
"total_requests": 0, |
|
"successful_requests": 0, |
|
"failed_requests": 0, |
|
"model_usage": defaultdict(int), |
|
"account_usage": defaultdict(int), |
|
"daily_usage": defaultdict(int), |
|
"hourly_usage": defaultdict(int), |
|
"request_history": [], |
|
"total_prompt_tokens": 0, |
|
"total_completion_tokens": 0, |
|
"total_tokens": 0, |
|
"model_tokens": defaultdict(int), |
|
"daily_tokens": defaultdict(int), |
|
"hourly_tokens": defaultdict(int), |
|
"last_saved": datetime.now().isoformat() |
|
} |
|
|
|
|
|
self.usage_stats_lock = threading.Lock() |
|
self.account_index_lock = threading.Lock() |
|
self.client_sessions_lock = threading.Lock() |
|
|
|
|
|
self.current_account_index = 0 |
|
|
|
|
|
|
|
|
|
self.client_sessions = {} |
|
|
|
|
|
self.accounts = [] |
|
|
|
|
|
|
|
self.account_cooldowns = {} |
|
|
|
def get(self, key: str, default: Any = None) -> Any: |
|
"""获取配置值""" |
|
return self._config.get(key, default) |
|
|
|
def set(self, key: str, value: Any) -> None: |
|
"""设置配置值""" |
|
self._config[key] = value |
|
|
|
def update(self, config_dict: Dict[str, Any]) -> None: |
|
"""批量更新配置值""" |
|
self._config.update(config_dict) |
|
|
|
def get_model_endpoint(self, model_name: str) -> str: |
|
"""获取模型对应的端点ID""" |
|
return self._model_mapping.get(model_name, self.get("default_endpoint_id")) |
|
|
|
def load_from_file(self) -> bool: |
|
"""从配置文件加载配置""" |
|
try: |
|
|
|
config_data = load_config() |
|
if config_data: |
|
|
|
for key, value in config_data.items(): |
|
if key != "accounts": |
|
self.set(key, value) |
|
|
|
|
|
if "accounts" in config_data: |
|
self.accounts = config_data["accounts"] |
|
|
|
logger.info("已从配置文件加载配置") |
|
return True |
|
return False |
|
except Exception as e: |
|
logger.error(f"加载配置文件时出错: {e}") |
|
return False |
|
|
|
def load_from_env(self) -> None: |
|
"""从环境变量加载配置""" |
|
|
|
if not self.accounts: |
|
accounts_env = os.getenv("ONDEMAND_ACCOUNTS", "") |
|
if accounts_env: |
|
try: |
|
self.accounts = json.loads(accounts_env).get('accounts', []) |
|
logger.info("已从环境变量加载账户信息") |
|
except json.JSONDecodeError: |
|
logger.error("解码 ONDEMAND_ACCOUNTS 环境变量失败") |
|
|
|
|
|
env_mappings = { |
|
"ondemand_session_timeout_minutes": "ONDEMAND_SESSION_TIMEOUT_MINUTES", |
|
"session_timeout_minutes": "SESSION_TIMEOUT_MINUTES", |
|
"max_retries": "MAX_RETRIES", |
|
"retry_delay": "RETRY_DELAY", |
|
"request_timeout": "REQUEST_TIMEOUT", |
|
"stream_timeout": "STREAM_TIMEOUT", |
|
"rate_limit": "RATE_LIMIT", |
|
"debug_mode": "DEBUG_MODE", |
|
"api_access_token": "API_ACCESS_TOKEN" |
|
} |
|
|
|
for config_key, env_key in env_mappings.items(): |
|
env_value = os.getenv(env_key) |
|
if env_value is not None: |
|
|
|
default_value = self.get(config_key) |
|
if isinstance(default_value, bool): |
|
self.set(config_key, env_value.lower() == 'true') |
|
elif isinstance(default_value, int): |
|
self.set(config_key, int(env_value)) |
|
elif isinstance(default_value, float): |
|
self.set(config_key, float(env_value)) |
|
else: |
|
self.set(config_key, env_value) |
|
|
|
def save_stats_to_file(self): |
|
"""将统计数据保存到文件中""" |
|
try: |
|
with self.usage_stats_lock: |
|
|
|
stats_copy = { |
|
"total_requests": self.usage_stats["total_requests"], |
|
"successful_requests": self.usage_stats["successful_requests"], |
|
"failed_requests": self.usage_stats["failed_requests"], |
|
"model_usage": dict(self.usage_stats["model_usage"]), |
|
"account_usage": dict(self.usage_stats["account_usage"]), |
|
"daily_usage": dict(self.usage_stats["daily_usage"]), |
|
"hourly_usage": dict(self.usage_stats["hourly_usage"]), |
|
"request_history": list(self.usage_stats["request_history"]), |
|
"total_prompt_tokens": self.usage_stats["total_prompt_tokens"], |
|
"total_completion_tokens": self.usage_stats["total_completion_tokens"], |
|
"total_tokens": self.usage_stats["total_tokens"], |
|
"model_tokens": dict(self.usage_stats["model_tokens"]), |
|
"daily_tokens": dict(self.usage_stats["daily_tokens"]), |
|
"hourly_tokens": dict(self.usage_stats["hourly_tokens"]), |
|
"last_saved": datetime.now().isoformat() |
|
} |
|
|
|
stats_file_path = self.get("stats_file_path") |
|
stats_backup_path = self.get("stats_backup_path") |
|
|
|
|
|
with open(stats_backup_path, 'w', encoding='utf-8') as f: |
|
json.dump(stats_copy, f, ensure_ascii=False, indent=2) |
|
|
|
|
|
if os.path.exists(stats_file_path): |
|
os.remove(stats_file_path) |
|
|
|
|
|
os.rename(stats_backup_path, stats_file_path) |
|
|
|
logger.info(f"统计数据已保存到 {stats_file_path}") |
|
self.usage_stats["last_saved"] = datetime.now().isoformat() |
|
except Exception as e: |
|
logger.error(f"保存统计数据时出错: {e}") |
|
|
|
def load_stats_from_file(self): |
|
"""从文件中加载统计数据""" |
|
try: |
|
stats_file_path = self.get("stats_file_path") |
|
if os.path.exists(stats_file_path): |
|
with open(stats_file_path, 'r', encoding='utf-8') as f: |
|
saved_stats = json.load(f) |
|
|
|
with self.usage_stats_lock: |
|
|
|
self.usage_stats["total_requests"] = saved_stats.get("total_requests", 0) |
|
self.usage_stats["successful_requests"] = saved_stats.get("successful_requests", 0) |
|
self.usage_stats["failed_requests"] = saved_stats.get("failed_requests", 0) |
|
self.usage_stats["total_prompt_tokens"] = saved_stats.get("total_prompt_tokens", 0) |
|
self.usage_stats["total_completion_tokens"] = saved_stats.get("total_completion_tokens", 0) |
|
self.usage_stats["total_tokens"] = saved_stats.get("total_tokens", 0) |
|
|
|
|
|
for model, count in saved_stats.get("model_usage", {}).items(): |
|
self.usage_stats["model_usage"][model] = count |
|
|
|
for account, count in saved_stats.get("account_usage", {}).items(): |
|
self.usage_stats["account_usage"][account] = count |
|
|
|
for day, count in saved_stats.get("daily_usage", {}).items(): |
|
self.usage_stats["daily_usage"][day] = count |
|
|
|
for hour, count in saved_stats.get("hourly_usage", {}).items(): |
|
self.usage_stats["hourly_usage"][hour] = count |
|
|
|
for model, tokens in saved_stats.get("model_tokens", {}).items(): |
|
self.usage_stats["model_tokens"][model] = tokens |
|
|
|
for day, tokens in saved_stats.get("daily_tokens", {}).items(): |
|
self.usage_stats["daily_tokens"][day] = tokens |
|
|
|
for hour, tokens in saved_stats.get("hourly_tokens", {}).items(): |
|
self.usage_stats["hourly_tokens"][hour] = tokens |
|
|
|
|
|
self.usage_stats["request_history"] = saved_stats.get("request_history", []) |
|
|
|
|
|
max_history_items = self.get("max_history_items") |
|
if len(self.usage_stats["request_history"]) > max_history_items: |
|
self.usage_stats["request_history"] = self.usage_stats["request_history"][-max_history_items:] |
|
|
|
logger.info(f"已从 {stats_file_path} 加载统计数据") |
|
return True |
|
else: |
|
logger.info(f"未找到统计数据文件 {stats_file_path},将使用默认值") |
|
return False |
|
except Exception as e: |
|
logger.error(f"加载统计数据时出错: {e}") |
|
return False |
|
|
|
def start_stats_save_thread(self): |
|
"""启动定期保存统计数据的线程""" |
|
def save_stats_periodically(): |
|
while True: |
|
time.sleep(self.get("stats_save_interval")) |
|
self.save_stats_to_file() |
|
|
|
save_thread = threading.Thread(target=save_stats_periodically, daemon=True) |
|
save_thread.start() |
|
logger.info(f"统计数据保存线程已启动,每 {self.get('stats_save_interval')} 秒保存一次") |
|
|
|
def init(self): |
|
"""初始化配置,从配置文件或环境变量加载设置""" |
|
|
|
self.load_from_file() |
|
|
|
|
|
self.load_from_env() |
|
|
|
|
|
if not self.accounts: |
|
error_msg = "在 config.json 或环境变量 ONDEMAND_ACCOUNTS 中未找到账户信息" |
|
logger.critical(error_msg) |
|
|
|
logger.warning("将继续运行,但没有账户信息,可能会导致功能受限") |
|
|
|
logger.info("已加载API访问Token") |
|
|
|
|
|
self.load_stats_from_file() |
|
|
|
|
|
self.start_stats_save_thread() |
|
|
|
def get_next_ondemand_account_details(self): |
|
"""获取下一个 OnDemand 账户的邮箱和密码,用于轮询。 |
|
会跳过处于冷却期的账户。""" |
|
with self.account_index_lock: |
|
current_time = datetime.now() |
|
|
|
|
|
expired_cooldowns = [email for email, end_time in self.account_cooldowns.items() |
|
if end_time < current_time] |
|
for email in expired_cooldowns: |
|
del self.account_cooldowns[email] |
|
logger.info(f"账户 {email} 的冷却期已结束,现在可用") |
|
|
|
|
|
for _ in range(len(self.accounts)): |
|
account_details = self.accounts[self.current_account_index] |
|
email = account_details.get('email') |
|
|
|
|
|
self.current_account_index = (self.current_account_index + 1) % len(self.accounts) |
|
|
|
|
|
if email in self.account_cooldowns: |
|
cooldown_end = self.account_cooldowns[email] |
|
remaining_seconds = (cooldown_end - current_time).total_seconds() |
|
logger.warning(f"账户 {email} 仍在冷却期中,还剩 {remaining_seconds:.1f} 秒") |
|
continue |
|
|
|
|
|
logger.info(f"[系统] 新会话将使用账户: {email}") |
|
return email, account_details.get('password') |
|
|
|
|
|
logger.warning("所有账户都在冷却期!使用第一个账户,尽管它可能会触发速率限制") |
|
account_details = self.accounts[0] |
|
return account_details.get('email'), account_details.get('password') |
|
|
|
|
|
|
|
config_instance = Config() |
|
|
|
def init_config(): |
|
"""初始化配置的兼容函数,用于向后兼容""" |
|
config_instance.init() |
|
|
|
|
|
def get_config_value(name: str, default: Any = None) -> Any: |
|
""" |
|
获取当前配置变量的最新值。 |
|
推荐外部通过 config.get_config_value('变量名') 获取配置。 |
|
对于 accounts, model_mapping, usage_stats, client_sessions,请使用新增的专用getter函数。 |
|
""" |
|
return config_instance.get(name, default) |
|
|
|
|
|
def get_accounts() -> List[Dict[str, str]]: |
|
"""获取账户信息列表""" |
|
return config_instance.accounts |
|
|
|
def get_model_mapping() -> Dict[str, str]: |
|
"""获取模型名称到端点ID的映射""" |
|
return config_instance._model_mapping |
|
|
|
def get_usage_stats() -> Dict[str, Any]: |
|
"""获取用量统计数据""" |
|
return config_instance.usage_stats |
|
|
|
def get_client_sessions() -> Dict[str, Any]: |
|
"""获取客户端会话信息""" |
|
return config_instance.client_sessions |
|
|
|
def get_next_ondemand_account_details(): |
|
"""获取下一个账户的兼容函数""" |
|
return config_instance.get_next_ondemand_account_details() |
|
|
|
def set_account_cooldown(email, cooldown_seconds=None): |
|
"""设置账户冷却期 |
|
|
|
Args: |
|
email: 账户邮箱 |
|
cooldown_seconds: 冷却时间(秒),如果为None则使用默认配置 |
|
""" |
|
if cooldown_seconds is None: |
|
cooldown_seconds = config_instance.get('account_cooldown_seconds') |
|
|
|
cooldown_end = datetime.now() + timedelta(seconds=cooldown_seconds) |
|
with config_instance.account_index_lock: |
|
config_instance.account_cooldowns[email] = cooldown_end |
|
logger.warning(f"账户 {email} 已设置冷却期 {cooldown_seconds} 秒,将于 {cooldown_end.strftime('%Y-%m-%d %H:%M:%S')} 结束") |
|
|
|
|
|
|
|
|
|
|