from fastapi import FastAPI, Depends, HTTPException, Request from fastapi.responses import HTMLResponse # ✅ Add this line from fastapi.security import APIKeyHeader from pydantic import BaseModel from model import ner_pipeline import logging import time import json import os import secrets from slowapi import Limiter, _rate_limit_exceeded_handler from slowapi.util import get_remote_address from slowapi.errors import RateLimitExceeded logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) app = FastAPI() limiter = Limiter(key_func=get_remote_address) app.state.limiter = limiter app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) #API_KEY_FILE = "api_keys.json" API_KEY_FILE = "/tmp/api_keys.json" token_header = APIKeyHeader(name="X-API-KEY") if os.path.exists(API_KEY_FILE): with open(API_KEY_FILE, "r") as f: API_KEYS_STORE = json.load(f) else: API_KEYS_STORE = {"users": {}} with open(API_KEY_FILE, "w") as f: json.dump(API_KEYS_STORE, f) API_KEYS = API_KEYS_STORE.get("users", {}) ADMIN_KEY = os.getenv("ADMIN_KEY") class TextRequest(BaseModel): text: str class RegisterRequest(BaseModel): label: str = "user" def save_keys(): API_KEYS_STORE["users"] = API_KEYS with open(API_KEY_FILE, "w") as f: json.dump(API_KEYS_STORE, f, indent=2) def verify_token(x_api_key: str = Depends(token_header)): if x_api_key not in API_KEYS: raise HTTPException(status_code=403, detail="Unauthorized") return x_api_key """ @app.post("/ner") @limiter.limit("10/minute") def ner_predict(request: TextRequest, api_key: str = Depends(verify_token), req: Request = None): logger.info("Received NER request from IP: %s", get_remote_address(req)) predictions = ner_pipeline(request.text) API_KEYS[api_key]["usage_count"] = API_KEYS[api_key].get("usage_count", 0) + 1 save_keys() return { "entities": predictions, "usage": API_KEYS[api_key]["usage_count"] } """ from fastapi import Request @app.post("/ner") @limiter.limit("10/minute") def ner_predict(body: TextRequest, request: Request, api_key: str = Depends(verify_token)): logger.info("NER request from IP: %s", get_remote_address(request)) raw_predictions = ner_pipeline(body.text) # Use original character spans to extract clean entity words predictions = [ { **ent, "score": float(ent["score"]), "word": body.text[ent["start"]:ent["end"]].strip() # ✅ Clean spaces } for ent in raw_predictions ] API_KEYS[api_key]["usage_count"] = API_KEYS[api_key].get("usage_count", 0) + 1 save_keys() return { "entities": predictions, "usage": API_KEYS[api_key]["usage_count"] } @app.post("/register_user") def register_user(request: RegisterRequest, x_api_key: str = Depends(token_header)): if x_api_key != ADMIN_KEY: raise HTTPException(status_code=403, detail="Admin access required") new_key = secrets.token_urlsafe(32) API_KEYS[new_key] = {"usage_count": 0, "label": request.label} save_keys() return {"message": "User registered", "api_key": new_key} @app.get("/list_users") def list_users(x_api_key: str = Depends(token_header)): if x_api_key != ADMIN_KEY: raise HTTPException(status_code=403, detail="Admin access required") return {"users": API_KEYS} @app.get("/", include_in_schema=False) def root(): return HTMLResponse("""
Visit /docs to try the API.
""")