Update main.py
Browse files
main.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
import os
|
2 |
import re
|
3 |
import random
|
@@ -7,16 +8,34 @@ import json
|
|
7 |
import logging
|
8 |
import asyncio
|
9 |
import time
|
|
|
10 |
from collections import defaultdict
|
11 |
from typing import List, Dict, Any, Optional, AsyncGenerator, Union
|
12 |
from datetime import datetime
|
13 |
-
import contextvars # Added for context variables
|
14 |
|
15 |
-
from
|
|
|
|
|
|
|
16 |
from fastapi import FastAPI, HTTPException, Request, Depends, Header
|
17 |
from fastapi.responses import StreamingResponse, JSONResponse, RedirectResponse
|
18 |
from pydantic import BaseModel
|
19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
# Define a context variable for client_ip
|
21 |
client_ip_var = contextvars.ContextVar("client_ip", default="N/A")
|
22 |
|
@@ -26,77 +45,145 @@ class ContextFilter(logging.Filter):
|
|
26 |
record.client_ip = client_ip_var.get()
|
27 |
return True
|
28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
# Configure logging
|
30 |
-
logging.
|
31 |
-
|
32 |
-
|
33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
|
36 |
-
# Get the root logger and add the ContextFilter
|
37 |
-
logger = logging.getLogger(__name__)
|
38 |
-
logger.addFilter(ContextFilter())
|
39 |
-
|
40 |
-
# Load environment variables
|
41 |
-
API_KEYS = os.getenv('API_KEYS', '').split(',') # Comma-separated API keys
|
42 |
-
RATE_LIMIT = int(os.getenv('RATE_LIMIT', '60')) # Requests per minute
|
43 |
-
AVAILABLE_MODELS = os.getenv('AVAILABLE_MODELS', '') # Comma-separated available models
|
44 |
-
|
45 |
-
if not API_KEYS or API_KEYS == ['']:
|
46 |
-
logger.error("No API keys found. Please set the API_KEYS environment variable.")
|
47 |
-
raise Exception("API_KEYS environment variable not set.")
|
48 |
-
|
49 |
-
# Process available models
|
50 |
-
if AVAILABLE_MODELS:
|
51 |
-
AVAILABLE_MODELS = [model.strip() for model in AVAILABLE_MODELS.split(',') if model.strip()]
|
52 |
-
else:
|
53 |
-
AVAILABLE_MODELS = [] # If empty, all models are available
|
54 |
-
|
55 |
-
# Simple in-memory rate limiter
|
56 |
-
rate_limit_store = defaultdict(lambda: {"count": 0, "timestamp": time.time()})
|
57 |
-
ip_rate_limit_store = defaultdict(lambda: {"count": 0, "timestamp": time.time()})
|
58 |
-
|
59 |
-
async def get_api_key(authorization: str = Header(None)) -> str:
|
60 |
-
if authorization is None or not authorization.startswith('Bearer '):
|
61 |
-
logger.warning("Invalid or missing authorization header.")
|
62 |
-
raise HTTPException(status_code=401, detail='Invalid authorization header format | NiansuhAI')
|
63 |
-
api_key = authorization[7:]
|
64 |
-
if api_key not in API_KEYS:
|
65 |
-
logger.warning(f"Invalid API key attempted: {api_key}")
|
66 |
-
raise HTTPException(status_code=401, detail='Invalid API key | NiansuhAI')
|
67 |
-
return api_key
|
68 |
-
|
69 |
-
async def rate_limiter(req: Request, api_key: str = Depends(get_api_key)):
|
70 |
-
current_time = time.time()
|
71 |
-
# Rate limiting per API key
|
72 |
-
window_start = rate_limit_store[api_key]["timestamp"]
|
73 |
-
if current_time - window_start > 60:
|
74 |
-
rate_limit_store[api_key] = {"count": 1, "timestamp": current_time}
|
75 |
-
else:
|
76 |
-
if rate_limit_store[api_key]["count"] >= RATE_LIMIT:
|
77 |
-
logger.warning(f"Rate limit exceeded for API key: {api_key} | NiansuhAI")
|
78 |
-
raise HTTPException(status_code=429, detail='Rate limit exceeded for API key | NiansuhAI')
|
79 |
-
rate_limit_store[api_key]["count"] += 1
|
80 |
-
|
81 |
-
# Rate limiting per IP address
|
82 |
-
client_ip = req.client.host
|
83 |
-
window_start_ip = ip_rate_limit_store[client_ip]["timestamp"]
|
84 |
-
if current_time - window_start_ip > 60:
|
85 |
-
ip_rate_limit_store[client_ip] = {"count": 1, "timestamp": current_time}
|
86 |
-
else:
|
87 |
-
if ip_rate_limit_store[client_ip]["count"] >= RATE_LIMIT:
|
88 |
-
logger.warning(f"Rate limit exceeded for IP address: {client_ip}")
|
89 |
-
raise HTTPException(status_code=429, detail='Rate limit exceeded for IP address | NiansuhAI')
|
90 |
-
ip_rate_limit_store[client_ip]["count"] += 1
|
91 |
-
|
92 |
-
# Custom exception for model not working
|
93 |
-
class ModelNotWorkingException(Exception):
|
94 |
-
def __init__(self, model: str):
|
95 |
-
self.model = model
|
96 |
-
self.message = f"The model '{model}' is currently not working. Please try another model or wait for it to be fixed."
|
97 |
-
super().__init__(self.message)
|
98 |
-
|
99 |
-
# Mock implementations for ImageResponse and to_data_uri
|
100 |
class ImageResponse:
|
101 |
def __init__(self, url: str, alt: str):
|
102 |
self.url = url
|
@@ -143,7 +230,9 @@ class Blackbox:
|
|
143 |
]
|
144 |
|
145 |
# Filter models based on AVAILABLE_MODELS
|
|
|
146 |
if AVAILABLE_MODELS:
|
|
|
147 |
models = [model for model in models if model in AVAILABLE_MODELS]
|
148 |
|
149 |
agentMode = {
|
@@ -383,80 +472,8 @@ class Blackbox:
|
|
383 |
if attempt == retry_attempts - 1:
|
384 |
raise HTTPException(status_code=500, detail=str(e))
|
385 |
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
# Middleware to enhance security and log client IP
|
390 |
-
@app.middleware("http")
|
391 |
-
async def security_middleware(request: Request, call_next):
|
392 |
-
client_ip = request.client.host
|
393 |
-
# Set the client_ip in the context variable
|
394 |
-
client_ip_var.set(client_ip)
|
395 |
-
|
396 |
-
# Enforce that POST requests to sensitive endpoints must have a valid Content-Type
|
397 |
-
if request.method == "POST" and request.url.path == "/v1/chat/completions":
|
398 |
-
content_type = request.headers.get("Content-Type")
|
399 |
-
if content_type != "application/json":
|
400 |
-
logger.warning("Invalid Content-Type for /v1/chat/completions")
|
401 |
-
return JSONResponse(
|
402 |
-
status_code=400,
|
403 |
-
content={
|
404 |
-
"error": {
|
405 |
-
"message": "Content-Type must be application/json",
|
406 |
-
"type": "invalid_request_error",
|
407 |
-
"param": None,
|
408 |
-
"code": None
|
409 |
-
}
|
410 |
-
},
|
411 |
-
)
|
412 |
-
|
413 |
-
# Log the incoming request
|
414 |
-
logger.info(f"Incoming request: {request.method} {request.url.path}")
|
415 |
-
|
416 |
-
response = await call_next(request)
|
417 |
-
|
418 |
-
# Log the response status
|
419 |
-
logger.info(f"Response status: {response.status_code}")
|
420 |
-
|
421 |
-
return response
|
422 |
-
|
423 |
-
class Message(BaseModel):
|
424 |
-
role: str
|
425 |
-
content: str
|
426 |
-
|
427 |
-
class ChatRequest(BaseModel):
|
428 |
-
model: str
|
429 |
-
messages: List[Message]
|
430 |
-
temperature: Optional[float] = 1.0
|
431 |
-
top_p: Optional[float] = 1.0
|
432 |
-
n: Optional[int] = 1
|
433 |
-
stream: Optional[bool] = False
|
434 |
-
stop: Optional[Union[str, List[str]]] = None
|
435 |
-
max_tokens: Optional[int] = None
|
436 |
-
presence_penalty: Optional[float] = 0.0
|
437 |
-
frequency_penalty: Optional[float] = 0.0
|
438 |
-
logit_bias: Optional[Dict[str, float]] = None
|
439 |
-
user: Optional[str] = None
|
440 |
-
webSearchMode: Optional[bool] = False # Custom parameter
|
441 |
-
|
442 |
-
def create_response(content: str, model: str, finish_reason: Optional[str] = None) -> Dict[str, Any]:
|
443 |
-
return {
|
444 |
-
"id": f"chatcmpl-{uuid.uuid4()}",
|
445 |
-
"object": "chat.completion.chunk",
|
446 |
-
"created": int(datetime.now().timestamp()),
|
447 |
-
"model": model,
|
448 |
-
"choices": [
|
449 |
-
{
|
450 |
-
"index": 0,
|
451 |
-
"delta": {"content": content, "role": "assistant"},
|
452 |
-
"finish_reason": finish_reason,
|
453 |
-
}
|
454 |
-
],
|
455 |
-
"usage": None,
|
456 |
-
}
|
457 |
-
|
458 |
-
@app.post("/v1/chat/completions", dependencies=[Depends(rate_limiter)])
|
459 |
-
async def chat_completions(request: ChatRequest, req: Request, api_key: str = Depends(get_api_key)):
|
460 |
# Redact user messages only for logging purposes
|
461 |
redacted_messages = [{"role": msg.role, "content": "[redacted]"} for msg in request.messages]
|
462 |
|
@@ -484,6 +501,14 @@ async def chat_completions(request: ChatRequest, req: Request, api_key: str = De
|
|
484 |
if isinstance(chunk, ImageResponse):
|
485 |
image_markdown = f""
|
486 |
response_chunk = create_response(image_markdown, request.model)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
487 |
else:
|
488 |
response_chunk = create_response(chunk, request.model)
|
489 |
|
@@ -504,6 +529,14 @@ async def chat_completions(request: ChatRequest, req: Request, api_key: str = De
|
|
504 |
async for chunk in async_generator:
|
505 |
if isinstance(chunk, ImageResponse):
|
506 |
response_content += f"\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
507 |
else:
|
508 |
response_content += chunk
|
509 |
|
@@ -549,20 +582,17 @@ async def chat_completions_get():
|
|
549 |
|
550 |
@app.get("/v1/models")
|
551 |
async def get_models(req: Request):
|
552 |
-
client_ip = req.client.host
|
553 |
logger.info(f"Fetching available models")
|
554 |
return {"data": [{"id": model, "object": "model"} for model in Blackbox.models]}
|
555 |
|
556 |
# Additional endpoints for better functionality
|
557 |
@app.get("/v1/health")
|
558 |
async def health_check(req: Request):
|
559 |
-
client_ip = req.client.host
|
560 |
logger.info(f"Health check requested")
|
561 |
return {"status": "ok"}
|
562 |
|
563 |
@app.get("/v1/models/{model}/status")
|
564 |
async def model_status(model: str, req: Request):
|
565 |
-
client_ip = req.client.host
|
566 |
logger.info(f"Model status requested for '{model}'")
|
567 |
if model in Blackbox.models:
|
568 |
return {"model": model, "status": "available"}
|
@@ -576,8 +606,7 @@ async def model_status(model: str, req: Request):
|
|
576 |
# Custom exception handler to match OpenAI's error format
|
577 |
@app.exception_handler(HTTPException)
|
578 |
async def http_exception_handler(request: Request, exc: HTTPException):
|
579 |
-
|
580 |
-
logger.error(f"HTTPException: {exc.detail} | Client IP: {client_ip}")
|
581 |
return JSONResponse(
|
582 |
status_code=exc.status_code,
|
583 |
content={
|
@@ -596,7 +625,6 @@ class TokenizerRequest(BaseModel):
|
|
596 |
|
597 |
@app.post("/v1/tokenizer")
|
598 |
async def tokenizer(request: TokenizerRequest, req: Request):
|
599 |
-
client_ip = req.client.host
|
600 |
text = request.text
|
601 |
token_count = len(text.split())
|
602 |
logger.info(f"Tokenizer called | Tokens: {token_count}")
|
|
|
1 |
+
# main.py
|
2 |
import os
|
3 |
import re
|
4 |
import random
|
|
|
8 |
import logging
|
9 |
import asyncio
|
10 |
import time
|
11 |
+
import contextvars
|
12 |
from collections import defaultdict
|
13 |
from typing import List, Dict, Any, Optional, AsyncGenerator, Union
|
14 |
from datetime import datetime
|
|
|
15 |
|
16 |
+
from slowapi import Limiter, _rate_limit_exceeded_handler
|
17 |
+
from slowapi.util import get_remote_address
|
18 |
+
from slowapi.errors import RateLimitExceeded
|
19 |
+
from slowapi.middleware import SlowAPIMiddleware
|
20 |
from fastapi import FastAPI, HTTPException, Request, Depends, Header
|
21 |
from fastapi.responses import StreamingResponse, JSONResponse, RedirectResponse
|
22 |
from pydantic import BaseModel
|
23 |
|
24 |
+
from sqlalchemy.orm import Session
|
25 |
+
|
26 |
+
from aiohttp import ClientSession, ClientTimeout, ClientError
|
27 |
+
|
28 |
+
from database import SessionLocal, engine, get_db
|
29 |
+
from models import Base, Image, Log
|
30 |
+
|
31 |
+
from dotenv import load_dotenv
|
32 |
+
|
33 |
+
# Load environment variables from .env file
|
34 |
+
load_dotenv()
|
35 |
+
|
36 |
+
# Create all tables
|
37 |
+
Base.metadata.create_all(bind=engine)
|
38 |
+
|
39 |
# Define a context variable for client_ip
|
40 |
client_ip_var = contextvars.ContextVar("client_ip", default="N/A")
|
41 |
|
|
|
45 |
record.client_ip = client_ip_var.get()
|
46 |
return True
|
47 |
|
48 |
+
# Custom logging formatter to handle missing client_ip
|
49 |
+
class SafeFormatter(logging.Formatter):
|
50 |
+
def format(self, record):
|
51 |
+
if not hasattr(record, 'client_ip'):
|
52 |
+
record.client_ip = 'N/A'
|
53 |
+
return super().format(record)
|
54 |
+
|
55 |
# Configure logging
|
56 |
+
logger = logging.getLogger("main") # Use a specific logger name if needed
|
57 |
+
logger.setLevel(logging.INFO)
|
58 |
+
|
59 |
+
# Create handlers
|
60 |
+
console_handler = logging.StreamHandler()
|
61 |
+
console_handler.setLevel(logging.INFO)
|
62 |
+
|
63 |
+
# Create and set the custom formatter
|
64 |
+
formatter = SafeFormatter(
|
65 |
+
fmt="%(asctime)s [%(levelname)s] %(name)s [IP: %(client_ip)s]: %(message)s",
|
66 |
+
datefmt="%Y-%m-%d %H:%M:%S"
|
67 |
)
|
68 |
+
console_handler.setFormatter(formatter)
|
69 |
+
|
70 |
+
# Add the custom filter to the console handler
|
71 |
+
console_handler.addFilter(ContextFilter())
|
72 |
+
|
73 |
+
# Add handlers to the logger
|
74 |
+
logger.addHandler(console_handler)
|
75 |
+
|
76 |
+
# Initialize the limiter with slowapi
|
77 |
+
limiter = Limiter(key_func=get_remote_address, default_limits=["60/minute"])
|
78 |
+
app = FastAPI()
|
79 |
+
|
80 |
+
# Register the rate limit exceeded handler
|
81 |
+
app.state.limiter = limiter
|
82 |
+
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
|
83 |
+
|
84 |
+
# Add SlowAPI middleware
|
85 |
+
app.add_middleware(SlowAPIMiddleware)
|
86 |
+
|
87 |
+
from logging import Handler
|
88 |
+
|
89 |
+
class DBLogHandler(Handler):
|
90 |
+
def __init__(self, db: Session):
|
91 |
+
super().__init__()
|
92 |
+
self.db = db
|
93 |
+
|
94 |
+
def emit(self, record):
|
95 |
+
log_entry = Log(
|
96 |
+
level=record.levelname,
|
97 |
+
message=record.getMessage(),
|
98 |
+
client_ip=getattr(record, 'client_ip', None)
|
99 |
+
)
|
100 |
+
try:
|
101 |
+
self.db.add(log_entry)
|
102 |
+
self.db.commit()
|
103 |
+
except Exception as e:
|
104 |
+
# Handle exceptions (e.g., rollback)
|
105 |
+
self.db.rollback()
|
106 |
+
print(f"Failed to log to database: {e}")
|
107 |
+
|
108 |
+
# Dependency to add DBLogHandler
|
109 |
+
async def get_db_log_handler(request: Request):
|
110 |
+
db = next(get_db())
|
111 |
+
db_log_handler = DBLogHandler(db)
|
112 |
+
logger.addHandler(db_log_handler)
|
113 |
+
try:
|
114 |
+
yield
|
115 |
+
finally:
|
116 |
+
logger.removeHandler(db_log_handler)
|
117 |
+
db.close()
|
118 |
+
|
119 |
+
@app.middleware("http")
|
120 |
+
async def security_middleware(request: Request, call_next):
|
121 |
+
client_ip = request.client.host
|
122 |
+
# Set the client_ip in the context variable
|
123 |
+
client_ip_var.set(client_ip)
|
124 |
+
|
125 |
+
# Enforce that POST requests to sensitive endpoints must have a valid Content-Type
|
126 |
+
if request.method == "POST" and request.url.path == "/v1/chat/completions":
|
127 |
+
content_type = request.headers.get("Content-Type")
|
128 |
+
if content_type != "application/json":
|
129 |
+
logger.warning("Invalid Content-Type for /v1/chat/completions")
|
130 |
+
return JSONResponse(
|
131 |
+
status_code=400,
|
132 |
+
content={
|
133 |
+
"error": {
|
134 |
+
"message": "Content-Type must be application/json",
|
135 |
+
"type": "invalid_request_error",
|
136 |
+
"param": None,
|
137 |
+
"code": None
|
138 |
+
}
|
139 |
+
},
|
140 |
+
)
|
141 |
+
|
142 |
+
# Log the incoming request
|
143 |
+
logger.info(f"Incoming request: {request.method} {request.url.path}")
|
144 |
+
|
145 |
+
response = await call_next(request)
|
146 |
+
|
147 |
+
# Log the response status
|
148 |
+
logger.info(f"Response status: {response.status_code}")
|
149 |
+
|
150 |
+
return response
|
151 |
+
|
152 |
+
class Message(BaseModel):
|
153 |
+
role: str
|
154 |
+
content: str
|
155 |
+
|
156 |
+
class ChatRequest(BaseModel):
|
157 |
+
model: str
|
158 |
+
messages: List[Message]
|
159 |
+
temperature: Optional[float] = 1.0
|
160 |
+
top_p: Optional[float] = 1.0
|
161 |
+
n: Optional[int] = 1
|
162 |
+
stream: Optional[bool] = False
|
163 |
+
stop: Optional[Union[str, List[str]]] = None
|
164 |
+
max_tokens: Optional[int] = None
|
165 |
+
presence_penalty: Optional[float] = 0.0
|
166 |
+
frequency_penalty: Optional[float] = 0.0
|
167 |
+
logit_bias: Optional[Dict[str, float]] = None
|
168 |
+
user: Optional[str] = None
|
169 |
+
webSearchMode: Optional[bool] = False # Custom parameter
|
170 |
+
|
171 |
+
def create_response(content: str, model: str, finish_reason: Optional[str] = None) -> Dict[str, Any]:
|
172 |
+
return {
|
173 |
+
"id": f"chatcmpl-{uuid.uuid4()}",
|
174 |
+
"object": "chat.completion.chunk",
|
175 |
+
"created": int(datetime.now().timestamp()),
|
176 |
+
"model": model,
|
177 |
+
"choices": [
|
178 |
+
{
|
179 |
+
"index": 0,
|
180 |
+
"delta": {"content": content, "role": "assistant"},
|
181 |
+
"finish_reason": finish_reason,
|
182 |
+
}
|
183 |
+
],
|
184 |
+
"usage": None,
|
185 |
+
}
|
186 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
187 |
class ImageResponse:
|
188 |
def __init__(self, url: str, alt: str):
|
189 |
self.url = url
|
|
|
230 |
]
|
231 |
|
232 |
# Filter models based on AVAILABLE_MODELS
|
233 |
+
AVAILABLE_MODELS = os.getenv("AVAILABLE_MODELS", "")
|
234 |
if AVAILABLE_MODELS:
|
235 |
+
AVAILABLE_MODELS = [model.strip() for model in AVAILABLE_MODELS.split(',') if model.strip()]
|
236 |
models = [model for model in models if model in AVAILABLE_MODELS]
|
237 |
|
238 |
agentMode = {
|
|
|
472 |
if attempt == retry_attempts - 1:
|
473 |
raise HTTPException(status_code=500, detail=str(e))
|
474 |
|
475 |
+
@app.post("/v1/chat/completions", dependencies=[Depends(limiter.limit("60/minute")), Depends(get_db_log_handler)])
|
476 |
+
async def chat_completions(request: ChatRequest, req: Request, api_key: str = Depends(get_api_key), db: Session = Depends(get_db)):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
477 |
# Redact user messages only for logging purposes
|
478 |
redacted_messages = [{"role": msg.role, "content": "[redacted]"} for msg in request.messages]
|
479 |
|
|
|
501 |
if isinstance(chunk, ImageResponse):
|
502 |
image_markdown = f""
|
503 |
response_chunk = create_response(image_markdown, request.model)
|
504 |
+
|
505 |
+
# Store image in the database
|
506 |
+
image_entry = Image(
|
507 |
+
image_url=chunk.url,
|
508 |
+
description=request.messages[-1].get('content', '')
|
509 |
+
)
|
510 |
+
db.add(image_entry)
|
511 |
+
db.commit()
|
512 |
else:
|
513 |
response_chunk = create_response(chunk, request.model)
|
514 |
|
|
|
529 |
async for chunk in async_generator:
|
530 |
if isinstance(chunk, ImageResponse):
|
531 |
response_content += f"\n"
|
532 |
+
|
533 |
+
# Store image in the database
|
534 |
+
image_entry = Image(
|
535 |
+
image_url=chunk.url,
|
536 |
+
description=request.messages[-1].get('content', '')
|
537 |
+
)
|
538 |
+
db.add(image_entry)
|
539 |
+
db.commit()
|
540 |
else:
|
541 |
response_content += chunk
|
542 |
|
|
|
582 |
|
583 |
@app.get("/v1/models")
|
584 |
async def get_models(req: Request):
|
|
|
585 |
logger.info(f"Fetching available models")
|
586 |
return {"data": [{"id": model, "object": "model"} for model in Blackbox.models]}
|
587 |
|
588 |
# Additional endpoints for better functionality
|
589 |
@app.get("/v1/health")
|
590 |
async def health_check(req: Request):
|
|
|
591 |
logger.info(f"Health check requested")
|
592 |
return {"status": "ok"}
|
593 |
|
594 |
@app.get("/v1/models/{model}/status")
|
595 |
async def model_status(model: str, req: Request):
|
|
|
596 |
logger.info(f"Model status requested for '{model}'")
|
597 |
if model in Blackbox.models:
|
598 |
return {"model": model, "status": "available"}
|
|
|
606 |
# Custom exception handler to match OpenAI's error format
|
607 |
@app.exception_handler(HTTPException)
|
608 |
async def http_exception_handler(request: Request, exc: HTTPException):
|
609 |
+
logger.error(f"HTTPException: {exc.detail}")
|
|
|
610 |
return JSONResponse(
|
611 |
status_code=exc.status_code,
|
612 |
content={
|
|
|
625 |
|
626 |
@app.post("/v1/tokenizer")
|
627 |
async def tokenizer(request: TokenizerRequest, req: Request):
|
|
|
628 |
text = request.text
|
629 |
token_count = len(text.split())
|
630 |
logger.info(f"Tokenizer called | Tokens: {token_count}")
|