import os import re from dotenv import load_dotenv from fastapi import FastAPI, HTTPException, Request, Depends, Security from fastapi.responses import StreamingResponse, HTMLResponse, JSONResponse, FileResponse from fastapi.security import APIKeyHeader from pydantic import BaseModel import httpx from functools import lru_cache from pathlib import Path import json import datetime import time import threading from typing import Optional, Dict, List, Any, Generator import asyncio from starlette.status import HTTP_403_FORBIDDEN import cloudscraper from concurrent.futures import ThreadPoolExecutor import uvloop from fastapi.middleware.gzip import GZipMiddleware from starlette.middleware.cors import CORSMiddleware import contextlib import requests # Enable uvloop for faster event loop asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) # Thread pool for CPU-bound operations executor = ThreadPoolExecutor(max_workers=16) # Increased thread count for better parallelism # Load environment variables once at startup load_dotenv() # API key security scheme api_key_header = APIKeyHeader(name="Authorization", auto_error=False) # Initialize usage tracker from usage_tracker import UsageTracker usage_tracker = UsageTracker() app = FastAPI() # Add middleware for compression and CORS app.add_middleware(GZipMiddleware, minimum_size=1000) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Environment variables (cached) @lru_cache(maxsize=1) def get_env_vars(): return { 'api_keys': os.getenv('API_KEYS', '').split(','), 'secret_api_endpoint': os.getenv('SECRET_API_ENDPOINT'), 'secret_api_endpoint_2': os.getenv('SECRET_API_ENDPOINT_2'), 'secret_api_endpoint_3': os.getenv('SECRET_API_ENDPOINT_3'), 'secret_api_endpoint_4': "https://text.pollinations.ai/openai", 'secret_api_endpoint_5': os.getenv('SECRET_API_ENDPOINT_5'), # Added new endpoint 'mistral_api': "https://api.mistral.ai", 'mistral_key': os.getenv('MISTRAL_KEY'), 'endpoint_origin': os.getenv('ENDPOINT_ORIGIN') } # Configuration for models - use sets for faster lookups mistral_models = { "mistral-large-latest", "pixtral-large-latest", "mistral-moderation-latest", "ministral-3b-latest", "ministral-8b-latest", "open-mistral-nemo", "mistral-small-latest", "mistral-saba-latest", "codestral-latest" } pollinations_models = { "openai", "openai-large", "openai-xlarge", "openai-reasoning", "qwen-coder", "llama", "mistral", "searchgpt", "deepseek", "claude-hybridspace", "deepseek-r1", "deepseek-reasoner", "llamalight", "gemini", "gemini-thinking", "hormoz", "phi", "phi-mini", "openai-audio", "llama-scaleway" } alternate_models = { "o1", "llama-4-scout", "o4-mini", "sonar", "sonar-pro", "sonar-reasoning", "sonar-reasoning-pro", "grok-3", "grok-3-fast", "r1-1776", "o3" } claude_3_models = { # Models for the new endpoint "claude-3-7-sonnet", "claude-3-7-sonnet-thinking", "claude 3.5 haiku", "claude 3.5 sonnet", "claude 3.5 haiku", "o3-mini-medium", "o3-mini-high", "grok-3", "grok-3-thinking", "grok 2" } # Supported image generation models supported_image_models = { "Flux Pro Ultra", "grok-2-aurora", "Flux Pro", "Flux Pro Ultra Raw", "Flux Dev", "Flux Schnell", "stable-diffusion-3-large-turbo", "Flux Realism", "stable-diffusion-ultra", "dall-e-3", "sdxl-lightning-4step" } # Request payload model class Payload(BaseModel): model: str messages: list stream: bool = False # Image generation payload model class ImageGenerationPayload(BaseModel): model: str prompt: str size: int number: int # Server status global variable server_status = True available_model_ids: List[str] = [] # Create a reusable httpx client pool with connection pooling @lru_cache(maxsize=1) def get_async_client(): return httpx.AsyncClient( timeout=60.0, limits=httpx.Limits(max_keepalive_connections=50, max_connections=200) # Increased limits ) # Create a cloudscraper pool scraper_pool = [] MAX_SCRAPERS = 20 # Increased pool size def get_scraper(): if not scraper_pool: for _ in range(MAX_SCRAPERS): scraper_pool.append(cloudscraper.create_scraper()) return scraper_pool[int(time.time() * 1000) % MAX_SCRAPERS] # Simple round-robin # API key validation - optimized to avoid string operations when possible async def verify_api_key( request: Request, api_key: str = Security(api_key_header) ) -> bool: # Allow bypass if the referer is from /playground or /image-playground referer = request.headers.get("referer", "") if referer.startswith(("https://parthsadaria-lokiai.hf.space/playground", "https://parthsadaria-lokiai.hf.space/image-playground")): return True if not api_key: raise HTTPException( status_code=HTTP_403_FORBIDDEN, detail="No API key provided" ) # Only clean if needed if api_key.startswith('Bearer '): api_key = api_key[7:] # Remove 'Bearer ' prefix # Get API keys from environment valid_api_keys = get_env_vars().get('api_keys', []) if not valid_api_keys or valid_api_keys == ['']: raise HTTPException( status_code=HTTP_403_FORBIDDEN, detail="API keys not configured on server" ) # Fast check with set operation if api_key not in set(valid_api_keys): raise HTTPException( status_code=HTTP_403_FORBIDDEN, detail="Invalid API key" ) return True # Pre-load and cache models.json @lru_cache(maxsize=1) def load_models_data(): try: file_path = Path(__file__).parent / 'models.json' with open(file_path, 'r') as f: return json.load(f) except (FileNotFoundError, json.JSONDecodeError) as e: print(f"Error loading models.json: {str(e)}") return [] # Async wrapper for models data async def get_models(): models_data = load_models_data() if not models_data: raise HTTPException(status_code=500, detail="Error loading available models") return models_data # Enhanced async streaming - now with real-time SSE support async def generate_search_async(query: str, systemprompt: Optional[str] = None, stream: bool = True): # Create a streaming response channel using asyncio.Queue queue = asyncio.Queue() async def _fetch_search_data(): try: headers = {"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"} # Use the provided system prompt, or default to "Be Helpful and Friendly" system_message = systemprompt or "Be Helpful and Friendly" # Create the prompt history prompt = [ {"role": "user", "content": query}, ] prompt.insert(0, {"content": system_message, "role": "system"}) # Prepare the payload for the API request payload = { "is_vscode_extension": True, "message_history": prompt, "requested_model": "searchgpt", "user_input": prompt[-1]["content"], } # Get endpoint from environment secret_api_endpoint_3 = get_env_vars()['secret_api_endpoint_3'] if not secret_api_endpoint_3: await queue.put({"error": "Search API endpoint not configured"}) return # Use AsyncClient for better performance async with httpx.AsyncClient(timeout=30.0) as client: async with client.stream("POST", secret_api_endpoint_3, json=payload, headers=headers) as response: if response.status_code != 200: await queue.put({"error": f"Search API returned status code {response.status_code}"}) return # Process the streaming response in real-time buffer = "" async for line in response.aiter_lines(): if line.startswith("data: "): try: json_data = json.loads(line[6:]) content = json_data.get("choices", [{}])[0].get("delta", {}).get("content", "") if content.strip(): cleaned_response = { "created": json_data.get("created"), "id": json_data.get("id"), "model": "searchgpt", "object": "chat.completion", "choices": [ { "message": { "content": content } } ] } # Send to queue immediately for streaming await queue.put({"data": f"data: {json.dumps(cleaned_response)}\n\n", "text": content}) except json.JSONDecodeError: continue # Signal completion await queue.put(None) except Exception as e: await queue.put({"error": str(e)}) await queue.put(None) # Start the fetch process asyncio.create_task(_fetch_search_data()) # Return the queue for consumption return queue # Cache for frequently accessed static files @lru_cache(maxsize=10) def read_html_file(file_path): try: with open(file_path, "r") as file: return file.read() except FileNotFoundError: return None # Basic routes @app.get("/favicon.ico") async def favicon(): favicon_path = Path(__file__).parent / "favicon.ico" return FileResponse(favicon_path, media_type="image/x-icon") @app.get("/banner.jpg") async def favicon(): favicon_path = Path(__file__).parent / "banner.jpg" return FileResponse(favicon_path, media_type="image/x-icon") @app.get("/ping") async def ping(): return {"message": "pong", "response_time": "0.000000 seconds"} @app.get("/", response_class=HTMLResponse) async def root(): html_content = read_html_file("index.html") if html_content is None: return HTMLResponse(content="
Model | Total Requests | First Used | Last Used |
---|
Endpoint | Total Requests | First Used | Last Used |
---|
Date | Entity | Requests |
---|
", f"" ) return HTMLResponse(content=final_html) # Model routes @app.get("/api/v1/models") @app.get("/models") async def return_models(): return await get_models() # Search routes with enhanced real-time streaming @app.get("/searchgpt") async def search_gpt(q: str, stream: Optional[bool] = False, systemprompt: Optional[str] = None): if not q: raise HTTPException(status_code=400, detail="Query parameter 'q' is required") usage_tracker.record_request(endpoint="/searchgpt") queue = await generate_search_async(q, systemprompt=systemprompt, stream=True) if stream: async def stream_generator(): collected_text = "" while True: item = await queue.get() if item is None: break if "error" in item: yield f"data: {json.dumps({'error': item['error']})}\n\n" break if "data" in item: yield item["data"] collected_text += item.get("text", "") return StreamingResponse( stream_generator(), media_type="text/event-stream" ) else: # For non-streaming, collect all text and return at once collected_text = "" while True: item = await queue.get() if item is None: break if "error" in item: raise HTTPException(status_code=500, detail=item["error"]) collected_text += item.get("text", "") return JSONResponse(content={"response": collected_text}) # Enhanced streaming with direct SSE pass-through for real-time responses header_url = os.getenv('HEADER_URL') @app.post("/chat/completions") @app.post("/api/v1/chat/completions") async def get_completion(payload: Payload, request: Request, authenticated: bool = Depends(verify_api_key)): # Check server status if not server_status: return JSONResponse( status_code=503, content={"message": "Server is under maintenance. Please try again later."} ) model_to_use = payload.model or "gpt-4o-mini" # Validate model availability - fast lookup with set if available_model_ids and model_to_use not in set(available_model_ids): raise HTTPException( status_code=400, detail=f"Model '{model_to_use}' is not available. Check /models for the available model list." ) # Log request without blocking asyncio.create_task(log_request(request, model_to_use)) usage_tracker.record_request(model=model_to_use, endpoint="/chat/completions") # Prepare payload payload_dict = payload.dict() payload_dict["model"] = model_to_use # Ensure stream is True for real-time streaming (can be overridden by client) stream_enabled = payload_dict.get("stream", True) # Get environment variables env_vars = get_env_vars() # Select the appropriate endpoint (fast lookup with sets) if model_to_use in mistral_models: endpoint = env_vars['mistral_api'] custom_headers = { "Authorization": f"Bearer {env_vars['mistral_key']}" } elif model_to_use in pollinations_models: endpoint = env_vars['secret_api_endpoint_4'] custom_headers = {} elif model_to_use in alternate_models: endpoint = env_vars['secret_api_endpoint_2'] custom_headers = {} elif model_to_use in claude_3_models: # Use the new endpoint endpoint = env_vars['secret_api_endpoint_5'] custom_headers = {} else: endpoint = env_vars['secret_api_endpoint'] custom_headers = { "Origin": header_url, "Priority": "u=1, i", "Referer": header_url } print(f"Using endpoint: {endpoint} for model: {model_to_use}") # Improved real-time streaming handler async def real_time_stream_generator(): try: async with httpx.AsyncClient(timeout=60.0) as client: async with client.stream("POST", f"{endpoint}/v1/chat/completions", json=payload_dict, headers=custom_headers) as response: if response.status_code >= 400: error_messages = { 422: "Unprocessable entity. Check your payload.", 400: "Bad request. Verify input data.", 403: "Forbidden. You do not have access to this resource.", 404: "The requested resource was not found.", } detail = error_messages.get(response.status_code, f"Error code: {response.status_code}") raise HTTPException(status_code=response.status_code, detail=detail) # Stream the response in real-time with minimal buffering async for line in response.aiter_lines(): if line: # Yield immediately for faster streaming yield line + "\n" except httpx.TimeoutException: raise HTTPException(status_code=504, detail="Request timed out") except httpx.RequestError as e: raise HTTPException(status_code=502, detail=f"Failed to connect to upstream API: {str(e)}") except Exception as e: if isinstance(e, HTTPException): raise e raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}") # Return streaming response with proper headers if stream_enabled: return StreamingResponse( real_time_stream_generator(), media_type="text/event-stream", headers={ "Content-Type": "text/event-stream", "Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no" # Disable proxy buffering for Nginx } ) else: # For non-streaming requests, collect the entire response response_content = [] async for chunk in real_time_stream_generator(): response_content.append(chunk) return JSONResponse(content=json.loads(''.join(response_content))) # New image generation endpoint @app.post("/images/generations") async def create_image(payload: ImageGenerationPayload, authenticated: bool = Depends(verify_api_key)): """ Endpoint for generating images based on a text prompt. """ # Check server status if not server_status: return JSONResponse( status_code=503, content={"message": "Server is under maintenance. Please try again later."} ) # Validate model if payload.model not in supported_image_models: raise HTTPException( status_code=400, detail=f"Model '{payload.model}' is not supported for image generation. Supported models are: {supported_image_models}" ) # Log the request usage_tracker.record_request(model=payload.model, endpoint="/images/generations") # Prepare the payload for the external API api_payload = { "model": payload.model, "prompt": payload.prompt, "size": payload.size, "number": payload.number } # Target API endpoint target_api_url = os.getenv('NEW_IMG') try: # Use a timeout for the image generation request async with httpx.AsyncClient(timeout=60.0) as client: response = await client.post(target_api_url, json=api_payload) if response.status_code != 200: error_detail = response.json().get("detail", f"Image generation failed with status code: {response.status_code}") raise HTTPException(status_code=response.status_code, detail=error_detail) # Return the response from the external API return JSONResponse(content=response.json()) except httpx.TimeoutException: raise HTTPException(status_code=504, detail="Image generation request timed out.") except httpx.RequestError as e: raise HTTPException(status_code=502, detail=f"Error connecting to image generation service: {e}") except Exception as e: raise HTTPException(status_code=500, detail=f"An unexpected error occurred during image generation: {e}") # Asynchronous logging function async def log_request(request, model): # Get minimal data for logging current_time = (datetime.datetime.utcnow() + datetime.timedelta(hours=5, minutes=30)).strftime("%Y-%m-%d %I:%M:%S %p") ip_hash = hash(request.client.host) % 10000 # Hash the IP for privacy print(f"Time: {current_time}, IP Hash: {ip_hash}, Model: {model}") # Cache usage statistics @lru_cache(maxsize=10) def get_usage_summary(days=7): return usage_tracker.get_usage_summary(days) @app.get("/usage") async def get_usage(days: int = 7): """Retrieve usage statistics""" return get_usage_summary(days) # Generate HTML for usage page def generate_usage_html(usage_data): # Model Usage Table Rows model_usage_rows = "\n".join([ f"""
""" for model, model_data in usage_data['models'].items() ]) # API Endpoint Usage Table Rows api_usage_rows = "\n".join([ f"""
""" for endpoint, endpoint_data in usage_data['api_endpoints'].items() ]) # Daily Usage Table Rows daily_usage_rows = "\n".join([ "\n".join([ f"""
""" for entity, requests in date_data.items() ]) for date, date_data in usage_data['recent_daily_usage'].items() ]) html_content = f"""