import os
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import StreamingResponse, HTMLResponse, JSONResponse, FileResponse
from pydantic import BaseModel
import httpx
import hashlib
from functools import lru_cache
from pathlib import Path # Import Path from pathlib
import requests
import re
import cloudscraper
import json
from typing import Optional
import datetime
import time
from usage_tracker import UsageTracker
from starlette.middleware.base import BaseHTTPMiddleware
from collections import defaultdict
from fastapi import Security #new
from fastapi import Depends
from fastapi.security import APIKeyHeader
from starlette.exceptions import HTTPException
from starlette.status import HTTP_403_FORBIDDEN
# API key header scheme
api_key_header = APIKeyHeader(name="Authorization", auto_error=False)
# Function to validate API key
async def verify_api_key(api_key: str = Security(api_key_header)) -> bool:
if not api_key:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN,
detail="No API key provided"
)
# Clean the API key by removing 'Bearer ' if present
if api_key.startswith('Bearer '):
api_key = api_key[7:] # Remove 'Bearer ' prefix
# Get API keys from environment
api_keys_str = os.getenv('API_KEYS')
if not api_keys_str:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN,
detail="API keys not configured on server"
)
valid_api_keys = api_keys_str.split(',')
# Check if the provided key is valid
if api_key not in valid_api_keys:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN,
detail="Invalid API key"
)
return True
class RateLimitMiddleware(BaseHTTPMiddleware):
def __init__(self, app, requests_per_second: int = 2):
super().__init__(app)
self.requests_per_second = requests_per_second
self.last_request_time = defaultdict(float)
self.tokens = defaultdict(lambda: requests_per_second)
self.last_update = defaultdict(float)
async def dispatch(self, request: Request, call_next):
client_ip = request.client.host
current_time = time.time()
# Update tokens
time_passed = current_time - self.last_update[client_ip]
self.last_update[client_ip] = current_time
self.tokens[client_ip] = min(
self.requests_per_second,
self.tokens[client_ip] + time_passed * self.requests_per_second
)
# Check if request can be processed
if self.tokens[client_ip] < 1:
return JSONResponse(
status_code=429,
content={
"detail": "Too many requests. Please try again later.",
"retry_after": round((1 - self.tokens[client_ip]) / self.requests_per_second)
}
)
# Consume a token
self.tokens[client_ip] -= 1
# Process the request
response = await call_next(request)
return response
usage_tracker = UsageTracker()
load_dotenv() #idk why this shi
app = FastAPI()
app.add_middleware(RateLimitMiddleware, requests_per_second=2)
# Get API keys and secret endpoint from environment variables
# valid_api_keys = api_keys_str.split(',') if api_keys_str else []
secret_api_endpoint = os.getenv('SECRET_API_ENDPOINT')
secret_api_endpoint_2 = os.getenv('SECRET_API_ENDPOINT_2')
secret_api_endpoint_3 = os.getenv('SECRET_API_ENDPOINT_3') # New endpoint for searchgpt
image_endpoint = os.getenv("IMAGE_ENDPOINT")
ENDPOINT_ORIGIN = os.getenv('ENDPOINT_ORIGIN')
# Validate if the main secret API endpoints are set
if not secret_api_endpoint or not secret_api_endpoint_2 or not secret_api_endpoint_3:
raise HTTPException(status_code=500, detail="API endpoint(s) are not configured in environment variables.")
# Define models that should use the secondary endpoint
# alternate_models = {"gpt-4o-mini", "claude-3-haiku", "llama-3.1-70b", "mixtral-8x7b"}
available_model_ids = []
class Payload(BaseModel):
model: str
messages: list
stream: bool = False
@app.get("/favicon.ico")
async def favicon():
# The favicon.ico file is in the same directory as the app
favicon_path = Path(__file__).parent / "favicon.ico"
return FileResponse(favicon_path, media_type="image/x-icon")
def generate_search(query: str, systemprompt: Optional[str] = None, stream: bool = True) -> str:
headers = {"User-Agent": ""}
# Use the provided system prompt, or default to "Be Helpful and Friendly"
system_message = systemprompt or "Be Helpful and Friendly"
# Create the prompt history with the user query and system message
prompt = [
{"role": "user", "content": query},
]
prompt.insert(0, {"content": system_message, "role": "system"})
# Prepare the payload for the API request
payload = {
"is_vscode_extension": True,
"message_history": prompt,
"requested_model": "searchgpt",
"user_input": prompt[-1]["content"],
}
# Send the request to the chat endpoint
response = requests.post(secret_api_endpoint_3, headers=headers, json=payload, stream=True)
streaming_text = ""
# Process the streaming response
for value in response.iter_lines(decode_unicode=True):
if value.startswith("data: "):
try:
json_modified_value = json.loads(value[6:])
content = json_modified_value.get("choices", [{}])[0].get("delta", {}).get("content", "")
if content.strip(): # Only process non-empty content
cleaned_response = {
"created": json_modified_value.get("created"),
"id": json_modified_value.get("id"),
"model": "searchgpt",
"object": "chat.completion",
"choices": [
{
"message": {
"content": content
}
}
]
}
if stream:
yield f"data: {json.dumps(cleaned_response)}\n\n"
streaming_text += content
except json.JSONDecodeError:
continue
if not stream:
yield streaming_text
@app.get("/ping")
async def ping():
start_time = datetime.datetime.now()
response_time = (datetime.datetime.now() - start_time).total_seconds()
return {"message": "pong", "response_time": f"{response_time:.6f} seconds"}
@app.get("/searchgpt")
async def search_gpt(q: str, stream: Optional[bool] = False, systemprompt: Optional[str] = None,authenticated: bool = Depends(verify_api_key)):
if not q:
raise HTTPException(status_code=400, detail="Query parameter 'q' is required")
usage_tracker.record_request(endpoint="/searchgpt")
if stream:
return StreamingResponse(
generate_search(q, systemprompt=systemprompt, stream=True),
media_type="text/event-stream"
)
else:
# For non-streaming, collect the text and return as JSON response
response_text = "".join([chunk for chunk in generate_search(q, systemprompt=systemprompt, stream=False)])
return JSONResponse(content={"response": response_text})
@app.get("/", response_class=HTMLResponse)
async def root():
# Open and read the content of index.html (in the same folder as the app)
file_path = "index.html"
try:
with open(file_path, "r") as file:
html_content = file.read()
return HTMLResponse(content=html_content)
except FileNotFoundError:
return HTMLResponse(content="
File not found
", status_code=404)
async def get_models():
try:
# Load the models from models.json in the same folder
file_path = Path(__file__).parent / 'models.json'
with open(file_path, 'r') as f:
return json.load(f)
except FileNotFoundError:
raise HTTPException(status_code=404, detail="models.json not found")
except json.JSONDecodeError:
raise HTTPException(status_code=500, detail="Error decoding models.json")
@app.get("api/v1/models")
@app.get("/models")
async def return_models():
return await get_models()
server_status = True
@app.post("/chat/completions")
@app.post("api/v1/chat/completions")
async def get_completion(payload: Payload, request: Request,authenticated: bool = Depends(verify_api_key)):
# Check server status
model_to_use = payload.model if payload.model else "gpt-4o-mini"
# Validate model availability
if model_to_use not in available_model_ids:
raise HTTPException(
status_code=400,
detail=f"Model '{model_to_use}' is not available. Check /models for the available model list."
)
usage_tracker.record_request(model=model_to_use, endpoint="/chat/completions")
# Prepare payload
payload_dict = payload.dict()
payload_dict["model"] = model_to_use
# payload_dict["stream"] = payload_dict.get("stream", False)
# Select the appropriate endpoint
endpoint = secret_api_endpoint
# Current time and IP logging
current_time = (datetime.datetime.utcnow() + datetime.timedelta(hours=5, minutes=30)).strftime("%Y-%m-%d %I:%M:%S %p")
aaip = request.client.host
print(f"Time: {current_time}, {aaip} , {model_to_use}, server status :- {server_status}")
print(payload_dict)
if not server_status:
return JSONResponse(
status_code=503,
content={"message": "Server is under maintenance. Please try again later."}
)
scraper = cloudscraper.create_scraper()
async def stream_generator(payload_dict):
# Prepare custom headers
custom_headers = {
'DNT': '1',
# 'Origin': ENDPOINT_ORIGIN,
'Priority': 'u=1, i',
# 'Referer': ENDPOINT_ORIGIN
}
try:
# Send POST request using CloudScraper with custom headers
response = scraper.post(
f"{endpoint}/v1/chat/completions",
json=payload_dict,
headers=custom_headers,
stream=True
)
# Error handling remains the same as in previous version
if response.status_code == 422:
raise HTTPException(status_code=422, detail="Unprocessable entity. Check your payload.")
elif response.status_code == 400:
raise HTTPException(status_code=400, detail="Bad request. Verify input data.")
elif response.status_code == 403:
raise HTTPException(status_code=403, detail="Forbidden. You do not have access to this resource.")
elif response.status_code == 404:
raise HTTPException(status_code=404, detail="The requested resource was not found.")
elif response.status_code >= 500:
raise HTTPException(status_code=500, detail="Server error. Try again later.")
# Stream response lines to the client
for line in response.iter_lines():
if line:
yield line.decode('utf-8') + "\n"
except requests.exceptions.RequestException as req_err:
# Handle request-specific errors
print(response.text)
raise HTTPException(status_code=500, detail=f"Request failed: {req_err}")
except Exception as e:
# Handle unexpected errors
print(response.text)
raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {e}")
return StreamingResponse(stream_generator(payload_dict), media_type="application/json")
# Remove the duplicated endpoint and combine the functionality
@app.api_route("/images/generations", methods=["GET", "POST"]) # Support both GET and POST
async def generate_image(
prompt: Optional[str] = None,
model: str = "flux", # Default model
seed: Optional[int] = None,
width: Optional[int] = None,
height: Optional[int] = None,
nologo: Optional[bool] = True,
private: Optional[bool] = None,
enhance: Optional[bool] = None,
request: Request = None, # Access raw POST data
authenticated: bool = Depends(verify_api_key)
):
"""
Generate an image using the Image Generation API.
"""
# Validate the image endpoint
if not image_endpoint:
raise HTTPException(status_code=500, detail="Image endpoint not configured in environment variables.")
usage_tracker.record_request(endpoint="/images/generations")
# Handle GET and POST prompts
if request.method == "POST":
try:
body = await request.json() # Parse JSON body
prompt = body.get("prompt", "").strip()
if not prompt:
raise HTTPException(status_code=400, detail="Prompt cannot be empty")
except Exception:
raise HTTPException(status_code=400, detail="Invalid JSON payload")
elif request.method == "GET":
if not prompt or not prompt.strip():
raise HTTPException(status_code=400, detail="Prompt cannot be empty")
prompt = prompt.strip()
# Sanitize and encode the prompt
encoded_prompt = httpx.QueryParams({'prompt': prompt}).get('prompt')
# Construct the URL with the encoded prompt
base_url = image_endpoint.rstrip('/') # Remove trailing slash if present
url = f"{base_url}/{encoded_prompt}"
# Prepare query parameters with validation
params = {}
if model and isinstance(model, str):
params['model'] = model
if seed is not None and isinstance(seed, int):
params['seed'] = seed
if width is not None and isinstance(width, int) and 64 <= width <= 2048:
params['width'] = width
if height is not None and isinstance(height, int) and 64 <= height <= 2048:
params['height'] = height
if nologo is not None:
params['nologo'] = str(nologo).lower()
if private is not None:
params['private'] = str(private).lower()
if enhance is not None:
params['enhance'] = str(enhance).lower()
try:
timeout = httpx.Timeout(60.0) # Set a reasonable timeout
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.get(url, params=params, follow_redirects=True)
# Check for various error conditions
if response.status_code == 404:
raise HTTPException(status_code=404, detail="Image generation service not found")
elif response.status_code == 400:
raise HTTPException(status_code=400, detail="Invalid parameters provided to image service")
elif response.status_code == 429:
raise HTTPException(status_code=429, detail="Too many requests to image service")
elif response.status_code != 200:
raise HTTPException(
status_code=response.status_code,
detail=f"Image generation failed with status code {response.status_code}"
)
# Verify content type
content_type = response.headers.get('content-type', '')
if not content_type.startswith('image/'):
raise HTTPException(
status_code=500,
detail=f"Unexpected content type received: {content_type}"
)
return StreamingResponse(
response.iter_bytes(),
media_type=content_type,
headers={
'Cache-Control': 'no-cache',
'Pragma': 'no-cache'
}
)
except httpx.TimeoutException:
raise HTTPException(status_code=504, detail="Image generation request timed out")
except httpx.RequestError as e:
raise HTTPException(status_code=500, detail=f"Failed to contact image service: {str(e)}")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error during image generation: {str(e)}")
@app.get("/playground", response_class=HTMLResponse)
async def playground():
# Open and read the content of playground.html (in the same folder as the app)
file_path = "playground.html"
try:
with open(file_path, "r") as file:
html_content = file.read()
return HTMLResponse(content=html_content)
except FileNotFoundError:
return HTMLResponse(content="
playground.html not found
", status_code=404)
def load_model_ids(json_file_path):
try:
with open(json_file_path, 'r') as f:
models_data = json.load(f)
# Extract 'id' from each model object
model_ids = [model['id'] for model in models_data if 'id' in model]
return model_ids
except FileNotFoundError:
print("Error: models.json file not found.")
return []
except json.JSONDecodeError:
print("Error: Invalid JSON format in models.json.")
return []
@app.get("/usage")
async def get_usage(days: int = 7):
"""Retrieve usage statistics"""
return usage_tracker.get_usage_summary(days)
@app.get("/usage/page", response_class=HTMLResponse)
async def usage_page():
"""Serve an HTML page showing usage statistics"""
# Retrieve usage data
usage_data = usage_tracker.get_usage_summary()
# Model Usage Table Rows
model_usage_rows = "\n".join([
f"""
{model}
{model_data['total_requests']}
{model_data['first_used']}
{model_data['last_used']}
""" for model, model_data in usage_data['models'].items()
])
# API Endpoint Usage Table Rows
api_usage_rows = "\n".join([
f"""
{endpoint}
{endpoint_data['total_requests']}
{endpoint_data['first_used']}
{endpoint_data['last_used']}
""" for endpoint, endpoint_data in usage_data['api_endpoints'].items()
])
# Daily Usage Table Rows
daily_usage_rows = "\n".join([
"\n".join([
f"""
{date}
{entity}
{requests}
""" for entity, requests in date_data.items()
]) for date, date_data in usage_data['recent_daily_usage'].items()
])
html_content = f"""
Lokiai AI - Usage Statistics
Lokiai AI
Total API Requests: {usage_data['total_requests']}
Model Usage
Model
Total Requests
First Used
Last Used
{model_usage_rows}
API Endpoint Usage
Endpoint
Total Requests
First Used
Last Used
{api_usage_rows}
Daily Usage (Last 7 Days)
Date
Entity
Requests
{daily_usage_rows}
"""
return HTMLResponse(content=html_content)
@app.get("/meme")
async def get_meme():
try:
response = requests.get("https://meme-api.com/gimme")
response_data = response.json()
meme_url = response_data.get("url")
if meme_url:
def stream_image():
with requests.get(meme_url, stream=True) as image_response:
for chunk in image_response.iter_content(chunk_size=1024):
yield chunk
return StreamingResponse(stream_image(), media_type="image/png")
else:
raise HTTPException(status_code=404, detail="No mimi found :(")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.on_event("startup")
async def startup_event():
global available_model_ids
available_model_ids = load_model_ids("models.json")
print(f"Loaded model IDs: {available_model_ids}")
print("API endpoints:")
print("GET /")
print("GET /models")
print("GET /searchgpt")
print("POST /chat/completions")
print("GET /images/generations")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)