Spaces:
Running
Running
File size: 3,062 Bytes
8e2cadd 104737f 8e2cadd 20cccb6 40403f3 20cccb6 5ed9749 8e2cadd 104737f ad1ff58 104737f 1ed6720 8e2cadd 40403f3 1ed6720 8e2cadd 40403f3 1ed6720 8e2cadd 40403f3 1ed6720 8e2cadd 40403f3 1ed6720 104737f 8e2cadd 40403f3 1ed6720 104737f 20cccb6 40403f3 1ed6720 20cccb6 40403f3 1ed6720 20cccb6 104737f 1ed6720 40403f3 ad1ff58 40403f3 ad1ff58 1ed6720 ad1ff58 1ed6720 104737f de305ed 40403f3 104737f 40403f3 ad1ff58 1ed6720 de305ed 40403f3 104737f 40403f3 ad1ff58 104737f 1ed6720 104737f 1ed6720 104737f 40403f3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
# Standard Library Imports
from typing import Callable, Optional, TypeAlias, Union
# Third-Party Library Imports
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.orm import DeclarativeBase
# Local Application Imports
from src.common import Config, logger
# Define the SQLAlchemy Base
class Base(DeclarativeBase):
pass
class DummyAsyncSession:
is_dummy = True # Flag to indicate this is a dummy session.
async def __enter__(self):
return self
async def __exit__(self, exc_type, exc_value, traceback):
pass
async def add(self, _instance, _warn=True):
# No-op: simply ignore adding the instance.
pass
async def commit(self):
# Raise an exception to simulate failure when attempting a write.
raise RuntimeError("DummyAsyncSession does not support commit operations.")
async def refresh(self, _instance):
# Raise an exception to simulate failure when attempting to refresh.
raise RuntimeError("DummyAsyncSession does not support refresh operations.")
async def rollback(self):
# No-op: there's nothing to roll back.
pass
async def close(self):
# No-op: nothing to close.
pass
AsyncDBSessionMaker: TypeAlias = Union[async_sessionmaker[AsyncSession], Callable[[], DummyAsyncSession]]
engine: Optional[AsyncEngine] = None
def init_db(config: Config) -> AsyncDBSessionMaker:
"""
Initialize the database engine and return a session factory based on the provided configuration.
In production, a valid DATABASE_URL is required. In development, if a DATABASE_URL is provided,
it is used; otherwise, a dummy session factory is returned to allow graceful failure.
Args:
config (Config): The application configuration.
Returns:
AsyncDBSessionMaker: A sessionmaker bound to the engine, or a dummy session factory.
"""
# ruff doesn't like setting global variables, but this is practical here
global engine # noqa
if config.app_env == "prod":
# In production, a valid DATABASE_URL is required.
if not config.database_url:
raise ValueError("DATABASE_URL must be set in production!")
async_db_url = config.database_url
engine = create_async_engine(async_db_url)
return async_sessionmaker(bind=engine, expire_on_commit=False, class_=AsyncSession)
# In development, if a DATABASE_URL is provided, use it.
if config.database_url:
async_db_url = config.database_url
engine = create_async_engine(async_db_url)
return async_sessionmaker(bind=engine, expire_on_commit=False, class_=AsyncSession)
# No DATABASE_URL is provided; use a DummyAsyncSession that does nothing.
engine = None
logger.warning("No DATABASE_URL provided - database operations will use DummyAsyncSession")
def async_dummy_session_factory() -> DummyAsyncSession:
return DummyAsyncSession()
return async_dummy_session_factory
|