Spaces:
Running
Running
File size: 11,844 Bytes
915d8f9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 |
# streamlit_app.py
import streamlit as st
from streamlit_autorefresh import st_autorefresh # For periodic refresh
import httpx
import asyncio
import websockets
import json
import threading
import queue
import logging
import time
import os
# Import backend components
from app import crud, models, schemas, auth, dependencies
from app.database import ensure_db_and_table_exist # Sync function
from app.websocket import manager # Import the manager instance
# FastAPI imports for mounting
from fastapi import FastAPI, Depends, HTTPException, status
from fastapi.routing import Mount
from fastapi.staticfiles import StaticFiles
from app.api import router as api_router # Import the specific API router
# --- Logging ---
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# --- Configuration ---
# Use environment variable or default for local vs deployed API endpoint
# Since we are mounting FastAPI within Streamlit for the HF Space deployment:
API_BASE_URL = "http://127.0.0.1:7860/api" # Calls within the same process
WS_BASE_URL = API_BASE_URL.replace("http", "ws")
# --- Ensure DB exists on first run ---
# This runs once per session/process start
ensure_db_and_table_exist()
# --- FastAPI Mounting Setup ---
# Create a FastAPI instance (separate from the Streamlit one)
# We won't run this directly with uvicorn, but Streamlit uses it internally
api_app = FastAPI(title="Backend API") # Can add lifespan if needed for API-specific setup later
api_app.include_router(api_router, prefix="/api")
# Mount the FastAPI app within Streamlit's internal Tornado server
# This requires monkey-patching or using available hooks if Streamlit allows.
# Simpler approach for HF Space: Run FastAPI separately is cleaner if possible.
# Reverting to the idea that for this HF Space demo, API calls will be internal HTTP requests.
# --- WebSocket Listener Thread ---
stop_event = threading.Event()
notification_queue = queue.Queue()
def websocket_listener(token: str):
"""Runs in a background thread to listen for WebSocket messages."""
logger.info(f"[WS Thread] Listener started for token: {token[:10]}...")
ws_url = f"{WS_BASE_URL}/ws/{token}"
async def listen():
try:
async with websockets.connect(ws_url, open_timeout=10.0) as ws:
logger.info(f"[WS Thread] Connected to {ws_url}")
st.session_state['ws_connected'] = True
while not stop_event.is_set():
try:
message = await asyncio.wait_for(ws.recv(), timeout=1.0) # Check stop_event frequently
logger.info(f"[WS Thread] Received message: {message[:100]}...")
try:
data = json.loads(message)
if data.get("type") == "new_user":
notification = schemas.Notification(**data)
notification_queue.put(notification.message) # Put message in queue
logger.info("[WS Thread] Put notification in queue.")
except json.JSONDecodeError:
logger.error("[WS Thread] Failed to decode JSON.")
except Exception as e:
logger.error(f"[WS Thread] Error processing message: {e}")
except asyncio.TimeoutError:
continue # No message, check stop_event again
except websockets.ConnectionClosed:
logger.warning("[WS Thread] Connection closed.")
break # Exit loop if closed
except Exception as e:
logger.error(f"[WS Thread] Connection failed or error: {e}")
finally:
logger.info("[WS Thread] Listener loop finished.")
st.session_state['ws_connected'] = False
try:
asyncio.run(listen())
except Exception as e:
logger.error(f"[WS Thread] asyncio.run error: {e}")
logger.info("[WS Thread] Listener thread exiting.")
# --- Streamlit UI ---
st.set_page_config(layout="wide")
# --- Initialize Session State ---
if 'logged_in' not in st.session_state:
st.session_state.logged_in = False
if 'token' not in st.session_state:
st.session_state.token = None
if 'user_email' not in st.session_state:
st.session_state.user_email = None
if 'notifications' not in st.session_state:
st.session_state.notifications = []
if 'ws_thread' not in st.session_state:
st.session_state.ws_thread = None
if 'ws_connected' not in st.session_state:
st.session_state.ws_connected = False
# --- Notification Processing ---
new_notifications = []
while not notification_queue.empty():
try:
msg = notification_queue.get_nowait()
new_notifications.append(msg)
except queue.Empty:
break
if new_notifications:
logger.info(f"Processing {len(new_notifications)} notifications from queue.")
# Prepend new notifications to the session state list
current_list = st.session_state.notifications
st.session_state.notifications = new_notifications + current_list
# Limit history
if len(st.session_state.notifications) > 15:
st.session_state.notifications = st.session_state.notifications[:15]
# No explicit rerun needed here, Streamlit should rerun due to state change (?)
# or due to autorefresh below.
# --- Auto Refresh ---
# Refresh every 2 seconds to check the queue and update display
count = st_autorefresh(interval=2000, limit=None, key="notifrefresh")
# --- API Client ---
client = httpx.AsyncClient(base_url=API_BASE_URL, timeout=10.0)
# --- Helper Functions for API Calls ---
async def api_register(email, password):
try:
response = await client.post("/register", json={"email": email, "password": password})
response.raise_for_status()
return {"success": True, "data": response.json()}
except httpx.HTTPStatusError as e:
detail = e.response.json().get("detail", e.response.text)
logger.error(f"API Register Error: {e.response.status_code} - {detail}")
return {"success": False, "error": f"API Error: {detail}"}
except Exception as e:
logger.exception("Register call failed")
return {"success": False, "error": f"Request failed: {e}"}
async def api_login(email, password):
try:
response = await client.post("/login", json={"email": email, "password": password})
response.raise_for_status()
return {"success": True, "data": response.json()}
except httpx.HTTPStatusError as e:
detail = e.response.json().get("detail", e.response.text)
logger.error(f"API Login Error: {e.response.status_code} - {detail}")
return {"success": False, "error": f"API Error: {detail}"}
except Exception as e:
logger.exception("Login call failed")
return {"success": False, "error": f"Request failed: {e}"}
# --- UI Rendering ---
st.title("Authentication & Notification App (Streamlit)")
if not st.session_state.logged_in:
st.sidebar.header("Login or Register")
login_tab, register_tab = st.sidebar.tabs(["Login", "Register"])
with login_tab:
with st.form("login_form"):
login_email = st.text_input("Email", key="login_email")
login_password = st.text_input("Password", type="password", key="login_password")
login_button = st.form_submit_button("Login")
if login_button:
if not login_email or not login_password:
st.error("Please enter email and password.")
else:
result = asyncio.run(api_login(login_email, login_password)) # Run async in sync context
if result["success"]:
token = result["data"]["access_token"]
# Attempt to get user info immediately - needs modification if /users/me requires auth header
# For simplicity, just store email from login form for now
st.session_state.logged_in = True
st.session_state.token = token
st.session_state.user_email = login_email # Store email used for login
st.session_state.notifications = [] # Clear old notifications
# Start WebSocket listener thread
stop_event.clear() # Ensure stop event is clear
thread = threading.Thread(target=websocket_listener, args=(token,), daemon=True)
st.session_state.ws_thread = thread
thread.start()
logger.info("Login successful, WS thread started.")
st.rerun() # Rerun immediately to switch view
else:
st.error(f"Login failed: {result['error']}")
with register_tab:
with st.form("register_form"):
reg_email = st.text_input("Email", key="reg_email")
reg_password = st.text_input("Password", type="password", key="reg_password")
reg_confirm = st.text_input("Confirm Password", type="password", key="reg_confirm")
register_button = st.form_submit_button("Register")
if register_button:
if not reg_email or not reg_password or not reg_confirm:
st.error("Please fill all fields.")
elif reg_password != reg_confirm:
st.error("Passwords do not match.")
elif len(reg_password) < 8:
st.error("Password must be at least 8 characters.")
else:
result = asyncio.run(api_register(reg_email, reg_password))
if result["success"]:
st.success(f"Registration successful for {result['data']['email']}! Please log in.")
else:
st.error(f"Registration failed: {result['error']}")
else: # Logged In View
st.sidebar.header(f"Welcome, {st.session_state.user_email}!")
if st.sidebar.button("Logout"):
logger.info("Logout requested.")
# Stop WebSocket thread
if st.session_state.ws_thread and st.session_state.ws_thread.is_alive():
logger.info("Signalling WS thread to stop.")
stop_event.set()
st.session_state.ws_thread.join(timeout=2.0) # Wait briefly for thread exit
if st.session_state.ws_thread.is_alive():
logger.warning("WS thread did not exit cleanly.")
# Clear session state
st.session_state.logged_in = False
st.session_state.token = None
st.session_state.user_email = None
st.session_state.notifications = []
st.session_state.ws_thread = None
st.session_state.ws_connected = False
logger.info("Session cleared.")
st.rerun()
st.header("Dashboard")
# Display notifications
st.subheader("Real-time Notifications")
ws_status = "Connected" if st.session_state.ws_connected else "Disconnected"
st.caption(f"WebSocket Status: {ws_status}")
if st.session_state.notifications:
for i, msg in enumerate(st.session_state.notifications):
st.info(f"{msg}", icon="🔔")
else:
st.text("No new notifications.")
# Add a button to manually check queue/refresh if needed
# if st.button("Check for notifications"):
# st.rerun() # Force rerun which includes queue check
# --- Final Cleanup ---
# Ensure httpx client is closed if script exits abnormally
# (This might not always run depending on how Streamlit terminates)
# Ideally handled within context managers if used more extensively
# asyncio.run(client.aclose()) |