File size: 9,662 Bytes
4d1a850
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
import asyncio

import httpx
import viser
import websockets
from fastapi import FastAPI, Request, WebSocket, WebSocketDisconnect
from fastapi.responses import Response


class ViserProxyManager:
    """Manages Viser server instances for Gradio applications.

    This class handles the creation, retrieval, and cleanup of Viser server instances,
    as well as proxying HTTP and WebSocket requests to the appropriate Viser server.

    Args:
        app: The FastAPI application to which the proxy routes will be added.
        min_local_port: Minimum local port number to use for Viser servers. Defaults to 8000.
            These ports are used only for internal communication and don't need to be publicly exposed.
        max_local_port: Maximum local port number to use for Viser servers. Defaults to 9000.
            These ports are used only for internal communication and don't need to be publicly exposed.
        max_message_size: Maximum WebSocket message size in bytes. Defaults to 100MB.
    """

    def __init__(
        self,
        app: FastAPI,
        min_local_port: int = 8000,
        max_local_port: int = 9000,
        max_message_size: int = 300 * 1024 * 1024,  # 300MB default
    ) -> None:
        self._min_port = min_local_port
        self._max_port = max_local_port
        self._max_message_size = max_message_size
        self._server_from_session_hash: dict[str, viser.ViserServer] = {}
        self._last_port = self._min_port - 1  # Track last port tried

        @app.get("/viser/{server_id}/{proxy_path:path}")
        async def proxy(request: Request, server_id: str, proxy_path: str):
            """Proxy HTTP requests to the appropriate Viser server."""
            # Get the local port for this server ID
            server = self._server_from_session_hash.get(server_id)
            if server is None:
                return Response(content="Server not found", status_code=404)

            # Build target URL
            if proxy_path:
                path_suffix = f"/{proxy_path}"
            else:
                path_suffix = "/"

            target_url = f"http://127.0.0.1:{server.get_port()}{path_suffix}"
            if request.url.query:
                target_url += f"?{request.url.query}"

            # Forward request
            async with httpx.AsyncClient() as client:
                # Forward the original headers, but remove any problematic ones
                headers = dict(request.headers)
                headers.pop("host", None)  # Remove host header to avoid conflicts
                headers["accept-encoding"] = "identity"  # Disable compression

                proxied_req = client.build_request(
                    method=request.method,
                    url=target_url,
                    headers=headers,
                    content=await request.body(),
                )
                proxied_resp = await client.send(proxied_req, stream=True)

                # Get response headers
                response_headers = dict(proxied_resp.headers)

                # Check if this is an HTML response
                content = await proxied_resp.aread()
                return Response(
                    content=content,
                    status_code=proxied_resp.status_code,
                    headers=response_headers,
                )

        # WebSocket Proxy
        @app.websocket("/viser/{server_id}")
        async def websocket_proxy(websocket: WebSocket, server_id: str):
            """Proxy WebSocket connections to the appropriate Viser server."""
            try:
                await websocket.accept()
                
                server = self._server_from_session_hash.get(server_id)
                if server is None:
                    await websocket.close(code=1008, reason="Not Found")
                    return

                # Determine target WebSocket URL
                target_ws_url = f"ws://127.0.0.1:{server.get_port()}"

                if not target_ws_url:
                    await websocket.close(code=1008, reason="Not Found")
                    return

                try:
                    # Connect to the target WebSocket with increased message size and timeout
                    async with websockets.connect(
                        target_ws_url,
                        max_size=self._max_message_size,
                        ping_interval=30,  # Send ping every 30 seconds
                        ping_timeout=10,   # Wait 10 seconds for pong response
                        close_timeout=5,   # Wait 5 seconds for close handshake
                    ) as ws_target:
                        # Create tasks for bidirectional communication
                        async def forward_to_target():
                            """Forward messages from the client to the target WebSocket."""
                            try:
                                while True:
                                    data = await websocket.receive_bytes()
                                    await ws_target.send(data, text=False)
                            except WebSocketDisconnect:
                                try:
                                    await ws_target.close()
                                except RuntimeError:
                                    pass

                        async def forward_from_target():
                            """Forward messages from the target WebSocket to the client."""
                            try:
                                while True:
                                    data = await ws_target.recv(decode=False)
                                    await websocket.send_bytes(data)
                            except websockets.exceptions.ConnectionClosed:
                                try:
                                    await websocket.close()
                                except RuntimeError:
                                    pass

                        # Run both forwarding tasks concurrently
                        forward_task = asyncio.create_task(forward_to_target())
                        backward_task = asyncio.create_task(forward_from_target())

                        # Wait for either task to complete (which means a connection was closed)
                        done, pending = await asyncio.wait(
                            [forward_task, backward_task],
                            return_when=asyncio.FIRST_COMPLETED,
                        )

                        # Cancel the remaining task
                        for task in pending:
                            task.cancel()

                except websockets.exceptions.ConnectionClosedError as e:
                    print(f"WebSocket connection closed with error: {e}")
                    await websocket.close(code=1011, reason="Connection to target closed")
                    
            except Exception as e:
                print(f"WebSocket proxy error: {e}")
                try:
                    await websocket.close(code=1011, reason=str(e)[:120])  # Limit reason length
                except:
                    pass  # Already closed

    def start_server(self, server_id: str) -> viser.ViserServer:
        """Start a new Viser server and associate it with the given server ID.

        Finds an available port within the configured min_local_port and max_local_port range.
        These ports are used only for internal communication and don't need to be publicly exposed.

        Args:
            server_id: The unique identifier to associate with the new server.

        Returns:
            The newly created Viser server instance.

        Raises:
            RuntimeError: If no free ports are available in the configured range.
        """
        import socket

        # Start searching from the last port + 1 (with wraparound)
        port_range_size = self._max_port - self._min_port + 1
        start_port = (
            (self._last_port + 1 - self._min_port) % port_range_size
        ) + self._min_port

        # Try each port once
        for offset in range(port_range_size):
            port = (
                (start_port - self._min_port + offset) % port_range_size
            ) + self._min_port
            try:
                # Check if port is available by attempting to bind to it
                with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
                    s.bind(("127.0.0.1", port))
                    # Port is available, create server with this port
                    server = viser.ViserServer(port=port)
                    self._server_from_session_hash[server_id] = server
                    self._last_port = port
                    return server
            except OSError:
                # Port is in use, try the next one
                continue

        # If we get here, no ports were available
        raise RuntimeError(
            f"No available local ports in range {self._min_port}-{self._max_port}"
        )

    def get_server(self, server_id: str) -> viser.ViserServer:
        """Retrieve a Viser server instance by its ID.

        Args:
            server_id: The unique identifier of the server to retrieve.

        Returns:
            The Viser server instance associated with the given ID.
        """
        return self._server_from_session_hash[server_id]

    def stop_server(self, server_id: str) -> None:
        """Stop a Viser server and remove it from the manager.

        Args:
            server_id: The unique identifier of the server to stop.
        """
        self._server_from_session_hash.pop(server_id).stop()