zach commited on
Commit
ad1ff58
·
1 Parent(s): a4afe51

Fix types in database package

Browse files
src/config.py CHANGED
@@ -20,10 +20,10 @@ from typing import TYPE_CHECKING, ClassVar, Optional
20
  # Third-Party Library Imports
21
  from dotenv import load_dotenv
22
 
 
23
  if TYPE_CHECKING:
24
  from src.integrations import AnthropicConfig, ElevenLabsConfig, HumeConfig
25
 
26
-
27
  logger: logging.Logger = logging.getLogger("tts_arena")
28
 
29
 
 
20
  # Third-Party Library Imports
21
  from dotenv import load_dotenv
22
 
23
+ # Local Application Imports
24
  if TYPE_CHECKING:
25
  from src.integrations import AnthropicConfig, ElevenLabsConfig, HumeConfig
26
 
 
27
  logger: logging.Logger = logging.getLogger("tts_arena")
28
 
29
 
src/database/crud.py CHANGED
@@ -32,7 +32,7 @@ def create_vote(db: Session, vote_data: VotingResults) -> VoteResult:
32
  option_b_provider=vote_data["option_b_provider"],
33
  option_a_generation_id=vote_data["option_a_generation_id"],
34
  option_b_generation_id=vote_data["option_b_generation_id"],
35
- voice_description=vote_data["voice_description"],
36
  text=vote_data["text"],
37
  is_custom_text=vote_data["is_custom_text"],
38
  )
 
32
  option_b_provider=vote_data["option_b_provider"],
33
  option_a_generation_id=vote_data["option_a_generation_id"],
34
  option_b_generation_id=vote_data["option_b_generation_id"],
35
+ voice_description=vote_data["character_description"],
36
  text=vote_data["text"],
37
  is_custom_text=vote_data["is_custom_text"],
38
  )
src/database/database.py CHANGED
@@ -5,7 +5,7 @@ This module sets up the SQLAlchemy database connection for the Expressive TTS Ar
5
  It initializes the PostgreSQL engine, creates a session factory for handling database transactions,
6
  and defines a declarative base class for ORM models.
7
 
8
- If no DATABASE_URL environment variable is set, then create a dummy database to fail gracefully
9
  """
10
 
11
  # Standard Library Imports
@@ -13,12 +13,20 @@ from typing import Callable, Optional
13
 
14
  # Third-Party Library Imports
15
  from sqlalchemy import Engine, create_engine
16
- from sqlalchemy.orm import declarative_base, sessionmaker
17
 
18
  # Local Application Imports
19
  from src.config import Config
20
 
21
 
 
 
 
 
 
 
 
 
22
  class DummySession:
23
  is_dummy = True # Flag to indicate this is a dummy session.
24
 
@@ -49,32 +57,42 @@ class DummySession:
49
  pass
50
 
51
 
52
- Base = declarative_base()
53
- engine: Optional[Engine] = None
54
-
55
-
56
  DBSessionMaker = sessionmaker | Callable[[], DummySession]
57
 
58
 
59
  def init_db(config: Config) -> DBSessionMaker:
 
 
 
 
 
 
 
 
 
 
 
 
60
  # ruff doesn't like setting global variables, but this is practical here
61
- global engine # noqa
62
 
63
  if config.app_env == "prod":
64
  # In production, a valid DATABASE_URL is required.
65
  if not config.database_url:
66
  raise ValueError("DATABASE_URL must be set in production!")
67
-
68
  engine = create_engine(config.database_url)
69
  return sessionmaker(bind=engine)
 
70
  # In development, if a DATABASE_URL is provided, use it.
71
  if config.database_url:
72
  engine = create_engine(config.database_url)
73
  return sessionmaker(bind=engine)
 
74
  # No DATABASE_URL is provided; use a DummySession that does nothing.
75
  engine = None
76
 
77
- def dummy_session_factory():
78
  return DummySession()
79
 
80
  return dummy_session_factory
 
5
  It initializes the PostgreSQL engine, creates a session factory for handling database transactions,
6
  and defines a declarative base class for ORM models.
7
 
8
+ If no DATABASE_URL environment variable is set, then create a dummy database to fail gracefully.
9
  """
10
 
11
  # Standard Library Imports
 
13
 
14
  # Third-Party Library Imports
15
  from sqlalchemy import Engine, create_engine
16
+ from sqlalchemy.orm import DeclarativeBase, sessionmaker
17
 
18
  # Local Application Imports
19
  from src.config import Config
20
 
21
 
22
+ # Define the SQLAlchemy Base using SQLAlchemy 2.0 style.
23
+ class Base(DeclarativeBase):
24
+ pass
25
+
26
+
27
+ engine: Optional[Engine] = None
28
+
29
+
30
  class DummySession:
31
  is_dummy = True # Flag to indicate this is a dummy session.
32
 
 
57
  pass
58
 
59
 
60
+ # DBSessionMaker is either a sessionmaker instance or a callable that returns a DummySession.
 
 
 
61
  DBSessionMaker = sessionmaker | Callable[[], DummySession]
62
 
63
 
64
  def init_db(config: Config) -> DBSessionMaker:
65
+ """
66
+ Initialize the database engine and return a session factory based on the provided configuration.
67
+
68
+ In production, a valid DATABASE_URL is required. In development, if a DATABASE_URL is provided,
69
+ it is used; otherwise, a dummy session factory is returned to allow graceful failure.
70
+
71
+ Args:
72
+ config (Config): The application configuration.
73
+
74
+ Returns:
75
+ DBSessionMaker: A sessionmaker bound to the engine, or a dummy session factory.
76
+ """
77
  # ruff doesn't like setting global variables, but this is practical here
78
+ global engine # noqa
79
 
80
  if config.app_env == "prod":
81
  # In production, a valid DATABASE_URL is required.
82
  if not config.database_url:
83
  raise ValueError("DATABASE_URL must be set in production!")
 
84
  engine = create_engine(config.database_url)
85
  return sessionmaker(bind=engine)
86
+
87
  # In development, if a DATABASE_URL is provided, use it.
88
  if config.database_url:
89
  engine = create_engine(config.database_url)
90
  return sessionmaker(bind=engine)
91
+
92
  # No DATABASE_URL is provided; use a DummySession that does nothing.
93
  engine = None
94
 
95
+ def dummy_session_factory() -> DummySession:
96
  return DummySession()
97
 
98
  return dummy_session_factory
src/database/models.py CHANGED
@@ -42,7 +42,7 @@ class VoteResult(Base):
42
  created_at = Column(DateTime, nullable=False, server_default=func.now())
43
  comparison_type = Column(String(50), nullable=False)
44
  winning_provider = Column(String(50), nullable=False)
45
- winning_option = Column(saEnum(OptionEnum), nullable=False)
46
  option_a_provider = Column(String(50), nullable=False)
47
  option_b_provider = Column(String(50), nullable=False)
48
  option_a_generation_id = Column(String(100), nullable=True)
 
42
  created_at = Column(DateTime, nullable=False, server_default=func.now())
43
  comparison_type = Column(String(50), nullable=False)
44
  winning_provider = Column(String(50), nullable=False)
45
+ winning_option = Column(saEnum(OptionEnum), nullable=False) # type: ignore
46
  option_a_provider = Column(String(50), nullable=False)
47
  option_b_provider = Column(String(50), nullable=False)
48
  option_a_generation_id = Column(String(100), nullable=True)