|
from fastapi import FastAPI, HTTPException, Request, Body |
|
from fastapi.responses import JSONResponse, FileResponse |
|
from fastapi.templating import Jinja2Templates |
|
import httpx |
|
import os |
|
import uuid |
|
from pathlib import Path |
|
import aiofiles |
|
from typing import Optional, Dict, Any, List |
|
from datetime import datetime, timedelta |
|
import hashlib |
|
import json |
|
from pydub import AudioSegment |
|
import re |
|
|
|
app = FastAPI() |
|
templates = Jinja2Templates(directory="templates") |
|
|
|
|
|
HOST = os.getenv("TTS_HOST", "https://99i-tts.hf.space") |
|
AUDIO_DIR = "audio_files" |
|
CACHE_EXPIRE_HOURS = 24 |
|
MAX_SEGMENT_LENGTH = 300 |
|
Path(AUDIO_DIR).mkdir(exist_ok=True) |
|
|
|
|
|
cache_store = {} |
|
|
|
def get_cache_key(params: Dict[str, Any]) -> str: |
|
"""生成缓存键""" |
|
param_str = json.dumps(params, sort_keys=True) |
|
return hashlib.md5(param_str.encode()).hexdigest() |
|
|
|
def validate_tts_params(params: Dict[str, Any]) -> Dict[str, str]: |
|
"""验证并转换TTS参数""" |
|
if not params.get("t"): |
|
raise HTTPException(status_code=400, detail="文本内容不能为空") |
|
|
|
try: |
|
return { |
|
"t": str(params.get("t", "")), |
|
"v": str(params.get("v", "zh-CN-XiaoxiaoMultilingualNeural")), |
|
"r": str(float(params.get("r", 0))), |
|
"p": str(float(params.get("p", 0))), |
|
"o": str(params.get("o", "audio-24khz-48kbitrate-mono-mp3")) |
|
} |
|
except ValueError: |
|
raise HTTPException(status_code=400, detail="语速和音调参数必须是有效的数字") |
|
|
|
def get_from_cache(cache_key: str) -> Optional[str]: |
|
"""从缓存获取音频文件名""" |
|
if cache_key in cache_store: |
|
cached_item = cache_store[cache_key] |
|
if datetime.now() < cached_item["expire_time"]: |
|
return cached_item["file_name"] |
|
return None |
|
|
|
def update_cache(cache_key: str, file_name: str) -> None: |
|
"""更新缓存""" |
|
cache_store[cache_key] = { |
|
"file_name": file_name, |
|
"expire_time": datetime.now() + timedelta(hours=CACHE_EXPIRE_HOURS) |
|
} |
|
|
|
async def make_tts_request(params: Dict[str, str]) -> bytes: |
|
"""发送TTS请求并获取音频数据""" |
|
async with httpx.AsyncClient() as client: |
|
try: |
|
print("\n=== TTS请求信息 ===") |
|
print(f"请求URL: {HOST}/tts") |
|
print(f"请求参数: {json.dumps(params, ensure_ascii=False, indent=2)}") |
|
response = await client.post(f"{HOST}/tts", json=params) |
|
print(f"响应状态码: {response.status_code}") |
|
print("=== TTS请求完成 ===\n") |
|
response.raise_for_status() |
|
return response.content |
|
except httpx.RequestError as e: |
|
raise HTTPException(status_code=503, detail=f"请求TTS服务失败: {str(e)}") |
|
except httpx.HTTPStatusError as e: |
|
error_detail = "未知错误" |
|
try: |
|
error_json = e.response.json() |
|
if isinstance(error_json, dict): |
|
error_detail = error_json.get("detail") or error_json.get("message") or str(error_json) |
|
except: |
|
error_detail = e.response.text or str(e) |
|
raise HTTPException(status_code=e.response.status_code, detail=f"TTS服务返回错误: {error_detail}") |
|
|
|
async def save_audio_file(content: bytes) -> str: |
|
"""保存音频文件并返回文件名""" |
|
file_name = f"{uuid.uuid4()}.mp3" |
|
file_path = Path(AUDIO_DIR) / file_name |
|
try: |
|
async with aiofiles.open(file_path, "wb") as f: |
|
await f.write(content) |
|
return file_name |
|
except IOError as e: |
|
raise HTTPException(status_code=500, detail=f"保存音频文件失败: {str(e)}") |
|
|
|
async def download_audio(params: dict) -> str: |
|
"""下载音频文件并保存到本地""" |
|
params = validate_tts_params(params) |
|
cache_key = get_cache_key(params) |
|
|
|
|
|
cached_file = get_from_cache(cache_key) |
|
if cached_file: |
|
return cached_file |
|
|
|
|
|
content = await make_tts_request(params) |
|
file_name = await save_audio_file(content) |
|
update_cache(cache_key, file_name) |
|
return file_name |
|
|
|
def generate_audio_url(request: Request, file_name: str) -> str: |
|
"""生成音频URL""" |
|
base_url = str(request.base_url) |
|
return f"{base_url}audio/{file_name}" |
|
|
|
def split_text(text: str) -> List[str]: |
|
"""将长文本分割成较短的段落""" |
|
segments = re.split(r'([。!?.!?])', text) |
|
segments = [''.join(i) for i in zip(segments[0::2], segments[1::2] + [''] * (len(segments[0::2]) - len(segments[1::2])))] |
|
|
|
result = [] |
|
current_segment = "" |
|
|
|
for segment in segments: |
|
segment = segment.strip() |
|
if not segment: |
|
continue |
|
|
|
if len(current_segment) + len(segment) <= MAX_SEGMENT_LENGTH: |
|
current_segment += segment |
|
else: |
|
if current_segment: |
|
result.append(current_segment) |
|
if len(segment) > MAX_SEGMENT_LENGTH: |
|
for i in range(0, len(segment), MAX_SEGMENT_LENGTH): |
|
result.append(segment[i:i + MAX_SEGMENT_LENGTH]) |
|
current_segment = "" |
|
else: |
|
current_segment = segment |
|
|
|
if current_segment: |
|
result.append(current_segment) |
|
|
|
return result |
|
|
|
async def merge_audio_files(file_names: List[str]) -> str: |
|
"""合并多个音频文件,在音频间添加静音间隔""" |
|
combined = AudioSegment.empty() |
|
silence = AudioSegment.silent(duration=500) |
|
|
|
for i, file_name in enumerate(file_names): |
|
file_path = Path(AUDIO_DIR) / file_name |
|
audio = AudioSegment.from_mp3(str(file_path)) |
|
if i > 0: |
|
combined += silence |
|
combined += audio |
|
|
|
output_file_name = f"{uuid.uuid4()}.mp3" |
|
output_path = Path(AUDIO_DIR) / output_file_name |
|
combined.export(str(output_path), format="mp3") |
|
return output_file_name |
|
|
|
@app.get("/tts") |
|
async def text_to_speech_get( |
|
request: Request, |
|
t: str, |
|
v: Optional[str] = "zh-CN-XiaoxiaoMultilingualNeural", |
|
r: Optional[str] = "0", |
|
p: Optional[str] = "0", |
|
o: Optional[str] = "audio-24khz-48kbitrate-mono-mp3" |
|
): |
|
"""GET方式转发TTS请求并返回在线播放链接""" |
|
try: |
|
file_name = await download_audio({"t": t, "v": v, "r": r, "p": p, "o": o}) |
|
audio_url = generate_audio_url(request, file_name) |
|
return {"status": "success", "audio_url": audio_url, "cached": False} |
|
except Exception as e: |
|
return JSONResponse(status_code=500, content={"status": "error", "message": str(e)}) |
|
|
|
@app.post("/tts") |
|
async def text_to_speech_post(request: Request, data: Dict[str, Any] = Body(...)): |
|
"""POST方式转发TTS请求并返回在线播放链接""" |
|
try: |
|
file_name = await download_audio(data) |
|
audio_url = generate_audio_url(request, file_name) |
|
cache_key = get_cache_key(data) |
|
cached = cache_key in cache_store and datetime.now() < cache_store[cache_key]["expire_time"] |
|
return {"status": "success", "audio_url": audio_url, "cached": cached} |
|
except Exception as e: |
|
return JSONResponse(status_code=500, content={"status": "error", "message": str(e)}) |
|
|
|
@app.post("/ltts") |
|
async def text_to_speech_long(request: Request, data: Dict[str, Any] = Body(...)): |
|
"""处理长文本TTS请求""" |
|
text = data.get("t") |
|
if not text: |
|
raise HTTPException(status_code=400, detail="Parameter 't' (text) is required") |
|
|
|
try: |
|
segments = split_text(text) |
|
if not segments: |
|
raise ValueError("文本分割后没有有效内容") |
|
|
|
file_names = [] |
|
base_params = {k: v for k, v in data.items() if k != "t"} |
|
|
|
for i, segment in enumerate(segments): |
|
try: |
|
segment_params = {**base_params, "t": segment} |
|
file_name = await download_audio(segment_params) |
|
file_names.append(file_name) |
|
except Exception as e: |
|
return JSONResponse( |
|
status_code=500, |
|
content={"status": "error", "message": f"第{i+1}段文本转语音失败: {str(e)}", "segment": segment} |
|
) |
|
|
|
final_file_name = await merge_audio_files(file_names) |
|
audio_url = generate_audio_url(request, final_file_name) |
|
return {"status": "success", "audio_url": audio_url, "segments_count": len(segments)} |
|
except ValueError as e: |
|
return JSONResponse(status_code=400, content={"status": "error", "message": str(e)}) |
|
except Exception as e: |
|
return JSONResponse(status_code=500, content={"status": "error", "message": str(e)}) |
|
|
|
@app.post("/mtts") |
|
async def multi_voice_tts(request: Request, data: List[Dict[str, Any]] = Body(...)): |
|
"""处理多角色朗读请求 |
|
接收格式:[{"v":"voice1", "t":"text1", "r":"rate1"}, {"v":"voice2", "t":"text2", "r":"rate2"}] |
|
""" |
|
if not data or not isinstance(data, list): |
|
raise HTTPException(status_code=400, detail="请求数据必须是包含多个角色配置的数组") |
|
|
|
try: |
|
file_names = [] |
|
for i, voice_config in enumerate(data): |
|
try: |
|
|
|
params = { |
|
"t": voice_config.get("t", ""), |
|
"v": voice_config.get("v", "zh-CN-XiaoxiaoMultilingualNeural"), |
|
"r": voice_config.get("r", "0"), |
|
"p": voice_config.get("p", "0"), |
|
"o": voice_config.get("o", "audio-24khz-48kbitrate-mono-mp3") |
|
} |
|
file_name = await download_audio(params) |
|
file_names.append(file_name) |
|
except Exception as e: |
|
return JSONResponse( |
|
status_code=500, |
|
content={"status": "error", "message": f"第{i+1}个角色的语音生成失败: {str(e)}", "voice_config": voice_config} |
|
) |
|
|
|
|
|
final_file_name = await merge_audio_files(file_names) |
|
audio_url = generate_audio_url(request, final_file_name) |
|
return {"status": "success", "audio_url": audio_url, "voices_count": len(data)} |
|
except Exception as e: |
|
return JSONResponse(status_code=500, content={"status": "error", "message": str(e)}) |
|
|
|
@app.get("/audio/{file_name}") |
|
async def get_audio(file_name: str): |
|
"""返回音频文件""" |
|
file_path = Path(AUDIO_DIR) / file_name |
|
if not file_path.exists(): |
|
raise HTTPException(status_code=404, detail="Audio file not found") |
|
return FileResponse(file_path, media_type="audio/mpeg") |
|
|
|
@app.get("/voices") |
|
async def get_voices(l: Optional[str] = "zh", d: Optional[bool] = False): |
|
"""获取语音列表""" |
|
params = {} |
|
if l: |
|
params["l"] = l |
|
if d: |
|
params["d"] = "" |
|
|
|
async with httpx.AsyncClient() as client: |
|
try: |
|
response = await client.get(f"{HOST}/voices", params=params) |
|
response.raise_for_status() |
|
return response.json() |
|
except httpx.HTTPStatusError as e: |
|
raise HTTPException(status_code=e.response.status_code, detail=str(e)) |
|
|
|
@app.get("/cache/status") |
|
async def cache_status(): |
|
"""获取缓存状态""" |
|
return {"cache_count": len(cache_store), "expire_hours": CACHE_EXPIRE_HOURS} |
|
|
|
@app.delete("/cache/clear") |
|
async def clear_cache(): |
|
"""清除所有缓存""" |
|
global cache_store |
|
cache_store = {} |
|
return {"status": "success", "message": "Cache cleared"} |
|
|
|
@app.get("/") |
|
async def read_root(request: Request): |
|
"""主页,展示语音列表""" |
|
return templates.TemplateResponse("index.html", {"request": request}) |
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
uvicorn.run(app, host="0.0.0.0", port=7860) |
|
|