amaye15 commited on
Commit
bfe88a9
·
1 Parent(s): a66a02c

Intial Deployment

Browse files
Files changed (15) hide show
  1. .env.example +2 -0
  2. .gitignore +7 -0
  3. Dockerfile +27 -0
  4. README.md +5 -0
  5. app/__init__.py +0 -0
  6. app/api.py +97 -0
  7. app/auth.py +0 -0
  8. app/crud.py +22 -0
  9. app/database.py +47 -0
  10. app/dependencies.py +22 -0
  11. app/main.py +318 -0
  12. app/models.py +8 -0
  13. app/schemas.py +18 -0
  14. app/websocket.py +66 -0
  15. requirements.txt +11 -0
.env.example ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ DATABASE_URL=sqlite+aiosqlite:///./app/app.db
2
+ SECRET_KEY=a_very_secret_key_change_this_in_production # For signing session IDs
.gitignore ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ venv/
2
+ __pycache__/
3
+ *.pyc
4
+ *.db
5
+ *.db-journal
6
+ .env
7
+ *.log
Dockerfile ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use an official Python runtime as a parent image
2
+ FROM python:3.10-slim
3
+
4
+ # Set the working directory in the container
5
+ WORKDIR /code
6
+
7
+ # Copy the requirements file into the container
8
+ COPY ./requirements.txt /code/requirements.txt
9
+
10
+ # Install any needed packages specified in requirements.txt
11
+ RUN pip install --no-cache-dir --upgrade pip && \
12
+ pip install --no-cache-dir -r requirements.txt
13
+
14
+ # Copy the rest of the application code into the container
15
+ COPY ./app /code/app
16
+ COPY ./.env /code/.env # Copy .env file - for Hugging Face, use Secrets instead
17
+
18
+ # Make port 7860 available to the world outside this container (Gradio default)
19
+ EXPOSE 7860
20
+
21
+ # Ensure the database directory exists if using SQLite relative paths implicitly
22
+ # RUN mkdir -p /code/app
23
+
24
+ # Command to run the application using uvicorn
25
+ # It will run the FastAPI app instance created in app/main.py
26
+ # Host 0.0.0.0 is important to accept connections from outside the container
27
+ CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "7860"]
README.md CHANGED
@@ -7,6 +7,11 @@ sdk: docker
7
  pinned: false
8
  license: apache-2.0
9
  short_description: An app demonstrating Gradio, FastAPI, Docker & SQL DB
 
 
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
7
  pinned: false
8
  license: apache-2.0
9
  short_description: An app demonstrating Gradio, FastAPI, Docker & SQL DB
10
+ app_file: app/main.py
11
+ python_version: 3.12
12
+ port: 7860
13
  ---
14
 
15
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
16
+
17
+
app/__init__.py ADDED
File without changes
app/api.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter, HTTPException, status, Depends, WebSocket, WebSocketDisconnect
2
+ from fastapi.responses import JSONResponse
3
+ import logging
4
+
5
+ from . import schemas, crud, auth, models
6
+ from .websocket import manager
7
+ from .dependencies import get_required_current_user
8
+
9
+ router = APIRouter()
10
+ logger = logging.getLogger(__name__)
11
+
12
+ @router.post("/register", status_code=status.HTTP_201_CREATED, response_model=models.User)
13
+ async def register_user(user_in: schemas.UserCreate):
14
+ existing_user = await crud.get_user_by_email(user_in.email)
15
+ if existing_user:
16
+ raise HTTPException(
17
+ status_code=status.HTTP_400_BAD_REQUEST,
18
+ detail="Email already registered",
19
+ )
20
+ hashed_password = auth.get_password_hash(user_in.password)
21
+ user_id = await crud.create_user(user_in=user_in, hashed_password=hashed_password)
22
+
23
+ # Send notification to other connected users
24
+ notification_msg = schemas.Notification(
25
+ email=user_in.email,
26
+ message=f"New user registered: {user_in.email}"
27
+ ).model_dump_json() # Use model_dump_json for Pydantic v2
28
+
29
+ # We broadcast but conceptually exclude the sender.
30
+ # Since the new user isn't connected via WebSocket *yet* during registration,
31
+ # we don't have a sender_id from the WebSocket context here.
32
+ # We can pass the new user_id to prevent potential self-notification if
33
+ # the WebSocket connection happens very quickly and maps the ID.
34
+ await manager.broadcast(notification_msg, sender_id=user_id)
35
+
36
+ # Return the newly created user's public info
37
+ # Fetch the user details to return them accurately
38
+ created_user = await crud.get_user_by_id(user_id)
39
+ if not created_user:
40
+ # This case should ideally not happen if create_user is successful
41
+ raise HTTPException(status_code=500, detail="Failed to retrieve created user")
42
+
43
+ # Convert UserInDB to User model for response
44
+ return models.User(id=created_user.id, email=created_user.email)
45
+
46
+
47
+ @router.post("/login", response_model=schemas.Token)
48
+ async def login_for_access_token(form_data: schemas.UserLogin):
49
+ user = await crud.get_user_by_email(form_data.email)
50
+ if not user or not auth.verify_password(form_data.password, user.hashed_password):
51
+ raise HTTPException(
52
+ status_code=status.HTTP_401_UNAUTHORIZED,
53
+ detail="Incorrect email or password",
54
+ headers={"WWW-Authenticate": "Bearer"},
55
+ )
56
+ access_token = auth.create_session_token(user_id=user.id)
57
+ return {"access_token": access_token, "token_type": "bearer"}
58
+
59
+ @router.get("/users/me", response_model=models.User)
60
+ async def read_users_me(current_user: models.User = Depends(get_required_current_user)):
61
+ # This endpoint now relies on the dependency correctly getting the user from the token
62
+ # The token needs to be passed to get_required_current_user somehow.
63
+ # In Gradio's case, we might call this function directly from Gradio's backend
64
+ # passing the token from gr.State, rather than relying on HTTP Headers/Cookies.
65
+ # Let's adjust the dependency call when we use it in main.py.
66
+ return current_user
67
+
68
+
69
+ # WebSocket endpoint (can be associated with the main API router or separate)
70
+ @router.websocket("/ws/{user_id_token}")
71
+ async def websocket_endpoint(websocket: WebSocket, user_id_token: str):
72
+ """
73
+ WebSocket endpoint. Connects user and listens for messages.
74
+ The user_id_token is the signed session token from login.
75
+ """
76
+ user_id = await auth.get_user_id_from_token(user_id_token)
77
+ if user_id is None:
78
+ logger.warning(f"WebSocket connection rejected: Invalid token {user_id_token}")
79
+ await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
80
+ return
81
+
82
+ await manager.connect(websocket, user_id)
83
+ try:
84
+ while True:
85
+ # Keep connection alive, maybe handle incoming messages if needed later
86
+ data = await websocket.receive_text()
87
+ # For now, we just broadcast on registration, not handle client messages
88
+ logger.debug(f"Received message from {user_id}: {data} (currently ignored)")
89
+ # Example: await websocket.send_text(f"Message text was: {data}")
90
+ except WebSocketDisconnect:
91
+ manager.disconnect(websocket)
92
+ logger.info(f"WebSocket disconnected for user {user_id}")
93
+ except Exception as e:
94
+ manager.disconnect(websocket)
95
+ logger.error(f"WebSocket error for user {user_id}: {e}")
96
+ # Optionally close with an error code
97
+ # await websocket.close(code=status.WS_1011_INTERNAL_ERROR)
app/auth.py ADDED
File without changes
app/crud.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .database import database, users
2
+ from .models import UserInDB
3
+ from .schemas import UserCreate
4
+ from typing import Optional
5
+
6
+ async def get_user_by_email(email: str) -> Optional[UserInDB]:
7
+ query = users.select().where(users.c.email == email)
8
+ result = await database.fetch_one(query)
9
+ return UserInDB(**result) if result else None
10
+
11
+ async def create_user(user_in: UserCreate, hashed_password: str) -> int:
12
+ query = users.insert().values(
13
+ email=user_in.email,
14
+ hashed_password=hashed_password
15
+ )
16
+ last_record_id = await database.execute(query)
17
+ return last_record_id
18
+
19
+ async def get_user_by_id(user_id: int) -> Optional[UserInDB]:
20
+ query = users.select().where(users.c.id == user_id)
21
+ result = await database.fetch_one(query)
22
+ return UserInDB(**result) if result else None
app/database.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from databases import Database
3
+ from dotenv import load_dotenv
4
+ from sqlalchemy import create_engine, MetaData, Table, Column, Integer, String, text
5
+
6
+ load_dotenv()
7
+
8
+ DATABASE_URL = os.getenv("DATABASE_URL", "sqlite+aiosqlite:///./app/app.db")
9
+
10
+ # Use 'check_same_thread': False only for SQLite, it's generally not needed for server DBs
11
+ connect_args = {"check_same_thread": False} if DATABASE_URL.startswith("sqlite") else {}
12
+
13
+ database = Database(DATABASE_URL, connect_args=connect_args)
14
+ metadata = MetaData()
15
+
16
+ # Define Users table using SQLAlchemy Core (needed for initial setup)
17
+ users = Table(
18
+ "users",
19
+ metadata,
20
+ Column("id", Integer, primary_key=True),
21
+ Column("email", String, unique=True, index=True, nullable=False),
22
+ Column("hashed_password", String, nullable=False),
23
+ )
24
+
25
+ # Create the database and table if they don't exist
26
+ # This synchronous part runs once at startup usually
27
+ engine = create_engine(DATABASE_URL.replace("+aiosqlite", ""), connect_args=connect_args)
28
+
29
+ # Check if table exists before creating
30
+ # Using a try-except block for robustness across DB engines if needed later
31
+ try:
32
+ with engine.connect() as connection:
33
+ connection.execute(text("SELECT 1 FROM users LIMIT 1"))
34
+ print("Users table already exists.")
35
+ except Exception:
36
+ print("Users table not found, creating...")
37
+ metadata.create_all(engine)
38
+ print("Users table created.")
39
+
40
+ # Async connect/disconnect functions for FastAPI lifespan events
41
+ async def connect_db():
42
+ await database.connect()
43
+ print("Database connection established.")
44
+
45
+ async def disconnect_db():
46
+ await database.disconnect()
47
+ print("Database connection closed.")
app/dependencies.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import HTTPException, status, Request # Request may not be needed if token passed directly
2
+ from typing import Optional
3
+ from . import auth
4
+ from .models import User
5
+
6
+ # This dependency assumes the token is passed somehow,
7
+ # e.g., in headers (less likely from Gradio client code) or as an argument
8
+ # We will adapt how the token is passed from Gradio later.
9
+ async def get_optional_current_user(token: Optional[str] = None) -> Optional[User]:
10
+ if token:
11
+ user = await auth.get_current_user_from_token(token)
12
+ return user
13
+ return None
14
+
15
+ async def get_required_current_user(token: Optional[str] = None) -> User:
16
+ user = await get_optional_current_user(token)
17
+ if user is None:
18
+ raise HTTPException(
19
+ status_code=status.HTTP_401_UNAUTHORIZED,
20
+ detail="Not authenticated",
21
+ )
22
+ return user
app/main.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import httpx # Use httpx for async requests from Gradio backend to API
3
+ import websockets # Use websockets library to connect from Gradio backend
4
+ import asyncio
5
+ import json
6
+ import os
7
+ import logging
8
+ from contextlib import asynccontextmanager
9
+
10
+ from fastapi import FastAPI, Depends # Import FastAPI itself
11
+ from .api import router as api_router # Import the API router
12
+ from .database import connect_db, disconnect_db
13
+ from . import schemas, auth, dependencies # Import necessary modules
14
+
15
+ # Configure logging
16
+ logging.basicConfig(level=logging.INFO)
17
+ logger = logging.getLogger(__name__)
18
+
19
+ # Base URL for the API (since Gradio runs its own server, even if mounted)
20
+ # Assuming Gradio runs on 7860, FastAPI routes mounted under it
21
+ API_BASE_URL = "http://127.0.0.1:7860/api" # Adjust if needed
22
+
23
+ # --- FastAPI Lifespan Event ---
24
+ @asynccontextmanager
25
+ async def lifespan(app: FastAPI):
26
+ # Startup
27
+ await connect_db()
28
+ # Start background task for websocket listening if needed (or handle within component interaction)
29
+ yield
30
+ # Shutdown
31
+ await disconnect_db()
32
+
33
+ # Create the main FastAPI app instance that Gradio will use
34
+ # We attach our API routes to this instance.
35
+ app = FastAPI(lifespan=lifespan)
36
+ app.include_router(api_router, prefix="/api") # Mount API routes under /api
37
+
38
+ # --- Gradio UI Definition ---
39
+
40
+ # Store websocket connection globally (or within a class) for the Gradio app instance
41
+ # This is tricky because Gradio re-runs functions. State management is key.
42
+ # We'll connect the WebSocket *after* login and store the connection task/info in gr.State.
43
+
44
+ # --- Helper functions for Gradio calling the API ---
45
+
46
+ async def make_api_request(method: str, endpoint: str, **kwargs):
47
+ async with httpx.AsyncClient() as client:
48
+ url = f"{API_BASE_URL}{endpoint}"
49
+ try:
50
+ response = await client.request(method, url, **kwargs)
51
+ response.raise_for_status() # Raise exception for 4xx/5xx errors
52
+ return response.json()
53
+ except httpx.RequestError as e:
54
+ logger.error(f"HTTP Request failed: {e.request.method} {e.request.url} - {e}")
55
+ return {"error": f"Network error contacting API: {e}"}
56
+ except httpx.HTTPStatusError as e:
57
+ logger.error(f"HTTP Status error: {e.response.status_code} - {e.response.text}")
58
+ try:
59
+ detail = e.response.json().get("detail", e.response.text)
60
+ except json.JSONDecodeError:
61
+ detail = e.response.text
62
+ return {"error": f"API Error: {detail}"}
63
+ except Exception as e:
64
+ logger.error(f"Unexpected error during API call: {e}")
65
+ return {"error": f"An unexpected error occurred: {str(e)}"}
66
+
67
+ # --- WebSocket handling within Gradio ---
68
+
69
+ async def listen_to_websockets(token: str, notification_state: list):
70
+ """Connects to WS and updates state list when a message arrives."""
71
+ if not token:
72
+ logger.warning("WebSocket listener: No token provided.")
73
+ return notification_state # Return current state if no token
74
+
75
+ ws_url_base = API_BASE_URL.replace("http", "ws")
76
+ ws_url = f"{ws_url_base}/ws/{token}"
77
+ logger.info(f"Attempting to connect to WebSocket: {ws_url}")
78
+
79
+ try:
80
+ # Add timeout to websocket connection attempt
81
+ async with asyncio.wait_for(websockets.connect(ws_url), timeout=10.0) as websocket:
82
+ logger.info(f"WebSocket connected successfully to {ws_url}")
83
+ while True:
84
+ try:
85
+ message_str = await websocket.recv()
86
+ logger.info(f"Received raw message: {message_str}")
87
+ message_data = json.loads(message_str)
88
+ logger.info(f"Parsed message data: {message_data}")
89
+
90
+ # Ensure it's the expected notification format
91
+ if message_data.get("type") == "new_user":
92
+ notification = schemas.Notification(**message_data)
93
+ # Prepend to show newest first
94
+ notification_state.insert(0, notification.message)
95
+ logger.info(f"Notification added: {notification.message}")
96
+ # Limit state history (optional)
97
+ if len(notification_state) > 10:
98
+ notification_state.pop()
99
+ # IMPORTANT: Need to trigger Gradio update. This function itself
100
+ # cannot directly update UI. It modifies the state, and we need
101
+ # a gr.update() returned by a Gradio event handler that *reads*
102
+ # this state. We use a polling mechanism or a hidden button trick.
103
+ # Let's use polling via `every=` parameter in Gradio.
104
+ # This function's primary job is just modifying the list.
105
+ # We return the modified list, but Gradio needs an event.
106
+ # ---> See the use of `gr.Textbox.change` and `every` below.
107
+
108
+ except websockets.ConnectionClosedOK:
109
+ logger.info("WebSocket connection closed normally.")
110
+ break
111
+ except websockets.ConnectionClosedError as e:
112
+ logger.error(f"WebSocket connection closed with error: {e}")
113
+ break
114
+ except json.JSONDecodeError:
115
+ logger.error(f"Failed to decode JSON from WebSocket message: {message_str}")
116
+ except Exception as e:
117
+ logger.error(f"Error in WebSocket listener loop: {e}")
118
+ # Avoid breaking the loop on transient errors maybe? Add delay?
119
+ await asyncio.sleep(1) # Short delay before retrying receive
120
+ except asyncio.TimeoutError:
121
+ logger.error(f"WebSocket connection timed out: {ws_url}")
122
+ except websockets.exceptions.InvalidURI:
123
+ logger.error(f"Invalid WebSocket URI: {ws_url}")
124
+ except websockets.exceptions.WebSocketException as e:
125
+ logger.error(f"WebSocket connection failed: {e}")
126
+ except Exception as e:
127
+ logger.error(f"Unexpected error connecting/listening to WebSocket: {e}")
128
+
129
+ # Return the state as it is if connection failed or ended
130
+ # The calling Gradio function will handle this return value.
131
+ return notification_state
132
+
133
+
134
+ # --- Gradio Interface ---
135
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
136
+ # State variables
137
+ # Holds the session token after login
138
+ auth_token = gr.State(None)
139
+ # Holds user info {id, email} after login
140
+ user_info = gr.State(None)
141
+ # Holds the list of notification messages
142
+ notification_list = gr.State([])
143
+ # Holds the asyncio task for the WebSocket listener
144
+ websocket_task = gr.State(None)
145
+
146
+ # --- UI Components ---
147
+ with gr.Tabs() as tabs:
148
+ # --- Registration Tab ---
149
+ with gr.TabItem("Register", id="register_tab"):
150
+ gr.Markdown("## Create a new account")
151
+ reg_email = gr.Textbox(label="Email", type="email")
152
+ reg_password = gr.Textbox(label="Password (min 8 chars)", type="password")
153
+ reg_confirm_password = gr.Textbox(label="Confirm Password", type="password")
154
+ reg_button = gr.Button("Register")
155
+ reg_status = gr.Textbox(label="Status", interactive=False)
156
+
157
+ # --- Login Tab ---
158
+ with gr.TabItem("Login", id="login_tab"):
159
+ gr.Markdown("## Login to your account")
160
+ login_email = gr.Textbox(label="Email", type="email")
161
+ login_password = gr.Textbox(label="Password", type="password")
162
+ login_button = gr.Button("Login")
163
+ login_status = gr.Textbox(label="Status", interactive=False)
164
+
165
+ # --- Welcome Tab (shown after login) ---
166
+ with gr.TabItem("Welcome", id="welcome_tab", visible=False) as welcome_tab:
167
+ gr.Markdown("## Welcome!", elem_id="welcome_header")
168
+ welcome_message = gr.Markdown("", elem_id="welcome_message")
169
+ logout_button = gr.Button("Logout")
170
+ gr.Markdown("---") # Separator
171
+ gr.Markdown("## Real-time Notifications")
172
+ # Textbox to display notifications, updated periodically
173
+ notification_display = gr.Textbox(
174
+ label="New User Alerts",
175
+ lines=5,
176
+ max_lines=10,
177
+ interactive=False,
178
+ # The `every=1` makes Gradio call the update function every 1 second
179
+ # This function will read the `notification_list` state
180
+ every=1
181
+ )
182
+
183
+ # --- Event Handlers ---
184
+
185
+ # Registration Logic
186
+ async def handle_register(email, password, confirm_password):
187
+ if not email or not password or not confirm_password:
188
+ return gr.update(value="Please fill in all fields.")
189
+ if password != confirm_password:
190
+ return gr.update(value="Passwords do not match.")
191
+ if len(password) < 8:
192
+ return gr.update(value="Password must be at least 8 characters long.")
193
+
194
+ payload = {"email": email, "password": password}
195
+ result = await make_api_request("post", "/register", json=payload)
196
+
197
+ if "error" in result:
198
+ return gr.update(value=f"Registration failed: {result['error']}")
199
+ else:
200
+ # Optionally switch to login tab after successful registration
201
+ return gr.update(value=f"Registration successful for {result.get('email')}! Please log in.")
202
+
203
+ reg_button.click(
204
+ handle_register,
205
+ inputs=[reg_email, reg_password, reg_confirm_password],
206
+ outputs=[reg_status]
207
+ )
208
+
209
+ # Login Logic
210
+ async def handle_login(email, password, current_task):
211
+ if not email or not password:
212
+ return gr.update(value="Please enter email and password."), None, None, None, gr.update(visible=False), current_task
213
+
214
+ payload = {"email": email, "password": password}
215
+ result = await make_api_request("post", "/login", json=payload)
216
+
217
+ if "error" in result:
218
+ return gr.update(value=f"Login failed: {result['error']}"), None, None, None, gr.update(visible=False), current_task
219
+ else:
220
+ token = result.get("access_token")
221
+ # Fetch user details using the token
222
+ user_data = await dependencies.get_optional_current_user(token) # Use dependency directly
223
+
224
+ if not user_data:
225
+ # This shouldn't happen if login succeeded, but check anyway
226
+ return gr.update(value="Login succeeded but failed to fetch user data."), None, None, None, gr.update(visible=False), current_task
227
+
228
+ # Cancel any existing websocket listener task before starting a new one
229
+ if current_task and not current_task.done():
230
+ current_task.cancel()
231
+ try:
232
+ await current_task # Wait for cancellation
233
+ except asyncio.CancelledError:
234
+ logger.info("Previous WebSocket task cancelled.")
235
+
236
+ # Start the WebSocket listener task in the background
237
+ # We pass the notification_list state *object* itself, which the task will modify
238
+ new_task = asyncio.create_task(listen_to_websockets(token, notification_list.value)) # Pass the list
239
+
240
+ # Update state and UI
241
+ welcome_msg = f"Welcome, {user_data.email}!"
242
+ # Switch tabs and show welcome message
243
+ return (
244
+ gr.update(value="Login successful!"), # login_status
245
+ token, # auth_token state
246
+ user_data.model_dump(), # user_info state (store as dict)
247
+ gr.update(selected="welcome_tab"), # Switch Tabs
248
+ gr.update(visible=True), # Make welcome tab visible
249
+ gr.update(value=welcome_msg), # Update welcome message markdown
250
+ new_task # websocket_task state
251
+ )
252
+
253
+ login_button.click(
254
+ handle_login,
255
+ inputs=[login_email, login_password, websocket_task],
256
+ outputs=[login_status, auth_token, user_info, tabs, welcome_tab, welcome_message, websocket_task]
257
+ )
258
+
259
+
260
+ # Function to update the notification display based on the state
261
+ # This function is triggered by the `every=1` on the notification_display Textbox
262
+ def update_notification_ui(notif_list_state):
263
+ # Join the list items into a string for display
264
+ return "\n".join(notif_list_state)
265
+
266
+ notification_display.change( # Use .change with every= setup on the component
267
+ fn=update_notification_ui,
268
+ inputs=[notification_list], # Read the state
269
+ outputs=[notification_display] # Update the component
270
+ )
271
+
272
+
273
+ # Logout Logic
274
+ async def handle_logout(current_task):
275
+ # Cancel the websocket listener task if it's running
276
+ if current_task and not current_task.done():
277
+ current_task.cancel()
278
+ try:
279
+ await current_task
280
+ except asyncio.CancelledError:
281
+ logger.info("WebSocket task cancelled on logout.")
282
+
283
+ # Clear state and switch back to login tab
284
+ return (
285
+ None, # Clear auth_token
286
+ None, # Clear user_info
287
+ [], # Clear notifications
288
+ None, # Clear websocket_task
289
+ gr.update(selected="login_tab"),# Switch Tabs
290
+ gr.update(visible=False), # Hide welcome tab
291
+ gr.update(value=""), # Clear welcome message
292
+ gr.update(value="") # Clear login status
293
+ )
294
+
295
+ logout_button.click(
296
+ handle_logout,
297
+ inputs=[websocket_task],
298
+ outputs=[
299
+ auth_token,
300
+ user_info,
301
+ notification_list,
302
+ websocket_task,
303
+ tabs,
304
+ welcome_tab,
305
+ welcome_message,
306
+ login_status
307
+ ]
308
+ )
309
+
310
+ # Mount the Gradio app onto the FastAPI app at the root
311
+ app = gr.mount_gradio_app(app, demo, path="/")
312
+
313
+ # If running this file directly (for local testing)
314
+ # Use uvicorn to run the FastAPI app (which now includes Gradio)
315
+ if __name__ == "__main__":
316
+ import uvicorn
317
+ # Use port 7860 as Gradio prefers, host 0.0.0.0 for Docker
318
+ uvicorn.run(app, host="0.0.0.0", port=7860)
app/models.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel, EmailStr
2
+
3
+ class User(BaseModel):
4
+ id: int
5
+ email: EmailStr
6
+
7
+ class UserInDB(User):
8
+ hashed_password: str
app/schemas.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel, EmailStr, Field
2
+
3
+ class UserCreate(BaseModel):
4
+ email: EmailStr
5
+ password: str = Field(min_length=8)
6
+
7
+ class UserLogin(BaseModel):
8
+ email: EmailStr
9
+ password: str
10
+
11
+ class Token(BaseModel):
12
+ access_token: str
13
+ token_type: str = "bearer"
14
+
15
+ class Notification(BaseModel):
16
+ type: str = "new_user"
17
+ email: EmailStr
18
+ message: str
app/websocket.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import WebSocket
2
+ from typing import List, Dict, Optional
3
+ import json
4
+ import logging
5
+
6
+ logger = logging.getLogger(__name__)
7
+
8
+ class ConnectionManager:
9
+ def __init__(self):
10
+ # Store connections with user ID if available for targeted messaging later
11
+ self.active_connections: Dict[Optional[int], List[WebSocket]] = {}
12
+ # Map websocket to user_id for easier removal
13
+ self.websocket_to_user: Dict[WebSocket, Optional[int]] = {}
14
+
15
+ async def connect(self, websocket: WebSocket, user_id: Optional[int] = None):
16
+ await websocket.accept()
17
+ if user_id not in self.active_connections:
18
+ self.active_connections[user_id] = []
19
+ self.active_connections[user_id].append(websocket)
20
+ self.websocket_to_user[websocket] = user_id
21
+ logger.info(f"WebSocket connected: {websocket.client.host}:{websocket.client.port}, User ID: {user_id}")
22
+ logger.info(f"Total connections: {sum(len(conns) for conns in self.active_connections.values())}")
23
+
24
+
25
+ def disconnect(self, websocket: WebSocket):
26
+ user_id = self.websocket_to_user.pop(websocket, None)
27
+ if user_id in self.active_connections:
28
+ try:
29
+ self.active_connections[user_id].remove(websocket)
30
+ if not self.active_connections[user_id]:
31
+ del self.active_connections[user_id]
32
+ except ValueError:
33
+ logger.warning(f"WebSocket not found in active list for user {user_id} during disconnect.")
34
+
35
+ logger.info(f"WebSocket disconnected: {websocket.client.host}:{websocket.client.port}, User ID: {user_id}")
36
+ logger.info(f"Total connections: {sum(len(conns) for conns in self.active_connections.values())}")
37
+
38
+
39
+ async def broadcast(self, message: str, sender_id: Optional[int] = None):
40
+ disconnected_websockets = []
41
+ # Iterate through all connections
42
+ all_websockets = [ws for user_conns in self.active_connections.values() for ws in user_conns]
43
+ logger.info(f"Broadcasting to {len(all_websockets)} connections (excluding sender if ID matches). Sender ID: {sender_id}")
44
+
45
+ for websocket in all_websockets:
46
+ # Send to all *other* users (or all if sender_id is None)
47
+ ws_user_id = self.websocket_to_user.get(websocket)
48
+ # Requirement: "all *other* connected users should see"
49
+ # Send if the websocket isn't associated with the sender_id
50
+ # (Note: If a user has multiple tabs/connections open, they might still receive it
51
+ # if the sender_id check only excludes one specific connection. This simple broadcast
52
+ # targets users based on their ID at connection time).
53
+ # Let's refine: Send if the user_id associated with the WS is not the sender_id
54
+ if ws_user_id != sender_id:
55
+ try:
56
+ await websocket.send_text(message)
57
+ logger.debug(f"Message sent to user {ws_user_id}")
58
+ except Exception as e:
59
+ logger.error(f"Error sending message to websocket {websocket.client}: {e}. Marking for disconnect.")
60
+ disconnected_websockets.append(websocket)
61
+
62
+ # Clean up connections that failed during broadcast
63
+ for ws in disconnected_websockets:
64
+ self.disconnect(ws)
65
+
66
+ manager = ConnectionManager()
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi==0.111.0
2
+ uvicorn[standard]==0.29.0
3
+ gradio==4.29.0
4
+ passlib[bcrypt]==1.7.4
5
+ python-dotenv==1.0.1
6
+ databases[sqlite]==0.9.0 # Async DB access
7
+ sqlalchemy==2.0.29 # Core needed by `databases` or for potential future use
8
+ pydantic==2.7.1
9
+ python-multipart==0.0.9 # For form data in FastAPI
10
+ itsdangerous==2.1.2 # Simple secure signing for session IDs
11
+ websockets>=11.0.3,<13.0 # Ensure compatibility