Spaces:
Running
Running
File size: 11,731 Bytes
5ed9749 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 |
# Standard Library Imports
import json
from typing import List, Tuple
# Third-Party Library Imports
from sqlalchemy.ext.asyncio import AsyncSession
# Local Application Imports
from src.common import (
ComparisonType,
LeaderboardEntry,
OptionKey,
OptionMap,
TTSProviderName,
VotingResults,
constants,
logger,
)
from src.database import (
AsyncDBSessionMaker,
create_vote,
get_head_to_head_battle_stats,
get_head_to_head_win_rate_stats,
get_leaderboard_stats,
)
class VotingService:
"""
Service for handling all database interactions related to voting and leaderboards.
Encapsulates logic for submitting votes and retrieving formatted leaderboard statistics.
"""
def __init__(self, db_session_maker: AsyncDBSessionMaker):
"""
Initializes the VotingService.
Args:
db_session_maker: An asynchronous database session factory.
"""
self.db_session_maker: AsyncDBSessionMaker = db_session_maker
logger.debug("VotingService initialized.")
async def _create_db_session(self) -> AsyncSession | None:
"""
Creates a new database session, returning None if it's a dummy session.
Returns:
An active AsyncSession or None if using a dummy session factory.
"""
session = self.db_session_maker()
# Check for a dummy session marker if your factory provides one
is_dummy_session = getattr(session, "is_dummy", False)
if is_dummy_session:
logger.debug("Using dummy DB session; operations will be skipped.")
# Ensure dummy sessions are also closed if they have resources
if hasattr(session, "close"):
await session.close()
return None
logger.debug("Created new DB session.")
return session
def _determine_comparison_type(self, provider_a: TTSProviderName, provider_b: TTSProviderName) -> ComparisonType:
"""
Determine the comparison type based on the given TTS provider names.
Args:
provider_a (TTSProviderName): The first TTS provider.
provider_b (TTSProviderName): The second TTS provider.
Returns:
ComparisonType: The determined comparison type.
Raises:
ValueError: If the combination of providers is not recognized.
"""
if provider_a == constants.HUME_AI and provider_b == constants.HUME_AI:
return constants.HUME_TO_HUME
providers = (provider_a, provider_b)
if constants.HUME_AI in providers and constants.ELEVENLABS in providers:
return constants.HUME_TO_ELEVENLABS
if constants.HUME_AI in providers and constants.OPENAI in providers:
return constants.HUME_TO_OPENAI
if constants.ELEVENLABS in providers and constants.OPENAI in providers:
return constants.OPENAI_TO_ELEVENLABS
raise ValueError(f"Invalid provider combination: {provider_a}, {provider_b}")
async def _persist_vote(self, voting_results: VotingResults) -> None:
"""
Persists a vote record in the database using a dedicated session.
Handles session creation, commit, rollback, and closure. Logs errors internally.
Args:
voting_results: A dictionary containing the vote details.
"""
session = await self._create_db_session()
if session is None:
logger.info("Skipping vote persistence (dummy session).")
self._log_voting_results(voting_results)
return
try:
self._log_voting_results(voting_results)
await create_vote(session, voting_results)
logger.info("Vote successfully persisted.")
except Exception as e:
logger.error(f"Failed to persist vote record: {e}", exc_info=True)
finally:
await session.close()
logger.debug("DB session closed after persisting vote.")
def _log_voting_results(self, voting_results: VotingResults) -> None:
"""Logs the full voting results dictionary."""
try:
logger.info("Voting results:\n%s", json.dumps(voting_results, indent=4, default=str))
except TypeError:
logger.error("Could not serialize voting results for logging.")
logger.info(f"Voting results (raw): {voting_results}")
def _format_leaderboard_data(self, leaderboard_data_raw: List[LeaderboardEntry]) -> List[List[str]]:
"""Formats raw leaderboard entries into HTML strings for the UI table."""
formatted_data = []
for rank, provider, model, win_rate, votes in leaderboard_data_raw:
provider_info = constants.TTS_PROVIDER_LINKS.get(provider, {})
provider_link = provider_info.get("provider_link", "#")
model_link = provider_info.get("model_link", "#")
formatted_data.append([
f'<p style="text-align: center;">{rank}</p>',
f'<a href="{provider_link}" target="_blank" class="provider-link">{provider}</a>',
f'<a href="{model_link}" target="_blank" class="provider-link">{model}</a>',
f'<p style="text-align: center;">{win_rate}</p>',
f'<p style="text-align: center;">{votes}</p>',
])
return formatted_data
def _format_battle_counts_data(self, battle_counts_data_raw: List[List[str]]) -> List[List[str]]:
"""Formats raw battle counts into an HTML matrix for the UI."""
battle_counts_dict = {item[0]: str(item[1]) for item in battle_counts_data_raw}
providers = constants.TTS_PROVIDERS
formatted_matrix: List[List[str]] = []
for row_provider in providers:
row = [f'<p style="padding-left: 8px;"><strong>{row_provider}</strong></p>']
for col_provider in providers:
if row_provider == col_provider:
cell_value = "-"
else:
comparison_key = self._determine_comparison_type(row_provider, col_provider)
cell_value = battle_counts_dict.get(comparison_key, "0")
row.append(f'<p style="text-align: center;">{cell_value}</p>')
formatted_matrix.append(row)
return formatted_matrix
def _format_win_rate_data(self, win_rate_data_raw: List[List[str]]) -> List[List[str]]:
"""Formats raw win rates into an HTML matrix for the UI."""
# win_rate_data_raw expected as [comparison_type, first_win_rate_str, second_win_rate_str]
win_rates = {}
for comparison_type, first_win_rate, second_win_rate in win_rate_data_raw:
# Comparison type should already be canonical 'ProviderA - ProviderB'
try:
provider1, provider2 = comparison_type.split(" - ")
win_rates[(provider1, provider2)] = first_win_rate
win_rates[(provider2, provider1)] = second_win_rate
except ValueError:
logger.warning(f"Could not parse comparison_type '{comparison_type}' in win rate data.")
continue # Skip malformed entry
providers = constants.TTS_PROVIDERS
formatted_matrix: List[List[str]] = []
for row_provider in providers:
row = [f'<p style="padding-left: 8px;"><strong>{row_provider}</strong></p>']
for col_provider in providers:
cell_value = "-" if row_provider == col_provider else win_rates.get((row_provider, col_provider), "0%")
row.append(f'<p style="text-align: center;">{cell_value}</p>')
formatted_matrix.append(row)
return formatted_matrix
async def get_formatted_leaderboard_data(self) -> Tuple[
List[List[str]],
List[List[str]],
List[List[str]],
]:
"""
Fetches raw leaderboard stats and formats them for UI display.
Retrieves overall rankings, battle counts, and win rates, then formats
them into HTML strings suitable for Gradio DataFrames.
Returns:
A tuple containing formatted lists of lists for:
- Leaderboard rankings table
- Battle counts matrix
- Win rate matrix
Returns empty lists ([[]], [[]], [[]]) on failure.
"""
session = await self._create_db_session()
if session is None:
logger.info("Skipping leaderboard fetch (dummy session).")
return [[]], [[]], [[]]
try:
# Fetch raw data using underlying CRUD functions
leaderboard_data_raw = await get_leaderboard_stats(session)
battle_counts_data_raw = await get_head_to_head_battle_stats(session)
win_rate_data_raw = await get_head_to_head_win_rate_stats(session)
logger.debug("Fetched raw leaderboard data successfully.")
# Format the data
leaderboard_data = self._format_leaderboard_data(leaderboard_data_raw)
battle_counts_data = self._format_battle_counts_data(battle_counts_data_raw)
win_rate_data = self._format_win_rate_data(win_rate_data_raw)
return leaderboard_data, battle_counts_data, win_rate_data
except Exception as e:
logger.error(f"Failed to fetch and format leaderboard data: {e}", exc_info=True)
return [[]], [[]], [[]] # Return empty structure on error
finally:
await session.close()
logger.debug("DB session closed after fetching leaderboard data.")
async def submit_vote(
self,
option_map: OptionMap,
selected_option: OptionKey,
text_modified: bool,
character_description: str,
text: str,
) -> None:
"""
Constructs and persists a vote record based on user selection and context.
This method is designed to be called safely from background tasks, handling all internal exceptions.
Args:
option_map: Mapping of comparison data and TTS options.
selected_option: The option key ('option_a' or 'option_b') selected by the user.
text_modified: Indicates if the text was custom vs. generated.
character_description: Description used for TTS generation.
text: The text synthesized.
"""
try:
provider_a: TTSProviderName = option_map[constants.OPTION_A_KEY]["provider"]
provider_b: TTSProviderName = option_map[constants.OPTION_B_KEY]["provider"]
comparison_type: ComparisonType = self._determine_comparison_type(provider_a, provider_b)
voting_results: VotingResults = {
"comparison_type": comparison_type,
"winning_provider": option_map[selected_option]["provider"],
"winning_option": selected_option,
"option_a_provider": provider_a,
"option_b_provider": provider_b,
"option_a_generation_id": option_map[constants.OPTION_A_KEY]["generation_id"],
"option_b_generation_id": option_map[constants.OPTION_B_KEY]["generation_id"],
"character_description": character_description,
"text": text,
"is_custom_text": text_modified,
}
await self._persist_vote(voting_results)
except KeyError as e:
logger.error(
f"Missing key in option_map during vote submission: {e}. OptionMap: {option_map}",
exc_info=True
)
except Exception as e:
logger.error(f"Unexpected error in submit_vote: {e}", exc_info=True)
|