Spaces:
Running
Running
File size: 6,183 Bytes
1392014 9f8a30b e4f5d4a 1392014 39308e2 546720a e4f5d4a 546720a e4f5d4a 1392014 73ef21f 1392014 73ef21f 1392014 e4f5d4a 73ef21f e4f5d4a 546720a 73ef21f 1392014 73ef21f 1392014 73ef21f 1392014 03afdca 8b883c8 e4f5d4a 546720a 1392014 e4f5d4a 1392014 e4f5d4a 546720a 1392014 e4f5d4a 1392014 e4f5d4a 546720a e4f5d4a 546720a e4f5d4a 63a5fea 1392014 9f8a30b 1392014 9f8a30b 1392014 e4f5d4a 9f8a30b e4f5d4a 39308e2 e4f5d4a 1392014 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
from fastapi import FastAPI, Depends, HTTPException, status, Query, Request
from fastapi.responses import FileResponse, StreamingResponse, JSONResponse
from fastapi.staticfiles import StaticFiles
from jose import JWTError
from schemas import UserRegister, TokenResponse, RefreshTokenRequest, QueryInput
from auth import register_user, get_db, authenticate_user, create_token, verify_token, verify_access_token, Session
from utils import generate_stream, generate_response, shutdown_event
from fastapi.security import OAuth2PasswordRequestForm
from pathlib import Path
from datetime import timedelta
import os
import logging
from dotenv import load_dotenv
# Load environment variables
load_dotenv()
# Import SlowAPI for Rate Limiting
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.util import get_remote_address
from slowapi.middleware import SlowAPIMiddleware
logging.basicConfig(level=logging.INFO)
# Security Keys with Default Fallbacks
SECRET_KEY = os.getenv("SECRET_KEY", "default_fallback_256_bit_key")
REFRESH_SECRET_KEY = os.getenv("REFRESH_SECRET_KEY", SECRET_KEY)
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30
REFRESH_TOKEN_EXPIRE_DAYS = 7
app = FastAPI()
# Custom Key Function for Rate Limiting (Handles Proxies)
def custom_key_func(request: Request):
forwarded = request.headers.get("X-Forwarded-For")
if forwarded:
return forwarded.split(",")[0] # Get real client IP if behind proxy
return get_remote_address(request)
# Initialize Rate Limiter
limiter = Limiter(key_func=custom_key_func)
app.state.limiter = limiter
# Attach Rate Limit Exceeded Handler
app.add_exception_handler(429, _rate_limit_exceeded_handler)
# Custom Rate Limit Response
@app.exception_handler(429)
async def rate_limit_exceeded_handler(request: Request, exc):
return JSONResponse(
status_code=429,
content={"error": "Rate limit exceeded. Please try again later."}
)
# Add Middleware for Rate Limiting
app.add_middleware(SlowAPIMiddleware)
app.mount("/static", StaticFiles(directory="static"), name="static")
# Entry Endpoint
@app.get('/')
def index() -> FileResponse:
file_path = Path(__file__).parent / 'static' / 'index.html'
return FileResponse(path=str(file_path), media_type='text/html')
# Apply Rate Limiting on Register API (Limit: 5 requests per minute)
@app.post("/register")
@limiter.limit("5/minute")
async def register(request: Request, user: UserRegister, db: Session = Depends(get_db)):
"""Registers a new user."""
new_user = register_user(user.username, user.password, db)
return {"message": "User registered successfully", "user": new_user.username}
# Apply Rate Limiting on Login API (Limit: 10 requests per minute)
@app.post("/login", response_model=TokenResponse)
@limiter.limit("10/minute")
async def login(request: Request, form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)):
try:
user = authenticate_user(form_data.username, form_data.password, db)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid username or password",
headers={"WWW-Authenticate": "Bearer"},
)
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
refresh_token_expires = timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS)
access_token = create_token(
data={"sub": user.username},
expires_delta=access_token_expires,
secret_key=SECRET_KEY
)
refresh_token = create_token(
data={"sub": user.username},
expires_delta=refresh_token_expires,
secret_key=REFRESH_SECRET_KEY
)
return {
"access_token": access_token,
"refresh_token": refresh_token,
"token_type": "bearer"
}
except Exception as e:
logging.error(f"Login error: {str(e)}")
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error")
@app.post("/refresh", response_model=TokenResponse)
async def refresh(refresh_request: RefreshTokenRequest):
try:
# Verify the refresh token
username = verify_token(refresh_request.refresh_token, REFRESH_SECRET_KEY)
# Create new access token
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
access_token = create_token(
data={"sub": username},
expires_delta=access_token_expires,
secret_key=SECRET_KEY
)
return {
"access_token": access_token,
"refresh_token": refresh_request.refresh_token, # Return the same refresh token
"token_type": "bearer"
}
except JWTError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
# Apply Rate Limiting on Generate API (Limit: 3 requests per 10 seconds)
@app.post("/generate")
@limiter.limit("3/10seconds")
async def generate(
request: Request,
query_input: QueryInput,
username: str = Depends(verify_access_token),
stream: bool = Query(False, description="Enable streaming response"),
):
"""Handles both streaming and non-streaming responses, with shutdown detection."""
if shutdown_event.is_set():
return JSONResponse({"message": "Server shutting down..."})
try:
if stream:
return StreamingResponse(generate_stream(query_input.query), media_type="text/event-stream")
else:
return JSONResponse(await generate_response(query_input.query))
except Exception as e:
logging.error(f"Error in generate endpoint: {e}")
raise HTTPException(status_code=500, detail="Internal server error")
# WebSocket endpoint for streaming
@app.on_event("startup")
async def startup_event():
os.makedirs("./cache", exist_ok=True)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)
|