amaye15's picture
Intial Deployment
bfe88a9
raw
history blame
14.5 kB
import gradio as gr
import httpx # Use httpx for async requests from Gradio backend to API
import websockets # Use websockets library to connect from Gradio backend
import asyncio
import json
import os
import logging
from contextlib import asynccontextmanager
from fastapi import FastAPI, Depends # Import FastAPI itself
from .api import router as api_router # Import the API router
from .database import connect_db, disconnect_db
from . import schemas, auth, dependencies # Import necessary modules
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Base URL for the API (since Gradio runs its own server, even if mounted)
# Assuming Gradio runs on 7860, FastAPI routes mounted under it
API_BASE_URL = "http://127.0.0.1:7860/api" # Adjust if needed
# --- FastAPI Lifespan Event ---
@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup
await connect_db()
# Start background task for websocket listening if needed (or handle within component interaction)
yield
# Shutdown
await disconnect_db()
# Create the main FastAPI app instance that Gradio will use
# We attach our API routes to this instance.
app = FastAPI(lifespan=lifespan)
app.include_router(api_router, prefix="/api") # Mount API routes under /api
# --- Gradio UI Definition ---
# Store websocket connection globally (or within a class) for the Gradio app instance
# This is tricky because Gradio re-runs functions. State management is key.
# We'll connect the WebSocket *after* login and store the connection task/info in gr.State.
# --- Helper functions for Gradio calling the API ---
async def make_api_request(method: str, endpoint: str, **kwargs):
async with httpx.AsyncClient() as client:
url = f"{API_BASE_URL}{endpoint}"
try:
response = await client.request(method, url, **kwargs)
response.raise_for_status() # Raise exception for 4xx/5xx errors
return response.json()
except httpx.RequestError as e:
logger.error(f"HTTP Request failed: {e.request.method} {e.request.url} - {e}")
return {"error": f"Network error contacting API: {e}"}
except httpx.HTTPStatusError as e:
logger.error(f"HTTP Status error: {e.response.status_code} - {e.response.text}")
try:
detail = e.response.json().get("detail", e.response.text)
except json.JSONDecodeError:
detail = e.response.text
return {"error": f"API Error: {detail}"}
except Exception as e:
logger.error(f"Unexpected error during API call: {e}")
return {"error": f"An unexpected error occurred: {str(e)}"}
# --- WebSocket handling within Gradio ---
async def listen_to_websockets(token: str, notification_state: list):
"""Connects to WS and updates state list when a message arrives."""
if not token:
logger.warning("WebSocket listener: No token provided.")
return notification_state # Return current state if no token
ws_url_base = API_BASE_URL.replace("http", "ws")
ws_url = f"{ws_url_base}/ws/{token}"
logger.info(f"Attempting to connect to WebSocket: {ws_url}")
try:
# Add timeout to websocket connection attempt
async with asyncio.wait_for(websockets.connect(ws_url), timeout=10.0) as websocket:
logger.info(f"WebSocket connected successfully to {ws_url}")
while True:
try:
message_str = await websocket.recv()
logger.info(f"Received raw message: {message_str}")
message_data = json.loads(message_str)
logger.info(f"Parsed message data: {message_data}")
# Ensure it's the expected notification format
if message_data.get("type") == "new_user":
notification = schemas.Notification(**message_data)
# Prepend to show newest first
notification_state.insert(0, notification.message)
logger.info(f"Notification added: {notification.message}")
# Limit state history (optional)
if len(notification_state) > 10:
notification_state.pop()
# IMPORTANT: Need to trigger Gradio update. This function itself
# cannot directly update UI. It modifies the state, and we need
# a gr.update() returned by a Gradio event handler that *reads*
# this state. We use a polling mechanism or a hidden button trick.
# Let's use polling via `every=` parameter in Gradio.
# This function's primary job is just modifying the list.
# We return the modified list, but Gradio needs an event.
# ---> See the use of `gr.Textbox.change` and `every` below.
except websockets.ConnectionClosedOK:
logger.info("WebSocket connection closed normally.")
break
except websockets.ConnectionClosedError as e:
logger.error(f"WebSocket connection closed with error: {e}")
break
except json.JSONDecodeError:
logger.error(f"Failed to decode JSON from WebSocket message: {message_str}")
except Exception as e:
logger.error(f"Error in WebSocket listener loop: {e}")
# Avoid breaking the loop on transient errors maybe? Add delay?
await asyncio.sleep(1) # Short delay before retrying receive
except asyncio.TimeoutError:
logger.error(f"WebSocket connection timed out: {ws_url}")
except websockets.exceptions.InvalidURI:
logger.error(f"Invalid WebSocket URI: {ws_url}")
except websockets.exceptions.WebSocketException as e:
logger.error(f"WebSocket connection failed: {e}")
except Exception as e:
logger.error(f"Unexpected error connecting/listening to WebSocket: {e}")
# Return the state as it is if connection failed or ended
# The calling Gradio function will handle this return value.
return notification_state
# --- Gradio Interface ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
# State variables
# Holds the session token after login
auth_token = gr.State(None)
# Holds user info {id, email} after login
user_info = gr.State(None)
# Holds the list of notification messages
notification_list = gr.State([])
# Holds the asyncio task for the WebSocket listener
websocket_task = gr.State(None)
# --- UI Components ---
with gr.Tabs() as tabs:
# --- Registration Tab ---
with gr.TabItem("Register", id="register_tab"):
gr.Markdown("## Create a new account")
reg_email = gr.Textbox(label="Email", type="email")
reg_password = gr.Textbox(label="Password (min 8 chars)", type="password")
reg_confirm_password = gr.Textbox(label="Confirm Password", type="password")
reg_button = gr.Button("Register")
reg_status = gr.Textbox(label="Status", interactive=False)
# --- Login Tab ---
with gr.TabItem("Login", id="login_tab"):
gr.Markdown("## Login to your account")
login_email = gr.Textbox(label="Email", type="email")
login_password = gr.Textbox(label="Password", type="password")
login_button = gr.Button("Login")
login_status = gr.Textbox(label="Status", interactive=False)
# --- Welcome Tab (shown after login) ---
with gr.TabItem("Welcome", id="welcome_tab", visible=False) as welcome_tab:
gr.Markdown("## Welcome!", elem_id="welcome_header")
welcome_message = gr.Markdown("", elem_id="welcome_message")
logout_button = gr.Button("Logout")
gr.Markdown("---") # Separator
gr.Markdown("## Real-time Notifications")
# Textbox to display notifications, updated periodically
notification_display = gr.Textbox(
label="New User Alerts",
lines=5,
max_lines=10,
interactive=False,
# The `every=1` makes Gradio call the update function every 1 second
# This function will read the `notification_list` state
every=1
)
# --- Event Handlers ---
# Registration Logic
async def handle_register(email, password, confirm_password):
if not email or not password or not confirm_password:
return gr.update(value="Please fill in all fields.")
if password != confirm_password:
return gr.update(value="Passwords do not match.")
if len(password) < 8:
return gr.update(value="Password must be at least 8 characters long.")
payload = {"email": email, "password": password}
result = await make_api_request("post", "/register", json=payload)
if "error" in result:
return gr.update(value=f"Registration failed: {result['error']}")
else:
# Optionally switch to login tab after successful registration
return gr.update(value=f"Registration successful for {result.get('email')}! Please log in.")
reg_button.click(
handle_register,
inputs=[reg_email, reg_password, reg_confirm_password],
outputs=[reg_status]
)
# Login Logic
async def handle_login(email, password, current_task):
if not email or not password:
return gr.update(value="Please enter email and password."), None, None, None, gr.update(visible=False), current_task
payload = {"email": email, "password": password}
result = await make_api_request("post", "/login", json=payload)
if "error" in result:
return gr.update(value=f"Login failed: {result['error']}"), None, None, None, gr.update(visible=False), current_task
else:
token = result.get("access_token")
# Fetch user details using the token
user_data = await dependencies.get_optional_current_user(token) # Use dependency directly
if not user_data:
# This shouldn't happen if login succeeded, but check anyway
return gr.update(value="Login succeeded but failed to fetch user data."), None, None, None, gr.update(visible=False), current_task
# Cancel any existing websocket listener task before starting a new one
if current_task and not current_task.done():
current_task.cancel()
try:
await current_task # Wait for cancellation
except asyncio.CancelledError:
logger.info("Previous WebSocket task cancelled.")
# Start the WebSocket listener task in the background
# We pass the notification_list state *object* itself, which the task will modify
new_task = asyncio.create_task(listen_to_websockets(token, notification_list.value)) # Pass the list
# Update state and UI
welcome_msg = f"Welcome, {user_data.email}!"
# Switch tabs and show welcome message
return (
gr.update(value="Login successful!"), # login_status
token, # auth_token state
user_data.model_dump(), # user_info state (store as dict)
gr.update(selected="welcome_tab"), # Switch Tabs
gr.update(visible=True), # Make welcome tab visible
gr.update(value=welcome_msg), # Update welcome message markdown
new_task # websocket_task state
)
login_button.click(
handle_login,
inputs=[login_email, login_password, websocket_task],
outputs=[login_status, auth_token, user_info, tabs, welcome_tab, welcome_message, websocket_task]
)
# Function to update the notification display based on the state
# This function is triggered by the `every=1` on the notification_display Textbox
def update_notification_ui(notif_list_state):
# Join the list items into a string for display
return "\n".join(notif_list_state)
notification_display.change( # Use .change with every= setup on the component
fn=update_notification_ui,
inputs=[notification_list], # Read the state
outputs=[notification_display] # Update the component
)
# Logout Logic
async def handle_logout(current_task):
# Cancel the websocket listener task if it's running
if current_task and not current_task.done():
current_task.cancel()
try:
await current_task
except asyncio.CancelledError:
logger.info("WebSocket task cancelled on logout.")
# Clear state and switch back to login tab
return (
None, # Clear auth_token
None, # Clear user_info
[], # Clear notifications
None, # Clear websocket_task
gr.update(selected="login_tab"),# Switch Tabs
gr.update(visible=False), # Hide welcome tab
gr.update(value=""), # Clear welcome message
gr.update(value="") # Clear login status
)
logout_button.click(
handle_logout,
inputs=[websocket_task],
outputs=[
auth_token,
user_info,
notification_list,
websocket_task,
tabs,
welcome_tab,
welcome_message,
login_status
]
)
# Mount the Gradio app onto the FastAPI app at the root
app = gr.mount_gradio_app(app, demo, path="/")
# If running this file directly (for local testing)
# Use uvicorn to run the FastAPI app (which now includes Gradio)
if __name__ == "__main__":
import uvicorn
# Use port 7860 as Gradio prefers, host 0.0.0.0 for Docker
uvicorn.run(app, host="0.0.0.0", port=7860)