File size: 3,244 Bytes
de26925
89878fd
de26925
a9fdc8f
fb0b173
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
de26925
9781b82
de26925
fb0b173
9781b82
 
 
 
fb0b173
 
 
a9fdc8f
fb0b173
 
9781b82
fb0b173
de26925
fb0b173
de26925
 
 
9781b82
de26925
 
9781b82
de26925
9781b82
de26925
 
9781b82
 
de26925
 
c1d606a
de26925
 
39b1c1e
de26925
 
 
 
 
 
 
 
 
 
89878fd
 
 
de26925
 
 
 
 
 
 
 
 
 
 
95359fb
9781b82
de26925
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from fastapi import FastAPI, Request, HTTPException
from fastapi.responses import Response
import httpx
import os
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.errors import RateLimitExceeded

# Custom key function to extract api_key from headers or query params
def get_api_key(request: Request) -> str:
    # Check headers first
    api_key = request.headers.get("X-API-Key")
    if not api_key:
        # Fallback to query parameters
        api_key = request.query_params.get("api_key")
    if not api_key:
        raise HTTPException(status_code=400, detail="API key is required in 'X-API-Key' header or 'api_key' query parameter")
    return api_key

# Initialize rate limiter with custom key function
limiter = Limiter(key_func=get_api_key)

# FastAPI app setup
app = FastAPI(
    title="Dhwani API Proxy",
    description="A proxy that forwards all requests to the Dhwani API target server.",
    version="1.0.0",
    redirect_slashes=False,
)

# Add rate limit exceeded handler
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)

# Target server to forward requests to
TARGET_SERVER = os.getenv("DWANI_API_BASE_URL", "http://localhost:8000")

# Catch-all route to forward all requests with rate limiting
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "HEAD"])
@limiter.limit("5/minute")  # Limit to 100 requests per minute per api_key
async def proxy(request: Request, path: str):
    # Construct the target URL
    target_url = f"{TARGET_SERVER}/{path}"
    
    # Prepare query parameters
    query_params = dict(request.query_params)
    
    # Prepare headers, excluding FastAPI-specific headers
    headers = {
        key: value for key, value in request.headers.items()
        if key.lower() not in ("host", "connection", "accept-encoding")
    }
    
    # Get the request body, if any
    body = await request.body()
    
    # Create an HTTPX client for making the request
    async with httpx.AsyncClient(timeout=60) as client:
        try:
            # Forward the request to the target server
            response = await client.request(
                method=request.method,
                url=target_url,
                params=query_params,
                headers=headers,
                content=body,
                follow_redirects=False
            )
            
            # Return the response directly
            return Response(
                content=response.content,
                status_code=response.status_code,
                headers=dict(response.headers),
                media_type=response.headers.get("content-type", "application/json")
            )
            
        except httpx.TimeoutException:
            raise HTTPException(status_code=504, detail="Target server timeout")
        except httpx.HTTPStatusError as e:
            raise HTTPException(status_code=e.response.status_code, detail=str(e))
        except httpx.RequestError as e:
            raise HTTPException(status_code=500, detail=f"Failed to forward request: {str(e)}")

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)  # Run the proxy server