File size: 11,968 Bytes
5931df7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
from fastapi import FastAPI, HTTPException, Request, Body
from fastapi.responses import JSONResponse, FileResponse
from fastapi.templating import Jinja2Templates
import httpx
import os
import uuid
from pathlib import Path
import aiofiles
from typing import Optional, Dict, Any, List
from datetime import datetime, timedelta
import hashlib
import json
from pydub import AudioSegment
import re

app = FastAPI()
templates = Jinja2Templates(directory="templates")

# 配置
HOST = os.getenv("TTS_HOST", "https://99i-tts.hf.space")
AUDIO_DIR = "audio_files"
CACHE_EXPIRE_HOURS = 24
MAX_SEGMENT_LENGTH = 300
Path(AUDIO_DIR).mkdir(exist_ok=True)

# 内存缓存
cache_store = {}

def get_cache_key(params: Dict[str, Any]) -> str:
    """生成缓存键"""
    param_str = json.dumps(params, sort_keys=True)
    return hashlib.md5(param_str.encode()).hexdigest()

def validate_tts_params(params: Dict[str, Any]) -> Dict[str, str]:
    """验证并转换TTS参数"""
    if not params.get("t"):
        raise HTTPException(status_code=400, detail="文本内容不能为空")
    
    try:
        return {
            "t": str(params.get("t", "")),
            "v": str(params.get("v", "zh-CN-XiaoxiaoMultilingualNeural")),
            "r": str(float(params.get("r", 0))),
            "p": str(float(params.get("p", 0))),
            "o": str(params.get("o", "audio-24khz-48kbitrate-mono-mp3"))
        }
    except ValueError:
        raise HTTPException(status_code=400, detail="语速和音调参数必须是有效的数字")

def get_from_cache(cache_key: str) -> Optional[str]:
    """从缓存获取音频文件名"""
    if cache_key in cache_store:
        cached_item = cache_store[cache_key]
        if datetime.now() < cached_item["expire_time"]:
            return cached_item["file_name"]
    return None

def update_cache(cache_key: str, file_name: str) -> None:
    """更新缓存"""
    cache_store[cache_key] = {
        "file_name": file_name,
        "expire_time": datetime.now() + timedelta(hours=CACHE_EXPIRE_HOURS)
    }

async def make_tts_request(params: Dict[str, str]) -> bytes:
    """发送TTS请求并获取音频数据"""
    async with httpx.AsyncClient() as client:
        try:
            print("\n=== TTS请求信息 ===")
            print(f"请求URL: {HOST}/tts")
            print(f"请求参数: {json.dumps(params, ensure_ascii=False, indent=2)}")
            response = await client.post(f"{HOST}/tts", json=params)
            print(f"响应状态码: {response.status_code}")
            print("=== TTS请求完成 ===\n")
            response.raise_for_status()
            return response.content
        except httpx.RequestError as e:
            raise HTTPException(status_code=503, detail=f"请求TTS服务失败: {str(e)}")
        except httpx.HTTPStatusError as e:
            error_detail = "未知错误"
            try:
                error_json = e.response.json()
                if isinstance(error_json, dict):
                    error_detail = error_json.get("detail") or error_json.get("message") or str(error_json)
            except:
                error_detail = e.response.text or str(e)
            raise HTTPException(status_code=e.response.status_code, detail=f"TTS服务返回错误: {error_detail}")

async def save_audio_file(content: bytes) -> str:
    """保存音频文件并返回文件名"""
    file_name = f"{uuid.uuid4()}.mp3"
    file_path = Path(AUDIO_DIR) / file_name
    try:
        async with aiofiles.open(file_path, "wb") as f:
            await f.write(content)
        return file_name
    except IOError as e:
        raise HTTPException(status_code=500, detail=f"保存音频文件失败: {str(e)}")

async def download_audio(params: dict) -> str:
    """下载音频文件并保存到本地"""
    params = validate_tts_params(params)
    cache_key = get_cache_key(params)
    
    # 检查缓存
    cached_file = get_from_cache(cache_key)
    if cached_file:
        return cached_file
    
    # 下载并保存音频
    content = await make_tts_request(params)
    file_name = await save_audio_file(content)
    update_cache(cache_key, file_name)
    return file_name

def generate_audio_url(request: Request, file_name: str) -> str:
    """生成音频URL"""
    base_url = str(request.base_url)
    return f"{base_url}audio/{file_name}"

def split_text(text: str) -> List[str]:
    """将长文本分割成较短的段落"""
    segments = re.split(r'([。!?.!?])', text)
    segments = [''.join(i) for i in zip(segments[0::2], segments[1::2] + [''] * (len(segments[0::2]) - len(segments[1::2])))]
    
    result = []
    current_segment = ""
    
    for segment in segments:
        segment = segment.strip()
        if not segment:
            continue
            
        if len(current_segment) + len(segment) <= MAX_SEGMENT_LENGTH:
            current_segment += segment
        else:
            if current_segment:
                result.append(current_segment)
            if len(segment) > MAX_SEGMENT_LENGTH:
                for i in range(0, len(segment), MAX_SEGMENT_LENGTH):
                    result.append(segment[i:i + MAX_SEGMENT_LENGTH])
                current_segment = ""
            else:
                current_segment = segment
    
    if current_segment:
        result.append(current_segment)
    
    return result

async def merge_audio_files(file_names: List[str]) -> str:
    """合并多个音频文件,在音频间添加静音间隔"""
    combined = AudioSegment.empty()
    silence = AudioSegment.silent(duration=500)  # 创建500毫秒的静音
    
    for i, file_name in enumerate(file_names):
        file_path = Path(AUDIO_DIR) / file_name
        audio = AudioSegment.from_mp3(str(file_path))
        if i > 0:  # 从第二个音频开始,在前面添加静音
            combined += silence
        combined += audio
    
    output_file_name = f"{uuid.uuid4()}.mp3"
    output_path = Path(AUDIO_DIR) / output_file_name
    combined.export(str(output_path), format="mp3")
    return output_file_name

@app.get("/tts")
async def text_to_speech_get(
    request: Request,
    t: str,
    v: Optional[str] = "zh-CN-XiaoxiaoMultilingualNeural",
    r: Optional[str] = "0",
    p: Optional[str] = "0",
    o: Optional[str] = "audio-24khz-48kbitrate-mono-mp3"
):
    """GET方式转发TTS请求并返回在线播放链接"""
    try:
        file_name = await download_audio({"t": t, "v": v, "r": r, "p": p, "o": o})
        audio_url = generate_audio_url(request, file_name)
        return {"status": "success", "audio_url": audio_url, "cached": False}
    except Exception as e:
        return JSONResponse(status_code=500, content={"status": "error", "message": str(e)})

@app.post("/tts")
async def text_to_speech_post(request: Request, data: Dict[str, Any] = Body(...)):
    """POST方式转发TTS请求并返回在线播放链接"""
    try:
        file_name = await download_audio(data)
        audio_url = generate_audio_url(request, file_name)
        cache_key = get_cache_key(data)
        cached = cache_key in cache_store and datetime.now() < cache_store[cache_key]["expire_time"]
        return {"status": "success", "audio_url": audio_url, "cached": cached}
    except Exception as e:
        return JSONResponse(status_code=500, content={"status": "error", "message": str(e)})

@app.post("/ltts")
async def text_to_speech_long(request: Request, data: Dict[str, Any] = Body(...)):
    """处理长文本TTS请求"""
    text = data.get("t")
    if not text:
        raise HTTPException(status_code=400, detail="Parameter 't' (text) is required")
    
    try:
        segments = split_text(text)
        if not segments:
            raise ValueError("文本分割后没有有效内容")
        
        file_names = []
        base_params = {k: v for k, v in data.items() if k != "t"}
        
        for i, segment in enumerate(segments):
            try:
                segment_params = {**base_params, "t": segment}
                file_name = await download_audio(segment_params)
                file_names.append(file_name)
            except Exception as e:
                return JSONResponse(
                    status_code=500,
                    content={"status": "error", "message": f"第{i+1}段文本转语音失败: {str(e)}", "segment": segment}
                )
        
        final_file_name = await merge_audio_files(file_names)
        audio_url = generate_audio_url(request, final_file_name)
        return {"status": "success", "audio_url": audio_url, "segments_count": len(segments)}
    except ValueError as e:
        return JSONResponse(status_code=400, content={"status": "error", "message": str(e)})
    except Exception as e:
        return JSONResponse(status_code=500, content={"status": "error", "message": str(e)})

@app.post("/mtts")
async def multi_voice_tts(request: Request, data: List[Dict[str, Any]] = Body(...)):
    """处理多角色朗读请求
    接收格式:[{"v":"voice1", "t":"text1", "r":"rate1"}, {"v":"voice2", "t":"text2", "r":"rate2"}]
    """
    if not data or not isinstance(data, list):
        raise HTTPException(status_code=400, detail="请求数据必须是包含多个角色配置的数组")
    
    try:
        file_names = []
        for i, voice_config in enumerate(data):
            try:
                # 为每个角色生成音频
                params = {
                    "t": voice_config.get("t", ""),
                    "v": voice_config.get("v", "zh-CN-XiaoxiaoMultilingualNeural"),
                    "r": voice_config.get("r", "0"),
                    "p": voice_config.get("p", "0"),
                    "o": voice_config.get("o", "audio-24khz-48kbitrate-mono-mp3")
                }
                file_name = await download_audio(params)
                file_names.append(file_name)
            except Exception as e:
                return JSONResponse(
                    status_code=500,
                    content={"status": "error", "message": f"第{i+1}个角色的语音生成失败: {str(e)}", "voice_config": voice_config}
                )
        
        # 合并所有角色的音频
        final_file_name = await merge_audio_files(file_names)
        audio_url = generate_audio_url(request, final_file_name)
        return {"status": "success", "audio_url": audio_url, "voices_count": len(data)}
    except Exception as e:
        return JSONResponse(status_code=500, content={"status": "error", "message": str(e)})

@app.get("/audio/{file_name}")
async def get_audio(file_name: str):
    """返回音频文件"""
    file_path = Path(AUDIO_DIR) / file_name
    if not file_path.exists():
        raise HTTPException(status_code=404, detail="Audio file not found")
    return FileResponse(file_path, media_type="audio/mpeg")

@app.get("/voices")
async def get_voices(l: Optional[str] = "zh", d: Optional[bool] = False):
    """获取语音列表"""
    params = {}
    if l:
        params["l"] = l
    if d:
        params["d"] = ""
    
    async with httpx.AsyncClient() as client:
        try:
            response = await client.get(f"{HOST}/voices", params=params)
            response.raise_for_status()
            return response.json()
        except httpx.HTTPStatusError as e:
            raise HTTPException(status_code=e.response.status_code, detail=str(e))

@app.get("/cache/status")
async def cache_status():
    """获取缓存状态"""
    return {"cache_count": len(cache_store), "expire_hours": CACHE_EXPIRE_HOURS}

@app.delete("/cache/clear")
async def clear_cache():
    """清除所有缓存"""
    global cache_store
    cache_store = {}
    return {"status": "success", "message": "Cache cleared"}

@app.get("/")
async def read_root(request: Request):
    """主页,展示语音列表"""
    return templates.TemplateResponse("index.html", {"request": request})

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)