AuthenticationApp / streamlit_app.py
amaye15's picture
debug
c0a29ce
raw
history blame
11.8 kB
# streamlit_app.py
import streamlit as st
from streamlit_autorefresh import st_autorefresh # For periodic refresh
import httpx
import asyncio
import websockets
import json
import threading
import queue
import logging
import time
import os
# Import backend components
from app import crud, models, schemas, auth, dependencies
from app.database import ensure_db_and_table_exist # Sync function
from app.websocket import manager # Import the manager instance
# FastAPI imports for mounting
from fastapi import FastAPI, Depends, HTTPException, status
from fastapi.routing import Mount
from fastapi.staticfiles import StaticFiles
from app.api import router as api_router # Import the specific API router
# --- Logging ---
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# --- Configuration ---
# Use environment variable or default for local vs deployed API endpoint
# Since we are mounting FastAPI within Streamlit for the HF Space deployment:
API_BASE_URL = "http://127.0.0.1:7860/api" # Calls within the same process
WS_BASE_URL = API_BASE_URL.replace("http", "ws")
# --- Ensure DB exists on first run ---
# This runs once per session/process start
ensure_db_and_table_exist()
# --- FastAPI Mounting Setup ---
# Create a FastAPI instance (separate from the Streamlit one)
# We won't run this directly with uvicorn, but Streamlit uses it internally
api_app = FastAPI(title="Backend API") # Can add lifespan if needed for API-specific setup later
api_app.include_router(api_router, prefix="/api")
# Mount the FastAPI app within Streamlit's internal Tornado server
# This requires monkey-patching or using available hooks if Streamlit allows.
# Simpler approach for HF Space: Run FastAPI separately is cleaner if possible.
# Reverting to the idea that for this HF Space demo, API calls will be internal HTTP requests.
# --- WebSocket Listener Thread ---
stop_event = threading.Event()
notification_queue = queue.Queue()
def websocket_listener(token: str):
"""Runs in a background thread to listen for WebSocket messages."""
logger.info(f"[WS Thread] Listener started for token: {token[:10]}...")
ws_url = f"{WS_BASE_URL}/ws/{token}"
async def listen():
try:
async with websockets.connect(ws_url, open_timeout=10.0) as ws:
logger.info(f"[WS Thread] Connected to {ws_url}")
st.session_state['ws_connected'] = True
while not stop_event.is_set():
try:
message = await asyncio.wait_for(ws.recv(), timeout=1.0) # Check stop_event frequently
logger.info(f"[WS Thread] Received message: {message[:100]}...")
try:
data = json.loads(message)
if data.get("type") == "new_user":
notification = schemas.Notification(**data)
notification_queue.put(notification.message) # Put message in queue
logger.info("[WS Thread] Put notification in queue.")
except json.JSONDecodeError:
logger.error("[WS Thread] Failed to decode JSON.")
except Exception as e:
logger.error(f"[WS Thread] Error processing message: {e}")
except asyncio.TimeoutError:
continue # No message, check stop_event again
except websockets.ConnectionClosed:
logger.warning("[WS Thread] Connection closed.")
break # Exit loop if closed
except Exception as e:
logger.error(f"[WS Thread] Connection failed or error: {e}")
finally:
logger.info("[WS Thread] Listener loop finished.")
st.session_state['ws_connected'] = False
try:
asyncio.run(listen())
except Exception as e:
logger.error(f"[WS Thread] asyncio.run error: {e}")
logger.info("[WS Thread] Listener thread exiting.")
# --- Streamlit UI ---
st.set_page_config(layout="wide")
# --- Initialize Session State ---
if 'logged_in' not in st.session_state:
st.session_state.logged_in = False
if 'token' not in st.session_state:
st.session_state.token = None
if 'user_email' not in st.session_state:
st.session_state.user_email = None
if 'notifications' not in st.session_state:
st.session_state.notifications = []
if 'ws_thread' not in st.session_state:
st.session_state.ws_thread = None
if 'ws_connected' not in st.session_state:
st.session_state.ws_connected = False
# --- Notification Processing ---
new_notifications = []
while not notification_queue.empty():
try:
msg = notification_queue.get_nowait()
new_notifications.append(msg)
except queue.Empty:
break
if new_notifications:
logger.info(f"Processing {len(new_notifications)} notifications from queue.")
# Prepend new notifications to the session state list
current_list = st.session_state.notifications
st.session_state.notifications = new_notifications + current_list
# Limit history
if len(st.session_state.notifications) > 15:
st.session_state.notifications = st.session_state.notifications[:15]
# No explicit rerun needed here, Streamlit should rerun due to state change (?)
# or due to autorefresh below.
# --- Auto Refresh ---
# Refresh every 2 seconds to check the queue and update display
count = st_autorefresh(interval=2000, limit=None, key="notifrefresh")
# --- API Client ---
client = httpx.AsyncClient(base_url=API_BASE_URL, timeout=10.0)
# --- Helper Functions for API Calls ---
async def api_register(email, password):
try:
response = await client.post("/register", json={"email": email, "password": password})
response.raise_for_status()
return {"success": True, "data": response.json()}
except httpx.HTTPStatusError as e:
detail = e.response.json().get("detail", e.response.text)
logger.error(f"API Register Error: {e.response.status_code} - {detail}")
return {"success": False, "error": f"API Error: {detail}"}
except Exception as e:
logger.exception("Register call failed")
return {"success": False, "error": f"Request failed: {e}"}
async def api_login(email, password):
try:
response = await client.post("/login", json={"email": email, "password": password})
response.raise_for_status()
return {"success": True, "data": response.json()}
except httpx.HTTPStatusError as e:
detail = e.response.json().get("detail", e.response.text)
logger.error(f"API Login Error: {e.response.status_code} - {detail}")
return {"success": False, "error": f"API Error: {detail}"}
except Exception as e:
logger.exception("Login call failed")
return {"success": False, "error": f"Request failed: {e}"}
# --- UI Rendering ---
st.title("Authentication & Notification App (Streamlit)")
if not st.session_state.logged_in:
st.sidebar.header("Login or Register")
login_tab, register_tab = st.sidebar.tabs(["Login", "Register"])
with login_tab:
with st.form("login_form"):
login_email = st.text_input("Email", key="login_email")
login_password = st.text_input("Password", type="password", key="login_password")
login_button = st.form_submit_button("Login")
if login_button:
if not login_email or not login_password:
st.error("Please enter email and password.")
else:
result = asyncio.run(api_login(login_email, login_password)) # Run async in sync context
if result["success"]:
token = result["data"]["access_token"]
# Attempt to get user info immediately - needs modification if /users/me requires auth header
# For simplicity, just store email from login form for now
st.session_state.logged_in = True
st.session_state.token = token
st.session_state.user_email = login_email # Store email used for login
st.session_state.notifications = [] # Clear old notifications
# Start WebSocket listener thread
stop_event.clear() # Ensure stop event is clear
thread = threading.Thread(target=websocket_listener, args=(token,), daemon=True)
st.session_state.ws_thread = thread
thread.start()
logger.info("Login successful, WS thread started.")
st.rerun() # Rerun immediately to switch view
else:
st.error(f"Login failed: {result['error']}")
with register_tab:
with st.form("register_form"):
reg_email = st.text_input("Email", key="reg_email")
reg_password = st.text_input("Password", type="password", key="reg_password")
reg_confirm = st.text_input("Confirm Password", type="password", key="reg_confirm")
register_button = st.form_submit_button("Register")
if register_button:
if not reg_email or not reg_password or not reg_confirm:
st.error("Please fill all fields.")
elif reg_password != reg_confirm:
st.error("Passwords do not match.")
elif len(reg_password) < 8:
st.error("Password must be at least 8 characters.")
else:
result = asyncio.run(api_register(reg_email, reg_password))
if result["success"]:
st.success(f"Registration successful for {result['data']['email']}! Please log in.")
else:
st.error(f"Registration failed: {result['error']}")
else: # Logged In View
st.sidebar.header(f"Welcome, {st.session_state.user_email}!")
if st.sidebar.button("Logout"):
logger.info("Logout requested.")
# Stop WebSocket thread
if st.session_state.ws_thread and st.session_state.ws_thread.is_alive():
logger.info("Signalling WS thread to stop.")
stop_event.set()
st.session_state.ws_thread.join(timeout=2.0) # Wait briefly for thread exit
if st.session_state.ws_thread.is_alive():
logger.warning("WS thread did not exit cleanly.")
# Clear session state
st.session_state.logged_in = False
st.session_state.token = None
st.session_state.user_email = None
st.session_state.notifications = []
st.session_state.ws_thread = None
st.session_state.ws_connected = False
logger.info("Session cleared.")
st.rerun()
st.header("Dashboard")
# Display notifications
st.subheader("Real-time Notifications")
ws_status = "Connected" if st.session_state.ws_connected else "Disconnected"
st.caption(f"WebSocket Status: {ws_status}")
if st.session_state.notifications:
for i, msg in enumerate(st.session_state.notifications):
st.info(f"{msg}", icon="🔔")
else:
st.text("No new notifications.")
# Add a button to manually check queue/refresh if needed
# if st.button("Check for notifications"):
# st.rerun() # Force rerun which includes queue check
# --- Final Cleanup ---
# Ensure httpx client is closed if script exits abnormally
# (This might not always run depending on how Streamlit terminates)
# Ideally handled within context managers if used more extensively
# asyncio.run(client.aclose())