expressive-tts-arena / src /core /voting_service.py
Zachary Greathouse
Zg/codebase refactor (#20)
5ed9749 unverified
raw
history blame contribute delete
11.7 kB
# Standard Library Imports
import json
from typing import List, Tuple
# Third-Party Library Imports
from sqlalchemy.ext.asyncio import AsyncSession
# Local Application Imports
from src.common import (
ComparisonType,
LeaderboardEntry,
OptionKey,
OptionMap,
TTSProviderName,
VotingResults,
constants,
logger,
)
from src.database import (
AsyncDBSessionMaker,
create_vote,
get_head_to_head_battle_stats,
get_head_to_head_win_rate_stats,
get_leaderboard_stats,
)
class VotingService:
"""
Service for handling all database interactions related to voting and leaderboards.
Encapsulates logic for submitting votes and retrieving formatted leaderboard statistics.
"""
def __init__(self, db_session_maker: AsyncDBSessionMaker):
"""
Initializes the VotingService.
Args:
db_session_maker: An asynchronous database session factory.
"""
self.db_session_maker: AsyncDBSessionMaker = db_session_maker
logger.debug("VotingService initialized.")
async def _create_db_session(self) -> AsyncSession | None:
"""
Creates a new database session, returning None if it's a dummy session.
Returns:
An active AsyncSession or None if using a dummy session factory.
"""
session = self.db_session_maker()
# Check for a dummy session marker if your factory provides one
is_dummy_session = getattr(session, "is_dummy", False)
if is_dummy_session:
logger.debug("Using dummy DB session; operations will be skipped.")
# Ensure dummy sessions are also closed if they have resources
if hasattr(session, "close"):
await session.close()
return None
logger.debug("Created new DB session.")
return session
def _determine_comparison_type(self, provider_a: TTSProviderName, provider_b: TTSProviderName) -> ComparisonType:
"""
Determine the comparison type based on the given TTS provider names.
Args:
provider_a (TTSProviderName): The first TTS provider.
provider_b (TTSProviderName): The second TTS provider.
Returns:
ComparisonType: The determined comparison type.
Raises:
ValueError: If the combination of providers is not recognized.
"""
if provider_a == constants.HUME_AI and provider_b == constants.HUME_AI:
return constants.HUME_TO_HUME
providers = (provider_a, provider_b)
if constants.HUME_AI in providers and constants.ELEVENLABS in providers:
return constants.HUME_TO_ELEVENLABS
if constants.HUME_AI in providers and constants.OPENAI in providers:
return constants.HUME_TO_OPENAI
if constants.ELEVENLABS in providers and constants.OPENAI in providers:
return constants.OPENAI_TO_ELEVENLABS
raise ValueError(f"Invalid provider combination: {provider_a}, {provider_b}")
async def _persist_vote(self, voting_results: VotingResults) -> None:
"""
Persists a vote record in the database using a dedicated session.
Handles session creation, commit, rollback, and closure. Logs errors internally.
Args:
voting_results: A dictionary containing the vote details.
"""
session = await self._create_db_session()
if session is None:
logger.info("Skipping vote persistence (dummy session).")
self._log_voting_results(voting_results)
return
try:
self._log_voting_results(voting_results)
await create_vote(session, voting_results)
logger.info("Vote successfully persisted.")
except Exception as e:
logger.error(f"Failed to persist vote record: {e}", exc_info=True)
finally:
await session.close()
logger.debug("DB session closed after persisting vote.")
def _log_voting_results(self, voting_results: VotingResults) -> None:
"""Logs the full voting results dictionary."""
try:
logger.info("Voting results:\n%s", json.dumps(voting_results, indent=4, default=str))
except TypeError:
logger.error("Could not serialize voting results for logging.")
logger.info(f"Voting results (raw): {voting_results}")
def _format_leaderboard_data(self, leaderboard_data_raw: List[LeaderboardEntry]) -> List[List[str]]:
"""Formats raw leaderboard entries into HTML strings for the UI table."""
formatted_data = []
for rank, provider, model, win_rate, votes in leaderboard_data_raw:
provider_info = constants.TTS_PROVIDER_LINKS.get(provider, {})
provider_link = provider_info.get("provider_link", "#")
model_link = provider_info.get("model_link", "#")
formatted_data.append([
f'<p style="text-align: center;">{rank}</p>',
f'<a href="{provider_link}" target="_blank" class="provider-link">{provider}</a>',
f'<a href="{model_link}" target="_blank" class="provider-link">{model}</a>',
f'<p style="text-align: center;">{win_rate}</p>',
f'<p style="text-align: center;">{votes}</p>',
])
return formatted_data
def _format_battle_counts_data(self, battle_counts_data_raw: List[List[str]]) -> List[List[str]]:
"""Formats raw battle counts into an HTML matrix for the UI."""
battle_counts_dict = {item[0]: str(item[1]) for item in battle_counts_data_raw}
providers = constants.TTS_PROVIDERS
formatted_matrix: List[List[str]] = []
for row_provider in providers:
row = [f'<p style="padding-left: 8px;"><strong>{row_provider}</strong></p>']
for col_provider in providers:
if row_provider == col_provider:
cell_value = "-"
else:
comparison_key = self._determine_comparison_type(row_provider, col_provider)
cell_value = battle_counts_dict.get(comparison_key, "0")
row.append(f'<p style="text-align: center;">{cell_value}</p>')
formatted_matrix.append(row)
return formatted_matrix
def _format_win_rate_data(self, win_rate_data_raw: List[List[str]]) -> List[List[str]]:
"""Formats raw win rates into an HTML matrix for the UI."""
# win_rate_data_raw expected as [comparison_type, first_win_rate_str, second_win_rate_str]
win_rates = {}
for comparison_type, first_win_rate, second_win_rate in win_rate_data_raw:
# Comparison type should already be canonical 'ProviderA - ProviderB'
try:
provider1, provider2 = comparison_type.split(" - ")
win_rates[(provider1, provider2)] = first_win_rate
win_rates[(provider2, provider1)] = second_win_rate
except ValueError:
logger.warning(f"Could not parse comparison_type '{comparison_type}' in win rate data.")
continue # Skip malformed entry
providers = constants.TTS_PROVIDERS
formatted_matrix: List[List[str]] = []
for row_provider in providers:
row = [f'<p style="padding-left: 8px;"><strong>{row_provider}</strong></p>']
for col_provider in providers:
cell_value = "-" if row_provider == col_provider else win_rates.get((row_provider, col_provider), "0%")
row.append(f'<p style="text-align: center;">{cell_value}</p>')
formatted_matrix.append(row)
return formatted_matrix
async def get_formatted_leaderboard_data(self) -> Tuple[
List[List[str]],
List[List[str]],
List[List[str]],
]:
"""
Fetches raw leaderboard stats and formats them for UI display.
Retrieves overall rankings, battle counts, and win rates, then formats
them into HTML strings suitable for Gradio DataFrames.
Returns:
A tuple containing formatted lists of lists for:
- Leaderboard rankings table
- Battle counts matrix
- Win rate matrix
Returns empty lists ([[]], [[]], [[]]) on failure.
"""
session = await self._create_db_session()
if session is None:
logger.info("Skipping leaderboard fetch (dummy session).")
return [[]], [[]], [[]]
try:
# Fetch raw data using underlying CRUD functions
leaderboard_data_raw = await get_leaderboard_stats(session)
battle_counts_data_raw = await get_head_to_head_battle_stats(session)
win_rate_data_raw = await get_head_to_head_win_rate_stats(session)
logger.debug("Fetched raw leaderboard data successfully.")
# Format the data
leaderboard_data = self._format_leaderboard_data(leaderboard_data_raw)
battle_counts_data = self._format_battle_counts_data(battle_counts_data_raw)
win_rate_data = self._format_win_rate_data(win_rate_data_raw)
return leaderboard_data, battle_counts_data, win_rate_data
except Exception as e:
logger.error(f"Failed to fetch and format leaderboard data: {e}", exc_info=True)
return [[]], [[]], [[]] # Return empty structure on error
finally:
await session.close()
logger.debug("DB session closed after fetching leaderboard data.")
async def submit_vote(
self,
option_map: OptionMap,
selected_option: OptionKey,
text_modified: bool,
character_description: str,
text: str,
) -> None:
"""
Constructs and persists a vote record based on user selection and context.
This method is designed to be called safely from background tasks, handling all internal exceptions.
Args:
option_map: Mapping of comparison data and TTS options.
selected_option: The option key ('option_a' or 'option_b') selected by the user.
text_modified: Indicates if the text was custom vs. generated.
character_description: Description used for TTS generation.
text: The text synthesized.
"""
try:
provider_a: TTSProviderName = option_map[constants.OPTION_A_KEY]["provider"]
provider_b: TTSProviderName = option_map[constants.OPTION_B_KEY]["provider"]
comparison_type: ComparisonType = self._determine_comparison_type(provider_a, provider_b)
voting_results: VotingResults = {
"comparison_type": comparison_type,
"winning_provider": option_map[selected_option]["provider"],
"winning_option": selected_option,
"option_a_provider": provider_a,
"option_b_provider": provider_b,
"option_a_generation_id": option_map[constants.OPTION_A_KEY]["generation_id"],
"option_b_generation_id": option_map[constants.OPTION_B_KEY]["generation_id"],
"character_description": character_description,
"text": text,
"is_custom_text": text_modified,
}
await self._persist_vote(voting_results)
except KeyError as e:
logger.error(
f"Missing key in option_map during vote submission: {e}. OptionMap: {option_map}",
exc_info=True
)
except Exception as e:
logger.error(f"Unexpected error in submit_vote: {e}", exc_info=True)