File size: 11,731 Bytes
5ed9749
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
# Standard Library Imports
import json
from typing import List, Tuple

# Third-Party Library Imports
from sqlalchemy.ext.asyncio import AsyncSession

# Local Application Imports
from src.common import (
    ComparisonType,
    LeaderboardEntry,
    OptionKey,
    OptionMap,
    TTSProviderName,
    VotingResults,
    constants,
    logger,
)
from src.database import (
    AsyncDBSessionMaker,
    create_vote,
    get_head_to_head_battle_stats,
    get_head_to_head_win_rate_stats,
    get_leaderboard_stats,
)


class VotingService:
    """
    Service for handling all database interactions related to voting and leaderboards.

    Encapsulates logic for submitting votes and retrieving formatted leaderboard statistics.
    """

    def __init__(self, db_session_maker: AsyncDBSessionMaker):
        """
        Initializes the VotingService.

        Args:
            db_session_maker: An asynchronous database session factory.
        """
        self.db_session_maker: AsyncDBSessionMaker = db_session_maker
        logger.debug("VotingService initialized.")

    async def _create_db_session(self) -> AsyncSession | None:
        """
        Creates a new database session, returning None if it's a dummy session.

        Returns:
            An active AsyncSession or None if using a dummy session factory.
        """
        session = self.db_session_maker()
        # Check for a dummy session marker if your factory provides one
        is_dummy_session = getattr(session, "is_dummy", False)

        if is_dummy_session:
            logger.debug("Using dummy DB session; operations will be skipped.")
            # Ensure dummy sessions are also closed if they have resources
            if hasattr(session, "close"):
                await session.close()
            return None

        logger.debug("Created new DB session.")
        return session

    def _determine_comparison_type(self, provider_a: TTSProviderName, provider_b: TTSProviderName) -> ComparisonType:
        """
        Determine the comparison type based on the given TTS provider names.

        Args:
            provider_a (TTSProviderName): The first TTS provider.
            provider_b (TTSProviderName): The second TTS provider.

        Returns:
            ComparisonType: The determined comparison type.

        Raises:
            ValueError: If the combination of providers is not recognized.
        """
        if provider_a == constants.HUME_AI and provider_b == constants.HUME_AI:
            return constants.HUME_TO_HUME

        providers = (provider_a, provider_b)

        if constants.HUME_AI in providers and constants.ELEVENLABS in providers:
            return constants.HUME_TO_ELEVENLABS

        if constants.HUME_AI in providers and constants.OPENAI in providers:
            return constants.HUME_TO_OPENAI

        if constants.ELEVENLABS in providers and constants.OPENAI in providers:
            return constants.OPENAI_TO_ELEVENLABS

        raise ValueError(f"Invalid provider combination: {provider_a}, {provider_b}")

    async def _persist_vote(self, voting_results: VotingResults) -> None:
        """
        Persists a vote record in the database using a dedicated session.

        Handles session creation, commit, rollback, and closure. Logs errors internally.

        Args:
            voting_results: A dictionary containing the vote details.
        """
        session = await self._create_db_session()
        if session is None:
            logger.info("Skipping vote persistence (dummy session).")
            self._log_voting_results(voting_results)
            return

        try:
            self._log_voting_results(voting_results)
            await create_vote(session, voting_results)
            logger.info("Vote successfully persisted.")
        except Exception as e:
            logger.error(f"Failed to persist vote record: {e}", exc_info=True)
        finally:
            await session.close()
            logger.debug("DB session closed after persisting vote.")

    def _log_voting_results(self, voting_results: VotingResults) -> None:
        """Logs the full voting results dictionary."""
        try:
            logger.info("Voting results:\n%s", json.dumps(voting_results, indent=4, default=str))
        except TypeError:
            logger.error("Could not serialize voting results for logging.")
            logger.info(f"Voting results (raw): {voting_results}")

    def _format_leaderboard_data(self, leaderboard_data_raw: List[LeaderboardEntry]) -> List[List[str]]:
        """Formats raw leaderboard entries into HTML strings for the UI table."""
        formatted_data = []
        for rank, provider, model, win_rate, votes in leaderboard_data_raw:
            provider_info = constants.TTS_PROVIDER_LINKS.get(provider, {})
            provider_link = provider_info.get("provider_link", "#")
            model_link = provider_info.get("model_link", "#")

            formatted_data.append([
                f'<p style="text-align: center;">{rank}</p>',
                f'<a href="{provider_link}" target="_blank" class="provider-link">{provider}</a>',
                f'<a href="{model_link}" target="_blank" class="provider-link">{model}</a>',
                f'<p style="text-align: center;">{win_rate}</p>',
                f'<p style="text-align: center;">{votes}</p>',
            ])
        return formatted_data


    def _format_battle_counts_data(self, battle_counts_data_raw: List[List[str]]) -> List[List[str]]:
        """Formats raw battle counts into an HTML matrix for the UI."""
        battle_counts_dict = {item[0]: str(item[1]) for item in battle_counts_data_raw}
        providers = constants.TTS_PROVIDERS

        formatted_matrix: List[List[str]] = []
        for row_provider in providers:
            row = [f'<p style="padding-left: 8px;"><strong>{row_provider}</strong></p>']
            for col_provider in providers:
                if row_provider == col_provider:
                    cell_value = "-"
                else:
                    comparison_key = self._determine_comparison_type(row_provider, col_provider)
                    cell_value = battle_counts_dict.get(comparison_key, "0")
                row.append(f'<p style="text-align: center;">{cell_value}</p>')
            formatted_matrix.append(row)
        return formatted_matrix


    def _format_win_rate_data(self, win_rate_data_raw: List[List[str]]) -> List[List[str]]:
        """Formats raw win rates into an HTML matrix for the UI."""
        # win_rate_data_raw expected as [comparison_type, first_win_rate_str, second_win_rate_str]
        win_rates = {}
        for comparison_type, first_win_rate, second_win_rate in win_rate_data_raw:
            # Comparison type should already be canonical 'ProviderA - ProviderB'
            try:
                provider1, provider2 = comparison_type.split(" - ")
                win_rates[(provider1, provider2)] = first_win_rate
                win_rates[(provider2, provider1)] = second_win_rate
            except ValueError:
                logger.warning(f"Could not parse comparison_type '{comparison_type}' in win rate data.")
                continue # Skip malformed entry

        providers = constants.TTS_PROVIDERS
        formatted_matrix: List[List[str]] = []
        for row_provider in providers:
            row = [f'<p style="padding-left: 8px;"><strong>{row_provider}</strong></p>']
            for col_provider in providers:
                cell_value = "-" if row_provider == col_provider else win_rates.get((row_provider, col_provider), "0%")
                row.append(f'<p style="text-align: center;">{cell_value}</p>')
            formatted_matrix.append(row)
        return formatted_matrix

    async def get_formatted_leaderboard_data(self) -> Tuple[
        List[List[str]],
        List[List[str]],
        List[List[str]],
    ]:
        """
        Fetches raw leaderboard stats and formats them for UI display.

        Retrieves overall rankings, battle counts, and win rates, then formats
        them into HTML strings suitable for Gradio DataFrames.

        Returns:
            A tuple containing formatted lists of lists for:
            - Leaderboard rankings table
            - Battle counts matrix
            - Win rate matrix
            Returns empty lists ([[]], [[]], [[]]) on failure.
        """
        session = await self._create_db_session()
        if session is None:
            logger.info("Skipping leaderboard fetch (dummy session).")
            return [[]], [[]], [[]]

        try:
            # Fetch raw data using underlying CRUD functions
            leaderboard_data_raw = await get_leaderboard_stats(session)
            battle_counts_data_raw = await get_head_to_head_battle_stats(session)
            win_rate_data_raw = await get_head_to_head_win_rate_stats(session)
            logger.debug("Fetched raw leaderboard data successfully.")

            # Format the data
            leaderboard_data = self._format_leaderboard_data(leaderboard_data_raw)
            battle_counts_data = self._format_battle_counts_data(battle_counts_data_raw)
            win_rate_data = self._format_win_rate_data(win_rate_data_raw)

            return leaderboard_data, battle_counts_data, win_rate_data

        except Exception as e:
            logger.error(f"Failed to fetch and format leaderboard data: {e}", exc_info=True)
            return [[]], [[]], [[]] # Return empty structure on error
        finally:
            await session.close()
            logger.debug("DB session closed after fetching leaderboard data.")

    async def submit_vote(
        self,
        option_map: OptionMap,
        selected_option: OptionKey,
        text_modified: bool,
        character_description: str,
        text: str,
    ) -> None:
        """
        Constructs and persists a vote record based on user selection and context.

        This method is designed to be called safely from background tasks, handling all internal exceptions.

        Args:
            option_map: Mapping of comparison data and TTS options.
            selected_option: The option key ('option_a' or 'option_b') selected by the user.
            text_modified: Indicates if the text was custom vs. generated.
            character_description: Description used for TTS generation.
            text: The text synthesized.
        """
        try:
            provider_a: TTSProviderName = option_map[constants.OPTION_A_KEY]["provider"]
            provider_b: TTSProviderName = option_map[constants.OPTION_B_KEY]["provider"]

            comparison_type: ComparisonType = self._determine_comparison_type(provider_a, provider_b)

            voting_results: VotingResults = {
                "comparison_type": comparison_type,
                "winning_provider": option_map[selected_option]["provider"],
                "winning_option": selected_option,
                "option_a_provider": provider_a,
                "option_b_provider": provider_b,
                "option_a_generation_id": option_map[constants.OPTION_A_KEY]["generation_id"],
                "option_b_generation_id": option_map[constants.OPTION_B_KEY]["generation_id"],
                "character_description": character_description,
                "text": text,
                "is_custom_text": text_modified,
            }

            await self._persist_vote(voting_results)

        except KeyError as e:
            logger.error(
                f"Missing key in option_map during vote submission: {e}. OptionMap: {option_map}",
                exc_info=True
            )
        except Exception as e:
            logger.error(f"Unexpected error in submit_vote: {e}", exc_info=True)