import os from dotenv import load_dotenv from fastapi import FastAPI, HTTPException, Request, Depends, Security from fastapi.responses import StreamingResponse, HTMLResponse, JSONResponse, FileResponse from fastapi.security import APIKeyHeader from pydantic import BaseModel import httpx from functools import lru_cache from pathlib import Path import json import datetime import time from typing import Optional, Dict, List, Any, Generator import asyncio from starlette.status import HTTP_403_FORBIDDEN import cloudscraper from concurrent.futures import ThreadPoolExecutor import uvloop from fastapi.middleware.gzip import GZipMiddleware from starlette.middleware.cors import CORSMiddleware # Enable uvloop for faster event loop asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) # Thread pool for CPU-bound operations executor = ThreadPoolExecutor(max_workers=8) # Load environment variables once at startup load_dotenv() # API key security scheme api_key_header = APIKeyHeader(name="Authorization", auto_error=False) # Initialize usage tracker from usage_tracker import UsageTracker usage_tracker = UsageTracker() app = FastAPI() # Add middleware for compression and CORS app.add_middleware(GZipMiddleware, minimum_size=1000) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Environment variables (cached) @lru_cache(maxsize=1) def get_env_vars(): return { 'api_keys': os.getenv('API_KEYS', '').split(','), 'secret_api_endpoint': os.getenv('SECRET_API_ENDPOINT'), 'secret_api_endpoint_2': os.getenv('SECRET_API_ENDPOINT_2'), 'secret_api_endpoint_3': os.getenv('SECRET_API_ENDPOINT_3'), 'mistral_api': "https://api.mistral.ai", 'mistral_key': os.getenv('MISTRAL_KEY'), 'image_endpoint': os.getenv("IMAGE_ENDPOINT"), 'endpoint_origin': os.getenv('ENDPOINT_ORIGIN') } # Configuration for models - use sets for faster lookups mistral_models = { "mistral-large-latest", "pixtral-large-latest", "mistral-moderation-latest", "ministral-3b-latest", "ministral-8b-latest", "open-mistral-nemo", "mistral-small-latest", "mistral-saba-latest", "codestral-latest" } alternate_models = { "gpt-4o-mini", "deepseek-v3", "llama-3.1-8b-instruct", "searchgpt", "llama-3.1-sonar-small-128k-online", "sonar-reasoning", "deepseek-r1-uncensored", "tinyswallow1.5b", "andy-3.5", "o3-mini-low", "hermes-3-llama-3.2-3b" } # Request payload model class Payload(BaseModel): model: str messages: list stream: bool = False # Server status global variable server_status = True available_model_ids: List[str] = [] # Create a reusable httpx client pool with connection pooling @lru_cache(maxsize=1) def get_async_client(): return httpx.AsyncClient( timeout=60.0, limits=httpx.Limits(max_keepalive_connections=20, max_connections=100) ) # Create a cloudscraper pool scraper_pool = [] MAX_SCRAPERS = 10 def get_scraper(): if not scraper_pool: for _ in range(MAX_SCRAPERS): scraper_pool.append(cloudscraper.create_scraper()) return scraper_pool[int(time.time() * 1000) % MAX_SCRAPERS] # Simple round-robin # API key validation - optimized to avoid string operations when possible async def verify_api_key(api_key: str = Security(api_key_header)) -> bool: if not api_key: raise HTTPException( status_code=HTTP_403_FORBIDDEN, detail="No API key provided" ) # Only clean if needed if api_key.startswith('Bearer '): api_key = api_key[7:] # Remove 'Bearer ' prefix # Get API keys from environment valid_api_keys = get_env_vars()['api_keys'] if not valid_api_keys or valid_api_keys == ['']: raise HTTPException( status_code=HTTP_403_FORBIDDEN, detail="API keys not configured on server" ) # Fast check with set operation if api_key not in set(valid_api_keys): raise HTTPException( status_code=HTTP_403_FORBIDDEN, detail="Invalid API key" ) return True # Pre-load and cache models.json @lru_cache(maxsize=1) def load_models_data(): try: file_path = Path(__file__).parent / 'models.json' with open(file_path, 'r') as f: return json.load(f) except (FileNotFoundError, json.JSONDecodeError) as e: print(f"Error loading models.json: {str(e)}") return [] # Async wrapper for models data async def get_models(): models_data = load_models_data() if not models_data: raise HTTPException(status_code=500, detail="Error loading available models") return models_data # Searcher function with optimized streaming - moved to a separate thread async def generate_search_async(query: str, systemprompt: Optional[str] = None, stream: bool = True): loop = asyncio.get_running_loop() def _generate_search(): headers = {"User-Agent": ""} # Use the provided system prompt, or default to "Be Helpful and Friendly" system_message = systemprompt or "Be Helpful and Friendly" # Create the prompt history with the user query and system message prompt = [ {"role": "user", "content": query}, ] prompt.insert(0, {"content": system_message, "role": "system"}) # Prepare the payload for the API request payload = { "is_vscode_extension": True, "message_history": prompt, "requested_model": "searchgpt", "user_input": prompt[-1]["content"], } # Get endpoint from environment secret_api_endpoint_3 = get_env_vars()['secret_api_endpoint_3'] if not secret_api_endpoint_3: raise ValueError("Search API endpoint not configured") # Send the request to the chat endpoint using a scraper from the pool response = get_scraper().post( secret_api_endpoint_3, headers=headers, json=payload, stream=True ) result = [] streaming_text = "" # Process the streaming response for value in response.iter_lines(decode_unicode=True): if value.startswith("data: "): try: json_modified_value = json.loads(value[6:]) content = json_modified_value.get("choices", [{}])[0].get("delta", {}).get("content", "") if content.strip(): # Only process non-empty content cleaned_response = { "created": json_modified_value.get("created"), "id": json_modified_value.get("id"), "model": "searchgpt", "object": "chat.completion", "choices": [ { "message": { "content": content } } ] } if stream: result.append(f"data: {json.dumps(cleaned_response)}\n\n") streaming_text += content except json.JSONDecodeError: continue if not stream: result.append(streaming_text) return result # Run in thread pool to avoid blocking the event loop return await loop.run_in_executor(executor, _generate_search) # Cache for frequently accessed static files @lru_cache(maxsize=10) def read_html_file(file_path): try: with open(file_path, "r") as file: return file.read() except FileNotFoundError: return None # Basic routes @app.get("/favicon.ico") async def favicon(): favicon_path = Path(__file__).parent / "favicon.ico" return FileResponse(favicon_path, media_type="image/x-icon") @app.get("/ping") async def ping(): return {"message": "pong", "response_time": "0.000000 seconds"} @app.get("/", response_class=HTMLResponse) async def root(): html_content = read_html_file("index.html") if html_content is None: return HTMLResponse(content="

File not found

", status_code=404) return HTMLResponse(content=html_content) @app.get("/playground", response_class=HTMLResponse) async def playground(): html_content = read_html_file("playground.html") if html_content is None: return HTMLResponse(content="

playground.html not found

", status_code=404) return HTMLResponse(content=html_content) # Model routes @app.get("/api/v1/models") @app.get("/models") async def return_models(): return await get_models() # Search routes @app.get("/searchgpt") async def search_gpt(q: str, stream: Optional[bool] = False, systemprompt: Optional[str] = None): if not q: raise HTTPException(status_code=400, detail="Query parameter 'q' is required") usage_tracker.record_request(endpoint="/searchgpt") result = await generate_search_async(q, systemprompt=systemprompt, stream=stream) if stream: async def stream_generator(): for chunk in result: yield chunk return StreamingResponse( stream_generator(), media_type="text/event-stream" ) else: # For non-streaming, return the collected text return JSONResponse(content={"response": result[0] if result else ""}) # Chat completion endpoint @app.post("/chat/completions") @app.post("/api/v1/chat/completions") async def get_completion(payload: Payload, request: Request, authenticated: bool = Depends(verify_api_key)): # Check server status if not server_status: return JSONResponse( status_code=503, content={"message": "Server is under maintenance. Please try again later."} ) model_to_use = payload.model or "gpt-4o-mini" # Validate model availability - fast lookup with set if available_model_ids and model_to_use not in set(available_model_ids): raise HTTPException( status_code=400, detail=f"Model '{model_to_use}' is not available. Check /models for the available model list." ) # Log request without blocking asyncio.create_task(log_request(request, model_to_use)) usage_tracker.record_request(model=model_to_use, endpoint="/chat/completions") # Prepare payload payload_dict = payload.dict() payload_dict["model"] = model_to_use # Get environment variables env_vars = get_env_vars() # Select the appropriate endpoint (fast lookup with sets) if model_to_use in mistral_models: endpoint = env_vars['mistral_api'] custom_headers = { "Authorization": f"Bearer {env_vars['mistral_key']}" } elif model_to_use in alternate_models: endpoint = env_vars['secret_api_endpoint_2'] custom_headers = {} else: endpoint = env_vars['secret_api_endpoint'] custom_headers = {} # Get a scraper from the pool scraper = get_scraper() async def stream_generator(payload_dict): try: # Send POST request with the correct headers response = scraper.post( f"{endpoint}/v1/chat/completions", json=payload_dict, headers=custom_headers, stream=True ) # Handle response errors if response.status_code >= 400: error_messages = { 422: "Unprocessable entity. Check your payload.", 400: "Bad request. Verify input data.", 403: "Forbidden. You do not have access to this resource.", 404: "The requested resource was not found.", } detail = error_messages.get(response.status_code, f"Error code: {response.status_code}") raise HTTPException(status_code=response.status_code, detail=detail) # Stream response lines to the client - use buffer for efficiency buffer = [] buffer_size = 0 max_buffer = 8192 # 8KB buffer for line in response.iter_lines(): if line: decoded = line.decode('utf-8') + "\n" buffer.append(decoded) buffer_size += len(decoded) if buffer_size >= max_buffer: yield ''.join(buffer) buffer = [] buffer_size = 0 # Flush remaining buffer if buffer: yield ''.join(buffer) except Exception as e: # Use a generic error message that doesn't expose internal details raise HTTPException(status_code=500, detail="An error occurred while processing your request") return StreamingResponse(stream_generator(payload_dict), media_type="application/json") # Asynchronous logging function async def log_request(request, model): # Get minimal data for logging current_time = (datetime.datetime.utcnow() + datetime.timedelta(hours=5, minutes=30)).strftime("%Y-%m-%d %I:%M:%S %p") ip_hash = hash(request.client.host) % 10000 # Hash the IP for privacy print(f"Time: {current_time}, IP Hash: {ip_hash}, Model: {model}") # Image generation endpoint - optimized to use connection pool @app.api_route("/images/generations", methods=["GET", "POST"]) async def generate_image( prompt: Optional[str] = None, model: str = "flux", seed: Optional[int] = None, width: Optional[int] = None, height: Optional[int] = None, nologo: Optional[bool] = True, private: Optional[bool] = None, enhance: Optional[bool] = None, request: Request = None, authenticated: bool = Depends(verify_api_key) ): # Validate the image endpoint image_endpoint = get_env_vars()['image_endpoint'] if not image_endpoint: raise HTTPException(status_code=500, detail="Image endpoint not configured in environment variables.") usage_tracker.record_request(endpoint="/images/generations") # Handle GET and POST prompts if request.method == "POST": try: body = await request.json() prompt = body.get("prompt", "").strip() if not prompt: raise HTTPException(status_code=400, detail="Prompt cannot be empty") except Exception: raise HTTPException(status_code=400, detail="Invalid JSON payload") elif request.method == "GET": if not prompt or not prompt.strip(): raise HTTPException(status_code=400, detail="Prompt cannot be empty") prompt = prompt.strip() # Sanitize and encode the prompt encoded_prompt = httpx.QueryParams({'prompt': prompt}).get('prompt') # Construct the URL with the encoded prompt base_url = image_endpoint.rstrip('/') url = f"{base_url}/{encoded_prompt}" # Prepare query parameters with validation params = {} if model and isinstance(model, str): params['model'] = model if seed is not None and isinstance(seed, int): params['seed'] = seed if width is not None and isinstance(width, int) and 64 <= width <= 2048: params['width'] = width if height is not None and isinstance(height, int) and 64 <= height <= 2048: params['height'] = height if nologo is not None: params['nologo'] = str(nologo).lower() if private is not None: params['private'] = str(private).lower() if enhance is not None: params['enhance'] = str(enhance).lower() try: # Use the shared httpx client for connection pooling client = get_async_client() response = await client.get(url, params=params, follow_redirects=True) # Check for various error conditions if response.status_code != 200: error_messages = { 404: "Image generation service not found", 400: "Invalid parameters provided to image service", 429: "Too many requests to image service", } detail = error_messages.get(response.status_code, f"Image generation failed with status code {response.status_code}") raise HTTPException(status_code=response.status_code, detail=detail) # Verify content type content_type = response.headers.get('content-type', '') if not content_type.startswith('image/'): raise HTTPException( status_code=500, detail="Unexpected content type received from image service" ) # Use larger chunks for streaming for better performance async def stream_with_larger_chunks(): chunks = [] size = 0 async for chunk in response.aiter_bytes(chunk_size=16384): # Use 16KB chunks chunks.append(chunk) size += len(chunk) if size >= 65536: # Yield every 64KB yield b''.join(chunks) chunks = [] size = 0 if chunks: yield b''.join(chunks) return StreamingResponse( stream_with_larger_chunks(), media_type=content_type, headers={ 'Cache-Control': 'no-cache, no-store, must-revalidate', 'Pragma': 'no-cache', 'Expires': '0' } ) except httpx.TimeoutException: raise HTTPException(status_code=504, detail="Image generation request timed out") except httpx.RequestError: raise HTTPException(status_code=500, detail="Failed to contact image service") except Exception: raise HTTPException(status_code=500, detail="Unexpected error during image generation") # Meme endpoint with optimized networking @app.get("/meme") async def get_meme(): try: # Use the shared client for connection pooling client = get_async_client() response = await client.get("https://meme-api.com/gimme") response_data = response.json() meme_url = response_data.get("url") if not meme_url: raise HTTPException(status_code=404, detail="No meme found") image_response = await client.get(meme_url, follow_redirects=True) # Use larger chunks for streaming async def stream_with_larger_chunks(): chunks = [] size = 0 async for chunk in image_response.aiter_bytes(chunk_size=16384): chunks.append(chunk) size += len(chunk) if size >= 65536: yield b''.join(chunks) chunks = [] size = 0 if chunks: yield b''.join(chunks) return StreamingResponse( stream_with_larger_chunks(), media_type=image_response.headers.get("content-type", "image/png"), headers={'Cache-Control': 'max-age=3600'} # Add caching ) except Exception: raise HTTPException(status_code=500, detail="Failed to retrieve meme") # Cache usage statistics @lru_cache(maxsize=10) def get_usage_summary(days=7): return usage_tracker.get_usage_summary(days) @app.get("/usage") async def get_usage(days: int = 7): """Retrieve usage statistics""" return get_usage_summary(days) # Generate HTML for usage page def generate_usage_html(usage_data): # Model Usage Table Rows model_usage_rows = "\n".join([ f""" {model} {model_data['total_requests']} {model_data['first_used']} {model_data['last_used']} """ for model, model_data in usage_data['models'].items() ]) # API Endpoint Usage Table Rows api_usage_rows = "\n".join([ f""" {endpoint} {endpoint_data['total_requests']} {endpoint_data['first_used']} {endpoint_data['last_used']} """ for endpoint, endpoint_data in usage_data['api_endpoints'].items() ]) # Daily Usage Table Rows daily_usage_rows = "\n".join([ "\n".join([ f""" {date} {entity} {requests} """ for entity, requests in date_data.items() ]) for date, date_data in usage_data['recent_daily_usage'].items() ]) html_content = f""" Lokiai AI - Usage Statistics
Total API Requests: {usage_data['total_requests']}

Model Usage

{model_usage_rows}
Model Total Requests First Used Last Used

API Endpoint Usage

{api_usage_rows}
Endpoint Total Requests First Used Last Used

Daily Usage (Last 7 Days)

{daily_usage_rows}
Date Entity Requests
""" return html_content # Cache the usage page HTML @lru_cache(maxsize=1) def get_usage_page_html(): usage_data = get_usage_summary() return generate_usage_html(usage_data) @app.get("/usage/page", response_class=HTMLResponse) async def usage_page(): """Serve an HTML page showing usage statistics""" # Use cached HTML if available, regenerate if not html_content = get_usage_page_html() return HTMLResponse(content=html_content) # Utility function for loading model IDs - optimized to run once at startup def load_model_ids(json_file_path): try: with open(json_file_path, 'r') as f: models_data = json.load(f) # Extract 'id' from each model object and use a set for fast lookups return [model['id'] for model in models_data if 'id' in model] except Exception as e: print(f"Error loading model IDs: {str(e)}") return [] @app.on_event("startup") async def startup_event(): global available_model_ids available_model_ids = load_model_ids("models.json") print(f"Loaded {len(available_model_ids)} model IDs") # Preload scrapers for _ in range(MAX_SCRAPERS): scraper_pool.append(cloudscraper.create_scraper()) # Validate critical environment variables env_vars = get_env_vars() missing_vars = [] if not env_vars['secret_api_endpoint']: missing_vars.append('SECRET_API_ENDPOINT') if not env_vars['secret_api_endpoint_2']: missing_vars.append('SECRET_API_ENDPOINT_2') if not env_vars['secret_api_endpoint_3']: missing_vars.append('SECRET_API_ENDPOINT_3') if missing_vars: print(f"WARNING: The following required environment variables are missing: {', '.join(missing_vars)}") print("API started successfully with high-performance optimizations") if __name__ == "__main__": import uvicorn uvicorn.run( app, host="0.0.0.0", port=7860, workers=4, # Multiple workers for better CPU utilization loop="uvloop", # Use uvloop for faster async operations http="httptools", # Faster HTTP parsing log_level="warning", # Reduce logging overhead limit_concurrency=100, # Limit concurrent connections timeout_keep_alive=5 # Reduce idle connection time )