Zachary Greathouse twitchard commited on
Commit
9ed181c
·
unverified ·
1 Parent(s): 548169b

Zg/add openai (#18)

Browse files

* Add OpenAI python SDK to dependencies

* Fix Anthropic clean API Error message.

* Update constants and custom types associated with TTS providers to include OpenAI

* Add OpenAI integration

* Update logic for selecting providers, add OpenAI tts to UI

* Fix typo in openai_api.py

* Update docstrings in openai_api.py

* Update leaderboard results query to include OpenAI results

* Add citation

* Adjust padding in UI components

* Adjust padding in UI components in citation

* Add transitive dependency override for sounddevice in pyproject.toml

* remove sounddevice

* Add warning toast for custom text inputs

* Improve leaderboard results query to account for zero records, and update to only include relevant comparison types for each provider.

---------

Co-authored-by: twitchard <[email protected]>

pyproject.toml CHANGED
@@ -12,13 +12,17 @@ dependencies = [
12
  "gradio>=5.18.0",
13
  "greenlet>=2.0.0",
14
  "hume>=0.7.8",
 
15
  "python-dotenv>=1.0.1",
16
  "sqlalchemy>=2.0.0",
17
  "tenacity>=9.0.0",
18
  ]
19
 
20
  [tool.uv]
21
- override-dependencies = ["aiofiles==24.1.0"]
 
 
 
22
  dev-dependencies = [
23
  "mypy>=1.15.0",
24
  "pre-commit>=4.1.0",
@@ -84,7 +88,7 @@ select = [
84
  "TID",
85
  "W",
86
  ]
87
- per-file-ignores = { "src/constants.py" = ["E501"] }
88
 
89
  [tool.ruff.lint.pycodestyle]
90
  max-line-length = 120
 
12
  "gradio>=5.18.0",
13
  "greenlet>=2.0.0",
14
  "hume>=0.7.8",
15
+ "openai>=1.68.0",
16
  "python-dotenv>=1.0.1",
17
  "sqlalchemy>=2.0.0",
18
  "tenacity>=9.0.0",
19
  ]
20
 
21
  [tool.uv]
22
+ override-dependencies = [
23
+ "aiofiles==24.1.0",
24
+ "sounddevice; sys_platform == 'never'",
25
+ ]
26
  dev-dependencies = [
27
  "mypy>=1.15.0",
28
  "pre-commit>=4.1.0",
 
88
  "TID",
89
  "W",
90
  ]
91
+ per-file-ignores = { "src/constants.py" = ["E501"], "src/frontend.py" = ["E501"] }
92
 
93
  [tool.ruff.lint.pycodestyle]
94
  max-line-length = 120
src/config.py CHANGED
@@ -22,7 +22,7 @@ from dotenv import load_dotenv
22
 
23
  # Local Application Imports
24
  if TYPE_CHECKING:
25
- from src.integrations import AnthropicConfig, ElevenLabsConfig, HumeConfig
26
 
27
  logger: logging.Logger = logging.getLogger("expressive_tts_arena")
28
 
@@ -37,6 +37,7 @@ class Config:
37
  anthropic_config: "AnthropicConfig"
38
  hume_config: "HumeConfig"
39
  elevenlabs_config: "ElevenLabsConfig"
 
40
 
41
  @classmethod
42
  def get(cls) -> "Config":
@@ -79,7 +80,7 @@ class Config:
79
  if debug:
80
  logger.debug("DEBUG mode enabled.")
81
 
82
- from src.integrations import AnthropicConfig, ElevenLabsConfig, HumeConfig
83
 
84
  return Config(
85
  app_env=app_env,
@@ -89,4 +90,5 @@ class Config:
89
  anthropic_config=AnthropicConfig(),
90
  hume_config=HumeConfig(),
91
  elevenlabs_config=ElevenLabsConfig(),
 
92
  )
 
22
 
23
  # Local Application Imports
24
  if TYPE_CHECKING:
25
+ from src.integrations import AnthropicConfig, ElevenLabsConfig, HumeConfig, OpenAIConfig
26
 
27
  logger: logging.Logger = logging.getLogger("expressive_tts_arena")
28
 
 
37
  anthropic_config: "AnthropicConfig"
38
  hume_config: "HumeConfig"
39
  elevenlabs_config: "ElevenLabsConfig"
40
+ openai_config: "OpenAIConfig"
41
 
42
  @classmethod
43
  def get(cls) -> "Config":
 
80
  if debug:
81
  logger.debug("DEBUG mode enabled.")
82
 
83
+ from src.integrations import AnthropicConfig, ElevenLabsConfig, HumeConfig, OpenAIConfig
84
 
85
  return Config(
86
  app_env=app_env,
 
90
  anthropic_config=AnthropicConfig(),
91
  hume_config=HumeConfig(),
92
  elevenlabs_config=ElevenLabsConfig(),
93
+ openai_config=OpenAIConfig(),
94
  )
src/constants.py CHANGED
@@ -10,6 +10,7 @@ from typing import Dict, List
10
  # Third-Party Library Imports
11
  from src.custom_types import (
12
  ComparisonType,
 
13
  OptionKey,
14
  OptionLabel,
15
  TTSProviderName,
@@ -23,8 +24,9 @@ RATE_LIMIT_ERROR_CODE = 429
23
  # UI constants
24
  HUME_AI: TTSProviderName = "Hume AI"
25
  ELEVENLABS: TTSProviderName = "ElevenLabs"
 
26
 
27
- TTS_PROVIDERS: List[TTSProviderName] = ["Hume AI", "ElevenLabs"]
28
  TTS_PROVIDER_LINKS = {
29
  "Hume AI": {
30
  "provider_link": "https://hume.ai/",
@@ -33,11 +35,17 @@ TTS_PROVIDER_LINKS = {
33
  "ElevenLabs": {
34
  "provider_link": "https://elevenlabs.io/",
35
  "model_link": "https://elevenlabs.io/blog/rvg",
 
 
 
 
36
  }
37
  }
38
 
39
  HUME_TO_HUME: ComparisonType = "Hume AI - Hume AI"
40
  HUME_TO_ELEVENLABS: ComparisonType = "Hume AI - ElevenLabs"
 
 
41
 
42
  CHARACTER_DESCRIPTION_MIN_LENGTH: int = 20
43
  CHARACTER_DESCRIPTION_MAX_LENGTH: int = 400
@@ -162,3 +170,9 @@ META_TAGS: List[Dict[str, str]] = [
162
  }
163
  ]
164
 
 
 
 
 
 
 
 
10
  # Third-Party Library Imports
11
  from src.custom_types import (
12
  ComparisonType,
13
+ LeaderboardEntry,
14
  OptionKey,
15
  OptionLabel,
16
  TTSProviderName,
 
24
  # UI constants
25
  HUME_AI: TTSProviderName = "Hume AI"
26
  ELEVENLABS: TTSProviderName = "ElevenLabs"
27
+ OPENAI: TTSProviderName = "OpenAI"
28
 
29
+ TTS_PROVIDERS: List[TTSProviderName] = ["Hume AI", "ElevenLabs", "OpenAI"]
30
  TTS_PROVIDER_LINKS = {
31
  "Hume AI": {
32
  "provider_link": "https://hume.ai/",
 
35
  "ElevenLabs": {
36
  "provider_link": "https://elevenlabs.io/",
37
  "model_link": "https://elevenlabs.io/blog/rvg",
38
+ },
39
+ "OpenAI": {
40
+ "provider_link": "https://openai.com/",
41
+ "model_link": "https://platform.openai.com/docs/models/gpt-4o-mini-tts",
42
  }
43
  }
44
 
45
  HUME_TO_HUME: ComparisonType = "Hume AI - Hume AI"
46
  HUME_TO_ELEVENLABS: ComparisonType = "Hume AI - ElevenLabs"
47
+ HUME_TO_OPENAI: ComparisonType = "Hume AI - OpenAI"
48
+ OPENAI_TO_ELEVENLABS: ComparisonType = "OpenAI - ElevenLabs"
49
 
50
  CHARACTER_DESCRIPTION_MIN_LENGTH: int = 20
51
  CHARACTER_DESCRIPTION_MAX_LENGTH: int = 400
 
170
  }
171
  ]
172
 
173
+ # Reflects and empty leaderboard state
174
+ DEFAULT_LEADERBOARD: List[LeaderboardEntry] = [
175
+ LeaderboardEntry("1", "", "", "0%", "0"),
176
+ LeaderboardEntry("2", "", "", "0%", "0"),
177
+ LeaderboardEntry("3", "", "", "0%", "0"),
178
+ ]
src/custom_types.py CHANGED
@@ -7,11 +7,16 @@ This module defines custom types for the application.
7
  # Standard Library Imports
8
  from typing import List, Literal, NamedTuple, Optional, TypedDict
9
 
10
- TTSProviderName = Literal["Hume AI", "ElevenLabs"]
11
  """TTSProviderName represents the allowed provider names for TTS services."""
12
 
13
 
14
- ComparisonType = Literal["Hume AI - Hume AI", "Hume AI - ElevenLabs"]
 
 
 
 
 
15
  """Comparison type denoting which providers are compared."""
16
 
17
 
 
7
  # Standard Library Imports
8
  from typing import List, Literal, NamedTuple, Optional, TypedDict
9
 
10
+ TTSProviderName = Literal["Hume AI", "ElevenLabs", "OpenAI"]
11
  """TTSProviderName represents the allowed provider names for TTS services."""
12
 
13
 
14
+ ComparisonType = Literal[
15
+ "Hume AI - Hume AI",
16
+ "Hume AI - ElevenLabs",
17
+ "Hume AI - OpenAI",
18
+ "OpenAI - ElevenLabs"
19
+ ]
20
  """Comparison type denoting which providers are compared."""
21
 
22
 
src/database/crud.py CHANGED
@@ -12,6 +12,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
12
 
13
  # Local Application Imports
14
  from src.config import logger
 
15
  from src.custom_types import LeaderboardEntry, LeaderboardTableEntries, VotingResults
16
  from src.database.models import VoteResult
17
 
@@ -72,8 +73,8 @@ async def get_leaderboard_stats(db: AsyncSession) -> LeaderboardTableEntries:
72
  """
73
  Fetches voting statistics from the database to populate a leaderboard.
74
 
75
- This function calculates voting statistics for TTS providers, excluding Hume-to-Hume
76
- comparisons, and returns data structured for a leaderboard display.
77
 
78
  Args:
79
  db (AsyncSession): The SQLAlchemy async database session.
@@ -82,46 +83,54 @@ async def get_leaderboard_stats(db: AsyncSession) -> LeaderboardTableEntries:
82
  LeaderboardTableEntries: A list of LeaderboardEntry objects containing rank,
83
  provider name, model name, win rate, and total votes.
84
  """
85
- default_leaderboard = [
86
- LeaderboardEntry("1", "", "", "0%", "0"),
87
- LeaderboardEntry("2", "", "", "0%", "0")
88
- ]
89
-
90
  try:
91
  query = text(
92
  """
93
- WITH provider_stats AS (
94
- -- Get wins for Hume AI
 
 
95
  SELECT
96
  'Hume AI' as provider,
97
  COUNT(*) as total_comparisons,
98
  SUM(CASE WHEN winning_provider = 'Hume AI' THEN 1 ELSE 0 END) as wins
99
  FROM vote_results
100
- WHERE comparison_type != 'Hume AI - Hume AI'
101
 
102
  UNION ALL
103
 
104
- -- Get wins for ElevenLabs
105
  SELECT
106
  'ElevenLabs' as provider,
107
  COUNT(*) as total_comparisons,
108
  SUM(CASE WHEN winning_provider = 'ElevenLabs' THEN 1 ELSE 0 END) as wins
109
  FROM vote_results
110
- WHERE comparison_type != 'Hume AI - Hume AI'
 
 
 
 
 
 
 
 
 
111
  )
112
  SELECT
113
- provider,
114
  CASE
115
- WHEN provider = 'Hume AI' THEN 'Octave'
116
- WHEN provider = 'ElevenLabs' THEN 'Voice Design'
 
117
  END as model,
118
  CASE
119
- WHEN total_comparisons > 0 THEN ROUND((wins * 100.0 / total_comparisons)::numeric, 2)
 
120
  ELSE 0
121
  END as win_rate,
122
- wins as total_votes
123
- FROM provider_stats
124
- ORDER BY win_rate DESC;
 
125
  """
126
  )
127
 
@@ -143,13 +152,14 @@ async def get_leaderboard_stats(db: AsyncSession) -> LeaderboardTableEntries:
143
 
144
  # If no data was found, return default entries
145
  if not leaderboard_data:
146
- return default_leaderboard
147
 
148
  return leaderboard_data
149
 
150
  except SQLAlchemyError as e:
151
  logger.error(f"Database error while fetching leaderboard stats: {e}")
152
- return default_leaderboard
153
  except Exception as e:
154
  logger.error(f"Unexpected error while fetching leaderboard stats: {e}")
155
- return default_leaderboard
 
 
12
 
13
  # Local Application Imports
14
  from src.config import logger
15
+ from src.constants import DEFAULT_LEADERBOARD
16
  from src.custom_types import LeaderboardEntry, LeaderboardTableEntries, VotingResults
17
  from src.database.models import VoteResult
18
 
 
73
  """
74
  Fetches voting statistics from the database to populate a leaderboard.
75
 
76
+ This function calculates voting statistics for TTS providers, using only the relevant
77
+ comparison types for each provider, and returns data structured for a leaderboard display.
78
 
79
  Args:
80
  db (AsyncSession): The SQLAlchemy async database session.
 
83
  LeaderboardTableEntries: A list of LeaderboardEntry objects containing rank,
84
  provider name, model name, win rate, and total votes.
85
  """
 
 
 
 
 
86
  try:
87
  query = text(
88
  """
89
+ WITH all_providers AS (
90
+ SELECT provider FROM (VALUES ('Hume AI'), ('ElevenLabs'), ('OpenAI')) AS p(provider)
91
+ ),
92
+ provider_stats AS (
93
  SELECT
94
  'Hume AI' as provider,
95
  COUNT(*) as total_comparisons,
96
  SUM(CASE WHEN winning_provider = 'Hume AI' THEN 1 ELSE 0 END) as wins
97
  FROM vote_results
98
+ WHERE comparison_type IN ('Hume AI - ElevenLabs', 'Hume AI - OpenAI')
99
 
100
  UNION ALL
101
 
 
102
  SELECT
103
  'ElevenLabs' as provider,
104
  COUNT(*) as total_comparisons,
105
  SUM(CASE WHEN winning_provider = 'ElevenLabs' THEN 1 ELSE 0 END) as wins
106
  FROM vote_results
107
+ WHERE comparison_type IN ('Hume AI - ElevenLabs', 'OpenAI - ElevenLabs')
108
+
109
+ UNION ALL
110
+
111
+ SELECT
112
+ 'OpenAI' as provider,
113
+ COUNT(*) as total_comparisons,
114
+ SUM(CASE WHEN winning_provider = 'OpenAI' THEN 1 ELSE 0 END) as wins
115
+ FROM vote_results
116
+ WHERE comparison_type IN ('Hume AI - OpenAI', 'OpenAI - ElevenLabs')
117
  )
118
  SELECT
119
+ p.provider,
120
  CASE
121
+ WHEN p.provider = 'Hume AI' THEN 'Octave'
122
+ WHEN p.provider = 'ElevenLabs' THEN 'Voice Design'
123
+ WHEN p.provider = 'OpenAI' THEN 'gpt-4o-mini-tts'
124
  END as model,
125
  CASE
126
+ WHEN COALESCE(ps.total_comparisons, 0) > 0
127
+ THEN ROUND((COALESCE(ps.wins, 0) * 100.0 / COALESCE(ps.total_comparisons, 1))::numeric, 2)
128
  ELSE 0
129
  END as win_rate,
130
+ COALESCE(ps.wins, 0) as total_votes
131
+ FROM all_providers p
132
+ LEFT JOIN provider_stats ps ON p.provider = ps.provider
133
+ ORDER BY win_rate DESC, total_votes DESC;
134
  """
135
  )
136
 
 
152
 
153
  # If no data was found, return default entries
154
  if not leaderboard_data:
155
+ return DEFAULT_LEADERBOARD
156
 
157
  return leaderboard_data
158
 
159
  except SQLAlchemyError as e:
160
  logger.error(f"Database error while fetching leaderboard stats: {e}")
161
+ return DEFAULT_LEADERBOARD
162
  except Exception as e:
163
  logger.error(f"Unexpected error while fetching leaderboard stats: {e}")
164
+ return DEFAULT_LEADERBOARD
165
+
src/frontend.py CHANGED
@@ -13,7 +13,7 @@ import asyncio
13
  import hashlib
14
  import json
15
  import time
16
- from typing import List, Tuple
17
 
18
  # Third-Party Library Imports
19
  import gradio as gr
@@ -27,15 +27,17 @@ from src.integrations import (
27
  AnthropicError,
28
  ElevenLabsError,
29
  HumeError,
 
30
  generate_text_with_claude,
31
  text_to_speech_with_elevenlabs,
32
  text_to_speech_with_hume,
 
33
  )
34
  from src.utils import (
35
  create_shuffled_tts_options,
36
  determine_selected_option,
37
  get_leaderboard_data,
38
- get_random_provider,
39
  submit_voting_results,
40
  validate_character_description_length,
41
  validate_text_length,
@@ -52,40 +54,40 @@ class Frontend:
52
 
53
  # leaderboard update state
54
  self._leaderboard_data: List[List[str]] = [[]]
55
- self._leaderboard_cache_hash = None
56
- self._last_leaderboard_update_time = 0
57
  self._min_refresh_interval = 30
58
 
59
  async def _update_leaderboard_data(self, force: bool = False) -> bool:
60
  """
61
  Fetches the latest leaderboard data only if needed based on cache and time constraints.
62
-
63
  Args:
64
  force (bool): If True, bypass the time-based throttling.
65
-
66
  Returns:
67
  bool: True if the leaderboard was updated, False otherwise.
68
  """
69
  current_time = time.time()
70
  time_since_last_update = current_time - self._last_leaderboard_update_time
71
-
72
  # Skip update if it's been less than min_refresh_interval seconds and not forced
73
  if not force and time_since_last_update < self._min_refresh_interval:
74
  logger.debug(f"Skipping leaderboard update: last updated {time_since_last_update:.1f}s ago.")
75
  return False
76
-
77
  # Fetch the latest data
78
  latest_leaderboard_data = await get_leaderboard_data(self.db_session_maker)
79
-
80
  # Generate a hash of the new data to check if it's changed
81
  data_str = json.dumps(str(latest_leaderboard_data))
82
  data_hash = hashlib.md5(data_str.encode()).hexdigest()
83
-
84
  # Check if the data has changed
85
  if data_hash == self._leaderboard_cache_hash and not force:
86
  logger.debug("Leaderboard data unchanged since last fetch.")
87
  return False
88
-
89
  # Update the cache and timestamp
90
  self._leaderboard_data = latest_leaderboard_data
91
  self._leaderboard_cache_hash = data_hash
@@ -125,6 +127,24 @@ class Frontend:
125
  logger.error(f"Text Generation Failed: Unexpected error while generating text: {e!s}")
126
  raise gr.Error("Failed to generate text. Please try again shortly.")
127
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  async def _synthesize_speech(
129
  self,
130
  character_description: str,
@@ -135,9 +155,7 @@ class Frontend:
135
  Synthesizes two text-to-speech outputs, updates UI state components, and returns additional TTS metadata.
136
 
137
  This function generates TTS outputs using different providers based on the input text and its modification
138
- state. Depending on the selected providers, it may:
139
- - Synthesize one Hume and one ElevenLabs output (50% chance), or
140
- - Synthesize two Hume outputs (50% chance).
141
 
142
  The outputs are processed and shuffled, and the corresponding UI components for two audio players are updated.
143
  Additional metadata such as the comparison type, generation IDs, and state information are also returned.
@@ -150,8 +168,8 @@ class Frontend:
150
 
151
  Returns:
152
  Tuple containing:
153
- - dict: Update for the first audio player (with autoplay enabled).
154
- - dict: Update for the second audio player.
155
  - OptionMap: A mapping of option constants to their corresponding TTS providers.
156
  - bool: Flag indicating whether the text was modified.
157
  - str: The original text that was synthesized.
@@ -169,22 +187,19 @@ class Frontend:
169
  raise gr.Error(str(ve))
170
 
171
  text_modified = text != generated_text_state
172
- provider_a = constants.HUME_AI # always compare with Hume
173
- provider_b = get_random_provider(text_modified)
174
 
175
  tts_provider_funcs = {
176
  constants.HUME_AI: text_to_speech_with_hume,
 
177
  constants.ELEVENLABS: text_to_speech_with_elevenlabs,
178
  }
179
 
180
- if provider_b not in tts_provider_funcs:
181
- raise ValueError(f"Unsupported provider: {provider_b}")
182
-
183
  try:
184
  logger.info(f"Starting speech synthesis with providers: {provider_a} and {provider_b}")
185
 
186
  # Create two tasks for concurrent execution
187
- task_a = text_to_speech_with_hume(character_description, text, self.config)
188
  task_b = tts_provider_funcs[provider_b](character_description, text, self.config)
189
 
190
  # Await both tasks concurrently using asyncio.gather()
@@ -204,12 +219,15 @@ class Frontend:
204
  character_description,
205
  True,
206
  )
207
- except ElevenLabsError as ee:
208
- logger.error(f"Synthesis failed with ElevenLabsError during TTS generation: {ee!s}")
209
- raise gr.Error(f'There was an issue communicating with the Elevenlabs API: "{ee.message}"')
210
  except HumeError as he:
211
  logger.error(f"Synthesis failed with HumeError during TTS generation: {he!s}")
212
  raise gr.Error(f'There was an issue communicating with the Hume API: "{he.message}"')
 
 
 
 
 
 
213
  except Exception as e:
214
  logger.error(f"Synthesis failed with an unexpected error during TTS generation: {e!s}")
215
  raise gr.Error("An unexpected error occurred. Please try again shortly.")
@@ -243,7 +261,7 @@ class Frontend:
243
 
244
  Returns:
245
  A tuple of:
246
- - A boolean indicating if the vote was accepted.
247
  - A dict update for hiding vote button A.
248
  - A dict update for hiding vote button B.
249
  - A dict update for showing vote result A textbox.
@@ -330,13 +348,12 @@ class Frontend:
330
  # Only return an update if the data changed or force=True
331
  if data_updated:
332
  return gr.update(value=self._leaderboard_data)
333
- else:
334
- return gr.skip()
335
 
336
  async def _handle_tab_select(self, evt: gr.SelectData):
337
  """
338
  Handles tab selection events and refreshes the leaderboard if the Leaderboard tab is selected.
339
-
340
  Args:
341
  evt (gr.SelectData): Event data containing information about the selected tab
342
 
@@ -431,7 +448,7 @@ class Frontend:
431
  Builds the Title section
432
  """
433
  gr.HTML(
434
- """
435
  <div class="title-container">
436
  <h1>Expressive TTS Arena</h1>
437
  <div class="social-links">
@@ -468,9 +485,9 @@ class Frontend:
468
  with gr.Row():
469
  with gr.Column(scale=5):
470
  gr.HTML(
471
- """
472
  <h2 class="tab-header">📋 Instructions</h2>
473
- <ol>
474
  <li>
475
  Select a sample character, or input a custom character description and click
476
  <strong>"Generate Text"</strong>, to generate your text input.
@@ -487,7 +504,8 @@ class Frontend:
487
  <strong>"Select Option B"</strong>.
488
  </li>
489
  </ol>
490
- """
 
491
  )
492
  randomize_all_button = gr.Button(
493
  "🎲 Randomize All",
@@ -726,6 +744,13 @@ class Frontend:
726
  ],
727
  )
728
 
 
 
 
 
 
 
 
729
  # "Synthesize Speech" button click event handler chain:
730
  # 1. Disable components in the UI
731
  # 2. Reset UI state for audio players and voting results
@@ -854,15 +879,16 @@ class Frontend:
854
  with gr.Row():
855
  with gr.Column(scale=5):
856
  gr.HTML(
857
- """
858
  <h2 class="tab-header">🏆 Leaderboard</h2>
859
- <p>
860
  This leaderboard presents community voting results for different TTS providers, showing which
861
  ones users found more expressive and natural-sounding. The win rate reflects how often each
862
  provider was selected as the preferred option in head-to-head comparisons. Click the refresh
863
  button to see the most up-to-date voting results.
864
  </p>
865
- """
 
866
  )
867
  refresh_button = gr.Button(
868
  "↻ Refresh",
@@ -883,10 +909,64 @@ class Frontend:
883
  elem_id="leaderboard-table"
884
  )
885
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
886
  # Wrapper for the async refresh function
887
  async def async_refresh_handler():
888
  return await self._refresh_leaderboard(force=True)
889
-
890
  # Handler to re-enable the button after a refresh
891
  def reenable_button():
892
  time.sleep(3) # wait 3 seconds before enabling to prevent excessive data fetching
 
13
  import hashlib
14
  import json
15
  import time
16
+ from typing import List, Optional, Tuple
17
 
18
  # Third-Party Library Imports
19
  import gradio as gr
 
27
  AnthropicError,
28
  ElevenLabsError,
29
  HumeError,
30
+ OpenAIError,
31
  generate_text_with_claude,
32
  text_to_speech_with_elevenlabs,
33
  text_to_speech_with_hume,
34
+ text_to_speech_with_openai,
35
  )
36
  from src.utils import (
37
  create_shuffled_tts_options,
38
  determine_selected_option,
39
  get_leaderboard_data,
40
+ get_random_providers,
41
  submit_voting_results,
42
  validate_character_description_length,
43
  validate_text_length,
 
54
 
55
  # leaderboard update state
56
  self._leaderboard_data: List[List[str]] = [[]]
57
+ self._leaderboard_cache_hash: Optional[str] = None
58
+ self._last_leaderboard_update_time: float = 0.0
59
  self._min_refresh_interval = 30
60
 
61
  async def _update_leaderboard_data(self, force: bool = False) -> bool:
62
  """
63
  Fetches the latest leaderboard data only if needed based on cache and time constraints.
64
+
65
  Args:
66
  force (bool): If True, bypass the time-based throttling.
67
+
68
  Returns:
69
  bool: True if the leaderboard was updated, False otherwise.
70
  """
71
  current_time = time.time()
72
  time_since_last_update = current_time - self._last_leaderboard_update_time
73
+
74
  # Skip update if it's been less than min_refresh_interval seconds and not forced
75
  if not force and time_since_last_update < self._min_refresh_interval:
76
  logger.debug(f"Skipping leaderboard update: last updated {time_since_last_update:.1f}s ago.")
77
  return False
78
+
79
  # Fetch the latest data
80
  latest_leaderboard_data = await get_leaderboard_data(self.db_session_maker)
81
+
82
  # Generate a hash of the new data to check if it's changed
83
  data_str = json.dumps(str(latest_leaderboard_data))
84
  data_hash = hashlib.md5(data_str.encode()).hexdigest()
85
+
86
  # Check if the data has changed
87
  if data_hash == self._leaderboard_cache_hash and not force:
88
  logger.debug("Leaderboard data unchanged since last fetch.")
89
  return False
90
+
91
  # Update the cache and timestamp
92
  self._leaderboard_data = latest_leaderboard_data
93
  self._leaderboard_cache_hash = data_hash
 
127
  logger.error(f"Text Generation Failed: Unexpected error while generating text: {e!s}")
128
  raise gr.Error("Failed to generate text. Please try again shortly.")
129
 
130
+ def _warn_user_about_custom_text(self, text: str, generated_text: str) -> None:
131
+ """
132
+ Shows a warning to the user if they have modified the generated text.
133
+
134
+ When users edit the generated text instead of using it as-is, only Hume Octave
135
+ outputs will be generated for comparison rather than comparing against other
136
+ providers. This function displays a warning to inform users of this limitation.
137
+
138
+ Args:
139
+ text (str): The current text that will be used for synthesis.
140
+ generated_text (str): The original text that was generated by the system.
141
+
142
+ Returns:
143
+ None: This function displays a warning but does not return any value.
144
+ """
145
+ if text != generated_text:
146
+ gr.Warning("When custom text is used, only Hume Octave outputs are generated.")
147
+
148
  async def _synthesize_speech(
149
  self,
150
  character_description: str,
 
155
  Synthesizes two text-to-speech outputs, updates UI state components, and returns additional TTS metadata.
156
 
157
  This function generates TTS outputs using different providers based on the input text and its modification
158
+ state.
 
 
159
 
160
  The outputs are processed and shuffled, and the corresponding UI components for two audio players are updated.
161
  Additional metadata such as the comparison type, generation IDs, and state information are also returned.
 
168
 
169
  Returns:
170
  Tuple containing:
171
+ - gr.Audio: Update for the first audio player (with autoplay enabled).
172
+ - gr.Audio: Update for the second audio player.
173
  - OptionMap: A mapping of option constants to their corresponding TTS providers.
174
  - bool: Flag indicating whether the text was modified.
175
  - str: The original text that was synthesized.
 
187
  raise gr.Error(str(ve))
188
 
189
  text_modified = text != generated_text_state
190
+ provider_a, provider_b = get_random_providers(text_modified)
 
191
 
192
  tts_provider_funcs = {
193
  constants.HUME_AI: text_to_speech_with_hume,
194
+ constants.OPENAI: text_to_speech_with_openai,
195
  constants.ELEVENLABS: text_to_speech_with_elevenlabs,
196
  }
197
 
 
 
 
198
  try:
199
  logger.info(f"Starting speech synthesis with providers: {provider_a} and {provider_b}")
200
 
201
  # Create two tasks for concurrent execution
202
+ task_a = tts_provider_funcs[provider_a](character_description, text, self.config)
203
  task_b = tts_provider_funcs[provider_b](character_description, text, self.config)
204
 
205
  # Await both tasks concurrently using asyncio.gather()
 
219
  character_description,
220
  True,
221
  )
 
 
 
222
  except HumeError as he:
223
  logger.error(f"Synthesis failed with HumeError during TTS generation: {he!s}")
224
  raise gr.Error(f'There was an issue communicating with the Hume API: "{he.message}"')
225
+ except OpenAIError as oe:
226
+ logger.error(f"Synthesis failed with OpenAIError during TTS generation: {oe!s}")
227
+ raise gr.Error(f'There was an issue communicating with the OpenAI API: "{oe.message}"')
228
+ except ElevenLabsError as ee:
229
+ logger.error(f"Synthesis failed with ElevenLabsError during TTS generation: {ee!s}")
230
+ raise gr.Error(f'There was an issue communicating with the Elevenlabs API: "{ee.message}"')
231
  except Exception as e:
232
  logger.error(f"Synthesis failed with an unexpected error during TTS generation: {e!s}")
233
  raise gr.Error("An unexpected error occurred. Please try again shortly.")
 
261
 
262
  Returns:
263
  A tuple of:
264
+ - bool: A boolean indicating if the vote was accepted.
265
  - A dict update for hiding vote button A.
266
  - A dict update for hiding vote button B.
267
  - A dict update for showing vote result A textbox.
 
348
  # Only return an update if the data changed or force=True
349
  if data_updated:
350
  return gr.update(value=self._leaderboard_data)
351
+ return gr.skip()
 
352
 
353
  async def _handle_tab_select(self, evt: gr.SelectData):
354
  """
355
  Handles tab selection events and refreshes the leaderboard if the Leaderboard tab is selected.
356
+
357
  Args:
358
  evt (gr.SelectData): Event data containing information about the selected tab
359
 
 
448
  Builds the Title section
449
  """
450
  gr.HTML(
451
+ value="""
452
  <div class="title-container">
453
  <h1>Expressive TTS Arena</h1>
454
  <div class="social-links">
 
485
  with gr.Row():
486
  with gr.Column(scale=5):
487
  gr.HTML(
488
+ value="""
489
  <h2 class="tab-header">📋 Instructions</h2>
490
+ <ol style="padding-left: 8px;">
491
  <li>
492
  Select a sample character, or input a custom character description and click
493
  <strong>"Generate Text"</strong>, to generate your text input.
 
504
  <strong>"Select Option B"</strong>.
505
  </li>
506
  </ol>
507
+ """,
508
+ padding=False,
509
  )
510
  randomize_all_button = gr.Button(
511
  "🎲 Randomize All",
 
744
  ],
745
  )
746
 
747
+ # "Text Input" blur event handler
748
+ text_input.blur(
749
+ fn=self._warn_user_about_custom_text,
750
+ inputs=[text_input, generated_text_state],
751
+ outputs=[],
752
+ )
753
+
754
  # "Synthesize Speech" button click event handler chain:
755
  # 1. Disable components in the UI
756
  # 2. Reset UI state for audio players and voting results
 
879
  with gr.Row():
880
  with gr.Column(scale=5):
881
  gr.HTML(
882
+ value="""
883
  <h2 class="tab-header">🏆 Leaderboard</h2>
884
+ <p style="padding-left: 8px;">
885
  This leaderboard presents community voting results for different TTS providers, showing which
886
  ones users found more expressive and natural-sounding. The win rate reflects how often each
887
  provider was selected as the preferred option in head-to-head comparisons. Click the refresh
888
  button to see the most up-to-date voting results.
889
  </p>
890
+ """,
891
+ padding=False,
892
  )
893
  refresh_button = gr.Button(
894
  "↻ Refresh",
 
909
  elem_id="leaderboard-table"
910
  )
911
 
912
+ with gr.Accordion(label="Citation", open=False):
913
+ with gr.Column(variant="panel"):
914
+ with gr.Column(variant="panel"):
915
+ gr.HTML(
916
+ value="""
917
+ <h2>Citation</h2>
918
+ <p style="padding: 0 8px;">
919
+ When referencing this leaderboard or its dataset in academic publications, please cite:
920
+ </p>
921
+ """,
922
+ padding=False,
923
+ )
924
+ gr.Markdown(
925
+ value="""
926
+ **BibTeX**
927
+ ```BibTeX
928
+ @misc{expressive-tts-arena,
929
+ title = {Expressive TTS Arena: An Open Platform for Evaluating Text-to-Speech Expressiveness by Human Preference},
930
+ author = {Alan Cowen, Zachary Greathouse, Richard Marmorstein, Jeremy Hadfield},
931
+ year = {2025},
932
+ publisher = {Hugging Face},
933
+ howpublished = {\\url{https://huggingface.co/spaces/HumeAI/expressive-tts-arena}}
934
+ }
935
+ ```
936
+ """
937
+ )
938
+ gr.HTML(
939
+ value="""
940
+ <h2>Terms of Use</h2>
941
+ <p style="padding: 0 8px;">
942
+ Users are required to agree to the following terms before using the service:
943
+ </p>
944
+ <p style="padding: 0 8px;">
945
+ All generated audio clips are provided for research and evaluation purposes only.
946
+ The audio content may not be redistributed or used for commercial purposes without
947
+ explicit permission. Users should not upload any private or personally identifiable
948
+ information. Please report any bugs, issues, or concerns to our
949
+ <a href="https://discord.com/invite/humeai" target="_blank" class="provider-link">
950
+ Discord community
951
+ </a>.
952
+ </p>
953
+ """,
954
+ padding=False,
955
+ )
956
+ gr.HTML(
957
+ value="""
958
+ <h2>Acknowledgements</h2>
959
+ <p style="padding: 0 8px;">
960
+ We thank all participants who contributed their votes to help build this leaderboard.
961
+ </p>
962
+ """,
963
+ padding=False,
964
+ )
965
+
966
  # Wrapper for the async refresh function
967
  async def async_refresh_handler():
968
  return await self._refresh_leaderboard(force=True)
969
+
970
  # Handler to re-enable the button after a refresh
971
  def reenable_button():
972
  time.sleep(3) # wait 3 seconds before enabling to prevent excessive data fetching
src/integrations/__init__.py CHANGED
@@ -1,6 +1,7 @@
1
  from .anthropic_api import AnthropicConfig, AnthropicError, generate_text_with_claude
2
  from .elevenlabs_api import ElevenLabsConfig, ElevenLabsError, text_to_speech_with_elevenlabs
3
  from .hume_api import HumeConfig, HumeError, text_to_speech_with_hume
 
4
 
5
  __all__ = [
6
  "AnthropicConfig",
@@ -9,7 +10,10 @@ __all__ = [
9
  "ElevenLabsError",
10
  "HumeConfig",
11
  "HumeError",
 
 
12
  "generate_text_with_claude",
13
  "text_to_speech_with_elevenlabs",
14
  "text_to_speech_with_hume",
 
15
  ]
 
1
  from .anthropic_api import AnthropicConfig, AnthropicError, generate_text_with_claude
2
  from .elevenlabs_api import ElevenLabsConfig, ElevenLabsError, text_to_speech_with_elevenlabs
3
  from .hume_api import HumeConfig, HumeError, text_to_speech_with_hume
4
+ from .openai_api import OpenAIConfig, OpenAIError, text_to_speech_with_openai
5
 
6
  __all__ = [
7
  "AnthropicConfig",
 
10
  "ElevenLabsError",
11
  "HumeConfig",
12
  "HumeError",
13
+ "OpenAIConfig",
14
+ "OpenAIError",
15
  "generate_text_with_claude",
16
  "text_to_speech_with_elevenlabs",
17
  "text_to_speech_with_hume",
18
+ "text_to_speech_with_openai",
19
  ]
src/integrations/anthropic_api.py CHANGED
@@ -23,7 +23,7 @@ from tenacity import after_log, before_log, retry, retry_if_exception, stop_afte
23
 
24
  # Local Application Imports
25
  from src.config import Config, logger
26
- from src.constants import CLIENT_ERROR_CODE, SERVER_ERROR_CODE
27
  from src.utils import truncate_text, validate_env_var
28
 
29
  PROMPT_TEMPLATE: str = """
@@ -246,7 +246,7 @@ def _extract_anthropic_error_message(e: APIError) -> str:
246
  Returns:
247
  str: A clean, user-friendly error message suitable for display to end users.
248
  """
249
- clean_message = "An unknown error has occurred. Please try again later."
250
 
251
  if hasattr(e, 'body') and isinstance(e.body, dict):
252
  error_body = e.body
 
23
 
24
  # Local Application Imports
25
  from src.config import Config, logger
26
+ from src.constants import CLIENT_ERROR_CODE, GENERIC_API_ERROR_MESSAGE, SERVER_ERROR_CODE
27
  from src.utils import truncate_text, validate_env_var
28
 
29
  PROMPT_TEMPLATE: str = """
 
246
  Returns:
247
  str: A clean, user-friendly error message suitable for display to end users.
248
  """
249
+ clean_message = GENERIC_API_ERROR_MESSAGE
250
 
251
  if hasattr(e, 'body') and isinstance(e.body, dict):
252
  error_body = e.body
src/integrations/openai_api.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ openai_api.py
3
+
4
+ This file defines the interaction with the OpenAI text-to-speech (TTS) API using the
5
+ OpenAI Python SDK. It includes functionality for API request handling and processing API responses.
6
+
7
+ Key Features:
8
+ - Encapsulates all logic related to the OpenAI TTS API.
9
+ - Implements retry logic using Tenacity for handling transient API errors.
10
+ - Handles received audio and processes it for playback on the web.
11
+ - Provides detailed logging for debugging and error tracking.
12
+ - Utilizes robust error handling (EAFP) to validate API responses.
13
+ """
14
+
15
+ # Standard Library Imports
16
+ import logging
17
+ import random
18
+ import time
19
+ from dataclasses import dataclass, field
20
+ from pathlib import Path
21
+ from typing import Literal, Tuple, Union
22
+
23
+ # Third-Party Library Imports
24
+ from openai import APIError, AsyncOpenAI
25
+ from tenacity import after_log, before_log, retry, retry_if_exception, stop_after_attempt, wait_fixed
26
+
27
+ # Local Application Imports
28
+ from src.config import Config, logger
29
+ from src.constants import CLIENT_ERROR_CODE, GENERIC_API_ERROR_MESSAGE, SERVER_ERROR_CODE
30
+ from src.utils import validate_env_var
31
+
32
+
33
+ @dataclass(frozen=True)
34
+ class OpenAIConfig:
35
+ """Immutable configuration for interacting with the OpenAI TTS API."""
36
+
37
+ api_key: str = field(init=False)
38
+ model: str = "gpt-4o-mini-tts"
39
+ response_format: Literal['mp3', 'opus', 'aac', 'flac', 'wav', 'pcm'] = "mp3"
40
+
41
+ def __post_init__(self) -> None:
42
+ """Validate required attributes and set computed fields."""
43
+
44
+ computed_api_key = validate_env_var("OPENAI_API_KEY")
45
+ object.__setattr__(self, "api_key", computed_api_key)
46
+
47
+ @property
48
+ def client(self) -> AsyncOpenAI:
49
+ """
50
+ Lazy initialization of the asynchronous OpenAI client.
51
+
52
+ Returns:
53
+ AsyncOpenAI: Configured async client instance.
54
+ """
55
+ return AsyncOpenAI(api_key=self.api_key)
56
+
57
+ @staticmethod
58
+ def select_random_base_voice() -> str:
59
+ """
60
+ Randomly selects one of OpenAI's base voice options for TTS.
61
+
62
+ OpenAI's Python SDK doesn't export a type for their base voice names,
63
+ so we use a hardcoded list of the available voice options.
64
+
65
+ Returns:
66
+ str: A randomly selected OpenAI base voice name (e.g., 'alloy', 'nova', etc.)
67
+ """
68
+ openai_base_voices = ["alloy", "ash", "coral", "echo", "fable", "onyx", "nova", "sage", "shimmer"]
69
+ return random.choice(openai_base_voices)
70
+
71
+
72
+ class OpenAIError(Exception):
73
+ """Custom exception for errors related to the OpenAI TTS API."""
74
+
75
+ def __init__(self, message: str, original_exception: Union[Exception, None] = None):
76
+ super().__init__(message)
77
+ self.original_exception = original_exception
78
+ self.message = message
79
+
80
+
81
+ class UnretryableOpenAIError(OpenAIError):
82
+ """Custom exception for errors related to the OpenAI TTS API that should not be retried."""
83
+
84
+ def __init__(self, message: str, original_exception: Union[Exception, None] = None):
85
+ super().__init__(message, original_exception)
86
+ self.original_exception = original_exception
87
+ self.message = message
88
+
89
+
90
+ @retry(
91
+ retry=retry_if_exception(lambda e: not isinstance(e, UnretryableOpenAIError)),
92
+ stop=stop_after_attempt(2),
93
+ wait=wait_fixed(2),
94
+ before=before_log(logger, logging.DEBUG),
95
+ after=after_log(logger, logging.DEBUG),
96
+ reraise=True,
97
+ )
98
+ async def text_to_speech_with_openai(
99
+ character_description: str,
100
+ text: str,
101
+ config: Config,
102
+ ) -> Tuple[None, str]:
103
+ """
104
+ Asynchronously synthesizes speech using the OpenAI TTS API, processes audio data, and writes audio to a file.
105
+
106
+ This function uses the OpenAI Python SDK to send a request to the OpenAI TTS API with a character description
107
+ and text to be converted to speech. It extracts the base64-encoded audio and generation ID from the response,
108
+ saves the audio as an MP3 file, and returns the relevant details.
109
+
110
+ Args:
111
+ character_description (str): Description used for voice synthesis.
112
+ text (str): Text to be converted to speech.
113
+ config (Config): Application configuration containing OpenAI API settings.
114
+
115
+ Returns:
116
+ Tuple[str, str]: A tuple containing:
117
+ - generation_id (str): Unique identifier for the generated audio.
118
+ - audio_file_path (str): Path to the saved audio file.
119
+
120
+ Raises:
121
+ OpenAIError: For errors communicating with the OpenAI API.
122
+ UnretryableOpenAIError: For client-side HTTP errors (status code 4xx).
123
+ """
124
+ logger.debug(f"Synthesizing speech with OpenAI. Text length: {len(text)} characters.")
125
+ openai_config = config.openai_config
126
+ client = openai_config.client
127
+ start_time = time.time()
128
+ try:
129
+ voice = openai_config.select_random_base_voice()
130
+ async with client.audio.speech.with_streaming_response.create(
131
+ model=openai_config.model,
132
+ input=text,
133
+ instructions=character_description,
134
+ response_format=openai_config.response_format,
135
+ voice=voice, # OpenAI requires a base voice to be specified
136
+ ) as response:
137
+ elapsed_time = time.time() - start_time
138
+ logger.info(f"OpenAI API request completed in {elapsed_time:.2f} seconds")
139
+
140
+ filename = f"openai_{voice}_{start_time}"
141
+ audio_file_path = Path(config.audio_dir) / filename
142
+ await response.stream_to_file(audio_file_path)
143
+ relative_audio_file_path = audio_file_path.relative_to(Path.cwd())
144
+
145
+ return None, str(relative_audio_file_path)
146
+
147
+ except APIError as e:
148
+ elapsed_time = time.time() - start_time
149
+ logger.error(f"OpenAI API request failed after {elapsed_time:.2f} seconds: {e!s}")
150
+ logger.error(f"Full OpenAI API error: {e!s}")
151
+ clean_message = _extract_openai_error_message(e)
152
+
153
+ if (
154
+ hasattr(e, 'status_code')
155
+ and e.status_code is not None
156
+ and CLIENT_ERROR_CODE <= e.status_code < SERVER_ERROR_CODE
157
+ ):
158
+ raise UnretryableOpenAIError(message=clean_message, original_exception=e) from e
159
+
160
+ raise OpenAIError(message=clean_message, original_exception=e) from e
161
+
162
+ except Exception as e:
163
+ error_type = type(e).__name__
164
+ error_message = str(e) if str(e) else f"An error of type {error_type} occurred"
165
+ logger.error("Error during OpenAI API call: %s - %s", error_type, error_message)
166
+ clean_message = GENERIC_API_ERROR_MESSAGE
167
+
168
+ raise OpenAIError(message=clean_message, original_exception=e) from e
169
+
170
+
171
+ def _extract_openai_error_message(e: APIError) -> str:
172
+ """
173
+ Extracts a clean, user-friendly error message from an OpenAI API error response.
174
+
175
+ Args:
176
+ e (APIError): The OpenAI API error exception containing response information.
177
+
178
+ Returns:
179
+ str: A clean, user-friendly error message suitable for display to end users.
180
+ """
181
+ clean_message = GENERIC_API_ERROR_MESSAGE
182
+
183
+ if hasattr(e, 'body') and isinstance(e.body, dict):
184
+ error_body = e.body
185
+ if (
186
+ 'error' in error_body
187
+ and isinstance(error_body['error'], dict)
188
+ and 'message' in error_body['error']
189
+ ):
190
+ clean_message = error_body['error']['message']
191
+
192
+ return clean_message
src/utils.py CHANGED
@@ -204,22 +204,37 @@ def save_base64_audio_to_file(base64_audio: str, filename: str, config: Config)
204
  return str(relative_path)
205
 
206
 
207
- def get_random_provider(text_modified: bool) -> TTSProviderName:
208
  """
209
- Select a TTS provider based on whether the text has been modified.
 
 
 
 
 
 
 
 
210
 
211
  Args:
212
- text_modified (bool): A flag indicating whether the text has been modified.
213
 
214
  Returns:
215
- provider: A TTS provider selected based on the following criteria:
216
- - If the text has been modified, it will be "Hume AI"
217
- - Otherwise, it will be "Hume AI" 30% of the time and "ElevenLabs" 70% of the time
218
  """
219
  if text_modified:
220
- return constants.HUME_AI
221
 
222
- return constants.HUME_AI if random.random() < 0.3 else constants.ELEVENLABS
 
 
 
 
 
 
 
 
 
223
 
224
 
225
  def create_shuffled_tts_options(option_a: Option, option_b: Option) -> OptionMap:
@@ -285,9 +300,6 @@ def _determine_comparison_type(provider_a: TTSProviderName, provider_b: TTSProvi
285
  """
286
  Determine the comparison type based on the given TTS provider names.
287
 
288
- If both providers are HUME_AI, the comparison type is HUME_TO_HUME.
289
- If either provider is ELEVENLABS, the comparison type is HUME_TO_ELEVENLABS.
290
-
291
  Args:
292
  provider_a (TTSProviderName): The first TTS provider.
293
  provider_b (TTSProviderName): The second TTS provider.
@@ -302,9 +314,17 @@ def _determine_comparison_type(provider_a: TTSProviderName, provider_b: TTSProvi
302
  if provider_a == constants.HUME_AI and provider_b == constants.HUME_AI:
303
  return constants.HUME_TO_HUME
304
 
305
- if constants.ELEVENLABS in (provider_a, provider_b):
 
 
306
  return constants.HUME_TO_ELEVENLABS
307
 
 
 
 
 
 
 
308
  raise ValueError(f"Invalid provider combination: {provider_a}, {provider_b}")
309
 
310
 
 
204
  return str(relative_path)
205
 
206
 
207
+ def get_random_providers(text_modified: bool) -> Tuple[TTSProviderName, TTSProviderName]:
208
  """
209
+ Select 2 TTS providers based on whether the text has been modified.
210
+
211
+ Probabilities:
212
+ - 50% HUME_AI, OPENAI
213
+ - 25% OPENAI, ELEVENLABS
214
+ - 20% HUME_AI, ELEVENLABS
215
+ - 5% HUME_AI, HUME_AI
216
+
217
+ If the `text_modified` argument is `True`, then 100% HUME_AI, HUME_AI
218
 
219
  Args:
220
+ text_modified (bool): A flag indicating whether the text has been modified, indicating a custom text input.
221
 
222
  Returns:
223
+ tuple: A tuple (TTSProviderName, TTSProviderName)
 
 
224
  """
225
  if text_modified:
226
+ return constants.HUME_AI, constants.HUME_AI
227
 
228
+ # When modifying the probability distribution, make sure the weights match the order of provider pairs
229
+ provider_pairs = [
230
+ (constants.HUME_AI, constants.OPENAI),
231
+ (constants.OPENAI, constants.ELEVENLABS),
232
+ (constants.HUME_AI, constants.ELEVENLABS),
233
+ (constants.HUME_AI, constants.HUME_AI)
234
+ ]
235
+ weights = [0.5, 0.25, 0.2, 0.05]
236
+
237
+ return random.choices(provider_pairs, weights=weights, k=1)[0]
238
 
239
 
240
  def create_shuffled_tts_options(option_a: Option, option_b: Option) -> OptionMap:
 
300
  """
301
  Determine the comparison type based on the given TTS provider names.
302
 
 
 
 
303
  Args:
304
  provider_a (TTSProviderName): The first TTS provider.
305
  provider_b (TTSProviderName): The second TTS provider.
 
314
  if provider_a == constants.HUME_AI and provider_b == constants.HUME_AI:
315
  return constants.HUME_TO_HUME
316
 
317
+ providers = (provider_a, provider_b)
318
+
319
+ if constants.HUME_AI in providers and constants.ELEVENLABS in providers:
320
  return constants.HUME_TO_ELEVENLABS
321
 
322
+ if constants.HUME_AI in providers and constants.OPENAI in providers:
323
+ return constants.HUME_TO_OPENAI
324
+
325
+ if constants.ELEVENLABS in providers and constants.OPENAI in providers:
326
+ return constants.OPENAI_TO_ELEVENLABS
327
+
328
  raise ValueError(f"Invalid provider combination: {provider_a}, {provider_b}")
329
 
330
 
uv.lock CHANGED
@@ -7,7 +7,10 @@ resolution-markers = [
7
  ]
8
 
9
  [manifest]
10
- overrides = [{ name = "aiofiles", specifier = "==24.1.0" }]
 
 
 
11
 
12
  [[package]]
13
  name = "aiofiles"
@@ -165,6 +168,15 @@ wheels = [
165
  { url = "https://files.pythonhosted.org/packages/38/fc/bce832fd4fd99766c04d1ee0eead6b0ec6486fb100ae5e74c1d91292b982/certifi-2025.1.31-py3-none-any.whl", hash = "sha256:ca78db4565a652026a4db2bcdf68f2fb589ea80d0be70e03929ed730746b84fe", size = 166393 },
166
  ]
167
 
 
 
 
 
 
 
 
 
 
168
  [[package]]
169
  name = "cfgv"
170
  version = "3.4.0"
@@ -227,7 +239,7 @@ name = "click"
227
  version = "8.1.8"
228
  source = { registry = "https://pypi.org/simple" }
229
  dependencies = [
230
- { name = "colorama", marker = "sys_platform == 'win32'" },
231
  ]
232
  sdist = { url = "https://files.pythonhosted.org/packages/b9/2e/0090cbf739cee7d23781ad4b89a9894a41538e4fcf4c31dcdd705b78eb8b/click-8.1.8.tar.gz", hash = "sha256:ed53c9d8990d83c2a27deae68e4ee337473f6330c040a31d4225c9574d16096a", size = 226593 }
233
  wheels = [
@@ -299,6 +311,7 @@ dependencies = [
299
  { name = "gradio" },
300
  { name = "greenlet" },
301
  { name = "hume" },
 
302
  { name = "python-dotenv" },
303
  { name = "sqlalchemy" },
304
  { name = "tenacity" },
@@ -324,6 +337,7 @@ requires-dist = [
324
  { name = "gradio", specifier = ">=5.18.0" },
325
  { name = "greenlet", specifier = ">=2.0.0" },
326
  { name = "hume", specifier = ">=0.7.8" },
 
327
  { name = "python-dotenv", specifier = ">=1.0.1" },
328
  { name = "sqlalchemy", specifier = ">=2.0.0" },
329
  { name = "tenacity", specifier = ">=9.0.0" },
@@ -783,6 +797,27 @@ wheels = [
783
  { url = "https://files.pythonhosted.org/packages/80/94/cd9e9b04012c015cb6320ab3bf43bc615e248dddfeb163728e800a5d96f0/numpy-2.2.2-cp313-cp313t-win_amd64.whl", hash = "sha256:97b974d3ba0fb4612b77ed35d7627490e8e3dff56ab41454d9e8b23448940576", size = 12696208 },
784
  ]
785
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
786
  [[package]]
787
  name = "orjson"
788
  version = "3.10.15"
@@ -963,6 +998,15 @@ wheels = [
963
  { url = "https://files.pythonhosted.org/packages/43/b3/df14c580d82b9627d173ceea305ba898dca135feb360b6d84019d0803d3b/pre_commit-4.1.0-py2.py3-none-any.whl", hash = "sha256:d29e7cb346295bcc1cc75fc3e92e343495e3ea0196c9ec6ba53f49f10ab6ae7b", size = 220560 },
964
  ]
965
 
 
 
 
 
 
 
 
 
 
966
  [[package]]
967
  name = "pydantic"
968
  version = "2.10.6"
@@ -1251,6 +1295,18 @@ wheels = [
1251
  { url = "https://files.pythonhosted.org/packages/e9/44/75a9c9421471a6c4805dbf2356f7c181a29c1879239abab1ea2cc8f38b40/sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2", size = 10235 },
1252
  ]
1253
 
 
 
 
 
 
 
 
 
 
 
 
 
1254
  [[package]]
1255
  name = "soupsieve"
1256
  version = "2.6"
@@ -1332,7 +1388,7 @@ name = "tqdm"
1332
  version = "4.67.1"
1333
  source = { registry = "https://pypi.org/simple" }
1334
  dependencies = [
1335
- { name = "colorama", marker = "sys_platform == 'win32'" },
1336
  ]
1337
  sdist = { url = "https://files.pythonhosted.org/packages/a8/4b/29b4ef32e036bb34e4ab51796dd745cdba7ed47ad142a9f4a1eb8e0c744d/tqdm-4.67.1.tar.gz", hash = "sha256:f8aef9c52c08c13a65f30ea34f4e5aac3fd1a34959879d7e59e63027286627f2", size = 169737 }
1338
  wheels = [
 
7
  ]
8
 
9
  [manifest]
10
+ overrides = [
11
+ { name = "aiofiles", specifier = "==24.1.0" },
12
+ { name = "sounddevice", marker = "sys_platform == 'never'" },
13
+ ]
14
 
15
  [[package]]
16
  name = "aiofiles"
 
168
  { url = "https://files.pythonhosted.org/packages/38/fc/bce832fd4fd99766c04d1ee0eead6b0ec6486fb100ae5e74c1d91292b982/certifi-2025.1.31-py3-none-any.whl", hash = "sha256:ca78db4565a652026a4db2bcdf68f2fb589ea80d0be70e03929ed730746b84fe", size = 166393 },
169
  ]
170
 
171
+ [[package]]
172
+ name = "cffi"
173
+ version = "1.17.1"
174
+ source = { registry = "https://pypi.org/simple" }
175
+ dependencies = [
176
+ { name = "pycparser" },
177
+ ]
178
+ sdist = { url = "https://files.pythonhosted.org/packages/fc/97/c783634659c2920c3fc70419e3af40972dbaf758daa229a7d6ea6135c90d/cffi-1.17.1.tar.gz", hash = "sha256:1c39c6016c32bc48dd54561950ebd6836e1670f2ae46128f67cf49e789c52824", size = 516621 }
179
+
180
  [[package]]
181
  name = "cfgv"
182
  version = "3.4.0"
 
239
  version = "8.1.8"
240
  source = { registry = "https://pypi.org/simple" }
241
  dependencies = [
242
+ { name = "colorama", marker = "platform_system == 'Windows'" },
243
  ]
244
  sdist = { url = "https://files.pythonhosted.org/packages/b9/2e/0090cbf739cee7d23781ad4b89a9894a41538e4fcf4c31dcdd705b78eb8b/click-8.1.8.tar.gz", hash = "sha256:ed53c9d8990d83c2a27deae68e4ee337473f6330c040a31d4225c9574d16096a", size = 226593 }
245
  wheels = [
 
311
  { name = "gradio" },
312
  { name = "greenlet" },
313
  { name = "hume" },
314
+ { name = "openai" },
315
  { name = "python-dotenv" },
316
  { name = "sqlalchemy" },
317
  { name = "tenacity" },
 
337
  { name = "gradio", specifier = ">=5.18.0" },
338
  { name = "greenlet", specifier = ">=2.0.0" },
339
  { name = "hume", specifier = ">=0.7.8" },
340
+ { name = "openai", specifier = ">=1.68.0" },
341
  { name = "python-dotenv", specifier = ">=1.0.1" },
342
  { name = "sqlalchemy", specifier = ">=2.0.0" },
343
  { name = "tenacity", specifier = ">=9.0.0" },
 
797
  { url = "https://files.pythonhosted.org/packages/80/94/cd9e9b04012c015cb6320ab3bf43bc615e248dddfeb163728e800a5d96f0/numpy-2.2.2-cp313-cp313t-win_amd64.whl", hash = "sha256:97b974d3ba0fb4612b77ed35d7627490e8e3dff56ab41454d9e8b23448940576", size = 12696208 },
798
  ]
799
 
800
+ [[package]]
801
+ name = "openai"
802
+ version = "1.68.0"
803
+ source = { registry = "https://pypi.org/simple" }
804
+ dependencies = [
805
+ { name = "anyio" },
806
+ { name = "distro" },
807
+ { name = "httpx" },
808
+ { name = "jiter" },
809
+ { name = "numpy" },
810
+ { name = "pydantic" },
811
+ { name = "sniffio" },
812
+ { name = "sounddevice", marker = "sys_platform == 'never'" },
813
+ { name = "tqdm" },
814
+ { name = "typing-extensions" },
815
+ ]
816
+ sdist = { url = "https://files.pythonhosted.org/packages/58/ea/58102e9bfda09edc963e6e877e39cca12706b46ebf35d5fc9da7b8af10f2/openai-1.68.0.tar.gz", hash = "sha256:c570c06c9ba10f98b891ac30a3dd7b5c89ed48094c711c7a3f35fb5ade6c0757", size = 413039 }
817
+ wheels = [
818
+ { url = "https://files.pythonhosted.org/packages/a5/b6/bd67b7031572cba7d8451d82ac4a990b3a96bbd3b037634726b48ac972c8/openai-1.68.0-py3-none-any.whl", hash = "sha256:20e279b0f3a78cb4a95f3eab2a180f3ee30c6a196aeebd6bf642a4f88ab85ee1", size = 605645 },
819
+ ]
820
+
821
  [[package]]
822
  name = "orjson"
823
  version = "3.10.15"
 
998
  { url = "https://files.pythonhosted.org/packages/43/b3/df14c580d82b9627d173ceea305ba898dca135feb360b6d84019d0803d3b/pre_commit-4.1.0-py2.py3-none-any.whl", hash = "sha256:d29e7cb346295bcc1cc75fc3e92e343495e3ea0196c9ec6ba53f49f10ab6ae7b", size = 220560 },
999
  ]
1000
 
1001
+ [[package]]
1002
+ name = "pycparser"
1003
+ version = "2.22"
1004
+ source = { registry = "https://pypi.org/simple" }
1005
+ sdist = { url = "https://files.pythonhosted.org/packages/1d/b2/31537cf4b1ca988837256c910a668b553fceb8f069bedc4b1c826024b52c/pycparser-2.22.tar.gz", hash = "sha256:491c8be9c040f5390f5bf44a5b07752bd07f56edf992381b05c701439eec10f6", size = 172736 }
1006
+ wheels = [
1007
+ { url = "https://files.pythonhosted.org/packages/13/a3/a812df4e2dd5696d1f351d58b8fe16a405b234ad2886a0dab9183fb78109/pycparser-2.22-py3-none-any.whl", hash = "sha256:c3702b6d3dd8c7abc1afa565d7e63d53a1d0bd86cdc24edd75470f4de499cfcc", size = 117552 },
1008
+ ]
1009
+
1010
  [[package]]
1011
  name = "pydantic"
1012
  version = "2.10.6"
 
1295
  { url = "https://files.pythonhosted.org/packages/e9/44/75a9c9421471a6c4805dbf2356f7c181a29c1879239abab1ea2cc8f38b40/sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2", size = 10235 },
1296
  ]
1297
 
1298
+ [[package]]
1299
+ name = "sounddevice"
1300
+ version = "0.5.1"
1301
+ source = { registry = "https://pypi.org/simple" }
1302
+ dependencies = [
1303
+ { name = "cffi" },
1304
+ ]
1305
+ sdist = { url = "https://files.pythonhosted.org/packages/80/2d/b04ae180312b81dbb694504bee170eada5372242e186f6298139fd3a0513/sounddevice-0.5.1.tar.gz", hash = "sha256:09ca991daeda8ce4be9ac91e15a9a81c8f81efa6b695a348c9171ea0c16cb041", size = 52896 }
1306
+ wheels = [
1307
+ { url = "https://files.pythonhosted.org/packages/06/d1/464b5fca3decdd0cfec8c47f7b4161a0b12972453201c1bf03811f367c5e/sounddevice-0.5.1-py3-none-any.whl", hash = "sha256:e2017f182888c3f3c280d9fbac92e5dbddac024a7e3442f6e6116bd79dab8a9c", size = 32276 },
1308
+ ]
1309
+
1310
  [[package]]
1311
  name = "soupsieve"
1312
  version = "2.6"
 
1388
  version = "4.67.1"
1389
  source = { registry = "https://pypi.org/simple" }
1390
  dependencies = [
1391
+ { name = "colorama", marker = "platform_system == 'Windows'" },
1392
  ]
1393
  sdist = { url = "https://files.pythonhosted.org/packages/a8/4b/29b4ef32e036bb34e4ab51796dd745cdba7ed47ad142a9f4a1eb8e0c744d/tqdm-4.67.1.tar.gz", hash = "sha256:f8aef9c52c08c13a65f30ea34f4e5aac3fd1a34959879d7e59e63027286627f2", size = 169737 }
1394
  wheels = [