Niansuh commited on
Commit
bd1b497
·
verified ·
1 Parent(s): eb9171d

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +178 -150
main.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import os
2
  import re
3
  import random
@@ -7,16 +8,34 @@ import json
7
  import logging
8
  import asyncio
9
  import time
 
10
  from collections import defaultdict
11
  from typing import List, Dict, Any, Optional, AsyncGenerator, Union
12
  from datetime import datetime
13
- import contextvars # Added for context variables
14
 
15
- from aiohttp import ClientSession, ClientTimeout, ClientError
 
 
 
16
  from fastapi import FastAPI, HTTPException, Request, Depends, Header
17
  from fastapi.responses import StreamingResponse, JSONResponse, RedirectResponse
18
  from pydantic import BaseModel
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  # Define a context variable for client_ip
21
  client_ip_var = contextvars.ContextVar("client_ip", default="N/A")
22
 
@@ -26,77 +45,145 @@ class ContextFilter(logging.Filter):
26
  record.client_ip = client_ip_var.get()
27
  return True
28
 
 
 
 
 
 
 
 
29
  # Configure logging
30
- logging.basicConfig(
31
- level=logging.INFO,
32
- format="%(asctime)s [%(levelname)s] %(name)s [IP: %(client_ip)s]: %(message)s",
33
- handlers=[logging.StreamHandler()]
 
 
 
 
 
 
 
34
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
- # Get the root logger and add the ContextFilter
37
- logger = logging.getLogger(__name__)
38
- logger.addFilter(ContextFilter())
39
-
40
- # Load environment variables
41
- API_KEYS = os.getenv('API_KEYS', '').split(',') # Comma-separated API keys
42
- RATE_LIMIT = int(os.getenv('RATE_LIMIT', '60')) # Requests per minute
43
- AVAILABLE_MODELS = os.getenv('AVAILABLE_MODELS', '') # Comma-separated available models
44
-
45
- if not API_KEYS or API_KEYS == ['']:
46
- logger.error("No API keys found. Please set the API_KEYS environment variable.")
47
- raise Exception("API_KEYS environment variable not set.")
48
-
49
- # Process available models
50
- if AVAILABLE_MODELS:
51
- AVAILABLE_MODELS = [model.strip() for model in AVAILABLE_MODELS.split(',') if model.strip()]
52
- else:
53
- AVAILABLE_MODELS = [] # If empty, all models are available
54
-
55
- # Simple in-memory rate limiter
56
- rate_limit_store = defaultdict(lambda: {"count": 0, "timestamp": time.time()})
57
- ip_rate_limit_store = defaultdict(lambda: {"count": 0, "timestamp": time.time()})
58
-
59
- async def get_api_key(authorization: str = Header(None)) -> str:
60
- if authorization is None or not authorization.startswith('Bearer '):
61
- logger.warning("Invalid or missing authorization header.")
62
- raise HTTPException(status_code=401, detail='Invalid authorization header format | NiansuhAI')
63
- api_key = authorization[7:]
64
- if api_key not in API_KEYS:
65
- logger.warning(f"Invalid API key attempted: {api_key}")
66
- raise HTTPException(status_code=401, detail='Invalid API key | NiansuhAI')
67
- return api_key
68
-
69
- async def rate_limiter(req: Request, api_key: str = Depends(get_api_key)):
70
- current_time = time.time()
71
- # Rate limiting per API key
72
- window_start = rate_limit_store[api_key]["timestamp"]
73
- if current_time - window_start > 60:
74
- rate_limit_store[api_key] = {"count": 1, "timestamp": current_time}
75
- else:
76
- if rate_limit_store[api_key]["count"] >= RATE_LIMIT:
77
- logger.warning(f"Rate limit exceeded for API key: {api_key} | NiansuhAI")
78
- raise HTTPException(status_code=429, detail='Rate limit exceeded for API key | NiansuhAI')
79
- rate_limit_store[api_key]["count"] += 1
80
-
81
- # Rate limiting per IP address
82
- client_ip = req.client.host
83
- window_start_ip = ip_rate_limit_store[client_ip]["timestamp"]
84
- if current_time - window_start_ip > 60:
85
- ip_rate_limit_store[client_ip] = {"count": 1, "timestamp": current_time}
86
- else:
87
- if ip_rate_limit_store[client_ip]["count"] >= RATE_LIMIT:
88
- logger.warning(f"Rate limit exceeded for IP address: {client_ip}")
89
- raise HTTPException(status_code=429, detail='Rate limit exceeded for IP address | NiansuhAI')
90
- ip_rate_limit_store[client_ip]["count"] += 1
91
-
92
- # Custom exception for model not working
93
- class ModelNotWorkingException(Exception):
94
- def __init__(self, model: str):
95
- self.model = model
96
- self.message = f"The model '{model}' is currently not working. Please try another model or wait for it to be fixed."
97
- super().__init__(self.message)
98
-
99
- # Mock implementations for ImageResponse and to_data_uri
100
  class ImageResponse:
101
  def __init__(self, url: str, alt: str):
102
  self.url = url
@@ -143,7 +230,9 @@ class Blackbox:
143
  ]
144
 
145
  # Filter models based on AVAILABLE_MODELS
 
146
  if AVAILABLE_MODELS:
 
147
  models = [model for model in models if model in AVAILABLE_MODELS]
148
 
149
  agentMode = {
@@ -383,80 +472,8 @@ class Blackbox:
383
  if attempt == retry_attempts - 1:
384
  raise HTTPException(status_code=500, detail=str(e))
385
 
386
- # FastAPI app setup
387
- app = FastAPI()
388
-
389
- # Middleware to enhance security and log client IP
390
- @app.middleware("http")
391
- async def security_middleware(request: Request, call_next):
392
- client_ip = request.client.host
393
- # Set the client_ip in the context variable
394
- client_ip_var.set(client_ip)
395
-
396
- # Enforce that POST requests to sensitive endpoints must have a valid Content-Type
397
- if request.method == "POST" and request.url.path == "/v1/chat/completions":
398
- content_type = request.headers.get("Content-Type")
399
- if content_type != "application/json":
400
- logger.warning("Invalid Content-Type for /v1/chat/completions")
401
- return JSONResponse(
402
- status_code=400,
403
- content={
404
- "error": {
405
- "message": "Content-Type must be application/json",
406
- "type": "invalid_request_error",
407
- "param": None,
408
- "code": None
409
- }
410
- },
411
- )
412
-
413
- # Log the incoming request
414
- logger.info(f"Incoming request: {request.method} {request.url.path}")
415
-
416
- response = await call_next(request)
417
-
418
- # Log the response status
419
- logger.info(f"Response status: {response.status_code}")
420
-
421
- return response
422
-
423
- class Message(BaseModel):
424
- role: str
425
- content: str
426
-
427
- class ChatRequest(BaseModel):
428
- model: str
429
- messages: List[Message]
430
- temperature: Optional[float] = 1.0
431
- top_p: Optional[float] = 1.0
432
- n: Optional[int] = 1
433
- stream: Optional[bool] = False
434
- stop: Optional[Union[str, List[str]]] = None
435
- max_tokens: Optional[int] = None
436
- presence_penalty: Optional[float] = 0.0
437
- frequency_penalty: Optional[float] = 0.0
438
- logit_bias: Optional[Dict[str, float]] = None
439
- user: Optional[str] = None
440
- webSearchMode: Optional[bool] = False # Custom parameter
441
-
442
- def create_response(content: str, model: str, finish_reason: Optional[str] = None) -> Dict[str, Any]:
443
- return {
444
- "id": f"chatcmpl-{uuid.uuid4()}",
445
- "object": "chat.completion.chunk",
446
- "created": int(datetime.now().timestamp()),
447
- "model": model,
448
- "choices": [
449
- {
450
- "index": 0,
451
- "delta": {"content": content, "role": "assistant"},
452
- "finish_reason": finish_reason,
453
- }
454
- ],
455
- "usage": None,
456
- }
457
-
458
- @app.post("/v1/chat/completions", dependencies=[Depends(rate_limiter)])
459
- async def chat_completions(request: ChatRequest, req: Request, api_key: str = Depends(get_api_key)):
460
  # Redact user messages only for logging purposes
461
  redacted_messages = [{"role": msg.role, "content": "[redacted]"} for msg in request.messages]
462
 
@@ -484,6 +501,14 @@ async def chat_completions(request: ChatRequest, req: Request, api_key: str = De
484
  if isinstance(chunk, ImageResponse):
485
  image_markdown = f"![image]({chunk.url})"
486
  response_chunk = create_response(image_markdown, request.model)
 
 
 
 
 
 
 
 
487
  else:
488
  response_chunk = create_response(chunk, request.model)
489
 
@@ -504,6 +529,14 @@ async def chat_completions(request: ChatRequest, req: Request, api_key: str = De
504
  async for chunk in async_generator:
505
  if isinstance(chunk, ImageResponse):
506
  response_content += f"![image]({chunk.url})\n"
 
 
 
 
 
 
 
 
507
  else:
508
  response_content += chunk
509
 
@@ -549,20 +582,17 @@ async def chat_completions_get():
549
 
550
  @app.get("/v1/models")
551
  async def get_models(req: Request):
552
- client_ip = req.client.host
553
  logger.info(f"Fetching available models")
554
  return {"data": [{"id": model, "object": "model"} for model in Blackbox.models]}
555
 
556
  # Additional endpoints for better functionality
557
  @app.get("/v1/health")
558
  async def health_check(req: Request):
559
- client_ip = req.client.host
560
  logger.info(f"Health check requested")
561
  return {"status": "ok"}
562
 
563
  @app.get("/v1/models/{model}/status")
564
  async def model_status(model: str, req: Request):
565
- client_ip = req.client.host
566
  logger.info(f"Model status requested for '{model}'")
567
  if model in Blackbox.models:
568
  return {"model": model, "status": "available"}
@@ -576,8 +606,7 @@ async def model_status(model: str, req: Request):
576
  # Custom exception handler to match OpenAI's error format
577
  @app.exception_handler(HTTPException)
578
  async def http_exception_handler(request: Request, exc: HTTPException):
579
- client_ip = request.client.host
580
- logger.error(f"HTTPException: {exc.detail} | Client IP: {client_ip}")
581
  return JSONResponse(
582
  status_code=exc.status_code,
583
  content={
@@ -596,7 +625,6 @@ class TokenizerRequest(BaseModel):
596
 
597
  @app.post("/v1/tokenizer")
598
  async def tokenizer(request: TokenizerRequest, req: Request):
599
- client_ip = req.client.host
600
  text = request.text
601
  token_count = len(text.split())
602
  logger.info(f"Tokenizer called | Tokens: {token_count}")
 
1
+ # main.py
2
  import os
3
  import re
4
  import random
 
8
  import logging
9
  import asyncio
10
  import time
11
+ import contextvars
12
  from collections import defaultdict
13
  from typing import List, Dict, Any, Optional, AsyncGenerator, Union
14
  from datetime import datetime
 
15
 
16
+ from slowapi import Limiter, _rate_limit_exceeded_handler
17
+ from slowapi.util import get_remote_address
18
+ from slowapi.errors import RateLimitExceeded
19
+ from slowapi.middleware import SlowAPIMiddleware
20
  from fastapi import FastAPI, HTTPException, Request, Depends, Header
21
  from fastapi.responses import StreamingResponse, JSONResponse, RedirectResponse
22
  from pydantic import BaseModel
23
 
24
+ from sqlalchemy.orm import Session
25
+
26
+ from aiohttp import ClientSession, ClientTimeout, ClientError
27
+
28
+ from database import SessionLocal, engine, get_db
29
+ from models import Base, Image, Log
30
+
31
+ from dotenv import load_dotenv
32
+
33
+ # Load environment variables from .env file
34
+ load_dotenv()
35
+
36
+ # Create all tables
37
+ Base.metadata.create_all(bind=engine)
38
+
39
  # Define a context variable for client_ip
40
  client_ip_var = contextvars.ContextVar("client_ip", default="N/A")
41
 
 
45
  record.client_ip = client_ip_var.get()
46
  return True
47
 
48
+ # Custom logging formatter to handle missing client_ip
49
+ class SafeFormatter(logging.Formatter):
50
+ def format(self, record):
51
+ if not hasattr(record, 'client_ip'):
52
+ record.client_ip = 'N/A'
53
+ return super().format(record)
54
+
55
  # Configure logging
56
+ logger = logging.getLogger("main") # Use a specific logger name if needed
57
+ logger.setLevel(logging.INFO)
58
+
59
+ # Create handlers
60
+ console_handler = logging.StreamHandler()
61
+ console_handler.setLevel(logging.INFO)
62
+
63
+ # Create and set the custom formatter
64
+ formatter = SafeFormatter(
65
+ fmt="%(asctime)s [%(levelname)s] %(name)s [IP: %(client_ip)s]: %(message)s",
66
+ datefmt="%Y-%m-%d %H:%M:%S"
67
  )
68
+ console_handler.setFormatter(formatter)
69
+
70
+ # Add the custom filter to the console handler
71
+ console_handler.addFilter(ContextFilter())
72
+
73
+ # Add handlers to the logger
74
+ logger.addHandler(console_handler)
75
+
76
+ # Initialize the limiter with slowapi
77
+ limiter = Limiter(key_func=get_remote_address, default_limits=["60/minute"])
78
+ app = FastAPI()
79
+
80
+ # Register the rate limit exceeded handler
81
+ app.state.limiter = limiter
82
+ app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
83
+
84
+ # Add SlowAPI middleware
85
+ app.add_middleware(SlowAPIMiddleware)
86
+
87
+ from logging import Handler
88
+
89
+ class DBLogHandler(Handler):
90
+ def __init__(self, db: Session):
91
+ super().__init__()
92
+ self.db = db
93
+
94
+ def emit(self, record):
95
+ log_entry = Log(
96
+ level=record.levelname,
97
+ message=record.getMessage(),
98
+ client_ip=getattr(record, 'client_ip', None)
99
+ )
100
+ try:
101
+ self.db.add(log_entry)
102
+ self.db.commit()
103
+ except Exception as e:
104
+ # Handle exceptions (e.g., rollback)
105
+ self.db.rollback()
106
+ print(f"Failed to log to database: {e}")
107
+
108
+ # Dependency to add DBLogHandler
109
+ async def get_db_log_handler(request: Request):
110
+ db = next(get_db())
111
+ db_log_handler = DBLogHandler(db)
112
+ logger.addHandler(db_log_handler)
113
+ try:
114
+ yield
115
+ finally:
116
+ logger.removeHandler(db_log_handler)
117
+ db.close()
118
+
119
+ @app.middleware("http")
120
+ async def security_middleware(request: Request, call_next):
121
+ client_ip = request.client.host
122
+ # Set the client_ip in the context variable
123
+ client_ip_var.set(client_ip)
124
+
125
+ # Enforce that POST requests to sensitive endpoints must have a valid Content-Type
126
+ if request.method == "POST" and request.url.path == "/v1/chat/completions":
127
+ content_type = request.headers.get("Content-Type")
128
+ if content_type != "application/json":
129
+ logger.warning("Invalid Content-Type for /v1/chat/completions")
130
+ return JSONResponse(
131
+ status_code=400,
132
+ content={
133
+ "error": {
134
+ "message": "Content-Type must be application/json",
135
+ "type": "invalid_request_error",
136
+ "param": None,
137
+ "code": None
138
+ }
139
+ },
140
+ )
141
+
142
+ # Log the incoming request
143
+ logger.info(f"Incoming request: {request.method} {request.url.path}")
144
+
145
+ response = await call_next(request)
146
+
147
+ # Log the response status
148
+ logger.info(f"Response status: {response.status_code}")
149
+
150
+ return response
151
+
152
+ class Message(BaseModel):
153
+ role: str
154
+ content: str
155
+
156
+ class ChatRequest(BaseModel):
157
+ model: str
158
+ messages: List[Message]
159
+ temperature: Optional[float] = 1.0
160
+ top_p: Optional[float] = 1.0
161
+ n: Optional[int] = 1
162
+ stream: Optional[bool] = False
163
+ stop: Optional[Union[str, List[str]]] = None
164
+ max_tokens: Optional[int] = None
165
+ presence_penalty: Optional[float] = 0.0
166
+ frequency_penalty: Optional[float] = 0.0
167
+ logit_bias: Optional[Dict[str, float]] = None
168
+ user: Optional[str] = None
169
+ webSearchMode: Optional[bool] = False # Custom parameter
170
+
171
+ def create_response(content: str, model: str, finish_reason: Optional[str] = None) -> Dict[str, Any]:
172
+ return {
173
+ "id": f"chatcmpl-{uuid.uuid4()}",
174
+ "object": "chat.completion.chunk",
175
+ "created": int(datetime.now().timestamp()),
176
+ "model": model,
177
+ "choices": [
178
+ {
179
+ "index": 0,
180
+ "delta": {"content": content, "role": "assistant"},
181
+ "finish_reason": finish_reason,
182
+ }
183
+ ],
184
+ "usage": None,
185
+ }
186
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
  class ImageResponse:
188
  def __init__(self, url: str, alt: str):
189
  self.url = url
 
230
  ]
231
 
232
  # Filter models based on AVAILABLE_MODELS
233
+ AVAILABLE_MODELS = os.getenv("AVAILABLE_MODELS", "")
234
  if AVAILABLE_MODELS:
235
+ AVAILABLE_MODELS = [model.strip() for model in AVAILABLE_MODELS.split(',') if model.strip()]
236
  models = [model for model in models if model in AVAILABLE_MODELS]
237
 
238
  agentMode = {
 
472
  if attempt == retry_attempts - 1:
473
  raise HTTPException(status_code=500, detail=str(e))
474
 
475
+ @app.post("/v1/chat/completions", dependencies=[Depends(limiter.limit("60/minute")), Depends(get_db_log_handler)])
476
+ async def chat_completions(request: ChatRequest, req: Request, api_key: str = Depends(get_api_key), db: Session = Depends(get_db)):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
477
  # Redact user messages only for logging purposes
478
  redacted_messages = [{"role": msg.role, "content": "[redacted]"} for msg in request.messages]
479
 
 
501
  if isinstance(chunk, ImageResponse):
502
  image_markdown = f"![image]({chunk.url})"
503
  response_chunk = create_response(image_markdown, request.model)
504
+
505
+ # Store image in the database
506
+ image_entry = Image(
507
+ image_url=chunk.url,
508
+ description=request.messages[-1].get('content', '')
509
+ )
510
+ db.add(image_entry)
511
+ db.commit()
512
  else:
513
  response_chunk = create_response(chunk, request.model)
514
 
 
529
  async for chunk in async_generator:
530
  if isinstance(chunk, ImageResponse):
531
  response_content += f"![image]({chunk.url})\n"
532
+
533
+ # Store image in the database
534
+ image_entry = Image(
535
+ image_url=chunk.url,
536
+ description=request.messages[-1].get('content', '')
537
+ )
538
+ db.add(image_entry)
539
+ db.commit()
540
  else:
541
  response_content += chunk
542
 
 
582
 
583
  @app.get("/v1/models")
584
  async def get_models(req: Request):
 
585
  logger.info(f"Fetching available models")
586
  return {"data": [{"id": model, "object": "model"} for model in Blackbox.models]}
587
 
588
  # Additional endpoints for better functionality
589
  @app.get("/v1/health")
590
  async def health_check(req: Request):
 
591
  logger.info(f"Health check requested")
592
  return {"status": "ok"}
593
 
594
  @app.get("/v1/models/{model}/status")
595
  async def model_status(model: str, req: Request):
 
596
  logger.info(f"Model status requested for '{model}'")
597
  if model in Blackbox.models:
598
  return {"model": model, "status": "available"}
 
606
  # Custom exception handler to match OpenAI's error format
607
  @app.exception_handler(HTTPException)
608
  async def http_exception_handler(request: Request, exc: HTTPException):
609
+ logger.error(f"HTTPException: {exc.detail}")
 
610
  return JSONResponse(
611
  status_code=exc.status_code,
612
  content={
 
625
 
626
  @app.post("/v1/tokenizer")
627
  async def tokenizer(request: TokenizerRequest, req: Request):
 
628
  text = request.text
629
  token_count = len(text.split())
630
  logger.info(f"Tokenizer called | Tokens: {token_count}")