from fastapi import FastAPI, Depends, HTTPException, status from fastapi.responses import FileResponse, StreamingResponse from fastapi.staticfiles import StaticFiles from jose import JWTError from schemas import UserRegister, TokenResponse, RefreshTokenRequest, QueryInput from auth import register_user, get_db, authenticate_user, create_token, verify_token, verify_access_token, Session from utils import generate_stream from fastapi.security import OAuth2PasswordRequestForm from pathlib import Path from datetime import timedelta import os import logging logging.basicConfig(level=logging.INFO) SECRET_KEY = os.getenv("SECRET_KEY", 'def6nQHONW99pOPyba9DShny6FB1CJJBigZault') REFRESH_SECRET_KEY = os.getenv("REFRESH_SECRET_KEY", SECRET_KEY) ALGORITHM = "HS256" ACCESS_TOKEN_EXPIRE_MINUTES = 30 REFRESH_TOKEN_EXPIRE_DAYS = 7 app = FastAPI() # Entry Endpoint @app.get('/') def index() -> FileResponse: file_path = Path(__file__).parent / 'static' / 'index.html' return FileResponse(path=str(file_path), media_type='text/html') @app.post("/register") async def register(user: UserRegister, db: Session = Depends(get_db)): """Registers a new user.""" new_user = register_user(user.username, user.password, db) return {"message": "User registered successfully", "user": new_user.username} @app.post("/login", response_model=TokenResponse) async def login(form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)): try: user = authenticate_user(form_data.username, form_data.password, db) if not user: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid username or password", headers={"WWW-Authenticate": "Bearer"}, ) access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) refresh_token_expires = timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS) access_token = create_token( data={"sub": user.username}, expires_delta=access_token_expires, secret_key=SECRET_KEY ) refresh_token = create_token( data={"sub": user.username}, expires_delta=refresh_token_expires, secret_key=REFRESH_SECRET_KEY ) return { "access_token": access_token, "refresh_token": refresh_token, "token_type": "bearer" } except Exception as e: logging.error(f"Login error: {str(e)}") raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") @app.post("/refresh", response_model=TokenResponse) async def refresh(refresh_request: RefreshTokenRequest): try: # Verify the refresh token username = verify_token(refresh_request.refresh_token, REFRESH_SECRET_KEY) # Create new access token access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) access_token = create_token( data={"sub": username}, expires_delta=access_token_expires, secret_key=SECRET_KEY ) return { "access_token": access_token, "refresh_token": refresh_request.refresh_token, # Return the same refresh token "token_type": "bearer" } except JWTError: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials", headers={"WWW-Authenticate": "Bearer"}, ) @app.post("/search") async def search( query_input: QueryInput, username: str = Depends(verify_access_token), ): return StreamingResponse(generate_stream(query_input.query), media_type="text/plain") # WebSocket endpoint for streaming @app.on_event("startup") async def startup_event(): os.makedirs("./cache", exist_ok=True) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)