|
import time |
|
import logging |
|
import functools |
|
import requests |
|
from abc import ABC, abstractmethod |
|
from typing import Callable, Any, Dict, Optional, Type, Union, TypeVar, cast |
|
|
|
|
|
import config |
|
|
|
|
|
T = TypeVar('T') |
|
|
|
class RetryStrategy(ABC): |
|
"""重试策略的抽象基类""" |
|
|
|
@abstractmethod |
|
def should_retry(self, exception: Exception, retry_count: int, max_retries: int) -> bool: |
|
""" |
|
判断是否应该重试 |
|
|
|
Args: |
|
exception: 捕获的异常 |
|
retry_count: 当前重试次数 |
|
max_retries: 最大重试次数 |
|
|
|
Returns: |
|
bool: 是否应该重试 |
|
""" |
|
pass |
|
|
|
@abstractmethod |
|
def get_retry_delay(self, retry_count: int, base_delay: int) -> float: |
|
""" |
|
计算重试延迟时间 |
|
|
|
Args: |
|
retry_count: 当前重试次数 |
|
base_delay: 基础延迟时间(秒) |
|
|
|
Returns: |
|
float: 重试延迟时间(秒) |
|
""" |
|
pass |
|
|
|
@abstractmethod |
|
def log_retry_attempt(self, logger: logging.Logger, exception: Exception, |
|
retry_count: int, max_retries: int, delay: float) -> None: |
|
""" |
|
记录重试尝试 |
|
|
|
Args: |
|
logger: 日志记录器 |
|
exception: 捕获的异常 |
|
retry_count: 当前重试次数 |
|
max_retries: 最大重试次数 |
|
delay: 重试延迟时间 |
|
""" |
|
pass |
|
|
|
@abstractmethod |
|
def on_retry(self, exception: Exception, retry_count: int) -> None: |
|
""" |
|
重试前的回调函数,可以执行额外操作 |
|
|
|
Args: |
|
exception: 捕获的异常 |
|
retry_count: 当前重试次数 |
|
""" |
|
pass |
|
|
|
|
|
class ExponentialBackoffStrategy(RetryStrategy): |
|
"""指数退避重试策略,适用于连接错误""" |
|
|
|
def should_retry(self, exception: Exception, retry_count: int, max_retries: int) -> bool: |
|
return (isinstance(exception, requests.exceptions.ConnectionError) and |
|
retry_count < max_retries) |
|
|
|
def get_retry_delay(self, retry_count: int, base_delay: int) -> float: |
|
|
|
return base_delay * (2 ** retry_count) |
|
|
|
def log_retry_attempt(self, logger: logging.Logger, exception: Exception, |
|
retry_count: int, max_retries: int, delay: float) -> None: |
|
|
|
if callable(logger) and not isinstance(logger, logging.Logger): |
|
|
|
logger(f"连接错误,{delay:.1f}秒后重试 ({retry_count}/{max_retries}): {exception}", "WARNING") |
|
else: |
|
|
|
logger.warning(f"连接错误,{delay:.1f}秒后重试 ({retry_count}/{max_retries}): {exception}") |
|
|
|
def on_retry(self, exception: Exception, retry_count: int) -> None: |
|
|
|
pass |
|
|
|
|
|
class LinearBackoffStrategy(RetryStrategy): |
|
"""线性退避重试策略,适用于超时错误""" |
|
|
|
def should_retry(self, exception: Exception, retry_count: int, max_retries: int) -> bool: |
|
return (isinstance(exception, requests.exceptions.Timeout) and |
|
retry_count < max_retries) |
|
|
|
def get_retry_delay(self, retry_count: int, base_delay: int) -> float: |
|
|
|
return base_delay * retry_count |
|
|
|
def log_retry_attempt(self, logger: logging.Logger, exception: Exception, |
|
retry_count: int, max_retries: int, delay: float) -> None: |
|
|
|
if callable(logger) and not isinstance(logger, logging.Logger): |
|
|
|
logger(f"请求超时,{delay:.1f}秒后重试 ({retry_count}/{max_retries}): {exception}", "WARNING") |
|
else: |
|
|
|
logger.warning(f"请求超时,{delay:.1f}秒后重试 ({retry_count}/{max_retries}): {exception}") |
|
|
|
def on_retry(self, exception: Exception, retry_count: int) -> None: |
|
|
|
pass |
|
|
|
|
|
class ServerErrorStrategy(RetryStrategy): |
|
"""服务器错误重试策略,适用于5xx错误""" |
|
|
|
def should_retry(self, exception: Exception, retry_count: int, max_retries: int) -> bool: |
|
if not isinstance(exception, requests.exceptions.HTTPError): |
|
return False |
|
|
|
response = getattr(exception, 'response', None) |
|
if response is None: |
|
return False |
|
|
|
return (500 <= response.status_code < 600 and retry_count < max_retries) |
|
|
|
def get_retry_delay(self, retry_count: int, base_delay: int) -> float: |
|
|
|
return base_delay * retry_count |
|
|
|
def log_retry_attempt(self, logger: logging.Logger, exception: Exception, |
|
retry_count: int, max_retries: int, delay: float) -> None: |
|
response = getattr(exception, 'response', None) |
|
status_code = response.status_code if response else 'unknown' |
|
|
|
if callable(logger) and not isinstance(logger, logging.Logger): |
|
|
|
logger(f"服务器错误 {status_code},{delay:.1f}秒后重试 ({retry_count}/{max_retries})", "WARNING") |
|
else: |
|
|
|
logger.warning(f"服务器错误 {status_code},{delay:.1f}秒后重试 ({retry_count}/{max_retries})") |
|
|
|
def on_retry(self, exception: Exception, retry_count: int) -> None: |
|
|
|
pass |
|
|
|
|
|
class RateLimitStrategy(RetryStrategy): |
|
"""速率限制重试策略,适用于429错误,包括账号切换逻辑和延迟重试""" |
|
|
|
def __init__(self, client=None): |
|
""" |
|
初始化速率限制重试策略 |
|
|
|
Args: |
|
client: API客户端实例,用于切换账号 |
|
""" |
|
self.client = client |
|
self.consecutive_429_count = 0 |
|
|
|
def should_retry(self, exception: Exception, retry_count: int, max_retries: int) -> bool: |
|
if not isinstance(exception, requests.exceptions.HTTPError): |
|
return False |
|
|
|
response = getattr(exception, 'response', None) |
|
if response is None: |
|
return False |
|
|
|
is_rate_limit = response.status_code == 429 |
|
if is_rate_limit: |
|
self.consecutive_429_count += 1 |
|
else: |
|
self.consecutive_429_count = 0 |
|
|
|
return is_rate_limit |
|
|
|
def get_retry_delay(self, retry_count: int, base_delay: int) -> float: |
|
|
|
return 0 |
|
|
|
def log_retry_attempt(self, logger: logging.Logger, exception: Exception, |
|
retry_count: int, max_retries: int, delay: float) -> None: |
|
|
|
message = "" |
|
if self.consecutive_429_count > 1: |
|
message = f"连续第{self.consecutive_429_count}次速率限制错误,尝试立即重试" |
|
else: |
|
message = "速率限制错误,尝试切换账号" |
|
|
|
if callable(logger) and not isinstance(logger, logging.Logger): |
|
|
|
logger(message, "WARNING") |
|
else: |
|
|
|
logger.warning(message) |
|
|
|
def on_retry(self, exception: Exception, retry_count: int) -> None: |
|
|
|
user_identifier = getattr(self.client, '_associated_user_identifier', None) |
|
request_ip = getattr(self.client, '_associated_request_ip', None) |
|
|
|
|
|
if self.consecutive_429_count == 1 or (self.consecutive_429_count > 0 and self.consecutive_429_count % 3 == 0): |
|
if self.client and hasattr(self.client, 'email'): |
|
|
|
current_email = self.client.email |
|
config.set_account_cooldown(current_email) |
|
|
|
|
|
new_email, new_password = config.get_next_ondemand_account_details() |
|
if new_email: |
|
|
|
self.client.email = new_email |
|
self.client.password = new_password |
|
self.client.token = "" |
|
self.client.refresh_token = "" |
|
self.client.session_id = "" |
|
|
|
|
|
try: |
|
|
|
current_context_hash = getattr(self.client, '_current_request_context_hash', None) |
|
|
|
self.client.sign_in(context=current_context_hash) |
|
if self.client.create_session(external_context=current_context_hash): |
|
|
|
if hasattr(self.client, '_log'): |
|
self.client._log(f"成功切换到账号 {new_email} 并使用上下文哈希 '{current_context_hash}' 重新登录和创建新会话。", "INFO") |
|
|
|
setattr(self.client, '_new_session_requires_full_history', True) |
|
if hasattr(self.client, '_log'): |
|
self.client._log(f"已设置 _new_session_requires_full_history = True,下次查询应发送完整历史。", "INFO") |
|
else: |
|
|
|
if hasattr(self.client, '_log'): |
|
self.client._log(f"切换到账号 {new_email} 后,创建新会话失败。", "WARNING") |
|
|
|
setattr(self.client, '_new_session_requires_full_history', False) |
|
|
|
|
|
|
|
if not user_identifier: |
|
if hasattr(self.client, '_log'): |
|
self.client._log("RateLimitStrategy: _associated_user_identifier not found on client. Cannot update client_sessions.", "ERROR") |
|
|
|
else: |
|
old_email_in_strategy = current_email |
|
new_email_in_strategy = self.client.email |
|
|
|
with config.config_instance.client_sessions_lock: |
|
if user_identifier in config.config_instance.client_sessions: |
|
user_specific_sessions = config.config_instance.client_sessions[user_identifier] |
|
|
|
|
|
|
|
|
|
|
|
if old_email_in_strategy in user_specific_sessions: |
|
|
|
|
|
del user_specific_sessions[old_email_in_strategy] |
|
if hasattr(self.client, '_log'): |
|
self.client._log(f"RateLimitStrategy: Removed session for old email '{old_email_in_strategy}' for user '{user_identifier}'.", "INFO") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ip_to_use = request_ip if request_ip else user_specific_sessions.get(new_email_in_strategy, {}).get("ip", "unknown_ip_in_retry_update") |
|
|
|
|
|
from datetime import datetime |
|
|
|
|
|
|
|
active_hash_for_new_session = getattr(self.client, '_current_request_context_hash', None) |
|
|
|
user_specific_sessions[new_email_in_strategy] = { |
|
"client": self.client, |
|
"active_context_hash": active_hash_for_new_session, |
|
"last_time": datetime.now(), |
|
"ip": ip_to_use |
|
} |
|
log_message_hash_part = f"set to '{active_hash_for_new_session}' (from client instance's _current_request_context_hash)" if active_hash_for_new_session is not None else "set to None (_current_request_context_hash not found on client instance)" |
|
if hasattr(self.client, '_log'): |
|
self.client._log(f"RateLimitStrategy: Updated/added session for new email '{new_email_in_strategy}' for user '{user_identifier}'. active_context_hash {log_message_hash_part}.", "INFO") |
|
else: |
|
if hasattr(self.client, '_log'): |
|
self.client._log(f"RateLimitStrategy: User '{user_identifier}' not found in client_sessions during update attempt.", "WARNING") |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
if hasattr(self.client, '_log'): |
|
self.client._log(f"切换到账号 {new_email} 后登录或创建会话失败: {e}", "WARNING") |
|
|
|
|
|
|
|
class RetryHandler: |
|
"""重试处理器,管理多个重试策略""" |
|
|
|
def __init__(self, client=None, logger=None): |
|
""" |
|
初始化重试处理器 |
|
|
|
Args: |
|
client: API客户端实例,用于切换账号 |
|
logger: 日志记录器或日志函数 |
|
""" |
|
self.client = client |
|
|
|
|
|
self.logger = logger or logging.getLogger(__name__) |
|
self.strategies = [ |
|
ExponentialBackoffStrategy(), |
|
LinearBackoffStrategy(), |
|
ServerErrorStrategy(), |
|
RateLimitStrategy(client) |
|
] |
|
|
|
def retry_operation(self, operation: Callable[..., T], *args, **kwargs) -> T: |
|
""" |
|
使用重试策略执行操作 |
|
|
|
Args: |
|
operation: 要执行的操作 |
|
*args: 操作的位置参数 |
|
**kwargs: 操作的关键字参数 |
|
|
|
Returns: |
|
操作的结果 |
|
|
|
Raises: |
|
Exception: 如果所有重试都失败,则抛出最后一个异常 |
|
""" |
|
max_retries = config.get_config_value('max_retries') |
|
base_delay = config.get_config_value('retry_delay') |
|
retry_count = 0 |
|
last_exception = None |
|
|
|
while True: |
|
try: |
|
return operation(*args, **kwargs) |
|
except Exception as e: |
|
last_exception = e |
|
|
|
|
|
strategy = next((s for s in self.strategies if s.should_retry(e, retry_count, max_retries)), None) |
|
|
|
if strategy: |
|
retry_count += 1 |
|
delay = strategy.get_retry_delay(retry_count, base_delay) |
|
strategy.log_retry_attempt(self.logger, e, retry_count, max_retries, delay) |
|
strategy.on_retry(e, retry_count) |
|
|
|
if delay > 0: |
|
time.sleep(delay) |
|
else: |
|
|
|
raise |
|
|
|
|
|
def with_retry(max_retries: Optional[int] = None, retry_delay: Optional[int] = None): |
|
""" |
|
重试装饰器,用于装饰需要重试的方法 |
|
|
|
Args: |
|
max_retries: 最大重试次数,如果为None则使用配置值 |
|
retry_delay: 基础重试延迟,如果为None则使用配置值 |
|
|
|
Returns: |
|
装饰后的函数 |
|
""" |
|
def decorator(func): |
|
@functools.wraps(func) |
|
def wrapper(self, *args, **kwargs): |
|
|
|
_max_retries = max_retries or config.get_config_value('max_retries') |
|
_retry_delay = retry_delay or config.get_config_value('retry_delay') |
|
|
|
|
|
handler = RetryHandler(client=self, logger=getattr(self, '_log', None)) |
|
|
|
|
|
def operation(): |
|
return func(self, *args, **kwargs) |
|
|
|
|
|
return handler.retry_operation(operation) |
|
|
|
return wrapper |
|
|
|
return decorator |