File size: 6,882 Bytes
9ed181c
 
 
 
 
 
 
 
 
 
 
 
 
5ed9749
 
 
9ed181c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5ed9749
9ed181c
 
 
 
 
 
 
 
 
 
 
 
5ed9749
9ed181c
5ed9749
 
 
 
 
9ed181c
 
 
 
 
 
 
 
 
 
 
5ed9749
9ed181c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
# Standard Library Imports
import logging
import random
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Literal, Tuple, Union

# Third-Party Library Imports
from openai import APIError, AsyncOpenAI
from tenacity import after_log, before_log, retry, retry_if_exception, stop_after_attempt, wait_fixed

# Local Application Imports
from src.common import Config, logger
from src.common.constants import CLIENT_ERROR_CODE, GENERIC_API_ERROR_MESSAGE, RATE_LIMIT_ERROR_CODE, SERVER_ERROR_CODE
from src.common.utils import validate_env_var


@dataclass(frozen=True)
class OpenAIConfig:
    """Immutable configuration for interacting with the OpenAI TTS API."""

    api_key: str = field(init=False)
    model: str = "gpt-4o-mini-tts"
    response_format: Literal['mp3', 'opus', 'aac', 'flac', 'wav', 'pcm'] = "mp3"

    def __post_init__(self) -> None:
        """Validate required attributes and set computed fields."""

        computed_api_key = validate_env_var("OPENAI_API_KEY")
        object.__setattr__(self, "api_key", computed_api_key)

    @property
    def client(self) -> AsyncOpenAI:
        """
        Lazy initialization of the asynchronous OpenAI client.

        Returns:
            AsyncOpenAI: Configured async client instance.
        """
        return AsyncOpenAI(api_key=self.api_key)

    @staticmethod
    def select_random_base_voice() -> str:
        """
        Randomly selects one of OpenAI's base voice options for TTS.

        OpenAI's Python SDK doesn't export a type for their base voice names,
        so we use a hardcoded list of the available voice options.

        Returns:
            str: A randomly selected OpenAI base voice name (e.g., 'alloy', 'nova', etc.)
        """
        openai_base_voices = ["alloy", "ash", "coral", "echo", "fable", "onyx", "nova", "sage", "shimmer"]
        return random.choice(openai_base_voices)

class OpenAIError(Exception):
    """Custom exception for errors related to the OpenAI TTS API."""

    def __init__(self, message: str, original_exception: Union[Exception, None] = None):
        super().__init__(message)
        self.original_exception = original_exception
        self.message = message

class UnretryableOpenAIError(OpenAIError):
    """Custom exception for errors related to the OpenAI TTS API that should not be retried."""

    def __init__(self, message: str, original_exception: Union[Exception, None] = None):
        super().__init__(message, original_exception)
        self.original_exception = original_exception
        self.message = message

@retry(
    retry=retry_if_exception(lambda e: not isinstance(e, UnretryableOpenAIError)),
    stop=stop_after_attempt(2),
    wait=wait_fixed(2),
    before=before_log(logger, logging.DEBUG),
    after=after_log(logger, logging.DEBUG),
    reraise=True,
)
async def text_to_speech_with_openai(
    character_description: str,
    text: str,
    config: Config,
) -> Tuple[None, str]:
    """
    Asynchronously synthesizes speech using the OpenAI TTS API, processes audio data, and writes audio to a file.

    This function uses the OpenAI Python SDK to send a request to the OpenAI TTS API with a character description
    and text to be converted to speech. It extracts the base64-encoded audio and generation ID from the response,
    saves the audio as an MP3 file, and returns the relevant details.

    Args:
        character_description (str): Description used for voice synthesis.
        text (str): Text to be converted to speech.
        config (Config): Application configuration containing OpenAI API settings.

    Returns:
        Tuple[str, str]: A tuple containing:
            - generation_id (str): Unique identifier for the generated audio.
            - audio_file_path (str): Path to the saved audio file.

    Raises:
        OpenAIError: For errors communicating with the OpenAI API.
        UnretryableOpenAIError: For client-side HTTP errors (status code 4xx).
    """
    logger.debug(f"Synthesizing speech with OpenAI. Text length: {len(text)} characters.")
    openai_config = config.openai_config
    client = openai_config.client
    start_time = time.time()
    try:
        voice = openai_config.select_random_base_voice()
        async with client.audio.speech.with_streaming_response.create(
            model=openai_config.model,
            input=text,
            instructions=character_description,
            response_format=openai_config.response_format,
            voice=voice, # OpenAI requires a base voice to be specified
        ) as response:
            elapsed_time = time.time() - start_time
            logger.info(f"OpenAI API request completed in {elapsed_time:.2f} seconds.")

            filename = f"openai_{voice}_{start_time}"
            audio_file_path = Path(config.audio_dir) / filename
            await response.stream_to_file(audio_file_path)
            relative_audio_file_path = audio_file_path.relative_to(Path.cwd())

            return None, str(relative_audio_file_path)

    except APIError as e:
        elapsed_time = time.time() - start_time
        logger.error(f"OpenAI API request failed after {elapsed_time:.2f} seconds: {e!s}")
        logger.error(f"Full OpenAI API error: {e!s}")
        clean_message = __extract_openai_error_message(e)

        if hasattr(e, 'status_code') and  e.status_code is not None:
            if e.status_code == RATE_LIMIT_ERROR_CODE:
                raise OpenAIError(message=clean_message, original_exception=e) from e
            if CLIENT_ERROR_CODE <= e.status_code < SERVER_ERROR_CODE:
                raise UnretryableOpenAIError(message=clean_message, original_exception=e) from e

        raise OpenAIError(message=clean_message, original_exception=e) from e

    except Exception as e:
        error_type = type(e).__name__
        error_message = str(e) if str(e) else f"An error of type {error_type} occurred"
        logger.error("Error during OpenAI API call: %s - %s", error_type, error_message)
        clean_message = GENERIC_API_ERROR_MESSAGE

        raise OpenAIError(message=clean_message, original_exception=e) from e

def __extract_openai_error_message(e: APIError) -> str:
    """
    Extracts a clean, user-friendly error message from an OpenAI API error response.

    Args:
        e (APIError): The OpenAI API error exception containing response information.

    Returns:
        str: A clean, user-friendly error message suitable for display to end users.
    """
    clean_message = GENERIC_API_ERROR_MESSAGE

    if hasattr(e, 'body') and isinstance(e.body, dict):
        error_body = e.body
        if (
            'error' in error_body
            and isinstance(error_body['error'], dict)
            and 'message' in error_body['error']
        ):
            clean_message = error_body['error']['message']

    return clean_message