viser_train1 / viser_proxy_manager.py
Junyi42's picture
code
4d1a850
import asyncio
import httpx
import viser
import websockets
from fastapi import FastAPI, Request, WebSocket, WebSocketDisconnect
from fastapi.responses import Response
class ViserProxyManager:
"""Manages Viser server instances for Gradio applications.
This class handles the creation, retrieval, and cleanup of Viser server instances,
as well as proxying HTTP and WebSocket requests to the appropriate Viser server.
Args:
app: The FastAPI application to which the proxy routes will be added.
min_local_port: Minimum local port number to use for Viser servers. Defaults to 8000.
These ports are used only for internal communication and don't need to be publicly exposed.
max_local_port: Maximum local port number to use for Viser servers. Defaults to 9000.
These ports are used only for internal communication and don't need to be publicly exposed.
max_message_size: Maximum WebSocket message size in bytes. Defaults to 100MB.
"""
def __init__(
self,
app: FastAPI,
min_local_port: int = 8000,
max_local_port: int = 9000,
max_message_size: int = 300 * 1024 * 1024, # 300MB default
) -> None:
self._min_port = min_local_port
self._max_port = max_local_port
self._max_message_size = max_message_size
self._server_from_session_hash: dict[str, viser.ViserServer] = {}
self._last_port = self._min_port - 1 # Track last port tried
@app.get("/viser/{server_id}/{proxy_path:path}")
async def proxy(request: Request, server_id: str, proxy_path: str):
"""Proxy HTTP requests to the appropriate Viser server."""
# Get the local port for this server ID
server = self._server_from_session_hash.get(server_id)
if server is None:
return Response(content="Server not found", status_code=404)
# Build target URL
if proxy_path:
path_suffix = f"/{proxy_path}"
else:
path_suffix = "/"
target_url = f"http://127.0.0.1:{server.get_port()}{path_suffix}"
if request.url.query:
target_url += f"?{request.url.query}"
# Forward request
async with httpx.AsyncClient() as client:
# Forward the original headers, but remove any problematic ones
headers = dict(request.headers)
headers.pop("host", None) # Remove host header to avoid conflicts
headers["accept-encoding"] = "identity" # Disable compression
proxied_req = client.build_request(
method=request.method,
url=target_url,
headers=headers,
content=await request.body(),
)
proxied_resp = await client.send(proxied_req, stream=True)
# Get response headers
response_headers = dict(proxied_resp.headers)
# Check if this is an HTML response
content = await proxied_resp.aread()
return Response(
content=content,
status_code=proxied_resp.status_code,
headers=response_headers,
)
# WebSocket Proxy
@app.websocket("/viser/{server_id}")
async def websocket_proxy(websocket: WebSocket, server_id: str):
"""Proxy WebSocket connections to the appropriate Viser server."""
try:
await websocket.accept()
server = self._server_from_session_hash.get(server_id)
if server is None:
await websocket.close(code=1008, reason="Not Found")
return
# Determine target WebSocket URL
target_ws_url = f"ws://127.0.0.1:{server.get_port()}"
if not target_ws_url:
await websocket.close(code=1008, reason="Not Found")
return
try:
# Connect to the target WebSocket with increased message size and timeout
async with websockets.connect(
target_ws_url,
max_size=self._max_message_size,
ping_interval=30, # Send ping every 30 seconds
ping_timeout=10, # Wait 10 seconds for pong response
close_timeout=5, # Wait 5 seconds for close handshake
) as ws_target:
# Create tasks for bidirectional communication
async def forward_to_target():
"""Forward messages from the client to the target WebSocket."""
try:
while True:
data = await websocket.receive_bytes()
await ws_target.send(data, text=False)
except WebSocketDisconnect:
try:
await ws_target.close()
except RuntimeError:
pass
async def forward_from_target():
"""Forward messages from the target WebSocket to the client."""
try:
while True:
data = await ws_target.recv(decode=False)
await websocket.send_bytes(data)
except websockets.exceptions.ConnectionClosed:
try:
await websocket.close()
except RuntimeError:
pass
# Run both forwarding tasks concurrently
forward_task = asyncio.create_task(forward_to_target())
backward_task = asyncio.create_task(forward_from_target())
# Wait for either task to complete (which means a connection was closed)
done, pending = await asyncio.wait(
[forward_task, backward_task],
return_when=asyncio.FIRST_COMPLETED,
)
# Cancel the remaining task
for task in pending:
task.cancel()
except websockets.exceptions.ConnectionClosedError as e:
print(f"WebSocket connection closed with error: {e}")
await websocket.close(code=1011, reason="Connection to target closed")
except Exception as e:
print(f"WebSocket proxy error: {e}")
try:
await websocket.close(code=1011, reason=str(e)[:120]) # Limit reason length
except:
pass # Already closed
def start_server(self, server_id: str) -> viser.ViserServer:
"""Start a new Viser server and associate it with the given server ID.
Finds an available port within the configured min_local_port and max_local_port range.
These ports are used only for internal communication and don't need to be publicly exposed.
Args:
server_id: The unique identifier to associate with the new server.
Returns:
The newly created Viser server instance.
Raises:
RuntimeError: If no free ports are available in the configured range.
"""
import socket
# Start searching from the last port + 1 (with wraparound)
port_range_size = self._max_port - self._min_port + 1
start_port = (
(self._last_port + 1 - self._min_port) % port_range_size
) + self._min_port
# Try each port once
for offset in range(port_range_size):
port = (
(start_port - self._min_port + offset) % port_range_size
) + self._min_port
try:
# Check if port is available by attempting to bind to it
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("127.0.0.1", port))
# Port is available, create server with this port
server = viser.ViserServer(port=port)
self._server_from_session_hash[server_id] = server
self._last_port = port
return server
except OSError:
# Port is in use, try the next one
continue
# If we get here, no ports were available
raise RuntimeError(
f"No available local ports in range {self._min_port}-{self._max_port}"
)
def get_server(self, server_id: str) -> viser.ViserServer:
"""Retrieve a Viser server instance by its ID.
Args:
server_id: The unique identifier of the server to retrieve.
Returns:
The Viser server instance associated with the given ID.
"""
return self._server_from_session_hash[server_id]
def stop_server(self, server_id: str) -> None:
"""Stop a Viser server and remove it from the manager.
Args:
server_id: The unique identifier of the server to stop.
"""
self._server_from_session_hash.pop(server_id).stop()