Zachary Greathouse
Zg/codebase refactor (#20)
5ed9749 unverified
raw
history blame contribute delete
3.06 kB
# Standard Library Imports
from typing import Callable, Optional, TypeAlias, Union
# Third-Party Library Imports
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.orm import DeclarativeBase
# Local Application Imports
from src.common import Config, logger
# Define the SQLAlchemy Base
class Base(DeclarativeBase):
pass
class DummyAsyncSession:
is_dummy = True # Flag to indicate this is a dummy session.
async def __enter__(self):
return self
async def __exit__(self, exc_type, exc_value, traceback):
pass
async def add(self, _instance, _warn=True):
# No-op: simply ignore adding the instance.
pass
async def commit(self):
# Raise an exception to simulate failure when attempting a write.
raise RuntimeError("DummyAsyncSession does not support commit operations.")
async def refresh(self, _instance):
# Raise an exception to simulate failure when attempting to refresh.
raise RuntimeError("DummyAsyncSession does not support refresh operations.")
async def rollback(self):
# No-op: there's nothing to roll back.
pass
async def close(self):
# No-op: nothing to close.
pass
AsyncDBSessionMaker: TypeAlias = Union[async_sessionmaker[AsyncSession], Callable[[], DummyAsyncSession]]
engine: Optional[AsyncEngine] = None
def init_db(config: Config) -> AsyncDBSessionMaker:
"""
Initialize the database engine and return a session factory based on the provided configuration.
In production, a valid DATABASE_URL is required. In development, if a DATABASE_URL is provided,
it is used; otherwise, a dummy session factory is returned to allow graceful failure.
Args:
config (Config): The application configuration.
Returns:
AsyncDBSessionMaker: A sessionmaker bound to the engine, or a dummy session factory.
"""
# ruff doesn't like setting global variables, but this is practical here
global engine # noqa
if config.app_env == "prod":
# In production, a valid DATABASE_URL is required.
if not config.database_url:
raise ValueError("DATABASE_URL must be set in production!")
async_db_url = config.database_url
engine = create_async_engine(async_db_url)
return async_sessionmaker(bind=engine, expire_on_commit=False, class_=AsyncSession)
# In development, if a DATABASE_URL is provided, use it.
if config.database_url:
async_db_url = config.database_url
engine = create_async_engine(async_db_url)
return async_sessionmaker(bind=engine, expire_on_commit=False, class_=AsyncSession)
# No DATABASE_URL is provided; use a DummyAsyncSession that does nothing.
engine = None
logger.warning("No DATABASE_URL provided - database operations will use DummyAsyncSession")
def async_dummy_session_factory() -> DummyAsyncSession:
return DummyAsyncSession()
return async_dummy_session_factory