Spaces:
Running
Running
import asyncio | |
import httpx | |
import viser | |
import websockets | |
from fastapi import FastAPI, Request, WebSocket, WebSocketDisconnect | |
from fastapi.responses import Response | |
class ViserProxyManager: | |
"""Manages Viser server instances for Gradio applications. | |
This class handles the creation, retrieval, and cleanup of Viser server instances, | |
as well as proxying HTTP and WebSocket requests to the appropriate Viser server. | |
Args: | |
app: The FastAPI application to which the proxy routes will be added. | |
min_local_port: Minimum local port number to use for Viser servers. Defaults to 8000. | |
These ports are used only for internal communication and don't need to be publicly exposed. | |
max_local_port: Maximum local port number to use for Viser servers. Defaults to 9000. | |
These ports are used only for internal communication and don't need to be publicly exposed. | |
max_message_size: Maximum WebSocket message size in bytes. Defaults to 100MB. | |
""" | |
def __init__( | |
self, | |
app: FastAPI, | |
min_local_port: int = 8000, | |
max_local_port: int = 9000, | |
max_message_size: int = 300 * 1024 * 1024, # 300MB default | |
) -> None: | |
self._min_port = min_local_port | |
self._max_port = max_local_port | |
self._max_message_size = max_message_size | |
self._server_from_session_hash: dict[str, viser.ViserServer] = {} | |
self._last_port = self._min_port - 1 # Track last port tried | |
async def proxy(request: Request, server_id: str, proxy_path: str): | |
"""Proxy HTTP requests to the appropriate Viser server.""" | |
# Get the local port for this server ID | |
server = self._server_from_session_hash.get(server_id) | |
if server is None: | |
return Response(content="Server not found", status_code=404) | |
# Build target URL | |
if proxy_path: | |
path_suffix = f"/{proxy_path}" | |
else: | |
path_suffix = "/" | |
target_url = f"http://127.0.0.1:{server.get_port()}{path_suffix}" | |
if request.url.query: | |
target_url += f"?{request.url.query}" | |
# Forward request | |
async with httpx.AsyncClient() as client: | |
# Forward the original headers, but remove any problematic ones | |
headers = dict(request.headers) | |
headers.pop("host", None) # Remove host header to avoid conflicts | |
headers["accept-encoding"] = "identity" # Disable compression | |
proxied_req = client.build_request( | |
method=request.method, | |
url=target_url, | |
headers=headers, | |
content=await request.body(), | |
) | |
proxied_resp = await client.send(proxied_req, stream=True) | |
# Get response headers | |
response_headers = dict(proxied_resp.headers) | |
# Check if this is an HTML response | |
content = await proxied_resp.aread() | |
return Response( | |
content=content, | |
status_code=proxied_resp.status_code, | |
headers=response_headers, | |
) | |
# WebSocket Proxy | |
async def websocket_proxy(websocket: WebSocket, server_id: str): | |
"""Proxy WebSocket connections to the appropriate Viser server.""" | |
try: | |
await websocket.accept() | |
server = self._server_from_session_hash.get(server_id) | |
if server is None: | |
await websocket.close(code=1008, reason="Not Found") | |
return | |
# Determine target WebSocket URL | |
target_ws_url = f"ws://127.0.0.1:{server.get_port()}" | |
if not target_ws_url: | |
await websocket.close(code=1008, reason="Not Found") | |
return | |
try: | |
# Connect to the target WebSocket with increased message size and timeout | |
async with websockets.connect( | |
target_ws_url, | |
max_size=self._max_message_size, | |
ping_interval=30, # Send ping every 30 seconds | |
ping_timeout=10, # Wait 10 seconds for pong response | |
close_timeout=5, # Wait 5 seconds for close handshake | |
) as ws_target: | |
# Create tasks for bidirectional communication | |
async def forward_to_target(): | |
"""Forward messages from the client to the target WebSocket.""" | |
try: | |
while True: | |
data = await websocket.receive_bytes() | |
await ws_target.send(data, text=False) | |
except WebSocketDisconnect: | |
try: | |
await ws_target.close() | |
except RuntimeError: | |
pass | |
async def forward_from_target(): | |
"""Forward messages from the target WebSocket to the client.""" | |
try: | |
while True: | |
data = await ws_target.recv(decode=False) | |
await websocket.send_bytes(data) | |
except websockets.exceptions.ConnectionClosed: | |
try: | |
await websocket.close() | |
except RuntimeError: | |
pass | |
# Run both forwarding tasks concurrently | |
forward_task = asyncio.create_task(forward_to_target()) | |
backward_task = asyncio.create_task(forward_from_target()) | |
# Wait for either task to complete (which means a connection was closed) | |
done, pending = await asyncio.wait( | |
[forward_task, backward_task], | |
return_when=asyncio.FIRST_COMPLETED, | |
) | |
# Cancel the remaining task | |
for task in pending: | |
task.cancel() | |
except websockets.exceptions.ConnectionClosedError as e: | |
print(f"WebSocket connection closed with error: {e}") | |
await websocket.close(code=1011, reason="Connection to target closed") | |
except Exception as e: | |
print(f"WebSocket proxy error: {e}") | |
try: | |
await websocket.close(code=1011, reason=str(e)[:120]) # Limit reason length | |
except: | |
pass # Already closed | |
def start_server(self, server_id: str) -> viser.ViserServer: | |
"""Start a new Viser server and associate it with the given server ID. | |
Finds an available port within the configured min_local_port and max_local_port range. | |
These ports are used only for internal communication and don't need to be publicly exposed. | |
Args: | |
server_id: The unique identifier to associate with the new server. | |
Returns: | |
The newly created Viser server instance. | |
Raises: | |
RuntimeError: If no free ports are available in the configured range. | |
""" | |
import socket | |
# Start searching from the last port + 1 (with wraparound) | |
port_range_size = self._max_port - self._min_port + 1 | |
start_port = ( | |
(self._last_port + 1 - self._min_port) % port_range_size | |
) + self._min_port | |
# Try each port once | |
for offset in range(port_range_size): | |
port = ( | |
(start_port - self._min_port + offset) % port_range_size | |
) + self._min_port | |
try: | |
# Check if port is available by attempting to bind to it | |
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: | |
s.bind(("127.0.0.1", port)) | |
# Port is available, create server with this port | |
server = viser.ViserServer(port=port) | |
self._server_from_session_hash[server_id] = server | |
self._last_port = port | |
return server | |
except OSError: | |
# Port is in use, try the next one | |
continue | |
# If we get here, no ports were available | |
raise RuntimeError( | |
f"No available local ports in range {self._min_port}-{self._max_port}" | |
) | |
def get_server(self, server_id: str) -> viser.ViserServer: | |
"""Retrieve a Viser server instance by its ID. | |
Args: | |
server_id: The unique identifier of the server to retrieve. | |
Returns: | |
The Viser server instance associated with the given ID. | |
""" | |
return self._server_from_session_hash[server_id] | |
def stop_server(self, server_id: str) -> None: | |
"""Stop a Viser server and remove it from the manager. | |
Args: | |
server_id: The unique identifier of the server to stop. | |
""" | |
self._server_from_session_hash.pop(server_id).stop() | |