from fastapi import FastAPI, WebSocket, Depends, HTTPException from auth import get_current_user, authenticate_user, create_access_token from gen import generate_text_stream from fastapi.security import OAuth2PasswordRequestForm from fastapi.middleware.cors import CORSMiddleware from middleware import setup_rate_limiter import os app = FastAPI() # Apply rate limiting middleware setup_rate_limiter(app) # CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Login endpoint @app.post("/token") async def login(form_data: OAuth2PasswordRequestForm = Depends()): user = authenticate_user(form_data.username, form_data.password) if not user: raise HTTPException(status_code=400, detail="Incorrect username or password") access_token = create_access_token(data={"sub": user["username"]}) return {"access_token": access_token, "token_type": "bearer"} # WebSocket endpoint for streaming @app.websocket("/generate") async def websocket_generate(websocket: WebSocket, token: str): await websocket.accept() try: user = get_current_user(token) prompt = await websocket.receive_text() await generate_text_stream(prompt, websocket) except Exception as e: await websocket.send_text(f"Error: {str(e)}") finally: await websocket.close()