File size: 3,036 Bytes
bfe88a9
9e56290
bfe88a9
9e56290
4809b28
9e56290
bfe88a9
 
4809b28
bfe88a9
d576ad8
3222a21
9e56290
 
 
d576ad8
9e56290
 
 
 
d576ad8
9e56290
 
 
 
 
 
 
 
 
4809b28
9e56290
bfe88a9
 
 
 
 
 
 
9e56290
 
1cff830
9e56290
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d576ad8
bfe88a9
9e56290
 
 
 
 
 
 
 
1cff830
9e56290
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import os
from databases import Database
from dotenv import load_dotenv
from sqlalchemy import MetaData, Table, Column, Integer, String
import logging
from urllib.parse import urlparse, urlunparse, parse_qs, urlencode

load_dotenv()
logger = logging.getLogger(__name__)


DEFAULT_DB_PATH = "/tmp/app.db"
raw_db_url = os.getenv("DATABASE_URL", f"sqlite+aiosqlite:///{DEFAULT_DB_PATH}")

final_database_url = raw_db_url

if raw_db_url.startswith("sqlite+aiosqlite"):
    parsed_url = urlparse(raw_db_url)
    query_params = parse_qs(parsed_url.query)
    if 'check_same_thread' not in query_params:
        query_params['check_same_thread'] = ['False'] 
        new_query = urlencode(query_params, doseq=True)
        final_database_url = urlunparse(parsed_url._replace(query=new_query))
    logger.info(f"Using final async DB URL: {final_database_url}")
else:
    logger.info(f"Using non-SQLite async DB URL: {final_database_url}")


database = Database(final_database_url)

metadata = MetaData()
users = Table(
    "users",
    metadata,
    Column("id", Integer, primary_key=True),
    Column("email", String, unique=True, index=True, nullable=False),
    Column("hashed_password", String, nullable=False),
)

async def connect_db():
    """Connects to the database defined by 'final_database_url'."""
    try:
        if final_database_url.startswith("sqlite"):
            db_file_path = final_database_url.split("sqlite:///")[-1].split("?")[0]
            db_dir = os.path.dirname(db_file_path)
            if db_dir:
                 if not os.path.exists(db_dir):
                     logger.info(f"DB directory '{db_dir}' missing, creating.")
                     try:
                         os.makedirs(db_dir, exist_ok=True)
                     except Exception as mkdir_err:
                         logger.error(f"Failed to create DB directory '{db_dir}': {mkdir_err}")
                 # Check writability (best effort)
                 if os.path.exists(db_dir) and not os.access(db_dir, os.W_OK):
                      logger.error(f"CRITICAL: DB directory '{db_dir}' exists but is not writable!")
                 elif not os.path.exists(db_dir):
                      logger.error(f"CRITICAL: DB directory '{db_dir}' does not exist and could not be created!")

        if not database.is_connected:
            await database.connect()
            logger.info(f"Database connection established: {final_database_url}")
        else:
            logger.info("Database connection already established.")
    except Exception as e:
        logger.exception(f"FATAL: Failed to establish async database connection: {e}")
        raise

async def disconnect_db():
    """Disconnects from the database if connected."""
    try:
        if database.is_connected:
             await database.disconnect()
             logger.info("Database connection closed.")
        else:
             logger.info("Database already disconnected.")
    except Exception as e:
        logger.exception(f"Error closing database connection: {e}")