amaye15 commited on
Commit
d576ad8
·
1 Parent(s): b6b902a

Docker optimise

Browse files
Files changed (10) hide show
  1. Dockerfile +68 -11
  2. README.md +0 -1
  3. app/api.py +3 -5
  4. app/auth.py +3 -8
  5. app/database.py +4 -112
  6. app/dependencies.py +4 -11
  7. app/main.py +4 -42
  8. app/websocket.py +0 -10
  9. requirements.txt +2 -3
  10. tests/api.py +405 -0
Dockerfile CHANGED
@@ -1,20 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # Use an official Python runtime as a parent image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  FROM python:3.12-slim
3
 
4
- WORKDIR /code
 
 
 
 
5
 
6
- COPY ./requirements.txt /code/requirements.txt
7
 
8
- RUN pip install --no-cache-dir --upgrade pip && \
9
- pip install --no-cache-dir -r requirements.txt
 
 
10
 
11
- # Copy application code
12
- COPY ./app /code/app
13
- # Copy static files and templates
14
- COPY ./static /code/static
15
- COPY ./templates /code/templates
16
 
 
 
 
 
 
 
17
  EXPOSE 7860
18
 
19
- # Command to run the FastAPI application
20
- CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "7860"]
 
 
 
 
 
 
 
 
1
+ # # Use an official Python runtime as a parent image
2
+ # FROM python:3.12-slim
3
+
4
+ # WORKDIR /code
5
+
6
+ # COPY ./requirements.txt /code/requirements.txt
7
+
8
+ # RUN pip install --no-cache-dir --upgrade pip && \
9
+ # pip install --no-cache-dir -r requirements.txt
10
+
11
+ # # Copy application code
12
+ # COPY ./app /code/app
13
+ # # Copy static files and templates
14
+ # COPY ./static /code/static
15
+ # COPY ./templates /code/templates
16
+
17
+ # EXPOSE 7860
18
+
19
+ # # Command to run the FastAPI application
20
+ # CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "7860"]
21
+
22
  # Use an official Python runtime as a parent image
23
+ FROM python:3.12-slim as builder
24
+
25
+ # Set environment variables
26
+ ENV PYTHONDONTWRITEBYTECODE=1 \
27
+ PYTHONUNBUFFERED=1 \
28
+ PIP_NO_CACHE_DIR=1 \
29
+ PIP_DISABLE_PIP_VERSION_CHECK=1 \
30
+ PYTHONOPTIMIZE=2
31
+
32
+ WORKDIR /build
33
+
34
+ # Copy only requirements first to leverage Docker caching
35
+ COPY ./requirements.txt .
36
+
37
+ # Install dependencies into a virtual environment
38
+ RUN python -m venv /venv && \
39
+ /venv/bin/pip install --no-cache-dir --upgrade pip && \
40
+ /venv/bin/pip install --no-cache-dir -r requirements.txt
41
+
42
+ # Final stage
43
  FROM python:3.12-slim
44
 
45
+ # Set environment variables
46
+ ENV PYTHONDONTWRITEBYTECODE=1 \
47
+ PYTHONUNBUFFERED=1 \
48
+ PYTHONOPTIMIZE=2 \
49
+ PATH="/venv/bin:$PATH"
50
 
51
+ WORKDIR /app
52
 
53
+ # Create a non-root user
54
+ RUN addgroup --system app && \
55
+ adduser --system --group app && \
56
+ chown -R app:app /app
57
 
58
+ # Copy the virtual environment from the builder stage
59
+ COPY --from=builder /venv /venv
 
 
 
60
 
61
+ # Copy application code and assets
62
+ COPY --chown=app:app ./app ./app
63
+ COPY --chown=app:app ./static ./static
64
+ COPY --chown=app:app ./templates ./templates
65
+
66
+ # Expose the port
67
  EXPOSE 7860
68
 
69
+ # Switch to non-root user
70
+ USER app
71
+
72
+ # Add healthcheck
73
+ HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
74
+ CMD curl -f http://localhost:7860/health || exit 1
75
+
76
+ # Command to run the FastAPI application with optimized settings
77
+ CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "7860", "--workers", "4", "--proxy-headers"]
README.md CHANGED
@@ -10,7 +10,6 @@ short_description: An app demonstrating Gradio, FastAPI, Docker & SQL DB
10
  app_file: app/main.py
11
  python_version: 3.12
12
  port: 7860
13
- fullWidth: false
14
  ---
15
 
16
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
10
  app_file: app/main.py
11
  python_version: 3.12
12
  port: 7860
 
13
  ---
14
 
15
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app/api.py CHANGED
@@ -1,4 +1,3 @@
1
- # app/api.py
2
  from fastapi import APIRouter, HTTPException, status, Depends, WebSocket, WebSocketDisconnect
3
  import logging
4
 
@@ -9,8 +8,8 @@ from .dependencies import get_required_current_user
9
  router = APIRouter()
10
  logger = logging.getLogger(__name__)
11
 
12
- # --- FIX THE DECORATORS HERE ---
13
- @router.post("/register", status_code=status.HTTP_201_CREATED, response_model=models.User) # <-- FIX HERE
14
  async def register_user(user_in: schemas.UserCreate):
15
  existing_user = await crud.get_user_by_email(user_in.email)
16
  if existing_user:
@@ -23,14 +22,13 @@ async def register_user(user_in: schemas.UserCreate):
23
  if not created_user: raise HTTPException(status_code=500, detail="Failed to retrieve created user")
24
  return models.User(id=created_user.id, email=created_user.email)
25
 
26
- @router.post("/login", response_model=schemas.Token) # <-- FIX HERE
27
  async def login_for_access_token(form_data: schemas.UserLogin):
28
  user = await crud.get_user_by_email(form_data.email)
29
  if not user or not auth.verify_password(form_data.password, user.hashed_password):
30
  raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect email or password", headers={"WWW-Authenticate": "Bearer"})
31
  access_token = auth.create_session_token(user_id=user.id)
32
  return {"access_token": access_token, "token_type": "bearer"}
33
- # --- END FIXES ---
34
 
35
  @router.get("/users/me", response_model=models.User)
36
  async def read_users_me(current_user: models.User = Depends(get_required_current_user)):
 
 
1
  from fastapi import APIRouter, HTTPException, status, Depends, WebSocket, WebSocketDisconnect
2
  import logging
3
 
 
8
  router = APIRouter()
9
  logger = logging.getLogger(__name__)
10
 
11
+
12
+ @router.post("/register", status_code=status.HTTP_201_CREATED, response_model=models.User)
13
  async def register_user(user_in: schemas.UserCreate):
14
  existing_user = await crud.get_user_by_email(user_in.email)
15
  if existing_user:
 
22
  if not created_user: raise HTTPException(status_code=500, detail="Failed to retrieve created user")
23
  return models.User(id=created_user.id, email=created_user.email)
24
 
25
+ @router.post("/login", response_model=schemas.Token)
26
  async def login_for_access_token(form_data: schemas.UserLogin):
27
  user = await crud.get_user_by_email(form_data.email)
28
  if not user or not auth.verify_password(form_data.password, user.hashed_password):
29
  raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect email or password", headers={"WWW-Authenticate": "Bearer"})
30
  access_token = auth.create_session_token(user_id=user.id)
31
  return {"access_token": access_token, "token_type": "bearer"}
 
32
 
33
  @router.get("/users/me", response_model=models.User)
34
  async def read_users_me(current_user: models.User = Depends(get_required_current_user)):
app/auth.py CHANGED
@@ -1,5 +1,4 @@
1
  import os
2
- from datetime import datetime, timedelta, timezone
3
  from passlib.context import CryptContext
4
  from itsdangerous import URLSafeTimedSerializer, SignatureExpired, BadSignature
5
  from dotenv import load_dotenv
@@ -8,8 +7,8 @@ from . import crud, models
8
 
9
  load_dotenv()
10
 
11
- SECRET_KEY = os.getenv("SECRET_KEY", "super-secret") # Fallback, but .env should be used
12
- # Use URLSafeTimedSerializer for session tokens that expire
13
  serializer = URLSafeTimedSerializer(SECRET_KEY)
14
 
15
  pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
@@ -20,8 +19,6 @@ def verify_password(plain_password: str, hashed_password: str) -> bool:
20
  def get_password_hash(password: str) -> str:
21
  return pwd_context.hash(password)
22
 
23
- # Session Token generation (using itsdangerous for simplicity)
24
- # Stores user_id securely signed with a timestamp
25
  def create_session_token(user_id: int) -> str:
26
  return serializer.dumps(user_id)
27
 
@@ -30,8 +27,7 @@ async def get_user_id_from_token(token: str) -> Optional[int]:
30
  if not token:
31
  return None
32
  try:
33
- # Set max_age to something reasonable, e.g., 1 day
34
- user_id = serializer.loads(token, max_age=86400) # 24 hours * 60 min * 60 sec
35
  return int(user_id)
36
  except (SignatureExpired, BadSignature, ValueError):
37
  return None
@@ -43,6 +39,5 @@ async def get_current_user_from_token(token: str) -> Optional[models.User]:
43
  return None
44
  user = await crud.get_user_by_id(user_id)
45
  if user:
46
- # Return the public User model, not UserInDB
47
  return models.User(id=user.id, email=user.email)
48
  return None
 
1
  import os
 
2
  from passlib.context import CryptContext
3
  from itsdangerous import URLSafeTimedSerializer, SignatureExpired, BadSignature
4
  from dotenv import load_dotenv
 
7
 
8
  load_dotenv()
9
 
10
+ SECRET_KEY = os.getenv("SECRET_KEY", "super-secret")
11
+
12
  serializer = URLSafeTimedSerializer(SECRET_KEY)
13
 
14
  pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
 
19
  def get_password_hash(password: str) -> str:
20
  return pwd_context.hash(password)
21
 
 
 
22
  def create_session_token(user_id: int) -> str:
23
  return serializer.dumps(user_id)
24
 
 
27
  if not token:
28
  return None
29
  try:
30
+ user_id = serializer.loads(token, max_age=86400)
 
31
  return int(user_id)
32
  except (SignatureExpired, BadSignature, ValueError):
33
  return None
 
39
  return None
40
  user = await crud.get_user_by_id(user_id)
41
  if user:
 
42
  return models.User(id=user.id, email=user.email)
43
  return None
app/database.py CHANGED
@@ -1,135 +1,33 @@
1
- # # app/database.py
2
- # import os
3
- # from databases import Database
4
- # from dotenv import load_dotenv
5
- # # --- Keep only these SQLAlchemy imports ---
6
- # from sqlalchemy import MetaData, Table, Column, Integer, String
7
- # import logging
8
- # from urllib.parse import urlparse, urlunparse, parse_qs, urlencode
9
-
10
- # load_dotenv()
11
- # logger = logging.getLogger(__name__)
12
-
13
- # # --- Database URL Configuration ---
14
- # DEFAULT_DB_PATH = "/tmp/app.db" # Store DB in the temporary directory
15
- # raw_db_url = os.getenv("DATABASE_URL", f"sqlite+aiosqlite:///{DEFAULT_DB_PATH}")
16
-
17
- # final_database_url = raw_db_url
18
- # if raw_db_url.startswith("sqlite+aiosqlite"):
19
- # parsed_url = urlparse(raw_db_url)
20
- # query_params = parse_qs(parsed_url.query)
21
- # if 'check_same_thread' not in query_params:
22
- # query_params['check_same_thread'] = ['False']
23
- # new_query = urlencode(query_params, doseq=True)
24
- # final_database_url = urlunparse(parsed_url._replace(query=new_query))
25
- # logger.info(f"Using final async DB URL: {final_database_url}")
26
- # else:
27
- # logger.info(f"Using non-SQLite async DB URL: {final_database_url}")
28
-
29
- # # --- Async Database Instance ---
30
- # database = Database(final_database_url)
31
-
32
- # # --- Metadata and Table Definition (Still needed for DDL generation) ---
33
- # metadata = MetaData()
34
- # users = Table(
35
- # "users",
36
- # metadata,
37
- # Column("id", Integer, primary_key=True),
38
- # Column("email", String, unique=True, index=True, nullable=False),
39
- # Column("hashed_password", String, nullable=False),
40
- # )
41
-
42
- # # --- REMOVE ALL SYNCHRONOUS ENGINE AND TABLE CREATION LOGIC ---
43
-
44
- # # --- Keep and refine Async connect/disconnect functions ---
45
- # async def connect_db():
46
- # """Connects to the database, ensuring the parent directory exists."""
47
- # try:
48
- # # Ensure the directory exists just before connecting
49
- # db_file_path = final_database_url.split("sqlite:///")[-1].split("?")[0]
50
- # db_dir = os.path.dirname(db_file_path)
51
- # if db_dir: # Only proceed if a directory path was found
52
- # if not os.path.exists(db_dir):
53
- # logger.info(f"Database directory {db_dir} does not exist. Attempting creation...")
54
- # try:
55
- # os.makedirs(db_dir, exist_ok=True)
56
- # logger.info(f"Created database directory {db_dir}.")
57
- # except Exception as mkdir_err:
58
- # # Log error but proceed, connection might still work if path is valid but dir creation failed weirdly
59
- # logger.error(f"Failed to create directory {db_dir}: {mkdir_err}")
60
- # # Check writability after ensuring existence attempt
61
- # if os.path.exists(db_dir) and not os.access(db_dir, os.W_OK):
62
- # logger.error(f"CRITICAL: Directory {db_dir} exists but is not writable!")
63
- # elif not os.path.exists(db_dir):
64
- # logger.error(f"CRITICAL: Directory {db_dir} does not exist and could not be created!")
65
-
66
-
67
- # # Now attempt connection
68
- # await database.connect()
69
- # logger.info(f"Database connection established (async): {final_database_url}")
70
- # # Table creation will happen in main.py lifespan event using this connection
71
- # except Exception as e:
72
- # logger.exception(f"Failed to establish async database connection: {e}")
73
- # raise # Reraise critical error during startup
74
-
75
- # async def disconnect_db():
76
- # """Disconnects from the database if connected."""
77
- # try:
78
- # if database.is_connected:
79
- # await database.disconnect()
80
- # logger.info("Database connection closed (async).")
81
- # else:
82
- # logger.info("Database already disconnected (async).")
83
- # except Exception as e:
84
- # logger.exception(f"Error closing async database connection: {e}")
85
-
86
-
87
-
88
- # app/database.py
89
  import os
90
  from databases import Database
91
  from dotenv import load_dotenv
92
- # --- Keep only these SQLAlchemy imports ---
93
- # MetaData and Table are needed for defining the table structure
94
- # which is used by crud.py and for DDL generation in main.py
95
  from sqlalchemy import MetaData, Table, Column, Integer, String
96
  import logging
97
  from urllib.parse import urlparse, urlunparse, parse_qs, urlencode
98
 
99
- # Load environment variables from .env file (if it exists)
100
  load_dotenv()
101
  logger = logging.getLogger(__name__)
102
 
103
- # --- Database URL Configuration ---
104
- # Use /tmp directory for the SQLite file as it's generally writable in containers
105
  DEFAULT_DB_PATH = "/tmp/app.db"
106
- # Get the URL from environment or use the default /tmp path
107
  raw_db_url = os.getenv("DATABASE_URL", f"sqlite+aiosqlite:///{DEFAULT_DB_PATH}")
108
 
109
  final_database_url = raw_db_url
110
- # Ensure 'check_same_thread=False' is in the URL query string for SQLite async connection
111
  if raw_db_url.startswith("sqlite+aiosqlite"):
112
  parsed_url = urlparse(raw_db_url)
113
  query_params = parse_qs(parsed_url.query)
114
  if 'check_same_thread' not in query_params:
115
- query_params['check_same_thread'] = ['False'] # Value needs to be a list for urlencode
116
  new_query = urlencode(query_params, doseq=True)
117
- # Rebuild the URL using _replace method of the named tuple
118
  final_database_url = urlunparse(parsed_url._replace(query=new_query))
119
  logger.info(f"Using final async DB URL: {final_database_url}")
120
  else:
121
  logger.info(f"Using non-SQLite async DB URL: {final_database_url}")
122
 
123
 
124
- # --- Async Database Instance ---
125
- # This 'database' object will be used by crud.py and main.py lifespan
126
  database = Database(final_database_url)
127
 
128
-
129
- # --- Metadata and Table Definition ---
130
- # These definitions are needed by:
131
- # 1. crud.py to construct queries (e.g., users.select())
132
- # 2. main.py (lifespan) to generate the CREATE TABLE statement
133
  metadata = MetaData()
134
  users = Table(
135
  "users",
@@ -139,13 +37,9 @@ users = Table(
139
  Column("hashed_password", String, nullable=False),
140
  )
141
 
142
-
143
- # --- Async connect/disconnect functions ---
144
- # Called by the FastAPI lifespan event handler in main.py
145
  async def connect_db():
146
  """Connects to the database defined by 'final_database_url'."""
147
  try:
148
- # Optional: Check/create directory if using file-based DB like SQLite
149
  if final_database_url.startswith("sqlite"):
150
  db_file_path = final_database_url.split("sqlite:///")[-1].split("?")[0]
151
  db_dir = os.path.dirname(db_file_path)
@@ -162,16 +56,14 @@ async def connect_db():
162
  elif not os.path.exists(db_dir):
163
  logger.error(f"CRITICAL: DB directory '{db_dir}' does not exist and could not be created!")
164
 
165
- # Connect using the 'databases' library instance
166
  if not database.is_connected:
167
  await database.connect()
168
  logger.info(f"Database connection established: {final_database_url}")
169
  else:
170
  logger.info("Database connection already established.")
171
- # Note: Table creation happens in main.py lifespan after connection
172
  except Exception as e:
173
  logger.exception(f"FATAL: Failed to establish async database connection: {e}")
174
- raise # Stop application startup if DB connection fails
175
 
176
  async def disconnect_db():
177
  """Disconnects from the database if connected."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  from databases import Database
3
  from dotenv import load_dotenv
 
 
 
4
  from sqlalchemy import MetaData, Table, Column, Integer, String
5
  import logging
6
  from urllib.parse import urlparse, urlunparse, parse_qs, urlencode
7
 
 
8
  load_dotenv()
9
  logger = logging.getLogger(__name__)
10
 
11
+
 
12
  DEFAULT_DB_PATH = "/tmp/app.db"
 
13
  raw_db_url = os.getenv("DATABASE_URL", f"sqlite+aiosqlite:///{DEFAULT_DB_PATH}")
14
 
15
  final_database_url = raw_db_url
16
+
17
  if raw_db_url.startswith("sqlite+aiosqlite"):
18
  parsed_url = urlparse(raw_db_url)
19
  query_params = parse_qs(parsed_url.query)
20
  if 'check_same_thread' not in query_params:
21
+ query_params['check_same_thread'] = ['False']
22
  new_query = urlencode(query_params, doseq=True)
 
23
  final_database_url = urlunparse(parsed_url._replace(query=new_query))
24
  logger.info(f"Using final async DB URL: {final_database_url}")
25
  else:
26
  logger.info(f"Using non-SQLite async DB URL: {final_database_url}")
27
 
28
 
 
 
29
  database = Database(final_database_url)
30
 
 
 
 
 
 
31
  metadata = MetaData()
32
  users = Table(
33
  "users",
 
37
  Column("hashed_password", String, nullable=False),
38
  )
39
 
 
 
 
40
  async def connect_db():
41
  """Connects to the database defined by 'final_database_url'."""
42
  try:
 
43
  if final_database_url.startswith("sqlite"):
44
  db_file_path = final_database_url.split("sqlite:///")[-1].split("?")[0]
45
  db_dir = os.path.dirname(db_file_path)
 
56
  elif not os.path.exists(db_dir):
57
  logger.error(f"CRITICAL: DB directory '{db_dir}' does not exist and could not be created!")
58
 
 
59
  if not database.is_connected:
60
  await database.connect()
61
  logger.info(f"Database connection established: {final_database_url}")
62
  else:
63
  logger.info("Database connection already established.")
 
64
  except Exception as e:
65
  logger.exception(f"FATAL: Failed to establish async database connection: {e}")
66
+ raise
67
 
68
  async def disconnect_db():
69
  """Disconnects from the database if connected."""
app/dependencies.py CHANGED
@@ -1,10 +1,8 @@
1
  from fastapi import Depends, HTTPException, status
2
- from fastapi.security import OAuth2PasswordBearer # Use FastAPI's built-in helper
3
  from typing import Optional
4
  from . import auth, models
5
 
6
- # Setup OAuth2 scheme pointing to the login *API* endpoint
7
- # tokenUrl relative to the path where the app is mounted
8
  oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/login")
9
 
10
  async def get_optional_current_user(token: str = Depends(oauth2_scheme)) -> Optional[models.User]:
@@ -14,25 +12,20 @@ async def get_optional_current_user(token: str = Depends(oauth2_scheme)) -> Opti
14
  Handles potential exceptions during token decoding/validation gracefully for optional user.
15
  """
16
  try:
17
- # OAuth2PasswordBearer already extracts the token from the header
18
  user = await auth.get_current_user_from_token(token)
19
  return user
20
- except Exception: # Catch exceptions if the token is invalid but we don't want to fail hard
21
  return None
22
 
23
  async def get_required_current_user(token: str = Depends(oauth2_scheme)) -> models.User:
24
  """
25
  Dependency to get the current user, raising HTTP 401 if not authenticated.
26
  """
27
- # OAuth2PasswordBearer will raise a 401 if the header is missing/malformed
28
  user = await auth.get_current_user_from_token(token)
29
  if user is None:
30
- # This case covers valid token format but expired/invalid signature/user not found
31
  raise HTTPException(
32
  status_code=status.HTTP_401_UNAUTHORIZED,
33
- detail="Could not validate credentials", # Keep detail generic
34
  headers={"WWW-Authenticate": "Bearer"},
35
  )
36
- return user
37
-
38
- # Modify the /users/me endpoint in api.py to use the new dependency
 
1
  from fastapi import Depends, HTTPException, status
2
+ from fastapi.security import OAuth2PasswordBearer
3
  from typing import Optional
4
  from . import auth, models
5
 
 
 
6
  oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/login")
7
 
8
  async def get_optional_current_user(token: str = Depends(oauth2_scheme)) -> Optional[models.User]:
 
12
  Handles potential exceptions during token decoding/validation gracefully for optional user.
13
  """
14
  try:
 
15
  user = await auth.get_current_user_from_token(token)
16
  return user
17
+ except Exception:
18
  return None
19
 
20
  async def get_required_current_user(token: str = Depends(oauth2_scheme)) -> models.User:
21
  """
22
  Dependency to get the current user, raising HTTP 401 if not authenticated.
23
  """
 
24
  user = await auth.get_current_user_from_token(token)
25
  if user is None:
 
26
  raise HTTPException(
27
  status_code=status.HTTP_401_UNAUTHORIZED,
28
+ detail="Could not validate credentials",
29
  headers={"WWW-Authenticate": "Bearer"},
30
  )
31
+ return user
 
 
app/main.py CHANGED
@@ -1,27 +1,13 @@
1
- # app/main.py
2
- # Remove Gradio imports if any remain
3
- # import gradio as gr <--- REMOVE
4
-
5
- import httpx # Keep if needed, but not used in this version of main.py
6
- import websockets # Keep if needed, but not used in this version of main.py
7
- import asyncio
8
- import json
9
- import os
10
  import logging
11
  from contextlib import asynccontextmanager
12
 
13
- from fastapi import FastAPI, Depends, Request # Add Request
14
- from fastapi.responses import HTMLResponse # Add HTMLResponse
15
- from fastapi.staticfiles import StaticFiles # Add StaticFiles
16
- from fastapi.templating import Jinja2Templates # Add Jinja2Templates (optional, but good practice)
17
 
18
- # --- Import necessary items from database.py ---
19
- from .database import connect_db, disconnect_db, database, metadata, users
20
  from .api import router as api_router
21
- from . import schemas, auth, dependencies
22
- from .websocket import manager # Keep
23
 
24
- # --- Import SQLAlchemy helpers for DDL generation ---
25
  from sqlalchemy.schema import CreateTable
26
  from sqlalchemy.dialects import sqlite
27
 
@@ -29,13 +15,8 @@ from sqlalchemy.dialects import sqlite
29
  logging.basicConfig(level=logging.INFO)
30
  logger = logging.getLogger(__name__)
31
 
32
- # --- REMOVE API_BASE_URL if not needed elsewhere ---
33
- # API_BASE_URL = "http://127.0.0.1:7860/api"
34
-
35
- # --- Lifespan Event (remains the same) ---
36
  @asynccontextmanager
37
  async def lifespan(app: FastAPI):
38
- # ... (same DB setup code as previous correct version) ...
39
  logger.info("Application startup: Connecting DB...")
40
  await connect_db()
41
  logger.info("Application startup: DB Connected. Checking/Creating tables...")
@@ -65,22 +46,12 @@ async def lifespan(app: FastAPI):
65
  logger.info("Application shutdown: DB Disconnected.")
66
 
67
 
68
- # Create the main FastAPI app instance
69
  app = FastAPI(lifespan=lifespan)
70
 
71
- # Mount API routes FIRST
72
  app.include_router(api_router, prefix="/api")
73
 
74
- # --- Mount Static files ---
75
- # Ensure the path exists relative to where you run uvicorn (or use absolute paths)
76
- # Since main.py is in app/, static/ is one level up
77
- # Adjust 'directory' path if needed based on your execution context
78
  app.mount("/static", StaticFiles(directory="static"), name="static")
79
 
80
- # --- Optional: Use Jinja2Templates for more flexibility ---
81
- # templates = Jinja2Templates(directory="templates")
82
-
83
- # --- Serve the main HTML page ---
84
  @app.get("/", response_class=HTMLResponse)
85
  async def read_root(request: Request):
86
  # Simple way: Read the file directly
@@ -91,16 +62,7 @@ async def read_root(request: Request):
91
  except FileNotFoundError:
92
  logger.error("templates/index.html not found!")
93
  return HTMLResponse(content="<html><body><h1>Error: Frontend not found</h1></body></html>", status_code=500)
94
- # Jinja2 way (if using templates):
95
- # return templates.TemplateResponse("index.html", {"request": request})
96
-
97
-
98
- # --- REMOVE Gradio mounting ---
99
- # app = gr.mount_gradio_app(app, demo, path="/")
100
 
101
- # --- Uvicorn run command (no changes needed here) ---
102
  if __name__ == "__main__":
103
  import uvicorn
104
- # Note: If running from the project root directory (fastapi_gradio_auth/),
105
- # the app path is "app.main:app"
106
  uvicorn.run("app.main:app", host="0.0.0.0", port=7860, reload=True)
 
 
 
 
 
 
 
 
 
 
1
  import logging
2
  from contextlib import asynccontextmanager
3
 
4
+ from fastapi import FastAPI, Request
5
+ from fastapi.responses import HTMLResponse
6
+ from fastapi.staticfiles import StaticFiles
 
7
 
8
+ from .database import connect_db, disconnect_db, database, users
 
9
  from .api import router as api_router
 
 
10
 
 
11
  from sqlalchemy.schema import CreateTable
12
  from sqlalchemy.dialects import sqlite
13
 
 
15
  logging.basicConfig(level=logging.INFO)
16
  logger = logging.getLogger(__name__)
17
 
 
 
 
 
18
  @asynccontextmanager
19
  async def lifespan(app: FastAPI):
 
20
  logger.info("Application startup: Connecting DB...")
21
  await connect_db()
22
  logger.info("Application startup: DB Connected. Checking/Creating tables...")
 
46
  logger.info("Application shutdown: DB Disconnected.")
47
 
48
 
 
49
  app = FastAPI(lifespan=lifespan)
50
 
 
51
  app.include_router(api_router, prefix="/api")
52
 
 
 
 
 
53
  app.mount("/static", StaticFiles(directory="static"), name="static")
54
 
 
 
 
 
55
  @app.get("/", response_class=HTMLResponse)
56
  async def read_root(request: Request):
57
  # Simple way: Read the file directly
 
62
  except FileNotFoundError:
63
  logger.error("templates/index.html not found!")
64
  return HTMLResponse(content="<html><body><h1>Error: Frontend not found</h1></body></html>", status_code=500)
 
 
 
 
 
 
65
 
 
66
  if __name__ == "__main__":
67
  import uvicorn
 
 
68
  uvicorn.run("app.main:app", host="0.0.0.0", port=7860, reload=True)
app/websocket.py CHANGED
@@ -1,15 +1,12 @@
1
  from fastapi import WebSocket
2
  from typing import List, Dict, Optional
3
- import json
4
  import logging
5
 
6
  logger = logging.getLogger(__name__)
7
 
8
  class ConnectionManager:
9
  def __init__(self):
10
- # Store connections with user ID if available for targeted messaging later
11
  self.active_connections: Dict[Optional[int], List[WebSocket]] = {}
12
- # Map websocket to user_id for easier removal
13
  self.websocket_to_user: Dict[WebSocket, Optional[int]] = {}
14
 
15
  async def connect(self, websocket: WebSocket, user_id: Optional[int] = None):
@@ -43,14 +40,7 @@ class ConnectionManager:
43
  logger.info(f"Broadcasting to {len(all_websockets)} connections (excluding sender if ID matches). Sender ID: {sender_id}")
44
 
45
  for websocket in all_websockets:
46
- # Send to all *other* users (or all if sender_id is None)
47
  ws_user_id = self.websocket_to_user.get(websocket)
48
- # Requirement: "all *other* connected users should see"
49
- # Send if the websocket isn't associated with the sender_id
50
- # (Note: If a user has multiple tabs/connections open, they might still receive it
51
- # if the sender_id check only excludes one specific connection. This simple broadcast
52
- # targets users based on their ID at connection time).
53
- # Let's refine: Send if the user_id associated with the WS is not the sender_id
54
  if ws_user_id != sender_id:
55
  try:
56
  await websocket.send_text(message)
 
1
  from fastapi import WebSocket
2
  from typing import List, Dict, Optional
 
3
  import logging
4
 
5
  logger = logging.getLogger(__name__)
6
 
7
  class ConnectionManager:
8
  def __init__(self):
 
9
  self.active_connections: Dict[Optional[int], List[WebSocket]] = {}
 
10
  self.websocket_to_user: Dict[WebSocket, Optional[int]] = {}
11
 
12
  async def connect(self, websocket: WebSocket, user_id: Optional[int] = None):
 
40
  logger.info(f"Broadcasting to {len(all_websockets)} connections (excluding sender if ID matches). Sender ID: {sender_id}")
41
 
42
  for websocket in all_websockets:
 
43
  ws_user_id = self.websocket_to_user.get(websocket)
 
 
 
 
 
 
44
  if ws_user_id != sender_id:
45
  try:
46
  await websocket.send_text(message)
requirements.txt CHANGED
@@ -1,6 +1,5 @@
1
  fastapi==0.111.0
2
  uvicorn[standard]==0.29.0
3
- # gradio==4.29.0 # REMOVE
4
  bcrypt==4.1.3
5
  passlib[bcrypt]==1.7.4
6
  python-dotenv==1.0.1
@@ -10,5 +9,5 @@ pydantic==2.7.1
10
  python-multipart==0.0.9
11
  itsdangerous==2.1.2
12
  websockets>=11.0.3,<13.0
13
- aiofiles==23.2.1 # <-- ADD for StaticFiles
14
- httpx==0.27.0 # <-- ADD (useful if backend needs to call itself, good practice)
 
1
  fastapi==0.111.0
2
  uvicorn[standard]==0.29.0
 
3
  bcrypt==4.1.3
4
  passlib[bcrypt]==1.7.4
5
  python-dotenv==1.0.1
 
9
  python-multipart==0.0.9
10
  itsdangerous==2.1.2
11
  websockets>=11.0.3,<13.0
12
+ aiofiles==23.2.1
13
+ httpx==0.27.0
tests/api.py ADDED
@@ -0,0 +1,405 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import requests
2
+ # import time
3
+ # from faker import Faker
4
+
5
+
6
+ # class AuthClient:
7
+ # """
8
+ # Python client for interacting with the Authentication API
9
+ # """
10
+
11
+ # def __init__(self, base_url="http://localhost:7860/api"):
12
+ # """
13
+ # Initialize the client with the API base URL
14
+
15
+ # Args:
16
+ # base_url (str): The base URL of the API
17
+ # """
18
+ # self.base_url = base_url
19
+ # self.token = None
20
+
21
+ # def register(self, email, password):
22
+ # """
23
+ # Register a new user
24
+
25
+ # Args:
26
+ # email (str): User's email
27
+ # password (str): User's password (should be at least 8 characters)
28
+
29
+ # Returns:
30
+ # dict: The user data returned by the API
31
+
32
+ # Raises:
33
+ # Exception: If registration fails
34
+ # """
35
+ # url = f"{self.base_url}/register"
36
+ # data = {
37
+ # "email": email,
38
+ # "password": password
39
+ # }
40
+
41
+ # response = requests.post(url, json=data)
42
+
43
+ # if response.status_code == 201:
44
+ # return response.json()
45
+ # else:
46
+ # error_detail = response.json().get("detail", "Unknown error")
47
+ # raise Exception(f"Registration failed: {error_detail} (Status: {response.status_code})")
48
+
49
+ # def login(self, email, password):
50
+ # """
51
+ # Login to obtain an authentication token
52
+
53
+ # Args:
54
+ # email (str): User's email
55
+ # password (str): User's password
56
+
57
+ # Returns:
58
+ # dict: The token data returned by the API
59
+
60
+ # Raises:
61
+ # Exception: If login fails
62
+ # """
63
+ # url = f"{self.base_url}/login"
64
+ # data = {
65
+ # "email": email,
66
+ # "password": password
67
+ # }
68
+
69
+ # response = requests.post(url, json=data)
70
+
71
+ # if response.status_code == 200:
72
+ # token_data = response.json()
73
+ # self.token = token_data["access_token"]
74
+ # return token_data
75
+ # else:
76
+ # error_detail = response.json().get("detail", "Unknown error")
77
+ # raise Exception(f"Login failed: {error_detail} (Status: {response.status_code})")
78
+
79
+ # def get_current_user(self):
80
+ # """
81
+ # Get information about the current logged-in user
82
+
83
+ # Returns:
84
+ # dict: The user data returned by the API
85
+
86
+ # Raises:
87
+ # Exception: If not authenticated or request fails
88
+ # """
89
+ # if not self.token:
90
+ # raise Exception("Not authenticated. Please login first.")
91
+
92
+ # url = f"{self.base_url}/users/me"
93
+ # headers = {"Authorization": f"Bearer {self.token}"}
94
+
95
+ # response = requests.get(url, headers=headers)
96
+
97
+ # if response.status_code == 200:
98
+ # return response.json()
99
+ # else:
100
+ # error_detail = response.json().get("detail", "Unknown error")
101
+ # raise Exception(f"Failed to get user info: {error_detail} (Status: {response.status_code})")
102
+
103
+ # def logout(self):
104
+ # """Clear the authentication token"""
105
+ # self.token = None
106
+
107
+
108
+ # # Example usage
109
+ # def main():
110
+ # # Initialize the client
111
+ # client = AuthClient("https://amaye15-authenticationapp.hf.space/api")
112
+
113
+ # # Initialize Faker
114
+ # fake = Faker()
115
+
116
+ # for i in range(10):
117
+ # try:
118
+ # # Generate random user data
119
+ # first_name = fake.first_name()
120
+ # last_name = fake.last_name()
121
+ # email = fake.email()
122
+ # password = fake.password(length=12, special_chars=True, digits=True, upper_case=True, lower_case=True)
123
+
124
+ # # Register a new user
125
+ # print(f"Registering a new user: {first_name} {last_name}...")
126
+ # try:
127
+ # user = client.register(email, password)
128
+ # print(f"Registered user: {user}")
129
+ # except Exception as e:
130
+ # print(f"Registration failed: {e}")
131
+
132
+ # # Login
133
+ # print("\nLogging in...")
134
+ # token_data = client.login(email, password)
135
+ # print(f"Login successful, token: {token_data['access_token'][:10]}...")
136
+
137
+ # # Get current user
138
+ # print("\nGetting current user info...")
139
+ # user_info = client.get_current_user()
140
+ # print(f"Current user: {user_info}")
141
+
142
+ # # Logout
143
+ # print("\nLogging out...")
144
+ # client.logout()
145
+ # print("Logged out successfully")
146
+
147
+ # except Exception as e:
148
+ # print(f"Error: {e}")
149
+
150
+
151
+ # if __name__ == "__main__":
152
+ # main()
153
+
154
+ import asyncio
155
+ import aiohttp
156
+ import time
157
+ from faker import Faker
158
+
159
+
160
+ class AuthClient:
161
+ """
162
+ Asynchronous Python client for interacting with the Authentication API
163
+ """
164
+
165
+ def __init__(self, base_url="http://localhost:7860/api"):
166
+ """
167
+ Initialize the client with the API base URL
168
+
169
+ Args:
170
+ base_url (str): The base URL of the API
171
+ """
172
+ self.base_url = base_url
173
+ self.token = None
174
+ self.session = None
175
+
176
+ async def __aenter__(self):
177
+ """Create and enter an aiohttp session"""
178
+ self.session = aiohttp.ClientSession()
179
+ return self
180
+
181
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
182
+ """Close the aiohttp session"""
183
+ if self.session:
184
+ await self.session.close()
185
+
186
+ async def _get_session(self):
187
+ """Get or create an aiohttp session"""
188
+ if self.session is None:
189
+ self.session = aiohttp.ClientSession()
190
+ return self.session
191
+
192
+ async def register(self, email, password):
193
+ """
194
+ Register a new user
195
+
196
+ Args:
197
+ email (str): User's email
198
+ password (str): User's password (should be at least 8 characters)
199
+
200
+ Returns:
201
+ dict: The user data returned by the API
202
+
203
+ Raises:
204
+ Exception: If registration fails
205
+ """
206
+ url = f"{self.base_url}/register"
207
+ data = {
208
+ "email": email,
209
+ "password": password
210
+ }
211
+
212
+ session = await self._get_session()
213
+ async with session.post(url, json=data) as response:
214
+ if response.status == 201:
215
+ return await response.json()
216
+ else:
217
+ error_data = await response.json()
218
+ error_detail = error_data.get("detail", "Unknown error")
219
+ raise Exception(f"Registration failed: {error_detail} (Status: {response.status})")
220
+
221
+ async def login(self, email, password):
222
+ """
223
+ Login to obtain an authentication token
224
+
225
+ Args:
226
+ email (str): User's email
227
+ password (str): User's password
228
+
229
+ Returns:
230
+ dict: The token data returned by the API
231
+
232
+ Raises:
233
+ Exception: If login fails
234
+ """
235
+ url = f"{self.base_url}/login"
236
+ data = {
237
+ "email": email,
238
+ "password": password
239
+ }
240
+
241
+ session = await self._get_session()
242
+ async with session.post(url, json=data) as response:
243
+ if response.status == 200:
244
+ token_data = await response.json()
245
+ self.token = token_data["access_token"]
246
+ return token_data
247
+ else:
248
+ error_data = await response.json()
249
+ error_detail = error_data.get("detail", "Unknown error")
250
+ raise Exception(f"Login failed: {error_detail} (Status: {response.status})")
251
+
252
+ async def get_current_user(self):
253
+ """
254
+ Get information about the current logged-in user
255
+
256
+ Returns:
257
+ dict: The user data returned by the API
258
+
259
+ Raises:
260
+ Exception: If not authenticated or request fails
261
+ """
262
+ if not self.token:
263
+ raise Exception("Not authenticated. Please login first.")
264
+
265
+ url = f"{self.base_url}/users/me"
266
+ headers = {"Authorization": f"Bearer {self.token}"}
267
+
268
+ session = await self._get_session()
269
+ async with session.get(url, headers=headers) as response:
270
+ if response.status == 200:
271
+ return await response.json()
272
+ else:
273
+ error_data = await response.json()
274
+ error_detail = error_data.get("detail", "Unknown error")
275
+ raise Exception(f"Failed to get user info: {error_detail} (Status: {response.status})")
276
+
277
+ def logout(self):
278
+ """Clear the authentication token"""
279
+ self.token = None
280
+
281
+
282
+ # Load testing function
283
+ async def load_test(num_users=10, concurrency=5, base_url="https://amaye15-authenticationapp.hf.space/api"):
284
+ """
285
+ Run a load test with multiple simulated users
286
+
287
+ Args:
288
+ num_users (int): Total number of users to simulate
289
+ concurrency (int): Number of concurrent users
290
+ base_url (str): The base URL of the API
291
+ """
292
+ fake = Faker()
293
+
294
+ start_time = time.time()
295
+ completed = 0
296
+ success_count = 0
297
+ failure_count = 0
298
+
299
+ # Semaphore to limit concurrency
300
+ sem = asyncio.Semaphore(concurrency)
301
+
302
+ # For progress tracking
303
+ progress_lock = asyncio.Lock()
304
+
305
+ async def run_single_user():
306
+ nonlocal completed, success_count, failure_count
307
+
308
+ async with sem: # This limits concurrency
309
+ async with AuthClient(base_url) as client:
310
+ try:
311
+ # Generate random user data
312
+ email = fake.email()
313
+ password = fake.password(length=12, special_chars=True, digits=True,
314
+ upper_case=True, lower_case=True)
315
+
316
+ # Complete user flow
317
+ await client.register(email, password)
318
+ await client.login(email, password)
319
+ await client.get_current_user()
320
+ client.logout()
321
+
322
+ async with progress_lock:
323
+ completed += 1
324
+ success_count += 1
325
+ # Print progress
326
+ print(f"Progress: {completed}/{num_users} users completed", end="\r")
327
+
328
+ except Exception as e:
329
+ async with progress_lock:
330
+ completed += 1
331
+ failure_count += 1
332
+ print(f"Error: {e}")
333
+ print(f"Progress: {completed}/{num_users} users completed", end="\r")
334
+
335
+ # Create all tasks
336
+ tasks = [run_single_user() for _ in range(num_users)]
337
+
338
+ # Display start message
339
+ print(f"Starting load test with {num_users} users (max {concurrency} concurrent)...")
340
+
341
+ # Run all tasks
342
+ await asyncio.gather(*tasks)
343
+
344
+ # Calculate stats
345
+ end_time = time.time()
346
+ duration = end_time - start_time
347
+
348
+ # Display results
349
+ print("\n\n--- Load Test Results ---")
350
+ print(f"Total users: {num_users}")
351
+ print(f"Concurrency level: {concurrency}")
352
+ print(f"Successful flows: {success_count} ({success_count/num_users*100:.1f}%)")
353
+ print(f"Failed flows: {failure_count} ({failure_count/num_users*100:.1f}%)")
354
+ print(f"Total duration: {duration:.2f} seconds")
355
+
356
+ if success_count > 0:
357
+ print(f"Average time per successful user: {duration/success_count:.2f} seconds")
358
+ print(f"Requests per second: {success_count/duration:.2f}")
359
+
360
+
361
+ # Example usage
362
+ async def main():
363
+ # Initialize the client
364
+ base_url = "https://amaye15-authenticationapp.hf.space/api"
365
+
366
+ # Run a simple example with a single user
367
+ fake = Faker()
368
+ async with AuthClient(base_url) as client:
369
+ # Generate random user data
370
+ first_name = fake.first_name()
371
+ last_name = fake.last_name()
372
+ email = fake.email()
373
+ password = fake.password(length=12, special_chars=True, digits=True, upper_case=True, lower_case=True)
374
+
375
+ try:
376
+ # Register a new user
377
+ print(f"Registering a new user: {first_name} {last_name}...")
378
+ user = await client.register(email, password)
379
+ print(f"Registered user: {user}")
380
+
381
+ # Login
382
+ print("\nLogging in...")
383
+ token_data = await client.login(email, password)
384
+ print(f"Login successful, token: {token_data['access_token'][:10]}...")
385
+
386
+ # Get current user
387
+ print("\nGetting current user info...")
388
+ user_info = await client.get_current_user()
389
+ print(f"Current user: {user_info}")
390
+
391
+ # Logout
392
+ print("\nLogging out...")
393
+ client.logout()
394
+ print("Logged out successfully")
395
+
396
+ except Exception as e:
397
+ print(f"Error: {e}")
398
+
399
+ # Run a load test
400
+ print("\nRunning load test...")
401
+ await load_test(10, 5, base_url)
402
+
403
+
404
+ if __name__ == "__main__":
405
+ asyncio.run(main())