Spaces:
Running
Running
from fastapi import FastAPI, Depends, HTTPException, status, Query | |
from fastapi.responses import FileResponse, StreamingResponse, JSONResponse | |
from fastapi.staticfiles import StaticFiles | |
from jose import JWTError | |
from schemas import UserRegister, TokenResponse, RefreshTokenRequest, QueryInput | |
from auth import register_user, get_db, authenticate_user, create_token, verify_token, verify_access_token, Session | |
from utils import generate_stream, generate_response, shutdown_event | |
from fastapi.security import OAuth2PasswordRequestForm | |
from pathlib import Path | |
from datetime import timedelta | |
import os | |
import logging | |
logging.basicConfig(level=logging.INFO) | |
SECRET_KEY = os.getenv("SECRET_KEY", 'def6nQHONW99pOPyba9DShny6FB1CJJBigZault') | |
REFRESH_SECRET_KEY = os.getenv("REFRESH_SECRET_KEY", SECRET_KEY) | |
ALGORITHM = "HS256" | |
ACCESS_TOKEN_EXPIRE_MINUTES = 30 | |
REFRESH_TOKEN_EXPIRE_DAYS = 7 | |
app = FastAPI() | |
app.mount("/static", StaticFiles(directory="static"), name="static") | |
# Entry Endpoint | |
def index() -> FileResponse: | |
file_path = Path(__file__).parent / 'static' / 'index.html' | |
return FileResponse(path=str(file_path), media_type='text/html') | |
async def register(user: UserRegister, db: Session = Depends(get_db)): | |
"""Registers a new user.""" | |
new_user = register_user(user.username, user.password, db) | |
return {"message": "User registered successfully", "user": new_user.username} | |
async def login(form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)): | |
try: | |
user = authenticate_user(form_data.username, form_data.password, db) | |
if not user: | |
raise HTTPException( | |
status_code=status.HTTP_401_UNAUTHORIZED, | |
detail="Invalid username or password", | |
headers={"WWW-Authenticate": "Bearer"}, | |
) | |
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) | |
refresh_token_expires = timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS) | |
access_token = create_token( | |
data={"sub": user.username}, | |
expires_delta=access_token_expires, | |
secret_key=SECRET_KEY | |
) | |
refresh_token = create_token( | |
data={"sub": user.username}, | |
expires_delta=refresh_token_expires, | |
secret_key=REFRESH_SECRET_KEY | |
) | |
return { | |
"access_token": access_token, | |
"refresh_token": refresh_token, | |
"token_type": "bearer" | |
} | |
except Exception as e: | |
logging.error(f"Login error: {str(e)}") | |
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") | |
async def refresh(refresh_request: RefreshTokenRequest): | |
try: | |
# Verify the refresh token | |
username = verify_token(refresh_request.refresh_token, REFRESH_SECRET_KEY) | |
# Create new access token | |
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) | |
access_token = create_token( | |
data={"sub": username}, | |
expires_delta=access_token_expires, | |
secret_key=SECRET_KEY | |
) | |
return { | |
"access_token": access_token, | |
"refresh_token": refresh_request.refresh_token, # Return the same refresh token | |
"token_type": "bearer" | |
} | |
except JWTError: | |
raise HTTPException( | |
status_code=status.HTTP_401_UNAUTHORIZED, | |
detail="Could not validate credentials", | |
headers={"WWW-Authenticate": "Bearer"}, | |
) | |
async def generate( | |
query_input: QueryInput, | |
username: str = Depends(verify_access_token), | |
stream: bool = Query(False, description="Enable streaming response"), | |
): | |
"""Handles both streaming and non-streaming responses, with shutdown detection.""" | |
if shutdown_event.is_set(): | |
return JSONResponse({"message": "Server shutting down..."}) | |
try: | |
if stream: | |
return StreamingResponse(generate_stream(query_input.query), media_type="text/event-stream") | |
else: | |
return JSONResponse(await generate_response(query_input.query)) | |
except Exception as e: | |
logging.error(f"Error in generate endpoint: {e}") | |
raise HTTPException(status_code=500, detail="Internal server error") | |
# WebSocket endpoint for streaming | |
async def startup_event(): | |
os.makedirs("./cache", exist_ok=True) | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=7860) |