seekr / src /utils /api_key_manager.py
Hemang Thakur
fixed chat anthropic
c8abe84
raw
history blame contribute delete
28.2 kB
from threading import Lock
import os
from typing import List, Optional, Literal, Union, Dict
from dotenv import load_dotenv
import re
from langchain_xai import ChatXAI
from langchain_openai import ChatOpenAI
from langchain_anthropic import ChatAnthropic
from langchain_google_genai import ChatGoogleGenerativeAI
from functools import wraps
import time
from openai import RateLimitError, OpenAIError
from anthropic import RateLimitError as AnthropicRateLimitError, APIError as AnthropicAPIError
from google.api_core.exceptions import ResourceExhausted, BadRequest, InvalidArgument
from tenacity import retry, wait_random_exponential, stop_after_attempt, retry_if_exception_type
import asyncio
ModelProvider = Literal["openai", "anthropic", "google", "xai"]
class APIKeyManager:
_instance = None
_lock = Lock()
# Define supported models
SUPPORTED_MODELS = {
"openai": [
"gpt-3.5-turbo",
"gpt-3.5-turbo-instruct",
"gpt-3.5-turbo-1106",
"gpt-3.5-turbo-0125",
"gpt-4-0314",
"gpt-4-0613",
"gpt-4",
"gpt-4-1106-preview",
"gpt-4-0125-preview",
"gpt-4-turbo-preview",
"gpt-4-turbo-2024-04-09",
"gpt-4-turbo",
"o1-mini-2024-09-12",
"o1-mini",
"o1-preview-2024-09-12",
"o1-preview",
"o1",
"gpt-4o-mini-2024-07-18",
"gpt-4o-mini",
"chatgpt-4o-latest",
"gpt-4o-2024-05-13",
"gpt-4o-2024-08-06",
"gpt-4o-2024-11-20",
"gpt-4o"
],
"google": [
"gemini-1.5-flash",
"gemini-1.5-flash-latest",
"gemini-1.5-flash-exp-0827",
"gemini-1.5-flash-001",
"gemini-1.5-flash-002",
"gemini-1.5-flash-8b-exp-0924",
"gemini-1.5-flash-8b-exp-0827",
"gemini-1.5-flash-8b-001",
"gemini-1.5-flash-8b",
"gemini-1.5-flash-8b-latest",
"gemini-1.5-pro",
"gemini-1.5-pro-latest",
"gemini-1.5-pro-001",
"gemini-1.5-pro-002",
"gemini-1.5-pro-exp-0827",
"gemini-1.0-pro",
"gemini-1.0-pro-latest",
"gemini-1.0-pro-001",
"gemini-pro",
"gemini-exp-1114",
"gemini-exp-1121",
"gemini-2.0-pro-exp-02-05",
"gemini-2.0-flash-lite-preview-02-05",
"gemini-2.0-flash-exp",
"gemini-2.0-flash",
"gemini-2.0-flash-thinking-exp-1219",
],
"xai": [
"grok-beta",
"grok-vision-beta",
"grok-2-vision-1212",
"grok-2-1212"
],
"anthropic": [
"claude-3-5-sonnet-20241022",
"claude-3-5-sonnet-latest",
"claude-3-5-haiku-20241022",
"claude-3-5-haiku-latest",
"claude-3-opus-20240229",
"claude-3-opus-latest",
"claude-3-sonnet-20240229",
"claude-3-haiku-20240307"
]
}
def __new__(cls):
with cls._lock:
if cls._instance is None:
cls._instance = super(APIKeyManager, cls).__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self):
if not self._initialized:
self._initialized = True
# 1) Always load env
load_dotenv(override=True)
self._current_indices = {
"openai": 0,
"anthropic": 0,
"google": 0,
"xai": 0
}
self._lock = Lock()
# 2) load all provider keys from environment
self._api_keys = self._load_api_keys()
self._llm = None
self._current_provider = None
# 3) read user’s chosen provider, model, temperature, top_p from env
provider_env = os.getenv("MODEL_PROVIDER", "openai").strip().lower()
self.model_name = os.getenv("MODEL_NAME", "gpt-3.5-turbo").strip()
temp_str = os.getenv("MODEL_TEMPERATURE", "0")
topp_str = os.getenv("MODEL_TOP_P", "1")
try:
self.temperature = float(temp_str)
except ValueError:
self.temperature = 0.0
try:
self.top_p = float(topp_str)
except ValueError:
self.top_p = 1.0
def _reinit(self):
self._initialized = False
self.__init__()
def _load_api_keys(self) -> Dict[str, List[str]]:
"""Load API keys from environment variables dynamically."""
api_keys = {
"openai": [],
"anthropic": [],
"google": [],
"xai": []
}
# Get all environment variables
env_vars = dict(os.environ)
# Load OpenAI API keys
openai_pattern = re.compile(r'OPENAI_API_KEY_\d+$')
openai_keys = {k: v for k, v in env_vars.items() if openai_pattern.match(k) and v.strip()}
if not openai_keys:
default_key = os.getenv('OPENAI_API_KEY')
if default_key and default_key.strip():
api_keys["openai"].append(default_key)
else:
sorted_keys = sorted(openai_keys.keys(), key=lambda x: int(x.split('_')[-1]))
for key_name in sorted_keys:
api_key = openai_keys[key_name]
if api_key and api_key.strip():
api_keys["openai"].append(api_key)
# Load Google API keys
google_pattern = re.compile(r'GOOGLE_API_KEY_\d+$')
google_keys = {k: v for k, v in env_vars.items() if google_pattern.match(k) and v.strip()}
if not google_keys:
default_key = os.getenv('GOOGLE_API_KEY')
if default_key and default_key.strip():
api_keys["google"].append(default_key)
else:
sorted_keys = sorted(google_keys.keys(), key=lambda x: int(x.split('_')[-1]))
for key_name in sorted_keys:
api_key = google_keys[key_name]
if api_key and api_key.strip():
api_keys["google"].append(api_key)
# Load XAI API keys
xai_pattern = re.compile(r'XAI_API_KEY_\d+$')
xai_keys = {k: v for k, v in env_vars.items() if xai_pattern.match(k) and v.strip()}
if not xai_keys:
default_key = os.getenv('XAI_API_KEY')
if default_key and default_key.strip():
api_keys["xai"].append(default_key)
else:
sorted_keys = sorted(xai_keys.keys(), key=lambda x: int(x.split('_')[-1]))
for key_name in sorted_keys:
api_key = xai_keys[key_name]
if api_key and api_key.strip():
api_keys["xai"].append(api_key)
# Load Anthropic API keys
anthropic_pattern = re.compile(r'ANTHROPIC_API_KEY_\d+$')
anthropic_keys = {k: v for k, v in env_vars.items() if anthropic_pattern.match(k) and v.strip()}
if not anthropic_keys:
default_key = os.getenv('ANTHROPIC_API_KEY')
if default_key and default_key.strip():
api_keys["anthropic"].append(default_key)
else:
sorted_keys = sorted(anthropic_keys.keys(), key=lambda x: int(x.split('_')[-1]))
for key_name in sorted_keys:
api_key = anthropic_keys[key_name]
if api_key and api_key.strip():
api_keys["anthropic"].append(api_key)
if not any(api_keys.values()):
raise Exception("No valid API keys found in environment variables")
for provider, keys in api_keys.items():
if keys:
print(f"Loaded {len(keys)} {provider} API keys for rotation")
return api_keys
def get_next_api_key(self, provider: ModelProvider) -> str:
"""Get the next API key in round-robin fashion for the specified provider."""
with self._lock:
if not self._api_keys.get(provider) or len(self._api_keys[provider]) == 0:
raise Exception(f"No API key found for {provider}")
if provider not in self._current_indices:
self._current_indices[provider] = 0
current_key = self._api_keys[provider][self._current_indices[provider]]
self._current_indices[provider] = (self._current_indices[provider] + 1) % len(self._api_keys[provider])
return current_key
def _get_provider_for_model(self) -> ModelProvider:
"""Determine the provider based on the model name."""
load_dotenv(override=True) # to refresh in case .env changed
provider_env = os.getenv("MODEL_PROVIDER", "openai").lower().strip()
if provider_env not in self.SUPPORTED_MODELS:
raise Exception(
f"Invalid or missing MODEL_PROVIDER in env: '{provider_env}'. "
f"Must be one of: {list(self.SUPPORTED_MODELS.keys())}"
)
# check if user-chosen model is in that provider’s list
if self.model_name not in self.SUPPORTED_MODELS[provider_env]:
available = self.SUPPORTED_MODELS[provider_env]
raise Exception(
f"Model '{self.model_name}' is not available under provider '{provider_env}'. "
f"Available: {available}"
)
return provider_env
def _initialize_llm(
self,
model_name: Optional[str] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
max_tokens: Optional[int] = None,
streaming: bool = False
):
"""Initialize LLM with the next API key in rotation."""
load_dotenv(override=True) # refresh .env in case it changed
provider = self._get_provider_for_model()
model_name = model_name if model_name else self.model_name
temperature = temperature if temperature else self.temperature
top_p = top_p if top_p else self.top_p
api_key = self.get_next_api_key(provider)
print(f"Using provider={provider}, model_name={model_name}, "
f"temperature={temperature}, top_p={top_p}, key={api_key}")
kwargs = {
"model": model_name,
"temperature": temperature,
"top_p": top_p,
"max_retries": 0,
"streaming": streaming,
"api_key": api_key,
}
if max_tokens is not None:
kwargs["max_tokens"] = max_tokens
if provider == "openai":
self._llm = ChatOpenAI(**kwargs)
elif provider == "google":
self._llm = ChatGoogleGenerativeAI(**kwargs)
elif provider == "anthropic":
self._llm = ChatAnthropic(**kwargs)
else:
self._llm = ChatXAI(**kwargs)
self._current_provider = provider
def get_llm(
self,
model_name: Optional[str] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
max_tokens: Optional[int] = None,
streaming: bool = False
) -> Union[ChatOpenAI, ChatGoogleGenerativeAI, ChatAnthropic, ChatXAI]:
"""Get LLM instance with the current API key."""
provider = self._get_provider_for_model()
model_name = model_name if model_name else self.model_name
temperature = temperature if temperature else self.temperature
top_p = top_p if top_p else self.top_p
if self._llm is None or provider != self._current_provider:
self._initialize_llm(model_name, temperature, top_p, max_tokens, streaming)
return self._llm
def rotate_key(self, provider: Optional[ModelProvider] = None, streaming: bool = False) -> None:
"""Manually rotate to the next API key."""
if provider is None:
provider = self._current_provider
self._initialize_llm(streaming=streaming)
def get_all_api_keys(self, provider: Optional[ModelProvider] = None) -> Union[Dict[str, List[str]], List[str]]:
"""Get all available API keys."""
if provider:
return self._api_keys[provider].copy()
return {k: v.copy() for k, v in self._api_keys.items()}
def get_key_count(self, provider: Optional[ModelProvider] = None) -> Union[Dict[str, int], int]:
"""Get the total number of available API keys."""
if provider:
return len(self._api_keys[provider])
return {k: len(v) for k, v in self._api_keys.items()}
def __len__(self) -> Dict[str, int]:
"""Get the number of active API keys for each provider."""
return self.get_key_count()
def __bool__(self) -> bool:
"""Check if there are any API keys available."""
return any(bool(keys) for keys in self._api_keys.values())
def with_api_manager(
model_name: Optional[str] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
max_tokens: Optional[int] = None,
streaming: bool = False,
delay_on_timeout: int = 20,
max_token_reduction_attempts: int = 0
):
"""Decorator for automatic key rotation on error with delay on timeout."""
manager = APIKeyManager()
provider = manager._get_provider_for_model()
model_name = model_name if model_name else manager.model_name
temperature = temperature if temperature else manager.temperature
top_p = top_p if top_p else manager.top_p
key_count = manager.get_key_count(provider)
def decorator(func):
if asyncio.iscoroutinefunction(func):
@wraps(func)
async def wrapper(*args, **kwargs):
if key_count > 1:
all_keys = manager.get_all_api_keys(provider)
tried_keys = set()
current_max_tokens = max_tokens
token_reduction_attempts = 0
while len(tried_keys) < len(all_keys):
try:
llm = manager.get_llm(
model_name=model_name,
temperature=temperature,
top_p=top_p,
max_tokens=current_max_tokens,
streaming=streaming
)
result = await func(*args, **kwargs, llm=llm)
return result
except (RateLimitError, ResourceExhausted, AnthropicRateLimitError) as e:
current_key = manager._api_keys[provider][(manager._current_indices[provider] - 1) % len(all_keys)]
print(f"Rate limit error with {provider} API key {current_key}: {str(e)}")
tried_keys.add(current_key)
if len(tried_keys) < len(all_keys):
manager.rotate_key(provider=provider, streaming=streaming)
print(f"Using next available {provider} API key")
else:
if delay_on_timeout > 0:
print(f"Waiting for {delay_on_timeout} seconds before retrying with the first key...")
time.sleep(delay_on_timeout)
manager._current_indices[provider] = 0
else:
print(f"All {provider} API keys failed due to rate limits: {str(e)}")
raise
except (OpenAIError, AnthropicAPIError, BadRequest, InvalidArgument) as e:
error_str = str(e)
if "token" in error_str.lower() or "context length" in error_str.lower():
print(f"Token limit error encountered: {error_str}")
if max_token_reduction_attempts > 0 and max_tokens is not None and token_reduction_attempts < max_token_reduction_attempts:
current_max_tokens = int(current_max_tokens * 0.8) # Reduce the local variable
token_reduction_attempts += 1
print(f"Retrying with reduced max_tokens: {current_max_tokens}")
continue # Retry with reduced max_tokens
else:
print("Max token reduction attempts reached or token reduction disabled. Proceeding with key rotation.")
current_key = manager._api_keys[provider][(manager._current_indices[provider] - 1) % len(all_keys)]
tried_keys.add(current_key)
if len(tried_keys) < len(all_keys):
manager.rotate_key(provider=provider, streaming=streaming)
print(f"Using next available {provider} API key after token limit error.")
else:
raise # All keys tried, raise the token limit error
else:
# Re-raise other API errors
raise
# Attempt one final time after trying all keys (for rate limits with delay)
try:
llm = manager.get_llm(
model_name=model_name,
temperature=temperature,
top_p=top_p,
max_tokens=current_max_tokens, # Use the current value
streaming=streaming
)
result = await func(*args, **kwargs, llm=llm)
return result
except (RateLimitError, ResourceExhausted, AnthropicRateLimitError,
OpenAIError, AnthropicAPIError, BadRequest, InvalidArgument) as e:
print(f"Error after retrying all {provider} API keys: {str(e)}")
raise
elif key_count == 1:
@retry(
wait=wait_random_exponential(min=10, max=60),
stop=stop_after_attempt(6),
retry=retry_if_exception_type((
RateLimitError, ResourceExhausted, AnthropicRateLimitError,
OpenAIError, AnthropicAPIError, BadRequest, InvalidArgument))
)
async def attempt_function_call():
llm = manager.get_llm(
model_name=model_name,
temperature=temperature,
top_p=top_p,
max_tokens=max_tokens,
streaming=streaming
)
return await func(*args, **kwargs, llm=llm)
try:
return await attempt_function_call()
except (RateLimitError, ResourceExhausted, AnthropicRateLimitError,
OpenAIError, AnthropicAPIError, BadRequest, InvalidArgument) as e:
print(f"Error encountered for {provider} after multiple retries: {str(e)}")
raise
else:
print(f"No API keys found for provider: {provider}")
raise
else:
@wraps(func)
def wrapper(*args, **kwargs):
if key_count > 1:
all_keys = manager.get_all_api_keys(provider)
tried_keys = set()
current_max_tokens = max_tokens
token_reduction_attempts = 0
while len(tried_keys) < len(all_keys):
try:
llm = manager.get_llm(
model_name=model_name,
temperature=temperature,
top_p=top_p,
max_tokens=current_max_tokens,
streaming=streaming
)
result = func(*args, **kwargs, llm=llm)
return result
except (RateLimitError, ResourceExhausted, AnthropicRateLimitError) as e:
current_key = manager._api_keys[provider][(manager._current_indices[provider] - 1) % len(all_keys)]
print(f"Rate limit error with {provider} API key {current_key}: {str(e)}")
tried_keys.add(current_key)
if len(tried_keys) < len(all_keys):
manager.rotate_key(provider=provider, streaming=streaming)
print(f"Using next available {provider} API key")
else:
if delay_on_timeout > 0:
print(f"Waiting for {delay_on_timeout} seconds before retrying with the first key...")
time.sleep(delay_on_timeout)
manager._current_indices[provider] = 0
else:
print(f"All {provider} API keys failed due to rate limits: {str(e)}")
raise
except (OpenAIError, AnthropicAPIError, BadRequest, InvalidArgument) as e:
error_str = str(e)
if "token" in error_str.lower() or "context length" in error_str.lower():
print(f"Token limit error encountered: {error_str}")
if max_token_reduction_attempts > 0 and max_tokens is not None and token_reduction_attempts < max_token_reduction_attempts:
current_max_tokens = int(current_max_tokens * 0.8)
token_reduction_attempts += 1
print(f"Retrying with reduced max_tokens: {current_max_tokens}")
continue # Retry with reduced max_tokens
else:
print("Max token reduction attempts reached or token reduction disabled. Proceeding with key rotation.")
current_key = manager._api_keys[provider][(manager._current_indices[provider] - 1) % len(all_keys)]
tried_keys.add(current_key)
if len(tried_keys) < len(all_keys):
manager.rotate_key(provider=provider, streaming=streaming)
print(f"Using next available {provider} API key after token limit error.")
else:
raise # All keys tried, raise the token limit error
else:
# Re-raise other API errors
raise
# Attempt one final time after trying all keys (for rate limits with delay)
try:
llm = manager.get_llm(
model_name=model_name,
temperature=temperature,
top_p=top_p,
max_tokens=current_max_tokens,
streaming=streaming
)
result = func(*args, **kwargs, llm=llm)
return result
except (RateLimitError, ResourceExhausted, AnthropicRateLimitError,
OpenAIError, AnthropicAPIError, BadRequest, InvalidArgument) as e:
print(f"Error after retrying all {provider} API keys: {str(e)}")
raise
elif key_count == 1:
@retry(
wait=wait_random_exponential(min=10, max=60),
stop=stop_after_attempt(6),
retry=retry_if_exception_type((
RateLimitError, ResourceExhausted, AnthropicRateLimitError,
OpenAIError, AnthropicAPIError, BadRequest, InvalidArgument))
)
def attempt_function_call():
llm = manager.get_llm(
model_name=model_name,
temperature=temperature,
top_p=top_p,
max_tokens=max_tokens,
streaming=streaming
)
return func(*args, **kwargs, llm=llm)
try:
return attempt_function_call()
except (RateLimitError, ResourceExhausted, AnthropicRateLimitError,
OpenAIError, AnthropicAPIError, BadRequest, InvalidArgument) as e:
print(f"Error encountered for {provider} after multiple retries: {str(e)}")
raise
else:
print(f"No API keys found for provider: {provider}")
raise
return wrapper
return decorator
if __name__ == "__main__":
import asyncio
prompt = "What is the capital of France?"
# Test key rotation
async def test_load_balancing(prompt: str, test_count: int = 10, stream: bool = False):
@with_api_manager(streaming=stream)
async def test(prompt: str, test_count: int = 10, *, llm):
print("="*50)
for i in range(test_count):
try:
print(f"\nTest {i+1} of {test_count}")
if stream:
async for chunk in llm.astream(prompt):
print(chunk.content, end="", flush=True)
print("\n" + "-"*50 if i != test_count - 1 else "\n" + "="*50)
else:
response = await llm.ainvoke(prompt)
print(f"Response: {response.content.strip()}")
print("-"*50) if i != test_count - 1 else print("="*50)
except (RateLimitError, ResourceExhausted, AnthropicRateLimitError) as e:
print(f"Error: {str(e)}")
raise
await test(prompt, test_count=test_count)
# Test without load balancing
def test_without_load_balancing(model_name: str, prompt: str, test_count: int = 10):
manager = APIKeyManager()
print(f"Using model: {model_name}")
print("="*50)
i = 0
while i < test_count:
try:
print(f"Test {i+1} of {test_count}")
llm = manager.get_llm(model_name=model_name)
response = llm.invoke(prompt)
print(f"Response: {response.content.strip()}")
print("-"*50) if i != test_count - 1 else print("="*50)
i += 1
except Exception as e:
raise Exception(f"Error with {model_name}: {str(e)}")
# test_without_load_balancing(model_name="gemini-exp-1121", prompt=prompt, test_count=50)
asyncio.run(test_load_balancing(prompt=prompt, test_count=100, stream=True))