Spaces:
Running
Running
Zachary Greathouse
commited on
Zg/codebase refactor (#20)
Browse files* Decomposes frontend UI code
* Encapsulates TTS to a new TTSService
* Encapsulates interaction with voting DB to a VotingService
* Composed frontend code in frontend directory
* Services defined in core directory
* Shared utils, types, config, and constants in common directory
* Middleware in middleware directory
* Clean up type annotations and docstrings
- README.md +39 -24
- pyproject.toml +1 -1
- src/common/__init__.py +34 -0
- src/{custom_types.py β common/common_types.py} +2 -12
- src/{config.py β common/config.py} +1 -13
- src/common/constants.py +49 -0
- src/common/utils.py +103 -0
- src/constants.py +0 -170
- src/core/__init__.py +4 -0
- src/core/tts_service.py +120 -0
- src/core/voting_service.py +281 -0
- src/database/crud.py +3 -11
- src/database/database.py +1 -15
- src/database/models.py +1 -9
- src/frontend/__init__.py +3 -0
- src/frontend/components/__init__.py +4 -0
- src/{frontend.py β frontend/components/arena.py} +364 -523
- src/frontend/components/leaderboard.py +298 -0
- src/frontend/frontend.py +127 -0
- src/integrations/__init__.py +4 -4
- src/integrations/{anthropic_api.py β anthropic.py} +26 -40
- src/integrations/{elevenlabs_api.py β elevenlabs.py} +10 -26
- src/integrations/{hume_api.py β hume.py} +6 -24
- src/integrations/{openai_api.py β openai.py} +11 -30
- src/main.py +5 -68
- src/middleware/__init__.py +3 -0
- src/middleware/meta_tag_injection.py +155 -0
- src/scripts/init_db.py +1 -1
- src/scripts/test_db.py +1 -1
- src/utils.py +0 -650
README.md
CHANGED
@@ -36,33 +36,48 @@ For support or to join the conversation, visit our [Discord](https://discord.com
|
|
36 |
|
37 |
```
|
38 |
Expressive TTS Arena/
|
39 |
-
βββ public/
|
40 |
-
βββ src/
|
41 |
-
β βββ
|
42 |
-
β β βββ __init__.py
|
43 |
-
β β βββ
|
44 |
-
β β βββ
|
45 |
-
β β
|
46 |
-
β βββ
|
47 |
-
β
|
48 |
-
β β βββ
|
49 |
-
β β βββ
|
50 |
-
β β
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
β βββ scripts/
|
52 |
-
β β βββ __init__.py
|
53 |
-
β β βββ init_db.py
|
54 |
-
β β βββ test_db.py
|
55 |
-
β βββ __init__.py
|
56 |
-
β βββ
|
57 |
-
β βββ constants.py # Global constants
|
58 |
-
β βββ custom_types.py # Global custom types
|
59 |
-
β βββ frontend.py # Gradio UI components
|
60 |
-
β βββ main.py # Entry file
|
61 |
-
β βββ utils.py # Utility functions
|
62 |
βββ static/
|
63 |
-
β βββ audio/
|
64 |
β βββ css/
|
65 |
-
β β βββ styles.css
|
66 |
βββ .dockerignore
|
67 |
βββ .env.example
|
68 |
βββ .gitignore
|
|
|
36 |
|
37 |
```
|
38 |
Expressive TTS Arena/
|
39 |
+
βββ public/
|
40 |
+
βββ src/
|
41 |
+
β βββ common/
|
42 |
+
β β βββ __init__.py
|
43 |
+
β β βββ common_types.py # Application-wide custom type aliases and definitions.
|
44 |
+
β β βββ config.py # Manages application config (Singleton) loaded from env vars.
|
45 |
+
β β βββ constants.py # Application-wide constant values.
|
46 |
+
β β βββ utils.py # General-purpose utility functions used across modules.
|
47 |
+
β βββ core/
|
48 |
+
β β βββ __init__.py
|
49 |
+
β β βββ tts_service.py # Service handling Text-to-Speech provider selection and API calls.
|
50 |
+
β β βββ voting_service.py # Service managing database operations for votes and leaderboards.
|
51 |
+
β βββ database/ # Database access layer using SQLAlchemy.
|
52 |
+
β β βββ __init__.py
|
53 |
+
β β βββ crud.py # Data Access Objects (DAO) / CRUD operations for database models.
|
54 |
+
β β βββ database.py # Database connection setup (engine, session management).
|
55 |
+
β β βββ models.py # SQLAlchemy ORM models defining database tables.
|
56 |
+
β βββ frontend/
|
57 |
+
β β βββ components/
|
58 |
+
β β β β βββ __init__.py
|
59 |
+
β β β β βββ arena.py # UI definition and logic for the 'Arena' tab.
|
60 |
+
β β β β βββ leaderboard.py # UI definition and logic for the 'Leaderboard' tab.
|
61 |
+
β β βββ __init__.py
|
62 |
+
β β βββ frontend.py # Main Gradio application class; orchestrates UI components and layout.
|
63 |
+
β βββ integrations/ # Modules for interacting with external third-party APIs.
|
64 |
+
β β βββ __init__.py
|
65 |
+
β β βββ anthropic_api.py # Integration logic for the Anthropic API.
|
66 |
+
β β βββ elevenlabs_api.py # Integration logic for the ElevenLabs API.
|
67 |
+
β β βββ hume_api.py # Integration logic for the Hume API.
|
68 |
+
β βββ middleware/
|
69 |
+
β β βββ __init__.py
|
70 |
+
β β βββ meta_tag_injection.py # Middleware for injecting custom HTML meta tags into the Gradio page.
|
71 |
β βββ scripts/
|
72 |
+
β β βββ __init__.py
|
73 |
+
β β βββ init_db.py # Script to create database tables based on models.
|
74 |
+
β β βββ test_db.py # Script for testing the database connection configuration.
|
75 |
+
β βββ __init__.py
|
76 |
+
β βββ main.py # Main script to configure and run the Gradio application.
|
|
|
|
|
|
|
|
|
|
|
77 |
βββ static/
|
78 |
+
β βββ audio/ # Temporary storage for generated audio files served to the UI.
|
79 |
β βββ css/
|
80 |
+
β β βββ styles.css # Custom CSS overrides and styling for the Gradio UI.
|
81 |
βββ .dockerignore
|
82 |
βββ .env.example
|
83 |
βββ .gitignore
|
pyproject.toml
CHANGED
@@ -88,7 +88,7 @@ select = [
|
|
88 |
"TID",
|
89 |
"W",
|
90 |
]
|
91 |
-
per-file-ignores = { "src/
|
92 |
|
93 |
[tool.ruff.lint.pycodestyle]
|
94 |
max-line-length = 120
|
|
|
88 |
"TID",
|
89 |
"W",
|
90 |
]
|
91 |
+
per-file-ignores = { "src/frontend/components/arena.py" = ["E501"], "src/frontend/components/leaderboard.py" = ["E501"], "src/middleware/meta_tag_injection.py" = ["E501"] }
|
92 |
|
93 |
[tool.ruff.lint.pycodestyle]
|
94 |
max-line-length = 120
|
src/common/__init__.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from . import constants
|
2 |
+
from .common_types import (
|
3 |
+
ComparisonType,
|
4 |
+
LeaderboardEntry,
|
5 |
+
LeaderboardTableEntries,
|
6 |
+
Option,
|
7 |
+
OptionDetail,
|
8 |
+
OptionKey,
|
9 |
+
OptionLabel,
|
10 |
+
OptionMap,
|
11 |
+
TTSProviderName,
|
12 |
+
VotingResults,
|
13 |
+
)
|
14 |
+
from .config import Config, logger
|
15 |
+
from .utils import save_base64_audio_to_file, validate_env_var
|
16 |
+
|
17 |
+
__all__ = [
|
18 |
+
"ComparisonType",
|
19 |
+
"Config",
|
20 |
+
"LeaderboardEntry",
|
21 |
+
"LeaderboardTableEntries",
|
22 |
+
"Option",
|
23 |
+
"OptionDetail",
|
24 |
+
"OptionKey",
|
25 |
+
"OptionLabel",
|
26 |
+
"OptionMap",
|
27 |
+
"TTSProviderName",
|
28 |
+
"VotingResults",
|
29 |
+
"constants",
|
30 |
+
"logger",
|
31 |
+
"save_base64_audio_to_file",
|
32 |
+
"utils",
|
33 |
+
"validate_env_var",
|
34 |
+
]
|
src/{custom_types.py β common/common_types.py}
RENAMED
@@ -1,9 +1,3 @@
|
|
1 |
-
"""
|
2 |
-
custom_types.py
|
3 |
-
|
4 |
-
This module defines custom types for the application.
|
5 |
-
"""
|
6 |
-
|
7 |
# Standard Library Imports
|
8 |
from typing import List, Literal, NamedTuple, Optional, TypedDict
|
9 |
|
@@ -12,8 +6,8 @@ TTSProviderName = Literal["Hume AI", "ElevenLabs", "OpenAI"]
|
|
12 |
|
13 |
|
14 |
ComparisonType = Literal[
|
15 |
-
"Hume AI - Hume AI",
|
16 |
-
"Hume AI - ElevenLabs",
|
17 |
"Hume AI - OpenAI",
|
18 |
"OpenAI - ElevenLabs"
|
19 |
]
|
@@ -41,7 +35,6 @@ class Option(NamedTuple):
|
|
41 |
audio (str): The relative file path to the audio file produced by the TTS provider.
|
42 |
generation_id (str): The unique identifier for this TTS generation.
|
43 |
"""
|
44 |
-
|
45 |
provider: TTSProviderName
|
46 |
audio: str
|
47 |
generation_id: str
|
@@ -49,7 +42,6 @@ class Option(NamedTuple):
|
|
49 |
|
50 |
class VotingResults(TypedDict):
|
51 |
"""Voting results data structure representing values we want to persist to the votes DB"""
|
52 |
-
|
53 |
comparison_type: ComparisonType
|
54 |
winning_provider: TTSProviderName
|
55 |
winning_option: OptionKey
|
@@ -71,7 +63,6 @@ class OptionDetail(TypedDict):
|
|
71 |
generation_id (Optional[str]): The unique identifier for this TTS generation, or None if not available.
|
72 |
audio_file_path (str): The relative file path to the generated audio file.
|
73 |
"""
|
74 |
-
|
75 |
provider: TTSProviderName
|
76 |
generation_id: Optional[str]
|
77 |
audio_file_path: str
|
@@ -85,7 +76,6 @@ class OptionMap(TypedDict):
|
|
85 |
option_a: OptionDetail,
|
86 |
option_b: OptionDetail
|
87 |
"""
|
88 |
-
|
89 |
option_a: OptionDetail
|
90 |
option_b: OptionDetail
|
91 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
# Standard Library Imports
|
2 |
from typing import List, Literal, NamedTuple, Optional, TypedDict
|
3 |
|
|
|
6 |
|
7 |
|
8 |
ComparisonType = Literal[
|
9 |
+
"Hume AI - Hume AI",
|
10 |
+
"Hume AI - ElevenLabs",
|
11 |
"Hume AI - OpenAI",
|
12 |
"OpenAI - ElevenLabs"
|
13 |
]
|
|
|
35 |
audio (str): The relative file path to the audio file produced by the TTS provider.
|
36 |
generation_id (str): The unique identifier for this TTS generation.
|
37 |
"""
|
|
|
38 |
provider: TTSProviderName
|
39 |
audio: str
|
40 |
generation_id: str
|
|
|
42 |
|
43 |
class VotingResults(TypedDict):
|
44 |
"""Voting results data structure representing values we want to persist to the votes DB"""
|
|
|
45 |
comparison_type: ComparisonType
|
46 |
winning_provider: TTSProviderName
|
47 |
winning_option: OptionKey
|
|
|
63 |
generation_id (Optional[str]): The unique identifier for this TTS generation, or None if not available.
|
64 |
audio_file_path (str): The relative file path to the generated audio file.
|
65 |
"""
|
|
|
66 |
provider: TTSProviderName
|
67 |
generation_id: Optional[str]
|
68 |
audio_file_path: str
|
|
|
76 |
option_a: OptionDetail,
|
77 |
option_b: OptionDetail
|
78 |
"""
|
|
|
79 |
option_a: OptionDetail
|
80 |
option_b: OptionDetail
|
81 |
|
src/{config.py β common/config.py}
RENAMED
@@ -1,15 +1,3 @@
|
|
1 |
-
"""
|
2 |
-
config.py
|
3 |
-
|
4 |
-
Global configuration and logger setup for the project.
|
5 |
-
|
6 |
-
Key Features:
|
7 |
-
- Uses environment variables defined in the system (Docker in production).
|
8 |
-
- Loads a `.env` file only in development to simulate production variables locally.
|
9 |
-
- Configures the logger for consistent logging across all modules.
|
10 |
-
- Dynamically enables DEBUG logging in development and INFO logging in production (unless overridden).
|
11 |
-
"""
|
12 |
-
|
13 |
# Standard Library Imports
|
14 |
import logging
|
15 |
import os
|
@@ -75,7 +63,7 @@ class Config:
|
|
75 |
audio_dir = Path.cwd() / "static" / "audio"
|
76 |
audio_dir.mkdir(parents=True, exist_ok=True)
|
77 |
|
78 |
-
logger.
|
79 |
|
80 |
if debug:
|
81 |
logger.debug("DEBUG mode enabled.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
# Standard Library Imports
|
2 |
import logging
|
3 |
import os
|
|
|
63 |
audio_dir = Path.cwd() / "static" / "audio"
|
64 |
audio_dir.mkdir(parents=True, exist_ok=True)
|
65 |
|
66 |
+
logger.debug(f"Audio directory set to {audio_dir}")
|
67 |
|
68 |
if debug:
|
69 |
logger.debug("DEBUG mode enabled.")
|
src/common/constants.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Standard Library Imports
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
# Third-Party Library Imports
|
5 |
+
from .common_types import ComparisonType, OptionKey, TTSProviderName
|
6 |
+
|
7 |
+
HUME_AI: TTSProviderName = "Hume AI"
|
8 |
+
ELEVENLABS: TTSProviderName = "ElevenLabs"
|
9 |
+
OPENAI: TTSProviderName = "OpenAI"
|
10 |
+
|
11 |
+
TTS_PROVIDERS: List[TTSProviderName] = [HUME_AI, OPENAI, ELEVENLABS]
|
12 |
+
|
13 |
+
HUME_TO_HUME: ComparisonType = "Hume AI - Hume AI"
|
14 |
+
HUME_TO_ELEVENLABS: ComparisonType = "Hume AI - ElevenLabs"
|
15 |
+
HUME_TO_OPENAI: ComparisonType = "Hume AI - OpenAI"
|
16 |
+
OPENAI_TO_ELEVENLABS: ComparisonType = "OpenAI - ElevenLabs"
|
17 |
+
|
18 |
+
TTS_PROVIDER_LINKS = {
|
19 |
+
"Hume AI": {
|
20 |
+
"provider_link": "https://hume.ai/",
|
21 |
+
"model_link": "https://www.hume.ai/blog/octave-the-first-text-to-speech-model-that-understands-what-its-saying"
|
22 |
+
},
|
23 |
+
"ElevenLabs": {
|
24 |
+
"provider_link": "https://elevenlabs.io/",
|
25 |
+
"model_link": "https://elevenlabs.io/blog/rvg",
|
26 |
+
},
|
27 |
+
"OpenAI": {
|
28 |
+
"provider_link": "https://openai.com/",
|
29 |
+
"model_link": "https://platform.openai.com/docs/models/gpt-4o-mini-tts",
|
30 |
+
}
|
31 |
+
}
|
32 |
+
|
33 |
+
CHARACTER_DESCRIPTION_MIN_LENGTH: int = 20
|
34 |
+
CHARACTER_DESCRIPTION_MAX_LENGTH: int = 400
|
35 |
+
|
36 |
+
TEXT_MIN_LENGTH: int = 100
|
37 |
+
TEXT_MAX_LENGTH: int = 400
|
38 |
+
|
39 |
+
OPTION_A_KEY: OptionKey = "option_a"
|
40 |
+
OPTION_B_KEY: OptionKey = "option_b"
|
41 |
+
|
42 |
+
SELECT_OPTION_A: str = "Select Option A"
|
43 |
+
SELECT_OPTION_B: str = "Select Option B"
|
44 |
+
|
45 |
+
CLIENT_ERROR_CODE = 400
|
46 |
+
SERVER_ERROR_CODE = 500
|
47 |
+
RATE_LIMIT_ERROR_CODE = 429
|
48 |
+
|
49 |
+
GENERIC_API_ERROR_MESSAGE: str = "An unexpected error occurred while processing your request. Please try again shortly."
|
src/common/utils.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Standard Library Imports
|
2 |
+
import base64
|
3 |
+
import os
|
4 |
+
import time
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
# Local Application Imports
|
8 |
+
from .config import Config, logger
|
9 |
+
|
10 |
+
|
11 |
+
def _delete_files_older_than(directory: Path, minutes: int = 30) -> None:
|
12 |
+
"""
|
13 |
+
Delete all files in the specified directory that are older than a given number of minutes.
|
14 |
+
|
15 |
+
This function checks each file in the given directory and removes it if its last modification
|
16 |
+
time is older than the specified threshold. By default, the threshold is set to 30 minutes.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
directory (str): The path to the directory where files will be checked and possibly deleted.
|
20 |
+
minutes (int, optional): The age threshold in minutes. Files older than this will be deleted.
|
21 |
+
Defaults to 30 minutes.
|
22 |
+
|
23 |
+
Returns: None
|
24 |
+
"""
|
25 |
+
# Get the current time in seconds since the epoch.
|
26 |
+
now = time.time()
|
27 |
+
# Convert the minutes threshold to seconds.
|
28 |
+
cutoff = now - (minutes * 60)
|
29 |
+
dir_path = Path(directory)
|
30 |
+
|
31 |
+
# Iterate over all files in the directory.
|
32 |
+
for file_path in dir_path.iterdir():
|
33 |
+
if file_path.is_file():
|
34 |
+
file_mod_time = file_path.stat().st_mtime
|
35 |
+
# If the file's modification time is older than the cutoff, delete it.
|
36 |
+
if file_mod_time < cutoff:
|
37 |
+
try:
|
38 |
+
file_path.unlink()
|
39 |
+
logger.info(f"Deleted: {file_path}")
|
40 |
+
except Exception as e:
|
41 |
+
logger.exception(f"Error deleting {file_path}: {e}")
|
42 |
+
|
43 |
+
def save_base64_audio_to_file(base64_audio: str, filename: str, config: Config) -> str:
|
44 |
+
"""
|
45 |
+
Decode a base64-encoded audio string and write the resulting binary data to a file
|
46 |
+
within the preconfigured AUDIO_DIR directory. Prior to writing the bytes to an audio
|
47 |
+
file, all files within the directory that are more than 30 minutes old are deleted.
|
48 |
+
This function verifies the file was created, logs both the absolute and relative
|
49 |
+
file paths, and returns a path relative to the current working directory
|
50 |
+
(as required by Gradio for serving static files).
|
51 |
+
|
52 |
+
Args:
|
53 |
+
base64_audio (str): The base64-encoded string representing the audio data.
|
54 |
+
filename (str): The name of the file (including extension, e.g.,
|
55 |
+
'b4a335da-9786-483a-b0a5-37e6e4ad5fd1.mp3') where the decoded
|
56 |
+
audio will be saved.
|
57 |
+
|
58 |
+
Returns:
|
59 |
+
str: The relative file path to the saved audio file.
|
60 |
+
|
61 |
+
Raises:
|
62 |
+
FileNotFoundError: If the audio file was not created.
|
63 |
+
"""
|
64 |
+
|
65 |
+
audio_bytes = base64.b64decode(base64_audio)
|
66 |
+
file_path = Path(config.audio_dir) / filename
|
67 |
+
num_minutes = 30
|
68 |
+
|
69 |
+
_delete_files_older_than(config.audio_dir, num_minutes)
|
70 |
+
|
71 |
+
# Write the binary audio data to the file.
|
72 |
+
with file_path.open("wb") as audio_file:
|
73 |
+
audio_file.write(audio_bytes)
|
74 |
+
|
75 |
+
# Verify that the file was created.
|
76 |
+
if not file_path.exists():
|
77 |
+
raise FileNotFoundError(f"Audio file was not created at {file_path}")
|
78 |
+
|
79 |
+
# Compute a relative path for Gradio to serve (relative to the current working directory).
|
80 |
+
relative_path = file_path.relative_to(Path.cwd())
|
81 |
+
logger.debug(f"Audio file absolute path: {file_path}")
|
82 |
+
logger.debug(f"Audio file relative path: {relative_path}")
|
83 |
+
|
84 |
+
return str(relative_path)
|
85 |
+
|
86 |
+
def validate_env_var(var_name: str) -> str:
|
87 |
+
"""
|
88 |
+
Validates that an environment variable is set and returns its value.
|
89 |
+
|
90 |
+
Args:
|
91 |
+
var_name (str): The name of the environment variable to validate.
|
92 |
+
|
93 |
+
Returns:
|
94 |
+
str: The value of the environment variable.
|
95 |
+
|
96 |
+
Raises:
|
97 |
+
ValueError: If the environment variable is not set.
|
98 |
+
"""
|
99 |
+
value = os.environ.get(var_name, "")
|
100 |
+
if not value:
|
101 |
+
raise ValueError(f"{var_name} is not set. Please ensure it is defined in your environment variables.")
|
102 |
+
return value
|
103 |
+
|
src/constants.py
DELETED
@@ -1,170 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
constants.py
|
3 |
-
|
4 |
-
This module defines global constants used throughout the project.
|
5 |
-
"""
|
6 |
-
|
7 |
-
# Standard Library Imports
|
8 |
-
from typing import Dict, List
|
9 |
-
|
10 |
-
# Third-Party Library Imports
|
11 |
-
from src.custom_types import (
|
12 |
-
ComparisonType,
|
13 |
-
OptionKey,
|
14 |
-
OptionLabel,
|
15 |
-
TTSProviderName,
|
16 |
-
)
|
17 |
-
|
18 |
-
CLIENT_ERROR_CODE = 400
|
19 |
-
SERVER_ERROR_CODE = 500
|
20 |
-
RATE_LIMIT_ERROR_CODE = 429
|
21 |
-
|
22 |
-
|
23 |
-
# UI constants
|
24 |
-
HUME_AI: TTSProviderName = "Hume AI"
|
25 |
-
ELEVENLABS: TTSProviderName = "ElevenLabs"
|
26 |
-
OPENAI: TTSProviderName = "OpenAI"
|
27 |
-
|
28 |
-
TTS_PROVIDERS: List[TTSProviderName] = ["Hume AI", "OpenAI", "ElevenLabs"]
|
29 |
-
TTS_PROVIDER_LINKS = {
|
30 |
-
"Hume AI": {
|
31 |
-
"provider_link": "https://hume.ai/",
|
32 |
-
"model_link": "https://www.hume.ai/blog/octave-the-first-text-to-speech-model-that-understands-what-its-saying"
|
33 |
-
},
|
34 |
-
"ElevenLabs": {
|
35 |
-
"provider_link": "https://elevenlabs.io/",
|
36 |
-
"model_link": "https://elevenlabs.io/blog/rvg",
|
37 |
-
},
|
38 |
-
"OpenAI": {
|
39 |
-
"provider_link": "https://openai.com/",
|
40 |
-
"model_link": "https://platform.openai.com/docs/models/gpt-4o-mini-tts",
|
41 |
-
}
|
42 |
-
}
|
43 |
-
|
44 |
-
HUME_TO_HUME: ComparisonType = "Hume AI - Hume AI"
|
45 |
-
HUME_TO_ELEVENLABS: ComparisonType = "Hume AI - ElevenLabs"
|
46 |
-
HUME_TO_OPENAI: ComparisonType = "Hume AI - OpenAI"
|
47 |
-
OPENAI_TO_ELEVENLABS: ComparisonType = "OpenAI - ElevenLabs"
|
48 |
-
|
49 |
-
CHARACTER_DESCRIPTION_MIN_LENGTH: int = 20
|
50 |
-
CHARACTER_DESCRIPTION_MAX_LENGTH: int = 400
|
51 |
-
|
52 |
-
TEXT_MIN_LENGTH: int = 100
|
53 |
-
TEXT_MAX_LENGTH: int = 400
|
54 |
-
|
55 |
-
OPTION_A_KEY: OptionKey = "option_a"
|
56 |
-
OPTION_B_KEY: OptionKey = "option_b"
|
57 |
-
OPTION_A_LABEL: OptionLabel = "Option A"
|
58 |
-
OPTION_B_LABEL: OptionLabel = "Option B"
|
59 |
-
|
60 |
-
SELECT_OPTION_A: str = "Select Option A"
|
61 |
-
SELECT_OPTION_B: str = "Select Option B"
|
62 |
-
|
63 |
-
GENERIC_API_ERROR_MESSAGE: str = "An unexpected error occurred while processing your request. Please try again shortly."
|
64 |
-
|
65 |
-
# A collection of pre-defined character descriptions categorized by theme, used to provide users with
|
66 |
-
# inspiration for generating creative, expressive text inputs for TTS, and generating novel voices.
|
67 |
-
SAMPLE_CHARACTER_DESCRIPTIONS: dict = {
|
68 |
-
"π¦ Australian Naturalist": (
|
69 |
-
"The speaker has a contagiously enthusiastic Australian accent, with the relaxed, sun-kissed vibe of a "
|
70 |
-
"wildlife expert fresh off the outback, delivering an amazing, laid-back narration."
|
71 |
-
),
|
72 |
-
"π§ Meditation Guru": (
|
73 |
-
"A mindfulness instructor with a gentle, soothing voice that flows at a slow, measured pace with natural "
|
74 |
-
"pauses. Their consistently calm, low-pitched tone has minimal variation, creating a peaceful auditory "
|
75 |
-
"experience."
|
76 |
-
),
|
77 |
-
"π¬ Noir Detective": (
|
78 |
-
"A 1940s private investigator narrating with a gravelly voice and deliberate pacing. "
|
79 |
-
"Speaks with a cynical, world-weary tone that drops lower when delivering key observations."
|
80 |
-
),
|
81 |
-
"π―οΈ Victorian Ghost Storyteller": (
|
82 |
-
"The speaker is a Victorian-era raconteur speaking with a refined English accent and formal, precise diction. Voice "
|
83 |
-
"modulates between hushed, tense whispers and dramatic declarations when describing eerie occurrences."
|
84 |
-
),
|
85 |
-
"πΏ English Naturalist": (
|
86 |
-
"Speaker is a wildlife documentarian speaking with a crisp, articulate English accent and clear enunciation. Voice "
|
87 |
-
"alternates between hushed, excited whispers and enthusiastic explanations filled with genuine wonder."
|
88 |
-
),
|
89 |
-
"π Texan Storyteller": (
|
90 |
-
"A speaker from rural Texas speaking with a warm voice and distinctive Southern drawl featuring elongated "
|
91 |
-
"vowels. Talks unhurriedly with a musical quality and occasional soft laughter."
|
92 |
-
),
|
93 |
-
"π Chill Surfer": (
|
94 |
-
"The speaker is a California surfer talking with a casual, slightly nasal voice and laid-back rhythm. Uses rising "
|
95 |
-
"inflections at sentence ends and bursts into spontaneous laughter when excited."
|
96 |
-
),
|
97 |
-
"π’ Old-School Radio Announcer": (
|
98 |
-
"The speaker has the voice of a seasoned horse race announcer, with a booming, energetic voice, a touch of "
|
99 |
-
"old-school radio charm, and the enthusiastic delivery of a viral commentator."
|
100 |
-
),
|
101 |
-
"π Obnoxious Royal": (
|
102 |
-
"Speaker is a member of the English royal family speaks in a smug and authoritative voice in an obnoxious, proper "
|
103 |
-
"English accent. They are insecure, arrogant, and prone to tantrums."
|
104 |
-
),
|
105 |
-
"π° Medieval Peasant": (
|
106 |
-
"A film portrayal of a medieval peasant speaking with a thick cockney accent and a worn voice, "
|
107 |
-
"dripping with sarcasm and self-effacing humor."
|
108 |
-
),
|
109 |
-
}
|
110 |
-
|
111 |
-
|
112 |
-
# HTML and social media metadata for the Gradio application
|
113 |
-
# These tags define SEO-friendly content and provide rich previews when shared on social platforms
|
114 |
-
META_TAGS: List[Dict[str, str]] = [
|
115 |
-
# HTML Meta Tags (description)
|
116 |
-
{
|
117 |
-
'name': 'description',
|
118 |
-
'content': 'An open-source web application for comparing and evaluating the expressiveness of different text-to-speech models, including Hume AI and ElevenLabs.'
|
119 |
-
},
|
120 |
-
# Facebook Meta Tags
|
121 |
-
{
|
122 |
-
'property': 'og:url',
|
123 |
-
'content': 'https://hume.ai'
|
124 |
-
},
|
125 |
-
{
|
126 |
-
'property': 'og:type',
|
127 |
-
'content': 'website'
|
128 |
-
},
|
129 |
-
{
|
130 |
-
'property': 'og:title',
|
131 |
-
'content': 'Expressive TTS Arena'
|
132 |
-
},
|
133 |
-
{
|
134 |
-
'property': 'og:description',
|
135 |
-
'content': 'An open-source web application for comparing and evaluating the expressiveness of different text-to-speech models, including Hume AI and ElevenLabs.'
|
136 |
-
},
|
137 |
-
{
|
138 |
-
'property': 'og:image',
|
139 |
-
'content': '/static/arena-opengraph-logo.png'
|
140 |
-
},
|
141 |
-
# Twitter Meta Tags
|
142 |
-
{
|
143 |
-
'name': 'twitter:card',
|
144 |
-
'content': 'summary_large_image'
|
145 |
-
},
|
146 |
-
{
|
147 |
-
'property': 'twitter:domain',
|
148 |
-
'content': 'hume.ai'
|
149 |
-
},
|
150 |
-
{
|
151 |
-
'property': 'twitter:url',
|
152 |
-
'content': 'https://hume.ai'
|
153 |
-
},
|
154 |
-
{
|
155 |
-
'name': 'twitter:creator',
|
156 |
-
'content': '@hume_ai'
|
157 |
-
},
|
158 |
-
{
|
159 |
-
'name': 'twitter:title',
|
160 |
-
'content': 'Expressive TTS Arena'
|
161 |
-
},
|
162 |
-
{
|
163 |
-
'name': 'twitter:description',
|
164 |
-
'content': 'An open-source web application for comparing and evaluating the expressiveness of different text-to-speech models, including Hume AI and ElevenLabs.'
|
165 |
-
},
|
166 |
-
{
|
167 |
-
'name': 'twitter:image',
|
168 |
-
'content': '/static/arena-opengraph-logo.png'
|
169 |
-
}
|
170 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/core/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .tts_service import TTSService
|
2 |
+
from .voting_service import VotingService
|
3 |
+
|
4 |
+
__all__ = ["TTSService", "VotingService"]
|
src/core/tts_service.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Standard Library Imports
|
2 |
+
import asyncio
|
3 |
+
import random
|
4 |
+
from typing import Tuple
|
5 |
+
|
6 |
+
# Local Application Imports
|
7 |
+
from src.common import Config, Option, OptionMap, TTSProviderName, logger
|
8 |
+
from src.common.constants import ELEVENLABS, HUME_AI, OPENAI
|
9 |
+
from src.integrations import (
|
10 |
+
text_to_speech_with_elevenlabs,
|
11 |
+
text_to_speech_with_hume,
|
12 |
+
text_to_speech_with_openai,
|
13 |
+
)
|
14 |
+
|
15 |
+
|
16 |
+
class TTSService:
|
17 |
+
"""
|
18 |
+
Service for coordinating text-to-speech generation across different providers.
|
19 |
+
|
20 |
+
This class handles the logic for selecting TTS providers, making concurrent API calls,
|
21 |
+
and processing the responses into a unified format for the frontend.
|
22 |
+
"""
|
23 |
+
|
24 |
+
def __init__(self, config: Config):
|
25 |
+
"""
|
26 |
+
Initialize the TTS service with application configuration.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
config (Config): Application configuration containing API settings
|
30 |
+
"""
|
31 |
+
self.config = config
|
32 |
+
self.tts_provider_functions = {
|
33 |
+
HUME_AI: text_to_speech_with_hume,
|
34 |
+
ELEVENLABS: text_to_speech_with_elevenlabs,
|
35 |
+
OPENAI: text_to_speech_with_openai,
|
36 |
+
}
|
37 |
+
|
38 |
+
def __select_providers(self, text_modified: bool) -> Tuple[TTSProviderName, TTSProviderName]:
|
39 |
+
"""
|
40 |
+
Select 2 TTS providers based on whether the text has been modified.
|
41 |
+
|
42 |
+
Probabilities:
|
43 |
+
- 50% HUME_AI, OPENAI
|
44 |
+
- 25% OPENAI, ELEVENLABS
|
45 |
+
- 20% HUME_AI, ELEVENLABS
|
46 |
+
- 5% HUME_AI, HUME_AI
|
47 |
+
|
48 |
+
If the `text_modified` argument is `True`, then 100% HUME_AI, HUME_AI
|
49 |
+
|
50 |
+
Args:
|
51 |
+
text_modified (bool): A flag indicating whether the text has been modified
|
52 |
+
|
53 |
+
Returns:
|
54 |
+
tuple: A tuple (TTSProviderName, TTSProviderName)
|
55 |
+
"""
|
56 |
+
if text_modified:
|
57 |
+
return HUME_AI, HUME_AI
|
58 |
+
|
59 |
+
# When modifying the probability distribution, make sure the weights match the order of provider pairs
|
60 |
+
provider_pairs = [
|
61 |
+
(HUME_AI, OPENAI),
|
62 |
+
(OPENAI, ELEVENLABS),
|
63 |
+
(HUME_AI, ELEVENLABS),
|
64 |
+
(HUME_AI, HUME_AI)
|
65 |
+
]
|
66 |
+
weights = [0.5, 0.25, 0.2, 0.05]
|
67 |
+
|
68 |
+
return random.choices(provider_pairs, weights=weights, k=1)[0]
|
69 |
+
|
70 |
+
async def synthesize_speech(
|
71 |
+
self,
|
72 |
+
character_description: str,
|
73 |
+
text: str,
|
74 |
+
text_modified: bool
|
75 |
+
) -> OptionMap:
|
76 |
+
"""
|
77 |
+
Generate speech for the given text using two different TTS providers.
|
78 |
+
|
79 |
+
This method selects appropriate providers based on the text modification status,
|
80 |
+
makes concurrent API calls to those providers, and returns the results.
|
81 |
+
|
82 |
+
Args:
|
83 |
+
character_description (str): Description of the character/voice for synthesis
|
84 |
+
text (str): The text to synthesize into speech
|
85 |
+
text_modified (bool): Whether the text has been modified from the original
|
86 |
+
|
87 |
+
Returns:
|
88 |
+
OptionMap: A mapping of shuffled TTS options, where each option includes
|
89 |
+
its provider, audio file path, and generation ID.
|
90 |
+
"""
|
91 |
+
provider_a, provider_b = self.__select_providers(text_modified)
|
92 |
+
|
93 |
+
logger.info(f"Starting speech synthesis with providers: {provider_a} and {provider_b}")
|
94 |
+
|
95 |
+
task_a = self.tts_provider_functions[provider_a](character_description, text, self.config)
|
96 |
+
task_b = self.tts_provider_functions[provider_b](character_description, text, self.config)
|
97 |
+
|
98 |
+
(generation_id_a, audio_a), (generation_id_b, audio_b) = await asyncio.gather(task_a, task_b)
|
99 |
+
|
100 |
+
logger.info(f"Synthesis succeeded for providers: {provider_a} and {provider_b}")
|
101 |
+
|
102 |
+
option_a = Option(provider=provider_a, audio=audio_a, generation_id=generation_id_a)
|
103 |
+
option_b = Option(provider=provider_b, audio=audio_b, generation_id=generation_id_b)
|
104 |
+
|
105 |
+
options = [option_a, option_b]
|
106 |
+
random.shuffle(options)
|
107 |
+
shuffled_option_a, shuffled_option_b = options
|
108 |
+
|
109 |
+
return {
|
110 |
+
"option_a": {
|
111 |
+
"provider": shuffled_option_a.provider,
|
112 |
+
"generation_id": shuffled_option_a.generation_id,
|
113 |
+
"audio_file_path": shuffled_option_a.audio,
|
114 |
+
},
|
115 |
+
"option_b": {
|
116 |
+
"provider": shuffled_option_b.provider,
|
117 |
+
"generation_id": shuffled_option_b.generation_id,
|
118 |
+
"audio_file_path": shuffled_option_b.audio,
|
119 |
+
},
|
120 |
+
}
|
src/core/voting_service.py
ADDED
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Standard Library Imports
|
2 |
+
import json
|
3 |
+
from typing import List, Tuple
|
4 |
+
|
5 |
+
# Third-Party Library Imports
|
6 |
+
from sqlalchemy.ext.asyncio import AsyncSession
|
7 |
+
|
8 |
+
# Local Application Imports
|
9 |
+
from src.common import (
|
10 |
+
ComparisonType,
|
11 |
+
LeaderboardEntry,
|
12 |
+
OptionKey,
|
13 |
+
OptionMap,
|
14 |
+
TTSProviderName,
|
15 |
+
VotingResults,
|
16 |
+
constants,
|
17 |
+
logger,
|
18 |
+
)
|
19 |
+
from src.database import (
|
20 |
+
AsyncDBSessionMaker,
|
21 |
+
create_vote,
|
22 |
+
get_head_to_head_battle_stats,
|
23 |
+
get_head_to_head_win_rate_stats,
|
24 |
+
get_leaderboard_stats,
|
25 |
+
)
|
26 |
+
|
27 |
+
|
28 |
+
class VotingService:
|
29 |
+
"""
|
30 |
+
Service for handling all database interactions related to voting and leaderboards.
|
31 |
+
|
32 |
+
Encapsulates logic for submitting votes and retrieving formatted leaderboard statistics.
|
33 |
+
"""
|
34 |
+
|
35 |
+
def __init__(self, db_session_maker: AsyncDBSessionMaker):
|
36 |
+
"""
|
37 |
+
Initializes the VotingService.
|
38 |
+
|
39 |
+
Args:
|
40 |
+
db_session_maker: An asynchronous database session factory.
|
41 |
+
"""
|
42 |
+
self.db_session_maker: AsyncDBSessionMaker = db_session_maker
|
43 |
+
logger.debug("VotingService initialized.")
|
44 |
+
|
45 |
+
async def _create_db_session(self) -> AsyncSession | None:
|
46 |
+
"""
|
47 |
+
Creates a new database session, returning None if it's a dummy session.
|
48 |
+
|
49 |
+
Returns:
|
50 |
+
An active AsyncSession or None if using a dummy session factory.
|
51 |
+
"""
|
52 |
+
session = self.db_session_maker()
|
53 |
+
# Check for a dummy session marker if your factory provides one
|
54 |
+
is_dummy_session = getattr(session, "is_dummy", False)
|
55 |
+
|
56 |
+
if is_dummy_session:
|
57 |
+
logger.debug("Using dummy DB session; operations will be skipped.")
|
58 |
+
# Ensure dummy sessions are also closed if they have resources
|
59 |
+
if hasattr(session, "close"):
|
60 |
+
await session.close()
|
61 |
+
return None
|
62 |
+
|
63 |
+
logger.debug("Created new DB session.")
|
64 |
+
return session
|
65 |
+
|
66 |
+
def _determine_comparison_type(self, provider_a: TTSProviderName, provider_b: TTSProviderName) -> ComparisonType:
|
67 |
+
"""
|
68 |
+
Determine the comparison type based on the given TTS provider names.
|
69 |
+
|
70 |
+
Args:
|
71 |
+
provider_a (TTSProviderName): The first TTS provider.
|
72 |
+
provider_b (TTSProviderName): The second TTS provider.
|
73 |
+
|
74 |
+
Returns:
|
75 |
+
ComparisonType: The determined comparison type.
|
76 |
+
|
77 |
+
Raises:
|
78 |
+
ValueError: If the combination of providers is not recognized.
|
79 |
+
"""
|
80 |
+
if provider_a == constants.HUME_AI and provider_b == constants.HUME_AI:
|
81 |
+
return constants.HUME_TO_HUME
|
82 |
+
|
83 |
+
providers = (provider_a, provider_b)
|
84 |
+
|
85 |
+
if constants.HUME_AI in providers and constants.ELEVENLABS in providers:
|
86 |
+
return constants.HUME_TO_ELEVENLABS
|
87 |
+
|
88 |
+
if constants.HUME_AI in providers and constants.OPENAI in providers:
|
89 |
+
return constants.HUME_TO_OPENAI
|
90 |
+
|
91 |
+
if constants.ELEVENLABS in providers and constants.OPENAI in providers:
|
92 |
+
return constants.OPENAI_TO_ELEVENLABS
|
93 |
+
|
94 |
+
raise ValueError(f"Invalid provider combination: {provider_a}, {provider_b}")
|
95 |
+
|
96 |
+
async def _persist_vote(self, voting_results: VotingResults) -> None:
|
97 |
+
"""
|
98 |
+
Persists a vote record in the database using a dedicated session.
|
99 |
+
|
100 |
+
Handles session creation, commit, rollback, and closure. Logs errors internally.
|
101 |
+
|
102 |
+
Args:
|
103 |
+
voting_results: A dictionary containing the vote details.
|
104 |
+
"""
|
105 |
+
session = await self._create_db_session()
|
106 |
+
if session is None:
|
107 |
+
logger.info("Skipping vote persistence (dummy session).")
|
108 |
+
self._log_voting_results(voting_results)
|
109 |
+
return
|
110 |
+
|
111 |
+
try:
|
112 |
+
self._log_voting_results(voting_results)
|
113 |
+
await create_vote(session, voting_results)
|
114 |
+
logger.info("Vote successfully persisted.")
|
115 |
+
except Exception as e:
|
116 |
+
logger.error(f"Failed to persist vote record: {e}", exc_info=True)
|
117 |
+
finally:
|
118 |
+
await session.close()
|
119 |
+
logger.debug("DB session closed after persisting vote.")
|
120 |
+
|
121 |
+
def _log_voting_results(self, voting_results: VotingResults) -> None:
|
122 |
+
"""Logs the full voting results dictionary."""
|
123 |
+
try:
|
124 |
+
logger.info("Voting results:\n%s", json.dumps(voting_results, indent=4, default=str))
|
125 |
+
except TypeError:
|
126 |
+
logger.error("Could not serialize voting results for logging.")
|
127 |
+
logger.info(f"Voting results (raw): {voting_results}")
|
128 |
+
|
129 |
+
def _format_leaderboard_data(self, leaderboard_data_raw: List[LeaderboardEntry]) -> List[List[str]]:
|
130 |
+
"""Formats raw leaderboard entries into HTML strings for the UI table."""
|
131 |
+
formatted_data = []
|
132 |
+
for rank, provider, model, win_rate, votes in leaderboard_data_raw:
|
133 |
+
provider_info = constants.TTS_PROVIDER_LINKS.get(provider, {})
|
134 |
+
provider_link = provider_info.get("provider_link", "#")
|
135 |
+
model_link = provider_info.get("model_link", "#")
|
136 |
+
|
137 |
+
formatted_data.append([
|
138 |
+
f'<p style="text-align: center;">{rank}</p>',
|
139 |
+
f'<a href="{provider_link}" target="_blank" class="provider-link">{provider}</a>',
|
140 |
+
f'<a href="{model_link}" target="_blank" class="provider-link">{model}</a>',
|
141 |
+
f'<p style="text-align: center;">{win_rate}</p>',
|
142 |
+
f'<p style="text-align: center;">{votes}</p>',
|
143 |
+
])
|
144 |
+
return formatted_data
|
145 |
+
|
146 |
+
|
147 |
+
def _format_battle_counts_data(self, battle_counts_data_raw: List[List[str]]) -> List[List[str]]:
|
148 |
+
"""Formats raw battle counts into an HTML matrix for the UI."""
|
149 |
+
battle_counts_dict = {item[0]: str(item[1]) for item in battle_counts_data_raw}
|
150 |
+
providers = constants.TTS_PROVIDERS
|
151 |
+
|
152 |
+
formatted_matrix: List[List[str]] = []
|
153 |
+
for row_provider in providers:
|
154 |
+
row = [f'<p style="padding-left: 8px;"><strong>{row_provider}</strong></p>']
|
155 |
+
for col_provider in providers:
|
156 |
+
if row_provider == col_provider:
|
157 |
+
cell_value = "-"
|
158 |
+
else:
|
159 |
+
comparison_key = self._determine_comparison_type(row_provider, col_provider)
|
160 |
+
cell_value = battle_counts_dict.get(comparison_key, "0")
|
161 |
+
row.append(f'<p style="text-align: center;">{cell_value}</p>')
|
162 |
+
formatted_matrix.append(row)
|
163 |
+
return formatted_matrix
|
164 |
+
|
165 |
+
|
166 |
+
def _format_win_rate_data(self, win_rate_data_raw: List[List[str]]) -> List[List[str]]:
|
167 |
+
"""Formats raw win rates into an HTML matrix for the UI."""
|
168 |
+
# win_rate_data_raw expected as [comparison_type, first_win_rate_str, second_win_rate_str]
|
169 |
+
win_rates = {}
|
170 |
+
for comparison_type, first_win_rate, second_win_rate in win_rate_data_raw:
|
171 |
+
# Comparison type should already be canonical 'ProviderA - ProviderB'
|
172 |
+
try:
|
173 |
+
provider1, provider2 = comparison_type.split(" - ")
|
174 |
+
win_rates[(provider1, provider2)] = first_win_rate
|
175 |
+
win_rates[(provider2, provider1)] = second_win_rate
|
176 |
+
except ValueError:
|
177 |
+
logger.warning(f"Could not parse comparison_type '{comparison_type}' in win rate data.")
|
178 |
+
continue # Skip malformed entry
|
179 |
+
|
180 |
+
providers = constants.TTS_PROVIDERS
|
181 |
+
formatted_matrix: List[List[str]] = []
|
182 |
+
for row_provider in providers:
|
183 |
+
row = [f'<p style="padding-left: 8px;"><strong>{row_provider}</strong></p>']
|
184 |
+
for col_provider in providers:
|
185 |
+
cell_value = "-" if row_provider == col_provider else win_rates.get((row_provider, col_provider), "0%")
|
186 |
+
row.append(f'<p style="text-align: center;">{cell_value}</p>')
|
187 |
+
formatted_matrix.append(row)
|
188 |
+
return formatted_matrix
|
189 |
+
|
190 |
+
async def get_formatted_leaderboard_data(self) -> Tuple[
|
191 |
+
List[List[str]],
|
192 |
+
List[List[str]],
|
193 |
+
List[List[str]],
|
194 |
+
]:
|
195 |
+
"""
|
196 |
+
Fetches raw leaderboard stats and formats them for UI display.
|
197 |
+
|
198 |
+
Retrieves overall rankings, battle counts, and win rates, then formats
|
199 |
+
them into HTML strings suitable for Gradio DataFrames.
|
200 |
+
|
201 |
+
Returns:
|
202 |
+
A tuple containing formatted lists of lists for:
|
203 |
+
- Leaderboard rankings table
|
204 |
+
- Battle counts matrix
|
205 |
+
- Win rate matrix
|
206 |
+
Returns empty lists ([[]], [[]], [[]]) on failure.
|
207 |
+
"""
|
208 |
+
session = await self._create_db_session()
|
209 |
+
if session is None:
|
210 |
+
logger.info("Skipping leaderboard fetch (dummy session).")
|
211 |
+
return [[]], [[]], [[]]
|
212 |
+
|
213 |
+
try:
|
214 |
+
# Fetch raw data using underlying CRUD functions
|
215 |
+
leaderboard_data_raw = await get_leaderboard_stats(session)
|
216 |
+
battle_counts_data_raw = await get_head_to_head_battle_stats(session)
|
217 |
+
win_rate_data_raw = await get_head_to_head_win_rate_stats(session)
|
218 |
+
logger.debug("Fetched raw leaderboard data successfully.")
|
219 |
+
|
220 |
+
# Format the data
|
221 |
+
leaderboard_data = self._format_leaderboard_data(leaderboard_data_raw)
|
222 |
+
battle_counts_data = self._format_battle_counts_data(battle_counts_data_raw)
|
223 |
+
win_rate_data = self._format_win_rate_data(win_rate_data_raw)
|
224 |
+
|
225 |
+
return leaderboard_data, battle_counts_data, win_rate_data
|
226 |
+
|
227 |
+
except Exception as e:
|
228 |
+
logger.error(f"Failed to fetch and format leaderboard data: {e}", exc_info=True)
|
229 |
+
return [[]], [[]], [[]] # Return empty structure on error
|
230 |
+
finally:
|
231 |
+
await session.close()
|
232 |
+
logger.debug("DB session closed after fetching leaderboard data.")
|
233 |
+
|
234 |
+
async def submit_vote(
|
235 |
+
self,
|
236 |
+
option_map: OptionMap,
|
237 |
+
selected_option: OptionKey,
|
238 |
+
text_modified: bool,
|
239 |
+
character_description: str,
|
240 |
+
text: str,
|
241 |
+
) -> None:
|
242 |
+
"""
|
243 |
+
Constructs and persists a vote record based on user selection and context.
|
244 |
+
|
245 |
+
This method is designed to be called safely from background tasks, handling all internal exceptions.
|
246 |
+
|
247 |
+
Args:
|
248 |
+
option_map: Mapping of comparison data and TTS options.
|
249 |
+
selected_option: The option key ('option_a' or 'option_b') selected by the user.
|
250 |
+
text_modified: Indicates if the text was custom vs. generated.
|
251 |
+
character_description: Description used for TTS generation.
|
252 |
+
text: The text synthesized.
|
253 |
+
"""
|
254 |
+
try:
|
255 |
+
provider_a: TTSProviderName = option_map[constants.OPTION_A_KEY]["provider"]
|
256 |
+
provider_b: TTSProviderName = option_map[constants.OPTION_B_KEY]["provider"]
|
257 |
+
|
258 |
+
comparison_type: ComparisonType = self._determine_comparison_type(provider_a, provider_b)
|
259 |
+
|
260 |
+
voting_results: VotingResults = {
|
261 |
+
"comparison_type": comparison_type,
|
262 |
+
"winning_provider": option_map[selected_option]["provider"],
|
263 |
+
"winning_option": selected_option,
|
264 |
+
"option_a_provider": provider_a,
|
265 |
+
"option_b_provider": provider_b,
|
266 |
+
"option_a_generation_id": option_map[constants.OPTION_A_KEY]["generation_id"],
|
267 |
+
"option_b_generation_id": option_map[constants.OPTION_B_KEY]["generation_id"],
|
268 |
+
"character_description": character_description,
|
269 |
+
"text": text,
|
270 |
+
"is_custom_text": text_modified,
|
271 |
+
}
|
272 |
+
|
273 |
+
await self._persist_vote(voting_results)
|
274 |
+
|
275 |
+
except KeyError as e:
|
276 |
+
logger.error(
|
277 |
+
f"Missing key in option_map during vote submission: {e}. OptionMap: {option_map}",
|
278 |
+
exc_info=True
|
279 |
+
)
|
280 |
+
except Exception as e:
|
281 |
+
logger.error(f"Unexpected error in submit_vote: {e}", exc_info=True)
|
src/database/crud.py
CHANGED
@@ -1,10 +1,3 @@
|
|
1 |
-
"""
|
2 |
-
crud.py
|
3 |
-
|
4 |
-
This module defines the operations for the Expressive TTS Arena project's database.
|
5 |
-
Since vote records are never updated or deleted, only functions to create and read votes are provided.
|
6 |
-
"""
|
7 |
-
|
8 |
# Standard Library Imports
|
9 |
from typing import List
|
10 |
|
@@ -14,9 +7,9 @@ from sqlalchemy.exc import SQLAlchemyError
|
|
14 |
from sqlalchemy.ext.asyncio import AsyncSession
|
15 |
|
16 |
# Local Application Imports
|
17 |
-
from src.
|
18 |
-
|
19 |
-
from
|
20 |
|
21 |
|
22 |
async def create_vote(db: AsyncSession, vote_data: VotingResults) -> VoteResult:
|
@@ -31,7 +24,6 @@ async def create_vote(db: AsyncSession, vote_data: VotingResults) -> VoteResult:
|
|
31 |
VoteResult: The newly created vote record.
|
32 |
"""
|
33 |
try:
|
34 |
-
|
35 |
# Create vote record
|
36 |
vote = VoteResult(
|
37 |
comparison_type=vote_data["comparison_type"],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
# Standard Library Imports
|
2 |
from typing import List
|
3 |
|
|
|
7 |
from sqlalchemy.ext.asyncio import AsyncSession
|
8 |
|
9 |
# Local Application Imports
|
10 |
+
from src.common import LeaderboardEntry, LeaderboardTableEntries, VotingResults, logger
|
11 |
+
|
12 |
+
from .models import VoteResult
|
13 |
|
14 |
|
15 |
async def create_vote(db: AsyncSession, vote_data: VotingResults) -> VoteResult:
|
|
|
24 |
VoteResult: The newly created vote record.
|
25 |
"""
|
26 |
try:
|
|
|
27 |
# Create vote record
|
28 |
vote = VoteResult(
|
29 |
comparison_type=vote_data["comparison_type"],
|
src/database/database.py
CHANGED
@@ -1,13 +1,3 @@
|
|
1 |
-
"""
|
2 |
-
database.py
|
3 |
-
|
4 |
-
This module sets up the SQLAlchemy database connection for the Expressive TTS Arena project.
|
5 |
-
It initializes the PostgreSQL engine, creates a session factory for handling database transactions,
|
6 |
-
and defines a declarative base class for ORM models.
|
7 |
-
|
8 |
-
If no DATABASE_URL environment variable is set, then create a dummy database to fail gracefully.
|
9 |
-
"""
|
10 |
-
|
11 |
# Standard Library Imports
|
12 |
from typing import Callable, Optional, TypeAlias, Union
|
13 |
|
@@ -16,14 +6,13 @@ from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker
|
|
16 |
from sqlalchemy.orm import DeclarativeBase
|
17 |
|
18 |
# Local Application Imports
|
19 |
-
from src.
|
20 |
|
21 |
|
22 |
# Define the SQLAlchemy Base
|
23 |
class Base(DeclarativeBase):
|
24 |
pass
|
25 |
|
26 |
-
|
27 |
class DummyAsyncSession:
|
28 |
is_dummy = True # Flag to indicate this is a dummy session.
|
29 |
|
@@ -53,11 +42,9 @@ class DummyAsyncSession:
|
|
53 |
# No-op: nothing to close.
|
54 |
pass
|
55 |
|
56 |
-
|
57 |
AsyncDBSessionMaker: TypeAlias = Union[async_sessionmaker[AsyncSession], Callable[[], DummyAsyncSession]]
|
58 |
engine: Optional[AsyncEngine] = None
|
59 |
|
60 |
-
|
61 |
def init_db(config: Config) -> AsyncDBSessionMaker:
|
62 |
"""
|
63 |
Initialize the database engine and return a session factory based on the provided configuration.
|
@@ -99,4 +86,3 @@ def init_db(config: Config) -> AsyncDBSessionMaker:
|
|
99 |
return DummyAsyncSession()
|
100 |
|
101 |
return async_dummy_session_factory
|
102 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
# Standard Library Imports
|
2 |
from typing import Callable, Optional, TypeAlias, Union
|
3 |
|
|
|
6 |
from sqlalchemy.orm import DeclarativeBase
|
7 |
|
8 |
# Local Application Imports
|
9 |
+
from src.common import Config, logger
|
10 |
|
11 |
|
12 |
# Define the SQLAlchemy Base
|
13 |
class Base(DeclarativeBase):
|
14 |
pass
|
15 |
|
|
|
16 |
class DummyAsyncSession:
|
17 |
is_dummy = True # Flag to indicate this is a dummy session.
|
18 |
|
|
|
42 |
# No-op: nothing to close.
|
43 |
pass
|
44 |
|
|
|
45 |
AsyncDBSessionMaker: TypeAlias = Union[async_sessionmaker[AsyncSession], Callable[[], DummyAsyncSession]]
|
46 |
engine: Optional[AsyncEngine] = None
|
47 |
|
|
|
48 |
def init_db(config: Config) -> AsyncDBSessionMaker:
|
49 |
"""
|
50 |
Initialize the database engine and return a session factory based on the provided configuration.
|
|
|
86 |
return DummyAsyncSession()
|
87 |
|
88 |
return async_dummy_session_factory
|
|
src/database/models.py
CHANGED
@@ -1,10 +1,3 @@
|
|
1 |
-
"""
|
2 |
-
models.py
|
3 |
-
|
4 |
-
This module defines the SQLAlchemy ORM models for the Expressive TTS Arena project.
|
5 |
-
It currently defines the VoteResult model representing the vote_results table.
|
6 |
-
"""
|
7 |
-
|
8 |
# Standard Library Imports
|
9 |
from enum import Enum
|
10 |
|
@@ -27,14 +20,13 @@ from sqlalchemy import (
|
|
27 |
)
|
28 |
|
29 |
# Local Application Imports
|
30 |
-
from
|
31 |
|
32 |
|
33 |
class OptionEnum(str, Enum):
|
34 |
OPTION_A = "option_a"
|
35 |
OPTION_B = "option_b"
|
36 |
|
37 |
-
|
38 |
class VoteResult(Base):
|
39 |
__tablename__ = "vote_results"
|
40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
# Standard Library Imports
|
2 |
from enum import Enum
|
3 |
|
|
|
20 |
)
|
21 |
|
22 |
# Local Application Imports
|
23 |
+
from .database import Base
|
24 |
|
25 |
|
26 |
class OptionEnum(str, Enum):
|
27 |
OPTION_A = "option_a"
|
28 |
OPTION_B = "option_b"
|
29 |
|
|
|
30 |
class VoteResult(Base):
|
31 |
__tablename__ = "vote_results"
|
32 |
|
src/frontend/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .frontend import Frontend
|
2 |
+
|
3 |
+
__all__ = ["Frontend"]
|
src/frontend/components/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .arena import Arena
|
2 |
+
from .leaderboard import Leaderboard
|
3 |
+
|
4 |
+
__all__ = ["Arena", "Leaderboard"]
|
src/{frontend.py β frontend/components/arena.py}
RENAMED
@@ -1,125 +1,170 @@
|
|
1 |
-
"""
|
2 |
-
frontend.py
|
3 |
-
|
4 |
-
Gradio UI for interacting with the Anthropic API, Hume TTS API, and ElevenLabs TTS API.
|
5 |
-
|
6 |
-
Users enter a character description, which is processed using Claude by Anthropic to generate text.
|
7 |
-
The text is then synthesized into speech using different TTS provider APIs.
|
8 |
-
Users can compare the outputs and vote for their favorite in an interactive UI.
|
9 |
-
"""
|
10 |
-
|
11 |
# Standard Library Imports
|
12 |
import asyncio
|
13 |
-
import
|
14 |
-
import json
|
15 |
import time
|
16 |
-
from typing import
|
17 |
|
18 |
# Third-Party Library Imports
|
19 |
import gradio as gr
|
20 |
|
21 |
# Local Application Imports
|
22 |
-
from src import constants
|
23 |
-
from src.
|
24 |
-
from src.
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
)
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
)
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
"""
|
65 |
-
|
66 |
|
67 |
Args:
|
68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
|
70 |
-
|
71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
"""
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
# Fetch the latest data
|
82 |
-
(
|
83 |
-
latest_leaderboard_data,
|
84 |
-
latest_battle_counts_data,
|
85 |
-
latest_win_rates_data
|
86 |
-
) = await get_leaderboard_data(self.db_session_maker)
|
87 |
-
|
88 |
-
# Generate a hash of the new data to check if it's changed
|
89 |
-
data_str = json.dumps(str(latest_leaderboard_data))
|
90 |
-
data_hash = hashlib.md5(data_str.encode()).hexdigest()
|
91 |
-
|
92 |
-
# Check if the data has changed
|
93 |
-
if data_hash == self._leaderboard_cache_hash and not force:
|
94 |
-
logger.debug("Leaderboard data unchanged since last fetch.")
|
95 |
-
return False
|
96 |
-
|
97 |
-
# Update the cache and timestamp
|
98 |
-
self._leaderboard_data = latest_leaderboard_data
|
99 |
-
self._battle_counts_data = latest_battle_counts_data
|
100 |
-
self._win_rates_data = latest_win_rates_data
|
101 |
-
self._leaderboard_cache_hash = data_hash
|
102 |
-
self._last_leaderboard_update_time = current_time
|
103 |
-
logger.info("Leaderboard data updated successfully.")
|
104 |
-
return True
|
105 |
-
|
106 |
-
async def _generate_text(self, character_description: str) -> Tuple[gr.Textbox, str]:
|
107 |
"""
|
108 |
-
Validates the
|
109 |
|
110 |
Args:
|
111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
|
113 |
Returns:
|
114 |
-
|
115 |
-
-
|
116 |
-
- The generated text string.
|
117 |
|
118 |
Raises:
|
119 |
-
gr.Error: On validation or API errors.
|
120 |
"""
|
121 |
try:
|
122 |
-
|
123 |
except ValueError as ve:
|
124 |
logger.warning(f"Validation error: {ve}")
|
125 |
raise gr.Error(str(ve))
|
@@ -132,100 +177,78 @@ class Frontend:
|
|
132 |
logger.error(f"Text Generation Failed: AnthropicError while generating text: {ae!s}")
|
133 |
raise gr.Error(f'There was an issue communicating with the Anthropic API: "{ae.message}"')
|
134 |
except Exception as e:
|
135 |
-
logger.error(f"Text Generation Failed: Unexpected error while generating text: {e!s}")
|
136 |
raise gr.Error("Failed to generate text. Please try again shortly.")
|
137 |
|
138 |
def _warn_user_about_custom_text(self, text: str, generated_text: str) -> None:
|
139 |
"""
|
140 |
-
|
141 |
-
|
142 |
-
When users edit the generated text instead of using it as-is, only Hume Octave
|
143 |
-
outputs will be generated for comparison rather than comparing against other
|
144 |
-
providers. This function displays a warning to inform users of this limitation.
|
145 |
|
146 |
Args:
|
147 |
-
text
|
148 |
-
generated_text
|
149 |
-
|
150 |
-
Returns:
|
151 |
-
None: This function displays a warning but does not return any value.
|
152 |
"""
|
153 |
if text != generated_text:
|
154 |
-
gr.Warning("When custom text is used, only Hume Octave outputs are generated.")
|
155 |
|
156 |
async def _synthesize_speech(
|
157 |
self,
|
158 |
character_description: str,
|
159 |
text: str,
|
160 |
generated_text_state: str,
|
161 |
-
) -> Tuple[
|
162 |
"""
|
163 |
-
|
164 |
-
|
165 |
-
This function generates TTS outputs using different providers based on the input text and its modification
|
166 |
-
state.
|
167 |
|
168 |
-
|
169 |
-
|
170 |
|
171 |
Args:
|
172 |
-
character_description
|
173 |
-
text
|
174 |
-
generated_text_state
|
175 |
-
been modified.
|
176 |
|
177 |
Returns:
|
178 |
-
|
179 |
-
-
|
180 |
-
-
|
181 |
-
- OptionMap:
|
182 |
-
- bool: Flag indicating
|
183 |
-
- str: The
|
184 |
-
- str: The
|
185 |
-
- bool: Flag indicating whether the vote buttons should be enabled
|
186 |
|
187 |
Raises:
|
188 |
-
gr.Error:
|
189 |
"""
|
190 |
try:
|
191 |
-
|
192 |
-
|
193 |
except ValueError as ve:
|
194 |
-
logger.
|
195 |
raise gr.Error(str(ve))
|
196 |
|
197 |
-
text_modified = text != generated_text_state
|
198 |
-
provider_a, provider_b = get_random_providers(text_modified)
|
199 |
-
|
200 |
-
tts_provider_funcs = {
|
201 |
-
constants.HUME_AI: text_to_speech_with_hume,
|
202 |
-
constants.OPENAI: text_to_speech_with_openai,
|
203 |
-
constants.ELEVENLABS: text_to_speech_with_elevenlabs,
|
204 |
-
}
|
205 |
-
|
206 |
try:
|
207 |
-
|
208 |
-
|
209 |
-
# Create two tasks for concurrent execution
|
210 |
-
task_a = tts_provider_funcs[provider_a](character_description, text, self.config)
|
211 |
-
task_b = tts_provider_funcs[provider_b](character_description, text, self.config)
|
212 |
|
213 |
-
#
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
|
221 |
return (
|
222 |
gr.update(value=options_map["option_a"]["audio_file_path"], autoplay=True),
|
223 |
gr.update(value=options_map["option_b"]["audio_file_path"]),
|
224 |
options_map,
|
225 |
text_modified,
|
226 |
-
text,
|
227 |
-
character_description,
|
228 |
-
True,
|
229 |
)
|
230 |
except HumeError as he:
|
231 |
logger.error(f"Synthesis failed with HumeError during TTS generation: {he!s}")
|
@@ -237,157 +260,171 @@ class Frontend:
|
|
237 |
logger.error(f"Synthesis failed with ElevenLabsError during TTS generation: {ee!s}")
|
238 |
raise gr.Error(f'There was an issue communicating with the Elevenlabs API: "{ee.message}"')
|
239 |
except Exception as e:
|
240 |
-
logger.error(f"Synthesis failed with an unexpected error during TTS generation: {e!s}")
|
241 |
raise gr.Error("An unexpected error occurred. Please try again shortly.")
|
242 |
|
243 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
244 |
self,
|
245 |
vote_submitted: bool,
|
246 |
option_map: OptionMap,
|
247 |
-
|
248 |
text_modified: bool,
|
249 |
character_description: str,
|
250 |
text: str,
|
251 |
) -> Tuple[
|
252 |
-
bool,
|
253 |
-
gr.
|
254 |
-
gr.
|
255 |
-
gr.
|
256 |
-
gr.
|
257 |
-
gr.
|
258 |
]:
|
259 |
"""
|
260 |
-
Handles user voting and updates the UI
|
|
|
|
|
261 |
|
262 |
Args:
|
263 |
-
vote_submitted
|
264 |
-
option_map
|
265 |
-
|
266 |
-
text_modified
|
267 |
-
character_description
|
268 |
-
text
|
269 |
|
270 |
Returns:
|
271 |
-
A tuple of
|
272 |
-
|
273 |
-
|
274 |
-
-
|
275 |
-
-
|
276 |
-
-
|
277 |
-
-
|
|
|
|
|
278 |
"""
|
279 |
-
|
|
|
|
|
|
|
280 |
return gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip()
|
281 |
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
295 |
)
|
296 |
-
|
297 |
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
|
302 |
-
|
303 |
-
True
|
304 |
-
gr.update(visible=
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
if selected_option == constants.OPTION_A_KEY
|
309 |
-
else gr.update(value=other_label, visible=True)
|
310 |
-
),
|
311 |
-
(
|
312 |
-
gr.update(value=other_label, visible=True)
|
313 |
-
if selected_option == constants.OPTION_A_KEY
|
314 |
-
else gr.update(value=selected_label, visible=True, elem_classes="winner")
|
315 |
-
),
|
316 |
-
gr.update(interactive=True),
|
317 |
-
)
|
318 |
|
319 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
320 |
"""
|
321 |
-
|
322 |
|
323 |
Returns:
|
324 |
-
|
325 |
-
-
|
326 |
-
-
|
327 |
"""
|
328 |
-
|
|
|
|
|
|
|
|
|
329 |
|
330 |
-
sample_keys = list(
|
331 |
random_sample = random.choice(sample_keys)
|
332 |
-
character_description =
|
333 |
|
334 |
logger.info(f"Randomize All: Selected '{random_sample}'")
|
335 |
|
336 |
return (
|
337 |
-
gr.update(value=random_sample), # Update dropdown
|
338 |
-
gr.update(value=character_description), # Update character description
|
339 |
)
|
340 |
|
341 |
-
|
342 |
-
"""
|
343 |
-
Asynchronously fetches and formats the latest leaderboard data.
|
344 |
-
|
345 |
-
Args:
|
346 |
-
force (bool): If True, bypass time-based throttling.
|
347 |
-
|
348 |
-
Returns:
|
349 |
-
tuple: Updated DataFrames or gr.skip() if no update needed
|
350 |
-
"""
|
351 |
-
data_updated = await self._update_leaderboard_data(force=force)
|
352 |
-
|
353 |
-
if not self._leaderboard_data:
|
354 |
-
raise gr.Error("Unable to retrieve leaderboard data. Please refresh the page or try again shortly.")
|
355 |
-
|
356 |
-
if data_updated or force:
|
357 |
-
return (
|
358 |
-
gr.update(value=self._leaderboard_data),
|
359 |
-
gr.update(value=self._battle_counts_data),
|
360 |
-
gr.update(value=self._win_rates_data)
|
361 |
-
)
|
362 |
-
return gr.skip(), gr.skip(), gr.skip()
|
363 |
-
|
364 |
-
async def _handle_tab_select(self, evt: gr.SelectData):
|
365 |
"""
|
366 |
-
|
367 |
-
|
368 |
-
Args:
|
369 |
-
evt (gr.SelectData): Event data containing information about the selected tab
|
370 |
|
371 |
Returns:
|
372 |
-
tuple
|
373 |
-
|
374 |
-
if evt.value == "Leaderboard":
|
375 |
-
return await self._refresh_leaderboard(force=False)
|
376 |
-
return gr.skip(), gr.skip(), gr.skip()
|
377 |
-
|
378 |
-
def _disable_ui(self) -> Tuple[
|
379 |
-
gr.Button,
|
380 |
-
gr.Dropdown,
|
381 |
-
gr.Textbox,
|
382 |
-
gr.Button,
|
383 |
-
gr.Textbox,
|
384 |
-
gr.Button,
|
385 |
-
gr.Button,
|
386 |
-
gr.Button
|
387 |
-
]:
|
388 |
-
"""
|
389 |
-
Disables all interactive components in the UI (except audio players)
|
390 |
"""
|
|
|
391 |
return(
|
392 |
gr.update(interactive=False), # disable Randomize All button
|
393 |
gr.update(interactive=False), # disable Character Description dropdown
|
@@ -399,19 +436,20 @@ class Frontend:
|
|
399 |
gr.update(interactive=False), # disable Select B Button
|
400 |
)
|
401 |
|
402 |
-
def _enable_ui(self, should_enable_vote_buttons) -> Tuple[
|
403 |
-
gr.Button,
|
404 |
-
gr.Dropdown,
|
405 |
-
gr.Textbox,
|
406 |
-
gr.Button,
|
407 |
-
gr.Textbox,
|
408 |
-
gr.Button,
|
409 |
-
gr.Button,
|
410 |
-
gr.Button
|
411 |
-
]:
|
412 |
"""
|
413 |
-
Enables
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
414 |
"""
|
|
|
415 |
return(
|
416 |
gr.update(interactive=True), # enable Randomize All button
|
417 |
gr.update(interactive=True), # enable Character Description dropdown
|
@@ -419,78 +457,55 @@ class Frontend:
|
|
419 |
gr.update(interactive=True), # enable Generate Text button
|
420 |
gr.update(interactive=True), # enable Input Text input
|
421 |
gr.update(interactive=True), # enable Synthesize Speech Button
|
422 |
-
gr.update(interactive=should_enable_vote_buttons), # enable Select A Button
|
423 |
-
gr.update(interactive=should_enable_vote_buttons), # enable Select B Button
|
424 |
)
|
425 |
|
426 |
-
def _reset_voting_ui(self) -> Tuple[
|
427 |
-
gr.Audio,
|
428 |
-
gr.Audio,
|
429 |
-
gr.Button,
|
430 |
-
gr.Button,
|
431 |
-
gr.Textbox,
|
432 |
-
gr.Textbox,
|
433 |
-
OptionMap,
|
434 |
-
bool,
|
435 |
-
bool,
|
436 |
-
]:
|
437 |
"""
|
438 |
-
Resets voting UI state
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
439 |
"""
|
|
|
440 |
default_option_map: OptionMap = {
|
441 |
"option_a": {"provider": constants.HUME_AI, "generation_id": None, "audio_file_path": ""},
|
442 |
"option_b": {"provider": constants.HUME_AI, "generation_id": None, "audio_file_path": ""},
|
443 |
}
|
444 |
return (
|
445 |
-
gr.update(value=None), # clear audio
|
446 |
-
gr.update(value=None, autoplay=False), # clear audio
|
447 |
-
gr.update(visible=True), # show vote button A
|
448 |
-
gr.update(visible=True), # show vote button B
|
449 |
-
gr.update(visible=False, elem_classes=[]), # hide vote result A
|
450 |
-
gr.update(visible=False, elem_classes=[]), # hide vote result B
|
451 |
-
default_option_map, # Reset option_map_state
|
452 |
False, # Reset vote_submitted_state
|
453 |
False, # Reset should_enable_vote_buttons state
|
454 |
)
|
455 |
|
456 |
-
def
|
457 |
-
"""
|
458 |
-
Builds the Title section
|
459 |
"""
|
460 |
-
|
461 |
-
value="""
|
462 |
-
<div class="title-container">
|
463 |
-
<h1>Expressive TTS Arena</h1>
|
464 |
-
<div class="social-links">
|
465 |
-
<a
|
466 |
-
href="https://discord.com/invite/humeai"
|
467 |
-
target="_blank"
|
468 |
-
id="discord-link"
|
469 |
-
title="Join our Discord"
|
470 |
-
aria-label="Join our Discord server"
|
471 |
-
></a>
|
472 |
-
<a
|
473 |
-
href="https://github.com/HumeAI/expressive-tts-arena"
|
474 |
-
target="_blank"
|
475 |
-
id="github-link"
|
476 |
-
title="View on GitHub"
|
477 |
-
aria-label="View project on GitHub"
|
478 |
-
></a>
|
479 |
-
</div>
|
480 |
-
</div>
|
481 |
-
<div class="excerpt-container">
|
482 |
-
<p>
|
483 |
-
Join the community in evaluating text-to-speech models, and vote for the AI voice that best
|
484 |
-
captures the emotion, nuance, and expressiveness of human speech.
|
485 |
-
</p>
|
486 |
-
</div>
|
487 |
-
"""
|
488 |
-
)
|
489 |
|
490 |
-
|
491 |
-
|
492 |
-
Builds the Arena section
|
493 |
"""
|
|
|
|
|
494 |
# --- UI components ---
|
495 |
with gr.Row():
|
496 |
with gr.Column(scale=5):
|
@@ -525,7 +540,7 @@ class Frontend:
|
|
525 |
)
|
526 |
|
527 |
sample_character_description_dropdown = gr.Dropdown(
|
528 |
-
choices=list(
|
529 |
label="Sample Characters",
|
530 |
info="Generate text with a sample character description.",
|
531 |
value=None,
|
@@ -561,12 +576,12 @@ class Frontend:
|
|
561 |
with gr.Column():
|
562 |
with gr.Group():
|
563 |
option_a_audio_player = gr.Audio(
|
564 |
-
label=
|
565 |
type="filepath",
|
566 |
interactive=False,
|
567 |
show_download_button=False,
|
568 |
)
|
569 |
-
vote_button_a = gr.Button(constants.SELECT_OPTION_A, interactive=False)
|
570 |
vote_result_a = gr.Textbox(
|
571 |
interactive=False,
|
572 |
visible=False,
|
@@ -577,12 +592,12 @@ class Frontend:
|
|
577 |
with gr.Column():
|
578 |
with gr.Group():
|
579 |
option_b_audio_player = gr.Audio(
|
580 |
-
label=
|
581 |
type="filepath",
|
582 |
interactive=False,
|
583 |
show_download_button=False,
|
584 |
)
|
585 |
-
vote_button_b = gr.Button(constants.SELECT_OPTION_B, interactive=False)
|
586 |
vote_result_b = gr.Textbox(
|
587 |
interactive=False,
|
588 |
visible=False,
|
@@ -599,7 +614,7 @@ class Frontend:
|
|
599 |
# Track generated text state
|
600 |
generated_text_state = gr.State("")
|
601 |
# Track whether text that was used was generated or modified/custom
|
602 |
-
text_modified_state = gr.State()
|
603 |
# Track option map (option A and option B are randomized)
|
604 |
option_map_state = gr.State({}) # OptionMap state as a dictionary
|
605 |
# Track whether the user has voted for an option
|
@@ -683,7 +698,7 @@ class Frontend:
|
|
683 |
# 3. Generate text
|
684 |
# 4. Enable interactive UI components
|
685 |
sample_character_description_dropdown.select(
|
686 |
-
fn=lambda choice:
|
687 |
inputs=[sample_character_description_dropdown],
|
688 |
outputs=[character_description_input],
|
689 |
).then(
|
@@ -826,7 +841,7 @@ class Frontend:
|
|
826 |
inputs=[],
|
827 |
outputs=[vote_button_a, vote_button_b],
|
828 |
).then(
|
829 |
-
fn=self.
|
830 |
inputs=[
|
831 |
vote_submitted_state,
|
832 |
option_map_state,
|
@@ -851,7 +866,7 @@ class Frontend:
|
|
851 |
inputs=[],
|
852 |
outputs=[vote_button_a, vote_button_b],
|
853 |
).then(
|
854 |
-
fn=self.
|
855 |
inputs=[
|
856 |
vote_submitted_state,
|
857 |
option_map_state,
|
@@ -881,178 +896,4 @@ class Frontend:
|
|
881 |
outputs=[option_b_audio_player],
|
882 |
)
|
883 |
|
884 |
-
|
885 |
-
"""
|
886 |
-
Builds the Leaderboard section
|
887 |
-
"""
|
888 |
-
# --- UI components ---
|
889 |
-
with gr.Row():
|
890 |
-
with gr.Column(scale=5):
|
891 |
-
gr.HTML(
|
892 |
-
value="""
|
893 |
-
<h2 class="tab-header">π Leaderboard</h2>
|
894 |
-
<p style="padding-left: 8px;">
|
895 |
-
This leaderboard presents community voting results for different TTS providers, showing which
|
896 |
-
ones users found more expressive and natural-sounding. The win rate reflects how often each
|
897 |
-
provider was selected as the preferred option in head-to-head comparisons. Click the refresh
|
898 |
-
button to see the most up-to-date voting results.
|
899 |
-
</p>
|
900 |
-
""",
|
901 |
-
padding=False,
|
902 |
-
)
|
903 |
-
refresh_button = gr.Button(
|
904 |
-
"β» Refresh",
|
905 |
-
variant="primary",
|
906 |
-
elem_classes="refresh-btn",
|
907 |
-
scale=1,
|
908 |
-
)
|
909 |
-
|
910 |
-
with gr.Column(elem_id="leaderboard-table-container"):
|
911 |
-
leaderboard_table = gr.DataFrame(
|
912 |
-
headers=["Rank", "Provider", "Model", "Win Rate", "Votes"],
|
913 |
-
datatype=["html", "html", "html", "html", "html"],
|
914 |
-
column_widths=[80, 300, 180, 120, 116],
|
915 |
-
value=self._leaderboard_data,
|
916 |
-
min_width=680,
|
917 |
-
interactive=False,
|
918 |
-
render=True,
|
919 |
-
elem_id="leaderboard-table"
|
920 |
-
)
|
921 |
-
|
922 |
-
with gr.Column():
|
923 |
-
gr.HTML(
|
924 |
-
value="""
|
925 |
-
<h2 style="padding-top: 12px;" class="tab-header">π Head-to-Head Matchups</h2>
|
926 |
-
<p style="padding-left: 8px; width: 80%;">
|
927 |
-
These tables show how each provider performs against others in direct comparisons.
|
928 |
-
The first table shows the total number of comparisons between each pair of providers.
|
929 |
-
The second table shows the win rate (percentage) of the row provider against the column provider.
|
930 |
-
</p>
|
931 |
-
""",
|
932 |
-
padding=False
|
933 |
-
)
|
934 |
-
|
935 |
-
with gr.Row(equal_height=True):
|
936 |
-
with gr.Column(min_width=420):
|
937 |
-
battle_counts_table = gr.DataFrame(
|
938 |
-
headers=["", "Hume AI", "OpenAI", "ElevenLabs"],
|
939 |
-
datatype=["html", "html", "html", "html"],
|
940 |
-
column_widths=[132, 132, 132, 132],
|
941 |
-
value=self._battle_counts_data,
|
942 |
-
interactive=False,
|
943 |
-
)
|
944 |
-
with gr.Column(min_width=420):
|
945 |
-
win_rates_table = gr.DataFrame(
|
946 |
-
headers=["", "Hume AI", "OpenAI", "ElevenLabs"],
|
947 |
-
datatype=["html", "html", "html", "html"],
|
948 |
-
column_widths=[132, 132, 132, 132],
|
949 |
-
value=self._win_rates_data,
|
950 |
-
interactive=False,
|
951 |
-
)
|
952 |
-
|
953 |
-
with gr.Accordion(label="Citation", open=False):
|
954 |
-
with gr.Column(variant="panel"):
|
955 |
-
with gr.Column(variant="panel"):
|
956 |
-
gr.HTML(
|
957 |
-
value="""
|
958 |
-
<h2>Citation</h2>
|
959 |
-
<p style="padding: 0 8px;">
|
960 |
-
When referencing this leaderboard or its dataset in academic publications, please cite:
|
961 |
-
</p>
|
962 |
-
""",
|
963 |
-
padding=False,
|
964 |
-
)
|
965 |
-
gr.Markdown(
|
966 |
-
value="""
|
967 |
-
**BibTeX**
|
968 |
-
```BibTeX
|
969 |
-
@misc{expressive-tts-arena,
|
970 |
-
title = {Expressive TTS Arena: An Open Platform for Evaluating Text-to-Speech Expressiveness by Human Preference},
|
971 |
-
author = {Alan Cowen, Zachary Greathouse, Richard Marmorstein, Jeremy Hadfield},
|
972 |
-
year = {2025},
|
973 |
-
publisher = {Hugging Face},
|
974 |
-
howpublished = {\\url{https://huggingface.co/spaces/HumeAI/expressive-tts-arena}}
|
975 |
-
}
|
976 |
-
```
|
977 |
-
"""
|
978 |
-
)
|
979 |
-
gr.HTML(
|
980 |
-
value="""
|
981 |
-
<h2>Terms of Use</h2>
|
982 |
-
<p style="padding: 0 8px;">
|
983 |
-
Users are required to agree to the following terms before using the service:
|
984 |
-
</p>
|
985 |
-
<p style="padding: 0 8px;">
|
986 |
-
All generated audio clips are provided for research and evaluation purposes only.
|
987 |
-
The audio content may not be redistributed or used for commercial purposes without
|
988 |
-
explicit permission. Users should not upload any private or personally identifiable
|
989 |
-
information. Please report any bugs, issues, or concerns to our
|
990 |
-
<a href="https://discord.com/invite/humeai" target="_blank" class="provider-link">
|
991 |
-
Discord community
|
992 |
-
</a>.
|
993 |
-
</p>
|
994 |
-
""",
|
995 |
-
padding=False,
|
996 |
-
)
|
997 |
-
gr.HTML(
|
998 |
-
value="""
|
999 |
-
<h2>Acknowledgements</h2>
|
1000 |
-
<p style="padding: 0 8px;">
|
1001 |
-
We thank all participants who contributed their votes to help build this leaderboard.
|
1002 |
-
</p>
|
1003 |
-
""",
|
1004 |
-
padding=False,
|
1005 |
-
)
|
1006 |
-
|
1007 |
-
# Wrapper for the async refresh function
|
1008 |
-
async def async_refresh_handler():
|
1009 |
-
leaderboard_update, battle_counts_update, win_rates_update = await self._refresh_leaderboard(force=True)
|
1010 |
-
return leaderboard_update, battle_counts_update, win_rates_update
|
1011 |
-
|
1012 |
-
# Handler to re-enable the button after a refresh
|
1013 |
-
def reenable_button():
|
1014 |
-
time.sleep(3) # wait 3 seconds before enabling to prevent excessive data fetching
|
1015 |
-
return gr.update(interactive=True)
|
1016 |
-
|
1017 |
-
# Refresh button click event handler
|
1018 |
-
refresh_button.click(
|
1019 |
-
fn=lambda _=None: (gr.update(interactive=False)),
|
1020 |
-
inputs=[],
|
1021 |
-
outputs=[refresh_button],
|
1022 |
-
).then(
|
1023 |
-
fn=async_refresh_handler,
|
1024 |
-
inputs=[],
|
1025 |
-
outputs=[leaderboard_table, battle_counts_table, win_rates_table] # Update all three tables
|
1026 |
-
).then(
|
1027 |
-
fn=reenable_button,
|
1028 |
-
inputs=[],
|
1029 |
-
outputs=[refresh_button]
|
1030 |
-
)
|
1031 |
-
|
1032 |
-
return leaderboard_table, battle_counts_table, win_rates_table
|
1033 |
-
|
1034 |
-
async def build_gradio_interface(self) -> gr.Blocks:
|
1035 |
-
"""
|
1036 |
-
Builds and configures the fully constructed Gradio UI layout.
|
1037 |
-
"""
|
1038 |
-
with gr.Blocks(
|
1039 |
-
title="Expressive TTS Arena",
|
1040 |
-
css_paths="static/css/styles.css",
|
1041 |
-
) as demo:
|
1042 |
-
await self._update_leaderboard_data()
|
1043 |
-
self._build_title_section()
|
1044 |
-
|
1045 |
-
with gr.Tabs() as tabs:
|
1046 |
-
with gr.TabItem("Arena"):
|
1047 |
-
self._build_arena_section()
|
1048 |
-
with gr.TabItem("Leaderboard"):
|
1049 |
-
leaderboard_table, battle_counts_table, win_rates_table = self._build_leaderboard_section()
|
1050 |
-
|
1051 |
-
tabs.select(
|
1052 |
-
fn=self._handle_tab_select,
|
1053 |
-
inputs=[],
|
1054 |
-
outputs=[leaderboard_table, battle_counts_table, win_rates_table],
|
1055 |
-
)
|
1056 |
-
|
1057 |
-
logger.debug("Gradio interface built successfully")
|
1058 |
-
return demo
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
# Standard Library Imports
|
2 |
import asyncio
|
3 |
+
import random
|
|
|
4 |
import time
|
5 |
+
from typing import Tuple, Union
|
6 |
|
7 |
# Third-Party Library Imports
|
8 |
import gradio as gr
|
9 |
|
10 |
# Local Application Imports
|
11 |
+
from src.common import Config, OptionKey, OptionLabel, OptionMap, constants, logger
|
12 |
+
from src.core import TTSService, VotingService
|
13 |
+
from src.integrations import AnthropicError, ElevenLabsError, HumeError, OpenAIError, generate_text_with_claude
|
14 |
+
|
15 |
+
OPTION_A_LABEL: OptionLabel = "Option A"
|
16 |
+
OPTION_B_LABEL: OptionLabel = "Option B"
|
17 |
+
|
18 |
+
# A collection of pre-defined character descriptions categorized by theme, used to provide users with
|
19 |
+
# inspiration for generating creative, expressive text inputs for TTS, and generating novel voices.
|
20 |
+
SAMPLE_CHARACTER_DESCRIPTIONS: dict = {
|
21 |
+
"π¦ Australian Naturalist": (
|
22 |
+
"The speaker has a contagiously enthusiastic Australian accent, with the relaxed, sun-kissed vibe of a "
|
23 |
+
"wildlife expert fresh off the outback, delivering an amazing, laid-back narration."
|
24 |
+
),
|
25 |
+
"π§ Meditation Guru": (
|
26 |
+
"A mindfulness instructor with a gentle, soothing voice that flows at a slow, measured pace with natural "
|
27 |
+
"pauses. Their consistently calm, low-pitched tone has minimal variation, creating a peaceful auditory "
|
28 |
+
"experience."
|
29 |
+
),
|
30 |
+
"π¬ Noir Detective": (
|
31 |
+
"A 1940s private investigator narrating with a gravelly voice and deliberate pacing. "
|
32 |
+
"Speaks with a cynical, world-weary tone that drops lower when delivering key observations."
|
33 |
+
),
|
34 |
+
"π―οΈ Victorian Ghost Storyteller": (
|
35 |
+
"The speaker is a Victorian-era raconteur speaking with a refined English accent and formal, precise diction. Voice "
|
36 |
+
"modulates between hushed, tense whispers and dramatic declarations when describing eerie occurrences."
|
37 |
+
),
|
38 |
+
"πΏ English Naturalist": (
|
39 |
+
"Speaker is a wildlife documentarian speaking with a crisp, articulate English accent and clear enunciation. Voice "
|
40 |
+
"alternates between hushed, excited whispers and enthusiastic explanations filled with genuine wonder."
|
41 |
+
),
|
42 |
+
"π Texan Storyteller": (
|
43 |
+
"A speaker from rural Texas speaking with a warm voice and distinctive Southern drawl featuring elongated "
|
44 |
+
"vowels. Talks unhurriedly with a musical quality and occasional soft laughter."
|
45 |
+
),
|
46 |
+
"π Chill Surfer": (
|
47 |
+
"The speaker is a California surfer talking with a casual, slightly nasal voice and laid-back rhythm. Uses rising "
|
48 |
+
"inflections at sentence ends and bursts into spontaneous laughter when excited."
|
49 |
+
),
|
50 |
+
"π’ Old-School Radio Announcer": (
|
51 |
+
"The speaker has the voice of a seasoned horse race announcer, with a booming, energetic voice, a touch of "
|
52 |
+
"old-school radio charm, and the enthusiastic delivery of a viral commentator."
|
53 |
+
),
|
54 |
+
"π Obnoxious Royal": (
|
55 |
+
"Speaker is a member of the English royal family speaks in a smug and authoritative voice in an obnoxious, proper "
|
56 |
+
"English accent. They are insecure, arrogant, and prone to tantrums."
|
57 |
+
),
|
58 |
+
"π° Medieval Peasant": (
|
59 |
+
"A film portrayal of a medieval peasant speaking with a thick cockney accent and a worn voice, "
|
60 |
+
"dripping with sarcasm and self-effacing humor."
|
61 |
+
),
|
62 |
+
}
|
63 |
+
|
64 |
+
class Arena:
|
65 |
+
"""
|
66 |
+
Handles the user interface logic, state management, and event handling
|
67 |
+
for the 'Arena' tab where users generate, synthesize, and compare TTS audio.
|
68 |
+
"""
|
69 |
+
def __init__(self, config: Config, tts_service: TTSService, voting_service: VotingService):
|
70 |
"""
|
71 |
+
Initializes the Arena component.
|
72 |
|
73 |
Args:
|
74 |
+
config: The application configuration object.
|
75 |
+
tts_service: The service for TTS operations.
|
76 |
+
voting_service: The service for voting/leaderboard DB operations.
|
77 |
+
"""
|
78 |
+
self.config: Config = config
|
79 |
+
self.tts_service = tts_service
|
80 |
+
self.voting_service = voting_service
|
81 |
|
82 |
+
def _validate_input_length(
|
83 |
+
self,
|
84 |
+
input_value: str,
|
85 |
+
min_length: int,
|
86 |
+
max_length: int,
|
87 |
+
input_name: str,
|
88 |
+
) -> None:
|
89 |
+
"""
|
90 |
+
Validates input string length against minimum and maximum limits.
|
91 |
+
|
92 |
+
Args:
|
93 |
+
input_value: The string value to validate.
|
94 |
+
min_length: The minimum required length (inclusive).
|
95 |
+
max_length: The maximum allowed length (inclusive).
|
96 |
+
input_name: A descriptive name of the input field (e.g., "character description")
|
97 |
+
used for error messages.
|
98 |
+
|
99 |
+
Raises:
|
100 |
+
ValueError: If the input length is outside the specified bounds.
|
101 |
+
"""
|
102 |
+
stripped_value = input_value.strip()
|
103 |
+
value_length = len(stripped_value)
|
104 |
+
logger.debug(f"Validating length for '{input_name}': {value_length} characters")
|
105 |
+
|
106 |
+
if value_length < min_length:
|
107 |
+
raise ValueError(
|
108 |
+
f"Your {input_name} is too short. Please enter at least "
|
109 |
+
f"{min_length} characters. (Current length: {value_length})"
|
110 |
+
)
|
111 |
+
if value_length > max_length:
|
112 |
+
raise ValueError(
|
113 |
+
f"Your {input_name} is too long. Please limit it to "
|
114 |
+
f"{max_length} characters. (Current length: {value_length})"
|
115 |
+
)
|
116 |
+
|
117 |
+
def _validate_character_description_length(self, character_description: str) -> None:
|
118 |
+
"""
|
119 |
+
Validates the character description length using predefined constants.
|
120 |
+
|
121 |
+
Args:
|
122 |
+
character_description: The input character description to validate.
|
123 |
+
|
124 |
+
Raises:
|
125 |
+
ValueError: If the character description length is invalid.
|
126 |
"""
|
127 |
+
self._validate_input_length(
|
128 |
+
character_description,
|
129 |
+
constants.CHARACTER_DESCRIPTION_MIN_LENGTH,
|
130 |
+
constants.CHARACTER_DESCRIPTION_MAX_LENGTH,
|
131 |
+
"character description",
|
132 |
+
)
|
133 |
+
|
134 |
+
def _validate_text_length(self, text: str) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
135 |
"""
|
136 |
+
Validates the input text length using predefined constants.
|
137 |
|
138 |
Args:
|
139 |
+
text: The input text to validate.
|
140 |
+
|
141 |
+
Raises:
|
142 |
+
ValueError: If the text length is invalid.
|
143 |
+
"""
|
144 |
+
self._validate_input_length(
|
145 |
+
text,
|
146 |
+
constants.TEXT_MIN_LENGTH,
|
147 |
+
constants.TEXT_MAX_LENGTH,
|
148 |
+
"text",
|
149 |
+
)
|
150 |
+
|
151 |
+
async def _generate_text(self, character_description: str) -> Tuple[dict, str]:
|
152 |
+
"""
|
153 |
+
Validates the character description and generates text using the Anthropic API.
|
154 |
+
|
155 |
+
Args:
|
156 |
+
character_description: The user-provided text for character description.
|
157 |
|
158 |
Returns:
|
159 |
+
A tuple containing:
|
160 |
+
- A Gradio update dictionary for the text input component.
|
161 |
+
- The generated text string (also used for state).
|
162 |
|
163 |
Raises:
|
164 |
+
gr.Error: On validation failure or Anthropic API errors.
|
165 |
"""
|
166 |
try:
|
167 |
+
self._validate_character_description_length(character_description)
|
168 |
except ValueError as ve:
|
169 |
logger.warning(f"Validation error: {ve}")
|
170 |
raise gr.Error(str(ve))
|
|
|
177 |
logger.error(f"Text Generation Failed: AnthropicError while generating text: {ae!s}")
|
178 |
raise gr.Error(f'There was an issue communicating with the Anthropic API: "{ae.message}"')
|
179 |
except Exception as e:
|
180 |
+
logger.error(f"Text Generation Failed: Unexpected error while generating text: {e!s}", exc_info=True)
|
181 |
raise gr.Error("Failed to generate text. Please try again shortly.")
|
182 |
|
183 |
def _warn_user_about_custom_text(self, text: str, generated_text: str) -> None:
|
184 |
"""
|
185 |
+
Displays a Gradio warning if the input text differs from the generated text state.
|
186 |
+
This informs the user that using custom text limits the comparison to only Hume outputs.
|
|
|
|
|
|
|
187 |
|
188 |
Args:
|
189 |
+
text: The current text in the input component.
|
190 |
+
generated_text: The original text generated by the system (stored in state).
|
|
|
|
|
|
|
191 |
"""
|
192 |
if text != generated_text:
|
193 |
+
gr.Warning("When custom text is used, only Hume Octave outputs are generated for comparison.")
|
194 |
|
195 |
async def _synthesize_speech(
|
196 |
self,
|
197 |
character_description: str,
|
198 |
text: str,
|
199 |
generated_text_state: str,
|
200 |
+
) -> Tuple[dict, dict, OptionMap, bool, str, str, bool]:
|
201 |
"""
|
202 |
+
Validates inputs and synthesizes two TTS outputs for comparison.
|
|
|
|
|
|
|
203 |
|
204 |
+
Generates TTS audio using different providers (or only Hume if text was
|
205 |
+
modified), updates UI state, and returns audio paths and metadata.
|
206 |
|
207 |
Args:
|
208 |
+
character_description: The description used for voice generation.
|
209 |
+
text: The text content to synthesize.
|
210 |
+
generated_text_state: The previously generated text state to check for modifications.
|
|
|
211 |
|
212 |
Returns:
|
213 |
+
A tuple containing:
|
214 |
+
- dict: Gradio update for the first audio player (Option A).
|
215 |
+
- dict: Gradio update for the second audio player (Option B).
|
216 |
+
- OptionMap: Mapping of options ('option_a', 'option_b') to provider details.
|
217 |
+
- bool: Flag indicating if the input text was modified from the generated state.
|
218 |
+
- str: The text string that was synthesized (for state).
|
219 |
+
- str: The character description used (for state).
|
220 |
+
- bool: Flag indicating whether the vote buttons should be enabled.
|
221 |
|
222 |
Raises:
|
223 |
+
gr.Error: On validation failure or errors during TTS synthesis API calls.
|
224 |
"""
|
225 |
try:
|
226 |
+
self._validate_character_description_length(character_description)
|
227 |
+
self._validate_text_length(text)
|
228 |
except ValueError as ve:
|
229 |
+
logger.error(f"Validation error during speech synthesis: {ve}")
|
230 |
raise gr.Error(str(ve))
|
231 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
232 |
try:
|
233 |
+
text_modified = text != generated_text_state
|
234 |
+
options_map: OptionMap = await self.tts_service.synthesize_speech(character_description, text, text_modified)
|
|
|
|
|
|
|
235 |
|
236 |
+
# Ensure options_map has the expected keys before accessing
|
237 |
+
if "option_a" not in options_map or "option_b" not in options_map:
|
238 |
+
logger.error(f"Invalid options_map received from TTS service: {options_map}")
|
239 |
+
raise gr.Error("Internal error: Failed to retrieve synthesis results correctly.")
|
240 |
+
if not options_map.get("option_a") or not options_map.get("option_b"):
|
241 |
+
logger.error(f"Missing data in options_map from TTS service: {options_map}")
|
242 |
+
raise gr.Error("Internal error: Missing synthesis results.")
|
243 |
|
244 |
return (
|
245 |
gr.update(value=options_map["option_a"]["audio_file_path"], autoplay=True),
|
246 |
gr.update(value=options_map["option_b"]["audio_file_path"]),
|
247 |
options_map,
|
248 |
text_modified,
|
249 |
+
text, # text_state update
|
250 |
+
character_description, # character_description_state update
|
251 |
+
True, # should_enable_vote_buttons update
|
252 |
)
|
253 |
except HumeError as he:
|
254 |
logger.error(f"Synthesis failed with HumeError during TTS generation: {he!s}")
|
|
|
260 |
logger.error(f"Synthesis failed with ElevenLabsError during TTS generation: {ee!s}")
|
261 |
raise gr.Error(f'There was an issue communicating with the Elevenlabs API: "{ee.message}"')
|
262 |
except Exception as e:
|
263 |
+
logger.error(f"Synthesis failed with an unexpected error during TTS generation: {e!s}", exc_info=True)
|
264 |
raise gr.Error("An unexpected error occurred. Please try again shortly.")
|
265 |
|
266 |
+
def _determine_selected_option(self, selected_option_button_value: str) -> Tuple[OptionKey, OptionKey]:
|
267 |
+
"""
|
268 |
+
Determines the selected option key ('option_a'/'option_b') based on the button value.
|
269 |
+
|
270 |
+
Args:
|
271 |
+
selected_option_button_value: The value property of the clicked vote button
|
272 |
+
(e.g., constants.SELECT_OPTION_A).
|
273 |
+
|
274 |
+
Returns:
|
275 |
+
A tuple (selected_option_key, other_option_key).
|
276 |
+
|
277 |
+
Raises:
|
278 |
+
ValueError: If the button value is not one of the expected constants.
|
279 |
+
"""
|
280 |
+
if selected_option_button_value == constants.SELECT_OPTION_A:
|
281 |
+
selected_option, other_option = constants.OPTION_A_KEY, constants.OPTION_B_KEY
|
282 |
+
elif selected_option_button_value == constants.SELECT_OPTION_B:
|
283 |
+
selected_option, other_option = constants.OPTION_B_KEY, constants.OPTION_A_KEY
|
284 |
+
else:
|
285 |
+
logger.error(f"Invalid selected button value received: {selected_option_button_value}")
|
286 |
+
raise ValueError(f"Invalid selected button: {selected_option_button_value}")
|
287 |
+
|
288 |
+
return selected_option, other_option
|
289 |
+
|
290 |
+
async def _submit_vote(
|
291 |
self,
|
292 |
vote_submitted: bool,
|
293 |
option_map: OptionMap,
|
294 |
+
clicked_option_button_value: str, # Renamed for clarity (it's the button's value, not the component)
|
295 |
text_modified: bool,
|
296 |
character_description: str,
|
297 |
text: str,
|
298 |
) -> Tuple[
|
299 |
+
Union[bool, gr.skip],
|
300 |
+
Union[dict, gr.skip],
|
301 |
+
Union[dict, gr.skip],
|
302 |
+
Union[dict, gr.skip],
|
303 |
+
Union[dict, gr.skip],
|
304 |
+
Union[dict, gr.skip]
|
305 |
]:
|
306 |
"""
|
307 |
+
Handles user voting, submits results asynchronously, and updates the UI.
|
308 |
+
|
309 |
+
Prevents duplicate votes and updates button visibility and result textboxes.
|
310 |
|
311 |
Args:
|
312 |
+
vote_submitted: Boolean state indicating if a vote was already submitted for this pair.
|
313 |
+
option_map: The OptionMap dictionary containing details of the two options.
|
314 |
+
clicked_option_button_value: The value of the button that was clicked (e.g., constants.SELECT_OPTION_A).
|
315 |
+
text_modified: Boolean state indicating if the text was modified by the user.
|
316 |
+
character_description: The character description used for synthesis (from state).
|
317 |
+
text: The text used for synthesis (from state).
|
318 |
|
319 |
Returns:
|
320 |
+
A tuple of updates for various UI components and state variables,
|
321 |
+
or multiple gr.skip() objects if the vote is ignored (e.g., duplicate).
|
322 |
+
Elements are:
|
323 |
+
- bool | gr.skip: Update for vote_submitted_state (True if vote processed).
|
324 |
+
- dict | gr.skip: Update for vote_button_a (visibility).
|
325 |
+
- dict | gr.skip: Update for vote_button_b (visibility).
|
326 |
+
- dict | gr.skip: Update for vote_result_a (visibility, value, style).
|
327 |
+
- dict | gr.skip: Update for vote_result_b (visibility, value, style).
|
328 |
+
- dict | gr.skip: Update for synthesize_speech_button (interactivity).
|
329 |
"""
|
330 |
+
# If option_map is empty/invalid or vote already submitted, do nothing
|
331 |
+
if not isinstance(option_map, dict) or not option_map or vote_submitted:
|
332 |
+
logger.warning(f"Vote submission skipped. Option map valid: {isinstance(option_map, dict) and bool(option_map)}, Vote submitted: {vote_submitted}")
|
333 |
+
# Return gr.skip() for all outputs
|
334 |
return gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip()
|
335 |
|
336 |
+
try:
|
337 |
+
selected_option, other_option = self._determine_selected_option(clicked_option_button_value)
|
338 |
+
|
339 |
+
# Ensure keys exist before accessing
|
340 |
+
if selected_option not in option_map or other_option not in option_map:
|
341 |
+
logger.error(f"Selected/Other option key missing in option_map: {selected_option}, {other_option}, Map: {option_map}")
|
342 |
+
raise gr.Error("Internal error: Could not process vote due to inconsistent data.")
|
343 |
+
if "provider" not in option_map[selected_option] or "provider" not in option_map[other_option]:
|
344 |
+
logger.error(f"Provider missing in option_map entry: Map: {option_map}")
|
345 |
+
raise gr.Error("Internal error: Could not process vote due to missing provider data.")
|
346 |
+
|
347 |
+
selected_provider = option_map[selected_option]["provider"]
|
348 |
+
other_provider = option_map[other_option]["provider"]
|
349 |
+
|
350 |
+
# Process vote in the background without blocking the UI
|
351 |
+
asyncio.create_task(
|
352 |
+
self.voting_service.submit_vote(
|
353 |
+
option_map,
|
354 |
+
selected_option,
|
355 |
+
text_modified,
|
356 |
+
character_description,
|
357 |
+
text,
|
358 |
+
)
|
359 |
)
|
360 |
+
logger.info(f"Vote submitted: Selected '{selected_provider}', Other '{other_provider}'")
|
361 |
|
362 |
+
# Build result labels
|
363 |
+
selected_label = f"{selected_provider} π"
|
364 |
+
other_label = f"{other_provider}"
|
365 |
|
366 |
+
# Determine which result box gets which label
|
367 |
+
result_a_update = gr.update(value=other_label, visible=True)
|
368 |
+
result_b_update = gr.update(value=selected_label, visible=True, elem_classes="winner")
|
369 |
+
if selected_option == constants.OPTION_A_KEY:
|
370 |
+
result_a_update = gr.update(value=selected_label, visible=True, elem_classes="winner")
|
371 |
+
result_b_update = gr.update(value=other_label, visible=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
372 |
|
373 |
+
|
374 |
+
return (
|
375 |
+
True, # Update vote_submitted_state to True
|
376 |
+
gr.update(visible=False), # Hide vote button A
|
377 |
+
gr.update(visible=False), # Hide vote button B
|
378 |
+
result_a_update, # Show/update result textbox A
|
379 |
+
result_b_update, # Show/update result textbox B
|
380 |
+
gr.update(interactive=True), # Re-enable synthesize speech button
|
381 |
+
)
|
382 |
+
except ValueError as ve: # Catch error from _determine_selected_option
|
383 |
+
logger.error(f"Vote submission failed due to invalid button value: {ve}", exc_info=True)
|
384 |
+
# Optionally raise gr.Error or just skip updates
|
385 |
+
gr.Error("An internal error occurred while processing your vote.")
|
386 |
+
return gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip()
|
387 |
+
except Exception as e:
|
388 |
+
logger.error(f"Vote submission failed unexpectedly: {e!s}", exc_info=True)
|
389 |
+
gr.Error("An unexpected error occurred while submitting your vote.")
|
390 |
+
# Still return skips to avoid partial UI updates
|
391 |
+
return gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip()
|
392 |
+
|
393 |
+
async def _randomize_character_description(self) -> Tuple[dict, dict]:
|
394 |
"""
|
395 |
+
Selects a random character description from the predefined samples.
|
396 |
|
397 |
Returns:
|
398 |
+
A tuple containing Gradio update dictionaries for:
|
399 |
+
- The sample character dropdown component.
|
400 |
+
- The character description input component.
|
401 |
"""
|
402 |
+
# Ensure SAMPLE_CHARACTER_DESCRIPTIONS is not empty
|
403 |
+
if not SAMPLE_CHARACTER_DESCRIPTIONS:
|
404 |
+
logger.warning("SAMPLE_CHARACTER_DESCRIPTIONS is empty. Cannot randomize.")
|
405 |
+
# Return updates that clear the fields or do nothing
|
406 |
+
return gr.update(value=None), gr.update(value="")
|
407 |
|
408 |
+
sample_keys = list(SAMPLE_CHARACTER_DESCRIPTIONS.keys())
|
409 |
random_sample = random.choice(sample_keys)
|
410 |
+
character_description = SAMPLE_CHARACTER_DESCRIPTIONS[random_sample]
|
411 |
|
412 |
logger.info(f"Randomize All: Selected '{random_sample}'")
|
413 |
|
414 |
return (
|
415 |
+
gr.update(value=random_sample), # Update dropdown selection
|
416 |
+
gr.update(value=character_description), # Update character description text
|
417 |
)
|
418 |
|
419 |
+
def _disable_ui(self) -> Tuple[dict, dict, dict, dict, dict, dict, dict, dict]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
420 |
"""
|
421 |
+
Disables interactive UI components during processing.
|
|
|
|
|
|
|
422 |
|
423 |
Returns:
|
424 |
+
A tuple of Gradio update dictionaries to set interactive=False
|
425 |
+
for relevant buttons, dropdowns, and textboxes.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
426 |
"""
|
427 |
+
logger.debug("Disabling UI components.")
|
428 |
return(
|
429 |
gr.update(interactive=False), # disable Randomize All button
|
430 |
gr.update(interactive=False), # disable Character Description dropdown
|
|
|
436 |
gr.update(interactive=False), # disable Select B Button
|
437 |
)
|
438 |
|
439 |
+
def _enable_ui(self, should_enable_vote_buttons: bool) -> Tuple[dict, dict, dict, dict, dict, dict, dict, dict]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
440 |
"""
|
441 |
+
Enables interactive UI components after processing.
|
442 |
+
|
443 |
+
Args:
|
444 |
+
should_enable_vote_buttons: Boolean indicating if the voting buttons
|
445 |
+
should be enabled (based on synthesis success).
|
446 |
+
|
447 |
+
Returns:
|
448 |
+
A tuple of Gradio update dictionaries to set interactive=True
|
449 |
+
for relevant buttons, dropdowns, and textboxes. Vote buttons'
|
450 |
+
interactivity depends on the input argument.
|
451 |
"""
|
452 |
+
logger.debug(f"Enabling UI components. Enable vote buttons: {should_enable_vote_buttons}")
|
453 |
return(
|
454 |
gr.update(interactive=True), # enable Randomize All button
|
455 |
gr.update(interactive=True), # enable Character Description dropdown
|
|
|
457 |
gr.update(interactive=True), # enable Generate Text button
|
458 |
gr.update(interactive=True), # enable Input Text input
|
459 |
gr.update(interactive=True), # enable Synthesize Speech Button
|
460 |
+
gr.update(interactive=should_enable_vote_buttons), # enable/disable Select A Button
|
461 |
+
gr.update(interactive=should_enable_vote_buttons), # enable/disable Select B Button
|
462 |
)
|
463 |
|
464 |
+
def _reset_voting_ui(self) -> Tuple[dict, dict, dict, dict, dict, dict, OptionMap, bool, bool]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
465 |
"""
|
466 |
+
Resets the voting UI elements to their initial state before new synthesis.
|
467 |
+
|
468 |
+
Clears audio players, makes vote buttons visible, hides result textboxes,
|
469 |
+
and resets associated state variables.
|
470 |
+
|
471 |
+
Returns:
|
472 |
+
A tuple containing updates for UI components and state variables:
|
473 |
+
- dict: Update for audio player A (clear value).
|
474 |
+
- dict: Update for audio player B (clear value, disable autoplay).
|
475 |
+
- dict: Update for vote button A (make visible).
|
476 |
+
- dict: Update for vote button B (make visible).
|
477 |
+
- dict: Update for vote result A (hide, clear style).
|
478 |
+
- dict: Update for vote result B (hide, clear style).
|
479 |
+
- OptionMap: Reset option_map_state to a default placeholder.
|
480 |
+
- bool: Reset vote_submitted_state to False.
|
481 |
+
- bool: Reset should_enable_vote_buttons state to False.
|
482 |
"""
|
483 |
+
logger.debug("Resetting voting UI.")
|
484 |
default_option_map: OptionMap = {
|
485 |
"option_a": {"provider": constants.HUME_AI, "generation_id": None, "audio_file_path": ""},
|
486 |
"option_b": {"provider": constants.HUME_AI, "generation_id": None, "audio_file_path": ""},
|
487 |
}
|
488 |
return (
|
489 |
+
gr.update(value=None, label=OPTION_A_LABEL), # clear audio player A, reset label
|
490 |
+
gr.update(value=None, autoplay=False, label=OPTION_B_LABEL), # clear audio player B, ensure autoplay off, reset label
|
491 |
+
gr.update(visible=True, interactive=False), # show vote button A, ensure non-interactive until enabled
|
492 |
+
gr.update(visible=True, interactive=False), # show vote button B, ensure non-interactive until enabled
|
493 |
+
gr.update(value="", visible=False, elem_classes=[]), # hide vote result A, clear text/style
|
494 |
+
gr.update(value="", visible=False, elem_classes=[]), # hide vote result B, clear text/style
|
495 |
+
default_option_map, # Reset option_map_state
|
496 |
False, # Reset vote_submitted_state
|
497 |
False, # Reset should_enable_vote_buttons state
|
498 |
)
|
499 |
|
500 |
+
def build_arena_section(self) -> None:
|
|
|
|
|
501 |
"""
|
502 |
+
Constructs the Gradio UI layout for the Arena tab and registers event handlers.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
503 |
|
504 |
+
This method defines all the components within the Arena tab and connects
|
505 |
+
button clicks, dropdown selections, etc., to their corresponding handler functions.
|
|
|
506 |
"""
|
507 |
+
logger.debug("Building Arena UI section...")
|
508 |
+
|
509 |
# --- UI components ---
|
510 |
with gr.Row():
|
511 |
with gr.Column(scale=5):
|
|
|
540 |
)
|
541 |
|
542 |
sample_character_description_dropdown = gr.Dropdown(
|
543 |
+
choices=list(SAMPLE_CHARACTER_DESCRIPTIONS.keys()),
|
544 |
label="Sample Characters",
|
545 |
info="Generate text with a sample character description.",
|
546 |
value=None,
|
|
|
576 |
with gr.Column():
|
577 |
with gr.Group():
|
578 |
option_a_audio_player = gr.Audio(
|
579 |
+
label=OPTION_A_LABEL,
|
580 |
type="filepath",
|
581 |
interactive=False,
|
582 |
show_download_button=False,
|
583 |
)
|
584 |
+
vote_button_a = gr.Button(value=constants.SELECT_OPTION_A, interactive=False)
|
585 |
vote_result_a = gr.Textbox(
|
586 |
interactive=False,
|
587 |
visible=False,
|
|
|
592 |
with gr.Column():
|
593 |
with gr.Group():
|
594 |
option_b_audio_player = gr.Audio(
|
595 |
+
label=OPTION_B_LABEL,
|
596 |
type="filepath",
|
597 |
interactive=False,
|
598 |
show_download_button=False,
|
599 |
)
|
600 |
+
vote_button_b = gr.Button(value=constants.SELECT_OPTION_B, interactive=False)
|
601 |
vote_result_b = gr.Textbox(
|
602 |
interactive=False,
|
603 |
visible=False,
|
|
|
614 |
# Track generated text state
|
615 |
generated_text_state = gr.State("")
|
616 |
# Track whether text that was used was generated or modified/custom
|
617 |
+
text_modified_state = gr.State(False)
|
618 |
# Track option map (option A and option B are randomized)
|
619 |
option_map_state = gr.State({}) # OptionMap state as a dictionary
|
620 |
# Track whether the user has voted for an option
|
|
|
698 |
# 3. Generate text
|
699 |
# 4. Enable interactive UI components
|
700 |
sample_character_description_dropdown.select(
|
701 |
+
fn=lambda choice: SAMPLE_CHARACTER_DESCRIPTIONS.get(choice, ""),
|
702 |
inputs=[sample_character_description_dropdown],
|
703 |
outputs=[character_description_input],
|
704 |
).then(
|
|
|
841 |
inputs=[],
|
842 |
outputs=[vote_button_a, vote_button_b],
|
843 |
).then(
|
844 |
+
fn=self._submit_vote,
|
845 |
inputs=[
|
846 |
vote_submitted_state,
|
847 |
option_map_state,
|
|
|
866 |
inputs=[],
|
867 |
outputs=[vote_button_a, vote_button_b],
|
868 |
).then(
|
869 |
+
fn=self._submit_vote,
|
870 |
inputs=[
|
871 |
vote_submitted_state,
|
872 |
option_map_state,
|
|
|
896 |
outputs=[option_b_audio_player],
|
897 |
)
|
898 |
|
899 |
+
logger.debug("Arena UI section built.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/frontend/components/leaderboard.py
ADDED
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Standard Library Imports
|
2 |
+
import hashlib
|
3 |
+
import json
|
4 |
+
import time
|
5 |
+
from typing import List, Optional, Tuple, Union
|
6 |
+
|
7 |
+
# Third-Party Library Imports
|
8 |
+
import gradio as gr
|
9 |
+
|
10 |
+
# Local Application Imports
|
11 |
+
from src.common import logger
|
12 |
+
from src.core import VotingService
|
13 |
+
|
14 |
+
|
15 |
+
class Leaderboard:
|
16 |
+
"""
|
17 |
+
Manages the state, data fetching, and UI construction for the Leaderboard tab.
|
18 |
+
|
19 |
+
Includes caching and throttling for leaderboard data updates.
|
20 |
+
"""
|
21 |
+
def __init__(self, voting_service: VotingService):
|
22 |
+
"""
|
23 |
+
Initializes the Leaderboard component.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
voting_service: The service for voting/leaderboard DB operations.
|
27 |
+
"""
|
28 |
+
self.voting_service = voting_service
|
29 |
+
|
30 |
+
# leaderboard update state
|
31 |
+
self.leaderboard_data: List[List[str]] = [[]]
|
32 |
+
self.battle_counts_data: List[List[str]] = [[]]
|
33 |
+
self.win_rates_data: List[List[str]] = [[]]
|
34 |
+
self.leaderboard_cache_hash: Optional[str] = None
|
35 |
+
self.last_leaderboard_update_time: float = 0.0
|
36 |
+
self.min_refresh_interval: int = 30
|
37 |
+
|
38 |
+
async def _update_leaderboard_data(self, force: bool = False) -> bool:
|
39 |
+
"""
|
40 |
+
Fetches leaderboard data from the source if cache is stale or force=True.
|
41 |
+
|
42 |
+
Updates internal state variables (leaderboard_data, battle_counts_data,
|
43 |
+
win_rates_data, cache_hash, last_update_time) if new data is fetched.
|
44 |
+
Uses time-based throttling defined by `min_refresh_interval`.
|
45 |
+
|
46 |
+
Args:
|
47 |
+
force: If True, bypasses cache hash check and time throttling.
|
48 |
+
|
49 |
+
Returns:
|
50 |
+
True if the leaderboard data state was updated, False otherwise.
|
51 |
+
"""
|
52 |
+
current_time = time.time()
|
53 |
+
time_since_last_update = current_time - self.last_leaderboard_update_time
|
54 |
+
|
55 |
+
# Skip update if throttled and not forced
|
56 |
+
if not force and time_since_last_update < self.min_refresh_interval:
|
57 |
+
logger.debug(f"Skipping leaderboard update (throttled): last updated {time_since_last_update:.1f}s ago.")
|
58 |
+
return False
|
59 |
+
|
60 |
+
try:
|
61 |
+
# Fetch the latest data
|
62 |
+
(
|
63 |
+
latest_leaderboard_data,
|
64 |
+
latest_battle_counts_data,
|
65 |
+
latest_win_rates_data
|
66 |
+
) = await self.voting_service.get_formatted_leaderboard_data()
|
67 |
+
|
68 |
+
# Check if data is valid before proceeding
|
69 |
+
if not latest_leaderboard_data or not latest_leaderboard_data[0]:
|
70 |
+
logger.error("Invalid data received from get_leaderboard_data.")
|
71 |
+
return False
|
72 |
+
|
73 |
+
# Generate a hash of the primary leaderboard data to check for changes
|
74 |
+
# Use a stable serialization format (sort_keys=True)
|
75 |
+
data_str = json.dumps(latest_leaderboard_data, sort_keys=True)
|
76 |
+
new_data_hash = hashlib.md5(data_str.encode()).hexdigest()
|
77 |
+
|
78 |
+
# Skip if data hasn't changed and not forced
|
79 |
+
if not force and new_data_hash == self.leaderboard_cache_hash:
|
80 |
+
logger.debug("Leaderboard data unchanged since last fetch.")
|
81 |
+
return False
|
82 |
+
|
83 |
+
# Update the state and cache
|
84 |
+
self.leaderboard_data = latest_leaderboard_data
|
85 |
+
self.battle_counts_data = latest_battle_counts_data
|
86 |
+
self.win_rates_data = latest_win_rates_data
|
87 |
+
self.leaderboard_cache_hash = new_data_hash
|
88 |
+
self.last_leaderboard_update_time = current_time
|
89 |
+
logger.info("Leaderboard data updated successfully.")
|
90 |
+
return True
|
91 |
+
|
92 |
+
except Exception as e:
|
93 |
+
logger.error(f"Failed to update leaderboard data: {e!s}", exc_info=True)
|
94 |
+
return False
|
95 |
+
|
96 |
+
async def refresh_leaderboard(
|
97 |
+
self, force: bool = False
|
98 |
+
) -> Tuple[Union[dict, gr.skip], Union[dict, gr.skip], Union[dict, gr.skip]]:
|
99 |
+
"""
|
100 |
+
Refreshes leaderboard data state and returns Gradio updates for the tables.
|
101 |
+
|
102 |
+
Calls `_update_leaderboard_data` and returns updates only if data changed
|
103 |
+
or `force` is True. Returns gr.skip() otherwise.
|
104 |
+
|
105 |
+
Args:
|
106 |
+
force: If True, forces `_update_leaderboard_data` to bypass throttling/cache.
|
107 |
+
|
108 |
+
Returns:
|
109 |
+
A tuple of Gradio update dictionaries for the leaderboard, battle counts,
|
110 |
+
and win rates tables, or gr.skip() for each if no update is needed.
|
111 |
+
|
112 |
+
Raises:
|
113 |
+
gr.Error: If leaderboard data is empty/invalid after attempting an update.
|
114 |
+
(Changed from previous: now raises only if data is *still* bad)
|
115 |
+
"""
|
116 |
+
data_updated = await self._update_leaderboard_data(force=force)
|
117 |
+
|
118 |
+
if not self.leaderboard_data or not isinstance(self.leaderboard_data[0], list):
|
119 |
+
logger.error("Leaderboard data is empty or invalid after update attempt.")
|
120 |
+
raise gr.Error("Unable to retrieve leaderboard data. Please refresh the page or try again shortly.")
|
121 |
+
|
122 |
+
if data_updated or force:
|
123 |
+
logger.debug("Returning leaderboard table updates.")
|
124 |
+
return (
|
125 |
+
gr.update(value=self.leaderboard_data),
|
126 |
+
gr.update(value=self.battle_counts_data),
|
127 |
+
gr.update(value=self.win_rates_data)
|
128 |
+
)
|
129 |
+
logger.debug("Skipping leaderboard table updates (no data change).")
|
130 |
+
return gr.skip(), gr.skip(), gr.skip()
|
131 |
+
|
132 |
+
async def build_leaderboard_section(self) -> Tuple[gr.DataFrame, gr.DataFrame, gr.DataFrame]:
|
133 |
+
"""
|
134 |
+
Constructs the Gradio UI layout for the Leaderboard tab.
|
135 |
+
|
136 |
+
Defines the DataFrames, HTML descriptions, and refresh button logic.
|
137 |
+
|
138 |
+
Returns:
|
139 |
+
A tuple containing the Gradio DataFrame components for:
|
140 |
+
- Main Leaderboard table
|
141 |
+
- Battle Counts table
|
142 |
+
- Win Rates table
|
143 |
+
These components are needed by the main Frontend class to wire up events.
|
144 |
+
"""
|
145 |
+
logger.debug("Building Leaderboard UI section...")
|
146 |
+
# Pre-load leaderboard data before building UI that depends on it
|
147 |
+
await self._update_leaderboard_data(force=True)
|
148 |
+
|
149 |
+
# --- UI components ---
|
150 |
+
with gr.Row():
|
151 |
+
with gr.Column(scale=5):
|
152 |
+
gr.HTML(
|
153 |
+
value="""
|
154 |
+
<h2 class="tab-header">π Leaderboard</h2>
|
155 |
+
<p style="padding-left: 8px;">
|
156 |
+
This leaderboard presents community voting results for different TTS providers, showing which
|
157 |
+
ones users found more expressive and natural-sounding. The win rate reflects how often each
|
158 |
+
provider was selected as the preferred option in head-to-head comparisons. Click the refresh
|
159 |
+
button to see the most up-to-date voting results.
|
160 |
+
</p>
|
161 |
+
""",
|
162 |
+
padding=False,
|
163 |
+
)
|
164 |
+
refresh_button = gr.Button(
|
165 |
+
"β» Refresh",
|
166 |
+
variant="primary",
|
167 |
+
elem_classes="refresh-btn",
|
168 |
+
scale=1,
|
169 |
+
)
|
170 |
+
|
171 |
+
with gr.Column(elem_id="leaderboard-table-container"):
|
172 |
+
leaderboard_table = gr.DataFrame(
|
173 |
+
headers=["Rank", "Provider", "Model", "Win Rate", "Votes"],
|
174 |
+
datatype=["html", "html", "html", "html", "html"],
|
175 |
+
column_widths=[80, 300, 180, 120, 116],
|
176 |
+
value=self.leaderboard_data,
|
177 |
+
min_width=680,
|
178 |
+
interactive=False,
|
179 |
+
render=True,
|
180 |
+
elem_id="leaderboard-table"
|
181 |
+
)
|
182 |
+
|
183 |
+
with gr.Column():
|
184 |
+
gr.HTML(
|
185 |
+
value="""
|
186 |
+
<h2 style="padding-top: 12px;" class="tab-header">π Head-to-Head Matchups</h2>
|
187 |
+
<p style="padding-left: 8px; width: 80%;">
|
188 |
+
These tables show how each provider performs against others in direct comparisons.
|
189 |
+
The first table shows the total number of comparisons between each pair of providers.
|
190 |
+
The second table shows the win rate (percentage) of the row provider against the column provider.
|
191 |
+
</p>
|
192 |
+
""",
|
193 |
+
padding=False
|
194 |
+
)
|
195 |
+
|
196 |
+
with gr.Row(equal_height=True):
|
197 |
+
with gr.Column(min_width=420):
|
198 |
+
battle_counts_table = gr.DataFrame(
|
199 |
+
headers=["", "Hume AI", "OpenAI", "ElevenLabs"],
|
200 |
+
datatype=["html", "html", "html", "html"],
|
201 |
+
column_widths=[132, 132, 132, 132],
|
202 |
+
value=self.battle_counts_data,
|
203 |
+
interactive=False,
|
204 |
+
)
|
205 |
+
with gr.Column(min_width=420):
|
206 |
+
win_rates_table = gr.DataFrame(
|
207 |
+
headers=["", "Hume AI", "OpenAI", "ElevenLabs"],
|
208 |
+
datatype=["html", "html", "html", "html"],
|
209 |
+
column_widths=[132, 132, 132, 132],
|
210 |
+
value=self.win_rates_data,
|
211 |
+
interactive=False,
|
212 |
+
)
|
213 |
+
|
214 |
+
with gr.Accordion(label="Citation", open=False):
|
215 |
+
with gr.Column(variant="panel"):
|
216 |
+
with gr.Column(variant="panel"):
|
217 |
+
gr.HTML(
|
218 |
+
value="""
|
219 |
+
<h2>Citation</h2>
|
220 |
+
<p style="padding: 0 8px;">
|
221 |
+
When referencing this leaderboard or its dataset in academic publications, please cite:
|
222 |
+
</p>
|
223 |
+
""",
|
224 |
+
padding=False,
|
225 |
+
)
|
226 |
+
gr.Markdown(
|
227 |
+
value="""
|
228 |
+
**BibTeX**
|
229 |
+
```BibTeX
|
230 |
+
@misc{expressive-tts-arena,
|
231 |
+
title = {Expressive TTS Arena: An Open Platform for Evaluating Text-to-Speech Expressiveness by Human Preference},
|
232 |
+
author = {Alan Cowen, Zachary Greathouse, Richard Marmorstein, Jeremy Hadfield},
|
233 |
+
year = {2025},
|
234 |
+
publisher = {Hugging Face},
|
235 |
+
howpublished = {\\url{https://huggingface.co/spaces/HumeAI/expressive-tts-arena}}
|
236 |
+
}
|
237 |
+
```
|
238 |
+
"""
|
239 |
+
)
|
240 |
+
gr.HTML(
|
241 |
+
value="""
|
242 |
+
<h2>Terms of Use</h2>
|
243 |
+
<p style="padding: 0 8px;">
|
244 |
+
Users are required to agree to the following terms before using the service:
|
245 |
+
</p>
|
246 |
+
<p style="padding: 0 8px;">
|
247 |
+
All generated audio clips are provided for research and evaluation purposes only.
|
248 |
+
The audio content may not be redistributed or used for commercial purposes without
|
249 |
+
explicit permission. Users should not upload any private or personally identifiable
|
250 |
+
information. Please report any bugs, issues, or concerns to our
|
251 |
+
<a href="https://discord.com/invite/humeai" target="_blank" class="provider-link">
|
252 |
+
Discord community
|
253 |
+
</a>.
|
254 |
+
</p>
|
255 |
+
""",
|
256 |
+
padding=False,
|
257 |
+
)
|
258 |
+
gr.HTML(
|
259 |
+
value="""
|
260 |
+
<h2>Acknowledgements</h2>
|
261 |
+
<p style="padding: 0 8px;">
|
262 |
+
We thank all participants who contributed their votes to help build this leaderboard.
|
263 |
+
</p>
|
264 |
+
""",
|
265 |
+
padding=False,
|
266 |
+
)
|
267 |
+
|
268 |
+
# Wrapper for the async refresh function
|
269 |
+
async def async_refresh_handler() -> Tuple[Union[dict, gr.skip], Union[dict, gr.skip], Union[dict, gr.skip]]:
|
270 |
+
"""Async helper to call refresh_leaderboard and handle its tuple return."""
|
271 |
+
logger.debug("Refresh button clicked, calling async_refresh_handler.")
|
272 |
+
return await self.refresh_leaderboard(force=True)
|
273 |
+
|
274 |
+
# Handler to re-enable the button after a short delay
|
275 |
+
def reenable_button() -> dict: # Returns a Gradio update dict
|
276 |
+
"""Waits briefly and returns an update to re-enable the refresh button."""
|
277 |
+
throttle_delay = 3 # seconds
|
278 |
+
time.sleep(throttle_delay) # Okay in Gradio event handlers (runs in thread)
|
279 |
+
return gr.update(interactive=True)
|
280 |
+
|
281 |
+
# Refresh button click event handler
|
282 |
+
refresh_button.click(
|
283 |
+
fn=lambda _=None: (gr.update(interactive=False)), # Disable button immediately
|
284 |
+
inputs=[],
|
285 |
+
outputs=[refresh_button],
|
286 |
+
).then(
|
287 |
+
fn=async_refresh_handler,
|
288 |
+
inputs=[],
|
289 |
+
outputs=[leaderboard_table, battle_counts_table, win_rates_table] # Update all three tables
|
290 |
+
).then(
|
291 |
+
fn=reenable_button, # Re-enable the button after a delay
|
292 |
+
inputs=[],
|
293 |
+
outputs=[refresh_button]
|
294 |
+
)
|
295 |
+
|
296 |
+
logger.debug("Leaderboard UI section built.")
|
297 |
+
# Return the component instances needed by the Frontend class
|
298 |
+
return leaderboard_table, battle_counts_table, win_rates_table
|
src/frontend/frontend.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Standard Library Imports
|
2 |
+
from typing import Tuple, Union
|
3 |
+
|
4 |
+
# Third-Party Library Imports
|
5 |
+
import gradio as gr
|
6 |
+
|
7 |
+
# Local Application Imports
|
8 |
+
from src.common import Config, logger
|
9 |
+
from src.core import TTSService, VotingService
|
10 |
+
from src.database import AsyncDBSessionMaker
|
11 |
+
|
12 |
+
from .components import Arena, Leaderboard
|
13 |
+
|
14 |
+
|
15 |
+
class Frontend:
|
16 |
+
"""
|
17 |
+
Main frontend class orchestrating the Gradio UI application.
|
18 |
+
|
19 |
+
Initializes and manages the Arena and Leaderboard components, builds the overall UI structure (Tabs, HTML),
|
20 |
+
and handles top-level events like tab selection.
|
21 |
+
"""
|
22 |
+
def __init__(self, config: Config, db_session_maker: AsyncDBSessionMaker):
|
23 |
+
"""
|
24 |
+
Initializes the Frontend application controller.
|
25 |
+
|
26 |
+
Args:
|
27 |
+
config: The application configuration object.
|
28 |
+
db_session_maker: An asynchronous database session factory.
|
29 |
+
"""
|
30 |
+
self.config = config
|
31 |
+
|
32 |
+
# Instantiate services
|
33 |
+
self.tts_service: TTSService = TTSService(config)
|
34 |
+
self.voting_service: VotingService = VotingService(db_session_maker)
|
35 |
+
logger.debug("Frontend initialized with TTSService and VotingService.")
|
36 |
+
|
37 |
+
# Initialize components with dependencies
|
38 |
+
self.arena = Arena(config, self.tts_service, self.voting_service)
|
39 |
+
self.leaderboard = Leaderboard(self.voting_service)
|
40 |
+
logger.debug("Frontend initialized with Arena and Leaderboard components.")
|
41 |
+
|
42 |
+
async def _handle_tab_select(self, evt: gr.SelectData) -> Tuple[
|
43 |
+
Union[dict, gr.skip],
|
44 |
+
Union[dict, gr.skip],
|
45 |
+
Union[dict, gr.skip],
|
46 |
+
]:
|
47 |
+
"""
|
48 |
+
Handles tab selection events. Refreshes leaderboard if its tab is selected.
|
49 |
+
|
50 |
+
Args:
|
51 |
+
evt: Gradio SelectData event, containing the selected tab's value (label).
|
52 |
+
|
53 |
+
Returns:
|
54 |
+
A tuple of Gradio update dictionaries for the leaderboard tables if the Leaderboard tab was selected
|
55 |
+
and data needed refreshing, otherwise a tuple of gr.skip() objects.
|
56 |
+
"""
|
57 |
+
selected_tab = evt.value
|
58 |
+
if selected_tab == "Leaderboard":
|
59 |
+
# Refresh leaderboard, but don't force it (allow cache/throttle)
|
60 |
+
return await self.leaderboard.refresh_leaderboard(force=False)
|
61 |
+
# Return skip updates for other tabs
|
62 |
+
return gr.skip(), gr.skip(), gr.skip()
|
63 |
+
|
64 |
+
async def build_gradio_interface(self) -> gr.Blocks:
|
65 |
+
"""
|
66 |
+
Builds and configures the complete Gradio Blocks UI.
|
67 |
+
|
68 |
+
Pre-loads initial leaderboard data, defines layout (HTML, Tabs), integrates Arena and Leaderboard sections,
|
69 |
+
and sets up tab selection handler.
|
70 |
+
|
71 |
+
Returns:
|
72 |
+
The fully constructed Gradio Blocks application instance.
|
73 |
+
"""
|
74 |
+
logger.info("Building Gradio interface...")
|
75 |
+
|
76 |
+
with gr.Blocks(title="Expressive TTS Arena", css_paths="static/css/styles.css") as demo:
|
77 |
+
# --- Header HTML ---
|
78 |
+
gr.HTML(
|
79 |
+
value="""
|
80 |
+
<div class="title-container">
|
81 |
+
<h1>Expressive TTS Arena</h1>
|
82 |
+
<div class="social-links">
|
83 |
+
<a
|
84 |
+
href="https://discord.com/invite/humeai"
|
85 |
+
target="_blank"
|
86 |
+
id="discord-link"
|
87 |
+
title="Join our Discord"
|
88 |
+
aria-label="Join our Discord server"
|
89 |
+
></a>
|
90 |
+
<a
|
91 |
+
href="https://github.com/HumeAI/expressive-tts-arena"
|
92 |
+
target="_blank"
|
93 |
+
id="github-link"
|
94 |
+
title="View on GitHub"
|
95 |
+
aria-label="View project on GitHub"
|
96 |
+
></a>
|
97 |
+
</div>
|
98 |
+
</div>
|
99 |
+
<div class="excerpt-container">
|
100 |
+
<p>
|
101 |
+
Join the community in evaluating text-to-speech models, and vote for the AI voice that best
|
102 |
+
captures the emotion, nuance, and expressiveness of human speech.
|
103 |
+
</p>
|
104 |
+
</div>
|
105 |
+
"""
|
106 |
+
)
|
107 |
+
|
108 |
+
# --- Tabs ---
|
109 |
+
with gr.Tabs() as tabs:
|
110 |
+
with gr.TabItem("Arena"):
|
111 |
+
self.arena.build_arena_section()
|
112 |
+
with gr.TabItem("Leaderboard"):
|
113 |
+
(
|
114 |
+
leaderboard_table,
|
115 |
+
battle_counts_table,
|
116 |
+
win_rates_table
|
117 |
+
) = await self.leaderboard.build_leaderboard_section()
|
118 |
+
|
119 |
+
# --- Top-level Event Handlers ---
|
120 |
+
tabs.select(
|
121 |
+
fn=self._handle_tab_select,
|
122 |
+
inputs=[],
|
123 |
+
outputs=[leaderboard_table, battle_counts_table, win_rates_table],
|
124 |
+
)
|
125 |
+
|
126 |
+
logger.debug("Gradio interface built successfully")
|
127 |
+
return demo
|
src/integrations/__init__.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
-
from .
|
2 |
-
from .
|
3 |
-
from .
|
4 |
-
from .
|
5 |
|
6 |
__all__ = [
|
7 |
"AnthropicConfig",
|
|
|
1 |
+
from .anthropic import AnthropicConfig, AnthropicError, generate_text_with_claude
|
2 |
+
from .elevenlabs import ElevenLabsConfig, ElevenLabsError, text_to_speech_with_elevenlabs
|
3 |
+
from .hume import HumeConfig, HumeError, text_to_speech_with_hume
|
4 |
+
from .openai import OpenAIConfig, OpenAIError, text_to_speech_with_openai
|
5 |
|
6 |
__all__ = [
|
7 |
"AnthropicConfig",
|
src/integrations/{anthropic_api.py β anthropic.py}
RENAMED
@@ -1,18 +1,6 @@
|
|
1 |
-
"""
|
2 |
-
anthropic_api.py
|
3 |
-
|
4 |
-
This file defines the asynchronous interaction with the Anthropic API, focusing on generating text using the Claude
|
5 |
-
model. It includes functionality for input validation, asynchronous API request handling, and processing API responses.
|
6 |
-
|
7 |
-
Key Features:
|
8 |
-
- Encapsulates all logic related to the Anthropic API.
|
9 |
-
- Implements asynchronous retry logic for handling transient API errors.
|
10 |
-
- Validates the response content to ensure API compatibility.
|
11 |
-
- Provides detailed logging for debugging and error tracking.
|
12 |
-
"""
|
13 |
-
|
14 |
# Standard Library Imports
|
15 |
import logging
|
|
|
16 |
from dataclasses import dataclass, field
|
17 |
from typing import List, Optional, Union
|
18 |
|
@@ -22,11 +10,11 @@ from anthropic.types import Message, ModelParam, TextBlock, ToolUseBlock
|
|
22 |
from tenacity import after_log, before_log, retry, retry_if_exception, stop_after_attempt, wait_exponential
|
23 |
|
24 |
# Local Application Imports
|
25 |
-
from src.
|
26 |
-
from src.constants import CLIENT_ERROR_CODE, GENERIC_API_ERROR_MESSAGE, SERVER_ERROR_CODE
|
27 |
-
from src.utils import
|
28 |
|
29 |
-
|
30 |
<role>
|
31 |
You are an expert at generating micro-content optimized for text-to-speech synthesis.
|
32 |
Your absolute priority is delivering complete, untruncated responses within strict length limits.
|
@@ -54,7 +42,7 @@ Your absolute priority is delivering complete, untruncated responses within stri
|
|
54 |
class AnthropicConfig:
|
55 |
"""Immutable configuration for interacting with the Anthropic API using the asynchronous client."""
|
56 |
api_key: str = field(init=False)
|
57 |
-
system_prompt: str =
|
58 |
model: ModelParam = "claude-3-5-sonnet-latest"
|
59 |
max_tokens: int = 300
|
60 |
|
@@ -64,15 +52,13 @@ class AnthropicConfig:
|
|
64 |
raise ValueError("Anthropic Model is not set.")
|
65 |
if not self.max_tokens:
|
66 |
raise ValueError("Anthropic Max Tokens is not set.")
|
|
|
|
|
67 |
|
68 |
# Compute the API key from the environment.
|
69 |
computed_api_key = validate_env_var("ANTHROPIC_API_KEY")
|
70 |
object.__setattr__(self, "api_key", computed_api_key)
|
71 |
|
72 |
-
# Compute the system prompt using max_tokens and other logic.
|
73 |
-
computed_prompt = PROMPT_TEMPLATE.format(max_tokens=self.max_tokens)
|
74 |
-
object.__setattr__(self, "system_prompt", computed_prompt)
|
75 |
-
|
76 |
@property
|
77 |
def client(self):
|
78 |
"""
|
@@ -181,20 +167,21 @@ async def generate_text_with_claude(character_description: str, config: Config)
|
|
181 |
UnretryableAnthropicError: For unretryable API errors.
|
182 |
AnthropicError: For other errors communicating with the Anthropic API.
|
183 |
"""
|
|
|
|
|
|
|
|
|
184 |
try:
|
185 |
-
anthropic_config = config.anthropic_config
|
186 |
prompt = anthropic_config.build_expressive_prompt(character_description)
|
187 |
-
|
188 |
-
|
189 |
-
assert anthropic_config.system_prompt is not None, "system_prompt must be set."
|
190 |
-
|
191 |
-
response: Message = await anthropic_config.client.messages.create(
|
192 |
model=anthropic_config.model,
|
193 |
max_tokens=anthropic_config.max_tokens,
|
194 |
system=anthropic_config.system_prompt,
|
195 |
messages=[{"role": "user", "content": prompt}],
|
196 |
)
|
197 |
-
|
|
|
|
|
198 |
|
199 |
if not hasattr(response, "content") or response.content is None:
|
200 |
logger.error("Response is missing 'content'. Response: %s", response)
|
@@ -204,26 +191,25 @@ async def generate_text_with_claude(character_description: str, config: Config)
|
|
204 |
|
205 |
if isinstance(blocks, list):
|
206 |
result = "\n\n".join(block.text for block in blocks if isinstance(block, TextBlock))
|
207 |
-
logger.debug(f"Processed response from list: {truncate_text(result)}")
|
208 |
return result
|
209 |
|
210 |
if isinstance(blocks, TextBlock):
|
211 |
-
logger.debug(f"Processed response from single TextBlock: {truncate_text(blocks.text)}")
|
212 |
return blocks.text
|
213 |
|
214 |
logger.warning(f"Unexpected response type: {type(blocks)}")
|
215 |
return str(blocks or "No content generated.")
|
216 |
|
217 |
except APIError as e:
|
218 |
-
|
219 |
-
|
|
|
|
|
220 |
|
221 |
-
if (
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
raise UnretryableAnthropicError(message=clean_message, original_exception=e) from e
|
227 |
|
228 |
raise AnthropicError(message=clean_message, original_exception=e) from e
|
229 |
|
@@ -236,7 +222,7 @@ async def generate_text_with_claude(character_description: str, config: Config)
|
|
236 |
raise AnthropicError(message=clean_message, original_exception=e) from e
|
237 |
|
238 |
|
239 |
-
def
|
240 |
"""
|
241 |
Extracts a clean, user-friendly error message from an Anthropic API error response.
|
242 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
# Standard Library Imports
|
2 |
import logging
|
3 |
+
import time
|
4 |
from dataclasses import dataclass, field
|
5 |
from typing import List, Optional, Union
|
6 |
|
|
|
10 |
from tenacity import after_log, before_log, retry, retry_if_exception, stop_after_attempt, wait_exponential
|
11 |
|
12 |
# Local Application Imports
|
13 |
+
from src.common import Config, logger
|
14 |
+
from src.common.constants import CLIENT_ERROR_CODE, GENERIC_API_ERROR_MESSAGE, RATE_LIMIT_ERROR_CODE, SERVER_ERROR_CODE
|
15 |
+
from src.common.utils import validate_env_var
|
16 |
|
17 |
+
SYSTEM_PROMPT: str = """
|
18 |
<role>
|
19 |
You are an expert at generating micro-content optimized for text-to-speech synthesis.
|
20 |
Your absolute priority is delivering complete, untruncated responses within strict length limits.
|
|
|
42 |
class AnthropicConfig:
|
43 |
"""Immutable configuration for interacting with the Anthropic API using the asynchronous client."""
|
44 |
api_key: str = field(init=False)
|
45 |
+
system_prompt: str = SYSTEM_PROMPT
|
46 |
model: ModelParam = "claude-3-5-sonnet-latest"
|
47 |
max_tokens: int = 300
|
48 |
|
|
|
52 |
raise ValueError("Anthropic Model is not set.")
|
53 |
if not self.max_tokens:
|
54 |
raise ValueError("Anthropic Max Tokens is not set.")
|
55 |
+
if not self.system_prompt:
|
56 |
+
raise ValueError("Anthropic system prompt is not set.")
|
57 |
|
58 |
# Compute the API key from the environment.
|
59 |
computed_api_key = validate_env_var("ANTHROPIC_API_KEY")
|
60 |
object.__setattr__(self, "api_key", computed_api_key)
|
61 |
|
|
|
|
|
|
|
|
|
62 |
@property
|
63 |
def client(self):
|
64 |
"""
|
|
|
167 |
UnretryableAnthropicError: For unretryable API errors.
|
168 |
AnthropicError: For other errors communicating with the Anthropic API.
|
169 |
"""
|
170 |
+
logger.debug("Generating text with Anthropic.")
|
171 |
+
anthropic_config = config.anthropic_config
|
172 |
+
client = anthropic_config.client
|
173 |
+
start_time = time.time()
|
174 |
try:
|
|
|
175 |
prompt = anthropic_config.build_expressive_prompt(character_description)
|
176 |
+
response: Message = await client.messages.create(
|
|
|
|
|
|
|
|
|
177 |
model=anthropic_config.model,
|
178 |
max_tokens=anthropic_config.max_tokens,
|
179 |
system=anthropic_config.system_prompt,
|
180 |
messages=[{"role": "user", "content": prompt}],
|
181 |
)
|
182 |
+
|
183 |
+
elapsed_time = time.time() - start_time
|
184 |
+
logger.info(f"Anthropic API request completed in {elapsed_time:.2f} seconds.")
|
185 |
|
186 |
if not hasattr(response, "content") or response.content is None:
|
187 |
logger.error("Response is missing 'content'. Response: %s", response)
|
|
|
191 |
|
192 |
if isinstance(blocks, list):
|
193 |
result = "\n\n".join(block.text for block in blocks if isinstance(block, TextBlock))
|
|
|
194 |
return result
|
195 |
|
196 |
if isinstance(blocks, TextBlock):
|
|
|
197 |
return blocks.text
|
198 |
|
199 |
logger.warning(f"Unexpected response type: {type(blocks)}")
|
200 |
return str(blocks or "No content generated.")
|
201 |
|
202 |
except APIError as e:
|
203 |
+
elapsed_time = time.time() - start_time
|
204 |
+
logger.error(f"Anthropic API request failed after {elapsed_time:.2f} seconds: {e!s}")
|
205 |
+
logger.error(f"Full Anthropic API error: {e!s}")
|
206 |
+
clean_message = __extract_anthropic_error_message(e)
|
207 |
|
208 |
+
if hasattr(e, 'status_code') and e.status_code is not None:
|
209 |
+
if e.status_code == RATE_LIMIT_ERROR_CODE:
|
210 |
+
raise AnthropicError(message=clean_message, original_exception=e) from e
|
211 |
+
if CLIENT_ERROR_CODE <= e.status_code < SERVER_ERROR_CODE:
|
212 |
+
raise UnretryableAnthropicError(message=clean_message, original_exception=e) from e
|
|
|
213 |
|
214 |
raise AnthropicError(message=clean_message, original_exception=e) from e
|
215 |
|
|
|
222 |
raise AnthropicError(message=clean_message, original_exception=e) from e
|
223 |
|
224 |
|
225 |
+
def __extract_anthropic_error_message(e: APIError) -> str:
|
226 |
"""
|
227 |
Extracts a clean, user-friendly error message from an Anthropic API error response.
|
228 |
|
src/integrations/{elevenlabs_api.py β elevenlabs.py}
RENAMED
@@ -1,17 +1,3 @@
|
|
1 |
-
"""
|
2 |
-
elevenlabs_api.py
|
3 |
-
|
4 |
-
This file defines the interaction with the ElevenLabs text-to-speech (TTS) API using the
|
5 |
-
ElevenLabs Python SDK. It includes functionality for API request handling and processing API responses.
|
6 |
-
|
7 |
-
Key Features:
|
8 |
-
- Encapsulates all logic related to the ElevenLabs TTS API.
|
9 |
-
- Implements retry logic using Tenacity for handling transient API errors.
|
10 |
-
- Handles received audio and processes it for playback on the web.
|
11 |
-
- Provides detailed logging for debugging and error tracking.
|
12 |
-
- Utilizes robust error handling (EAFP) to validate API responses.
|
13 |
-
"""
|
14 |
-
|
15 |
# Standard Library Imports
|
16 |
import logging
|
17 |
import random
|
@@ -25,9 +11,8 @@ from elevenlabs.core import ApiError
|
|
25 |
from tenacity import after_log, before_log, retry, retry_if_exception, stop_after_attempt, wait_fixed
|
26 |
|
27 |
# Local Application Imports
|
28 |
-
from src.
|
29 |
-
from src.constants import CLIENT_ERROR_CODE, GENERIC_API_ERROR_MESSAGE, SERVER_ERROR_CODE
|
30 |
-
from src.utils import save_base64_audio_to_file, validate_env_var
|
31 |
|
32 |
|
33 |
@dataclass(frozen=True)
|
@@ -55,7 +40,6 @@ class ElevenLabsConfig:
|
|
55 |
"""
|
56 |
return AsyncElevenLabs(api_key=self.api_key)
|
57 |
|
58 |
-
|
59 |
class ElevenLabsError(Exception):
|
60 |
"""Custom exception for errors related to the ElevenLabs TTS API."""
|
61 |
|
@@ -64,7 +48,6 @@ class ElevenLabsError(Exception):
|
|
64 |
self.original_exception = original_exception
|
65 |
self.message = message
|
66 |
|
67 |
-
|
68 |
class UnretryableElevenLabsError(ElevenLabsError):
|
69 |
"""Custom exception for errors related to the ElevenLabs TTS API that should not be retried."""
|
70 |
|
@@ -73,7 +56,6 @@ class UnretryableElevenLabsError(ElevenLabsError):
|
|
73 |
self.original_exception = original_exception
|
74 |
self.message = message
|
75 |
|
76 |
-
|
77 |
@retry(
|
78 |
retry=retry_if_exception(lambda e: not isinstance(e, UnretryableElevenLabsError)),
|
79 |
stop=stop_after_attempt(2),
|
@@ -113,7 +95,7 @@ async def text_to_speech_with_elevenlabs(
|
|
113 |
)
|
114 |
|
115 |
elapsed_time = time.time() - start_time
|
116 |
-
logger.info(f"Elevenlabs API request completed in {elapsed_time:.2f} seconds")
|
117 |
|
118 |
previews = response.previews
|
119 |
if not previews:
|
@@ -129,10 +111,13 @@ async def text_to_speech_with_elevenlabs(
|
|
129 |
|
130 |
except ApiError as e:
|
131 |
logger.error(f"ElevenLabs API request failed: {e!s}")
|
132 |
-
clean_message =
|
133 |
|
134 |
-
if e.status_code is not None
|
135 |
-
|
|
|
|
|
|
|
136 |
|
137 |
raise ElevenLabsError(message=clean_message, original_exception=e) from e
|
138 |
|
@@ -144,8 +129,7 @@ async def text_to_speech_with_elevenlabs(
|
|
144 |
|
145 |
raise ElevenLabsError(message=error_message, original_exception=e) from e
|
146 |
|
147 |
-
|
148 |
-
def _extract_elevenlabs_error_message(e: ApiError) -> str:
|
149 |
"""
|
150 |
Extracts a clean, user-friendly error message from an ElevenLabs API error response.
|
151 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
# Standard Library Imports
|
2 |
import logging
|
3 |
import random
|
|
|
11 |
from tenacity import after_log, before_log, retry, retry_if_exception, stop_after_attempt, wait_fixed
|
12 |
|
13 |
# Local Application Imports
|
14 |
+
from src.common import Config, logger, save_base64_audio_to_file, validate_env_var
|
15 |
+
from src.common.constants import CLIENT_ERROR_CODE, GENERIC_API_ERROR_MESSAGE, RATE_LIMIT_ERROR_CODE, SERVER_ERROR_CODE
|
|
|
16 |
|
17 |
|
18 |
@dataclass(frozen=True)
|
|
|
40 |
"""
|
41 |
return AsyncElevenLabs(api_key=self.api_key)
|
42 |
|
|
|
43 |
class ElevenLabsError(Exception):
|
44 |
"""Custom exception for errors related to the ElevenLabs TTS API."""
|
45 |
|
|
|
48 |
self.original_exception = original_exception
|
49 |
self.message = message
|
50 |
|
|
|
51 |
class UnretryableElevenLabsError(ElevenLabsError):
|
52 |
"""Custom exception for errors related to the ElevenLabs TTS API that should not be retried."""
|
53 |
|
|
|
56 |
self.original_exception = original_exception
|
57 |
self.message = message
|
58 |
|
|
|
59 |
@retry(
|
60 |
retry=retry_if_exception(lambda e: not isinstance(e, UnretryableElevenLabsError)),
|
61 |
stop=stop_after_attempt(2),
|
|
|
95 |
)
|
96 |
|
97 |
elapsed_time = time.time() - start_time
|
98 |
+
logger.info(f"Elevenlabs API request completed in {elapsed_time:.2f} seconds.")
|
99 |
|
100 |
previews = response.previews
|
101 |
if not previews:
|
|
|
111 |
|
112 |
except ApiError as e:
|
113 |
logger.error(f"ElevenLabs API request failed: {e!s}")
|
114 |
+
clean_message = __extract_elevenlabs_error_message(e)
|
115 |
|
116 |
+
if hasattr(e, 'status_code') and e.status_code is not None:
|
117 |
+
if e.status_code == RATE_LIMIT_ERROR_CODE:
|
118 |
+
raise ElevenLabsError(message=clean_message, original_exception=e) from e
|
119 |
+
if CLIENT_ERROR_CODE <= e.status_code < SERVER_ERROR_CODE:
|
120 |
+
raise UnretryableElevenLabsError(message=clean_message, original_exception=e) from e
|
121 |
|
122 |
raise ElevenLabsError(message=clean_message, original_exception=e) from e
|
123 |
|
|
|
129 |
|
130 |
raise ElevenLabsError(message=error_message, original_exception=e) from e
|
131 |
|
132 |
+
def __extract_elevenlabs_error_message(e: ApiError) -> str:
|
|
|
133 |
"""
|
134 |
Extracts a clean, user-friendly error message from an ElevenLabs API error response.
|
135 |
|
src/integrations/{hume_api.py β hume.py}
RENAMED
@@ -1,16 +1,3 @@
|
|
1 |
-
"""
|
2 |
-
hume_api.py
|
3 |
-
|
4 |
-
This file defines the interaction with the Hume text-to-speech (TTS) API using the
|
5 |
-
Hume Python SDK. It includes functionality for API request handling and processing API responses.
|
6 |
-
|
7 |
-
Key Features:
|
8 |
-
- Encapsulates all logic related to the Hume TTS API.
|
9 |
-
- Implements retry logic for handling transient API errors.
|
10 |
-
- Handles received audio and processes it for playback on the web.
|
11 |
-
- Provides detailed logging for debugging and error tracking.
|
12 |
-
"""
|
13 |
-
|
14 |
# Standard Library Imports
|
15 |
import logging
|
16 |
import time
|
@@ -24,9 +11,8 @@ from hume.tts.types import Format, FormatMp3, PostedUtterance, ReturnTts
|
|
24 |
from tenacity import after_log, before_log, retry, retry_if_exception, stop_after_attempt, wait_fixed
|
25 |
|
26 |
# Local Application Imports
|
27 |
-
from src.
|
28 |
-
from src.constants import CLIENT_ERROR_CODE, GENERIC_API_ERROR_MESSAGE, RATE_LIMIT_ERROR_CODE, SERVER_ERROR_CODE
|
29 |
-
from src.utils import save_base64_audio_to_file, validate_env_var
|
30 |
|
31 |
|
32 |
@dataclass(frozen=True)
|
@@ -58,7 +44,6 @@ class HumeConfig:
|
|
58 |
timeout=self.request_timeout
|
59 |
)
|
60 |
|
61 |
-
|
62 |
class HumeError(Exception):
|
63 |
"""Custom exception for errors related to the Hume TTS API."""
|
64 |
|
@@ -67,7 +52,6 @@ class HumeError(Exception):
|
|
67 |
self.original_exception = original_exception
|
68 |
self.message = message
|
69 |
|
70 |
-
|
71 |
class UnretryableHumeError(HumeError):
|
72 |
"""Custom exception for errors related to the Hume TTS API that should not be retried."""
|
73 |
|
@@ -76,7 +60,6 @@ class UnretryableHumeError(HumeError):
|
|
76 |
self.original_exception = original_exception
|
77 |
self.message = message
|
78 |
|
79 |
-
|
80 |
@retry(
|
81 |
retry=retry_if_exception(lambda e: not isinstance(e, UnretryableHumeError)),
|
82 |
stop=stop_after_attempt(2),
|
@@ -123,7 +106,7 @@ async def text_to_speech_with_hume(
|
|
123 |
)
|
124 |
|
125 |
elapsed_time = time.time() - start_time
|
126 |
-
logger.info(f"Hume API request completed in {elapsed_time:.2f} seconds")
|
127 |
|
128 |
generations = response.generations
|
129 |
if not generations:
|
@@ -140,10 +123,10 @@ async def text_to_speech_with_hume(
|
|
140 |
except ApiError as e:
|
141 |
elapsed_time = time.time() - start_time
|
142 |
logger.error(f"Hume API request failed after {elapsed_time:.2f} seconds: {e!s}")
|
143 |
-
clean_message =
|
144 |
logger.error(f"Full Hume API error: {e!s}")
|
145 |
|
146 |
-
if e.status_code is not None:
|
147 |
if e.status_code == RATE_LIMIT_ERROR_CODE:
|
148 |
rate_limit_error_message = "We're working on scaling capacity. Please try again in a few seconds."
|
149 |
raise HumeError(message=rate_limit_error_message, original_exception=e) from e
|
@@ -160,8 +143,7 @@ async def text_to_speech_with_hume(
|
|
160 |
|
161 |
raise HumeError(message=clean_message, original_exception=e) from e
|
162 |
|
163 |
-
|
164 |
-
def _extract_hume_api_error_message(e: ApiError) -> str:
|
165 |
"""
|
166 |
Extracts a clean, user-friendly error message from a Hume API error response.
|
167 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
# Standard Library Imports
|
2 |
import logging
|
3 |
import time
|
|
|
11 |
from tenacity import after_log, before_log, retry, retry_if_exception, stop_after_attempt, wait_fixed
|
12 |
|
13 |
# Local Application Imports
|
14 |
+
from src.common import Config, logger, save_base64_audio_to_file, validate_env_var
|
15 |
+
from src.common.constants import CLIENT_ERROR_CODE, GENERIC_API_ERROR_MESSAGE, RATE_LIMIT_ERROR_CODE, SERVER_ERROR_CODE
|
|
|
16 |
|
17 |
|
18 |
@dataclass(frozen=True)
|
|
|
44 |
timeout=self.request_timeout
|
45 |
)
|
46 |
|
|
|
47 |
class HumeError(Exception):
|
48 |
"""Custom exception for errors related to the Hume TTS API."""
|
49 |
|
|
|
52 |
self.original_exception = original_exception
|
53 |
self.message = message
|
54 |
|
|
|
55 |
class UnretryableHumeError(HumeError):
|
56 |
"""Custom exception for errors related to the Hume TTS API that should not be retried."""
|
57 |
|
|
|
60 |
self.original_exception = original_exception
|
61 |
self.message = message
|
62 |
|
|
|
63 |
@retry(
|
64 |
retry=retry_if_exception(lambda e: not isinstance(e, UnretryableHumeError)),
|
65 |
stop=stop_after_attempt(2),
|
|
|
106 |
)
|
107 |
|
108 |
elapsed_time = time.time() - start_time
|
109 |
+
logger.info(f"Hume API request completed in {elapsed_time:.2f} seconds.")
|
110 |
|
111 |
generations = response.generations
|
112 |
if not generations:
|
|
|
123 |
except ApiError as e:
|
124 |
elapsed_time = time.time() - start_time
|
125 |
logger.error(f"Hume API request failed after {elapsed_time:.2f} seconds: {e!s}")
|
126 |
+
clean_message = __extract_hume_api_error_message(e)
|
127 |
logger.error(f"Full Hume API error: {e!s}")
|
128 |
|
129 |
+
if hasattr(e, 'status_code') and e.status_code is not None:
|
130 |
if e.status_code == RATE_LIMIT_ERROR_CODE:
|
131 |
rate_limit_error_message = "We're working on scaling capacity. Please try again in a few seconds."
|
132 |
raise HumeError(message=rate_limit_error_message, original_exception=e) from e
|
|
|
143 |
|
144 |
raise HumeError(message=clean_message, original_exception=e) from e
|
145 |
|
146 |
+
def __extract_hume_api_error_message(e: ApiError) -> str:
|
|
|
147 |
"""
|
148 |
Extracts a clean, user-friendly error message from a Hume API error response.
|
149 |
|
src/integrations/{openai_api.py β openai.py}
RENAMED
@@ -1,17 +1,3 @@
|
|
1 |
-
"""
|
2 |
-
openai_api.py
|
3 |
-
|
4 |
-
This file defines the interaction with the OpenAI text-to-speech (TTS) API using the
|
5 |
-
OpenAI Python SDK. It includes functionality for API request handling and processing API responses.
|
6 |
-
|
7 |
-
Key Features:
|
8 |
-
- Encapsulates all logic related to the OpenAI TTS API.
|
9 |
-
- Implements retry logic using Tenacity for handling transient API errors.
|
10 |
-
- Handles received audio and processes it for playback on the web.
|
11 |
-
- Provides detailed logging for debugging and error tracking.
|
12 |
-
- Utilizes robust error handling (EAFP) to validate API responses.
|
13 |
-
"""
|
14 |
-
|
15 |
# Standard Library Imports
|
16 |
import logging
|
17 |
import random
|
@@ -25,9 +11,9 @@ from openai import APIError, AsyncOpenAI
|
|
25 |
from tenacity import after_log, before_log, retry, retry_if_exception, stop_after_attempt, wait_fixed
|
26 |
|
27 |
# Local Application Imports
|
28 |
-
from src.
|
29 |
-
from src.constants import CLIENT_ERROR_CODE, GENERIC_API_ERROR_MESSAGE, SERVER_ERROR_CODE
|
30 |
-
from src.utils import validate_env_var
|
31 |
|
32 |
|
33 |
@dataclass(frozen=True)
|
@@ -68,7 +54,6 @@ class OpenAIConfig:
|
|
68 |
openai_base_voices = ["alloy", "ash", "coral", "echo", "fable", "onyx", "nova", "sage", "shimmer"]
|
69 |
return random.choice(openai_base_voices)
|
70 |
|
71 |
-
|
72 |
class OpenAIError(Exception):
|
73 |
"""Custom exception for errors related to the OpenAI TTS API."""
|
74 |
|
@@ -77,7 +62,6 @@ class OpenAIError(Exception):
|
|
77 |
self.original_exception = original_exception
|
78 |
self.message = message
|
79 |
|
80 |
-
|
81 |
class UnretryableOpenAIError(OpenAIError):
|
82 |
"""Custom exception for errors related to the OpenAI TTS API that should not be retried."""
|
83 |
|
@@ -86,7 +70,6 @@ class UnretryableOpenAIError(OpenAIError):
|
|
86 |
self.original_exception = original_exception
|
87 |
self.message = message
|
88 |
|
89 |
-
|
90 |
@retry(
|
91 |
retry=retry_if_exception(lambda e: not isinstance(e, UnretryableOpenAIError)),
|
92 |
stop=stop_after_attempt(2),
|
@@ -135,7 +118,7 @@ async def text_to_speech_with_openai(
|
|
135 |
voice=voice, # OpenAI requires a base voice to be specified
|
136 |
) as response:
|
137 |
elapsed_time = time.time() - start_time
|
138 |
-
logger.info(f"OpenAI API request completed in {elapsed_time:.2f} seconds")
|
139 |
|
140 |
filename = f"openai_{voice}_{start_time}"
|
141 |
audio_file_path = Path(config.audio_dir) / filename
|
@@ -148,14 +131,13 @@ async def text_to_speech_with_openai(
|
|
148 |
elapsed_time = time.time() - start_time
|
149 |
logger.error(f"OpenAI API request failed after {elapsed_time:.2f} seconds: {e!s}")
|
150 |
logger.error(f"Full OpenAI API error: {e!s}")
|
151 |
-
clean_message =
|
152 |
|
153 |
-
if (
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
raise UnretryableOpenAIError(message=clean_message, original_exception=e) from e
|
159 |
|
160 |
raise OpenAIError(message=clean_message, original_exception=e) from e
|
161 |
|
@@ -167,8 +149,7 @@ async def text_to_speech_with_openai(
|
|
167 |
|
168 |
raise OpenAIError(message=clean_message, original_exception=e) from e
|
169 |
|
170 |
-
|
171 |
-
def _extract_openai_error_message(e: APIError) -> str:
|
172 |
"""
|
173 |
Extracts a clean, user-friendly error message from an OpenAI API error response.
|
174 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
# Standard Library Imports
|
2 |
import logging
|
3 |
import random
|
|
|
11 |
from tenacity import after_log, before_log, retry, retry_if_exception, stop_after_attempt, wait_fixed
|
12 |
|
13 |
# Local Application Imports
|
14 |
+
from src.common import Config, logger
|
15 |
+
from src.common.constants import CLIENT_ERROR_CODE, GENERIC_API_ERROR_MESSAGE, RATE_LIMIT_ERROR_CODE, SERVER_ERROR_CODE
|
16 |
+
from src.common.utils import validate_env_var
|
17 |
|
18 |
|
19 |
@dataclass(frozen=True)
|
|
|
54 |
openai_base_voices = ["alloy", "ash", "coral", "echo", "fable", "onyx", "nova", "sage", "shimmer"]
|
55 |
return random.choice(openai_base_voices)
|
56 |
|
|
|
57 |
class OpenAIError(Exception):
|
58 |
"""Custom exception for errors related to the OpenAI TTS API."""
|
59 |
|
|
|
62 |
self.original_exception = original_exception
|
63 |
self.message = message
|
64 |
|
|
|
65 |
class UnretryableOpenAIError(OpenAIError):
|
66 |
"""Custom exception for errors related to the OpenAI TTS API that should not be retried."""
|
67 |
|
|
|
70 |
self.original_exception = original_exception
|
71 |
self.message = message
|
72 |
|
|
|
73 |
@retry(
|
74 |
retry=retry_if_exception(lambda e: not isinstance(e, UnretryableOpenAIError)),
|
75 |
stop=stop_after_attempt(2),
|
|
|
118 |
voice=voice, # OpenAI requires a base voice to be specified
|
119 |
) as response:
|
120 |
elapsed_time = time.time() - start_time
|
121 |
+
logger.info(f"OpenAI API request completed in {elapsed_time:.2f} seconds.")
|
122 |
|
123 |
filename = f"openai_{voice}_{start_time}"
|
124 |
audio_file_path = Path(config.audio_dir) / filename
|
|
|
131 |
elapsed_time = time.time() - start_time
|
132 |
logger.error(f"OpenAI API request failed after {elapsed_time:.2f} seconds: {e!s}")
|
133 |
logger.error(f"Full OpenAI API error: {e!s}")
|
134 |
+
clean_message = __extract_openai_error_message(e)
|
135 |
|
136 |
+
if hasattr(e, 'status_code') and e.status_code is not None:
|
137 |
+
if e.status_code == RATE_LIMIT_ERROR_CODE:
|
138 |
+
raise OpenAIError(message=clean_message, original_exception=e) from e
|
139 |
+
if CLIENT_ERROR_CODE <= e.status_code < SERVER_ERROR_CODE:
|
140 |
+
raise UnretryableOpenAIError(message=clean_message, original_exception=e) from e
|
|
|
141 |
|
142 |
raise OpenAIError(message=clean_message, original_exception=e) from e
|
143 |
|
|
|
149 |
|
150 |
raise OpenAIError(message=clean_message, original_exception=e) from e
|
151 |
|
152 |
+
def __extract_openai_error_message(e: APIError) -> str:
|
|
|
153 |
"""
|
154 |
Extracts a clean, user-friendly error message from an OpenAI API error response.
|
155 |
|
src/main.py
CHANGED
@@ -1,80 +1,17 @@
|
|
1 |
-
"""
|
2 |
-
main.py
|
3 |
-
|
4 |
-
This module is the entry point for the app. It loads configuration and starts the Gradio app.
|
5 |
-
"""
|
6 |
-
|
7 |
# Standard Library Imports
|
8 |
import asyncio
|
9 |
from pathlib import Path
|
10 |
-
from typing import Awaitable, Callable
|
11 |
|
12 |
# Third-Party Library Imports
|
13 |
import gradio as gr
|
14 |
-
from fastapi import FastAPI
|
15 |
-
from fastapi.responses import Response
|
16 |
from fastapi.staticfiles import StaticFiles
|
17 |
-
from starlette.middleware.base import BaseHTTPMiddleware
|
18 |
-
|
19 |
-
from src.config import Config, logger
|
20 |
-
from src.constants import META_TAGS
|
21 |
-
from src.database import init_db
|
22 |
|
23 |
# Local Application Imports
|
|
|
|
|
24 |
from src.frontend import Frontend
|
25 |
-
from src.
|
26 |
-
|
27 |
-
|
28 |
-
class ResponseModifierMiddleware(BaseHTTPMiddleware):
|
29 |
-
"""
|
30 |
-
FastAPI middleware that safely intercepts and modifies the HTML response from the root endpoint
|
31 |
-
to inject custom meta tags into the document head.
|
32 |
-
|
33 |
-
This middleware specifically targets the root path ('/') and leaves all other endpoint
|
34 |
-
responses unmodified. It uses BeautifulSoup to properly parse and modify the HTML,
|
35 |
-
ensuring that JavaScript functionality remains intact.
|
36 |
-
"""
|
37 |
-
async def dispatch(
|
38 |
-
self,
|
39 |
-
request: Request,
|
40 |
-
call_next: Callable[[Request], Awaitable[Response]]
|
41 |
-
) -> Response:
|
42 |
-
# Process the request and get the response
|
43 |
-
response = await call_next(request)
|
44 |
-
|
45 |
-
# Only intercept responses from the root endpoint and HTML content
|
46 |
-
if request.url.path == "/" and response.headers.get("content-type", "").startswith("text/html"):
|
47 |
-
# Get the response body
|
48 |
-
response_body = b""
|
49 |
-
async for chunk in response.body_iterator:
|
50 |
-
response_body += chunk
|
51 |
-
|
52 |
-
try:
|
53 |
-
# Decode, modify, and re-encode the content
|
54 |
-
content = response_body.decode("utf-8")
|
55 |
-
modified_content = update_meta_tags(content, META_TAGS).encode("utf-8")
|
56 |
-
|
57 |
-
# Update content-length header to reflect modified content size
|
58 |
-
headers = dict(response.headers)
|
59 |
-
headers["content-length"] = str(len(modified_content))
|
60 |
-
|
61 |
-
# Create a new response with the modified content
|
62 |
-
return Response(
|
63 |
-
content=modified_content,
|
64 |
-
status_code=response.status_code,
|
65 |
-
headers=headers,
|
66 |
-
media_type=response.media_type
|
67 |
-
)
|
68 |
-
except Exception:
|
69 |
-
# If there's an error, return the original response
|
70 |
-
return Response(
|
71 |
-
content=response_body,
|
72 |
-
status_code=response.status_code,
|
73 |
-
headers=dict(response.headers),
|
74 |
-
media_type=response.media_type
|
75 |
-
)
|
76 |
-
|
77 |
-
return response
|
78 |
|
79 |
|
80 |
async def main():
|
@@ -89,7 +26,7 @@ async def main():
|
|
89 |
demo = await frontend.build_gradio_interface()
|
90 |
|
91 |
app = FastAPI()
|
92 |
-
app.add_middleware(
|
93 |
|
94 |
public_dir = Path("public")
|
95 |
app.mount("/static", StaticFiles(directory=public_dir), name="static")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
# Standard Library Imports
|
2 |
import asyncio
|
3 |
from pathlib import Path
|
|
|
4 |
|
5 |
# Third-Party Library Imports
|
6 |
import gradio as gr
|
7 |
+
from fastapi import FastAPI
|
|
|
8 |
from fastapi.staticfiles import StaticFiles
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
# Local Application Imports
|
11 |
+
from src.common import Config, logger
|
12 |
+
from src.database import init_db
|
13 |
from src.frontend import Frontend
|
14 |
+
from src.middleware import MetaTagInjectionMiddleware
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
|
17 |
async def main():
|
|
|
26 |
demo = await frontend.build_gradio_interface()
|
27 |
|
28 |
app = FastAPI()
|
29 |
+
app.add_middleware(MetaTagInjectionMiddleware)
|
30 |
|
31 |
public_dir = Path("public")
|
32 |
app.mount("/static", StaticFiles(directory=public_dir), name="static")
|
src/middleware/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from src.middleware.meta_tag_injection import MetaTagInjectionMiddleware
|
2 |
+
|
3 |
+
__all__ = ["MetaTagInjectionMiddleware"]
|
src/middleware/meta_tag_injection.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Standard Library Imports
|
2 |
+
from typing import Awaitable, Callable, Dict, List
|
3 |
+
|
4 |
+
# Third-Party Library Imports
|
5 |
+
from bs4 import BeautifulSoup
|
6 |
+
from fastapi import Request
|
7 |
+
from fastapi.responses import Response
|
8 |
+
from starlette.middleware.base import BaseHTTPMiddleware
|
9 |
+
|
10 |
+
# HTML and social media metadata for the Gradio application
|
11 |
+
# These tags define SEO-friendly content and provide rich previews when shared on social platforms
|
12 |
+
META_TAGS: List[Dict[str, str]] = [
|
13 |
+
# HTML Meta Tags (description)
|
14 |
+
{
|
15 |
+
'name': 'description',
|
16 |
+
'content': 'An open-source web application for comparing and evaluating the expressiveness of different text-to-speech models, including Hume AI and ElevenLabs.'
|
17 |
+
},
|
18 |
+
# Facebook Meta Tags
|
19 |
+
{
|
20 |
+
'property': 'og:url',
|
21 |
+
'content': 'https://hume.ai'
|
22 |
+
},
|
23 |
+
{
|
24 |
+
'property': 'og:type',
|
25 |
+
'content': 'website'
|
26 |
+
},
|
27 |
+
{
|
28 |
+
'property': 'og:title',
|
29 |
+
'content': 'Expressive TTS Arena'
|
30 |
+
},
|
31 |
+
{
|
32 |
+
'property': 'og:description',
|
33 |
+
'content': 'An open-source web application for comparing and evaluating the expressiveness of different text-to-speech models, including Hume AI and ElevenLabs.'
|
34 |
+
},
|
35 |
+
{
|
36 |
+
'property': 'og:image',
|
37 |
+
'content': '/static/arena-opengraph-logo.png'
|
38 |
+
},
|
39 |
+
# Twitter Meta Tags
|
40 |
+
{
|
41 |
+
'name': 'twitter:card',
|
42 |
+
'content': 'summary_large_image'
|
43 |
+
},
|
44 |
+
{
|
45 |
+
'property': 'twitter:domain',
|
46 |
+
'content': 'hume.ai'
|
47 |
+
},
|
48 |
+
{
|
49 |
+
'property': 'twitter:url',
|
50 |
+
'content': 'https://hume.ai'
|
51 |
+
},
|
52 |
+
{
|
53 |
+
'name': 'twitter:creator',
|
54 |
+
'content': '@hume_ai'
|
55 |
+
},
|
56 |
+
{
|
57 |
+
'name': 'twitter:title',
|
58 |
+
'content': 'Expressive TTS Arena'
|
59 |
+
},
|
60 |
+
{
|
61 |
+
'name': 'twitter:description',
|
62 |
+
'content': 'An open-source web application for comparing and evaluating the expressiveness of different text-to-speech models, including Hume AI and ElevenLabs.'
|
63 |
+
},
|
64 |
+
{
|
65 |
+
'name': 'twitter:image',
|
66 |
+
'content': '/static/arena-opengraph-logo.png'
|
67 |
+
}
|
68 |
+
]
|
69 |
+
|
70 |
+
def __update_meta_tags(html_content: str, meta_tags: List[Dict[str, str]]) -> str:
|
71 |
+
"""
|
72 |
+
Safely updates the HTML content by adding or replacing meta tags in the head section
|
73 |
+
without affecting other elements, especially scripts and event handlers.
|
74 |
+
|
75 |
+
Args:
|
76 |
+
html_content: The original HTML content as a string
|
77 |
+
meta_tags: A list of dictionaries with meta tag attributes to add
|
78 |
+
|
79 |
+
Returns:
|
80 |
+
The modified HTML content with updated meta tags
|
81 |
+
"""
|
82 |
+
# Parse the HTML
|
83 |
+
soup = BeautifulSoup(html_content, 'html.parser')
|
84 |
+
head = soup.head
|
85 |
+
|
86 |
+
# Remove existing meta tags that would conflict with our new ones
|
87 |
+
for meta_tag in meta_tags:
|
88 |
+
# Determine if we're looking for 'name' or 'property' attribute
|
89 |
+
attr_type = 'name' if 'name' in meta_tag else 'property'
|
90 |
+
attr_value = meta_tag.get(attr_type)
|
91 |
+
|
92 |
+
# Find and remove existing meta tags with the same name/property
|
93 |
+
existing_tags = head.find_all('meta', attrs={attr_type: attr_value})
|
94 |
+
for tag in existing_tags:
|
95 |
+
tag.decompose()
|
96 |
+
|
97 |
+
# Add the new meta tags to the head section
|
98 |
+
for meta_info in meta_tags:
|
99 |
+
new_meta = soup.new_tag('meta')
|
100 |
+
for attr, value in meta_info.items():
|
101 |
+
new_meta[attr] = value
|
102 |
+
head.append(new_meta)
|
103 |
+
|
104 |
+
return str(soup)
|
105 |
+
|
106 |
+
class MetaTagInjectionMiddleware(BaseHTTPMiddleware):
|
107 |
+
"""
|
108 |
+
FastAPI middleware that safely intercepts and modifies the HTML response from the root endpoint
|
109 |
+
to inject custom meta tags into the document head.
|
110 |
+
|
111 |
+
This middleware specifically targets the root path ('/') and leaves all other endpoint
|
112 |
+
responses unmodified. It uses BeautifulSoup to properly parse and modify the HTML,
|
113 |
+
ensuring that JavaScript functionality remains intact.
|
114 |
+
"""
|
115 |
+
async def dispatch(
|
116 |
+
self,
|
117 |
+
request: Request,
|
118 |
+
call_next: Callable[[Request], Awaitable[Response]]
|
119 |
+
) -> Response:
|
120 |
+
# Process the request and get the response
|
121 |
+
response = await call_next(request)
|
122 |
+
|
123 |
+
# Only intercept responses from the root endpoint and HTML content
|
124 |
+
if request.url.path == "/" and response.headers.get("content-type", "").startswith("text/html"):
|
125 |
+
# Get the response body
|
126 |
+
response_body = b""
|
127 |
+
async for chunk in response.body_iterator:
|
128 |
+
response_body += chunk
|
129 |
+
|
130 |
+
try:
|
131 |
+
# Decode, modify, and re-encode the content
|
132 |
+
content = response_body.decode("utf-8")
|
133 |
+
modified_content = __update_meta_tags(content, META_TAGS).encode("utf-8")
|
134 |
+
|
135 |
+
# Update content-length header to reflect modified content size
|
136 |
+
headers = dict(response.headers)
|
137 |
+
headers["content-length"] = str(len(modified_content))
|
138 |
+
|
139 |
+
# Create a new response with the modified content
|
140 |
+
return Response(
|
141 |
+
content=modified_content,
|
142 |
+
status_code=response.status_code,
|
143 |
+
headers=headers,
|
144 |
+
media_type=response.media_type
|
145 |
+
)
|
146 |
+
except Exception:
|
147 |
+
# If there's an error, return the original response
|
148 |
+
return Response(
|
149 |
+
content=response_body,
|
150 |
+
status_code=response.status_code,
|
151 |
+
headers=dict(response.headers),
|
152 |
+
media_type=response.media_type
|
153 |
+
)
|
154 |
+
|
155 |
+
return response
|
src/scripts/init_db.py
CHANGED
@@ -12,7 +12,7 @@ import sys
|
|
12 |
from sqlalchemy.ext.asyncio import create_async_engine
|
13 |
|
14 |
# Local Application Imports
|
15 |
-
from src.
|
16 |
from src.database import Base
|
17 |
|
18 |
|
|
|
12 |
from sqlalchemy.ext.asyncio import create_async_engine
|
13 |
|
14 |
# Local Application Imports
|
15 |
+
from src.common import Config, logger
|
16 |
from src.database import Base
|
17 |
|
18 |
|
src/scripts/test_db.py
CHANGED
@@ -33,7 +33,7 @@ import sys
|
|
33 |
from sqlalchemy import text
|
34 |
|
35 |
# Local Application Imports
|
36 |
-
from src.
|
37 |
from src.database import engine, init_db
|
38 |
|
39 |
|
|
|
33 |
from sqlalchemy import text
|
34 |
|
35 |
# Local Application Imports
|
36 |
+
from src.common import Config, logger
|
37 |
from src.database import engine, init_db
|
38 |
|
39 |
|
src/utils.py
DELETED
@@ -1,650 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
utils.py
|
3 |
-
|
4 |
-
This file contains utility functions that are shared across the project.
|
5 |
-
These functions provide reusable logic to simplify code in other modules.
|
6 |
-
"""
|
7 |
-
|
8 |
-
# Standard Library Imports
|
9 |
-
import base64
|
10 |
-
import json
|
11 |
-
import os
|
12 |
-
import random
|
13 |
-
import time
|
14 |
-
from pathlib import Path
|
15 |
-
from typing import Dict, List, Tuple, cast
|
16 |
-
|
17 |
-
# Third-Party Library Imports
|
18 |
-
from bs4 import BeautifulSoup
|
19 |
-
from sqlalchemy.ext.asyncio import AsyncSession
|
20 |
-
|
21 |
-
# Local Application Imports
|
22 |
-
from src import constants
|
23 |
-
from src.config import Config, logger
|
24 |
-
from src.custom_types import (
|
25 |
-
ComparisonType,
|
26 |
-
LeaderboardEntry,
|
27 |
-
Option,
|
28 |
-
OptionKey,
|
29 |
-
OptionMap,
|
30 |
-
TTSProviderName,
|
31 |
-
VotingResults,
|
32 |
-
)
|
33 |
-
from src.database import (
|
34 |
-
AsyncDBSessionMaker,
|
35 |
-
create_vote,
|
36 |
-
get_head_to_head_battle_stats,
|
37 |
-
get_head_to_head_win_rate_stats,
|
38 |
-
get_leaderboard_stats,
|
39 |
-
)
|
40 |
-
|
41 |
-
|
42 |
-
def truncate_text(text: str, max_length: int = 50) -> str:
|
43 |
-
"""
|
44 |
-
Truncate a string to the specified length, appending ellipses if necessary.
|
45 |
-
|
46 |
-
Args:
|
47 |
-
text (str): The text to truncate.
|
48 |
-
max_length (int): The maximum length of the truncated string.
|
49 |
-
|
50 |
-
Returns:
|
51 |
-
str: The truncated text.
|
52 |
-
|
53 |
-
Examples:
|
54 |
-
>>> truncate_text("Hello, World!", 5)
|
55 |
-
'Hello...'
|
56 |
-
>>> truncate_text("Short string", 20)
|
57 |
-
'Short string'
|
58 |
-
>>> truncate_text("Edge case with zero length", 0)
|
59 |
-
''
|
60 |
-
"""
|
61 |
-
if max_length <= 0:
|
62 |
-
logger.warning(f"Invalid max_length={max_length}. Returning empty string.")
|
63 |
-
return ""
|
64 |
-
|
65 |
-
is_truncated = len(text) > max_length
|
66 |
-
if is_truncated:
|
67 |
-
logger.debug(f"Truncated text to {max_length} characters.")
|
68 |
-
|
69 |
-
return text[:max_length] + ("..." if is_truncated else "")
|
70 |
-
|
71 |
-
|
72 |
-
def validate_character_description_length(character_description: str) -> None:
|
73 |
-
"""
|
74 |
-
Validates that a voice description is within specified minimum and maximum length limits.
|
75 |
-
|
76 |
-
Args:
|
77 |
-
character_description (str): The input character description to validate.
|
78 |
-
|
79 |
-
Raises:
|
80 |
-
ValueError: If the character description is empty, too short, or exceeds max length.
|
81 |
-
"""
|
82 |
-
stripped_character_description = character_description.strip()
|
83 |
-
character_description_length = len(stripped_character_description)
|
84 |
-
|
85 |
-
logger.debug(f"Voice description length being validated: {character_description_length} characters")
|
86 |
-
|
87 |
-
if character_description_length < constants.CHARACTER_DESCRIPTION_MIN_LENGTH:
|
88 |
-
raise ValueError(
|
89 |
-
f"Your character description is too short. Please enter at least "
|
90 |
-
f"{constants.CHARACTER_DESCRIPTION_MIN_LENGTH} characters. "
|
91 |
-
f"(Current length: {character_description_length})"
|
92 |
-
)
|
93 |
-
if character_description_length > constants.CHARACTER_DESCRIPTION_MAX_LENGTH:
|
94 |
-
raise ValueError(
|
95 |
-
f"Your character description is too long. Please limit it to "
|
96 |
-
f"{constants.CHARACTER_DESCRIPTION_MAX_LENGTH} characters. "
|
97 |
-
f"(Current length: {character_description_length})"
|
98 |
-
)
|
99 |
-
|
100 |
-
truncated_description = truncate_text(stripped_character_description)
|
101 |
-
logger.debug(f"Character description length validation passed for character_description: {truncated_description}")
|
102 |
-
|
103 |
-
|
104 |
-
def validate_text_length(text: str) -> None:
|
105 |
-
"""
|
106 |
-
Validates that a text input is within specified minimum and maximum length limits.
|
107 |
-
|
108 |
-
Args:
|
109 |
-
text (str): The input text to validate.
|
110 |
-
|
111 |
-
Raises:
|
112 |
-
ValueError: If the text is empty, too short, or exceeds max length.
|
113 |
-
"""
|
114 |
-
stripped_text = text.strip()
|
115 |
-
text_length = len(stripped_text)
|
116 |
-
|
117 |
-
logger.debug(f"Voice description length being validated: {text_length} characters")
|
118 |
-
|
119 |
-
if text_length < constants.TEXT_MIN_LENGTH:
|
120 |
-
raise ValueError(
|
121 |
-
f"Your text is too short. Please enter at least "
|
122 |
-
f"{constants.TEXT_MIN_LENGTH} characters. "
|
123 |
-
f"(Current length: {text_length})"
|
124 |
-
)
|
125 |
-
if text_length > constants.TEXT_MAX_LENGTH:
|
126 |
-
raise ValueError(
|
127 |
-
f"Your text is too long. Please limit it to "
|
128 |
-
f"{constants.TEXT_MAX_LENGTH} characters. "
|
129 |
-
f"(Current length: {text_length})"
|
130 |
-
)
|
131 |
-
|
132 |
-
truncated_text = truncate_text(stripped_text)
|
133 |
-
logger.debug(f"Character description length validation passed for text: {truncated_text}")
|
134 |
-
|
135 |
-
|
136 |
-
def _delete_files_older_than(directory: Path, minutes: int = 30) -> None:
|
137 |
-
"""
|
138 |
-
Delete all files in the specified directory that are older than a given number of minutes.
|
139 |
-
|
140 |
-
This function checks each file in the given directory and removes it if its last modification
|
141 |
-
time is older than the specified threshold. By default, the threshold is set to 30 minutes.
|
142 |
-
|
143 |
-
Args:
|
144 |
-
directory (str): The path to the directory where files will be checked and possibly deleted.
|
145 |
-
minutes (int, optional): The age threshold in minutes. Files older than this will be deleted.
|
146 |
-
Defaults to 30 minutes.
|
147 |
-
|
148 |
-
Returns: None
|
149 |
-
"""
|
150 |
-
# Get the current time in seconds since the epoch.
|
151 |
-
now = time.time()
|
152 |
-
# Convert the minutes threshold to seconds.
|
153 |
-
cutoff = now - (minutes * 60)
|
154 |
-
dir_path = Path(directory)
|
155 |
-
|
156 |
-
# Iterate over all files in the directory.
|
157 |
-
for file_path in dir_path.iterdir():
|
158 |
-
if file_path.is_file():
|
159 |
-
file_mod_time = file_path.stat().st_mtime
|
160 |
-
# If the file's modification time is older than the cutoff, delete it.
|
161 |
-
if file_mod_time < cutoff:
|
162 |
-
try:
|
163 |
-
file_path.unlink()
|
164 |
-
logger.info(f"Deleted: {file_path}")
|
165 |
-
except Exception as e:
|
166 |
-
logger.exception(f"Error deleting {file_path}: {e}")
|
167 |
-
|
168 |
-
|
169 |
-
def save_base64_audio_to_file(base64_audio: str, filename: str, config: Config) -> str:
|
170 |
-
"""
|
171 |
-
Decode a base64-encoded audio string and write the resulting binary data to a file
|
172 |
-
within the preconfigured AUDIO_DIR directory. Prior to writing the bytes to an audio
|
173 |
-
file, all files within the directory that are more than 30 minutes old are deleted.
|
174 |
-
This function verifies the file was created, logs both the absolute and relative
|
175 |
-
file paths, and returns a path relative to the current working directory
|
176 |
-
(as required by Gradio for serving static files).
|
177 |
-
|
178 |
-
Args:
|
179 |
-
base64_audio (str): The base64-encoded string representing the audio data.
|
180 |
-
filename (str): The name of the file (including extension, e.g.,
|
181 |
-
'b4a335da-9786-483a-b0a5-37e6e4ad5fd1.mp3') where the decoded
|
182 |
-
audio will be saved.
|
183 |
-
|
184 |
-
Returns:
|
185 |
-
str: The relative file path to the saved audio file.
|
186 |
-
|
187 |
-
Raises:
|
188 |
-
FileNotFoundError: If the audio file was not created.
|
189 |
-
"""
|
190 |
-
|
191 |
-
audio_bytes = base64.b64decode(base64_audio)
|
192 |
-
file_path = Path(config.audio_dir) / filename
|
193 |
-
num_minutes = 30
|
194 |
-
|
195 |
-
_delete_files_older_than(config.audio_dir, num_minutes)
|
196 |
-
|
197 |
-
# Write the binary audio data to the file.
|
198 |
-
with file_path.open("wb") as audio_file:
|
199 |
-
audio_file.write(audio_bytes)
|
200 |
-
|
201 |
-
# Verify that the file was created.
|
202 |
-
if not file_path.exists():
|
203 |
-
raise FileNotFoundError(f"Audio file was not created at {file_path}")
|
204 |
-
|
205 |
-
# Compute a relative path for Gradio to serve (relative to the current working directory).
|
206 |
-
relative_path = file_path.relative_to(Path.cwd())
|
207 |
-
logger.debug(f"Audio file absolute path: {file_path}")
|
208 |
-
logger.debug(f"Audio file relative path: {relative_path}")
|
209 |
-
|
210 |
-
return str(relative_path)
|
211 |
-
|
212 |
-
|
213 |
-
def get_random_providers(text_modified: bool) -> Tuple[TTSProviderName, TTSProviderName]:
|
214 |
-
"""
|
215 |
-
Select 2 TTS providers based on whether the text has been modified.
|
216 |
-
|
217 |
-
Probabilities:
|
218 |
-
- 50% HUME_AI, OPENAI
|
219 |
-
- 25% OPENAI, ELEVENLABS
|
220 |
-
- 20% HUME_AI, ELEVENLABS
|
221 |
-
- 5% HUME_AI, HUME_AI
|
222 |
-
|
223 |
-
If the `text_modified` argument is `True`, then 100% HUME_AI, HUME_AI
|
224 |
-
|
225 |
-
Args:
|
226 |
-
text_modified (bool): A flag indicating whether the text has been modified, indicating a custom text input.
|
227 |
-
|
228 |
-
Returns:
|
229 |
-
tuple: A tuple (TTSProviderName, TTSProviderName)
|
230 |
-
"""
|
231 |
-
if text_modified:
|
232 |
-
return constants.HUME_AI, constants.HUME_AI
|
233 |
-
|
234 |
-
# When modifying the probability distribution, make sure the weights match the order of provider pairs
|
235 |
-
provider_pairs = [
|
236 |
-
(constants.HUME_AI, constants.OPENAI),
|
237 |
-
(constants.OPENAI, constants.ELEVENLABS),
|
238 |
-
(constants.HUME_AI, constants.ELEVENLABS),
|
239 |
-
(constants.HUME_AI, constants.HUME_AI)
|
240 |
-
]
|
241 |
-
weights = [0.5, 0.25, 0.2, 0.05]
|
242 |
-
|
243 |
-
return random.choices(provider_pairs, weights=weights, k=1)[0]
|
244 |
-
|
245 |
-
|
246 |
-
def create_shuffled_tts_options(option_a: Option, option_b: Option) -> OptionMap:
|
247 |
-
"""
|
248 |
-
Create and shuffle TTS generation options.
|
249 |
-
|
250 |
-
This function accepts two TTS generation options, shuffles them randomly,
|
251 |
-
and returns an OptionMap with keys 'option_a' and 'option_b' corresponding
|
252 |
-
to the shuffled options.
|
253 |
-
|
254 |
-
Args:
|
255 |
-
option_a (Option): The first TTS generation option.
|
256 |
-
option_b (Option): The second TTS generation option.
|
257 |
-
|
258 |
-
Returns:
|
259 |
-
OptionMap: A mapping of shuffled TTS options, where each option includes
|
260 |
-
its provider, audio file path, and generation ID.
|
261 |
-
"""
|
262 |
-
|
263 |
-
options = [option_a, option_b]
|
264 |
-
random.shuffle(options)
|
265 |
-
shuffled_option_a, shuffled_option_b = options
|
266 |
-
|
267 |
-
return {
|
268 |
-
"option_a": {
|
269 |
-
"provider": shuffled_option_a.provider,
|
270 |
-
"generation_id": shuffled_option_a.generation_id,
|
271 |
-
"audio_file_path": shuffled_option_a.audio,
|
272 |
-
},
|
273 |
-
"option_b": {
|
274 |
-
"provider": shuffled_option_b.provider,
|
275 |
-
"generation_id": shuffled_option_b.generation_id,
|
276 |
-
"audio_file_path": shuffled_option_b.audio,
|
277 |
-
},
|
278 |
-
}
|
279 |
-
|
280 |
-
|
281 |
-
def determine_selected_option(selected_option_button: str) -> Tuple[OptionKey, OptionKey]:
|
282 |
-
"""
|
283 |
-
Determines the selected option and the alternative option based on the user's selection.
|
284 |
-
|
285 |
-
Args:
|
286 |
-
selected_option_button (str): The option selected by the user, expected to be either
|
287 |
-
constants.OPTION_A_KEY or constants.OPTION_B_KEY.
|
288 |
-
|
289 |
-
Returns:
|
290 |
-
tuple: A tuple (selected_option, other_option) where:
|
291 |
-
- selected_option is the same as the selected_option.
|
292 |
-
- other_option is the alternative option.
|
293 |
-
"""
|
294 |
-
|
295 |
-
if selected_option_button == constants.SELECT_OPTION_A:
|
296 |
-
selected_option, other_option = constants.OPTION_A_KEY, constants.OPTION_B_KEY
|
297 |
-
elif selected_option_button == constants.SELECT_OPTION_B:
|
298 |
-
selected_option, other_option = constants.OPTION_B_KEY, constants.OPTION_A_KEY
|
299 |
-
else:
|
300 |
-
raise ValueError(f"Invalid selected button: {selected_option_button}")
|
301 |
-
|
302 |
-
return selected_option, other_option
|
303 |
-
|
304 |
-
|
305 |
-
def _determine_comparison_type(provider_a: TTSProviderName, provider_b: TTSProviderName) -> ComparisonType:
|
306 |
-
"""
|
307 |
-
Determine the comparison type based on the given TTS provider names.
|
308 |
-
|
309 |
-
Args:
|
310 |
-
provider_a (TTSProviderName): The first TTS provider.
|
311 |
-
provider_b (TTSProviderName): The second TTS provider.
|
312 |
-
|
313 |
-
Returns:
|
314 |
-
ComparisonType: The determined comparison type.
|
315 |
-
|
316 |
-
Raises:
|
317 |
-
ValueError: If the combination of providers is not recognized.
|
318 |
-
"""
|
319 |
-
|
320 |
-
if provider_a == constants.HUME_AI and provider_b == constants.HUME_AI:
|
321 |
-
return constants.HUME_TO_HUME
|
322 |
-
|
323 |
-
providers = (provider_a, provider_b)
|
324 |
-
|
325 |
-
if constants.HUME_AI in providers and constants.ELEVENLABS in providers:
|
326 |
-
return constants.HUME_TO_ELEVENLABS
|
327 |
-
|
328 |
-
if constants.HUME_AI in providers and constants.OPENAI in providers:
|
329 |
-
return constants.HUME_TO_OPENAI
|
330 |
-
|
331 |
-
if constants.ELEVENLABS in providers and constants.OPENAI in providers:
|
332 |
-
return constants.OPENAI_TO_ELEVENLABS
|
333 |
-
|
334 |
-
raise ValueError(f"Invalid provider combination: {provider_a}, {provider_b}")
|
335 |
-
|
336 |
-
|
337 |
-
def _log_voting_results(voting_results: VotingResults) -> None:
|
338 |
-
"""Log the full voting results."""
|
339 |
-
|
340 |
-
logger.info("Voting results:\n%s", json.dumps(voting_results, indent=4))
|
341 |
-
|
342 |
-
|
343 |
-
async def _create_db_session(db_session_maker: AsyncDBSessionMaker) -> AsyncSession:
|
344 |
-
"""
|
345 |
-
Creates a new database session using the provided session maker and checks if it's a dummy session.
|
346 |
-
|
347 |
-
A dummy session might be used in development or testing environments where database operations
|
348 |
-
should be simulated but not actually performed.
|
349 |
-
|
350 |
-
Args:
|
351 |
-
db_session_maker (AsyncDBSessionMaker): A callable that returns a new async database session.
|
352 |
-
|
353 |
-
Returns:
|
354 |
-
AsyncSession: A newly created database session that can be used for database operations.
|
355 |
-
"""
|
356 |
-
session = db_session_maker()
|
357 |
-
is_dummy_session = getattr(session, "is_dummy", False)
|
358 |
-
|
359 |
-
if is_dummy_session:
|
360 |
-
await session.close()
|
361 |
-
return None
|
362 |
-
|
363 |
-
return session
|
364 |
-
|
365 |
-
|
366 |
-
async def _persist_vote(db_session_maker: AsyncDBSessionMaker, voting_results: VotingResults) -> None:
|
367 |
-
"""
|
368 |
-
Asynchronously persist a vote record in the database and handle potential failures.
|
369 |
-
Designed to work safely in a background task context.
|
370 |
-
|
371 |
-
Args:
|
372 |
-
db_session_maker (AsyncDBSessionMaker): A callable that returns a new async database session.
|
373 |
-
voting_results (VotingResults): A dictionary containing the details of the vote to persist.
|
374 |
-
config (Config): The application configuration, used to determine environment-specific behavior.
|
375 |
-
|
376 |
-
Returns:
|
377 |
-
None
|
378 |
-
"""
|
379 |
-
# Create session
|
380 |
-
session = await _create_db_session(db_session_maker)
|
381 |
-
_log_voting_results(voting_results)
|
382 |
-
try:
|
383 |
-
await create_vote(cast(AsyncSession, session), voting_results)
|
384 |
-
except Exception as e:
|
385 |
-
# Log the error with traceback
|
386 |
-
logger.error(f"Failed to create vote record: {e}", exc_info=True)
|
387 |
-
finally:
|
388 |
-
# Always ensure the session is closed
|
389 |
-
if session is not None:
|
390 |
-
await session.close()
|
391 |
-
|
392 |
-
|
393 |
-
async def submit_voting_results(
|
394 |
-
option_map: OptionMap,
|
395 |
-
selected_option: OptionKey,
|
396 |
-
text_modified: bool,
|
397 |
-
character_description: str,
|
398 |
-
text: str,
|
399 |
-
db_session_maker: AsyncDBSessionMaker,
|
400 |
-
) -> None:
|
401 |
-
"""
|
402 |
-
Asynchronously constructs the voting results dictionary and persists a new vote record.
|
403 |
-
Designed to run as a background task, handling all exceptions internally.
|
404 |
-
|
405 |
-
Args:
|
406 |
-
option_map (OptionMap): Mapping of comparison data and TTS options.
|
407 |
-
selected_option (OptionKey): The option selected by the user.
|
408 |
-
text_modified (bool): Indicates whether the text was modified from the original generated text.
|
409 |
-
character_description (str): Description of the voice/character used for TTS generation.
|
410 |
-
text (str): The text that was synthesized into speech.
|
411 |
-
db_session_maker (AsyncDBSessionMaker): Factory function for creating async database sessions.
|
412 |
-
config (Config): Application configuration containing environment settings.
|
413 |
-
|
414 |
-
Returns:
|
415 |
-
None
|
416 |
-
"""
|
417 |
-
try:
|
418 |
-
provider_a: TTSProviderName = option_map[constants.OPTION_A_KEY]["provider"]
|
419 |
-
provider_b: TTSProviderName = option_map[constants.OPTION_B_KEY]["provider"]
|
420 |
-
|
421 |
-
comparison_type: ComparisonType = _determine_comparison_type(provider_a, provider_b)
|
422 |
-
|
423 |
-
voting_results: VotingResults = {
|
424 |
-
"comparison_type": comparison_type,
|
425 |
-
"winning_provider": option_map[selected_option]["provider"],
|
426 |
-
"winning_option": selected_option,
|
427 |
-
"option_a_provider": provider_a,
|
428 |
-
"option_b_provider": provider_b,
|
429 |
-
"option_a_generation_id": option_map[constants.OPTION_A_KEY]["generation_id"],
|
430 |
-
"option_b_generation_id": option_map[constants.OPTION_B_KEY]["generation_id"],
|
431 |
-
"character_description": character_description,
|
432 |
-
"text": text,
|
433 |
-
"is_custom_text": text_modified,
|
434 |
-
}
|
435 |
-
|
436 |
-
await _persist_vote(db_session_maker, voting_results)
|
437 |
-
|
438 |
-
# Catch exceptions at the top level of the background task to prevent unhandled exceptions in background tasks
|
439 |
-
except Exception as e:
|
440 |
-
logger.error(f"Background task error in submit_voting_results: {e}", exc_info=True)
|
441 |
-
|
442 |
-
|
443 |
-
async def get_leaderboard_data(
|
444 |
-
db_session_maker: AsyncDBSessionMaker
|
445 |
-
) -> Tuple[List[List[str]], List[List[str]], List[List[str]]]:
|
446 |
-
"""
|
447 |
-
Fetches and formats all leaderboard data from the voting results database.
|
448 |
-
|
449 |
-
This function retrieves three different datasets:
|
450 |
-
1. Provider rankings with overall performance metrics
|
451 |
-
2. Head-to-head battle counts between providers
|
452 |
-
3. Win rate percentages for each provider against others
|
453 |
-
|
454 |
-
Args:
|
455 |
-
db_session_maker (AsyncDBSessionMaker): Factory function for creating async database sessions.
|
456 |
-
|
457 |
-
Returns:
|
458 |
-
Tuple containing three datasets, each as List[List[str]]:
|
459 |
-
- leaderboard_data: Provider rankings with performance metrics
|
460 |
-
- battle_counts_data: Number of comparisons between each provider pair
|
461 |
-
- win_rate_data: Win percentages in head-to-head matchups
|
462 |
-
"""
|
463 |
-
# Create session
|
464 |
-
session = await _create_db_session(db_session_maker)
|
465 |
-
try:
|
466 |
-
leaderboard_data_raw = await get_leaderboard_stats(cast(AsyncSession, session))
|
467 |
-
battle_counts_data_raw = await get_head_to_head_battle_stats(cast(AsyncSession, session))
|
468 |
-
win_rate_data_raw = await get_head_to_head_win_rate_stats(cast(AsyncSession, session))
|
469 |
-
|
470 |
-
logger.info("Fetched leaderboard data successfully.")
|
471 |
-
|
472 |
-
leaderboard_data = _format_leaderboard_data(leaderboard_data_raw)
|
473 |
-
battle_counts_data = _format_battle_counts_data(battle_counts_data_raw)
|
474 |
-
win_rate_data = _format_win_rate_data(win_rate_data_raw)
|
475 |
-
|
476 |
-
return leaderboard_data, battle_counts_data, win_rate_data
|
477 |
-
except Exception as e:
|
478 |
-
# Log the error with traceback
|
479 |
-
logger.error(f"Failed to fetch leaderboard data: {e}", exc_info=True)
|
480 |
-
return [[]], [[]], [[]]
|
481 |
-
finally:
|
482 |
-
# Always ensure the session is closed
|
483 |
-
if session is not None:
|
484 |
-
await session.close()
|
485 |
-
|
486 |
-
def _format_leaderboard_data(leaderboard_data_raw: List[LeaderboardEntry]) -> List[List[str]]:
|
487 |
-
"""
|
488 |
-
Formats raw leaderboard data for display in the UI.
|
489 |
-
|
490 |
-
Converts LeaderboardEntry objects into HTML-formatted strings with appropriate
|
491 |
-
styling and links for provider and model information.
|
492 |
-
|
493 |
-
Args:
|
494 |
-
leaderboard_data_raw (List[LeaderboardEntry]): Raw leaderboard data from the database.
|
495 |
-
|
496 |
-
Returns:
|
497 |
-
List[List[str]]: Formatted HTML strings for each cell in the leaderboard table.
|
498 |
-
"""
|
499 |
-
return [
|
500 |
-
[
|
501 |
-
f'<p style="text-align: center;">{row[0]}</p>',
|
502 |
-
f"""<a href="{constants.TTS_PROVIDER_LINKS[row[1]]["provider_link"]}"
|
503 |
-
target="_blank"
|
504 |
-
class="provider-link"
|
505 |
-
>{row[1]}</a>
|
506 |
-
""",
|
507 |
-
f"""<a href="{constants.TTS_PROVIDER_LINKS[row[1]]["model_link"]}"
|
508 |
-
target="_blank"
|
509 |
-
class="provider-link"
|
510 |
-
>{row[2]}</a>
|
511 |
-
""",
|
512 |
-
f'<p style="text-align: center;">{row[3]}</p>',
|
513 |
-
f'<p style="text-align: center;">{row[4]}</p>',
|
514 |
-
] for row in leaderboard_data_raw
|
515 |
-
]
|
516 |
-
|
517 |
-
|
518 |
-
def _format_battle_counts_data(battle_counts_data_raw: List[List[str]]) -> List[List[str]]:
|
519 |
-
"""
|
520 |
-
Formats battle count data into a matrix format for the UI.
|
521 |
-
|
522 |
-
Creates a provider-by-provider matrix showing the number of direct comparisons
|
523 |
-
between each pair of providers. Diagonal cells show dashes as providers aren't
|
524 |
-
compared against themselves.
|
525 |
-
|
526 |
-
Args:
|
527 |
-
battle_counts_data_raw (List[List[str]]): Raw battle count data from the database,
|
528 |
-
where each inner list contains [comparison_type, count].
|
529 |
-
|
530 |
-
Returns:
|
531 |
-
List[List[str]]: HTML-formatted matrix of battle counts between providers.
|
532 |
-
"""
|
533 |
-
battle_counts_dict = {item[0]: item[1] for item in battle_counts_data_raw}
|
534 |
-
# Create canonical comparison keys based on your expected database formats
|
535 |
-
comparison_keys = {
|
536 |
-
("Hume AI", "OpenAI"): "Hume AI - OpenAI",
|
537 |
-
("Hume AI", "ElevenLabs"): "Hume AI - ElevenLabs",
|
538 |
-
("OpenAI", "ElevenLabs"): "OpenAI - ElevenLabs"
|
539 |
-
}
|
540 |
-
return [
|
541 |
-
[
|
542 |
-
f'<p style="padding-left: 8px;"><strong>{row_provider}</strong></p>'
|
543 |
-
] + [
|
544 |
-
f"""
|
545 |
-
<p style="text-align: center;">
|
546 |
-
{"-" if row_provider == col_provider
|
547 |
-
else battle_counts_dict.get(
|
548 |
-
comparison_keys.get((row_provider, col_provider)) or
|
549 |
-
comparison_keys.get((col_provider, row_provider), "unknown"),
|
550 |
-
"0"
|
551 |
-
)
|
552 |
-
}
|
553 |
-
</p>
|
554 |
-
""" for col_provider in constants.TTS_PROVIDERS
|
555 |
-
]
|
556 |
-
for row_provider in constants.TTS_PROVIDERS
|
557 |
-
]
|
558 |
-
|
559 |
-
|
560 |
-
def _format_win_rate_data(win_rate_data_raw: List[List[str]]) -> List[List[str]]:
|
561 |
-
"""
|
562 |
-
Formats win rate data into a matrix format for the UI.
|
563 |
-
|
564 |
-
Creates a provider-by-provider matrix showing the percentage of times the row
|
565 |
-
provider won against the column provider. Diagonal cells show dashes as
|
566 |
-
providers aren't compared against themselves.
|
567 |
-
|
568 |
-
Args:
|
569 |
-
win_rate_data_raw (List[List[str]]): Raw win rate data from the database,
|
570 |
-
where each inner list contains [comparison_type, first_win_rate, second_win_rate].
|
571 |
-
|
572 |
-
Returns:
|
573 |
-
List[List[str]]: HTML-formatted matrix of win rates between providers.
|
574 |
-
"""
|
575 |
-
# Create a clean lookup dictionary with provider pairs as keys
|
576 |
-
win_rates = {}
|
577 |
-
for comparison_type, first_win_rate, second_win_rate in win_rate_data_raw:
|
578 |
-
provider1, provider2 = comparison_type.split(" - ")
|
579 |
-
win_rates[(provider1, provider2)] = first_win_rate
|
580 |
-
win_rates[(provider2, provider1)] = second_win_rate
|
581 |
-
|
582 |
-
return [
|
583 |
-
[
|
584 |
-
f'<p style="padding-left: 8px;"><strong>{row_provider}</strong></p>'
|
585 |
-
] + [
|
586 |
-
f"""
|
587 |
-
<p style="text-align: center;">
|
588 |
-
{"-" if row_provider == col_provider else win_rates.get((row_provider, col_provider), "0%")}
|
589 |
-
</p>
|
590 |
-
"""
|
591 |
-
for col_provider in constants.TTS_PROVIDERS
|
592 |
-
]
|
593 |
-
for row_provider in constants.TTS_PROVIDERS
|
594 |
-
]
|
595 |
-
|
596 |
-
|
597 |
-
def validate_env_var(var_name: str) -> str:
|
598 |
-
"""
|
599 |
-
Validates that an environment variable is set and returns its value.
|
600 |
-
|
601 |
-
Args:
|
602 |
-
var_name (str): The name of the environment variable to validate.
|
603 |
-
|
604 |
-
Returns:
|
605 |
-
str: The value of the environment variable.
|
606 |
-
|
607 |
-
Raises:
|
608 |
-
ValueError: If the environment variable is not set.
|
609 |
-
"""
|
610 |
-
value = os.environ.get(var_name, "")
|
611 |
-
if not value:
|
612 |
-
raise ValueError(f"{var_name} is not set. Please ensure it is defined in your environment variables.")
|
613 |
-
return value
|
614 |
-
|
615 |
-
|
616 |
-
def update_meta_tags(html_content: str, meta_tags: List[Dict[str, str]]) -> str:
|
617 |
-
"""
|
618 |
-
Safely updates the HTML content by adding or replacing meta tags in the head section
|
619 |
-
without affecting other elements, especially scripts and event handlers.
|
620 |
-
|
621 |
-
Args:
|
622 |
-
html_content: The original HTML content as a string
|
623 |
-
meta_tags: A list of dictionaries with meta tag attributes to add
|
624 |
-
|
625 |
-
Returns:
|
626 |
-
The modified HTML content with updated meta tags
|
627 |
-
"""
|
628 |
-
# Parse the HTML
|
629 |
-
soup = BeautifulSoup(html_content, 'html.parser')
|
630 |
-
head = soup.head
|
631 |
-
|
632 |
-
# Remove existing meta tags that would conflict with our new ones
|
633 |
-
for meta_tag in meta_tags:
|
634 |
-
# Determine if we're looking for 'name' or 'property' attribute
|
635 |
-
attr_type = 'name' if 'name' in meta_tag else 'property'
|
636 |
-
attr_value = meta_tag.get(attr_type)
|
637 |
-
|
638 |
-
# Find and remove existing meta tags with the same name/property
|
639 |
-
existing_tags = head.find_all('meta', attrs={attr_type: attr_value})
|
640 |
-
for tag in existing_tags:
|
641 |
-
tag.decompose()
|
642 |
-
|
643 |
-
# Add the new meta tags to the head section
|
644 |
-
for meta_info in meta_tags:
|
645 |
-
new_meta = soup.new_tag('meta')
|
646 |
-
for attr, value in meta_info.items():
|
647 |
-
new_meta[attr] = value
|
648 |
-
head.append(new_meta)
|
649 |
-
|
650 |
-
return str(soup)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|