Spaces:
Running
Running
Update main.py
Browse files
main.py
CHANGED
@@ -14,10 +14,51 @@ import json
|
|
14 |
from typing import Optional
|
15 |
import datetime
|
16 |
from usage_tracker import UsageTracker
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
usage_tracker = UsageTracker()
|
18 |
load_dotenv() #idk why this shi
|
19 |
|
20 |
app = FastAPI()
|
|
|
21 |
|
22 |
# Get API keys and secret endpoint from environment variables
|
23 |
api_keys_str = os.getenv('API_KEYS') #deprecated -_-
|
|
|
14 |
from typing import Optional
|
15 |
import datetime
|
16 |
from usage_tracker import UsageTracker
|
17 |
+
from fastapi.middleware.base import BaseHTTPMiddleware
|
18 |
+
from collections import defaultdict
|
19 |
+
|
20 |
+
class RateLimitMiddleware(BaseHTTPMiddleware):
|
21 |
+
def __init__(self, app, requests_per_second: int = 2):
|
22 |
+
super().__init__(app)
|
23 |
+
self.requests_per_second = requests_per_second
|
24 |
+
self.last_request_time = defaultdict(float)
|
25 |
+
self.tokens = defaultdict(lambda: requests_per_second)
|
26 |
+
self.last_update = defaultdict(float)
|
27 |
+
|
28 |
+
async def dispatch(self, request: Request, call_next):
|
29 |
+
client_ip = request.client.host
|
30 |
+
current_time = time.time()
|
31 |
+
|
32 |
+
# Update tokens
|
33 |
+
time_passed = current_time - self.last_update[client_ip]
|
34 |
+
self.last_update[client_ip] = current_time
|
35 |
+
self.tokens[client_ip] = min(
|
36 |
+
self.requests_per_second,
|
37 |
+
self.tokens[client_ip] + time_passed * self.requests_per_second
|
38 |
+
)
|
39 |
+
|
40 |
+
# Check if request can be processed
|
41 |
+
if self.tokens[client_ip] < 1:
|
42 |
+
return JSONResponse(
|
43 |
+
status_code=429,
|
44 |
+
content={
|
45 |
+
"detail": "Too many requests. Please try again later.",
|
46 |
+
"retry_after": round((1 - self.tokens[client_ip]) / self.requests_per_second)
|
47 |
+
}
|
48 |
+
)
|
49 |
+
|
50 |
+
# Consume a token
|
51 |
+
self.tokens[client_ip] -= 1
|
52 |
+
|
53 |
+
# Process the request
|
54 |
+
response = await call_next(request)
|
55 |
+
return response
|
56 |
+
|
57 |
usage_tracker = UsageTracker()
|
58 |
load_dotenv() #idk why this shi
|
59 |
|
60 |
app = FastAPI()
|
61 |
+
app.add_middleware(RateLimitMiddleware, requests_per_second=2)
|
62 |
|
63 |
# Get API keys and secret endpoint from environment variables
|
64 |
api_keys_str = os.getenv('API_KEYS') #deprecated -_-
|