zach commited on
Commit
8e2cadd
·
1 Parent(s): 704294b

Update db logic and behavior in development environment

Browse files
src/database/__init__.py CHANGED
@@ -1,11 +1,9 @@
1
  from .crud import create_vote
2
  from .database import Base, SessionLocal, engine
3
- from .models import VoteResult
4
 
5
  __all__ = [
6
  "Base",
7
  "SessionLocal",
8
- "VoteResult",
9
  "create_vote",
10
  "engine"
11
  ]
 
1
  from .crud import create_vote
2
  from .database import Base, SessionLocal, engine
 
3
 
4
  __all__ = [
5
  "Base",
6
  "SessionLocal",
 
7
  "create_vote",
8
  "engine"
9
  ]
src/database/database.py CHANGED
@@ -4,23 +4,70 @@ database.py
4
  This module sets up the SQLAlchemy database connection for the Expressive TTS Arena project.
5
  It initializes the PostgreSQL engine, creates a session factory for handling database transactions,
6
  and defines a declarative base class for ORM models.
 
 
7
  """
8
 
 
 
 
9
  # Third-Party Library Imports
10
  from sqlalchemy import create_engine
11
  from sqlalchemy.orm import declarative_base, sessionmaker
12
 
13
  # Local Application Imports
14
- from src.config import validate_env_var
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
- # Validate and retrieve the database URL from environment variables
17
- DATABASE_URL = validate_env_var("DATABASE_URL")
 
18
 
19
- # Create the database engine using the validated URL
20
- engine = create_engine(DATABASE_URL)
21
 
22
- # Create a session factory for database transactions
23
- SessionLocal = sessionmaker(bind=engine)
24
 
25
- # Declarative base class for ORM models
26
  Base = declarative_base()
 
4
  This module sets up the SQLAlchemy database connection for the Expressive TTS Arena project.
5
  It initializes the PostgreSQL engine, creates a session factory for handling database transactions,
6
  and defines a declarative base class for ORM models.
7
+
8
+ If no DATABASE_URL environment variable is set, then create a dummy database to fail gracefully
9
  """
10
 
11
+ # Standard Library Imports
12
+ import os
13
+
14
  # Third-Party Library Imports
15
  from sqlalchemy import create_engine
16
  from sqlalchemy.orm import declarative_base, sessionmaker
17
 
18
  # Local Application Imports
19
+ from src.config import APP_ENV
20
+
21
+ DATABASE_URL = os.getenv("DATABASE_URL")
22
+
23
+ if APP_ENV == "prod":
24
+ # In production, a valid DATABASE_URL is required.
25
+ if not DATABASE_URL:
26
+ raise ValueError("DATABASE_URL must be set in production!")
27
+
28
+ engine = create_engine(DATABASE_URL)
29
+ SessionLocal = sessionmaker(bind=engine)
30
+ # In development, if a DATABASE_URL is provided, use it.
31
+ elif DATABASE_URL:
32
+ engine = create_engine(DATABASE_URL)
33
+ SessionLocal = sessionmaker(bind=engine)
34
+ else:
35
+ # No DATABASE_URL is provided; use a DummySession that does nothing.
36
+ engine = None
37
+
38
+ class DummySession:
39
+ is_dummy = True # Flag to indicate this is a dummy session.
40
+
41
+ def __enter__(self):
42
+ return self
43
+
44
+ def __exit__(self, exc_type, exc_value, traceback):
45
+ pass
46
+
47
+ def add(self, _instance, _warn=True):
48
+ # No-op: simply ignore adding the instance.
49
+ pass
50
+
51
+ def commit(self):
52
+ # Raise an exception to simulate failure when attempting a write.
53
+ raise RuntimeError("DummySession does not support commit operations.")
54
+
55
+ def refresh(self, _instance):
56
+ # Raise an exception to simulate failure when attempting to refresh.
57
+ raise RuntimeError("DummySession does not support refresh operations.")
58
+
59
+ def rollback(self):
60
+ # No-op: there's nothing to roll back.
61
+ pass
62
 
63
+ def close(self):
64
+ # No-op: nothing to close.
65
+ pass
66
 
67
+ def dummy_session_factory():
68
+ return DummySession()
69
 
70
+ SessionLocal = dummy_session_factory
 
71
 
72
+ # Declarative base class for ORM models.
73
  Base = declarative_base()
src/utils.py CHANGED
@@ -15,7 +15,7 @@ from typing import Tuple
15
 
16
  # Local Application Imports
17
  from src import constants
18
- from src.config import AUDIO_DIR, logger
19
  from src.custom_types import (
20
  ComparisonType,
21
  Option,
@@ -24,7 +24,7 @@ from src.custom_types import (
24
  TTSProviderName,
25
  VotingResults,
26
  )
27
- from src.database import SessionLocal, VoteResult, crud
28
 
29
 
30
  def truncate_text(text: str, max_length: int = 50) -> str:
@@ -306,13 +306,39 @@ def determine_comparison_type(
306
  raise ValueError(f"Invalid provider combination: {provider_a}, {provider_b}")
307
 
308
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
309
  def submit_voting_results(
310
  option_map: OptionMap,
311
  selected_option: str,
312
  text_modified: bool,
313
  character_description: str,
314
  text: str,
315
- ) -> VoteResult:
316
  """
317
  Constructs the voting results dictionary from the provided inputs,
318
  logs it, persists a new vote record in the database, and returns the record.
@@ -323,9 +349,6 @@ def submit_voting_results(
323
  text_modified (bool): Indicates whether the text was modified.
324
  character_description (str): Description of the voice/character.
325
  text (str): The text associated with the TTS generation.
326
-
327
- Returns:
328
- VoteResult: The newly created vote record from the database.
329
  """
330
  provider_a: TTSProviderName = option_map[constants.OPTION_A_KEY]["provider"]
331
  provider_b: TTSProviderName = option_map[constants.OPTION_B_KEY]["provider"]
@@ -344,14 +367,17 @@ def submit_voting_results(
344
  "is_custom_text": text_modified,
345
  }
346
 
347
- logger.info("Voting results:\n%s", json.dumps(voting_results, indent=4))
348
-
349
- # Create a new database session, persist the vote record, and then close the session.
350
  db = SessionLocal()
 
351
  try:
352
- vote_record = crud.create_vote(db, voting_results)
353
- logger.info("Vote record created successfully")
 
 
 
 
 
354
  finally:
355
  db.close()
356
 
357
- return vote_record
 
15
 
16
  # Local Application Imports
17
  from src import constants
18
+ from src.config import APP_ENV, AUDIO_DIR, logger
19
  from src.custom_types import (
20
  ComparisonType,
21
  Option,
 
24
  TTSProviderName,
25
  VotingResults,
26
  )
27
+ from src.database import SessionLocal, crud
28
 
29
 
30
  def truncate_text(text: str, max_length: int = 50) -> str:
 
306
  raise ValueError(f"Invalid provider combination: {provider_a}, {provider_b}")
307
 
308
 
309
+ def log_voting_results(voting_results: VotingResults) -> None:
310
+ """Log the full voting results."""
311
+ logger.info("Voting results:\n%s", json.dumps(voting_results, indent=4))
312
+
313
+
314
+ def handle_vote_failure(e: Exception, voting_results: VotingResults, is_dummy_db_session: bool) -> None:
315
+ """
316
+ Handles logging when creating a vote record fails.
317
+
318
+ In production (or in dev with a real session):
319
+ - Logs the error (with full traceback in prod) and the voting results.
320
+ - In production, re-raises the exception.
321
+
322
+ In development with a dummy session:
323
+ - Only logs the voting results.
324
+ """
325
+ if APP_ENV == "prod" or (APP_ENV == "dev" and not is_dummy_db_session):
326
+ logger.error("Failed to create vote record: %s", e, exc_info=(APP_ENV == "prod"))
327
+ log_voting_results(voting_results)
328
+ if APP_ENV == "prod":
329
+ raise e
330
+ else:
331
+ # Dev mode with a dummy session: only log the voting results.
332
+ log_voting_results(voting_results)
333
+
334
+
335
  def submit_voting_results(
336
  option_map: OptionMap,
337
  selected_option: str,
338
  text_modified: bool,
339
  character_description: str,
340
  text: str,
341
+ ) -> None:
342
  """
343
  Constructs the voting results dictionary from the provided inputs,
344
  logs it, persists a new vote record in the database, and returns the record.
 
349
  text_modified (bool): Indicates whether the text was modified.
350
  character_description (str): Description of the voice/character.
351
  text (str): The text associated with the TTS generation.
 
 
 
352
  """
353
  provider_a: TTSProviderName = option_map[constants.OPTION_A_KEY]["provider"]
354
  provider_b: TTSProviderName = option_map[constants.OPTION_B_KEY]["provider"]
 
367
  "is_custom_text": text_modified,
368
  }
369
 
 
 
 
370
  db = SessionLocal()
371
+ is_dummy_db_session = getattr(db, "is_dummy", False)
372
  try:
373
+ crud.create_vote(db, voting_results)
374
+ except Exception as e:
375
+ handle_vote_failure(e, voting_results, is_dummy_db_session)
376
+ else:
377
+ logger.info("Vote record created successfully.")
378
+ if APP_ENV == "dev":
379
+ log_voting_results(voting_results)
380
  finally:
381
  db.close()
382
 
383
+