from fastapi import FastAPI, Depends, HTTPException, status, Query, Request from fastapi.responses import FileResponse, StreamingResponse, JSONResponse from fastapi.staticfiles import StaticFiles from jose import JWTError from schemas import UserRegister, TokenResponse, RefreshTokenRequest, QueryInput from auth import register_user, get_db, authenticate_user, create_token, verify_token, verify_access_token, Session from utils import generate_stream, generate_response, shutdown_event from fastapi.security import OAuth2PasswordRequestForm from pathlib import Path from datetime import timedelta import os import logging from dotenv import load_dotenv # Load environment variables load_dotenv() # Import SlowAPI for Rate Limiting from slowapi import Limiter, _rate_limit_exceeded_handler from slowapi.util import get_remote_address from slowapi.middleware import SlowAPIMiddleware logging.basicConfig(level=logging.INFO) # Security Keys with Default Fallbacks SECRET_KEY = os.getenv("SECRET_KEY", "default_fallback_256_bit_key") REFRESH_SECRET_KEY = os.getenv("REFRESH_SECRET_KEY", SECRET_KEY) ALGORITHM = "HS256" ACCESS_TOKEN_EXPIRE_MINUTES = 30 REFRESH_TOKEN_EXPIRE_DAYS = 7 app = FastAPI() # Custom Key Function for Rate Limiting (Handles Proxies) def custom_key_func(request: Request): forwarded = request.headers.get("X-Forwarded-For") if forwarded: return forwarded.split(",")[0] # Get real client IP if behind proxy return get_remote_address(request) # Initialize Rate Limiter limiter = Limiter(key_func=custom_key_func) app.state.limiter = limiter # Attach Rate Limit Exceeded Handler app.add_exception_handler(429, _rate_limit_exceeded_handler) # Custom Rate Limit Response @app.exception_handler(429) async def rate_limit_exceeded_handler(request: Request, exc): return JSONResponse( status_code=429, content={"error": "Rate limit exceeded. Please try again later."} ) # Add Middleware for Rate Limiting app.add_middleware(SlowAPIMiddleware) app.mount("/static", StaticFiles(directory="static"), name="static") # Entry Endpoint @app.get('/') def index() -> FileResponse: file_path = Path(__file__).parent / 'static' / 'index.html' return FileResponse(path=str(file_path), media_type='text/html') # Apply Rate Limiting on Register API (Limit: 5 requests per minute) @app.post("/register") @limiter.limit("5/minute") async def register(request: Request, user: UserRegister, db: Session = Depends(get_db)): """Registers a new user.""" new_user = register_user(user.username, user.password, db) return {"message": "User registered successfully", "user": new_user.username} # Apply Rate Limiting on Login API (Limit: 10 requests per minute) @app.post("/login", response_model=TokenResponse) @limiter.limit("10/minute") async def login(request: Request, form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)): try: user = authenticate_user(form_data.username, form_data.password, db) if not user: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid username or password", headers={"WWW-Authenticate": "Bearer"}, ) access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) refresh_token_expires = timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS) access_token = create_token( data={"sub": user.username}, expires_delta=access_token_expires, secret_key=SECRET_KEY ) refresh_token = create_token( data={"sub": user.username}, expires_delta=refresh_token_expires, secret_key=REFRESH_SECRET_KEY ) return { "access_token": access_token, "refresh_token": refresh_token, "token_type": "bearer" } except Exception as e: logging.error(f"Login error: {str(e)}") raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") @app.post("/refresh", response_model=TokenResponse) async def refresh(refresh_request: RefreshTokenRequest): try: # Verify the refresh token username = verify_token(refresh_request.refresh_token, REFRESH_SECRET_KEY) # Create new access token access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) access_token = create_token( data={"sub": username}, expires_delta=access_token_expires, secret_key=SECRET_KEY ) return { "access_token": access_token, "refresh_token": refresh_request.refresh_token, # Return the same refresh token "token_type": "bearer" } except JWTError: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials", headers={"WWW-Authenticate": "Bearer"}, ) # Apply Rate Limiting on Generate API (Limit: 3 requests per 10 seconds) @app.post("/generate") @limiter.limit("3/10seconds") async def generate( request: Request, query_input: QueryInput, username: str = Depends(verify_access_token), stream: bool = Query(False, description="Enable streaming response"), ): """Handles both streaming and non-streaming responses, with shutdown detection.""" if shutdown_event.is_set(): return JSONResponse({"message": "Server shutting down..."}) try: if stream: return StreamingResponse(generate_stream(query_input.query), media_type="text/event-stream") else: return JSONResponse(await generate_response(query_input.query)) except Exception as e: logging.error(f"Error in generate endpoint: {e}") raise HTTPException(status_code=500, detail="Internal server error") # WebSocket endpoint for streaming @app.on_event("startup") async def startup_event(): os.makedirs("./cache", exist_ok=True) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)