Spaces:
Running
Running
Bahodir Nematjonov
commited on
Commit
·
1392014
1
Parent(s):
a52b206
feat: Rate Limiter Per Users
Browse files
main.py
CHANGED
@@ -1,26 +1,43 @@
|
|
1 |
-
from fastapi import FastAPI, Depends, HTTPException, status, Query
|
2 |
from fastapi.responses import FileResponse, StreamingResponse, JSONResponse
|
3 |
from fastapi.staticfiles import StaticFiles
|
4 |
from jose import JWTError
|
5 |
from schemas import UserRegister, TokenResponse, RefreshTokenRequest, QueryInput
|
6 |
-
from auth import
|
7 |
from utils import generate_stream, generate_response, shutdown_event
|
8 |
from fastapi.security import OAuth2PasswordRequestForm
|
9 |
from pathlib import Path
|
10 |
from datetime import timedelta
|
11 |
import os
|
12 |
import logging
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
logging.basicConfig(level=logging.INFO)
|
15 |
|
16 |
-
SECRET_KEY = os.getenv("SECRET_KEY"
|
17 |
-
REFRESH_SECRET_KEY = os.getenv("REFRESH_SECRET_KEY"
|
18 |
ALGORITHM = "HS256"
|
19 |
ACCESS_TOKEN_EXPIRE_MINUTES = 30
|
20 |
REFRESH_TOKEN_EXPIRE_DAYS = 7
|
21 |
|
22 |
app = FastAPI()
|
23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
app.mount("/static", StaticFiles(directory="static"), name="static")
|
25 |
|
26 |
# Entry Endpoint
|
@@ -29,14 +46,18 @@ def index() -> FileResponse:
|
|
29 |
file_path = Path(__file__).parent / 'static' / 'index.html'
|
30 |
return FileResponse(path=str(file_path), media_type='text/html')
|
31 |
|
|
|
32 |
@app.post("/register")
|
33 |
-
|
|
|
34 |
"""Registers a new user."""
|
35 |
new_user = register_user(user.username, user.password, db)
|
36 |
return {"message": "User registered successfully", "user": new_user.username}
|
37 |
|
|
|
38 |
@app.post("/login", response_model=TokenResponse)
|
39 |
-
|
|
|
40 |
try:
|
41 |
user = authenticate_user(form_data.username, form_data.password, db)
|
42 |
|
@@ -70,38 +91,15 @@ async def login(form_data: OAuth2PasswordRequestForm = Depends(), db: Session =
|
|
70 |
logging.error(f"Login error: {str(e)}")
|
71 |
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error")
|
72 |
|
73 |
-
|
74 |
-
async def refresh(refresh_request: RefreshTokenRequest):
|
75 |
-
try:
|
76 |
-
# Verify the refresh token
|
77 |
-
username = verify_token(refresh_request.refresh_token, REFRESH_SECRET_KEY)
|
78 |
-
|
79 |
-
# Create new access token
|
80 |
-
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
81 |
-
access_token = create_token(
|
82 |
-
data={"sub": username},
|
83 |
-
expires_delta=access_token_expires,
|
84 |
-
secret_key=SECRET_KEY
|
85 |
-
)
|
86 |
-
|
87 |
-
return {
|
88 |
-
"access_token": access_token,
|
89 |
-
"refresh_token": refresh_request.refresh_token, # Return the same refresh token
|
90 |
-
"token_type": "bearer"
|
91 |
-
}
|
92 |
-
|
93 |
-
except JWTError:
|
94 |
-
raise HTTPException(
|
95 |
-
status_code=status.HTTP_401_UNAUTHORIZED,
|
96 |
-
detail="Could not validate credentials",
|
97 |
-
headers={"WWW-Authenticate": "Bearer"},
|
98 |
-
)
|
99 |
-
|
100 |
@app.post("/generate")
|
|
|
101 |
async def generate(
|
|
|
102 |
query_input: QueryInput,
|
103 |
username: str = Depends(verify_access_token),
|
104 |
stream: bool = Query(False, description="Enable streaming response"),
|
|
|
105 |
):
|
106 |
"""Handles both streaming and non-streaming responses, with shutdown detection."""
|
107 |
if shutdown_event.is_set():
|
@@ -123,4 +121,4 @@ async def startup_event():
|
|
123 |
|
124 |
if __name__ == "__main__":
|
125 |
import uvicorn
|
126 |
-
uvicorn.run(app, host="0.0.0.0", port=7860)
|
|
|
1 |
+
from fastapi import FastAPI, Depends, HTTPException, status, Query, Request
|
2 |
from fastapi.responses import FileResponse, StreamingResponse, JSONResponse
|
3 |
from fastapi.staticfiles import StaticFiles
|
4 |
from jose import JWTError
|
5 |
from schemas import UserRegister, TokenResponse, RefreshTokenRequest, QueryInput
|
6 |
+
from auth import register_user, get_db, authenticate_user, create_token, verify_token, verify_access_token, Session
|
7 |
from utils import generate_stream, generate_response, shutdown_event
|
8 |
from fastapi.security import OAuth2PasswordRequestForm
|
9 |
from pathlib import Path
|
10 |
from datetime import timedelta
|
11 |
import os
|
12 |
import logging
|
13 |
+
from dotenv import load_dotenv
|
14 |
+
|
15 |
+
load_dotenv()
|
16 |
+
# Import SlowAPI for Rate Limiting
|
17 |
+
from slowapi import Limiter, _rate_limit_exceeded_handler
|
18 |
+
from slowapi.util import get_remote_address
|
19 |
+
from slowapi.middleware import SlowAPIMiddleware
|
20 |
|
21 |
logging.basicConfig(level=logging.INFO)
|
22 |
|
23 |
+
SECRET_KEY = os.getenv("SECRET_KEY")
|
24 |
+
REFRESH_SECRET_KEY = os.getenv("REFRESH_SECRET_KEY")
|
25 |
ALGORITHM = "HS256"
|
26 |
ACCESS_TOKEN_EXPIRE_MINUTES = 30
|
27 |
REFRESH_TOKEN_EXPIRE_DAYS = 7
|
28 |
|
29 |
app = FastAPI()
|
30 |
|
31 |
+
# Initialize Rate Limiter
|
32 |
+
limiter = Limiter(key_func=get_remote_address)
|
33 |
+
app.state.limiter = limiter
|
34 |
+
|
35 |
+
# Attach Rate Limit Exceeded Handler
|
36 |
+
app.add_exception_handler(429, _rate_limit_exceeded_handler)
|
37 |
+
|
38 |
+
# Add Middleware for Rate Limiting
|
39 |
+
app.add_middleware(SlowAPIMiddleware)
|
40 |
+
|
41 |
app.mount("/static", StaticFiles(directory="static"), name="static")
|
42 |
|
43 |
# Entry Endpoint
|
|
|
46 |
file_path = Path(__file__).parent / 'static' / 'index.html'
|
47 |
return FileResponse(path=str(file_path), media_type='text/html')
|
48 |
|
49 |
+
# Apply Rate Limiting on Register API (Limit: 5 requests per minute)
|
50 |
@app.post("/register")
|
51 |
+
@limiter.limit("5/minute")
|
52 |
+
async def register(request: Request, user: UserRegister, db: Session = Depends(get_db)):
|
53 |
"""Registers a new user."""
|
54 |
new_user = register_user(user.username, user.password, db)
|
55 |
return {"message": "User registered successfully", "user": new_user.username}
|
56 |
|
57 |
+
# Apply Rate Limiting on Login API (Limit: 10 requests per minute)
|
58 |
@app.post("/login", response_model=TokenResponse)
|
59 |
+
@limiter.limit("10/minute")
|
60 |
+
async def login(request: Request, form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)):
|
61 |
try:
|
62 |
user = authenticate_user(form_data.username, form_data.password, db)
|
63 |
|
|
|
91 |
logging.error(f"Login error: {str(e)}")
|
92 |
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error")
|
93 |
|
94 |
+
# Apply Rate Limiting on Generate API (Limit: 3 requests per 10 seconds)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
@app.post("/generate")
|
96 |
+
@limiter.limit("3/10seconds")
|
97 |
async def generate(
|
98 |
+
request: Request,
|
99 |
query_input: QueryInput,
|
100 |
username: str = Depends(verify_access_token),
|
101 |
stream: bool = Query(False, description="Enable streaming response"),
|
102 |
+
|
103 |
):
|
104 |
"""Handles both streaming and non-streaming responses, with shutdown detection."""
|
105 |
if shutdown_event.is_set():
|
|
|
121 |
|
122 |
if __name__ == "__main__":
|
123 |
import uvicorn
|
124 |
+
uvicorn.run(app, host="0.0.0.0", port=7860)
|