lokiai / main.py
ParthSadaria's picture
Update main.py
614a889 verified
raw
history blame
7.9 kB
import os
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import StreamingResponse, HTMLResponse, JSONResponse, FileResponse
from pydantic import BaseModel
import httpx
from pathlib import Path # Import Path from pathlib
import requests
import re
import json
from typing import Optional
load_dotenv()
app = FastAPI()
# Get API keys and secret endpoint from environment variables
api_keys_str = os.getenv('API_KEYS')
valid_api_keys = api_keys_str.split(',') if api_keys_str else []
secret_api_endpoint = os.getenv('SECRET_API_ENDPOINT')
secret_api_endpoint_2 = os.getenv('SECRET_API_ENDPOINT_2')
secret_api_endpoint_3 = os.getenv('SECRET_API_ENDPOINT_3') # New endpoint for searchgpt
# Validate if the main secret API endpoints are set
if not secret_api_endpoint or not secret_api_endpoint_2 or not secret_api_endpoint_3:
raise HTTPException(status_code=500, detail="API endpoint(s) are not configured in environment variables.")
# Define models that should use the secondary endpoint
alternate_models = {"gpt-4o-mini", "claude-3-haiku", "llama-3.1-70b", "mixtral-8x7b"}
class Payload(BaseModel):
model: str
messages: list
stream: bool
@app.get("/favicon.ico")
async def favicon():
# The favicon.ico file is in the same directory as the app
favicon_path = Path(__file__).parent / "favicon.ico"
return FileResponse(favicon_path, media_type="image/x-icon")
def generate_search(query: str, stream: bool = True) -> str:
headers = {"User-Agent": ""}
prompt = [
{"role": "user", "content": query},
]
# Insert the system prompt at the beginning of the conversation history
prompt.insert(0, {"content": "Be Helpful and Friendly", "role": "system"})
payload = {
"is_vscode_extension": True,
"message_history": prompt,
"requested_model": "searchgpt",
"user_input": prompt[-1]["content"],
}
# Use the newly added SECRET_API_ENDPOINT_3 for the search API call
chat_endpoint = secret_api_endpoint_3
response = requests.post(chat_endpoint, headers=headers, json=payload, stream=True)
# Collect streamed text content
streaming_text = ""
for value in response.iter_lines(decode_unicode=True):
# Ensure the value starts with 'data: ' and process it
if value.startswith("data: "):
try:
json_modified_value = json.loads(value[6:]) # Remove 'data: ' prefix
content = json_modified_value.get("choices", [{}])[0].get("delta", {}).get("content", "")
# Include everything, even if it's just whitespace
if stream:
yield f"data: {json.dumps({'response': content})}\n\n"
streaming_text += content
except json.JSONDecodeError:
continue # Skip lines that are not valid JSON
# If not streaming, yield the full collected content
if not stream:
yield streaming_text
@app.get("/searchgpt")
async def search_gpt(q: str, stream: Optional[bool] = False):
if not q:
raise HTTPException(status_code=400, detail="Query parameter 'q' is required")
if stream:
return StreamingResponse(
generate_search(q, stream=True),
media_type="text/event-stream"
)
else:
# For non-streaming response, collect all content and return as JSON
response_text = "".join([chunk for chunk in generate_search(q, stream=False)])
return JSONResponse(content={"response": response_text})
@app.get("/", response_class=HTMLResponse)
async def root():
# Open and read the content of index.html (in the same folder as the app)
file_path = "index.html"
try:
with open(file_path, "r") as file:
html_content = file.read()
return HTMLResponse(content=html_content)
except FileNotFoundError:
return HTMLResponse(content="<h1>File not found</h1>", status_code=404)
async def get_models():
try:
# Load the models from models.json in the same folder
file_path = Path(__file__).parent / 'models.json'
with open(file_path, 'r') as f:
return json.load(f)
except FileNotFoundError:
raise HTTPException(status_code=404, detail="models.json not found")
except json.JSONDecodeError:
raise HTTPException(status_code=500, detail="Error decoding models.json")
@app.get("/models")
async def fetch_models():
return await get_models()
@app.post("/chat/completions")
@app.post("/v1/chat/completions")
async def get_completion(payload: Payload, request: Request):
model_to_use = payload.model
payload_dict = payload.dict()
payload_dict["model"] = model_to_use
# Select the appropriate endpoint
endpoint = secret_api_endpoint_2 if model_to_use in alternate_models else secret_api_endpoint
print(payload_dict)
async def stream_generator(payload_dict):
async with httpx.AsyncClient() as client:
try:
async with client.stream("POST", f"{endpoint}/v1/chat/completions", json=payload_dict, timeout=10) as response:
if response.status_code == 422:
# Handle unprocessable entity errors
raise HTTPException(status_code=422, detail="Unprocessable entity. Check your payload.")
elif response.status_code == 400:
# Handle bad request errors
raise HTTPException(status_code=400, detail="Bad request. Verify input data.")
elif response.status_code == 403:
# Handle forbidden access
raise HTTPException(status_code=403, detail="Forbidden. You do not have access to this resource.")
elif response.status_code == 404:
# Handle not found errors
raise HTTPException(status_code=404, detail="The requested resource was not found.")
elif response.status_code >= 500:
# Handle server errors
raise HTTPException(status_code=500, detail="Server error. Try again later.")
response.raise_for_status() # Raise HTTPStatusError for non-200 responses not explicitly handled
# Stream response to the client
async for line in response.aiter_lines():
if line:
yield f"{line}\n"
except httpx.HTTPStatusError as status_err:
# Catch specific HTTP errors
raise HTTPException(
status_code=status_err.response.status_code,
detail=f"HTTP error: {status_err.response.text}"
)
except httpx.TimeoutException:
# Handle timeout errors
raise HTTPException(status_code=504, detail="Request timed out. Please try again later.")
except httpx.RequestError as req_err:
# Handle generic request errors
raise HTTPException(status_code=500, detail=f"Request failed: {req_err}")
except Exception as e:
# Catch any unexpected exceptions
raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {e}")
return StreamingResponse(stream_generator(payload_dict), media_type="application/json")
@app.on_event("startup")
async def startup_event():
print("API endpoints:")
print("GET /")
print("GET /models")
print("GET /searchgpt") # We now have the new search API
print("POST /chat/completions")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)