Spaces:
Running
Running
File size: 6,882 Bytes
9ed181c 5ed9749 9ed181c 5ed9749 9ed181c 5ed9749 9ed181c 5ed9749 9ed181c 5ed9749 9ed181c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 |
# Standard Library Imports
import logging
import random
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Literal, Tuple, Union
# Third-Party Library Imports
from openai import APIError, AsyncOpenAI
from tenacity import after_log, before_log, retry, retry_if_exception, stop_after_attempt, wait_fixed
# Local Application Imports
from src.common import Config, logger
from src.common.constants import CLIENT_ERROR_CODE, GENERIC_API_ERROR_MESSAGE, RATE_LIMIT_ERROR_CODE, SERVER_ERROR_CODE
from src.common.utils import validate_env_var
@dataclass(frozen=True)
class OpenAIConfig:
"""Immutable configuration for interacting with the OpenAI TTS API."""
api_key: str = field(init=False)
model: str = "gpt-4o-mini-tts"
response_format: Literal['mp3', 'opus', 'aac', 'flac', 'wav', 'pcm'] = "mp3"
def __post_init__(self) -> None:
"""Validate required attributes and set computed fields."""
computed_api_key = validate_env_var("OPENAI_API_KEY")
object.__setattr__(self, "api_key", computed_api_key)
@property
def client(self) -> AsyncOpenAI:
"""
Lazy initialization of the asynchronous OpenAI client.
Returns:
AsyncOpenAI: Configured async client instance.
"""
return AsyncOpenAI(api_key=self.api_key)
@staticmethod
def select_random_base_voice() -> str:
"""
Randomly selects one of OpenAI's base voice options for TTS.
OpenAI's Python SDK doesn't export a type for their base voice names,
so we use a hardcoded list of the available voice options.
Returns:
str: A randomly selected OpenAI base voice name (e.g., 'alloy', 'nova', etc.)
"""
openai_base_voices = ["alloy", "ash", "coral", "echo", "fable", "onyx", "nova", "sage", "shimmer"]
return random.choice(openai_base_voices)
class OpenAIError(Exception):
"""Custom exception for errors related to the OpenAI TTS API."""
def __init__(self, message: str, original_exception: Union[Exception, None] = None):
super().__init__(message)
self.original_exception = original_exception
self.message = message
class UnretryableOpenAIError(OpenAIError):
"""Custom exception for errors related to the OpenAI TTS API that should not be retried."""
def __init__(self, message: str, original_exception: Union[Exception, None] = None):
super().__init__(message, original_exception)
self.original_exception = original_exception
self.message = message
@retry(
retry=retry_if_exception(lambda e: not isinstance(e, UnretryableOpenAIError)),
stop=stop_after_attempt(2),
wait=wait_fixed(2),
before=before_log(logger, logging.DEBUG),
after=after_log(logger, logging.DEBUG),
reraise=True,
)
async def text_to_speech_with_openai(
character_description: str,
text: str,
config: Config,
) -> Tuple[None, str]:
"""
Asynchronously synthesizes speech using the OpenAI TTS API, processes audio data, and writes audio to a file.
This function uses the OpenAI Python SDK to send a request to the OpenAI TTS API with a character description
and text to be converted to speech. It extracts the base64-encoded audio and generation ID from the response,
saves the audio as an MP3 file, and returns the relevant details.
Args:
character_description (str): Description used for voice synthesis.
text (str): Text to be converted to speech.
config (Config): Application configuration containing OpenAI API settings.
Returns:
Tuple[str, str]: A tuple containing:
- generation_id (str): Unique identifier for the generated audio.
- audio_file_path (str): Path to the saved audio file.
Raises:
OpenAIError: For errors communicating with the OpenAI API.
UnretryableOpenAIError: For client-side HTTP errors (status code 4xx).
"""
logger.debug(f"Synthesizing speech with OpenAI. Text length: {len(text)} characters.")
openai_config = config.openai_config
client = openai_config.client
start_time = time.time()
try:
voice = openai_config.select_random_base_voice()
async with client.audio.speech.with_streaming_response.create(
model=openai_config.model,
input=text,
instructions=character_description,
response_format=openai_config.response_format,
voice=voice, # OpenAI requires a base voice to be specified
) as response:
elapsed_time = time.time() - start_time
logger.info(f"OpenAI API request completed in {elapsed_time:.2f} seconds.")
filename = f"openai_{voice}_{start_time}"
audio_file_path = Path(config.audio_dir) / filename
await response.stream_to_file(audio_file_path)
relative_audio_file_path = audio_file_path.relative_to(Path.cwd())
return None, str(relative_audio_file_path)
except APIError as e:
elapsed_time = time.time() - start_time
logger.error(f"OpenAI API request failed after {elapsed_time:.2f} seconds: {e!s}")
logger.error(f"Full OpenAI API error: {e!s}")
clean_message = __extract_openai_error_message(e)
if hasattr(e, 'status_code') and e.status_code is not None:
if e.status_code == RATE_LIMIT_ERROR_CODE:
raise OpenAIError(message=clean_message, original_exception=e) from e
if CLIENT_ERROR_CODE <= e.status_code < SERVER_ERROR_CODE:
raise UnretryableOpenAIError(message=clean_message, original_exception=e) from e
raise OpenAIError(message=clean_message, original_exception=e) from e
except Exception as e:
error_type = type(e).__name__
error_message = str(e) if str(e) else f"An error of type {error_type} occurred"
logger.error("Error during OpenAI API call: %s - %s", error_type, error_message)
clean_message = GENERIC_API_ERROR_MESSAGE
raise OpenAIError(message=clean_message, original_exception=e) from e
def __extract_openai_error_message(e: APIError) -> str:
"""
Extracts a clean, user-friendly error message from an OpenAI API error response.
Args:
e (APIError): The OpenAI API error exception containing response information.
Returns:
str: A clean, user-friendly error message suitable for display to end users.
"""
clean_message = GENERIC_API_ERROR_MESSAGE
if hasattr(e, 'body') and isinstance(e.body, dict):
error_body = e.body
if (
'error' in error_body
and isinstance(error_body['error'], dict)
and 'message' in error_body['error']
):
clean_message = error_body['error']['message']
return clean_message
|