Spaces:
Running
Running
Intial Deployment
Browse files- .env.example +2 -0
- .gitignore +7 -0
- Dockerfile +27 -0
- README.md +5 -0
- app/__init__.py +0 -0
- app/api.py +97 -0
- app/auth.py +0 -0
- app/crud.py +22 -0
- app/database.py +47 -0
- app/dependencies.py +22 -0
- app/main.py +318 -0
- app/models.py +8 -0
- app/schemas.py +18 -0
- app/websocket.py +66 -0
- requirements.txt +11 -0
.env.example
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
DATABASE_URL=sqlite+aiosqlite:///./app/app.db
|
2 |
+
SECRET_KEY=a_very_secret_key_change_this_in_production # For signing session IDs
|
.gitignore
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
venv/
|
2 |
+
__pycache__/
|
3 |
+
*.pyc
|
4 |
+
*.db
|
5 |
+
*.db-journal
|
6 |
+
.env
|
7 |
+
*.log
|
Dockerfile
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Use an official Python runtime as a parent image
|
2 |
+
FROM python:3.10-slim
|
3 |
+
|
4 |
+
# Set the working directory in the container
|
5 |
+
WORKDIR /code
|
6 |
+
|
7 |
+
# Copy the requirements file into the container
|
8 |
+
COPY ./requirements.txt /code/requirements.txt
|
9 |
+
|
10 |
+
# Install any needed packages specified in requirements.txt
|
11 |
+
RUN pip install --no-cache-dir --upgrade pip && \
|
12 |
+
pip install --no-cache-dir -r requirements.txt
|
13 |
+
|
14 |
+
# Copy the rest of the application code into the container
|
15 |
+
COPY ./app /code/app
|
16 |
+
COPY ./.env /code/.env # Copy .env file - for Hugging Face, use Secrets instead
|
17 |
+
|
18 |
+
# Make port 7860 available to the world outside this container (Gradio default)
|
19 |
+
EXPOSE 7860
|
20 |
+
|
21 |
+
# Ensure the database directory exists if using SQLite relative paths implicitly
|
22 |
+
# RUN mkdir -p /code/app
|
23 |
+
|
24 |
+
# Command to run the application using uvicorn
|
25 |
+
# It will run the FastAPI app instance created in app/main.py
|
26 |
+
# Host 0.0.0.0 is important to accept connections from outside the container
|
27 |
+
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "7860"]
|
README.md
CHANGED
@@ -7,6 +7,11 @@ sdk: docker
|
|
7 |
pinned: false
|
8 |
license: apache-2.0
|
9 |
short_description: An app demonstrating Gradio, FastAPI, Docker & SQL DB
|
|
|
|
|
|
|
10 |
---
|
11 |
|
12 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
|
|
7 |
pinned: false
|
8 |
license: apache-2.0
|
9 |
short_description: An app demonstrating Gradio, FastAPI, Docker & SQL DB
|
10 |
+
app_file: app/main.py
|
11 |
+
python_version: 3.12
|
12 |
+
port: 7860
|
13 |
---
|
14 |
|
15 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
16 |
+
|
17 |
+
|
app/__init__.py
ADDED
File without changes
|
app/api.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import APIRouter, HTTPException, status, Depends, WebSocket, WebSocketDisconnect
|
2 |
+
from fastapi.responses import JSONResponse
|
3 |
+
import logging
|
4 |
+
|
5 |
+
from . import schemas, crud, auth, models
|
6 |
+
from .websocket import manager
|
7 |
+
from .dependencies import get_required_current_user
|
8 |
+
|
9 |
+
router = APIRouter()
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
+
|
12 |
+
@router.post("/register", status_code=status.HTTP_201_CREATED, response_model=models.User)
|
13 |
+
async def register_user(user_in: schemas.UserCreate):
|
14 |
+
existing_user = await crud.get_user_by_email(user_in.email)
|
15 |
+
if existing_user:
|
16 |
+
raise HTTPException(
|
17 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
18 |
+
detail="Email already registered",
|
19 |
+
)
|
20 |
+
hashed_password = auth.get_password_hash(user_in.password)
|
21 |
+
user_id = await crud.create_user(user_in=user_in, hashed_password=hashed_password)
|
22 |
+
|
23 |
+
# Send notification to other connected users
|
24 |
+
notification_msg = schemas.Notification(
|
25 |
+
email=user_in.email,
|
26 |
+
message=f"New user registered: {user_in.email}"
|
27 |
+
).model_dump_json() # Use model_dump_json for Pydantic v2
|
28 |
+
|
29 |
+
# We broadcast but conceptually exclude the sender.
|
30 |
+
# Since the new user isn't connected via WebSocket *yet* during registration,
|
31 |
+
# we don't have a sender_id from the WebSocket context here.
|
32 |
+
# We can pass the new user_id to prevent potential self-notification if
|
33 |
+
# the WebSocket connection happens very quickly and maps the ID.
|
34 |
+
await manager.broadcast(notification_msg, sender_id=user_id)
|
35 |
+
|
36 |
+
# Return the newly created user's public info
|
37 |
+
# Fetch the user details to return them accurately
|
38 |
+
created_user = await crud.get_user_by_id(user_id)
|
39 |
+
if not created_user:
|
40 |
+
# This case should ideally not happen if create_user is successful
|
41 |
+
raise HTTPException(status_code=500, detail="Failed to retrieve created user")
|
42 |
+
|
43 |
+
# Convert UserInDB to User model for response
|
44 |
+
return models.User(id=created_user.id, email=created_user.email)
|
45 |
+
|
46 |
+
|
47 |
+
@router.post("/login", response_model=schemas.Token)
|
48 |
+
async def login_for_access_token(form_data: schemas.UserLogin):
|
49 |
+
user = await crud.get_user_by_email(form_data.email)
|
50 |
+
if not user or not auth.verify_password(form_data.password, user.hashed_password):
|
51 |
+
raise HTTPException(
|
52 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
53 |
+
detail="Incorrect email or password",
|
54 |
+
headers={"WWW-Authenticate": "Bearer"},
|
55 |
+
)
|
56 |
+
access_token = auth.create_session_token(user_id=user.id)
|
57 |
+
return {"access_token": access_token, "token_type": "bearer"}
|
58 |
+
|
59 |
+
@router.get("/users/me", response_model=models.User)
|
60 |
+
async def read_users_me(current_user: models.User = Depends(get_required_current_user)):
|
61 |
+
# This endpoint now relies on the dependency correctly getting the user from the token
|
62 |
+
# The token needs to be passed to get_required_current_user somehow.
|
63 |
+
# In Gradio's case, we might call this function directly from Gradio's backend
|
64 |
+
# passing the token from gr.State, rather than relying on HTTP Headers/Cookies.
|
65 |
+
# Let's adjust the dependency call when we use it in main.py.
|
66 |
+
return current_user
|
67 |
+
|
68 |
+
|
69 |
+
# WebSocket endpoint (can be associated with the main API router or separate)
|
70 |
+
@router.websocket("/ws/{user_id_token}")
|
71 |
+
async def websocket_endpoint(websocket: WebSocket, user_id_token: str):
|
72 |
+
"""
|
73 |
+
WebSocket endpoint. Connects user and listens for messages.
|
74 |
+
The user_id_token is the signed session token from login.
|
75 |
+
"""
|
76 |
+
user_id = await auth.get_user_id_from_token(user_id_token)
|
77 |
+
if user_id is None:
|
78 |
+
logger.warning(f"WebSocket connection rejected: Invalid token {user_id_token}")
|
79 |
+
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
|
80 |
+
return
|
81 |
+
|
82 |
+
await manager.connect(websocket, user_id)
|
83 |
+
try:
|
84 |
+
while True:
|
85 |
+
# Keep connection alive, maybe handle incoming messages if needed later
|
86 |
+
data = await websocket.receive_text()
|
87 |
+
# For now, we just broadcast on registration, not handle client messages
|
88 |
+
logger.debug(f"Received message from {user_id}: {data} (currently ignored)")
|
89 |
+
# Example: await websocket.send_text(f"Message text was: {data}")
|
90 |
+
except WebSocketDisconnect:
|
91 |
+
manager.disconnect(websocket)
|
92 |
+
logger.info(f"WebSocket disconnected for user {user_id}")
|
93 |
+
except Exception as e:
|
94 |
+
manager.disconnect(websocket)
|
95 |
+
logger.error(f"WebSocket error for user {user_id}: {e}")
|
96 |
+
# Optionally close with an error code
|
97 |
+
# await websocket.close(code=status.WS_1011_INTERNAL_ERROR)
|
app/auth.py
ADDED
File without changes
|
app/crud.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .database import database, users
|
2 |
+
from .models import UserInDB
|
3 |
+
from .schemas import UserCreate
|
4 |
+
from typing import Optional
|
5 |
+
|
6 |
+
async def get_user_by_email(email: str) -> Optional[UserInDB]:
|
7 |
+
query = users.select().where(users.c.email == email)
|
8 |
+
result = await database.fetch_one(query)
|
9 |
+
return UserInDB(**result) if result else None
|
10 |
+
|
11 |
+
async def create_user(user_in: UserCreate, hashed_password: str) -> int:
|
12 |
+
query = users.insert().values(
|
13 |
+
email=user_in.email,
|
14 |
+
hashed_password=hashed_password
|
15 |
+
)
|
16 |
+
last_record_id = await database.execute(query)
|
17 |
+
return last_record_id
|
18 |
+
|
19 |
+
async def get_user_by_id(user_id: int) -> Optional[UserInDB]:
|
20 |
+
query = users.select().where(users.c.id == user_id)
|
21 |
+
result = await database.fetch_one(query)
|
22 |
+
return UserInDB(**result) if result else None
|
app/database.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from databases import Database
|
3 |
+
from dotenv import load_dotenv
|
4 |
+
from sqlalchemy import create_engine, MetaData, Table, Column, Integer, String, text
|
5 |
+
|
6 |
+
load_dotenv()
|
7 |
+
|
8 |
+
DATABASE_URL = os.getenv("DATABASE_URL", "sqlite+aiosqlite:///./app/app.db")
|
9 |
+
|
10 |
+
# Use 'check_same_thread': False only for SQLite, it's generally not needed for server DBs
|
11 |
+
connect_args = {"check_same_thread": False} if DATABASE_URL.startswith("sqlite") else {}
|
12 |
+
|
13 |
+
database = Database(DATABASE_URL, connect_args=connect_args)
|
14 |
+
metadata = MetaData()
|
15 |
+
|
16 |
+
# Define Users table using SQLAlchemy Core (needed for initial setup)
|
17 |
+
users = Table(
|
18 |
+
"users",
|
19 |
+
metadata,
|
20 |
+
Column("id", Integer, primary_key=True),
|
21 |
+
Column("email", String, unique=True, index=True, nullable=False),
|
22 |
+
Column("hashed_password", String, nullable=False),
|
23 |
+
)
|
24 |
+
|
25 |
+
# Create the database and table if they don't exist
|
26 |
+
# This synchronous part runs once at startup usually
|
27 |
+
engine = create_engine(DATABASE_URL.replace("+aiosqlite", ""), connect_args=connect_args)
|
28 |
+
|
29 |
+
# Check if table exists before creating
|
30 |
+
# Using a try-except block for robustness across DB engines if needed later
|
31 |
+
try:
|
32 |
+
with engine.connect() as connection:
|
33 |
+
connection.execute(text("SELECT 1 FROM users LIMIT 1"))
|
34 |
+
print("Users table already exists.")
|
35 |
+
except Exception:
|
36 |
+
print("Users table not found, creating...")
|
37 |
+
metadata.create_all(engine)
|
38 |
+
print("Users table created.")
|
39 |
+
|
40 |
+
# Async connect/disconnect functions for FastAPI lifespan events
|
41 |
+
async def connect_db():
|
42 |
+
await database.connect()
|
43 |
+
print("Database connection established.")
|
44 |
+
|
45 |
+
async def disconnect_db():
|
46 |
+
await database.disconnect()
|
47 |
+
print("Database connection closed.")
|
app/dependencies.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import HTTPException, status, Request # Request may not be needed if token passed directly
|
2 |
+
from typing import Optional
|
3 |
+
from . import auth
|
4 |
+
from .models import User
|
5 |
+
|
6 |
+
# This dependency assumes the token is passed somehow,
|
7 |
+
# e.g., in headers (less likely from Gradio client code) or as an argument
|
8 |
+
# We will adapt how the token is passed from Gradio later.
|
9 |
+
async def get_optional_current_user(token: Optional[str] = None) -> Optional[User]:
|
10 |
+
if token:
|
11 |
+
user = await auth.get_current_user_from_token(token)
|
12 |
+
return user
|
13 |
+
return None
|
14 |
+
|
15 |
+
async def get_required_current_user(token: Optional[str] = None) -> User:
|
16 |
+
user = await get_optional_current_user(token)
|
17 |
+
if user is None:
|
18 |
+
raise HTTPException(
|
19 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
20 |
+
detail="Not authenticated",
|
21 |
+
)
|
22 |
+
return user
|
app/main.py
ADDED
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import httpx # Use httpx for async requests from Gradio backend to API
|
3 |
+
import websockets # Use websockets library to connect from Gradio backend
|
4 |
+
import asyncio
|
5 |
+
import json
|
6 |
+
import os
|
7 |
+
import logging
|
8 |
+
from contextlib import asynccontextmanager
|
9 |
+
|
10 |
+
from fastapi import FastAPI, Depends # Import FastAPI itself
|
11 |
+
from .api import router as api_router # Import the API router
|
12 |
+
from .database import connect_db, disconnect_db
|
13 |
+
from . import schemas, auth, dependencies # Import necessary modules
|
14 |
+
|
15 |
+
# Configure logging
|
16 |
+
logging.basicConfig(level=logging.INFO)
|
17 |
+
logger = logging.getLogger(__name__)
|
18 |
+
|
19 |
+
# Base URL for the API (since Gradio runs its own server, even if mounted)
|
20 |
+
# Assuming Gradio runs on 7860, FastAPI routes mounted under it
|
21 |
+
API_BASE_URL = "http://127.0.0.1:7860/api" # Adjust if needed
|
22 |
+
|
23 |
+
# --- FastAPI Lifespan Event ---
|
24 |
+
@asynccontextmanager
|
25 |
+
async def lifespan(app: FastAPI):
|
26 |
+
# Startup
|
27 |
+
await connect_db()
|
28 |
+
# Start background task for websocket listening if needed (or handle within component interaction)
|
29 |
+
yield
|
30 |
+
# Shutdown
|
31 |
+
await disconnect_db()
|
32 |
+
|
33 |
+
# Create the main FastAPI app instance that Gradio will use
|
34 |
+
# We attach our API routes to this instance.
|
35 |
+
app = FastAPI(lifespan=lifespan)
|
36 |
+
app.include_router(api_router, prefix="/api") # Mount API routes under /api
|
37 |
+
|
38 |
+
# --- Gradio UI Definition ---
|
39 |
+
|
40 |
+
# Store websocket connection globally (or within a class) for the Gradio app instance
|
41 |
+
# This is tricky because Gradio re-runs functions. State management is key.
|
42 |
+
# We'll connect the WebSocket *after* login and store the connection task/info in gr.State.
|
43 |
+
|
44 |
+
# --- Helper functions for Gradio calling the API ---
|
45 |
+
|
46 |
+
async def make_api_request(method: str, endpoint: str, **kwargs):
|
47 |
+
async with httpx.AsyncClient() as client:
|
48 |
+
url = f"{API_BASE_URL}{endpoint}"
|
49 |
+
try:
|
50 |
+
response = await client.request(method, url, **kwargs)
|
51 |
+
response.raise_for_status() # Raise exception for 4xx/5xx errors
|
52 |
+
return response.json()
|
53 |
+
except httpx.RequestError as e:
|
54 |
+
logger.error(f"HTTP Request failed: {e.request.method} {e.request.url} - {e}")
|
55 |
+
return {"error": f"Network error contacting API: {e}"}
|
56 |
+
except httpx.HTTPStatusError as e:
|
57 |
+
logger.error(f"HTTP Status error: {e.response.status_code} - {e.response.text}")
|
58 |
+
try:
|
59 |
+
detail = e.response.json().get("detail", e.response.text)
|
60 |
+
except json.JSONDecodeError:
|
61 |
+
detail = e.response.text
|
62 |
+
return {"error": f"API Error: {detail}"}
|
63 |
+
except Exception as e:
|
64 |
+
logger.error(f"Unexpected error during API call: {e}")
|
65 |
+
return {"error": f"An unexpected error occurred: {str(e)}"}
|
66 |
+
|
67 |
+
# --- WebSocket handling within Gradio ---
|
68 |
+
|
69 |
+
async def listen_to_websockets(token: str, notification_state: list):
|
70 |
+
"""Connects to WS and updates state list when a message arrives."""
|
71 |
+
if not token:
|
72 |
+
logger.warning("WebSocket listener: No token provided.")
|
73 |
+
return notification_state # Return current state if no token
|
74 |
+
|
75 |
+
ws_url_base = API_BASE_URL.replace("http", "ws")
|
76 |
+
ws_url = f"{ws_url_base}/ws/{token}"
|
77 |
+
logger.info(f"Attempting to connect to WebSocket: {ws_url}")
|
78 |
+
|
79 |
+
try:
|
80 |
+
# Add timeout to websocket connection attempt
|
81 |
+
async with asyncio.wait_for(websockets.connect(ws_url), timeout=10.0) as websocket:
|
82 |
+
logger.info(f"WebSocket connected successfully to {ws_url}")
|
83 |
+
while True:
|
84 |
+
try:
|
85 |
+
message_str = await websocket.recv()
|
86 |
+
logger.info(f"Received raw message: {message_str}")
|
87 |
+
message_data = json.loads(message_str)
|
88 |
+
logger.info(f"Parsed message data: {message_data}")
|
89 |
+
|
90 |
+
# Ensure it's the expected notification format
|
91 |
+
if message_data.get("type") == "new_user":
|
92 |
+
notification = schemas.Notification(**message_data)
|
93 |
+
# Prepend to show newest first
|
94 |
+
notification_state.insert(0, notification.message)
|
95 |
+
logger.info(f"Notification added: {notification.message}")
|
96 |
+
# Limit state history (optional)
|
97 |
+
if len(notification_state) > 10:
|
98 |
+
notification_state.pop()
|
99 |
+
# IMPORTANT: Need to trigger Gradio update. This function itself
|
100 |
+
# cannot directly update UI. It modifies the state, and we need
|
101 |
+
# a gr.update() returned by a Gradio event handler that *reads*
|
102 |
+
# this state. We use a polling mechanism or a hidden button trick.
|
103 |
+
# Let's use polling via `every=` parameter in Gradio.
|
104 |
+
# This function's primary job is just modifying the list.
|
105 |
+
# We return the modified list, but Gradio needs an event.
|
106 |
+
# ---> See the use of `gr.Textbox.change` and `every` below.
|
107 |
+
|
108 |
+
except websockets.ConnectionClosedOK:
|
109 |
+
logger.info("WebSocket connection closed normally.")
|
110 |
+
break
|
111 |
+
except websockets.ConnectionClosedError as e:
|
112 |
+
logger.error(f"WebSocket connection closed with error: {e}")
|
113 |
+
break
|
114 |
+
except json.JSONDecodeError:
|
115 |
+
logger.error(f"Failed to decode JSON from WebSocket message: {message_str}")
|
116 |
+
except Exception as e:
|
117 |
+
logger.error(f"Error in WebSocket listener loop: {e}")
|
118 |
+
# Avoid breaking the loop on transient errors maybe? Add delay?
|
119 |
+
await asyncio.sleep(1) # Short delay before retrying receive
|
120 |
+
except asyncio.TimeoutError:
|
121 |
+
logger.error(f"WebSocket connection timed out: {ws_url}")
|
122 |
+
except websockets.exceptions.InvalidURI:
|
123 |
+
logger.error(f"Invalid WebSocket URI: {ws_url}")
|
124 |
+
except websockets.exceptions.WebSocketException as e:
|
125 |
+
logger.error(f"WebSocket connection failed: {e}")
|
126 |
+
except Exception as e:
|
127 |
+
logger.error(f"Unexpected error connecting/listening to WebSocket: {e}")
|
128 |
+
|
129 |
+
# Return the state as it is if connection failed or ended
|
130 |
+
# The calling Gradio function will handle this return value.
|
131 |
+
return notification_state
|
132 |
+
|
133 |
+
|
134 |
+
# --- Gradio Interface ---
|
135 |
+
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
136 |
+
# State variables
|
137 |
+
# Holds the session token after login
|
138 |
+
auth_token = gr.State(None)
|
139 |
+
# Holds user info {id, email} after login
|
140 |
+
user_info = gr.State(None)
|
141 |
+
# Holds the list of notification messages
|
142 |
+
notification_list = gr.State([])
|
143 |
+
# Holds the asyncio task for the WebSocket listener
|
144 |
+
websocket_task = gr.State(None)
|
145 |
+
|
146 |
+
# --- UI Components ---
|
147 |
+
with gr.Tabs() as tabs:
|
148 |
+
# --- Registration Tab ---
|
149 |
+
with gr.TabItem("Register", id="register_tab"):
|
150 |
+
gr.Markdown("## Create a new account")
|
151 |
+
reg_email = gr.Textbox(label="Email", type="email")
|
152 |
+
reg_password = gr.Textbox(label="Password (min 8 chars)", type="password")
|
153 |
+
reg_confirm_password = gr.Textbox(label="Confirm Password", type="password")
|
154 |
+
reg_button = gr.Button("Register")
|
155 |
+
reg_status = gr.Textbox(label="Status", interactive=False)
|
156 |
+
|
157 |
+
# --- Login Tab ---
|
158 |
+
with gr.TabItem("Login", id="login_tab"):
|
159 |
+
gr.Markdown("## Login to your account")
|
160 |
+
login_email = gr.Textbox(label="Email", type="email")
|
161 |
+
login_password = gr.Textbox(label="Password", type="password")
|
162 |
+
login_button = gr.Button("Login")
|
163 |
+
login_status = gr.Textbox(label="Status", interactive=False)
|
164 |
+
|
165 |
+
# --- Welcome Tab (shown after login) ---
|
166 |
+
with gr.TabItem("Welcome", id="welcome_tab", visible=False) as welcome_tab:
|
167 |
+
gr.Markdown("## Welcome!", elem_id="welcome_header")
|
168 |
+
welcome_message = gr.Markdown("", elem_id="welcome_message")
|
169 |
+
logout_button = gr.Button("Logout")
|
170 |
+
gr.Markdown("---") # Separator
|
171 |
+
gr.Markdown("## Real-time Notifications")
|
172 |
+
# Textbox to display notifications, updated periodically
|
173 |
+
notification_display = gr.Textbox(
|
174 |
+
label="New User Alerts",
|
175 |
+
lines=5,
|
176 |
+
max_lines=10,
|
177 |
+
interactive=False,
|
178 |
+
# The `every=1` makes Gradio call the update function every 1 second
|
179 |
+
# This function will read the `notification_list` state
|
180 |
+
every=1
|
181 |
+
)
|
182 |
+
|
183 |
+
# --- Event Handlers ---
|
184 |
+
|
185 |
+
# Registration Logic
|
186 |
+
async def handle_register(email, password, confirm_password):
|
187 |
+
if not email or not password or not confirm_password:
|
188 |
+
return gr.update(value="Please fill in all fields.")
|
189 |
+
if password != confirm_password:
|
190 |
+
return gr.update(value="Passwords do not match.")
|
191 |
+
if len(password) < 8:
|
192 |
+
return gr.update(value="Password must be at least 8 characters long.")
|
193 |
+
|
194 |
+
payload = {"email": email, "password": password}
|
195 |
+
result = await make_api_request("post", "/register", json=payload)
|
196 |
+
|
197 |
+
if "error" in result:
|
198 |
+
return gr.update(value=f"Registration failed: {result['error']}")
|
199 |
+
else:
|
200 |
+
# Optionally switch to login tab after successful registration
|
201 |
+
return gr.update(value=f"Registration successful for {result.get('email')}! Please log in.")
|
202 |
+
|
203 |
+
reg_button.click(
|
204 |
+
handle_register,
|
205 |
+
inputs=[reg_email, reg_password, reg_confirm_password],
|
206 |
+
outputs=[reg_status]
|
207 |
+
)
|
208 |
+
|
209 |
+
# Login Logic
|
210 |
+
async def handle_login(email, password, current_task):
|
211 |
+
if not email or not password:
|
212 |
+
return gr.update(value="Please enter email and password."), None, None, None, gr.update(visible=False), current_task
|
213 |
+
|
214 |
+
payload = {"email": email, "password": password}
|
215 |
+
result = await make_api_request("post", "/login", json=payload)
|
216 |
+
|
217 |
+
if "error" in result:
|
218 |
+
return gr.update(value=f"Login failed: {result['error']}"), None, None, None, gr.update(visible=False), current_task
|
219 |
+
else:
|
220 |
+
token = result.get("access_token")
|
221 |
+
# Fetch user details using the token
|
222 |
+
user_data = await dependencies.get_optional_current_user(token) # Use dependency directly
|
223 |
+
|
224 |
+
if not user_data:
|
225 |
+
# This shouldn't happen if login succeeded, but check anyway
|
226 |
+
return gr.update(value="Login succeeded but failed to fetch user data."), None, None, None, gr.update(visible=False), current_task
|
227 |
+
|
228 |
+
# Cancel any existing websocket listener task before starting a new one
|
229 |
+
if current_task and not current_task.done():
|
230 |
+
current_task.cancel()
|
231 |
+
try:
|
232 |
+
await current_task # Wait for cancellation
|
233 |
+
except asyncio.CancelledError:
|
234 |
+
logger.info("Previous WebSocket task cancelled.")
|
235 |
+
|
236 |
+
# Start the WebSocket listener task in the background
|
237 |
+
# We pass the notification_list state *object* itself, which the task will modify
|
238 |
+
new_task = asyncio.create_task(listen_to_websockets(token, notification_list.value)) # Pass the list
|
239 |
+
|
240 |
+
# Update state and UI
|
241 |
+
welcome_msg = f"Welcome, {user_data.email}!"
|
242 |
+
# Switch tabs and show welcome message
|
243 |
+
return (
|
244 |
+
gr.update(value="Login successful!"), # login_status
|
245 |
+
token, # auth_token state
|
246 |
+
user_data.model_dump(), # user_info state (store as dict)
|
247 |
+
gr.update(selected="welcome_tab"), # Switch Tabs
|
248 |
+
gr.update(visible=True), # Make welcome tab visible
|
249 |
+
gr.update(value=welcome_msg), # Update welcome message markdown
|
250 |
+
new_task # websocket_task state
|
251 |
+
)
|
252 |
+
|
253 |
+
login_button.click(
|
254 |
+
handle_login,
|
255 |
+
inputs=[login_email, login_password, websocket_task],
|
256 |
+
outputs=[login_status, auth_token, user_info, tabs, welcome_tab, welcome_message, websocket_task]
|
257 |
+
)
|
258 |
+
|
259 |
+
|
260 |
+
# Function to update the notification display based on the state
|
261 |
+
# This function is triggered by the `every=1` on the notification_display Textbox
|
262 |
+
def update_notification_ui(notif_list_state):
|
263 |
+
# Join the list items into a string for display
|
264 |
+
return "\n".join(notif_list_state)
|
265 |
+
|
266 |
+
notification_display.change( # Use .change with every= setup on the component
|
267 |
+
fn=update_notification_ui,
|
268 |
+
inputs=[notification_list], # Read the state
|
269 |
+
outputs=[notification_display] # Update the component
|
270 |
+
)
|
271 |
+
|
272 |
+
|
273 |
+
# Logout Logic
|
274 |
+
async def handle_logout(current_task):
|
275 |
+
# Cancel the websocket listener task if it's running
|
276 |
+
if current_task and not current_task.done():
|
277 |
+
current_task.cancel()
|
278 |
+
try:
|
279 |
+
await current_task
|
280 |
+
except asyncio.CancelledError:
|
281 |
+
logger.info("WebSocket task cancelled on logout.")
|
282 |
+
|
283 |
+
# Clear state and switch back to login tab
|
284 |
+
return (
|
285 |
+
None, # Clear auth_token
|
286 |
+
None, # Clear user_info
|
287 |
+
[], # Clear notifications
|
288 |
+
None, # Clear websocket_task
|
289 |
+
gr.update(selected="login_tab"),# Switch Tabs
|
290 |
+
gr.update(visible=False), # Hide welcome tab
|
291 |
+
gr.update(value=""), # Clear welcome message
|
292 |
+
gr.update(value="") # Clear login status
|
293 |
+
)
|
294 |
+
|
295 |
+
logout_button.click(
|
296 |
+
handle_logout,
|
297 |
+
inputs=[websocket_task],
|
298 |
+
outputs=[
|
299 |
+
auth_token,
|
300 |
+
user_info,
|
301 |
+
notification_list,
|
302 |
+
websocket_task,
|
303 |
+
tabs,
|
304 |
+
welcome_tab,
|
305 |
+
welcome_message,
|
306 |
+
login_status
|
307 |
+
]
|
308 |
+
)
|
309 |
+
|
310 |
+
# Mount the Gradio app onto the FastAPI app at the root
|
311 |
+
app = gr.mount_gradio_app(app, demo, path="/")
|
312 |
+
|
313 |
+
# If running this file directly (for local testing)
|
314 |
+
# Use uvicorn to run the FastAPI app (which now includes Gradio)
|
315 |
+
if __name__ == "__main__":
|
316 |
+
import uvicorn
|
317 |
+
# Use port 7860 as Gradio prefers, host 0.0.0.0 for Docker
|
318 |
+
uvicorn.run(app, host="0.0.0.0", port=7860)
|
app/models.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pydantic import BaseModel, EmailStr
|
2 |
+
|
3 |
+
class User(BaseModel):
|
4 |
+
id: int
|
5 |
+
email: EmailStr
|
6 |
+
|
7 |
+
class UserInDB(User):
|
8 |
+
hashed_password: str
|
app/schemas.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pydantic import BaseModel, EmailStr, Field
|
2 |
+
|
3 |
+
class UserCreate(BaseModel):
|
4 |
+
email: EmailStr
|
5 |
+
password: str = Field(min_length=8)
|
6 |
+
|
7 |
+
class UserLogin(BaseModel):
|
8 |
+
email: EmailStr
|
9 |
+
password: str
|
10 |
+
|
11 |
+
class Token(BaseModel):
|
12 |
+
access_token: str
|
13 |
+
token_type: str = "bearer"
|
14 |
+
|
15 |
+
class Notification(BaseModel):
|
16 |
+
type: str = "new_user"
|
17 |
+
email: EmailStr
|
18 |
+
message: str
|
app/websocket.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import WebSocket
|
2 |
+
from typing import List, Dict, Optional
|
3 |
+
import json
|
4 |
+
import logging
|
5 |
+
|
6 |
+
logger = logging.getLogger(__name__)
|
7 |
+
|
8 |
+
class ConnectionManager:
|
9 |
+
def __init__(self):
|
10 |
+
# Store connections with user ID if available for targeted messaging later
|
11 |
+
self.active_connections: Dict[Optional[int], List[WebSocket]] = {}
|
12 |
+
# Map websocket to user_id for easier removal
|
13 |
+
self.websocket_to_user: Dict[WebSocket, Optional[int]] = {}
|
14 |
+
|
15 |
+
async def connect(self, websocket: WebSocket, user_id: Optional[int] = None):
|
16 |
+
await websocket.accept()
|
17 |
+
if user_id not in self.active_connections:
|
18 |
+
self.active_connections[user_id] = []
|
19 |
+
self.active_connections[user_id].append(websocket)
|
20 |
+
self.websocket_to_user[websocket] = user_id
|
21 |
+
logger.info(f"WebSocket connected: {websocket.client.host}:{websocket.client.port}, User ID: {user_id}")
|
22 |
+
logger.info(f"Total connections: {sum(len(conns) for conns in self.active_connections.values())}")
|
23 |
+
|
24 |
+
|
25 |
+
def disconnect(self, websocket: WebSocket):
|
26 |
+
user_id = self.websocket_to_user.pop(websocket, None)
|
27 |
+
if user_id in self.active_connections:
|
28 |
+
try:
|
29 |
+
self.active_connections[user_id].remove(websocket)
|
30 |
+
if not self.active_connections[user_id]:
|
31 |
+
del self.active_connections[user_id]
|
32 |
+
except ValueError:
|
33 |
+
logger.warning(f"WebSocket not found in active list for user {user_id} during disconnect.")
|
34 |
+
|
35 |
+
logger.info(f"WebSocket disconnected: {websocket.client.host}:{websocket.client.port}, User ID: {user_id}")
|
36 |
+
logger.info(f"Total connections: {sum(len(conns) for conns in self.active_connections.values())}")
|
37 |
+
|
38 |
+
|
39 |
+
async def broadcast(self, message: str, sender_id: Optional[int] = None):
|
40 |
+
disconnected_websockets = []
|
41 |
+
# Iterate through all connections
|
42 |
+
all_websockets = [ws for user_conns in self.active_connections.values() for ws in user_conns]
|
43 |
+
logger.info(f"Broadcasting to {len(all_websockets)} connections (excluding sender if ID matches). Sender ID: {sender_id}")
|
44 |
+
|
45 |
+
for websocket in all_websockets:
|
46 |
+
# Send to all *other* users (or all if sender_id is None)
|
47 |
+
ws_user_id = self.websocket_to_user.get(websocket)
|
48 |
+
# Requirement: "all *other* connected users should see"
|
49 |
+
# Send if the websocket isn't associated with the sender_id
|
50 |
+
# (Note: If a user has multiple tabs/connections open, they might still receive it
|
51 |
+
# if the sender_id check only excludes one specific connection. This simple broadcast
|
52 |
+
# targets users based on their ID at connection time).
|
53 |
+
# Let's refine: Send if the user_id associated with the WS is not the sender_id
|
54 |
+
if ws_user_id != sender_id:
|
55 |
+
try:
|
56 |
+
await websocket.send_text(message)
|
57 |
+
logger.debug(f"Message sent to user {ws_user_id}")
|
58 |
+
except Exception as e:
|
59 |
+
logger.error(f"Error sending message to websocket {websocket.client}: {e}. Marking for disconnect.")
|
60 |
+
disconnected_websockets.append(websocket)
|
61 |
+
|
62 |
+
# Clean up connections that failed during broadcast
|
63 |
+
for ws in disconnected_websockets:
|
64 |
+
self.disconnect(ws)
|
65 |
+
|
66 |
+
manager = ConnectionManager()
|
requirements.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
fastapi==0.111.0
|
2 |
+
uvicorn[standard]==0.29.0
|
3 |
+
gradio==4.29.0
|
4 |
+
passlib[bcrypt]==1.7.4
|
5 |
+
python-dotenv==1.0.1
|
6 |
+
databases[sqlite]==0.9.0 # Async DB access
|
7 |
+
sqlalchemy==2.0.29 # Core needed by `databases` or for potential future use
|
8 |
+
pydantic==2.7.1
|
9 |
+
python-multipart==0.0.9 # For form data in FastAPI
|
10 |
+
itsdangerous==2.1.2 # Simple secure signing for session IDs
|
11 |
+
websockets>=11.0.3,<13.0 # Ensure compatibility
|