lokiai / main.py
ParthSadaria's picture
Update main.py
4bcb2c2 verified
raw
history blame
29.2 kB
import os
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException, Request, Depends, Security
from fastapi.responses import StreamingResponse, HTMLResponse, JSONResponse, FileResponse
from fastapi.security import APIKeyHeader
from pydantic import BaseModel
import httpx
from functools import lru_cache
from pathlib import Path
import json
import datetime
import time
import threading
from typing import Optional, Dict, List, Any, Generator
import asyncio
from starlette.status import HTTP_403_FORBIDDEN
import cloudscraper
from concurrent.futures import ThreadPoolExecutor
import uvloop
from fastapi.middleware.gzip import GZipMiddleware
from starlette.middleware.cors import CORSMiddleware
import contextlib
# Enable uvloop for faster event loop
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
# Thread pool for CPU-bound operations
executor = ThreadPoolExecutor(max_workers=16) # Increased thread count for better parallelism
# Load environment variables once at startup
load_dotenv()
# API key security scheme
api_key_header = APIKeyHeader(name="Authorization", auto_error=False)
# Initialize usage tracker
from usage_tracker import UsageTracker
usage_tracker = UsageTracker()
app = FastAPI()
# Add middleware for compression and CORS
app.add_middleware(GZipMiddleware, minimum_size=1000)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Environment variables (cached)
@lru_cache(maxsize=1)
def get_env_vars():
return {
'api_keys': os.getenv('API_KEYS', '').split(','),
'secret_api_endpoint': os.getenv('SECRET_API_ENDPOINT'),
'secret_api_endpoint_2': os.getenv('SECRET_API_ENDPOINT_2'),
'secret_api_endpoint_3': os.getenv('SECRET_API_ENDPOINT_3'),
'secret_api_endpoint_4': "https://text.pollinations.ai/openai",
'mistral_api': "https://api.mistral.ai",
'mistral_key': os.getenv('MISTRAL_KEY'),
'endpoint_origin': os.getenv('ENDPOINT_ORIGIN')
}
# Configuration for models - use sets for faster lookups
mistral_models = {
"mistral-large-latest",
"pixtral-large-latest",
"mistral-moderation-latest",
"ministral-3b-latest",
"ministral-8b-latest",
"open-mistral-nemo",
"mistral-small-latest",
"mistral-saba-latest",
"codestral-latest"
}
pollinations_models = {
"openai",
"openai-large",
"openai-reasoning",
"qwen-coder",
"llama",
"mistral",
"searchgpt",
"deepseek",
"claude-hybridspace",
"deepseek-r1",
"deepseek-reasoner",
"llamalight",
"gemini",
"gemini-thinking",
"hormoz",
"phi",
"openai-audio",
"llama-scaleway"
}
alternate_models = { #heh doesnt work
"gpt-4o-mini",
"deepseek-v3",
"llama-3.1-8b-instruct",
"searchgpt",
"llama-3.1-sonar-small-128k-online",
"sonar-reasoning",
"deepseek-r1-uncensored",
"tinyswallow1.5b",
"andy-3.5",
"o3-mini-low",
"command-a",
"claude-3-7-20250219",
"hermes-3-llama-3.2-3b"
}
# Request payload model
class Payload(BaseModel):
model: str
messages: list
stream: bool = False
# Server status global variable
server_status = True
available_model_ids: List[str] = []
# Create a reusable httpx client pool with connection pooling
@lru_cache(maxsize=1)
def get_async_client():
return httpx.AsyncClient(
timeout=60.0,
limits=httpx.Limits(max_keepalive_connections=50, max_connections=200) # Increased limits
)
# Create a cloudscraper pool
scraper_pool = []
MAX_SCRAPERS = 20 # Increased pool size
def get_scraper():
if not scraper_pool:
for _ in range(MAX_SCRAPERS):
scraper_pool.append(cloudscraper.create_scraper())
return scraper_pool[int(time.time() * 1000) % MAX_SCRAPERS] # Simple round-robin
# API key validation - optimized to avoid string operations when possible
async def verify_api_key(
request: Request,
api_key: str = Security(api_key_header)
) -> bool:
# Allow bypass if the referer is from /playground
referer = request.headers.get("referer", "")
if referer.startswith("https://parthsadaria-lokiai.hf.space/playground"):
return True
if not api_key:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN,
detail="No API key provided"
)
# Only clean if needed
if api_key.startswith('Bearer '):
api_key = api_key[7:] # Remove 'Bearer ' prefix
# Get API keys from environment
valid_api_keys = get_env_vars().get('api_keys', [])
if not valid_api_keys or valid_api_keys == ['']:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN,
detail="API keys not configured on server"
)
# Fast check with set operation
if api_key not in set(valid_api_keys):
raise HTTPException(
status_code=HTTP_403_FORBIDDEN,
detail="Invalid API key"
)
return True
# Pre-load and cache models.json
@lru_cache(maxsize=1)
def load_models_data():
try:
file_path = Path(__file__).parent / 'models.json'
with open(file_path, 'r') as f:
return json.load(f)
except (FileNotFoundError, json.JSONDecodeError) as e:
print(f"Error loading models.json: {str(e)}")
return []
# Async wrapper for models data
async def get_models():
models_data = load_models_data()
if not models_data:
raise HTTPException(status_code=500, detail="Error loading available models")
return models_data
# Enhanced async streaming - now with real-time SSE support
async def generate_search_async(query: str, systemprompt: Optional[str] = None, stream: bool = True):
# Create a streaming response channel using asyncio.Queue
queue = asyncio.Queue()
async def _fetch_search_data():
try:
headers = {"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"}
# Use the provided system prompt, or default to "Be Helpful and Friendly"
system_message = systemprompt or "Be Helpful and Friendly"
# Create the prompt history
prompt = [
{"role": "user", "content": query},
]
prompt.insert(0, {"content": system_message, "role": "system"})
# Prepare the payload for the API request
payload = {
"is_vscode_extension": True,
"message_history": prompt,
"requested_model": "Claude 3.7 Sonnet",
"user_input": prompt[-1]["content"],
}
# Get endpoint from environment
secret_api_endpoint_3 = get_env_vars()['secret_api_endpoint_3']
if not secret_api_endpoint_3:
await queue.put({"error": "Search API endpoint not configured"})
return
# Use AsyncClient for better performance
async with httpx.AsyncClient(timeout=30.0) as client:
async with client.stream("POST", secret_api_endpoint_3, json=payload, headers=headers) as response:
if response.status_code != 200:
await queue.put({"error": f"Search API returned status code {response.status_code}"})
return
# Process the streaming response in real-time
buffer = ""
async for line in response.aiter_lines():
if line.startswith("data: "):
try:
json_data = json.loads(line[6:])
content = json_data.get("choices", [{}])[0].get("delta", {}).get("content", "")
if content.strip():
cleaned_response = {
"created": json_data.get("created"),
"id": json_data.get("id"),
"model": "searchgpt",
"object": "chat.completion",
"choices": [
{
"message": {
"content": content
}
}
]
}
# Send to queue immediately for streaming
await queue.put({"data": f"data: {json.dumps(cleaned_response)}\n\n", "text": content})
except json.JSONDecodeError:
continue
# Signal completion
await queue.put(None)
except Exception as e:
await queue.put({"error": str(e)})
await queue.put(None)
# Start the fetch process
asyncio.create_task(_fetch_search_data())
# Return the queue for consumption
return queue
# Cache for frequently accessed static files
@lru_cache(maxsize=10)
def read_html_file(file_path):
try:
with open(file_path, "r") as file:
return file.read()
except FileNotFoundError:
return None
# Basic routes
@app.get("/favicon.ico")
async def favicon():
favicon_path = Path(__file__).parent / "favicon.ico"
return FileResponse(favicon_path, media_type="image/x-icon")
@app.get("/ping")
async def ping():
return {"message": "pong", "response_time": "0.000000 seconds"}
@app.get("/", response_class=HTMLResponse)
async def root():
html_content = read_html_file("index.html")
if html_content is None:
return HTMLResponse(content="<h1>File not found</h1>", status_code=404)
return HTMLResponse(content=html_content)
@app.get("/playground", response_class=HTMLResponse)
async def playground():
html_content = read_html_file("playground.html")
if html_content is None:
return HTMLResponse(content="<h1>playground.html not found</h1>", status_code=404)
return HTMLResponse(content=html_content)
# Model routes
@app.get("/api/v1/models")
@app.get("/models")
async def return_models():
return await get_models()
# Search routes with enhanced real-time streaming
@app.get("/searchgpt")
async def search_gpt(q: str, stream: Optional[bool] = False, systemprompt: Optional[str] = None):
if not q:
raise HTTPException(status_code=400, detail="Query parameter 'q' is required")
usage_tracker.record_request(endpoint="/searchgpt")
queue = await generate_search_async(q, systemprompt=systemprompt, stream=True)
if stream:
async def stream_generator():
collected_text = ""
while True:
item = await queue.get()
if item is None:
break
if "error" in item:
yield f"data: {json.dumps({'error': item['error']})}\n\n"
break
if "data" in item:
yield item["data"]
collected_text += item.get("text", "")
return StreamingResponse(
stream_generator(),
media_type="text/event-stream"
)
else:
# For non-streaming, collect all text and return at once
collected_text = ""
while True:
item = await queue.get()
if item is None:
break
if "error" in item:
raise HTTPException(status_code=500, detail=item["error"])
collected_text += item.get("text", "")
return JSONResponse(content={"response": collected_text})
# Enhanced streaming with direct SSE pass-through for real-time responses
@app.post("/chat/completions")
@app.post("/api/v1/chat/completions")
async def get_completion(payload: Payload, request: Request, authenticated: bool = Depends(verify_api_key)):
# Check server status
if not server_status:
return JSONResponse(
status_code=503,
content={"message": "Server is under maintenance. Please try again later."}
)
model_to_use = payload.model or "gpt-4o-mini"
# Validate model availability - fast lookup with set
if available_model_ids and model_to_use not in set(available_model_ids):
raise HTTPException(
status_code=400,
detail=f"Model '{model_to_use}' is not available. Check /models for the available model list."
)
# Log request without blocking
asyncio.create_task(log_request(request, model_to_use))
usage_tracker.record_request(model=model_to_use, endpoint="/chat/completions")
# Prepare payload
payload_dict = payload.dict()
payload_dict["model"] = model_to_use
# Ensure stream is True for real-time streaming (can be overridden by client)
stream_enabled = payload_dict.get("stream", True)
# Get environment variables
env_vars = get_env_vars()
# Select the appropriate endpoint (fast lookup with sets)
if model_to_use in mistral_models:
endpoint = env_vars['mistral_api']
custom_headers = {
"Authorization": f"Bearer {env_vars['mistral_key']}"
}
elif model_to_use in pollinations_models:
endpoint = env_vars['secret_api_endpoint_4']
custom_headers = {}
elif model_to_use in alternate_models:
endpoint = env_vars['secret_api_endpoint_2']
custom_headers = {}
else:
endpoint = env_vars['secret_api_endpoint']
custom_headers = {}
print(f"Using endpoint: {endpoint} for model: {model_to_use}")
# Improved real-time streaming handler
async def real_time_stream_generator():
try:
async with httpx.AsyncClient(timeout=60.0) as client:
async with client.stream("POST", f"{endpoint}/v1/chat/completions", json=payload_dict, headers=custom_headers) as response:
if response.status_code >= 400:
error_messages = {
422: "Unprocessable entity. Check your payload.",
400: "Bad request. Verify input data.",
403: "Forbidden. You do not have access to this resource.",
404: "The requested resource was not found.",
}
detail = error_messages.get(response.status_code, f"Error code: {response.status_code}")
raise HTTPException(status_code=response.status_code, detail=detail)
# Stream the response in real-time with minimal buffering
async for line in response.aiter_lines():
if line:
# Yield immediately for faster streaming
yield line + "\n"
except httpx.TimeoutException:
raise HTTPException(status_code=504, detail="Request timed out")
except httpx.RequestError as e:
raise HTTPException(status_code=502, detail=f"Failed to connect to upstream API: {str(e)}")
except Exception as e:
if isinstance(e, HTTPException):
raise e
raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
# Return streaming response with proper headers
if stream_enabled:
return StreamingResponse(
real_time_stream_generator(),
media_type="text/event-stream",
headers={
"Content-Type": "text/event-stream",
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no" # Disable proxy buffering for Nginx
}
)
else:
# For non-streaming requests, collect the entire response
response_content = []
async for chunk in real_time_stream_generator():
response_content.append(chunk)
return JSONResponse(content=json.loads(''.join(response_content)))
# Asynchronous logging function
async def log_request(request, model):
# Get minimal data for logging
current_time = (datetime.datetime.utcnow() + datetime.timedelta(hours=5, minutes=30)).strftime("%Y-%m-%d %I:%M:%S %p")
ip_hash = hash(request.client.host) % 10000 # Hash the IP for privacy
print(f"Time: {current_time}, IP Hash: {ip_hash}, Model: {model}")
# Cache usage statistics
@lru_cache(maxsize=10)
def get_usage_summary(days=7):
return usage_tracker.get_usage_summary(days)
@app.get("/usage")
async def get_usage(days: int = 7):
"""Retrieve usage statistics"""
return get_usage_summary(days)
# Generate HTML for usage page
def generate_usage_html(usage_data):
# Model Usage Table Rows
model_usage_rows = "\n".join([
f"""
<tr>
<td>{model}</td>
<td>{model_data['total_requests']}</td>
<td>{model_data['first_used']}</td>
<td>{model_data['last_used']}</td>
</tr>
""" for model, model_data in usage_data['models'].items()
])
# API Endpoint Usage Table Rows
api_usage_rows = "\n".join([
f"""
<tr>
<td>{endpoint}</td>
<td>{endpoint_data['total_requests']}</td>
<td>{endpoint_data['first_used']}</td>
<td>{endpoint_data['last_used']}</td>
</tr>
""" for endpoint, endpoint_data in usage_data['api_endpoints'].items()
])
# Daily Usage Table Rows
daily_usage_rows = "\n".join([
"\n".join([
f"""
<tr>
<td>{date}</td>
<td>{entity}</td>
<td>{requests}</td>
</tr>
""" for entity, requests in date_data.items()
]) for date, date_data in usage_data['recent_daily_usage'].items()
])
html_content = f"""
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<title>Lokiai AI - Usage Statistics</title>
<link href="https://fonts.googleapis.com/css2?family=Inter:wght@300;400;600&display=swap" rel="stylesheet">
<style>
:root {{
--bg-dark: #0f1011;
--bg-darker: #070708;
--text-primary: #e6e6e6;
--text-secondary: #8c8c8c;
--border-color: #2c2c2c;
--accent-color: #3a6ee0;
--accent-hover: #4a7ef0;
}}
body {{
font-family: 'Inter', sans-serif;
background-color: var(--bg-dark);
color: var(--text-primary);
max-width: 1200px;
margin: 0 auto;
padding: 40px 20px;
line-height: 1.6;
}}
.logo {{
display: flex;
align-items: center;
justify-content: center;
margin-bottom: 30px;
}}
.logo h1 {{
font-weight: 600;
font-size: 2.5em;
color: var(--text-primary);
margin-left: 15px;
}}
.logo img {{
width: 60px;
height: 60px;
border-radius: 10px;
}}
.container {{
background-color: var(--bg-darker);
border-radius: 12px;
padding: 30px;
box-shadow: 0 15px 40px rgba(0,0,0,0.3);
border: 1px solid var(--border-color);
}}
h2, h3 {{
color: var(--text-primary);
border-bottom: 2px solid var(--border-color);
padding-bottom: 10px;
font-weight: 500;
}}
.total-requests {{
background-color: var(--accent-color);
color: white;
text-align: center;
padding: 15px;
border-radius: 8px;
margin-bottom: 30px;
font-weight: 600;
letter-spacing: -0.5px;
}}
table {{
width: 100%;
border-collapse: separate;
border-spacing: 0;
margin-bottom: 30px;
background-color: var(--bg-dark);
border-radius: 8px;
overflow: hidden;
}}
th, td {{
border: 1px solid var(--border-color);
padding: 12px;
text-align: left;
transition: background-color 0.3s ease;
}}
th {{
background-color: #1e1e1e;
color: var(--text-primary);
font-weight: 600;
text-transform: uppercase;
font-size: 0.9em;
}}
tr:nth-child(even) {{
background-color: rgba(255,255,255,0.05);
}}
tr:hover {{
background-color: rgba(62,100,255,0.1);
}}
@media (max-width: 768px) {{
.container {{
padding: 15px;
}}
table {{
font-size: 0.9em;
}}
}}
</style>
</head>
<body>
<div class="container">
<div class="logo">
<img src="" alt="Lokai AI Logo">
<h1>Lokiai AI</h1>
</div>
<div class="total-requests">
Total API Requests: {usage_data['total_requests']}
</div>
<h2>Model Usage</h2>
<table>
<tr>
<th>Model</th>
<th>Total Requests</th>
<th>First Used</th>
<th>Last Used</th>
</tr>
{model_usage_rows}
</table>
<h2>API Endpoint Usage</h2>
<table>
<tr>
<th>Endpoint</th>
<th>Total Requests</th>
<th>First Used</th>
<th>Last Used</th>
</tr>
{api_usage_rows}
</table>
<h2>Daily Usage (Last 7 Days)</h2>
<table>
<tr>
<th>Date</th>
<th>Entity</th>
<th>Requests</th>
</tr>
{daily_usage_rows}
</table>
</div>
</body>
</html>
"""
return html_content
# Cache the usage page HTML
@lru_cache(maxsize=1)
def get_usage_page_html():
usage_data = get_usage_summary()
return generate_usage_html(usage_data)
@app.get("/usage/page", response_class=HTMLResponse)
async def usage_page():
"""Serve an HTML page showing usage statistics"""
# Use cached HTML if available, regenerate if not
html_content = get_usage_page_html()
return HTMLResponse(content=html_content)
# Meme endpoint with optimized networking
@app.get("/meme")
async def get_meme():
try:
# Use the shared client for connection pooling
client = get_async_client()
response = await client.get("https://meme-api.com/gimme")
response_data = response.json()
meme_url = response_data.get("url")
if not meme_url:
raise HTTPException(status_code=404, detail="No meme found")
image_response = await client.get(meme_url, follow_redirects=True)
# Use larger chunks for streaming
async def stream_with_larger_chunks():
chunks = []
size = 0
async for chunk in image_response.aiter_bytes(chunk_size=16384):
chunks.append(chunk)
size += len(chunk)
if size >= 65536:
yield b''.join(chunks)
chunks = []
size = 0
if chunks:
yield b''.join(chunks)
return StreamingResponse(
stream_with_larger_chunks(),
media_type=image_response.headers.get("content-type", "image/png"),
headers={'Cache-Control': 'max-age=3600'} # Add caching
)
except Exception:
raise HTTPException(status_code=500, detail="Failed to retrieve meme")
# Utility function for loading model IDs - optimized to run once at startup
def load_model_ids(json_file_path):
try:
with open(json_file_path, 'r') as f:
models_data = json.load(f)
# Extract 'id' from each model object and use a set for fast lookups
return [model['id'] for model in models_data if 'id' in model]
except Exception as e:
print(f"Error loading model IDs: {str(e)}")
return []
@app.on_event("startup")
async def startup_event():
global available_model_ids
available_model_ids = load_model_ids("models.json")
print(f"Loaded {len(available_model_ids)} model IDs")
# Add all pollinations models to available_model_ids
available_model_ids.extend(list(pollinations_models))
# Add alternate models to available_model_ids
available_model_ids.extend(list(alternate_models))
# Add mistral models to available_model_ids
available_model_ids.extend(list(mistral_models))
available_model_ids = list(set(available_model_ids)) # Remove duplicates
print(f"Total available models: {len(available_model_ids)}")
# Preload scrapers
for _ in range(MAX_SCRAPERS):
scraper_pool.append(cloudscraper.create_scraper())
# Validate critical environment variables
env_vars = get_env_vars()
missing_vars = []
if not env_vars['secret_api_endpoint']:
missing_vars.append('SECRET_API_ENDPOINT')
if not env_vars['secret_api_endpoint_2']:
missing_vars.append('SECRET_API_ENDPOINT_2')
if not env_vars['secret_api_endpoint_3']:
missing_vars.append('SECRET_API_ENDPOINT_3')
if not env_vars['secret_api_endpoint_4']:
missing_vars.append('SECRET_API_ENDPOINT_4')
if not env_vars['mistral_api'] and any(model in mistral_models for model in available_model_ids):
missing_vars.append('MISTRAL_API')
if not env_vars['mistral_key'] and any(model in mistral_models for model in available_model_ids):
missing_vars.append('MISTRAL_KEY')
if missing_vars:
print(f"WARNING: The following environment variables are missing: {', '.join(missing_vars)}")
print("Some functionality may be limited.")
print("Server started successfully!")
@app.on_event("shutdown")
async def shutdown_event():
# Close the httpx client
client = get_async_client()
await client.aclose()
# Clear scraper pool
scraper_pool.clear()
# Persist usage data
usage_tracker.save_data()
print("Server shutdown complete!")
# Health check endpoint
# Health check endpoint
@app.get("/health")
async def health_check():
"""Health check endpoint for monitoring"""
env_vars = get_env_vars()
missing_critical_vars = []
# Check critical environment variables
if not env_vars['api_keys'] or env_vars['api_keys'] == ['']:
missing_critical_vars.append('API_KEYS')
if not env_vars['secret_api_endpoint']:
missing_critical_vars.append('SECRET_API_ENDPOINT')
if not env_vars['secret_api_endpoint_2']:
missing_critical_vars.append('SECRET_API_ENDPOINT_2')
if not env_vars['secret_api_endpoint_3']:
missing_critical_vars.append('SECRET_API_ENDPOINT_3')
if not env_vars['secret_api_endpoint_4']:
missing_critical_vars.append('SECRET_API_ENDPOINT_4')
if not env_vars['mistral_api']:
missing_critical_vars.append('MISTRAL_API')
if not env_vars['mistral_key']:
missing_critical_vars.append('MISTRAL_KEY')
health_status = {
"status": "healthy" if not missing_critical_vars else "unhealthy",
"missing_env_vars": missing_critical_vars,
"server_status": server_status,
"message": "Everything's lit! πŸš€" if not missing_critical_vars else "Uh oh, some env vars are missing. 😬"
}
return JSONResponse(content=health_status)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)