File size: 3,062 Bytes
8e2cadd
104737f
8e2cadd
20cccb6
40403f3
 
20cccb6
 
5ed9749
8e2cadd
 
104737f
ad1ff58
 
 
104737f
1ed6720
8e2cadd
40403f3
1ed6720
8e2cadd
40403f3
1ed6720
8e2cadd
40403f3
1ed6720
 
8e2cadd
40403f3
1ed6720
104737f
8e2cadd
40403f3
1ed6720
104737f
20cccb6
40403f3
1ed6720
 
20cccb6
40403f3
1ed6720
 
20cccb6
104737f
 
1ed6720
40403f3
ad1ff58
 
 
 
 
 
 
 
 
 
40403f3
ad1ff58
1ed6720
ad1ff58
1ed6720
 
 
 
 
104737f
de305ed
40403f3
104737f
40403f3
ad1ff58
1ed6720
 
de305ed
40403f3
104737f
40403f3
ad1ff58
104737f
1ed6720
104737f
1ed6720
104737f
 
40403f3
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
# Standard Library Imports
from typing import Callable, Optional, TypeAlias, Union

# Third-Party Library Imports
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.orm import DeclarativeBase

# Local Application Imports
from src.common import Config, logger


# Define the SQLAlchemy Base
class Base(DeclarativeBase):
    pass

class DummyAsyncSession:
    is_dummy = True  # Flag to indicate this is a dummy session.

    async def __enter__(self):
        return self

    async def __exit__(self, exc_type, exc_value, traceback):
        pass

    async def add(self, _instance, _warn=True):
        # No-op: simply ignore adding the instance.
        pass

    async def commit(self):
        # Raise an exception to simulate failure when attempting a write.
        raise RuntimeError("DummyAsyncSession does not support commit operations.")

    async def refresh(self, _instance):
        # Raise an exception to simulate failure when attempting to refresh.
        raise RuntimeError("DummyAsyncSession does not support refresh operations.")

    async def rollback(self):
        # No-op: there's nothing to roll back.
        pass

    async def close(self):
        # No-op: nothing to close.
        pass

AsyncDBSessionMaker: TypeAlias = Union[async_sessionmaker[AsyncSession], Callable[[], DummyAsyncSession]]
engine: Optional[AsyncEngine] = None

def init_db(config: Config) -> AsyncDBSessionMaker:
    """
    Initialize the database engine and return a session factory based on the provided configuration.

    In production, a valid DATABASE_URL is required. In development, if a DATABASE_URL is provided,
    it is used; otherwise, a dummy session factory is returned to allow graceful failure.

    Args:
        config (Config): The application configuration.

    Returns:
        AsyncDBSessionMaker: A sessionmaker bound to the engine, or a dummy session factory.
    """
    # ruff doesn't like setting global variables, but this is practical here
    global engine  # noqa

    if config.app_env == "prod":
        # In production, a valid DATABASE_URL is required.
        if not config.database_url:
            raise ValueError("DATABASE_URL must be set in production!")

        async_db_url = config.database_url
        engine = create_async_engine(async_db_url)

        return async_sessionmaker(bind=engine, expire_on_commit=False, class_=AsyncSession)

    # In development, if a DATABASE_URL is provided, use it.
    if config.database_url:
        async_db_url = config.database_url
        engine = create_async_engine(async_db_url)

        return async_sessionmaker(bind=engine, expire_on_commit=False, class_=AsyncSession)

    # No DATABASE_URL is provided; use a DummyAsyncSession that does nothing.
    engine = None
    logger.warning("No DATABASE_URL provided - database operations will use DummyAsyncSession")

    def async_dummy_session_factory() -> DummyAsyncSession:
        return DummyAsyncSession()

    return async_dummy_session_factory