Niansuh commited on
Commit
a5a718a
·
verified ·
1 Parent(s): 5456d86

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +126 -140
main.py CHANGED
@@ -1,4 +1,3 @@
1
- # main.py
2
  import os
3
  import re
4
  import random
@@ -17,25 +16,13 @@ from slowapi import Limiter, _rate_limit_exceeded_handler
17
  from slowapi.util import get_remote_address
18
  from slowapi.errors import RateLimitExceeded
19
  from slowapi.middleware import SlowAPIMiddleware
20
- from fastapi import FastAPI, HTTPException, Request, Depends, Header
21
  from fastapi.responses import StreamingResponse, JSONResponse, RedirectResponse
 
22
  from pydantic import BaseModel
23
 
24
- from sqlalchemy.orm import Session
25
-
26
  from aiohttp import ClientSession, ClientTimeout, ClientError
27
 
28
- from database import SessionLocal, engine, get_db
29
- from models import Base, Image, Log
30
-
31
- from dotenv import load_dotenv
32
-
33
- # Load environment variables from .env file
34
- load_dotenv()
35
-
36
- # Create all tables
37
- Base.metadata.create_all(bind=engine)
38
-
39
  # Define a context variable for client_ip
40
  client_ip_var = contextvars.ContextVar("client_ip", default="N/A")
41
 
@@ -57,21 +44,21 @@ logger = logging.getLogger("main") # Use a specific logger name if needed
57
  logger.setLevel(logging.INFO)
58
 
59
  # Create handlers
60
- console_handler = logging.StreamHandler()
61
- console_handler.setLevel(logging.INFO)
62
 
63
  # Create and set the custom formatter
64
  formatter = SafeFormatter(
65
  fmt="%(asctime)s [%(levelname)s] %(name)s [IP: %(client_ip)s]: %(message)s",
66
  datefmt="%Y-%m-%d %H:%M:%S"
67
  )
68
- console_handler.setFormatter(formatter)
69
 
70
- # Add the custom filter to the console handler
71
- console_handler.addFilter(ContextFilter())
72
 
73
  # Add handlers to the logger
74
- logger.addHandler(console_handler)
75
 
76
  # Initialize the limiter with slowapi
77
  limiter = Limiter(key_func=get_remote_address, default_limits=["60/minute"])
@@ -84,106 +71,50 @@ app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
84
  # Add SlowAPI middleware
85
  app.add_middleware(SlowAPIMiddleware)
86
 
87
- from logging import Handler
88
-
89
- class DBLogHandler(Handler):
90
- def __init__(self, db: Session):
91
- super().__init__()
92
- self.db = db
93
-
94
- def emit(self, record):
95
- log_entry = Log(
96
- level=record.levelname,
97
- message=record.getMessage(),
98
- client_ip=getattr(record, 'client_ip', None)
 
 
 
 
 
 
 
 
99
  )
100
- try:
101
- self.db.add(log_entry)
102
- self.db.commit()
103
- except Exception as e:
104
- # Handle exceptions (e.g., rollback)
105
- self.db.rollback()
106
- print(f"Failed to log to database: {e}")
107
-
108
- # Dependency to add DBLogHandler
109
- async def get_db_log_handler(request: Request):
110
- db = next(get_db())
111
- db_log_handler = DBLogHandler(db)
112
- logger.addHandler(db_log_handler)
113
- try:
114
- yield
115
- finally:
116
- logger.removeHandler(db_log_handler)
117
- db.close()
118
-
119
- @app.middleware("http")
120
- async def security_middleware(request: Request, call_next):
121
- client_ip = request.client.host
122
- # Set the client_ip in the context variable
123
- client_ip_var.set(client_ip)
124
-
125
- # Enforce that POST requests to sensitive endpoints must have a valid Content-Type
126
- if request.method == "POST" and request.url.path == "/v1/chat/completions":
127
- content_type = request.headers.get("Content-Type")
128
- if content_type != "application/json":
129
- logger.warning("Invalid Content-Type for /v1/chat/completions")
130
- return JSONResponse(
131
- status_code=400,
132
- content={
133
- "error": {
134
- "message": "Content-Type must be application/json",
135
- "type": "invalid_request_error",
136
- "param": None,
137
- "code": None
138
- }
139
- },
140
- )
141
-
142
- # Log the incoming request
143
- logger.info(f"Incoming request: {request.method} {request.url.path}")
144
-
145
- response = await call_next(request)
146
-
147
- # Log the response status
148
- logger.info(f"Response status: {response.status_code}")
149
-
150
- return response
151
-
152
- class Message(BaseModel):
153
- role: str
154
- content: str
155
-
156
- class ChatRequest(BaseModel):
157
- model: str
158
- messages: List[Message]
159
- temperature: Optional[float] = 1.0
160
- top_p: Optional[float] = 1.0
161
- n: Optional[int] = 1
162
- stream: Optional[bool] = False
163
- stop: Optional[Union[str, List[str]]] = None
164
- max_tokens: Optional[int] = None
165
- presence_penalty: Optional[float] = 0.0
166
- frequency_penalty: Optional[float] = 0.0
167
- logit_bias: Optional[Dict[str, float]] = None
168
- user: Optional[str] = None
169
- webSearchMode: Optional[bool] = False # Custom parameter
170
-
171
- def create_response(content: str, model: str, finish_reason: Optional[str] = None) -> Dict[str, Any]:
172
- return {
173
- "id": f"chatcmpl-{uuid.uuid4()}",
174
- "object": "chat.completion.chunk",
175
- "created": int(datetime.now().timestamp()),
176
- "model": model,
177
- "choices": [
178
- {
179
- "index": 0,
180
- "delta": {"content": content, "role": "assistant"},
181
- "finish_reason": finish_reason,
182
- }
183
- ],
184
- "usage": None,
185
- }
186
-
187
  class ImageResponse:
188
  def __init__(self, url: str, alt: str):
189
  self.url = url
@@ -230,9 +161,7 @@ class Blackbox:
230
  ]
231
 
232
  # Filter models based on AVAILABLE_MODELS
233
- AVAILABLE_MODELS = os.getenv("AVAILABLE_MODELS", "")
234
  if AVAILABLE_MODELS:
235
- AVAILABLE_MODELS = [model.strip() for model in AVAILABLE_MODELS.split(',') if model.strip()]
236
  models = [model for model in models if model in AVAILABLE_MODELS]
237
 
238
  agentMode = {
@@ -472,8 +401,81 @@ class Blackbox:
472
  if attempt == retry_attempts - 1:
473
  raise HTTPException(status_code=500, detail=str(e))
474
 
475
- @app.post("/v1/chat/completions", dependencies=[Depends(limiter.limit("60/minute")), Depends(get_db_log_handler)])
476
- async def chat_completions(request: ChatRequest, req: Request, api_key: str = Depends(get_api_key), db: Session = Depends(get_db)):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
477
  # Redact user messages only for logging purposes
478
  redacted_messages = [{"role": msg.role, "content": "[redacted]"} for msg in request.messages]
479
 
@@ -501,14 +503,6 @@ async def chat_completions(request: ChatRequest, req: Request, api_key: str = De
501
  if isinstance(chunk, ImageResponse):
502
  image_markdown = f"![image]({chunk.url})"
503
  response_chunk = create_response(image_markdown, request.model)
504
-
505
- # Store image in the database
506
- image_entry = Image(
507
- image_url=chunk.url,
508
- description=request.messages[-1].get('content', '')
509
- )
510
- db.add(image_entry)
511
- db.commit()
512
  else:
513
  response_chunk = create_response(chunk, request.model)
514
 
@@ -529,14 +523,6 @@ async def chat_completions(request: ChatRequest, req: Request, api_key: str = De
529
  async for chunk in async_generator:
530
  if isinstance(chunk, ImageResponse):
531
  response_content += f"![image]({chunk.url})\n"
532
-
533
- # Store image in the database
534
- image_entry = Image(
535
- image_url=chunk.url,
536
- description=request.messages[-1].get('content', '')
537
- )
538
- db.add(image_entry)
539
- db.commit()
540
  else:
541
  response_content += chunk
542
 
 
 
1
  import os
2
  import re
3
  import random
 
16
  from slowapi.util import get_remote_address
17
  from slowapi.errors import RateLimitExceeded
18
  from slowapi.middleware import SlowAPIMiddleware
19
+ from fastapi import FastAPI, HTTPException, Request, Depends, Header, Security
20
  from fastapi.responses import StreamingResponse, JSONResponse, RedirectResponse
21
+ from fastapi.security.api_key import APIKeyHeader
22
  from pydantic import BaseModel
23
 
 
 
24
  from aiohttp import ClientSession, ClientTimeout, ClientError
25
 
 
 
 
 
 
 
 
 
 
 
 
26
  # Define a context variable for client_ip
27
  client_ip_var = contextvars.ContextVar("client_ip", default="N/A")
28
 
 
44
  logger.setLevel(logging.INFO)
45
 
46
  # Create handlers
47
+ handler = logging.StreamHandler()
48
+ handler.setLevel(logging.INFO)
49
 
50
  # Create and set the custom formatter
51
  formatter = SafeFormatter(
52
  fmt="%(asctime)s [%(levelname)s] %(name)s [IP: %(client_ip)s]: %(message)s",
53
  datefmt="%Y-%m-%d %H:%M:%S"
54
  )
55
+ handler.setFormatter(formatter)
56
 
57
+ # Add the custom filter to the handler
58
+ handler.addFilter(ContextFilter())
59
 
60
  # Add handlers to the logger
61
+ logger.addHandler(handler)
62
 
63
  # Initialize the limiter with slowapi
64
  limiter = Limiter(key_func=get_remote_address, default_limits=["60/minute"])
 
71
  # Add SlowAPI middleware
72
  app.add_middleware(SlowAPIMiddleware)
73
 
74
+ # Load environment variables
75
+ API_KEYS = os.getenv('API_KEYS', '').split(',') # Comma-separated API keys
76
+ RATE_LIMIT = int(os.getenv('RATE_LIMIT', '60')) # Requests per minute
77
+ AVAILABLE_MODELS = os.getenv('AVAILABLE_MODELS', '') # Comma-separated available models
78
+
79
+ if not API_KEYS or API_KEYS == ['']:
80
+ logger.error("No API keys found. Please set the API_KEYS environment variable.")
81
+ raise Exception("API_KEYS environment variable not set.")
82
+
83
+ # Define API key security
84
+ API_KEY_NAME = "X-API-Key"
85
+ api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False)
86
+
87
+ async def get_api_key(api_key: str = Security(api_key_header)):
88
+ if not api_key:
89
+ logger.warning("API key missing")
90
+ raise HTTPException(
91
+ status_code=401,
92
+ detail="API key missing",
93
+ headers={"WWW-Authenticate": "API Key"},
94
  )
95
+ if api_key not in API_KEYS:
96
+ logger.warning(f"Invalid API key: {api_key}")
97
+ raise HTTPException(
98
+ status_code=403,
99
+ detail="Invalid API key",
100
+ headers={"WWW-Authenticate": "API Key"},
101
+ )
102
+ return api_key
103
+
104
+ # Process available models
105
+ if AVAILABLE_MODELS:
106
+ AVAILABLE_MODELS = [model.strip() for model in AVAILABLE_MODELS.split(',') if model.strip()]
107
+ else:
108
+ AVAILABLE_MODELS = [] # If empty, all models are available
109
+
110
+ # Custom exception for model not working
111
+ class ModelNotWorkingException(Exception):
112
+ def __init__(self, model: str):
113
+ self.model = model
114
+ self.message = f"The model '{model}' is currently not working. Please try another model or wait for it to be fixed."
115
+ super().__init__(self.message)
116
+
117
+ # Mock implementations for ImageResponse and to_data_uri
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  class ImageResponse:
119
  def __init__(self, url: str, alt: str):
120
  self.url = url
 
161
  ]
162
 
163
  # Filter models based on AVAILABLE_MODELS
 
164
  if AVAILABLE_MODELS:
 
165
  models = [model for model in models if model in AVAILABLE_MODELS]
166
 
167
  agentMode = {
 
401
  if attempt == retry_attempts - 1:
402
  raise HTTPException(status_code=500, detail=str(e))
403
 
404
+ # FastAPI app setup with slowapi rate limiting
405
+ @app.middleware("http")
406
+ async def security_middleware(request: Request, call_next):
407
+ client_ip = request.client.host
408
+ # Set the client_ip in the context variable
409
+ client_ip_var.set(client_ip)
410
+
411
+ # Enforce that POST requests to sensitive endpoints must have a valid Content-Type
412
+ if request.method == "POST" and request.url.path == "/v1/chat/completions":
413
+ content_type = request.headers.get("Content-Type")
414
+ if content_type != "application/json":
415
+ logger.warning("Invalid Content-Type for /v1/chat/completions")
416
+ return JSONResponse(
417
+ status_code=400,
418
+ content={
419
+ "error": {
420
+ "message": "Content-Type must be application/json",
421
+ "type": "invalid_request_error",
422
+ "param": None,
423
+ "code": None
424
+ }
425
+ },
426
+ )
427
+
428
+ # Log the incoming request
429
+ logger.info(f"Incoming request: {request.method} {request.url.path}")
430
+
431
+ response = await call_next(request)
432
+
433
+ # Log the response status
434
+ logger.info(f"Response status: {response.status_code}")
435
+
436
+ return response
437
+
438
+ class Message(BaseModel):
439
+ role: str
440
+ content: str
441
+
442
+ class ChatRequest(BaseModel):
443
+ model: str
444
+ messages: List[Message]
445
+ temperature: Optional[float] = 1.0
446
+ top_p: Optional[float] = 1.0
447
+ n: Optional[int] = 1
448
+ stream: Optional[bool] = False
449
+ stop: Optional[Union[str, List[str]]] = None
450
+ max_tokens: Optional[int] = None
451
+ presence_penalty: Optional[float] = 0.0
452
+ frequency_penalty: Optional[float] = 0.0
453
+ logit_bias: Optional[Dict[str, float]] = None
454
+ user: Optional[str] = None
455
+ webSearchMode: Optional[bool] = False # Custom parameter
456
+
457
+ def create_response(content: str, model: str, finish_reason: Optional[str] = None) -> Dict[str, Any]:
458
+ return {
459
+ "id": f"chatcmpl-{uuid.uuid4()}",
460
+ "object": "chat.completion.chunk",
461
+ "created": int(datetime.now().timestamp()),
462
+ "model": model,
463
+ "choices": [
464
+ {
465
+ "index": 0,
466
+ "delta": {"content": content, "role": "assistant"},
467
+ "finish_reason": finish_reason,
468
+ }
469
+ ],
470
+ "usage": None,
471
+ }
472
+
473
+ @app.post("/v1/chat/completions", dependencies=[Depends(limiter.limit("60/minute"))])
474
+ async def chat_completions(
475
+ request: ChatRequest,
476
+ req: Request,
477
+ api_key: str = Depends(get_api_key)
478
+ ):
479
  # Redact user messages only for logging purposes
480
  redacted_messages = [{"role": msg.role, "content": "[redacted]"} for msg in request.messages]
481
 
 
503
  if isinstance(chunk, ImageResponse):
504
  image_markdown = f"![image]({chunk.url})"
505
  response_chunk = create_response(image_markdown, request.model)
 
 
 
 
 
 
 
 
506
  else:
507
  response_chunk = create_response(chunk, request.model)
508
 
 
523
  async for chunk in async_generator:
524
  if isinstance(chunk, ImageResponse):
525
  response_content += f"![image]({chunk.url})\n"
 
 
 
 
 
 
 
 
526
  else:
527
  response_content += chunk
528