Spaces:
Running
Running
File size: 4,034 Bytes
e4f5d4a 8b883c8 e4f5d4a 8b883c8 546720a e4f5d4a 546720a e4f5d4a 546720a 8b883c8 e4f5d4a 546720a e4f5d4a 546720a e4f5d4a 546720a e4f5d4a 546720a e4f5d4a 546720a e4f5d4a 8b883c8 e4f5d4a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
from fastapi import FastAPI, Depends, HTTPException, status
from fastapi.responses import FileResponse, StreamingResponse
from fastapi.staticfiles import StaticFiles
from jose import JWTError
from schemas import UserRegister, TokenResponse, RefreshTokenRequest, QueryInput
from auth import register_user, get_db, authenticate_user, create_token, verify_token, verify_access_token, Session
from utils import generate_stream
from fastapi.security import OAuth2PasswordRequestForm
from pathlib import Path
from datetime import timedelta
import os
import logging
logging.basicConfig(level=logging.INFO)
SECRET_KEY = os.getenv("SECRET_KEY", 'def6nQHONW99pOPyba9DShny6FB1CJJBigZault')
REFRESH_SECRET_KEY = os.getenv("REFRESH_SECRET_KEY", SECRET_KEY)
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30
REFRESH_TOKEN_EXPIRE_DAYS = 7
app = FastAPI()
# Entry Endpoint
@app.get('/')
def index() -> FileResponse:
file_path = Path(__file__).parent / 'static' / 'index.html'
return FileResponse(path=str(file_path), media_type='text/html')
@app.post("/register")
async def register(user: UserRegister, db: Session = Depends(get_db)):
"""Registers a new user."""
new_user = register_user(user.username, user.password, db)
return {"message": "User registered successfully", "user": new_user.username}
@app.post("/login", response_model=TokenResponse)
async def login(form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)):
try:
user = authenticate_user(form_data.username, form_data.password, db)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid username or password",
headers={"WWW-Authenticate": "Bearer"},
)
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
refresh_token_expires = timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS)
access_token = create_token(
data={"sub": user.username},
expires_delta=access_token_expires,
secret_key=SECRET_KEY
)
refresh_token = create_token(
data={"sub": user.username},
expires_delta=refresh_token_expires,
secret_key=REFRESH_SECRET_KEY
)
return {
"access_token": access_token,
"refresh_token": refresh_token,
"token_type": "bearer"
}
except Exception as e:
logging.error(f"Login error: {str(e)}")
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error")
@app.post("/refresh", response_model=TokenResponse)
async def refresh(refresh_request: RefreshTokenRequest):
try:
# Verify the refresh token
username = verify_token(refresh_request.refresh_token, REFRESH_SECRET_KEY)
# Create new access token
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
access_token = create_token(
data={"sub": username},
expires_delta=access_token_expires,
secret_key=SECRET_KEY
)
return {
"access_token": access_token,
"refresh_token": refresh_request.refresh_token, # Return the same refresh token
"token_type": "bearer"
}
except JWTError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
@app.post("/search")
async def search(
query_input: QueryInput,
username: str = Depends(verify_access_token),
):
return StreamingResponse(generate_stream(query_input.query), media_type="text/plain")
# WebSocket endpoint for streaming
@app.on_event("startup")
async def startup_event():
os.makedirs("./cache", exist_ok=True)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860) |