File size: 5,594 Bytes
62da328
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========

import os
from typing import Any, Optional


class FishAudioModel:
    r"""Provides access to FishAudio's Text-to-Speech (TTS) and Speech_to_Text
    (STT) models.
    """

    def __init__(
        self,
        api_key: Optional[str] = None,
        url: Optional[str] = None,
    ) -> None:
        r"""Initialize an instance of FishAudioModel.

        Args:
            api_key (Optional[str]): API key for FishAudio service. If not
                provided, the environment variable `FISHAUDIO_API_KEY` will be
                used.
            url (Optional[str]): Base URL for FishAudio API. If not provided,
                the environment variable `FISHAUDIO_API_BASE_URL` will be used.
        """
        from fish_audio_sdk import Session

        self._api_key = api_key or os.environ.get("FISHAUDIO_API_KEY")
        self._url = url or os.environ.get(
            "FISHAUDIO_API_BASE_URL", "https://api.fish.audio"
        )
        self.session = Session(apikey=self._api_key, base_url=self._url)
        

    def text_to_speech(
        self,
        input: str,
        storage_path: str,
        reference_id: Optional[str] = None,
        reference_audio: Optional[str] = None,
        reference_audio_text: Optional[str] = None,
        **kwargs: Any,
    ) -> Any:
        r"""Convert text to speech and save the output to a file.

        Args:
            input_text (str): The text to convert to speech.
            storage_path (str): The file path where the resulting speech will
                be saved.
            reference_id (Optional[str]): An optional reference ID to
                associate with the request. (default: :obj:`None`)
            reference_audio (Optional[str]): Path to an audio file for
                reference speech. (default: :obj:`None`)
            reference_audio_text (Optional[str]): Text for the reference audio.
                (default: :obj:`None`)
            **kwargs (Any): Additional parameters to pass to the TTS request.

        Raises:
            FileNotFoundError: If the reference audio file cannot be found.
        """
        from fish_audio_sdk import ReferenceAudio, TTSRequest

        directory = os.path.dirname(storage_path)
        if directory and not os.path.exists(directory):
            os.makedirs(directory)

        if not reference_audio:
            with open(f"{storage_path}", "wb") as f:
                for chunk in self.session.tts(
                    TTSRequest(reference_id=reference_id, text=input, **kwargs)
                ):
                    f.write(chunk)
        else:
            if not os.path.exists(reference_audio):
                raise FileNotFoundError(
                    f"Reference audio file not found: {reference_audio}"
                )
            if not reference_audio_text:
                raise ValueError("reference_audio_text should be provided")
            with open(f"{reference_audio}", "rb") as audio_file:
                with open(f"{storage_path}", "wb") as f:
                    for chunk in self.session.tts(
                        TTSRequest(
                            text=input,
                            references=[
                                ReferenceAudio(
                                    audio=audio_file.read(),
                                    text=reference_audio_text,
                                )
                            ],
                            **kwargs,
                        )
                    ):
                        f.write(chunk)

    def speech_to_text(
        self,
        audio_file_path: str,
        language: Optional[str] = None,
        ignore_timestamps: Optional[bool] = None,
        **kwargs: Any,
    ) -> str:
        r"""Convert speech to text from an audio file.

        Args:
            audio_file_path (str): The path to the audio file to transcribe.
            language (Optional[str]): The language of the audio. (default:
                :obj:`None`)
            ignore_timestamps (Optional[bool]): Whether to ignore timestamps.
                (default: :obj:`None`)
            **kwargs (Any): Additional parameters to pass to the STT request.

        Returns:
            str: The transcribed text from the audio.

        Raises:
            FileNotFoundError: If the audio file cannot be found.
        """
        from fish_audio_sdk import ASRRequest

        if not os.path.exists(audio_file_path):
            raise FileNotFoundError(f"Audio file not found: {audio_file_path}")

        with open(f"{audio_file_path}", "rb") as audio_file:
            audio_data = audio_file.read()

        response = self.session.asr(
            ASRRequest(
                audio=audio_data,
                language=language,
                ignore_timestamps=ignore_timestamps,
                **kwargs,
            )
        )
        return response.text