File size: 14,500 Bytes
bfe88a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
import gradio as gr
import httpx # Use httpx for async requests from Gradio backend to API
import websockets # Use websockets library to connect from Gradio backend
import asyncio
import json
import os
import logging
from contextlib import asynccontextmanager

from fastapi import FastAPI, Depends # Import FastAPI itself
from .api import router as api_router # Import the API router
from .database import connect_db, disconnect_db
from . import schemas, auth, dependencies # Import necessary modules

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Base URL for the API (since Gradio runs its own server, even if mounted)
# Assuming Gradio runs on 7860, FastAPI routes mounted under it
API_BASE_URL = "http://127.0.0.1:7860/api" # Adjust if needed

# --- FastAPI Lifespan Event ---
@asynccontextmanager
async def lifespan(app: FastAPI):
    # Startup
    await connect_db()
    # Start background task for websocket listening if needed (or handle within component interaction)
    yield
    # Shutdown
    await disconnect_db()

# Create the main FastAPI app instance that Gradio will use
# We attach our API routes to this instance.
app = FastAPI(lifespan=lifespan)
app.include_router(api_router, prefix="/api") # Mount API routes under /api

# --- Gradio UI Definition ---

# Store websocket connection globally (or within a class) for the Gradio app instance
# This is tricky because Gradio re-runs functions. State management is key.
# We'll connect the WebSocket *after* login and store the connection task/info in gr.State.

# --- Helper functions for Gradio calling the API ---

async def make_api_request(method: str, endpoint: str, **kwargs):
    async with httpx.AsyncClient() as client:
        url = f"{API_BASE_URL}{endpoint}"
        try:
            response = await client.request(method, url, **kwargs)
            response.raise_for_status() # Raise exception for 4xx/5xx errors
            return response.json()
        except httpx.RequestError as e:
            logger.error(f"HTTP Request failed: {e.request.method} {e.request.url} - {e}")
            return {"error": f"Network error contacting API: {e}"}
        except httpx.HTTPStatusError as e:
            logger.error(f"HTTP Status error: {e.response.status_code} - {e.response.text}")
            try:
                detail = e.response.json().get("detail", e.response.text)
            except json.JSONDecodeError:
                detail = e.response.text
            return {"error": f"API Error: {detail}"}
        except Exception as e:
             logger.error(f"Unexpected error during API call: {e}")
             return {"error": f"An unexpected error occurred: {str(e)}"}

# --- WebSocket handling within Gradio ---

async def listen_to_websockets(token: str, notification_state: list):
    """Connects to WS and updates state list when a message arrives."""
    if not token:
        logger.warning("WebSocket listener: No token provided.")
        return notification_state # Return current state if no token

    ws_url_base = API_BASE_URL.replace("http", "ws")
    ws_url = f"{ws_url_base}/ws/{token}"
    logger.info(f"Attempting to connect to WebSocket: {ws_url}")

    try:
        # Add timeout to websocket connection attempt
        async with asyncio.wait_for(websockets.connect(ws_url), timeout=10.0) as websocket:
            logger.info(f"WebSocket connected successfully to {ws_url}")
            while True:
                try:
                    message_str = await websocket.recv()
                    logger.info(f"Received raw message: {message_str}")
                    message_data = json.loads(message_str)
                    logger.info(f"Parsed message data: {message_data}")

                    # Ensure it's the expected notification format
                    if message_data.get("type") == "new_user":
                         notification = schemas.Notification(**message_data)
                         # Prepend to show newest first
                         notification_state.insert(0, notification.message)
                         logger.info(f"Notification added: {notification.message}")
                         # Limit state history (optional)
                         if len(notification_state) > 10:
                             notification_state.pop()
                         # IMPORTANT: Need to trigger Gradio update. This function itself
                         # cannot directly update UI. It modifies the state, and we need
                         # a gr.update() returned by a Gradio event handler that *reads*
                         # this state. We use a polling mechanism or a hidden button trick.
                         # Let's use polling via `every=` parameter in Gradio.
                         # This function's primary job is just modifying the list.
                         # We return the modified list, but Gradio needs an event.
                         # ---> See the use of `gr.Textbox.change` and `every` below.

                except websockets.ConnectionClosedOK:
                    logger.info("WebSocket connection closed normally.")
                    break
                except websockets.ConnectionClosedError as e:
                    logger.error(f"WebSocket connection closed with error: {e}")
                    break
                except json.JSONDecodeError:
                     logger.error(f"Failed to decode JSON from WebSocket message: {message_str}")
                except Exception as e:
                     logger.error(f"Error in WebSocket listener loop: {e}")
                     # Avoid breaking the loop on transient errors maybe? Add delay?
                     await asyncio.sleep(1) # Short delay before retrying receive
    except asyncio.TimeoutError:
        logger.error(f"WebSocket connection timed out: {ws_url}")
    except websockets.exceptions.InvalidURI:
         logger.error(f"Invalid WebSocket URI: {ws_url}")
    except websockets.exceptions.WebSocketException as e:
        logger.error(f"WebSocket connection failed: {e}")
    except Exception as e:
        logger.error(f"Unexpected error connecting/listening to WebSocket: {e}")

    # Return the state as it is if connection failed or ended
    # The calling Gradio function will handle this return value.
    return notification_state


# --- Gradio Interface ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    # State variables
    # Holds the session token after login
    auth_token = gr.State(None)
    # Holds user info {id, email} after login
    user_info = gr.State(None)
    # Holds the list of notification messages
    notification_list = gr.State([])
    # Holds the asyncio task for the WebSocket listener
    websocket_task = gr.State(None)

    # --- UI Components ---
    with gr.Tabs() as tabs:
        # --- Registration Tab ---
        with gr.TabItem("Register", id="register_tab"):
            gr.Markdown("## Create a new account")
            reg_email = gr.Textbox(label="Email", type="email")
            reg_password = gr.Textbox(label="Password (min 8 chars)", type="password")
            reg_confirm_password = gr.Textbox(label="Confirm Password", type="password")
            reg_button = gr.Button("Register")
            reg_status = gr.Textbox(label="Status", interactive=False)

        # --- Login Tab ---
        with gr.TabItem("Login", id="login_tab"):
            gr.Markdown("## Login to your account")
            login_email = gr.Textbox(label="Email", type="email")
            login_password = gr.Textbox(label="Password", type="password")
            login_button = gr.Button("Login")
            login_status = gr.Textbox(label="Status", interactive=False)

        # --- Welcome Tab (shown after login) ---
        with gr.TabItem("Welcome", id="welcome_tab", visible=False) as welcome_tab:
            gr.Markdown("## Welcome!", elem_id="welcome_header")
            welcome_message = gr.Markdown("", elem_id="welcome_message")
            logout_button = gr.Button("Logout")
            gr.Markdown("---") # Separator
            gr.Markdown("## Real-time Notifications")
            # Textbox to display notifications, updated periodically
            notification_display = gr.Textbox(
                label="New User Alerts",
                lines=5,
                max_lines=10,
                interactive=False,
                # The `every=1` makes Gradio call the update function every 1 second
                # This function will read the `notification_list` state
                every=1
            )

    # --- Event Handlers ---

    # Registration Logic
    async def handle_register(email, password, confirm_password):
        if not email or not password or not confirm_password:
            return gr.update(value="Please fill in all fields.")
        if password != confirm_password:
            return gr.update(value="Passwords do not match.")
        if len(password) < 8:
            return gr.update(value="Password must be at least 8 characters long.")

        payload = {"email": email, "password": password}
        result = await make_api_request("post", "/register", json=payload)

        if "error" in result:
            return gr.update(value=f"Registration failed: {result['error']}")
        else:
            # Optionally switch to login tab after successful registration
            return gr.update(value=f"Registration successful for {result.get('email')}! Please log in.")

    reg_button.click(
        handle_register,
        inputs=[reg_email, reg_password, reg_confirm_password],
        outputs=[reg_status]
    )

    # Login Logic
    async def handle_login(email, password, current_task):
        if not email or not password:
            return gr.update(value="Please enter email and password."), None, None, None, gr.update(visible=False), current_task

        payload = {"email": email, "password": password}
        result = await make_api_request("post", "/login", json=payload)

        if "error" in result:
            return gr.update(value=f"Login failed: {result['error']}"), None, None, None, gr.update(visible=False), current_task
        else:
            token = result.get("access_token")
            # Fetch user details using the token
            user_data = await dependencies.get_optional_current_user(token) # Use dependency directly

            if not user_data:
                 # This shouldn't happen if login succeeded, but check anyway
                 return gr.update(value="Login succeeded but failed to fetch user data."), None, None, None, gr.update(visible=False), current_task

            # Cancel any existing websocket listener task before starting a new one
            if current_task and not current_task.done():
                current_task.cancel()
                try:
                    await current_task # Wait for cancellation
                except asyncio.CancelledError:
                    logger.info("Previous WebSocket task cancelled.")

            # Start the WebSocket listener task in the background
            # We pass the notification_list state *object* itself, which the task will modify
            new_task = asyncio.create_task(listen_to_websockets(token, notification_list.value)) # Pass the list

            # Update state and UI
            welcome_msg = f"Welcome, {user_data.email}!"
            # Switch tabs and show welcome message
            return (
                gr.update(value="Login successful!"),       # login_status
                token,                                      # auth_token state
                user_data.model_dump(),                     # user_info state (store as dict)
                gr.update(selected="welcome_tab"),          # Switch Tabs
                gr.update(visible=True),                    # Make welcome tab visible
                gr.update(value=welcome_msg),               # Update welcome message markdown
                new_task                                    # websocket_task state
            )

    login_button.click(
        handle_login,
        inputs=[login_email, login_password, websocket_task],
        outputs=[login_status, auth_token, user_info, tabs, welcome_tab, welcome_message, websocket_task]
    )


    # Function to update the notification display based on the state
    # This function is triggered by the `every=1` on the notification_display Textbox
    def update_notification_ui(notif_list_state):
        # Join the list items into a string for display
        return "\n".join(notif_list_state)

    notification_display.change( # Use .change with every= setup on the component
        fn=update_notification_ui,
        inputs=[notification_list], # Read the state
        outputs=[notification_display] # Update the component
    )


    # Logout Logic
    async def handle_logout(current_task):
         # Cancel the websocket listener task if it's running
        if current_task and not current_task.done():
            current_task.cancel()
            try:
                await current_task
            except asyncio.CancelledError:
                logger.info("WebSocket task cancelled on logout.")

        # Clear state and switch back to login tab
        return (
            None,                           # Clear auth_token
            None,                           # Clear user_info
            [],                             # Clear notifications
            None,                           # Clear websocket_task
            gr.update(selected="login_tab"),# Switch Tabs
            gr.update(visible=False),       # Hide welcome tab
            gr.update(value=""),            # Clear welcome message
            gr.update(value="")             # Clear login status
        )

    logout_button.click(
        handle_logout,
        inputs=[websocket_task],
        outputs=[
            auth_token,
            user_info,
            notification_list,
            websocket_task,
            tabs,
            welcome_tab,
            welcome_message,
            login_status
        ]
    )

# Mount the Gradio app onto the FastAPI app at the root
app = gr.mount_gradio_app(app, demo, path="/")

# If running this file directly (for local testing)
# Use uvicorn to run the FastAPI app (which now includes Gradio)
if __name__ == "__main__":
    import uvicorn
    # Use port 7860 as Gradio prefers, host 0.0.0.0 for Docker
    uvicorn.run(app, host="0.0.0.0", port=7860)