File size: 6,183 Bytes
1392014
9f8a30b
e4f5d4a
 
 
1392014
39308e2
546720a
e4f5d4a
 
546720a
e4f5d4a
1392014
 
73ef21f
1392014
73ef21f
1392014
 
 
 
e4f5d4a
 
 
73ef21f
 
 
e4f5d4a
 
 
546720a
 
 
73ef21f
 
 
 
 
 
 
1392014
73ef21f
1392014
 
 
 
 
73ef21f
 
 
 
 
 
 
 
1392014
 
 
03afdca
8b883c8
e4f5d4a
 
 
 
 
546720a
1392014
e4f5d4a
1392014
 
e4f5d4a
 
 
546720a
1392014
e4f5d4a
1392014
 
e4f5d4a
 
546720a
e4f5d4a
 
 
 
 
 
546720a
e4f5d4a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63a5fea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1392014
9f8a30b
1392014
9f8a30b
1392014
e4f5d4a
 
9f8a30b
e4f5d4a
39308e2
 
 
 
 
 
 
 
 
 
 
 
 
e4f5d4a
 
 
 
 
 
 
1392014
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
from fastapi import FastAPI, Depends, HTTPException, status, Query, Request
from fastapi.responses import FileResponse, StreamingResponse, JSONResponse
from fastapi.staticfiles import StaticFiles
from jose import JWTError
from schemas import UserRegister, TokenResponse, RefreshTokenRequest, QueryInput
from auth import register_user, get_db, authenticate_user, create_token, verify_token, verify_access_token, Session
from utils import generate_stream, generate_response, shutdown_event
from fastapi.security import OAuth2PasswordRequestForm
from pathlib import Path
from datetime import timedelta
import os
import logging
from dotenv import load_dotenv

# Load environment variables
load_dotenv()

# Import SlowAPI for Rate Limiting
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.util import get_remote_address
from slowapi.middleware import SlowAPIMiddleware

logging.basicConfig(level=logging.INFO)

# Security Keys with Default Fallbacks
SECRET_KEY = os.getenv("SECRET_KEY", "default_fallback_256_bit_key")
REFRESH_SECRET_KEY = os.getenv("REFRESH_SECRET_KEY", SECRET_KEY)
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30
REFRESH_TOKEN_EXPIRE_DAYS = 7

app = FastAPI()

# Custom Key Function for Rate Limiting (Handles Proxies)
def custom_key_func(request: Request):
    forwarded = request.headers.get("X-Forwarded-For")
    if forwarded:
        return forwarded.split(",")[0]  # Get real client IP if behind proxy
    return get_remote_address(request)

# Initialize Rate Limiter
limiter = Limiter(key_func=custom_key_func)
app.state.limiter = limiter

# Attach Rate Limit Exceeded Handler
app.add_exception_handler(429, _rate_limit_exceeded_handler)

# Custom Rate Limit Response
@app.exception_handler(429)
async def rate_limit_exceeded_handler(request: Request, exc):
    return JSONResponse(
        status_code=429,
        content={"error": "Rate limit exceeded. Please try again later."}
    )

# Add Middleware for Rate Limiting
app.add_middleware(SlowAPIMiddleware)

app.mount("/static", StaticFiles(directory="static"), name="static")

# Entry Endpoint 
@app.get('/')
def index() -> FileResponse:
    file_path = Path(__file__).parent / 'static' / 'index.html'
    return FileResponse(path=str(file_path), media_type='text/html')

# Apply Rate Limiting on Register API (Limit: 5 requests per minute)
@app.post("/register")
@limiter.limit("5/minute")
async def register(request: Request, user: UserRegister, db: Session = Depends(get_db)):
    """Registers a new user."""
    new_user = register_user(user.username, user.password, db)
    return {"message": "User registered successfully", "user": new_user.username}

# Apply Rate Limiting on Login API (Limit: 10 requests per minute)
@app.post("/login", response_model=TokenResponse)
@limiter.limit("10/minute")
async def login(request: Request, form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)):
    try:
        user = authenticate_user(form_data.username, form_data.password, db)

        if not user:
           raise HTTPException(
               status_code=status.HTTP_401_UNAUTHORIZED,
               detail="Invalid username or password",
               headers={"WWW-Authenticate": "Bearer"},
           )

        access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
        refresh_token_expires = timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS)
        
        access_token = create_token(
            data={"sub": user.username}, 
            expires_delta=access_token_expires, 
            secret_key=SECRET_KEY
        )
        refresh_token = create_token(
            data={"sub": user.username}, 
            expires_delta=refresh_token_expires, 
            secret_key=REFRESH_SECRET_KEY
        )

        return {
            "access_token": access_token, 
            "refresh_token": refresh_token, 
            "token_type": "bearer"
        }
    except Exception as e:
        logging.error(f"Login error: {str(e)}")
        raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error")

@app.post("/refresh", response_model=TokenResponse)
async def refresh(refresh_request: RefreshTokenRequest):
    try:
        # Verify the refresh token
        username = verify_token(refresh_request.refresh_token, REFRESH_SECRET_KEY)
        
        # Create new access token
        access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
        access_token = create_token(
            data={"sub": username},
            expires_delta=access_token_expires,
            secret_key=SECRET_KEY
        )
        
        return {
            "access_token": access_token,
            "refresh_token": refresh_request.refresh_token,  # Return the same refresh token
            "token_type": "bearer"
        }
        
    except JWTError:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="Could not validate credentials",
            headers={"WWW-Authenticate": "Bearer"},
        )
        
# Apply Rate Limiting on Generate API (Limit: 3 requests per 10 seconds)
@app.post("/generate")
@limiter.limit("3/10seconds")
async def generate(
    request: Request,
    query_input: QueryInput,
    username: str = Depends(verify_access_token),
    stream: bool = Query(False, description="Enable streaming response"),
):
    """Handles both streaming and non-streaming responses, with shutdown detection."""
    if shutdown_event.is_set():
        return JSONResponse({"message": "Server shutting down..."})

    try:
        if stream:
            return StreamingResponse(generate_stream(query_input.query), media_type="text/event-stream")
        else:
            return JSONResponse(await generate_response(query_input.query))
    except Exception as e:
        logging.error(f"Error in generate endpoint: {e}")
        raise HTTPException(status_code=500, detail="Internal server error")

# WebSocket endpoint for streaming
@app.on_event("startup")
async def startup_event():
    os.makedirs("./cache", exist_ok=True)

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)