Spaces:
Running
Running
# Standard Library Imports | |
import logging | |
import random | |
import time | |
from dataclasses import dataclass, field | |
from pathlib import Path | |
from typing import Literal, Tuple, Union | |
# Third-Party Library Imports | |
from openai import APIError, AsyncOpenAI | |
from tenacity import after_log, before_log, retry, retry_if_exception, stop_after_attempt, wait_fixed | |
# Local Application Imports | |
from src.common import Config, logger | |
from src.common.constants import CLIENT_ERROR_CODE, GENERIC_API_ERROR_MESSAGE, RATE_LIMIT_ERROR_CODE, SERVER_ERROR_CODE | |
from src.common.utils import validate_env_var | |
class OpenAIConfig: | |
"""Immutable configuration for interacting with the OpenAI TTS API.""" | |
api_key: str = field(init=False) | |
model: str = "gpt-4o-mini-tts" | |
response_format: Literal['mp3', 'opus', 'aac', 'flac', 'wav', 'pcm'] = "mp3" | |
def __post_init__(self) -> None: | |
"""Validate required attributes and set computed fields.""" | |
computed_api_key = validate_env_var("OPENAI_API_KEY") | |
object.__setattr__(self, "api_key", computed_api_key) | |
def client(self) -> AsyncOpenAI: | |
""" | |
Lazy initialization of the asynchronous OpenAI client. | |
Returns: | |
AsyncOpenAI: Configured async client instance. | |
""" | |
return AsyncOpenAI(api_key=self.api_key) | |
def select_random_base_voice() -> str: | |
""" | |
Randomly selects one of OpenAI's base voice options for TTS. | |
OpenAI's Python SDK doesn't export a type for their base voice names, | |
so we use a hardcoded list of the available voice options. | |
Returns: | |
str: A randomly selected OpenAI base voice name (e.g., 'alloy', 'nova', etc.) | |
""" | |
openai_base_voices = ["alloy", "ash", "coral", "echo", "fable", "onyx", "nova", "sage", "shimmer"] | |
return random.choice(openai_base_voices) | |
class OpenAIError(Exception): | |
"""Custom exception for errors related to the OpenAI TTS API.""" | |
def __init__(self, message: str, original_exception: Union[Exception, None] = None): | |
super().__init__(message) | |
self.original_exception = original_exception | |
self.message = message | |
class UnretryableOpenAIError(OpenAIError): | |
"""Custom exception for errors related to the OpenAI TTS API that should not be retried.""" | |
def __init__(self, message: str, original_exception: Union[Exception, None] = None): | |
super().__init__(message, original_exception) | |
self.original_exception = original_exception | |
self.message = message | |
async def text_to_speech_with_openai( | |
character_description: str, | |
text: str, | |
config: Config, | |
) -> Tuple[None, str]: | |
""" | |
Asynchronously synthesizes speech using the OpenAI TTS API, processes audio data, and writes audio to a file. | |
This function uses the OpenAI Python SDK to send a request to the OpenAI TTS API with a character description | |
and text to be converted to speech. It extracts the base64-encoded audio and generation ID from the response, | |
saves the audio as an MP3 file, and returns the relevant details. | |
Args: | |
character_description (str): Description used for voice synthesis. | |
text (str): Text to be converted to speech. | |
config (Config): Application configuration containing OpenAI API settings. | |
Returns: | |
Tuple[str, str]: A tuple containing: | |
- generation_id (str): Unique identifier for the generated audio. | |
- audio_file_path (str): Path to the saved audio file. | |
Raises: | |
OpenAIError: For errors communicating with the OpenAI API. | |
UnretryableOpenAIError: For client-side HTTP errors (status code 4xx). | |
""" | |
logger.debug(f"Synthesizing speech with OpenAI. Text length: {len(text)} characters.") | |
openai_config = config.openai_config | |
client = openai_config.client | |
start_time = time.time() | |
try: | |
voice = openai_config.select_random_base_voice() | |
async with client.audio.speech.with_streaming_response.create( | |
model=openai_config.model, | |
input=text, | |
instructions=character_description, | |
response_format=openai_config.response_format, | |
voice=voice, # OpenAI requires a base voice to be specified | |
) as response: | |
elapsed_time = time.time() - start_time | |
logger.info(f"OpenAI API request completed in {elapsed_time:.2f} seconds.") | |
filename = f"openai_{voice}_{start_time}" | |
audio_file_path = Path(config.audio_dir) / filename | |
await response.stream_to_file(audio_file_path) | |
relative_audio_file_path = audio_file_path.relative_to(Path.cwd()) | |
return None, str(relative_audio_file_path) | |
except APIError as e: | |
elapsed_time = time.time() - start_time | |
logger.error(f"OpenAI API request failed after {elapsed_time:.2f} seconds: {e!s}") | |
logger.error(f"Full OpenAI API error: {e!s}") | |
clean_message = __extract_openai_error_message(e) | |
if hasattr(e, 'status_code') and e.status_code is not None: | |
if e.status_code == RATE_LIMIT_ERROR_CODE: | |
raise OpenAIError(message=clean_message, original_exception=e) from e | |
if CLIENT_ERROR_CODE <= e.status_code < SERVER_ERROR_CODE: | |
raise UnretryableOpenAIError(message=clean_message, original_exception=e) from e | |
raise OpenAIError(message=clean_message, original_exception=e) from e | |
except Exception as e: | |
error_type = type(e).__name__ | |
error_message = str(e) if str(e) else f"An error of type {error_type} occurred" | |
logger.error("Error during OpenAI API call: %s - %s", error_type, error_message) | |
clean_message = GENERIC_API_ERROR_MESSAGE | |
raise OpenAIError(message=clean_message, original_exception=e) from e | |
def __extract_openai_error_message(e: APIError) -> str: | |
""" | |
Extracts a clean, user-friendly error message from an OpenAI API error response. | |
Args: | |
e (APIError): The OpenAI API error exception containing response information. | |
Returns: | |
str: A clean, user-friendly error message suitable for display to end users. | |
""" | |
clean_message = GENERIC_API_ERROR_MESSAGE | |
if hasattr(e, 'body') and isinstance(e.body, dict): | |
error_body = e.body | |
if ( | |
'error' in error_body | |
and isinstance(error_body['error'], dict) | |
and 'message' in error_body['error'] | |
): | |
clean_message = error_body['error']['message'] | |
return clean_message | |