Spaces:
Running
Running
File size: 7,895 Bytes
4986fe4 d38c2eb ef215d3 9fa0f10 378f2c3 8834a20 fc764d5 a111cf9 378f2c3 d38c2eb 4986fe4 378f2c3 b955cc1 4986fe4 b955cc1 a111cf9 4986fe4 a111cf9 b955cc1 378f2c3 b955cc1 006f05b a111cf9 afe2e88 a111cf9 4d88866 fa65dba a111cf9 4d88866 fa65dba 4d88866 a111cf9 fa65dba a111cf9 ae066fd a111cf9 d5656a9 a111cf9 2c1c62a 7ef5d89 c613f2b 7ef5d89 a68045e 614a889 7ef5d89 ac4bad0 3109050 a0270ea ef215d3 0881536 b955cc1 ba11b8c a0270ea 3109050 8834a20 b955cc1 ba11b8c 114ca84 3109050 ba11b8c 3109050 ba11b8c 3109050 ba11b8c 3109050 a0270ea 3109050 7ef5d89 a111cf9 6a84e5c 8e4491b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 |
import os
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import StreamingResponse, HTMLResponse, JSONResponse, FileResponse
from pydantic import BaseModel
import httpx
from pathlib import Path # Import Path from pathlib
import requests
import re
import json
from typing import Optional
load_dotenv()
app = FastAPI()
# Get API keys and secret endpoint from environment variables
api_keys_str = os.getenv('API_KEYS')
valid_api_keys = api_keys_str.split(',') if api_keys_str else []
secret_api_endpoint = os.getenv('SECRET_API_ENDPOINT')
secret_api_endpoint_2 = os.getenv('SECRET_API_ENDPOINT_2')
secret_api_endpoint_3 = os.getenv('SECRET_API_ENDPOINT_3') # New endpoint for searchgpt
# Validate if the main secret API endpoints are set
if not secret_api_endpoint or not secret_api_endpoint_2 or not secret_api_endpoint_3:
raise HTTPException(status_code=500, detail="API endpoint(s) are not configured in environment variables.")
# Define models that should use the secondary endpoint
alternate_models = {"gpt-4o-mini", "claude-3-haiku", "llama-3.1-70b", "mixtral-8x7b"}
class Payload(BaseModel):
model: str
messages: list
stream: bool
@app.get("/favicon.ico")
async def favicon():
# The favicon.ico file is in the same directory as the app
favicon_path = Path(__file__).parent / "favicon.ico"
return FileResponse(favicon_path, media_type="image/x-icon")
def generate_search(query: str, stream: bool = True) -> str:
headers = {"User-Agent": ""}
prompt = [
{"role": "user", "content": query},
]
# Insert the system prompt at the beginning of the conversation history
prompt.insert(0, {"content": "Be Helpful and Friendly", "role": "system"})
payload = {
"is_vscode_extension": True,
"message_history": prompt,
"requested_model": "searchgpt",
"user_input": prompt[-1]["content"],
}
# Use the newly added SECRET_API_ENDPOINT_3 for the search API call
chat_endpoint = secret_api_endpoint_3
response = requests.post(chat_endpoint, headers=headers, json=payload, stream=True)
# Collect streamed text content
streaming_text = ""
for value in response.iter_lines(decode_unicode=True):
# Ensure the value starts with 'data: ' and process it
if value.startswith("data: "):
try:
json_modified_value = json.loads(value[6:]) # Remove 'data: ' prefix
content = json_modified_value.get("choices", [{}])[0].get("delta", {}).get("content", "")
# Include everything, even if it's just whitespace
if stream:
yield f"data: {json.dumps({'response': content})}\n\n"
streaming_text += content
except json.JSONDecodeError:
continue # Skip lines that are not valid JSON
# If not streaming, yield the full collected content
if not stream:
yield streaming_text
@app.get("/searchgpt")
async def search_gpt(q: str, stream: Optional[bool] = False):
if not q:
raise HTTPException(status_code=400, detail="Query parameter 'q' is required")
if stream:
return StreamingResponse(
generate_search(q, stream=True),
media_type="text/event-stream"
)
else:
# For non-streaming response, collect all content and return as JSON
response_text = "".join([chunk for chunk in generate_search(q, stream=False)])
return JSONResponse(content={"response": response_text})
@app.get("/", response_class=HTMLResponse)
async def root():
# Open and read the content of index.html (in the same folder as the app)
file_path = "index.html"
try:
with open(file_path, "r") as file:
html_content = file.read()
return HTMLResponse(content=html_content)
except FileNotFoundError:
return HTMLResponse(content="<h1>File not found</h1>", status_code=404)
async def get_models():
try:
# Load the models from models.json in the same folder
file_path = Path(__file__).parent / 'models.json'
with open(file_path, 'r') as f:
return json.load(f)
except FileNotFoundError:
raise HTTPException(status_code=404, detail="models.json not found")
except json.JSONDecodeError:
raise HTTPException(status_code=500, detail="Error decoding models.json")
@app.get("/models")
async def fetch_models():
return await get_models()
@app.post("/chat/completions")
@app.post("/v1/chat/completions")
async def get_completion(payload: Payload, request: Request):
model_to_use = payload.model
payload_dict = payload.dict()
payload_dict["model"] = model_to_use
# Select the appropriate endpoint
endpoint = secret_api_endpoint_2 if model_to_use in alternate_models else secret_api_endpoint
print(payload_dict)
async def stream_generator(payload_dict):
async with httpx.AsyncClient() as client:
try:
async with client.stream("POST", f"{endpoint}/v1/chat/completions", json=payload_dict, timeout=10) as response:
if response.status_code == 422:
# Handle unprocessable entity errors
raise HTTPException(status_code=422, detail="Unprocessable entity. Check your payload.")
elif response.status_code == 400:
# Handle bad request errors
raise HTTPException(status_code=400, detail="Bad request. Verify input data.")
elif response.status_code == 403:
# Handle forbidden access
raise HTTPException(status_code=403, detail="Forbidden. You do not have access to this resource.")
elif response.status_code == 404:
# Handle not found errors
raise HTTPException(status_code=404, detail="The requested resource was not found.")
elif response.status_code >= 500:
# Handle server errors
raise HTTPException(status_code=500, detail="Server error. Try again later.")
response.raise_for_status() # Raise HTTPStatusError for non-200 responses not explicitly handled
# Stream response to the client
async for line in response.aiter_lines():
if line:
yield f"{line}\n"
except httpx.HTTPStatusError as status_err:
# Catch specific HTTP errors
raise HTTPException(
status_code=status_err.response.status_code,
detail=f"HTTP error: {status_err.response.text}"
)
except httpx.TimeoutException:
# Handle timeout errors
raise HTTPException(status_code=504, detail="Request timed out. Please try again later.")
except httpx.RequestError as req_err:
# Handle generic request errors
raise HTTPException(status_code=500, detail=f"Request failed: {req_err}")
except Exception as e:
# Catch any unexpected exceptions
raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {e}")
return StreamingResponse(stream_generator(payload_dict), media_type="application/json")
@app.on_event("startup")
async def startup_event():
print("API endpoints:")
print("GET /")
print("GET /models")
print("GET /searchgpt") # We now have the new search API
print("POST /chat/completions")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
|