import os from dotenv import load_dotenv from fastapi import FastAPI, HTTPException, Request from fastapi.responses import StreamingResponse, HTMLResponse, JSONResponse, FileResponse from pydantic import BaseModel import httpx import hashlib from functools import lru_cache from pathlib import Path # Import Path from pathlib import requests import re import cloudscraper import json from typing import Optional import datetime import time from usage_tracker import UsageTracker from starlette.middleware.base import BaseHTTPMiddleware from collections import defaultdict from fastapi import Security #new from fastapi import Depends from fastapi.security import APIKeyHeader from starlette.exceptions import HTTPException from starlette.status import HTTP_403_FORBIDDEN # API key header scheme api_key_header = APIKeyHeader(name="Authorization", auto_error=False) # Function to validate API key async def verify_api_key(api_key: str = Security(api_key_header)) -> bool: if not api_key: raise HTTPException( status_code=HTTP_403_FORBIDDEN, detail="No API key provided" ) # Clean the API key by removing 'Bearer ' if present if api_key.startswith('Bearer '): api_key = api_key[7:] # Remove 'Bearer ' prefix # Get API keys from environment api_keys_str = os.getenv('API_KEYS') if not api_keys_str: raise HTTPException( status_code=HTTP_403_FORBIDDEN, detail="API keys not configured on server" ) valid_api_keys = api_keys_str.split(',') # Check if the provided key is valid if api_key not in valid_api_keys: raise HTTPException( status_code=HTTP_403_FORBIDDEN, detail="Invalid API key" ) return True class RateLimitMiddleware(BaseHTTPMiddleware): def __init__(self, app, requests_per_second: int = 2): super().__init__(app) self.requests_per_second = requests_per_second self.last_request_time = defaultdict(float) self.tokens = defaultdict(lambda: requests_per_second) self.last_update = defaultdict(float) async def dispatch(self, request: Request, call_next): client_ip = request.client.host current_time = time.time() # Update tokens time_passed = current_time - self.last_update[client_ip] self.last_update[client_ip] = current_time self.tokens[client_ip] = min( self.requests_per_second, self.tokens[client_ip] + time_passed * self.requests_per_second ) # Check if request can be processed if self.tokens[client_ip] < 1: return JSONResponse( status_code=429, content={ "detail": "Too many requests. Please try again later.", "retry_after": round((1 - self.tokens[client_ip]) / self.requests_per_second) } ) # Consume a token self.tokens[client_ip] -= 1 # Process the request response = await call_next(request) return response usage_tracker = UsageTracker() load_dotenv() #idk why this shi app = FastAPI() app.add_middleware(RateLimitMiddleware, requests_per_second=2) # Get API keys and secret endpoint from environment variables # valid_api_keys = api_keys_str.split(',') if api_keys_str else [] secret_api_endpoint = os.getenv('SECRET_API_ENDPOINT') secret_api_endpoint_2 = os.getenv('SECRET_API_ENDPOINT_2') secret_api_endpoint_3 = os.getenv('SECRET_API_ENDPOINT_3') # New endpoint for searchgpt image_endpoint = os.getenv("IMAGE_ENDPOINT") ENDPOINT_ORIGIN = os.getenv('ENDPOINT_ORIGIN') # Validate if the main secret API endpoints are set if not secret_api_endpoint or not secret_api_endpoint_2 or not secret_api_endpoint_3: raise HTTPException(status_code=500, detail="API endpoint(s) are not configured in environment variables.") # Define models that should use the secondary endpoint # alternate_models = {"gpt-4o-mini", "claude-3-haiku", "llama-3.1-70b", "mixtral-8x7b"} available_model_ids = [] class Payload(BaseModel): model: str messages: list stream: bool = False @app.get("/favicon.ico") async def favicon(): # The favicon.ico file is in the same directory as the app favicon_path = Path(__file__).parent / "favicon.ico" return FileResponse(favicon_path, media_type="image/x-icon") def generate_search(query: str, systemprompt: Optional[str] = None, stream: bool = True) -> str: headers = {"User-Agent": ""} # Use the provided system prompt, or default to "Be Helpful and Friendly" system_message = systemprompt or "Be Helpful and Friendly" # Create the prompt history with the user query and system message prompt = [ {"role": "user", "content": query}, ] prompt.insert(0, {"content": system_message, "role": "system"}) # Prepare the payload for the API request payload = { "is_vscode_extension": True, "message_history": prompt, "requested_model": "searchgpt", "user_input": prompt[-1]["content"], } # Send the request to the chat endpoint response = requests.post(secret_api_endpoint_3, headers=headers, json=payload, stream=True) streaming_text = "" # Process the streaming response for value in response.iter_lines(decode_unicode=True): if value.startswith("data: "): try: json_modified_value = json.loads(value[6:]) content = json_modified_value.get("choices", [{}])[0].get("delta", {}).get("content", "") if content.strip(): # Only process non-empty content cleaned_response = { "created": json_modified_value.get("created"), "id": json_modified_value.get("id"), "model": "searchgpt", "object": "chat.completion", "choices": [ { "message": { "content": content } } ] } if stream: yield f"data: {json.dumps(cleaned_response)}\n\n" streaming_text += content except json.JSONDecodeError: continue if not stream: yield streaming_text @app.get("/ping") async def ping(): start_time = datetime.datetime.now() response_time = (datetime.datetime.now() - start_time).total_seconds() return {"message": "pong", "response_time": f"{response_time:.6f} seconds"} @app.get("/searchgpt") async def search_gpt(q: str, stream: Optional[bool] = False, systemprompt: Optional[str] = None,authenticated: bool = Depends(verify_api_key)): if not q: raise HTTPException(status_code=400, detail="Query parameter 'q' is required") usage_tracker.record_request(endpoint="/searchgpt") if stream: return StreamingResponse( generate_search(q, systemprompt=systemprompt, stream=True), media_type="text/event-stream" ) else: # For non-streaming, collect the text and return as JSON response response_text = "".join([chunk for chunk in generate_search(q, systemprompt=systemprompt, stream=False)]) return JSONResponse(content={"response": response_text}) @app.get("/", response_class=HTMLResponse) async def root(): # Open and read the content of index.html (in the same folder as the app) file_path = "index.html" try: with open(file_path, "r") as file: html_content = file.read() return HTMLResponse(content=html_content) except FileNotFoundError: return HTMLResponse(content="

File not found

", status_code=404) async def get_models(): try: # Load the models from models.json in the same folder file_path = Path(__file__).parent / 'models.json' with open(file_path, 'r') as f: return json.load(f) except FileNotFoundError: raise HTTPException(status_code=404, detail="models.json not found") except json.JSONDecodeError: raise HTTPException(status_code=500, detail="Error decoding models.json") @app.get("api/v1/models") @app.get("/models") async def return_models(): return await get_models() server_status = True @app.post("/chat/completions") @app.post("api/v1/chat/completions") async def get_completion(payload: Payload, request: Request,authenticated: bool = Depends(verify_api_key)): # Check server status model_to_use = payload.model if payload.model else "gpt-4o-mini" # Validate model availability if model_to_use not in available_model_ids: raise HTTPException( status_code=400, detail=f"Model '{model_to_use}' is not available. Check /models for the available model list." ) usage_tracker.record_request(model=model_to_use, endpoint="/chat/completions") # Prepare payload payload_dict = payload.dict() payload_dict["model"] = model_to_use # payload_dict["stream"] = payload_dict.get("stream", False) # Select the appropriate endpoint endpoint = secret_api_endpoint # Current time and IP logging current_time = (datetime.datetime.utcnow() + datetime.timedelta(hours=5, minutes=30)).strftime("%Y-%m-%d %I:%M:%S %p") aaip = request.client.host print(f"Time: {current_time}, {aaip} , {model_to_use}, server status :- {server_status}") print(payload_dict) if not server_status: return JSONResponse( status_code=503, content={"message": "Server is under maintenance. Please try again later."} ) scraper = cloudscraper.create_scraper() async def stream_generator(payload_dict): # Prepare custom headers custom_headers = { 'DNT': '1', # 'Origin': ENDPOINT_ORIGIN, 'Priority': 'u=1, i', # 'Referer': ENDPOINT_ORIGIN } try: # Send POST request using CloudScraper with custom headers response = scraper.post( f"{endpoint}/v1/chat/completions", json=payload_dict, headers=custom_headers, stream=True ) # Error handling remains the same as in previous version if response.status_code == 422: raise HTTPException(status_code=422, detail="Unprocessable entity. Check your payload.") elif response.status_code == 400: raise HTTPException(status_code=400, detail="Bad request. Verify input data.") elif response.status_code == 403: raise HTTPException(status_code=403, detail="Forbidden. You do not have access to this resource.") elif response.status_code == 404: raise HTTPException(status_code=404, detail="The requested resource was not found.") elif response.status_code >= 500: raise HTTPException(status_code=500, detail="Server error. Try again later.") # Stream response lines to the client for line in response.iter_lines(): if line: yield line.decode('utf-8') + "\n" except requests.exceptions.RequestException as req_err: # Handle request-specific errors print(response.text) raise HTTPException(status_code=500, detail=f"Request failed: {req_err}") except Exception as e: # Handle unexpected errors print(response.text) raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {e}") return StreamingResponse(stream_generator(payload_dict), media_type="application/json") # Remove the duplicated endpoint and combine the functionality @app.api_route("/images/generations", methods=["GET", "POST"]) # Support both GET and POST async def generate_image( prompt: Optional[str] = None, model: str = "flux", # Default model seed: Optional[int] = None, width: Optional[int] = None, height: Optional[int] = None, nologo: Optional[bool] = True, private: Optional[bool] = None, enhance: Optional[bool] = None, request: Request = None, # Access raw POST data authenticated: bool = Depends(verify_api_key) ): """ Generate an image using the Image Generation API. """ # Validate the image endpoint if not image_endpoint: raise HTTPException(status_code=500, detail="Image endpoint not configured in environment variables.") usage_tracker.record_request(endpoint="/images/generations") # Handle GET and POST prompts if request.method == "POST": try: body = await request.json() # Parse JSON body prompt = body.get("prompt", "").strip() if not prompt: raise HTTPException(status_code=400, detail="Prompt cannot be empty") except Exception: raise HTTPException(status_code=400, detail="Invalid JSON payload") elif request.method == "GET": if not prompt or not prompt.strip(): raise HTTPException(status_code=400, detail="Prompt cannot be empty") prompt = prompt.strip() # Sanitize and encode the prompt encoded_prompt = httpx.QueryParams({'prompt': prompt}).get('prompt') # Construct the URL with the encoded prompt base_url = image_endpoint.rstrip('/') # Remove trailing slash if present url = f"{base_url}/{encoded_prompt}" # Prepare query parameters with validation params = {} if model and isinstance(model, str): params['model'] = model if seed is not None and isinstance(seed, int): params['seed'] = seed if width is not None and isinstance(width, int) and 64 <= width <= 2048: params['width'] = width if height is not None and isinstance(height, int) and 64 <= height <= 2048: params['height'] = height if nologo is not None: params['nologo'] = str(nologo).lower() if private is not None: params['private'] = str(private).lower() if enhance is not None: params['enhance'] = str(enhance).lower() try: timeout = httpx.Timeout(60.0) # Set a reasonable timeout async with httpx.AsyncClient(timeout=timeout) as client: response = await client.get(url, params=params, follow_redirects=True) # Check for various error conditions if response.status_code == 404: raise HTTPException(status_code=404, detail="Image generation service not found") elif response.status_code == 400: raise HTTPException(status_code=400, detail="Invalid parameters provided to image service") elif response.status_code == 429: raise HTTPException(status_code=429, detail="Too many requests to image service") elif response.status_code != 200: raise HTTPException( status_code=response.status_code, detail=f"Image generation failed with status code {response.status_code}" ) # Verify content type content_type = response.headers.get('content-type', '') if not content_type.startswith('image/'): raise HTTPException( status_code=500, detail=f"Unexpected content type received: {content_type}" ) return StreamingResponse( response.iter_bytes(), media_type=content_type, headers={ 'Cache-Control': 'no-cache', 'Pragma': 'no-cache' } ) except httpx.TimeoutException: raise HTTPException(status_code=504, detail="Image generation request timed out") except httpx.RequestError as e: raise HTTPException(status_code=500, detail=f"Failed to contact image service: {str(e)}") except Exception as e: raise HTTPException(status_code=500, detail=f"Unexpected error during image generation: {str(e)}") @app.get("/playground", response_class=HTMLResponse) async def playground(): # Open and read the content of playground.html (in the same folder as the app) file_path = "playground.html" try: with open(file_path, "r") as file: html_content = file.read() return HTMLResponse(content=html_content) except FileNotFoundError: return HTMLResponse(content="

playground.html not found

", status_code=404) def load_model_ids(json_file_path): try: with open(json_file_path, 'r') as f: models_data = json.load(f) # Extract 'id' from each model object model_ids = [model['id'] for model in models_data if 'id' in model] return model_ids except FileNotFoundError: print("Error: models.json file not found.") return [] except json.JSONDecodeError: print("Error: Invalid JSON format in models.json.") return [] @app.get("/usage") async def get_usage(days: int = 7): """Retrieve usage statistics""" return usage_tracker.get_usage_summary(days) @app.get("/usage/page", response_class=HTMLResponse) async def usage_page(): """Serve an HTML page showing usage statistics""" # Retrieve usage data usage_data = usage_tracker.get_usage_summary() # Model Usage Table Rows model_usage_rows = "\n".join([ f""" {model} {model_data['total_requests']} {model_data['first_used']} {model_data['last_used']} """ for model, model_data in usage_data['models'].items() ]) # API Endpoint Usage Table Rows api_usage_rows = "\n".join([ f""" {endpoint} {endpoint_data['total_requests']} {endpoint_data['first_used']} {endpoint_data['last_used']} """ for endpoint, endpoint_data in usage_data['api_endpoints'].items() ]) # Daily Usage Table Rows daily_usage_rows = "\n".join([ "\n".join([ f""" {date} {entity} {requests} """ for entity, requests in date_data.items() ]) for date, date_data in usage_data['recent_daily_usage'].items() ]) html_content = f""" Lokiai AI - Usage Statistics
Total API Requests: {usage_data['total_requests']}

Model Usage

{model_usage_rows}
Model Total Requests First Used Last Used

API Endpoint Usage

{api_usage_rows}
Endpoint Total Requests First Used Last Used

Daily Usage (Last 7 Days)

{daily_usage_rows}
Date Entity Requests
""" return HTMLResponse(content=html_content) @app.get("/meme") async def get_meme(): try: response = requests.get("https://meme-api.com/gimme") response_data = response.json() meme_url = response_data.get("url") if meme_url: def stream_image(): with requests.get(meme_url, stream=True) as image_response: for chunk in image_response.iter_content(chunk_size=1024): yield chunk return StreamingResponse(stream_image(), media_type="image/png") else: raise HTTPException(status_code=404, detail="No mimi found :(") except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.on_event("startup") async def startup_event(): global available_model_ids available_model_ids = load_model_ids("models.json") print(f"Loaded model IDs: {available_model_ids}") print("API endpoints:") print("GET /") print("GET /models") print("GET /searchgpt") print("POST /chat/completions") print("GET /images/generations") if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)