expressive-tts-arena / src /core /tts_service.py
Zachary Greathouse
Update probability distribution for provider selection for comparison. (#23)
e07c653 unverified
raw
history blame contribute delete
4.35 kB
# Standard Library Imports
import asyncio
import random
from typing import Tuple
# Local Application Imports
from src.common import Config, Option, OptionMap, TTSProviderName, logger
from src.common.constants import ELEVENLABS, HUME_AI, OPENAI
from src.integrations import (
text_to_speech_with_elevenlabs,
text_to_speech_with_hume,
text_to_speech_with_openai,
)
class TTSService:
"""
Service for coordinating text-to-speech generation across different providers.
This class handles the logic for selecting TTS providers, making concurrent API calls,
and processing the responses into a unified format for the frontend.
"""
def __init__(self, config: Config):
"""
Initialize the TTS service with application configuration.
Args:
config (Config): Application configuration containing API settings
"""
self.config = config
self.tts_provider_functions = {
HUME_AI: text_to_speech_with_hume,
ELEVENLABS: text_to_speech_with_elevenlabs,
OPENAI: text_to_speech_with_openai,
}
def __select_providers(self, text_modified: bool) -> Tuple[TTSProviderName, TTSProviderName]:
"""
Select 2 TTS providers based on whether the text has been modified.
Probabilities:
- 1/3 HUME_AI & OPENAI
- 1/3 HUME_AI & ELEVENLABS
- 1/3 OPENAI & ELEVENLABS
If the `text_modified` argument is `True`, then 100% HUME_AI, HUME_AI
Args:
text_modified (bool): A flag indicating whether the text has been modified
Returns:
tuple: A tuple (TTSProviderName, TTSProviderName)
"""
if text_modified:
return HUME_AI, HUME_AI
# When modifying the probability distribution, make sure the weights match the order of provider pairs
provider_pairs = [
(HUME_AI, OPENAI),
(HUME_AI, ELEVENLABS),
(OPENAI, ELEVENLABS),
]
weights = [1, 1, 1]
selected_pair = random.choices(provider_pairs, weights=weights, k=1)[0]
return selected_pair
async def synthesize_speech(
self,
character_description: str,
text: str,
text_modified: bool
) -> OptionMap:
"""
Generate speech for the given text using two different TTS providers.
This method selects appropriate providers based on the text modification status,
makes concurrent API calls to those providers, and returns the results.
Args:
character_description (str): Description of the character/voice for synthesis
text (str): The text to synthesize into speech
text_modified (bool): Whether the text has been modified from the original
Returns:
OptionMap: A mapping of shuffled TTS options, where each option includes
its provider, audio file path, and generation ID.
"""
provider_a, provider_b = self.__select_providers(text_modified)
logger.info(f"Starting speech synthesis with providers: {provider_a} and {provider_b}")
task_a = self.tts_provider_functions[provider_a](character_description, text, self.config)
task_b = self.tts_provider_functions[provider_b](character_description, text, self.config)
(generation_id_a, audio_a), (generation_id_b, audio_b) = await asyncio.gather(task_a, task_b)
logger.info(f"Synthesis succeeded for providers: {provider_a} and {provider_b}")
option_a = Option(provider=provider_a, audio=audio_a, generation_id=generation_id_a)
option_b = Option(provider=provider_b, audio=audio_b, generation_id=generation_id_b)
options = [option_a, option_b]
random.shuffle(options)
shuffled_option_a, shuffled_option_b = options
return {
"option_a": {
"provider": shuffled_option_a.provider,
"generation_id": shuffled_option_a.generation_id,
"audio_file_path": shuffled_option_a.audio,
},
"option_b": {
"provider": shuffled_option_b.provider,
"generation_id": shuffled_option_b.generation_id,
"audio_file_path": shuffled_option_b.audio,
},
}