node-py-test / app.py
pngwn's picture
pngwn HF Staff
Update app.py
d0d3b45 verified
raw
history blame
4 kB
import httpx
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse, JSONResponse
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
import subprocess
import signal
import time
import os
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
PYTHON_PORT = 7860
NODE_PORT = 4321
NODE_SCRIPT_PATH = "build"
node_process = subprocess.Popen(["node", NODE_SCRIPT_PATH])
def handle_sigterm(signum, frame):
print("Stopping Node.js server...")
node_process.terminate()
node_process.wait()
exit(0)
signal.signal(signal.SIGTERM, handle_sigterm)
client = httpx.AsyncClient()
@app.on_event("shutdown")
def shutdown_event():
print("Stopping Node.js server...")
node_process.terminate()
node_process.wait()
@app.get("/config")
async def route_with_config():
return JSONResponse(content={"one": "hello", "two": "from", "three": "Python"})
# async def proxy_to_node(request: Request):
# # Preserve the full path including query parameters
# full_path = request.url.path
# if request.url.query:
# full_path += f"?{request.url.query}"
# url = f"http://localhost:{NODE_PORT}{full_path}"
# headers = {
# k: v
# for k, v in request.headers.items()
# if k.lower() not in ["host", "content-length"]
# }
# print(headers)
# # body = await request.body()
# # async with client:
# # response = await client.request(
# # method=request.method, url=url, headers=headers, content=body
# # )
# # return StreamingResponse(
# # response.iter_bytes(),
# # status_code=response.status_code,
# # headers=response.headers,
# # )
# req = client.build_request("GET", httpx.URL(url), headers=headers)
# r = await client.send(req, stream=True)
# return StreamingResponse(
# r.aiter_raw(), headers=r.headers
# )
async def proxy_to_node(
request: Request,
server_name: str,
node_port: int,
python_port: int,
scheme: str = "http",
mounted_path: str = "",
):
start_time = time.time()
full_path = request.url.path
if mounted_path:
full_path = full_path.replace(mounted_path, "")
if request.url.query:
full_path += f"?{request.url.query}"
url = f"{scheme}://{server_name}:{node_port}{full_path}"
headers = dict(request.headers)
print(
headers,
)
server_url = f"{scheme}://{server_name}"
if python_port:
server_url += f":{python_port}"
if mounted_path:
server_url += mounted_path
headers["x-gradio-server"] = server_url
headers["x-gradio-port"] = str(python_port)
print(
f"Proxying request from {request.url.path} to {url} with server url {server_url}"
)
if os.getenv("GRADIO_LOCAL_DEV_MODE"):
headers["x-gradio-local-dev-mode"] = "1"
print(f"Time to prepare request: {time.time() - start_time:.4f} seconds")
print(
f"Total setup time before streaming: {time.time() - start_time:.4f} seconds"
)
req = client.build_request("GET", httpx.URL(url), headers=headers)
r = await client.send(req, stream=True)
print(f"Time to prepare request: {time.time() - start_time:.4f} seconds")
print(f"\nHeaders: {r.headers}\n")
return StreamingResponse(r.aiter_raw(), headers=r.headers)
@app.api_route(
"/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"]
)
async def catch_all(request: Request, path: str):
return await proxy_to_node(
request,
"0.0.0.0",
4321,
request.url.port,
request.url.scheme,
"",
)
if __name__ == "__main__":
print(
f"Starting dual server. Python handles specific routes, Node handles the rest."
)
uvicorn.run(app, host="0.0.0.0", port=PYTHON_PORT)