File size: 5,222 Bytes
87337b1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import asyncio
from dataclasses import dataclass
import aiohttp
import json
from datetime import datetime
from typing import AsyncIterator
from ten.async_ten_env import AsyncTenEnv
from ten_ai_base.config import BaseConfig
@dataclass
class MinimaxTTSConfig(BaseConfig):
api_key: str = ""
model: str = "speech-01-turbo"
voice_id: str = "male-qn-qingse"
sample_rate: int = 32000
url: str = "https://api.minimax.chat/v1/t2a_v2"
group_id: str = ""
request_timeout_seconds: int = 10
class MinimaxTTS:
def __init__(self, config: MinimaxTTSConfig):
self.config = config
async def get(self, ten_env: AsyncTenEnv, text: str) -> AsyncIterator[bytes]:
payload = json.dumps(
{
"model": self.config.model,
"text": text,
"stream": True,
"voice_setting": {
"voice_id": self.config.voice_id,
"speed": 1.0,
"vol": 1.0,
"pitch": 0,
},
"pronunciation_dict": {"tone": []},
"audio_setting": {
"sample_rate": self.config.sample_rate,
"format": "pcm",
"channel": 1,
},
}
)
url = f"{self.config.url}?GroupId={self.config.group_id}"
headers = {
"accept": "application/json, text/plain, */*",
"Authorization": f"Bearer {self.config.api_key}",
"Content-Type": "application/json",
}
start_time = datetime.now()
ten_env.log_info(f"Start request, url: {self.config.url}, text: {text}")
ttfb = None
async with aiohttp.ClientSession() as session:
try:
async with session.post(url, headers=headers, data=payload) as response:
trace_id = ""
alb_receive_time = ""
try:
trace_id = response.headers.get("Trace-Id")
except Exception:
ten_env.log_warn("get response, no Trace-Id")
try:
alb_receive_time = response.headers.get("alb_receive_time")
except Exception:
ten_env.log_warn("get response, no alb_receive_time")
ten_env.log_info(
f"get response trace-id: {trace_id}, alb_receive_time: {alb_receive_time}, cost_time {self._duration_in_ms_since(start_time)}ms"
)
if response.status != 200:
raise RuntimeError(
f"Request failed with status {response.status}"
)
buffer = b""
async for chunk in response.content.iter_chunked(
1024
): # Read in 1024 byte chunks
buffer += chunk
# Split the buffer into lines based on newline character
while b"\n" in buffer:
line, buffer = buffer.split(b"\n", 1)
# Process only lines that start with "data:"
if line.startswith(b"data:"):
try:
json_data = json.loads(
line[5:].decode("utf-8").strip()
)
# Check for the required keys in the JSON data
if (
"data" in json_data
and "extra_info" not in json_data
):
audio = json_data["data"].get("audio")
if audio:
decoded_hex = bytes.fromhex(audio)
yield decoded_hex
except (json.JSONDecodeError, UnicodeDecodeError) as e:
# Handle malformed JSON or decoding errors
ten_env.log_warn(f"Error decoding line: {e}")
continue
if not ttfb:
ttfb = self._duration_in_ms_since(start_time)
ten_env.log_info(f"trace-id: {trace_id}, ttfb {ttfb}ms")
except aiohttp.ClientError as e:
ten_env.log_error(f"Client error occurred: {e}")
except asyncio.TimeoutError:
ten_env.log_error("Request timed out")
finally:
ten_env.log_info(
f"http loop done, cost_time {self._duration_in_ms_since(start_time)}ms"
)
def _duration_in_ms(self, start: datetime, end: datetime) -> int:
return int((end - start).total_seconds() * 1000)
def _duration_in_ms_since(self, start: datetime) -> int:
return self._duration_in_ms(start, datetime.now())
|