File size: 4,408 Bytes
9f8a30b
 
e4f5d4a
 
 
 
9f8a30b
546720a
e4f5d4a
 
546720a
e4f5d4a
 
 
 
 
 
 
 
 
546720a
 
 
03afdca
8b883c8
e4f5d4a
 
 
 
 
546720a
e4f5d4a
 
 
 
 
546720a
e4f5d4a
 
 
 
546720a
e4f5d4a
 
 
 
 
 
546720a
e4f5d4a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
546720a
e4f5d4a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9f8a30b
 
e4f5d4a
 
9f8a30b
e4f5d4a
9f8a30b
 
 
 
 
e4f5d4a
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
from fastapi import FastAPI, Depends, HTTPException, status, Query
from fastapi.responses import FileResponse, StreamingResponse, JSONResponse
from fastapi.staticfiles import StaticFiles
from jose import JWTError
from schemas import UserRegister, TokenResponse, RefreshTokenRequest, QueryInput
from auth import  register_user, get_db, authenticate_user, create_token, verify_token, verify_access_token, Session
from utils import generate_stream, generate_response
from fastapi.security import OAuth2PasswordRequestForm
from pathlib import Path
from datetime import timedelta
import os
import logging

logging.basicConfig(level=logging.INFO)

SECRET_KEY = os.getenv("SECRET_KEY", 'def6nQHONW99pOPyba9DShny6FB1CJJBigZault')
REFRESH_SECRET_KEY = os.getenv("REFRESH_SECRET_KEY", SECRET_KEY)
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30
REFRESH_TOKEN_EXPIRE_DAYS = 7

app = FastAPI()

app.mount("/static", StaticFiles(directory="static"), name="static")

# Entry Endpoint 
@app.get('/')
def index() -> FileResponse:
    file_path = Path(__file__).parent / 'static' / 'index.html'
    return FileResponse(path=str(file_path), media_type='text/html')

@app.post("/register")
async def register(user: UserRegister, db: Session = Depends(get_db)):
    """Registers a new user."""
    new_user = register_user(user.username, user.password, db)
    return {"message": "User registered successfully", "user": new_user.username}

@app.post("/login", response_model=TokenResponse)
async def login(form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)):
    try:
        user = authenticate_user(form_data.username, form_data.password, db)

        if not user:
           raise HTTPException(
               status_code=status.HTTP_401_UNAUTHORIZED,
               detail="Invalid username or password",
               headers={"WWW-Authenticate": "Bearer"},
           )

        access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
        refresh_token_expires = timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS)
        
        access_token = create_token(
            data={"sub": user.username}, 
            expires_delta=access_token_expires, 
            secret_key=SECRET_KEY
        )
        refresh_token = create_token(
            data={"sub": user.username}, 
            expires_delta=refresh_token_expires, 
            secret_key=REFRESH_SECRET_KEY
        )

        return {
            "access_token": access_token, 
            "refresh_token": refresh_token, 
            "token_type": "bearer"
        }
    except Exception as e:
        logging.error(f"Login error: {str(e)}")
        raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error")

@app.post("/refresh", response_model=TokenResponse)
async def refresh(refresh_request: RefreshTokenRequest):
    try:
        # Verify the refresh token
        username = verify_token(refresh_request.refresh_token, REFRESH_SECRET_KEY)
        
        # Create new access token
        access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
        access_token = create_token(
            data={"sub": username},
            expires_delta=access_token_expires,
            secret_key=SECRET_KEY
        )
        
        return {
            "access_token": access_token,
            "refresh_token": refresh_request.refresh_token,  # Return the same refresh token
            "token_type": "bearer"
        }
        
    except JWTError:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="Could not validate credentials",
            headers={"WWW-Authenticate": "Bearer"},
        )
    
@app.post("/generate")
async def generate(
    query_input: QueryInput,
    username: str = Depends(verify_access_token),
    stream: bool = Query(False, description="Enable streaming response"),
):
    """API endpoint that supports both streaming and non-streaming responses."""
    if stream:
        return StreamingResponse(generate_stream(query_input.query), media_type="text/event-stream")
    else:
        return JSONResponse(await generate_response(query_input.query))
# WebSocket endpoint for streaming
@app.on_event("startup")
async def startup_event():
    os.makedirs("./cache", exist_ok=True)

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)