Spaces:
Running
Running
Update main.py
Browse files
main.py
CHANGED
@@ -17,7 +17,44 @@ import time
|
|
17 |
from usage_tracker import UsageTracker
|
18 |
from starlette.middleware.base import BaseHTTPMiddleware
|
19 |
from collections import defaultdict
|
|
|
|
|
|
|
|
|
20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
class RateLimitMiddleware(BaseHTTPMiddleware):
|
22 |
def __init__(self, app, requests_per_second: int = 2):
|
23 |
super().__init__(app)
|
@@ -62,7 +99,6 @@ app = FastAPI()
|
|
62 |
app.add_middleware(RateLimitMiddleware, requests_per_second=2)
|
63 |
|
64 |
# Get API keys and secret endpoint from environment variables
|
65 |
-
api_keys_str = os.getenv('API_KEYS') #deprecated -_-
|
66 |
valid_api_keys = api_keys_str.split(',') if api_keys_str else []
|
67 |
secret_api_endpoint = os.getenv('SECRET_API_ENDPOINT')
|
68 |
secret_api_endpoint_2 = os.getenv('SECRET_API_ENDPOINT_2')
|
@@ -75,7 +111,7 @@ if not secret_api_endpoint or not secret_api_endpoint_2 or not secret_api_endpoi
|
|
75 |
raise HTTPException(status_code=500, detail="API endpoint(s) are not configured in environment variables.")
|
76 |
|
77 |
# Define models that should use the secondary endpoint
|
78 |
-
alternate_models = {"gpt-4o-mini", "claude-3-haiku", "llama-3.1-70b", "mixtral-8x7b"}
|
79 |
|
80 |
available_model_ids = []
|
81 |
class Payload(BaseModel):
|
@@ -154,7 +190,7 @@ async def ping():
|
|
154 |
return {"message": "pong", "response_time": f"{response_time:.6f} seconds"}
|
155 |
|
156 |
@app.get("/searchgpt")
|
157 |
-
async def search_gpt(q: str, stream: Optional[bool] = False, systemprompt: Optional[str] = None):
|
158 |
if not q:
|
159 |
raise HTTPException(status_code=400, detail="Query parameter 'q' is required")
|
160 |
usage_tracker.record_request(endpoint="/searchgpt")
|
@@ -191,12 +227,12 @@ async def get_models():
|
|
191 |
raise HTTPException(status_code=500, detail="Error decoding models.json")
|
192 |
@app.get("api/v1/models")
|
193 |
@app.get("/models")
|
194 |
-
async def fetch_models():
|
195 |
return await get_models()
|
196 |
server_status = True
|
197 |
@app.post("/chat/completions")
|
198 |
@app.post("api/v1/chat/completions")
|
199 |
-
async def get_completion(payload: Payload, request: Request):
|
200 |
# Check server status
|
201 |
|
202 |
|
@@ -216,7 +252,7 @@ async def get_completion(payload: Payload, request: Request):
|
|
216 |
payload_dict["model"] = model_to_use
|
217 |
# payload_dict["stream"] = payload_dict.get("stream", False)
|
218 |
# Select the appropriate endpoint
|
219 |
-
endpoint =
|
220 |
|
221 |
# Current time and IP logging
|
222 |
current_time = (datetime.datetime.utcnow() + datetime.timedelta(hours=5, minutes=30)).strftime("%Y-%m-%d %I:%M:%S %p")
|
@@ -286,6 +322,7 @@ async def generate_image(
|
|
286 |
private: Optional[bool] = None,
|
287 |
enhance: Optional[bool] = None,
|
288 |
request: Request = None, # Access raw POST data
|
|
|
289 |
):
|
290 |
"""
|
291 |
Generate an image using the Image Generation API.
|
|
|
17 |
from usage_tracker import UsageTracker
|
18 |
from starlette.middleware.base import BaseHTTPMiddleware
|
19 |
from collections import defaultdict
|
20 |
+
from fastapi import Security #new
|
21 |
+
from fastapi.security import APIKeyHeader
|
22 |
+
from starlette.exceptions import HTTPException
|
23 |
+
from starlette.status import HTTP_403_FORBIDDEN
|
24 |
|
25 |
+
# API key header scheme
|
26 |
+
api_key_header = APIKeyHeader(name="Authorization", auto_error=False)
|
27 |
+
|
28 |
+
# Function to validate API key
|
29 |
+
async def verify_api_key(api_key: str = Security(api_key_header)) -> bool:
|
30 |
+
if not api_key:
|
31 |
+
raise HTTPException(
|
32 |
+
status_code=HTTP_403_FORBIDDEN,
|
33 |
+
detail="No API key provided"
|
34 |
+
)
|
35 |
+
|
36 |
+
# Clean the API key by removing 'Bearer ' if present
|
37 |
+
if api_key.startswith('Bearer '):
|
38 |
+
api_key = api_key[7:] # Remove 'Bearer ' prefix
|
39 |
+
|
40 |
+
# Get API keys from environment
|
41 |
+
api_keys_str = os.getenv('API_KEYS')
|
42 |
+
if not api_keys_str:
|
43 |
+
raise HTTPException(
|
44 |
+
status_code=HTTP_403_FORBIDDEN,
|
45 |
+
detail="API keys not configured on server"
|
46 |
+
)
|
47 |
+
|
48 |
+
valid_api_keys = api_keys_str.split(',')
|
49 |
+
|
50 |
+
# Check if the provided key is valid
|
51 |
+
if api_key not in valid_api_keys:
|
52 |
+
raise HTTPException(
|
53 |
+
status_code=HTTP_403_FORBIDDEN,
|
54 |
+
detail="Invalid API key"
|
55 |
+
)
|
56 |
+
|
57 |
+
return True
|
58 |
class RateLimitMiddleware(BaseHTTPMiddleware):
|
59 |
def __init__(self, app, requests_per_second: int = 2):
|
60 |
super().__init__(app)
|
|
|
99 |
app.add_middleware(RateLimitMiddleware, requests_per_second=2)
|
100 |
|
101 |
# Get API keys and secret endpoint from environment variables
|
|
|
102 |
valid_api_keys = api_keys_str.split(',') if api_keys_str else []
|
103 |
secret_api_endpoint = os.getenv('SECRET_API_ENDPOINT')
|
104 |
secret_api_endpoint_2 = os.getenv('SECRET_API_ENDPOINT_2')
|
|
|
111 |
raise HTTPException(status_code=500, detail="API endpoint(s) are not configured in environment variables.")
|
112 |
|
113 |
# Define models that should use the secondary endpoint
|
114 |
+
# alternate_models = {"gpt-4o-mini", "claude-3-haiku", "llama-3.1-70b", "mixtral-8x7b"}
|
115 |
|
116 |
available_model_ids = []
|
117 |
class Payload(BaseModel):
|
|
|
190 |
return {"message": "pong", "response_time": f"{response_time:.6f} seconds"}
|
191 |
|
192 |
@app.get("/searchgpt")
|
193 |
+
async def search_gpt(q: str, stream: Optional[bool] = False, systemprompt: Optional[str] = None,authenticated: bool = Depends(verify_api_key)):
|
194 |
if not q:
|
195 |
raise HTTPException(status_code=400, detail="Query parameter 'q' is required")
|
196 |
usage_tracker.record_request(endpoint="/searchgpt")
|
|
|
227 |
raise HTTPException(status_code=500, detail="Error decoding models.json")
|
228 |
@app.get("api/v1/models")
|
229 |
@app.get("/models")
|
230 |
+
async def fetch_models(authenticated: bool = Depends(verify_api_key)):
|
231 |
return await get_models()
|
232 |
server_status = True
|
233 |
@app.post("/chat/completions")
|
234 |
@app.post("api/v1/chat/completions")
|
235 |
+
async def get_completion(payload: Payload, request: Request,authenticated: bool = Depends(verify_api_key)):
|
236 |
# Check server status
|
237 |
|
238 |
|
|
|
252 |
payload_dict["model"] = model_to_use
|
253 |
# payload_dict["stream"] = payload_dict.get("stream", False)
|
254 |
# Select the appropriate endpoint
|
255 |
+
endpoint = secret_api_endpoint
|
256 |
|
257 |
# Current time and IP logging
|
258 |
current_time = (datetime.datetime.utcnow() + datetime.timedelta(hours=5, minutes=30)).strftime("%Y-%m-%d %I:%M:%S %p")
|
|
|
322 |
private: Optional[bool] = None,
|
323 |
enhance: Optional[bool] = None,
|
324 |
request: Request = None, # Access raw POST data
|
325 |
+
authenticated: bool = Depends(verify_api_key)
|
326 |
):
|
327 |
"""
|
328 |
Generate an image using the Image Generation API.
|