File size: 3,858 Bytes
bb6d7b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
"""
Database initialization for the application.

This script checks if the database is initialized and creates tables if needed.
It's meant to be imported and run at application startup.
"""
import os
import logging
import asyncio
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import sessionmaker
from sqlalchemy.future import select
import subprocess
import sys

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)

# Database URL from environment
db_url = os.getenv("DATABASE_URL", "")
if db_url.startswith("postgresql://"):
    # Remove sslmode parameter if present which causes issues with asyncpg
    if "?" in db_url:
        base_url, params = db_url.split("?", 1)
        param_list = params.split("&")
        filtered_params = [p for p in param_list if not p.startswith("sslmode=")]
        if filtered_params:
            db_url = f"{base_url}?{'&'.join(filtered_params)}"
        else:
            db_url = base_url
    
    ASYNC_DATABASE_URL = db_url.replace("postgresql://", "postgresql+asyncpg://", 1)
else:
    ASYNC_DATABASE_URL = "postgresql+asyncpg://postgres:postgres@localhost:5432/postgres"


async def check_db_initialized():
    """Check if the database is initialized with required tables."""
    try:
        engine = create_async_engine(
            ASYNC_DATABASE_URL,
            echo=False,
        )
        
        # Create session factory
        async_session = sessionmaker(
            engine, 
            class_=AsyncSession, 
            expire_on_commit=False
        )
        
        async with async_session() as session:
            # Try to query tables
            # Replace with actual table names once you've defined them
            try:
                # Check if the 'users' table exists
                from sqlalchemy import text
                query = text("SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = 'users')")
                result = await session.execute(query)
                exists = result.scalar()
                
                if exists:
                    logger.info("Database is initialized.")
                    return True
                else:
                    logger.warning("Database tables are not initialized.")
                    return False
            except Exception as e:
                logger.error(f"Error checking tables: {e}")
                return False
    except Exception as e:
        logger.error(f"Failed to connect to database: {e}")
        return False


def initialize_database():
    """Initialize the database with required tables."""
    try:
        # Call the init_db.py script
        logger.info("Initializing database...")
        
        # Get the current directory
        current_dir = os.path.dirname(os.path.abspath(__file__))
        script_path = os.path.join(current_dir, "scripts", "init_db.py")
        
        # Run the script using the current Python interpreter
        result = subprocess.run([sys.executable, script_path], capture_output=True, text=True)
        
        if result.returncode == 0:
            logger.info("Database initialized successfully.")
            logger.debug(result.stdout)
            return True
        else:
            logger.error(f"Failed to initialize database: {result.stderr}")
            return False
    except Exception as e:
        logger.error(f"Error initializing database: {e}")
        return False


def ensure_database_initialized():
    """Ensure the database is initialized with required tables."""
    is_initialized = asyncio.run(check_db_initialized())
    
    if not is_initialized:
        return initialize_database()
    
    return True