File size: 3,468 Bytes
d2237d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e1ec338
 
 
d2237d8
 
 
 
 
 
 
e1ec338
 
 
 
d2237d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75c8918
d2237d8
 
 
 
 
 
 
 
 
 
 
75c8918
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d2237d8
 
 
 
 
 
 
 
 
 
c43f30f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
from fastapi import FastAPI, Depends, HTTPException, Request
from fastapi.security import APIKeyHeader
from pydantic import BaseModel
from model import ner_pipeline
import logging
import time
import json
import os
import secrets
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

app = FastAPI()
limiter = Limiter(key_func=get_remote_address)
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)

#API_KEY_FILE = "api_keys.json"
API_KEY_FILE = "/tmp/api_keys.json"

token_header = APIKeyHeader(name="X-API-KEY")

if os.path.exists(API_KEY_FILE):
    with open(API_KEY_FILE, "r") as f:
        API_KEYS_STORE = json.load(f)
else:
    API_KEYS_STORE = {"users": {}}
    with open(API_KEY_FILE, "w") as f:
        json.dump(API_KEYS_STORE, f)



API_KEYS = API_KEYS_STORE.get("users", {})
ADMIN_KEY = os.getenv("ADMIN_KEY")

class TextRequest(BaseModel):
    text: str

class RegisterRequest(BaseModel):
    label: str = "user"

def save_keys():
    API_KEYS_STORE["users"] = API_KEYS
    with open(API_KEY_FILE, "w") as f:
        json.dump(API_KEYS_STORE, f, indent=2)

def verify_token(x_api_key: str = Depends(token_header)):
    if x_api_key not in API_KEYS:
        raise HTTPException(status_code=403, detail="Unauthorized")
    return x_api_key

"""
@app.post("/ner")
@limiter.limit("10/minute")
def ner_predict(request: TextRequest, api_key: str = Depends(verify_token), req: Request = None):
    logger.info("Received NER request from IP: %s", get_remote_address(req))
    predictions = ner_pipeline(request.text)
    API_KEYS[api_key]["usage_count"] = API_KEYS[api_key].get("usage_count", 0) + 1
    save_keys()
    return {
        "entities": predictions,
        "usage": API_KEYS[api_key]["usage_count"]
    }
"""

from fastapi import Request

@app.post("/ner")
@limiter.limit("10/minute")
def ner_predict(
    body: TextRequest, 
    req: Request, 
    api_key: str = Depends(verify_token)
):
    logger.info("NER request from IP: %s", get_remote_address(req))
    predictions = ner_pipeline(body.text)
    API_KEYS[api_key]["usage_count"] = API_KEYS[api_key].get("usage_count", 0) + 1
    save_keys()
    return {
        "entities": predictions,
        "usage": API_KEYS[api_key]["usage_count"]
    }



@app.post("/register_user")
def register_user(request: RegisterRequest, x_api_key: str = Depends(token_header)):
    if x_api_key != ADMIN_KEY:
        raise HTTPException(status_code=403, detail="Admin access required")

    new_key = secrets.token_urlsafe(32)
    API_KEYS[new_key] = {"usage_count": 0, "label": request.label}
    save_keys()
    return {"message": "User registered", "api_key": new_key}


@app.get("/list_users")
def list_users(x_api_key: str = Depends(token_header)):
    if x_api_key != ADMIN_KEY:
        raise HTTPException(status_code=403, detail="Admin access required")
    return {"users": API_KEYS}

@app.get("/", include_in_schema=False)
def root():
    return HTMLResponse("""
    <html>
      <head><title>MagBERT-NER API</title></head>
      <body style="font-family:sans-serif;">
        <h1>✅ MagBERT-NER API is running!</h1>
        <p>This is a private FastAPI app.</p>
        <p>Visit <a href="/docs">/docs</a> to explore the API interface.</p>
      </body>
    </html>
    """)