Spaces:
Running
Running
zach
commited on
Commit
·
ad1ff58
1
Parent(s):
a4afe51
Fix types in database package
Browse files- src/config.py +1 -1
- src/database/crud.py +1 -1
- src/database/database.py +27 -9
- src/database/models.py +1 -1
src/config.py
CHANGED
@@ -20,10 +20,10 @@ from typing import TYPE_CHECKING, ClassVar, Optional
|
|
20 |
# Third-Party Library Imports
|
21 |
from dotenv import load_dotenv
|
22 |
|
|
|
23 |
if TYPE_CHECKING:
|
24 |
from src.integrations import AnthropicConfig, ElevenLabsConfig, HumeConfig
|
25 |
|
26 |
-
|
27 |
logger: logging.Logger = logging.getLogger("tts_arena")
|
28 |
|
29 |
|
|
|
20 |
# Third-Party Library Imports
|
21 |
from dotenv import load_dotenv
|
22 |
|
23 |
+
# Local Application Imports
|
24 |
if TYPE_CHECKING:
|
25 |
from src.integrations import AnthropicConfig, ElevenLabsConfig, HumeConfig
|
26 |
|
|
|
27 |
logger: logging.Logger = logging.getLogger("tts_arena")
|
28 |
|
29 |
|
src/database/crud.py
CHANGED
@@ -32,7 +32,7 @@ def create_vote(db: Session, vote_data: VotingResults) -> VoteResult:
|
|
32 |
option_b_provider=vote_data["option_b_provider"],
|
33 |
option_a_generation_id=vote_data["option_a_generation_id"],
|
34 |
option_b_generation_id=vote_data["option_b_generation_id"],
|
35 |
-
voice_description=vote_data["
|
36 |
text=vote_data["text"],
|
37 |
is_custom_text=vote_data["is_custom_text"],
|
38 |
)
|
|
|
32 |
option_b_provider=vote_data["option_b_provider"],
|
33 |
option_a_generation_id=vote_data["option_a_generation_id"],
|
34 |
option_b_generation_id=vote_data["option_b_generation_id"],
|
35 |
+
voice_description=vote_data["character_description"],
|
36 |
text=vote_data["text"],
|
37 |
is_custom_text=vote_data["is_custom_text"],
|
38 |
)
|
src/database/database.py
CHANGED
@@ -5,7 +5,7 @@ This module sets up the SQLAlchemy database connection for the Expressive TTS Ar
|
|
5 |
It initializes the PostgreSQL engine, creates a session factory for handling database transactions,
|
6 |
and defines a declarative base class for ORM models.
|
7 |
|
8 |
-
If no DATABASE_URL environment variable is set, then create a dummy database to fail gracefully
|
9 |
"""
|
10 |
|
11 |
# Standard Library Imports
|
@@ -13,12 +13,20 @@ from typing import Callable, Optional
|
|
13 |
|
14 |
# Third-Party Library Imports
|
15 |
from sqlalchemy import Engine, create_engine
|
16 |
-
from sqlalchemy.orm import
|
17 |
|
18 |
# Local Application Imports
|
19 |
from src.config import Config
|
20 |
|
21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
class DummySession:
|
23 |
is_dummy = True # Flag to indicate this is a dummy session.
|
24 |
|
@@ -49,32 +57,42 @@ class DummySession:
|
|
49 |
pass
|
50 |
|
51 |
|
52 |
-
|
53 |
-
engine: Optional[Engine] = None
|
54 |
-
|
55 |
-
|
56 |
DBSessionMaker = sessionmaker | Callable[[], DummySession]
|
57 |
|
58 |
|
59 |
def init_db(config: Config) -> DBSessionMaker:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
# ruff doesn't like setting global variables, but this is practical here
|
61 |
-
global engine
|
62 |
|
63 |
if config.app_env == "prod":
|
64 |
# In production, a valid DATABASE_URL is required.
|
65 |
if not config.database_url:
|
66 |
raise ValueError("DATABASE_URL must be set in production!")
|
67 |
-
|
68 |
engine = create_engine(config.database_url)
|
69 |
return sessionmaker(bind=engine)
|
|
|
70 |
# In development, if a DATABASE_URL is provided, use it.
|
71 |
if config.database_url:
|
72 |
engine = create_engine(config.database_url)
|
73 |
return sessionmaker(bind=engine)
|
|
|
74 |
# No DATABASE_URL is provided; use a DummySession that does nothing.
|
75 |
engine = None
|
76 |
|
77 |
-
def dummy_session_factory():
|
78 |
return DummySession()
|
79 |
|
80 |
return dummy_session_factory
|
|
|
5 |
It initializes the PostgreSQL engine, creates a session factory for handling database transactions,
|
6 |
and defines a declarative base class for ORM models.
|
7 |
|
8 |
+
If no DATABASE_URL environment variable is set, then create a dummy database to fail gracefully.
|
9 |
"""
|
10 |
|
11 |
# Standard Library Imports
|
|
|
13 |
|
14 |
# Third-Party Library Imports
|
15 |
from sqlalchemy import Engine, create_engine
|
16 |
+
from sqlalchemy.orm import DeclarativeBase, sessionmaker
|
17 |
|
18 |
# Local Application Imports
|
19 |
from src.config import Config
|
20 |
|
21 |
|
22 |
+
# Define the SQLAlchemy Base using SQLAlchemy 2.0 style.
|
23 |
+
class Base(DeclarativeBase):
|
24 |
+
pass
|
25 |
+
|
26 |
+
|
27 |
+
engine: Optional[Engine] = None
|
28 |
+
|
29 |
+
|
30 |
class DummySession:
|
31 |
is_dummy = True # Flag to indicate this is a dummy session.
|
32 |
|
|
|
57 |
pass
|
58 |
|
59 |
|
60 |
+
# DBSessionMaker is either a sessionmaker instance or a callable that returns a DummySession.
|
|
|
|
|
|
|
61 |
DBSessionMaker = sessionmaker | Callable[[], DummySession]
|
62 |
|
63 |
|
64 |
def init_db(config: Config) -> DBSessionMaker:
|
65 |
+
"""
|
66 |
+
Initialize the database engine and return a session factory based on the provided configuration.
|
67 |
+
|
68 |
+
In production, a valid DATABASE_URL is required. In development, if a DATABASE_URL is provided,
|
69 |
+
it is used; otherwise, a dummy session factory is returned to allow graceful failure.
|
70 |
+
|
71 |
+
Args:
|
72 |
+
config (Config): The application configuration.
|
73 |
+
|
74 |
+
Returns:
|
75 |
+
DBSessionMaker: A sessionmaker bound to the engine, or a dummy session factory.
|
76 |
+
"""
|
77 |
# ruff doesn't like setting global variables, but this is practical here
|
78 |
+
global engine # noqa
|
79 |
|
80 |
if config.app_env == "prod":
|
81 |
# In production, a valid DATABASE_URL is required.
|
82 |
if not config.database_url:
|
83 |
raise ValueError("DATABASE_URL must be set in production!")
|
|
|
84 |
engine = create_engine(config.database_url)
|
85 |
return sessionmaker(bind=engine)
|
86 |
+
|
87 |
# In development, if a DATABASE_URL is provided, use it.
|
88 |
if config.database_url:
|
89 |
engine = create_engine(config.database_url)
|
90 |
return sessionmaker(bind=engine)
|
91 |
+
|
92 |
# No DATABASE_URL is provided; use a DummySession that does nothing.
|
93 |
engine = None
|
94 |
|
95 |
+
def dummy_session_factory() -> DummySession:
|
96 |
return DummySession()
|
97 |
|
98 |
return dummy_session_factory
|
src/database/models.py
CHANGED
@@ -42,7 +42,7 @@ class VoteResult(Base):
|
|
42 |
created_at = Column(DateTime, nullable=False, server_default=func.now())
|
43 |
comparison_type = Column(String(50), nullable=False)
|
44 |
winning_provider = Column(String(50), nullable=False)
|
45 |
-
winning_option = Column(saEnum(OptionEnum), nullable=False)
|
46 |
option_a_provider = Column(String(50), nullable=False)
|
47 |
option_b_provider = Column(String(50), nullable=False)
|
48 |
option_a_generation_id = Column(String(100), nullable=True)
|
|
|
42 |
created_at = Column(DateTime, nullable=False, server_default=func.now())
|
43 |
comparison_type = Column(String(50), nullable=False)
|
44 |
winning_provider = Column(String(50), nullable=False)
|
45 |
+
winning_option = Column(saEnum(OptionEnum), nullable=False) # type: ignore
|
46 |
option_a_provider = Column(String(50), nullable=False)
|
47 |
option_b_provider = Column(String(50), nullable=False)
|
48 |
option_a_generation_id = Column(String(100), nullable=True)
|