Update main.py
Browse files
main.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
# main.py
|
2 |
import os
|
3 |
import re
|
4 |
import random
|
@@ -17,25 +16,13 @@ from slowapi import Limiter, _rate_limit_exceeded_handler
|
|
17 |
from slowapi.util import get_remote_address
|
18 |
from slowapi.errors import RateLimitExceeded
|
19 |
from slowapi.middleware import SlowAPIMiddleware
|
20 |
-
from fastapi import FastAPI, HTTPException, Request, Depends, Header
|
21 |
from fastapi.responses import StreamingResponse, JSONResponse, RedirectResponse
|
|
|
22 |
from pydantic import BaseModel
|
23 |
|
24 |
-
from sqlalchemy.orm import Session
|
25 |
-
|
26 |
from aiohttp import ClientSession, ClientTimeout, ClientError
|
27 |
|
28 |
-
from database import SessionLocal, engine, get_db
|
29 |
-
from models import Base, Image, Log
|
30 |
-
|
31 |
-
from dotenv import load_dotenv
|
32 |
-
|
33 |
-
# Load environment variables from .env file
|
34 |
-
load_dotenv()
|
35 |
-
|
36 |
-
# Create all tables
|
37 |
-
Base.metadata.create_all(bind=engine)
|
38 |
-
|
39 |
# Define a context variable for client_ip
|
40 |
client_ip_var = contextvars.ContextVar("client_ip", default="N/A")
|
41 |
|
@@ -57,21 +44,21 @@ logger = logging.getLogger("main") # Use a specific logger name if needed
|
|
57 |
logger.setLevel(logging.INFO)
|
58 |
|
59 |
# Create handlers
|
60 |
-
|
61 |
-
|
62 |
|
63 |
# Create and set the custom formatter
|
64 |
formatter = SafeFormatter(
|
65 |
fmt="%(asctime)s [%(levelname)s] %(name)s [IP: %(client_ip)s]: %(message)s",
|
66 |
datefmt="%Y-%m-%d %H:%M:%S"
|
67 |
)
|
68 |
-
|
69 |
|
70 |
-
# Add the custom filter to the
|
71 |
-
|
72 |
|
73 |
# Add handlers to the logger
|
74 |
-
logger.addHandler(
|
75 |
|
76 |
# Initialize the limiter with slowapi
|
77 |
limiter = Limiter(key_func=get_remote_address, default_limits=["60/minute"])
|
@@ -84,106 +71,50 @@ app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
|
|
84 |
# Add SlowAPI middleware
|
85 |
app.add_middleware(SlowAPIMiddleware)
|
86 |
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
)
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
client_ip_var.set(client_ip)
|
124 |
-
|
125 |
-
# Enforce that POST requests to sensitive endpoints must have a valid Content-Type
|
126 |
-
if request.method == "POST" and request.url.path == "/v1/chat/completions":
|
127 |
-
content_type = request.headers.get("Content-Type")
|
128 |
-
if content_type != "application/json":
|
129 |
-
logger.warning("Invalid Content-Type for /v1/chat/completions")
|
130 |
-
return JSONResponse(
|
131 |
-
status_code=400,
|
132 |
-
content={
|
133 |
-
"error": {
|
134 |
-
"message": "Content-Type must be application/json",
|
135 |
-
"type": "invalid_request_error",
|
136 |
-
"param": None,
|
137 |
-
"code": None
|
138 |
-
}
|
139 |
-
},
|
140 |
-
)
|
141 |
-
|
142 |
-
# Log the incoming request
|
143 |
-
logger.info(f"Incoming request: {request.method} {request.url.path}")
|
144 |
-
|
145 |
-
response = await call_next(request)
|
146 |
-
|
147 |
-
# Log the response status
|
148 |
-
logger.info(f"Response status: {response.status_code}")
|
149 |
-
|
150 |
-
return response
|
151 |
-
|
152 |
-
class Message(BaseModel):
|
153 |
-
role: str
|
154 |
-
content: str
|
155 |
-
|
156 |
-
class ChatRequest(BaseModel):
|
157 |
-
model: str
|
158 |
-
messages: List[Message]
|
159 |
-
temperature: Optional[float] = 1.0
|
160 |
-
top_p: Optional[float] = 1.0
|
161 |
-
n: Optional[int] = 1
|
162 |
-
stream: Optional[bool] = False
|
163 |
-
stop: Optional[Union[str, List[str]]] = None
|
164 |
-
max_tokens: Optional[int] = None
|
165 |
-
presence_penalty: Optional[float] = 0.0
|
166 |
-
frequency_penalty: Optional[float] = 0.0
|
167 |
-
logit_bias: Optional[Dict[str, float]] = None
|
168 |
-
user: Optional[str] = None
|
169 |
-
webSearchMode: Optional[bool] = False # Custom parameter
|
170 |
-
|
171 |
-
def create_response(content: str, model: str, finish_reason: Optional[str] = None) -> Dict[str, Any]:
|
172 |
-
return {
|
173 |
-
"id": f"chatcmpl-{uuid.uuid4()}",
|
174 |
-
"object": "chat.completion.chunk",
|
175 |
-
"created": int(datetime.now().timestamp()),
|
176 |
-
"model": model,
|
177 |
-
"choices": [
|
178 |
-
{
|
179 |
-
"index": 0,
|
180 |
-
"delta": {"content": content, "role": "assistant"},
|
181 |
-
"finish_reason": finish_reason,
|
182 |
-
}
|
183 |
-
],
|
184 |
-
"usage": None,
|
185 |
-
}
|
186 |
-
|
187 |
class ImageResponse:
|
188 |
def __init__(self, url: str, alt: str):
|
189 |
self.url = url
|
@@ -230,9 +161,7 @@ class Blackbox:
|
|
230 |
]
|
231 |
|
232 |
# Filter models based on AVAILABLE_MODELS
|
233 |
-
AVAILABLE_MODELS = os.getenv("AVAILABLE_MODELS", "")
|
234 |
if AVAILABLE_MODELS:
|
235 |
-
AVAILABLE_MODELS = [model.strip() for model in AVAILABLE_MODELS.split(',') if model.strip()]
|
236 |
models = [model for model in models if model in AVAILABLE_MODELS]
|
237 |
|
238 |
agentMode = {
|
@@ -472,8 +401,81 @@ class Blackbox:
|
|
472 |
if attempt == retry_attempts - 1:
|
473 |
raise HTTPException(status_code=500, detail=str(e))
|
474 |
|
475 |
-
|
476 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
477 |
# Redact user messages only for logging purposes
|
478 |
redacted_messages = [{"role": msg.role, "content": "[redacted]"} for msg in request.messages]
|
479 |
|
@@ -501,14 +503,6 @@ async def chat_completions(request: ChatRequest, req: Request, api_key: str = De
|
|
501 |
if isinstance(chunk, ImageResponse):
|
502 |
image_markdown = f""
|
503 |
response_chunk = create_response(image_markdown, request.model)
|
504 |
-
|
505 |
-
# Store image in the database
|
506 |
-
image_entry = Image(
|
507 |
-
image_url=chunk.url,
|
508 |
-
description=request.messages[-1].get('content', '')
|
509 |
-
)
|
510 |
-
db.add(image_entry)
|
511 |
-
db.commit()
|
512 |
else:
|
513 |
response_chunk = create_response(chunk, request.model)
|
514 |
|
@@ -529,14 +523,6 @@ async def chat_completions(request: ChatRequest, req: Request, api_key: str = De
|
|
529 |
async for chunk in async_generator:
|
530 |
if isinstance(chunk, ImageResponse):
|
531 |
response_content += f"\n"
|
532 |
-
|
533 |
-
# Store image in the database
|
534 |
-
image_entry = Image(
|
535 |
-
image_url=chunk.url,
|
536 |
-
description=request.messages[-1].get('content', '')
|
537 |
-
)
|
538 |
-
db.add(image_entry)
|
539 |
-
db.commit()
|
540 |
else:
|
541 |
response_content += chunk
|
542 |
|
|
|
|
|
1 |
import os
|
2 |
import re
|
3 |
import random
|
|
|
16 |
from slowapi.util import get_remote_address
|
17 |
from slowapi.errors import RateLimitExceeded
|
18 |
from slowapi.middleware import SlowAPIMiddleware
|
19 |
+
from fastapi import FastAPI, HTTPException, Request, Depends, Header, Security
|
20 |
from fastapi.responses import StreamingResponse, JSONResponse, RedirectResponse
|
21 |
+
from fastapi.security.api_key import APIKeyHeader
|
22 |
from pydantic import BaseModel
|
23 |
|
|
|
|
|
24 |
from aiohttp import ClientSession, ClientTimeout, ClientError
|
25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
# Define a context variable for client_ip
|
27 |
client_ip_var = contextvars.ContextVar("client_ip", default="N/A")
|
28 |
|
|
|
44 |
logger.setLevel(logging.INFO)
|
45 |
|
46 |
# Create handlers
|
47 |
+
handler = logging.StreamHandler()
|
48 |
+
handler.setLevel(logging.INFO)
|
49 |
|
50 |
# Create and set the custom formatter
|
51 |
formatter = SafeFormatter(
|
52 |
fmt="%(asctime)s [%(levelname)s] %(name)s [IP: %(client_ip)s]: %(message)s",
|
53 |
datefmt="%Y-%m-%d %H:%M:%S"
|
54 |
)
|
55 |
+
handler.setFormatter(formatter)
|
56 |
|
57 |
+
# Add the custom filter to the handler
|
58 |
+
handler.addFilter(ContextFilter())
|
59 |
|
60 |
# Add handlers to the logger
|
61 |
+
logger.addHandler(handler)
|
62 |
|
63 |
# Initialize the limiter with slowapi
|
64 |
limiter = Limiter(key_func=get_remote_address, default_limits=["60/minute"])
|
|
|
71 |
# Add SlowAPI middleware
|
72 |
app.add_middleware(SlowAPIMiddleware)
|
73 |
|
74 |
+
# Load environment variables
|
75 |
+
API_KEYS = os.getenv('API_KEYS', '').split(',') # Comma-separated API keys
|
76 |
+
RATE_LIMIT = int(os.getenv('RATE_LIMIT', '60')) # Requests per minute
|
77 |
+
AVAILABLE_MODELS = os.getenv('AVAILABLE_MODELS', '') # Comma-separated available models
|
78 |
+
|
79 |
+
if not API_KEYS or API_KEYS == ['']:
|
80 |
+
logger.error("No API keys found. Please set the API_KEYS environment variable.")
|
81 |
+
raise Exception("API_KEYS environment variable not set.")
|
82 |
+
|
83 |
+
# Define API key security
|
84 |
+
API_KEY_NAME = "X-API-Key"
|
85 |
+
api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False)
|
86 |
+
|
87 |
+
async def get_api_key(api_key: str = Security(api_key_header)):
|
88 |
+
if not api_key:
|
89 |
+
logger.warning("API key missing")
|
90 |
+
raise HTTPException(
|
91 |
+
status_code=401,
|
92 |
+
detail="API key missing",
|
93 |
+
headers={"WWW-Authenticate": "API Key"},
|
94 |
)
|
95 |
+
if api_key not in API_KEYS:
|
96 |
+
logger.warning(f"Invalid API key: {api_key}")
|
97 |
+
raise HTTPException(
|
98 |
+
status_code=403,
|
99 |
+
detail="Invalid API key",
|
100 |
+
headers={"WWW-Authenticate": "API Key"},
|
101 |
+
)
|
102 |
+
return api_key
|
103 |
+
|
104 |
+
# Process available models
|
105 |
+
if AVAILABLE_MODELS:
|
106 |
+
AVAILABLE_MODELS = [model.strip() for model in AVAILABLE_MODELS.split(',') if model.strip()]
|
107 |
+
else:
|
108 |
+
AVAILABLE_MODELS = [] # If empty, all models are available
|
109 |
+
|
110 |
+
# Custom exception for model not working
|
111 |
+
class ModelNotWorkingException(Exception):
|
112 |
+
def __init__(self, model: str):
|
113 |
+
self.model = model
|
114 |
+
self.message = f"The model '{model}' is currently not working. Please try another model or wait for it to be fixed."
|
115 |
+
super().__init__(self.message)
|
116 |
+
|
117 |
+
# Mock implementations for ImageResponse and to_data_uri
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
class ImageResponse:
|
119 |
def __init__(self, url: str, alt: str):
|
120 |
self.url = url
|
|
|
161 |
]
|
162 |
|
163 |
# Filter models based on AVAILABLE_MODELS
|
|
|
164 |
if AVAILABLE_MODELS:
|
|
|
165 |
models = [model for model in models if model in AVAILABLE_MODELS]
|
166 |
|
167 |
agentMode = {
|
|
|
401 |
if attempt == retry_attempts - 1:
|
402 |
raise HTTPException(status_code=500, detail=str(e))
|
403 |
|
404 |
+
# FastAPI app setup with slowapi rate limiting
|
405 |
+
@app.middleware("http")
|
406 |
+
async def security_middleware(request: Request, call_next):
|
407 |
+
client_ip = request.client.host
|
408 |
+
# Set the client_ip in the context variable
|
409 |
+
client_ip_var.set(client_ip)
|
410 |
+
|
411 |
+
# Enforce that POST requests to sensitive endpoints must have a valid Content-Type
|
412 |
+
if request.method == "POST" and request.url.path == "/v1/chat/completions":
|
413 |
+
content_type = request.headers.get("Content-Type")
|
414 |
+
if content_type != "application/json":
|
415 |
+
logger.warning("Invalid Content-Type for /v1/chat/completions")
|
416 |
+
return JSONResponse(
|
417 |
+
status_code=400,
|
418 |
+
content={
|
419 |
+
"error": {
|
420 |
+
"message": "Content-Type must be application/json",
|
421 |
+
"type": "invalid_request_error",
|
422 |
+
"param": None,
|
423 |
+
"code": None
|
424 |
+
}
|
425 |
+
},
|
426 |
+
)
|
427 |
+
|
428 |
+
# Log the incoming request
|
429 |
+
logger.info(f"Incoming request: {request.method} {request.url.path}")
|
430 |
+
|
431 |
+
response = await call_next(request)
|
432 |
+
|
433 |
+
# Log the response status
|
434 |
+
logger.info(f"Response status: {response.status_code}")
|
435 |
+
|
436 |
+
return response
|
437 |
+
|
438 |
+
class Message(BaseModel):
|
439 |
+
role: str
|
440 |
+
content: str
|
441 |
+
|
442 |
+
class ChatRequest(BaseModel):
|
443 |
+
model: str
|
444 |
+
messages: List[Message]
|
445 |
+
temperature: Optional[float] = 1.0
|
446 |
+
top_p: Optional[float] = 1.0
|
447 |
+
n: Optional[int] = 1
|
448 |
+
stream: Optional[bool] = False
|
449 |
+
stop: Optional[Union[str, List[str]]] = None
|
450 |
+
max_tokens: Optional[int] = None
|
451 |
+
presence_penalty: Optional[float] = 0.0
|
452 |
+
frequency_penalty: Optional[float] = 0.0
|
453 |
+
logit_bias: Optional[Dict[str, float]] = None
|
454 |
+
user: Optional[str] = None
|
455 |
+
webSearchMode: Optional[bool] = False # Custom parameter
|
456 |
+
|
457 |
+
def create_response(content: str, model: str, finish_reason: Optional[str] = None) -> Dict[str, Any]:
|
458 |
+
return {
|
459 |
+
"id": f"chatcmpl-{uuid.uuid4()}",
|
460 |
+
"object": "chat.completion.chunk",
|
461 |
+
"created": int(datetime.now().timestamp()),
|
462 |
+
"model": model,
|
463 |
+
"choices": [
|
464 |
+
{
|
465 |
+
"index": 0,
|
466 |
+
"delta": {"content": content, "role": "assistant"},
|
467 |
+
"finish_reason": finish_reason,
|
468 |
+
}
|
469 |
+
],
|
470 |
+
"usage": None,
|
471 |
+
}
|
472 |
+
|
473 |
+
@app.post("/v1/chat/completions", dependencies=[Depends(limiter.limit("60/minute"))])
|
474 |
+
async def chat_completions(
|
475 |
+
request: ChatRequest,
|
476 |
+
req: Request,
|
477 |
+
api_key: str = Depends(get_api_key)
|
478 |
+
):
|
479 |
# Redact user messages only for logging purposes
|
480 |
redacted_messages = [{"role": msg.role, "content": "[redacted]"} for msg in request.messages]
|
481 |
|
|
|
503 |
if isinstance(chunk, ImageResponse):
|
504 |
image_markdown = f""
|
505 |
response_chunk = create_response(image_markdown, request.model)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
506 |
else:
|
507 |
response_chunk = create_response(chunk, request.model)
|
508 |
|
|
|
523 |
async for chunk in async_generator:
|
524 |
if isinstance(chunk, ImageResponse):
|
525 |
response_content += f"\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
526 |
else:
|
527 |
response_content += chunk
|
528 |
|