import os import re from dotenv import load_dotenv from fastapi import FastAPI, HTTPException, Request, Depends, Security from fastapi.responses import StreamingResponse, HTMLResponse, JSONResponse, FileResponse from fastapi.security import APIKeyHeader from pydantic import BaseModel import httpx from functools import lru_cache from pathlib import Path import json import datetime import time import threading from typing import Optional, Dict, List, Any, Generator import asyncio from starlette.status import HTTP_403_FORBIDDEN import cloudscraper from concurrent.futures import ThreadPoolExecutor import uvloop from fastapi.middleware.gzip import GZipMiddleware from starlette.middleware.cors import CORSMiddleware import contextlib import requests # Enable uvloop for faster event loop asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) # Thread pool for CPU-bound operations executor = ThreadPoolExecutor(max_workers=16) # Increased thread count for better parallelism # Load environment variables once at startup load_dotenv() # API key security scheme api_key_header = APIKeyHeader(name="Authorization", auto_error=False) # Initialize usage tracker from usage_tracker import UsageTracker usage_tracker = UsageTracker() app = FastAPI() # Add middleware for compression and CORS app.add_middleware(GZipMiddleware, minimum_size=1000) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Environment variables (cached) @lru_cache(maxsize=1) def get_env_vars(): return { 'api_keys': os.getenv('API_KEYS', '').split(','), 'secret_api_endpoint': os.getenv('SECRET_API_ENDPOINT'), 'secret_api_endpoint_2': os.getenv('SECRET_API_ENDPOINT_2'), 'secret_api_endpoint_3': os.getenv('SECRET_API_ENDPOINT_3'), 'secret_api_endpoint_4': "https://text.pollinations.ai/openai", 'secret_api_endpoint_5': os.getenv('SECRET_API_ENDPOINT_5'), # Added new endpoint 'mistral_api': "https://api.mistral.ai", 'mistral_key': os.getenv('MISTRAL_KEY'), 'endpoint_origin': os.getenv('ENDPOINT_ORIGIN') } # Configuration for models - use sets for faster lookups mistral_models = { "mistral-large-latest", "pixtral-large-latest", "mistral-moderation-latest", "ministral-3b-latest", "ministral-8b-latest", "open-mistral-nemo", "mistral-small-latest", "mistral-saba-latest", "codestral-latest" } pollinations_models = { "openai", "openai-large", "openai-xlarge", "openai-reasoning", "qwen-coder", "llama", "mistral", "searchgpt", "deepseek", "claude-hybridspace", "deepseek-r1", "deepseek-reasoner", "llamalight", "gemini", "gemini-thinking", "hormoz", "phi", "phi-mini", "openai-audio", "llama-scaleway" } alternate_models = { "o1", "llama-4-scout", "o4-mini", "sonar", "sonar-pro", "sonar-reasoning", "sonar-reasoning-pro", "grok-3", "grok-3-fast", "r1-1776", "o3" } claude_3_models = { # Models for the new endpoint "claude-3-7-sonnet", "claude-3-7-sonnet-thinking", "claude 3.5 haiku", "claude 3.5 sonnet", "claude 3.5 haiku", "o3-mini-medium", "o3-mini-high", "grok-3", "grok-3-thinking", "grok 2" } # Supported image generation models supported_image_models = { "Flux Pro Ultra", "grok-2-aurora", "Flux Pro", "Flux Pro Ultra Raw", "Flux Dev", "Flux Schnell", "stable-diffusion-3-large-turbo", "Flux Realism", "stable-diffusion-ultra", "dall-e-3", "sdxl-lightning-4step" } # Request payload model class Payload(BaseModel): model: str messages: list stream: bool = False # Image generation payload model class ImageGenerationPayload(BaseModel): model: str prompt: str size: int number: int # Server status global variable server_status = True available_model_ids: List[str] = [] # Create a reusable httpx client pool with connection pooling @lru_cache(maxsize=1) def get_async_client(): return httpx.AsyncClient( timeout=60.0, limits=httpx.Limits(max_keepalive_connections=50, max_connections=200) # Increased limits ) # Create a cloudscraper pool scraper_pool = [] MAX_SCRAPERS = 20 # Increased pool size def get_scraper(): if not scraper_pool: for _ in range(MAX_SCRAPERS): scraper_pool.append(cloudscraper.create_scraper()) return scraper_pool[int(time.time() * 1000) % MAX_SCRAPERS] # Simple round-robin # API key validation - optimized to avoid string operations when possible async def verify_api_key( request: Request, api_key: str = Security(api_key_header) ) -> bool: # Allow bypass if the referer is from /playground or /image-playground referer = request.headers.get("referer", "") if referer.startswith(("https://parthsadaria-lokiai.hf.space/playground", "https://parthsadaria-lokiai.hf.space/image-playground")): return True if not api_key: raise HTTPException( status_code=HTTP_403_FORBIDDEN, detail="No API key provided" ) # Only clean if needed if api_key.startswith('Bearer '): api_key = api_key[7:] # Remove 'Bearer ' prefix # Get API keys from environment valid_api_keys = get_env_vars().get('api_keys', []) if not valid_api_keys or valid_api_keys == ['']: raise HTTPException( status_code=HTTP_403_FORBIDDEN, detail="API keys not configured on server" ) # Fast check with set operation if api_key not in set(valid_api_keys): raise HTTPException( status_code=HTTP_403_FORBIDDEN, detail="Invalid API key" ) return True # Pre-load and cache models.json @lru_cache(maxsize=1) def load_models_data(): try: file_path = Path(__file__).parent / 'models.json' with open(file_path, 'r') as f: return json.load(f) except (FileNotFoundError, json.JSONDecodeError) as e: print(f"Error loading models.json: {str(e)}") return [] # Async wrapper for models data async def get_models(): models_data = load_models_data() if not models_data: raise HTTPException(status_code=500, detail="Error loading available models") return models_data # Enhanced async streaming - now with real-time SSE support async def generate_search_async(query: str, systemprompt: Optional[str] = None, stream: bool = True): # Create a streaming response channel using asyncio.Queue queue = asyncio.Queue() async def _fetch_search_data(): try: headers = {"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"} # Use the provided system prompt, or default to "Be Helpful and Friendly" system_message = systemprompt or "Be Helpful and Friendly" # Create the prompt history prompt = [ {"role": "user", "content": query}, ] prompt.insert(0, {"content": system_message, "role": "system"}) # Prepare the payload for the API request payload = { "is_vscode_extension": True, "message_history": prompt, "requested_model": "searchgpt", "user_input": prompt[-1]["content"], } # Get endpoint from environment secret_api_endpoint_3 = get_env_vars()['secret_api_endpoint_3'] if not secret_api_endpoint_3: await queue.put({"error": "Search API endpoint not configured"}) return # Use AsyncClient for better performance async with httpx.AsyncClient(timeout=30.0) as client: async with client.stream("POST", secret_api_endpoint_3, json=payload, headers=headers) as response: if response.status_code != 200: await queue.put({"error": f"Search API returned status code {response.status_code}"}) return # Process the streaming response in real-time buffer = "" async for line in response.aiter_lines(): if line.startswith("data: "): try: json_data = json.loads(line[6:]) content = json_data.get("choices", [{}])[0].get("delta", {}).get("content", "") if content.strip(): cleaned_response = { "created": json_data.get("created"), "id": json_data.get("id"), "model": "searchgpt", "object": "chat.completion", "choices": [ { "message": { "content": content } } ] } # Send to queue immediately for streaming await queue.put({"data": f"data: {json.dumps(cleaned_response)}\n\n", "text": content}) except json.JSONDecodeError: continue # Signal completion await queue.put(None) except Exception as e: await queue.put({"error": str(e)}) await queue.put(None) # Start the fetch process asyncio.create_task(_fetch_search_data()) # Return the queue for consumption return queue # Cache for frequently accessed static files @lru_cache(maxsize=10) def read_html_file(file_path): try: with open(file_path, "r") as file: return file.read() except FileNotFoundError: return None # Basic routes @app.get("/favicon.ico") async def favicon(): favicon_path = Path(__file__).parent / "favicon.ico" return FileResponse(favicon_path, media_type="image/x-icon") @app.get("/banner.jpg") async def favicon(): favicon_path = Path(__file__).parent / "banner.jpg" return FileResponse(favicon_path, media_type="image/x-icon") @app.get("/ping") async def ping(): return {"message": "pong", "response_time": "0.000000 seconds"} @app.get("/", response_class=HTMLResponse) async def root(): html_content = read_html_file("index.html") if html_content is None: return HTMLResponse(content="

File not found

", status_code=404) return HTMLResponse(content=html_content) @app.get("/script.js", response_class=HTMLResponse) async def root(): html_content = read_html_file("script.js") if html_content is None: return HTMLResponse(content="

File not found

", status_code=404) return HTMLResponse(content=html_content) @app.get("/style.css", response_class=HTMLResponse) async def root(): html_content = read_html_file("style.css") if html_content is None: return HTMLResponse(content="

File not found

", status_code=404) return HTMLResponse(content=html_content) @app.get("/dynamo", response_class=HTMLResponse) async def dynamic_ai_page(request: Request): user_agent = request.headers.get('user-agent', 'Unknown User') client_ip = request.client.host location = f"IP: {client_ip}" prompt = f""" Generate a dynamic HTML page for a user with the following details: with name "LOKI.AI" - User-Agent: {user_agent} - Location: {location} - Style: Cyberpunk, minimalist, or retro Make sure the HTML is clean and includes a heading, also have cool animations a motivational message, and a cool background. Wrap the generated HTML in triple backticks (```). """ payload = { "model": "mistral-small-latest", "messages": [{"role": "user", "content": prompt}] } headers = { "Authorization": "Bearer playground" } response = requests.post("https://parthsadaria-lokiai.hf.space/chat/completions", json=payload, headers=headers) data = response.json() # Extract HTML from ``` blocks html_content = re.search(r"```(.*?)```", data['choices'][0]['message']['content'], re.DOTALL) if html_content: html_content = html_content.group(1).strip() # Remove the first word if html_content: html_content = ' '.join(html_content.split(' ')[1:]) return HTMLResponse(content=html_content) @app.get("/playground", response_class=HTMLResponse) async def playground(): html_content = read_html_file("playground.html") if html_content is None: return HTMLResponse(content="

playground.html not found

", status_code=404) return HTMLResponse(content=html_content) @app.get("/image-playground", response_class=HTMLResponse) async def playground(): html_content = read_html_file("image-playground.html") if html_content is None: return HTMLResponse(content="

image-playground.html not found

", status_code=404) return HTMLResponse(content=html_content) # VETRA GITHUB_BASE = "https://raw.githubusercontent.com/Parthsadaria/Vetra/main" FILES = { "html": "index.html", "css": "style.css", "js": "script.js" } async def get_github_file(filename: str) -> str: url = f"{GITHUB_BASE}/{filename}" async with httpx.AsyncClient() as client: res = await client.get(url) return res.text if res.status_code == 200 else None @app.get("/vetra", response_class=HTMLResponse) async def serve_vetra(): html = await get_github_file(FILES["html"]) css = await get_github_file(FILES["css"]) js = await get_github_file(FILES["js"]) if not html: return HTMLResponse(content="

index.html not found on GitHub

", status_code=404) final_html = html.replace( "", f"" ).replace( "", f"" ) return HTMLResponse(content=final_html) # Model routes @app.get("/api/v1/models") @app.get("/models") async def return_models(): return await get_models() # Search routes with enhanced real-time streaming @app.get("/searchgpt") async def search_gpt(q: str, stream: Optional[bool] = False, systemprompt: Optional[str] = None): if not q: raise HTTPException(status_code=400, detail="Query parameter 'q' is required") usage_tracker.record_request(endpoint="/searchgpt") queue = await generate_search_async(q, systemprompt=systemprompt, stream=True) if stream: async def stream_generator(): collected_text = "" while True: item = await queue.get() if item is None: break if "error" in item: yield f"data: {json.dumps({'error': item['error']})}\n\n" break if "data" in item: yield item["data"] collected_text += item.get("text", "") return StreamingResponse( stream_generator(), media_type="text/event-stream" ) else: # For non-streaming, collect all text and return at once collected_text = "" while True: item = await queue.get() if item is None: break if "error" in item: raise HTTPException(status_code=500, detail=item["error"]) collected_text += item.get("text", "") return JSONResponse(content={"response": collected_text}) # Enhanced streaming with direct SSE pass-through for real-time responses header_url = os.getenv('HEADER_URL') @app.post("/chat/completions") @app.post("/api/v1/chat/completions") async def get_completion(payload: Payload, request: Request, authenticated: bool = Depends(verify_api_key)): # Check server status if not server_status: return JSONResponse( status_code=503, content={"message": "Server is under maintenance. Please try again later."} ) model_to_use = payload.model or "gpt-4o-mini" # Validate model availability - fast lookup with set if available_model_ids and model_to_use not in set(available_model_ids): raise HTTPException( status_code=400, detail=f"Model '{model_to_use}' is not available. Check /models for the available model list." ) # Log request without blocking asyncio.create_task(log_request(request, model_to_use)) usage_tracker.record_request(model=model_to_use, endpoint="/chat/completions") # Prepare payload payload_dict = payload.dict() payload_dict["model"] = model_to_use # Ensure stream is True for real-time streaming (can be overridden by client) stream_enabled = payload_dict.get("stream", True) # Get environment variables env_vars = get_env_vars() # Select the appropriate endpoint (fast lookup with sets) if model_to_use in mistral_models: endpoint = env_vars['mistral_api'] custom_headers = { "Authorization": f"Bearer {env_vars['mistral_key']}" } elif model_to_use in pollinations_models: endpoint = env_vars['secret_api_endpoint_4'] custom_headers = {} elif model_to_use in alternate_models: endpoint = env_vars['secret_api_endpoint_2'] custom_headers = {} elif model_to_use in claude_3_models: # Use the new endpoint endpoint = env_vars['secret_api_endpoint_5'] custom_headers = {} else: endpoint = env_vars['secret_api_endpoint'] custom_headers = { "Origin": header_url, "Priority": "u=1, i", "Referer": header_url } print(f"Using endpoint: {endpoint} for model: {model_to_use}") # Improved real-time streaming handler async def real_time_stream_generator(): try: async with httpx.AsyncClient(timeout=60.0) as client: async with client.stream("POST", f"{endpoint}/v1/chat/completions", json=payload_dict, headers=custom_headers) as response: if response.status_code >= 400: error_messages = { 422: "Unprocessable entity. Check your payload.", 400: "Bad request. Verify input data.", 403: "Forbidden. You do not have access to this resource.", 404: "The requested resource was not found.", } detail = error_messages.get(response.status_code, f"Error code: {response.status_code}") raise HTTPException(status_code=response.status_code, detail=detail) # Stream the response in real-time with minimal buffering async for line in response.aiter_lines(): if line: # Yield immediately for faster streaming yield line + "\n" except httpx.TimeoutException: raise HTTPException(status_code=504, detail="Request timed out") except httpx.RequestError as e: raise HTTPException(status_code=502, detail=f"Failed to connect to upstream API: {str(e)}") except Exception as e: if isinstance(e, HTTPException): raise e raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}") # Return streaming response with proper headers if stream_enabled: return StreamingResponse( real_time_stream_generator(), media_type="text/event-stream", headers={ "Content-Type": "text/event-stream", "Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no" # Disable proxy buffering for Nginx } ) else: # For non-streaming requests, collect the entire response response_content = [] async for chunk in real_time_stream_generator(): response_content.append(chunk) return JSONResponse(content=json.loads(''.join(response_content))) # New image generation endpoint @app.post("/images/generations") async def create_image(payload: ImageGenerationPayload, authenticated: bool = Depends(verify_api_key)): """ Endpoint for generating images based on a text prompt. """ # Check server status if not server_status: return JSONResponse( status_code=503, content={"message": "Server is under maintenance. Please try again later."} ) # Validate model if payload.model not in supported_image_models: raise HTTPException( status_code=400, detail=f"Model '{payload.model}' is not supported for image generation. Supported models are: {supported_image_models}" ) # Log the request usage_tracker.record_request(model=payload.model, endpoint="/images/generations") # Prepare the payload for the external API api_payload = { "model": payload.model, "prompt": payload.prompt, "size": payload.size, "number": payload.number } # Target API endpoint target_api_url = os.getenv('NEW_IMG') try: # Use a timeout for the image generation request async with httpx.AsyncClient(timeout=60.0) as client: response = await client.post(target_api_url, json=api_payload) if response.status_code != 200: error_detail = response.json().get("detail", f"Image generation failed with status code: {response.status_code}") raise HTTPException(status_code=response.status_code, detail=error_detail) # Return the response from the external API return JSONResponse(content=response.json()) except httpx.TimeoutException: raise HTTPException(status_code=504, detail="Image generation request timed out.") except httpx.RequestError as e: raise HTTPException(status_code=502, detail=f"Error connecting to image generation service: {e}") except Exception as e: raise HTTPException(status_code=500, detail=f"An unexpected error occurred during image generation: {e}") # Asynchronous logging function async def log_request(request, model): # Get minimal data for logging current_time = (datetime.datetime.utcnow() + datetime.timedelta(hours=5, minutes=30)).strftime("%Y-%m-%d %I:%M:%S %p") ip_hash = hash(request.client.host) % 10000 # Hash the IP for privacy print(f"Time: {current_time}, IP Hash: {ip_hash}, Model: {model}") # Cache usage statistics @lru_cache(maxsize=10) def get_usage_summary(days=7): return usage_tracker.get_usage_summary(days) @app.get("/usage") async def get_usage(days: int = 7): """Retrieve usage statistics""" return get_usage_summary(days) # Generate HTML for usage page def generate_usage_html(usage_data): # Model Usage Table Rows model_usage_rows = "\n".join([ f""" {model} {model_data['total_requests']} {model_data['first_used']} {model_data['last_used']} """ for model, model_data in usage_data['models'].items() ]) # API Endpoint Usage Table Rows api_usage_rows = "\n".join([ f""" {endpoint} {endpoint_data['total_requests']} {endpoint_data['first_used']} {endpoint_data['last_used']} """ for endpoint, endpoint_data in usage_data['api_endpoints'].items() ]) # Daily Usage Table Rows daily_usage_rows = "\n".join([ "\n".join([ f""" {date} {entity} {requests} """ for entity, requests in date_data.items() ]) for date, date_data in usage_data['recent_daily_usage'].items() ]) html_content = f""" Lokiai AI - Usage Statistics
Total API Requests: {usage_data['total_requests']}

Model Usage

{model_usage_rows}
Model Total Requests First Used Last Used

API Endpoint Usage

{api_usage_rows}
Endpoint Total Requests First Used Last Used

Daily Usage (Last 7 Days)

{daily_usage_rows}
Date Entity Requests
""" return html_content # Cache the usage page HTML @lru_cache(maxsize=1) def get_usage_page_html(): usage_data = get_usage_summary() return generate_usage_html(usage_data) @app.get("/usage/page", response_class=HTMLResponse) async def usage_page(): """Serve an HTML page showing usage statistics""" # Use cached HTML if available, regenerate if not html_content = get_usage_page_html() return HTMLResponse(content=html_content) # Meme endpoint with optimized networking @app.get("/meme") async def get_meme(): try: # Use the shared client for connection pooling client = get_async_client() response = await client.get("https://meme-api.com/gimme") response_data = response.json() meme_url = response_data.get("url") if not meme_url: raise HTTPException(status_code=404, detail="No meme found") image_response = await client.get(meme_url, follow_redirects=True) # Use larger chunks for streaming async def stream_with_larger_chunks(): chunks = [] size = 0 async for chunk in image_response.aiter_bytes(chunk_size=16384): chunks.append(chunk) size += len(chunk) if size >= 65536: yield b''.join(chunks) chunks = [] size = 0 if chunks: yield b''.join(chunks) return StreamingResponse( stream_with_larger_chunks(), media_type=image_response.headers.get("content-type", "image/png"), headers={'Cache-Control': 'max-age=3600'} # Add caching ) except Exception: raise HTTPException(status_code=500, detail="Failed to retrieve meme") # Utility function for loading model IDs - optimized to run once at startup def load_model_ids(json_file_path): try: with open(json_file_path, 'r') as f: models_data = json.load(f) # Extract 'id' from each model object and use a set for fast lookups return [model['id'] for model in models_data if 'id' in model] except Exception as e: print(f"Error loading model IDs: {str(e)}") return [] @app.on_event("startup") async def startup_event(): global available_model_ids available_model_ids = load_model_ids("models.json") print(f"Loaded {len(available_model_ids)} model IDs") # Add all pollinations models to available_model_ids available_model_ids.extend(list(pollinations_models)) # Add alternate models to available_model_ids available_model_ids.extend(list(alternate_models)) # Add mistral models to available_model_ids available_model_ids.extend(list(mistral_models)) # Add claude models available_model_ids.extend(list(claude_3_models)) available_model_ids = list(set(available_model_ids)) # Remove duplicates print(f"Total available models: {len(available_model_ids)}") # Preload scrapers for _ in range(MAX_SCRAPERS): scraper_pool.append(cloudscraper.create_scraper()) # Validate critical environment variables env_vars = get_env_vars() missing_vars = [] if not env_vars['api_keys'] or env_vars['api_keys'] == ['']: missing_vars.append('API_KEYS') if not env_vars['secret_api_endpoint']: missing_vars.append('SECRET_API_ENDPOINT') if not env_vars['secret_api_endpoint_2']: missing_vars.append('SECRET_API_ENDPOINT_2') if not env_vars['secret_api_endpoint_3']: missing_vars.append('SECRET_API_ENDPOINT_3') if not env_vars['secret_api_endpoint_4']: missing_vars.append('SECRET_API_ENDPOINT_4') if not env_vars['secret_api_endpoint_5']: # Check the new endpoint missing_vars.append('SECRET_API_ENDPOINT_5') if not env_vars['mistral_api'] and any(model in mistral_models for model in available_model_ids): missing_vars.append('MISTRAL_API') if not env_vars['mistral_key'] and any(model in mistral_models for model in available_model_ids): missing_vars.append('MISTRAL_KEY') if missing_vars: print(f"WARNING: The following environment variables are missing: {', '.join(missing_vars)}") print("Some functionality may be limited.") print("Server started successfully!") @app.on_event("shutdown") async def shutdown_event(): # Close the httpx client client = get_async_client() await client.aclose() # Clear scraper pool scraper_pool.clear() # Persist usage data usage_tracker.save_data() print("Server shutdown complete!") # Health check endpoint # Health check endpoint @app.get("/health") async def health_check(): """Health check endpoint for monitoring""" env_vars = get_env_vars() missing_critical_vars = [] # Check critical environment variables if not env_vars['api_keys'] or env_vars['api_keys'] == ['']: missing_critical_vars.append('API_KEYS') if not env_vars['secret_api_endpoint']: missing_critical_vars.append('SECRET_API_ENDPOINT') if not env_vars['secret_api_endpoint_2']: missing_critical_vars.append('SECRET_API_ENDPOINT_2') if not env_vars['secret_api_endpoint_3']: missing_critical_vars.append('SECRET_API_ENDPOINT_3') if not env_vars['secret_api_endpoint_4']: missing_critical_vars.append('SECRET_API_ENDPOINT_4') if not env_vars['secret_api_endpoint_5']: # Check the new endpoint missing_critical_vars.append('SECRET_API_ENDPOINT_5') if not env_vars['mistral_api']: missing_critical_vars.append('MISTRAL_API') if not env_vars['mistral_key']: missing_critical_vars.append('MISTRAL_KEY') health_status = { "status": "healthy" if not missing_critical_vars else "unhealthy", "missing_env_vars": missing_critical_vars, "server_status": server_status, "message": "Everything's lit! 🚀" if not missing_critical_vars else "Uh oh, some env vars are missing. 😬" } return JSONResponse(content=health_status) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)