Spaces:
Running
Running
# streamlit_app.py | |
import streamlit as st | |
from streamlit_autorefresh import st_autorefresh # For periodic refresh | |
import httpx | |
import asyncio | |
import websockets | |
import json | |
import threading | |
import queue | |
import logging | |
import time | |
import os | |
# Import backend components | |
from app import crud, models, schemas, auth, dependencies | |
from app.database import ensure_db_and_table_exist # Sync function | |
from app.websocket import manager # Import the manager instance | |
# FastAPI imports for mounting | |
from fastapi import FastAPI, Depends, HTTPException, status | |
from fastapi.routing import Mount | |
from fastapi.staticfiles import StaticFiles | |
from app.api import router as api_router # Import the specific API router | |
# --- Logging --- | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# --- Configuration --- | |
# Use environment variable or default for local vs deployed API endpoint | |
# Since we are mounting FastAPI within Streamlit for the HF Space deployment: | |
API_BASE_URL = "http://127.0.0.1:7860/api" # Calls within the same process | |
WS_BASE_URL = API_BASE_URL.replace("http", "ws") | |
# --- Ensure DB exists on first run --- | |
# This runs once per session/process start | |
ensure_db_and_table_exist() | |
# --- FastAPI Mounting Setup --- | |
# Create a FastAPI instance (separate from the Streamlit one) | |
# We won't run this directly with uvicorn, but Streamlit uses it internally | |
api_app = FastAPI(title="Backend API") # Can add lifespan if needed for API-specific setup later | |
api_app.include_router(api_router, prefix="/api") | |
# Mount the FastAPI app within Streamlit's internal Tornado server | |
# This requires monkey-patching or using available hooks if Streamlit allows. | |
# Simpler approach for HF Space: Run FastAPI separately is cleaner if possible. | |
# Reverting to the idea that for this HF Space demo, API calls will be internal HTTP requests. | |
# --- WebSocket Listener Thread --- | |
stop_event = threading.Event() | |
notification_queue = queue.Queue() | |
def websocket_listener(token: str): | |
"""Runs in a background thread to listen for WebSocket messages.""" | |
logger.info(f"[WS Thread] Listener started for token: {token[:10]}...") | |
ws_url = f"{WS_BASE_URL}/ws/{token}" | |
async def listen(): | |
try: | |
async with websockets.connect(ws_url, open_timeout=10.0) as ws: | |
logger.info(f"[WS Thread] Connected to {ws_url}") | |
st.session_state['ws_connected'] = True | |
while not stop_event.is_set(): | |
try: | |
message = await asyncio.wait_for(ws.recv(), timeout=1.0) # Check stop_event frequently | |
logger.info(f"[WS Thread] Received message: {message[:100]}...") | |
try: | |
data = json.loads(message) | |
if data.get("type") == "new_user": | |
notification = schemas.Notification(**data) | |
notification_queue.put(notification.message) # Put message in queue | |
logger.info("[WS Thread] Put notification in queue.") | |
except json.JSONDecodeError: | |
logger.error("[WS Thread] Failed to decode JSON.") | |
except Exception as e: | |
logger.error(f"[WS Thread] Error processing message: {e}") | |
except asyncio.TimeoutError: | |
continue # No message, check stop_event again | |
except websockets.ConnectionClosed: | |
logger.warning("[WS Thread] Connection closed.") | |
break # Exit loop if closed | |
except Exception as e: | |
logger.error(f"[WS Thread] Connection failed or error: {e}") | |
finally: | |
logger.info("[WS Thread] Listener loop finished.") | |
st.session_state['ws_connected'] = False | |
try: | |
asyncio.run(listen()) | |
except Exception as e: | |
logger.error(f"[WS Thread] asyncio.run error: {e}") | |
logger.info("[WS Thread] Listener thread exiting.") | |
# --- Streamlit UI --- | |
st.set_page_config(layout="wide") | |
# --- Initialize Session State --- | |
if 'logged_in' not in st.session_state: | |
st.session_state.logged_in = False | |
if 'token' not in st.session_state: | |
st.session_state.token = None | |
if 'user_email' not in st.session_state: | |
st.session_state.user_email = None | |
if 'notifications' not in st.session_state: | |
st.session_state.notifications = [] | |
if 'ws_thread' not in st.session_state: | |
st.session_state.ws_thread = None | |
if 'ws_connected' not in st.session_state: | |
st.session_state.ws_connected = False | |
# --- Notification Processing --- | |
new_notifications = [] | |
while not notification_queue.empty(): | |
try: | |
msg = notification_queue.get_nowait() | |
new_notifications.append(msg) | |
except queue.Empty: | |
break | |
if new_notifications: | |
logger.info(f"Processing {len(new_notifications)} notifications from queue.") | |
# Prepend new notifications to the session state list | |
current_list = st.session_state.notifications | |
st.session_state.notifications = new_notifications + current_list | |
# Limit history | |
if len(st.session_state.notifications) > 15: | |
st.session_state.notifications = st.session_state.notifications[:15] | |
# No explicit rerun needed here, Streamlit should rerun due to state change (?) | |
# or due to autorefresh below. | |
# --- Auto Refresh --- | |
# Refresh every 2 seconds to check the queue and update display | |
count = st_autorefresh(interval=2000, limit=None, key="notifrefresh") | |
# --- API Client --- | |
client = httpx.AsyncClient(base_url=API_BASE_URL, timeout=10.0) | |
# --- Helper Functions for API Calls --- | |
async def api_register(email, password): | |
try: | |
response = await client.post("/register", json={"email": email, "password": password}) | |
response.raise_for_status() | |
return {"success": True, "data": response.json()} | |
except httpx.HTTPStatusError as e: | |
detail = e.response.json().get("detail", e.response.text) | |
logger.error(f"API Register Error: {e.response.status_code} - {detail}") | |
return {"success": False, "error": f"API Error: {detail}"} | |
except Exception as e: | |
logger.exception("Register call failed") | |
return {"success": False, "error": f"Request failed: {e}"} | |
async def api_login(email, password): | |
try: | |
response = await client.post("/login", json={"email": email, "password": password}) | |
response.raise_for_status() | |
return {"success": True, "data": response.json()} | |
except httpx.HTTPStatusError as e: | |
detail = e.response.json().get("detail", e.response.text) | |
logger.error(f"API Login Error: {e.response.status_code} - {detail}") | |
return {"success": False, "error": f"API Error: {detail}"} | |
except Exception as e: | |
logger.exception("Login call failed") | |
return {"success": False, "error": f"Request failed: {e}"} | |
# --- UI Rendering --- | |
st.title("Authentication & Notification App (Streamlit)") | |
if not st.session_state.logged_in: | |
st.sidebar.header("Login or Register") | |
login_tab, register_tab = st.sidebar.tabs(["Login", "Register"]) | |
with login_tab: | |
with st.form("login_form"): | |
login_email = st.text_input("Email", key="login_email") | |
login_password = st.text_input("Password", type="password", key="login_password") | |
login_button = st.form_submit_button("Login") | |
if login_button: | |
if not login_email or not login_password: | |
st.error("Please enter email and password.") | |
else: | |
result = asyncio.run(api_login(login_email, login_password)) # Run async in sync context | |
if result["success"]: | |
token = result["data"]["access_token"] | |
# Attempt to get user info immediately - needs modification if /users/me requires auth header | |
# For simplicity, just store email from login form for now | |
st.session_state.logged_in = True | |
st.session_state.token = token | |
st.session_state.user_email = login_email # Store email used for login | |
st.session_state.notifications = [] # Clear old notifications | |
# Start WebSocket listener thread | |
stop_event.clear() # Ensure stop event is clear | |
thread = threading.Thread(target=websocket_listener, args=(token,), daemon=True) | |
st.session_state.ws_thread = thread | |
thread.start() | |
logger.info("Login successful, WS thread started.") | |
st.rerun() # Rerun immediately to switch view | |
else: | |
st.error(f"Login failed: {result['error']}") | |
with register_tab: | |
with st.form("register_form"): | |
reg_email = st.text_input("Email", key="reg_email") | |
reg_password = st.text_input("Password", type="password", key="reg_password") | |
reg_confirm = st.text_input("Confirm Password", type="password", key="reg_confirm") | |
register_button = st.form_submit_button("Register") | |
if register_button: | |
if not reg_email or not reg_password or not reg_confirm: | |
st.error("Please fill all fields.") | |
elif reg_password != reg_confirm: | |
st.error("Passwords do not match.") | |
elif len(reg_password) < 8: | |
st.error("Password must be at least 8 characters.") | |
else: | |
result = asyncio.run(api_register(reg_email, reg_password)) | |
if result["success"]: | |
st.success(f"Registration successful for {result['data']['email']}! Please log in.") | |
else: | |
st.error(f"Registration failed: {result['error']}") | |
else: # Logged In View | |
st.sidebar.header(f"Welcome, {st.session_state.user_email}!") | |
if st.sidebar.button("Logout"): | |
logger.info("Logout requested.") | |
# Stop WebSocket thread | |
if st.session_state.ws_thread and st.session_state.ws_thread.is_alive(): | |
logger.info("Signalling WS thread to stop.") | |
stop_event.set() | |
st.session_state.ws_thread.join(timeout=2.0) # Wait briefly for thread exit | |
if st.session_state.ws_thread.is_alive(): | |
logger.warning("WS thread did not exit cleanly.") | |
# Clear session state | |
st.session_state.logged_in = False | |
st.session_state.token = None | |
st.session_state.user_email = None | |
st.session_state.notifications = [] | |
st.session_state.ws_thread = None | |
st.session_state.ws_connected = False | |
logger.info("Session cleared.") | |
st.rerun() | |
st.header("Dashboard") | |
# Display notifications | |
st.subheader("Real-time Notifications") | |
ws_status = "Connected" if st.session_state.ws_connected else "Disconnected" | |
st.caption(f"WebSocket Status: {ws_status}") | |
if st.session_state.notifications: | |
for i, msg in enumerate(st.session_state.notifications): | |
st.info(f"{msg}", icon="🔔") | |
else: | |
st.text("No new notifications.") | |
# Add a button to manually check queue/refresh if needed | |
# if st.button("Check for notifications"): | |
# st.rerun() # Force rerun which includes queue check | |
# --- Final Cleanup --- | |
# Ensure httpx client is closed if script exits abnormally | |
# (This might not always run depending on how Streamlit terminates) | |
# Ideally handled within context managers if used more extensively | |
# asyncio.run(client.aclose()) |