ParthSadaria commited on
Commit
a06f1b3
·
verified ·
1 Parent(s): c04e4db

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +41 -0
main.py CHANGED
@@ -14,10 +14,51 @@ import json
14
  from typing import Optional
15
  import datetime
16
  from usage_tracker import UsageTracker
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  usage_tracker = UsageTracker()
18
  load_dotenv() #idk why this shi
19
 
20
  app = FastAPI()
 
21
 
22
  # Get API keys and secret endpoint from environment variables
23
  api_keys_str = os.getenv('API_KEYS') #deprecated -_-
 
14
  from typing import Optional
15
  import datetime
16
  from usage_tracker import UsageTracker
17
+ from fastapi.middleware.base import BaseHTTPMiddleware
18
+ from collections import defaultdict
19
+
20
+ class RateLimitMiddleware(BaseHTTPMiddleware):
21
+ def __init__(self, app, requests_per_second: int = 2):
22
+ super().__init__(app)
23
+ self.requests_per_second = requests_per_second
24
+ self.last_request_time = defaultdict(float)
25
+ self.tokens = defaultdict(lambda: requests_per_second)
26
+ self.last_update = defaultdict(float)
27
+
28
+ async def dispatch(self, request: Request, call_next):
29
+ client_ip = request.client.host
30
+ current_time = time.time()
31
+
32
+ # Update tokens
33
+ time_passed = current_time - self.last_update[client_ip]
34
+ self.last_update[client_ip] = current_time
35
+ self.tokens[client_ip] = min(
36
+ self.requests_per_second,
37
+ self.tokens[client_ip] + time_passed * self.requests_per_second
38
+ )
39
+
40
+ # Check if request can be processed
41
+ if self.tokens[client_ip] < 1:
42
+ return JSONResponse(
43
+ status_code=429,
44
+ content={
45
+ "detail": "Too many requests. Please try again later.",
46
+ "retry_after": round((1 - self.tokens[client_ip]) / self.requests_per_second)
47
+ }
48
+ )
49
+
50
+ # Consume a token
51
+ self.tokens[client_ip] -= 1
52
+
53
+ # Process the request
54
+ response = await call_next(request)
55
+ return response
56
+
57
  usage_tracker = UsageTracker()
58
  load_dotenv() #idk why this shi
59
 
60
  app = FastAPI()
61
+ app.add_middleware(RateLimitMiddleware, requests_per_second=2)
62
 
63
  # Get API keys and secret endpoint from environment variables
64
  api_keys_str = os.getenv('API_KEYS') #deprecated -_-