llm_host / main.py
Bahodir Nematjonov
feat: Rate Limiter Per Users
1392014
raw
history blame
4.58 kB
from fastapi import FastAPI, Depends, HTTPException, status, Query, Request
from fastapi.responses import FileResponse, StreamingResponse, JSONResponse
from fastapi.staticfiles import StaticFiles
from jose import JWTError
from schemas import UserRegister, TokenResponse, RefreshTokenRequest, QueryInput
from auth import register_user, get_db, authenticate_user, create_token, verify_token, verify_access_token, Session
from utils import generate_stream, generate_response, shutdown_event
from fastapi.security import OAuth2PasswordRequestForm
from pathlib import Path
from datetime import timedelta
import os
import logging
from dotenv import load_dotenv
load_dotenv()
# Import SlowAPI for Rate Limiting
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.util import get_remote_address
from slowapi.middleware import SlowAPIMiddleware
logging.basicConfig(level=logging.INFO)
SECRET_KEY = os.getenv("SECRET_KEY")
REFRESH_SECRET_KEY = os.getenv("REFRESH_SECRET_KEY")
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30
REFRESH_TOKEN_EXPIRE_DAYS = 7
app = FastAPI()
# Initialize Rate Limiter
limiter = Limiter(key_func=get_remote_address)
app.state.limiter = limiter
# Attach Rate Limit Exceeded Handler
app.add_exception_handler(429, _rate_limit_exceeded_handler)
# Add Middleware for Rate Limiting
app.add_middleware(SlowAPIMiddleware)
app.mount("/static", StaticFiles(directory="static"), name="static")
# Entry Endpoint
@app.get('/')
def index() -> FileResponse:
file_path = Path(__file__).parent / 'static' / 'index.html'
return FileResponse(path=str(file_path), media_type='text/html')
# Apply Rate Limiting on Register API (Limit: 5 requests per minute)
@app.post("/register")
@limiter.limit("5/minute")
async def register(request: Request, user: UserRegister, db: Session = Depends(get_db)):
"""Registers a new user."""
new_user = register_user(user.username, user.password, db)
return {"message": "User registered successfully", "user": new_user.username}
# Apply Rate Limiting on Login API (Limit: 10 requests per minute)
@app.post("/login", response_model=TokenResponse)
@limiter.limit("10/minute")
async def login(request: Request, form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)):
try:
user = authenticate_user(form_data.username, form_data.password, db)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid username or password",
headers={"WWW-Authenticate": "Bearer"},
)
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
refresh_token_expires = timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS)
access_token = create_token(
data={"sub": user.username},
expires_delta=access_token_expires,
secret_key=SECRET_KEY
)
refresh_token = create_token(
data={"sub": user.username},
expires_delta=refresh_token_expires,
secret_key=REFRESH_SECRET_KEY
)
return {
"access_token": access_token,
"refresh_token": refresh_token,
"token_type": "bearer"
}
except Exception as e:
logging.error(f"Login error: {str(e)}")
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error")
# Apply Rate Limiting on Generate API (Limit: 3 requests per 10 seconds)
@app.post("/generate")
@limiter.limit("3/10seconds")
async def generate(
request: Request,
query_input: QueryInput,
username: str = Depends(verify_access_token),
stream: bool = Query(False, description="Enable streaming response"),
):
"""Handles both streaming and non-streaming responses, with shutdown detection."""
if shutdown_event.is_set():
return JSONResponse({"message": "Server shutting down..."})
try:
if stream:
return StreamingResponse(generate_stream(query_input.query), media_type="text/event-stream")
else:
return JSONResponse(await generate_response(query_input.query))
except Exception as e:
logging.error(f"Error in generate endpoint: {e}")
raise HTTPException(status_code=500, detail="Internal server error")
# WebSocket endpoint for streaming
@app.on_event("startup")
async def startup_event():
os.makedirs("./cache", exist_ok=True)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)