Spaces:
Running
Running
from fastapi import FastAPI, Depends, HTTPException, status, Query, Request | |
from fastapi.responses import FileResponse, StreamingResponse, JSONResponse | |
from fastapi.staticfiles import StaticFiles | |
from jose import JWTError | |
from schemas import UserRegister, TokenResponse, RefreshTokenRequest, QueryInput | |
from auth import register_user, get_db, authenticate_user, create_token, verify_token, verify_access_token, Session | |
from utils import generate_stream, generate_response, shutdown_event | |
from fastapi.security import OAuth2PasswordRequestForm | |
from pathlib import Path | |
from datetime import timedelta | |
import os | |
import logging | |
from dotenv import load_dotenv | |
# Load environment variables | |
load_dotenv() | |
# Import SlowAPI for Rate Limiting | |
from slowapi import Limiter, _rate_limit_exceeded_handler | |
from slowapi.util import get_remote_address | |
from slowapi.middleware import SlowAPIMiddleware | |
logging.basicConfig(level=logging.INFO) | |
# Security Keys with Default Fallbacks | |
SECRET_KEY = os.getenv("SECRET_KEY", "default_fallback_256_bit_key") | |
REFRESH_SECRET_KEY = os.getenv("REFRESH_SECRET_KEY", SECRET_KEY) | |
ALGORITHM = "HS256" | |
ACCESS_TOKEN_EXPIRE_MINUTES = 30 | |
REFRESH_TOKEN_EXPIRE_DAYS = 7 | |
app = FastAPI() | |
# Custom Key Function for Rate Limiting (Handles Proxies) | |
def custom_key_func(request: Request): | |
forwarded = request.headers.get("X-Forwarded-For") | |
if forwarded: | |
return forwarded.split(",")[0] # Get real client IP if behind proxy | |
return get_remote_address(request) | |
# Initialize Rate Limiter | |
limiter = Limiter(key_func=custom_key_func) | |
app.state.limiter = limiter | |
# Attach Rate Limit Exceeded Handler | |
app.add_exception_handler(429, _rate_limit_exceeded_handler) | |
# Custom Rate Limit Response | |
async def rate_limit_exceeded_handler(request: Request, exc): | |
return JSONResponse( | |
status_code=429, | |
content={"error": "Rate limit exceeded. Please try again later."} | |
) | |
# Add Middleware for Rate Limiting | |
app.add_middleware(SlowAPIMiddleware) | |
app.mount("/static", StaticFiles(directory="static"), name="static") | |
# Entry Endpoint | |
def index() -> FileResponse: | |
file_path = Path(__file__).parent / 'static' / 'index.html' | |
return FileResponse(path=str(file_path), media_type='text/html') | |
# Apply Rate Limiting on Register API (Limit: 5 requests per minute) | |
async def register(request: Request, user: UserRegister, db: Session = Depends(get_db)): | |
"""Registers a new user.""" | |
new_user = register_user(user.username, user.password, db) | |
return {"message": "User registered successfully", "user": new_user.username} | |
# Apply Rate Limiting on Login API (Limit: 10 requests per minute) | |
async def login(request: Request, form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)): | |
try: | |
user = authenticate_user(form_data.username, form_data.password, db) | |
if not user: | |
raise HTTPException( | |
status_code=status.HTTP_401_UNAUTHORIZED, | |
detail="Invalid username or password", | |
headers={"WWW-Authenticate": "Bearer"}, | |
) | |
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) | |
refresh_token_expires = timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS) | |
access_token = create_token( | |
data={"sub": user.username}, | |
expires_delta=access_token_expires, | |
secret_key=SECRET_KEY | |
) | |
refresh_token = create_token( | |
data={"sub": user.username}, | |
expires_delta=refresh_token_expires, | |
secret_key=REFRESH_SECRET_KEY | |
) | |
return { | |
"access_token": access_token, | |
"refresh_token": refresh_token, | |
"token_type": "bearer" | |
} | |
except Exception as e: | |
logging.error(f"Login error: {str(e)}") | |
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") | |
async def refresh(refresh_request: RefreshTokenRequest): | |
try: | |
# Verify the refresh token | |
username = verify_token(refresh_request.refresh_token, REFRESH_SECRET_KEY) | |
# Create new access token | |
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) | |
access_token = create_token( | |
data={"sub": username}, | |
expires_delta=access_token_expires, | |
secret_key=SECRET_KEY | |
) | |
return { | |
"access_token": access_token, | |
"refresh_token": refresh_request.refresh_token, # Return the same refresh token | |
"token_type": "bearer" | |
} | |
except JWTError: | |
raise HTTPException( | |
status_code=status.HTTP_401_UNAUTHORIZED, | |
detail="Could not validate credentials", | |
headers={"WWW-Authenticate": "Bearer"}, | |
) | |
# Apply Rate Limiting on Generate API (Limit: 3 requests per 10 seconds) | |
async def generate( | |
request: Request, | |
query_input: QueryInput, | |
username: str = Depends(verify_access_token), | |
stream: bool = Query(False, description="Enable streaming response"), | |
): | |
"""Handles both streaming and non-streaming responses, with shutdown detection.""" | |
if shutdown_event.is_set(): | |
return JSONResponse({"message": "Server shutting down..."}) | |
try: | |
if stream: | |
return StreamingResponse(generate_stream(query_input.query), media_type="text/event-stream") | |
else: | |
return JSONResponse(await generate_response(query_input.query)) | |
except Exception as e: | |
logging.error(f"Error in generate endpoint: {e}") | |
raise HTTPException(status_code=500, detail="Internal server error") | |
# WebSocket endpoint for streaming | |
async def startup_event(): | |
os.makedirs("./cache", exist_ok=True) | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=7860) | |