Spaces:
Sleeping
Sleeping
import os | |
import re | |
import logging | |
from sqlalchemy import inspect | |
from sqlalchemy.sql import text | |
from alembic.config import Config | |
from alembic import command | |
import common.dependencies as DI | |
logger = logging.getLogger(__name__) | |
def get_old_versions(): | |
old_versions = list() | |
migration_dir = 'components/dbo/alembic/versions' | |
for file in os.listdir(migration_dir): | |
if not file.endswith('.py'): | |
continue | |
file_path = os.path.join(migration_dir, file) | |
with open(file_path, 'r', encoding='utf-8') as f: | |
content = f.read() | |
match = re.search( | |
r"^(down_revision: Union\[str, None\] = )(None|'[^']*')", | |
content, | |
re.MULTILINE) | |
if match: | |
old_versions.append(match.group(2).replace("'", "")) | |
return old_versions | |
def get_cur_version(): | |
session_factory = DI.get_db() | |
session: Session = session_factory() | |
try: | |
inspector = inspect(session.bind) | |
if 'alembic_version' not in inspector.get_table_names(): | |
return None | |
result = session.execute(text("SELECT version_num FROM alembic_version")).scalar() | |
return result | |
finally: | |
session.close() | |
def update(): | |
old_versions = get_old_versions() | |
cur_version = get_cur_version() | |
if cur_version not in old_versions and cur_version is not None: | |
return | |
logger.info(f"Updating the database from migration {cur_version}") | |
config = Config("alembic.ini") | |
command.upgrade(config, "head") | |