Bahodir Nematjonov commited on
Commit
1392014
·
1 Parent(s): a52b206

feat: Rate Limiter Per Users

Browse files
Files changed (1) hide show
  1. main.py +32 -34
main.py CHANGED
@@ -1,26 +1,43 @@
1
- from fastapi import FastAPI, Depends, HTTPException, status, Query
2
  from fastapi.responses import FileResponse, StreamingResponse, JSONResponse
3
  from fastapi.staticfiles import StaticFiles
4
  from jose import JWTError
5
  from schemas import UserRegister, TokenResponse, RefreshTokenRequest, QueryInput
6
- from auth import register_user, get_db, authenticate_user, create_token, verify_token, verify_access_token, Session
7
  from utils import generate_stream, generate_response, shutdown_event
8
  from fastapi.security import OAuth2PasswordRequestForm
9
  from pathlib import Path
10
  from datetime import timedelta
11
  import os
12
  import logging
 
 
 
 
 
 
 
13
 
14
  logging.basicConfig(level=logging.INFO)
15
 
16
- SECRET_KEY = os.getenv("SECRET_KEY", 'def6nQHONW99pOPyba9DShny6FB1CJJBigZault')
17
- REFRESH_SECRET_KEY = os.getenv("REFRESH_SECRET_KEY", SECRET_KEY)
18
  ALGORITHM = "HS256"
19
  ACCESS_TOKEN_EXPIRE_MINUTES = 30
20
  REFRESH_TOKEN_EXPIRE_DAYS = 7
21
 
22
  app = FastAPI()
23
 
 
 
 
 
 
 
 
 
 
 
24
  app.mount("/static", StaticFiles(directory="static"), name="static")
25
 
26
  # Entry Endpoint
@@ -29,14 +46,18 @@ def index() -> FileResponse:
29
  file_path = Path(__file__).parent / 'static' / 'index.html'
30
  return FileResponse(path=str(file_path), media_type='text/html')
31
 
 
32
  @app.post("/register")
33
- async def register(user: UserRegister, db: Session = Depends(get_db)):
 
34
  """Registers a new user."""
35
  new_user = register_user(user.username, user.password, db)
36
  return {"message": "User registered successfully", "user": new_user.username}
37
 
 
38
  @app.post("/login", response_model=TokenResponse)
39
- async def login(form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)):
 
40
  try:
41
  user = authenticate_user(form_data.username, form_data.password, db)
42
 
@@ -70,38 +91,15 @@ async def login(form_data: OAuth2PasswordRequestForm = Depends(), db: Session =
70
  logging.error(f"Login error: {str(e)}")
71
  raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error")
72
 
73
- @app.post("/refresh", response_model=TokenResponse)
74
- async def refresh(refresh_request: RefreshTokenRequest):
75
- try:
76
- # Verify the refresh token
77
- username = verify_token(refresh_request.refresh_token, REFRESH_SECRET_KEY)
78
-
79
- # Create new access token
80
- access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
81
- access_token = create_token(
82
- data={"sub": username},
83
- expires_delta=access_token_expires,
84
- secret_key=SECRET_KEY
85
- )
86
-
87
- return {
88
- "access_token": access_token,
89
- "refresh_token": refresh_request.refresh_token, # Return the same refresh token
90
- "token_type": "bearer"
91
- }
92
-
93
- except JWTError:
94
- raise HTTPException(
95
- status_code=status.HTTP_401_UNAUTHORIZED,
96
- detail="Could not validate credentials",
97
- headers={"WWW-Authenticate": "Bearer"},
98
- )
99
-
100
  @app.post("/generate")
 
101
  async def generate(
 
102
  query_input: QueryInput,
103
  username: str = Depends(verify_access_token),
104
  stream: bool = Query(False, description="Enable streaming response"),
 
105
  ):
106
  """Handles both streaming and non-streaming responses, with shutdown detection."""
107
  if shutdown_event.is_set():
@@ -123,4 +121,4 @@ async def startup_event():
123
 
124
  if __name__ == "__main__":
125
  import uvicorn
126
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
+ from fastapi import FastAPI, Depends, HTTPException, status, Query, Request
2
  from fastapi.responses import FileResponse, StreamingResponse, JSONResponse
3
  from fastapi.staticfiles import StaticFiles
4
  from jose import JWTError
5
  from schemas import UserRegister, TokenResponse, RefreshTokenRequest, QueryInput
6
+ from auth import register_user, get_db, authenticate_user, create_token, verify_token, verify_access_token, Session
7
  from utils import generate_stream, generate_response, shutdown_event
8
  from fastapi.security import OAuth2PasswordRequestForm
9
  from pathlib import Path
10
  from datetime import timedelta
11
  import os
12
  import logging
13
+ from dotenv import load_dotenv
14
+
15
+ load_dotenv()
16
+ # Import SlowAPI for Rate Limiting
17
+ from slowapi import Limiter, _rate_limit_exceeded_handler
18
+ from slowapi.util import get_remote_address
19
+ from slowapi.middleware import SlowAPIMiddleware
20
 
21
  logging.basicConfig(level=logging.INFO)
22
 
23
+ SECRET_KEY = os.getenv("SECRET_KEY")
24
+ REFRESH_SECRET_KEY = os.getenv("REFRESH_SECRET_KEY")
25
  ALGORITHM = "HS256"
26
  ACCESS_TOKEN_EXPIRE_MINUTES = 30
27
  REFRESH_TOKEN_EXPIRE_DAYS = 7
28
 
29
  app = FastAPI()
30
 
31
+ # Initialize Rate Limiter
32
+ limiter = Limiter(key_func=get_remote_address)
33
+ app.state.limiter = limiter
34
+
35
+ # Attach Rate Limit Exceeded Handler
36
+ app.add_exception_handler(429, _rate_limit_exceeded_handler)
37
+
38
+ # Add Middleware for Rate Limiting
39
+ app.add_middleware(SlowAPIMiddleware)
40
+
41
  app.mount("/static", StaticFiles(directory="static"), name="static")
42
 
43
  # Entry Endpoint
 
46
  file_path = Path(__file__).parent / 'static' / 'index.html'
47
  return FileResponse(path=str(file_path), media_type='text/html')
48
 
49
+ # Apply Rate Limiting on Register API (Limit: 5 requests per minute)
50
  @app.post("/register")
51
+ @limiter.limit("5/minute")
52
+ async def register(request: Request, user: UserRegister, db: Session = Depends(get_db)):
53
  """Registers a new user."""
54
  new_user = register_user(user.username, user.password, db)
55
  return {"message": "User registered successfully", "user": new_user.username}
56
 
57
+ # Apply Rate Limiting on Login API (Limit: 10 requests per minute)
58
  @app.post("/login", response_model=TokenResponse)
59
+ @limiter.limit("10/minute")
60
+ async def login(request: Request, form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)):
61
  try:
62
  user = authenticate_user(form_data.username, form_data.password, db)
63
 
 
91
  logging.error(f"Login error: {str(e)}")
92
  raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error")
93
 
94
+ # Apply Rate Limiting on Generate API (Limit: 3 requests per 10 seconds)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  @app.post("/generate")
96
+ @limiter.limit("3/10seconds")
97
  async def generate(
98
+ request: Request,
99
  query_input: QueryInput,
100
  username: str = Depends(verify_access_token),
101
  stream: bool = Query(False, description="Enable streaming response"),
102
+
103
  ):
104
  """Handles both streaming and non-streaming responses, with shutdown detection."""
105
  if shutdown_event.is_set():
 
121
 
122
  if __name__ == "__main__":
123
  import uvicorn
124
+ uvicorn.run(app, host="0.0.0.0", port=7860)