Niansuh commited on
Commit
8ace20c
·
verified ·
1 Parent(s): 3cfd9e7

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +33 -6
main.py CHANGED
@@ -12,8 +12,8 @@ from typing import List, Dict, Any, Optional, AsyncGenerator, Union
12
  from datetime import datetime
13
 
14
  from aiohttp import ClientSession, ClientTimeout, ClientError
15
- from fastapi import FastAPI, HTTPException, Request, Depends, Header
16
- from fastapi.responses import StreamingResponse, JSONResponse
17
  from pydantic import BaseModel
18
 
19
  # Configure logging
@@ -43,9 +43,9 @@ else:
43
  rate_limit_store = defaultdict(lambda: {"count": 0, "timestamp": time.time()})
44
  ip_rate_limit_store = defaultdict(lambda: {"count": 0, "timestamp": time.time()})
45
 
46
- async def get_api_key(authorization: str = Header(...)) -> str:
47
- if not authorization.startswith('Bearer '):
48
- logger.warning("Invalid authorization header format.")
49
  raise HTTPException(status_code=401, detail='Invalid authorization header format')
50
  api_key = authorization[7:]
51
  if api_key not in API_KEYS:
@@ -202,7 +202,7 @@ class Blackbox:
202
  }
203
 
204
  @classmethod
205
- def get_model(cls, model: str) -> str:
206
  if model in cls.models:
207
  return model
208
  elif model in cls.userSelectedModel and cls.userSelectedModel[model] in cls.models:
@@ -373,6 +373,27 @@ class Blackbox:
373
  # FastAPI app setup
374
  app = FastAPI()
375
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
376
  class Message(BaseModel):
377
  role: str
378
  content: str
@@ -492,6 +513,12 @@ async def chat_completions(request: ChatRequest, req: Request, api_key: str = De
492
  logger.exception("An unexpected error occurred while processing the chat completions request.")
493
  raise HTTPException(status_code=500, detail=str(e))
494
 
 
 
 
 
 
 
495
  @app.get("/v1/models")
496
  async def get_models():
497
  logger.info("Fetching available models")
 
12
  from datetime import datetime
13
 
14
  from aiohttp import ClientSession, ClientTimeout, ClientError
15
+ from fastapi import FastAPI, HTTPException, Request, Depends, Header, Response
16
+ from fastapi.responses import StreamingResponse, JSONResponse, HTMLResponse, PlainTextResponse
17
  from pydantic import BaseModel
18
 
19
  # Configure logging
 
43
  rate_limit_store = defaultdict(lambda: {"count": 0, "timestamp": time.time()})
44
  ip_rate_limit_store = defaultdict(lambda: {"count": 0, "timestamp": time.time()})
45
 
46
+ async def get_api_key(authorization: str = Header(None)) -> str:
47
+ if authorization is None or not authorization.startswith('Bearer '):
48
+ logger.warning("Invalid or missing authorization header.")
49
  raise HTTPException(status_code=401, detail='Invalid authorization header format')
50
  api_key = authorization[7:]
51
  if api_key not in API_KEYS:
 
202
  }
203
 
204
  @classmethod
205
+ def get_model(cls, model: str) -> Optional[str]:
206
  if model in cls.models:
207
  return model
208
  elif model in cls.userSelectedModel and cls.userSelectedModel[model] in cls.models:
 
373
  # FastAPI app setup
374
  app = FastAPI()
375
 
376
+ # Middleware to enhance security
377
+ @app.middleware("http")
378
+ async def security_middleware(request: Request, call_next):
379
+ # Enforce that POST requests to sensitive endpoints must have a valid Content-Type
380
+ if request.method == "POST" and request.url.path in ["/v1/chat/completions", "/v1/completions"]:
381
+ content_type = request.headers.get("Content-Type")
382
+ if content_type != "application/json":
383
+ return JSONResponse(
384
+ status_code=400,
385
+ content={
386
+ "error": {
387
+ "message": "Content-Type must be application/json",
388
+ "type": "invalid_request_error",
389
+ "param": None,
390
+ "code": None
391
+ }
392
+ },
393
+ )
394
+ response = await call_next(request)
395
+ return response
396
+
397
  class Message(BaseModel):
398
  role: str
399
  content: str
 
513
  logger.exception("An unexpected error occurred while processing the chat completions request.")
514
  raise HTTPException(status_code=500, detail=str(e))
515
 
516
+ # Return 'about:blank' when accessing the endpoint via GET
517
+ @app.get("/v1/chat/completions")
518
+ async def chat_completions_get():
519
+ logger.info("GET request made to /v1/chat/completions, returning 'about:blank'")
520
+ return HTMLResponse(content='', status_code=200, media_type='text/html')
521
+
522
  @app.get("/v1/models")
523
  async def get_models():
524
  logger.info("Fetching available models")