t2 / main.py
99i's picture
Update main.py
5931df7 verified
from fastapi import FastAPI, HTTPException, Request, Body
from fastapi.responses import JSONResponse, FileResponse
from fastapi.templating import Jinja2Templates
import httpx
import os
import uuid
from pathlib import Path
import aiofiles
from typing import Optional, Dict, Any, List
from datetime import datetime, timedelta
import hashlib
import json
from pydub import AudioSegment
import re
app = FastAPI()
templates = Jinja2Templates(directory="templates")
# 配置
HOST = os.getenv("TTS_HOST", "https://99i-tts.hf.space")
AUDIO_DIR = "audio_files"
CACHE_EXPIRE_HOURS = 24
MAX_SEGMENT_LENGTH = 300
Path(AUDIO_DIR).mkdir(exist_ok=True)
# 内存缓存
cache_store = {}
def get_cache_key(params: Dict[str, Any]) -> str:
"""生成缓存键"""
param_str = json.dumps(params, sort_keys=True)
return hashlib.md5(param_str.encode()).hexdigest()
def validate_tts_params(params: Dict[str, Any]) -> Dict[str, str]:
"""验证并转换TTS参数"""
if not params.get("t"):
raise HTTPException(status_code=400, detail="文本内容不能为空")
try:
return {
"t": str(params.get("t", "")),
"v": str(params.get("v", "zh-CN-XiaoxiaoMultilingualNeural")),
"r": str(float(params.get("r", 0))),
"p": str(float(params.get("p", 0))),
"o": str(params.get("o", "audio-24khz-48kbitrate-mono-mp3"))
}
except ValueError:
raise HTTPException(status_code=400, detail="语速和音调参数必须是有效的数字")
def get_from_cache(cache_key: str) -> Optional[str]:
"""从缓存获取音频文件名"""
if cache_key in cache_store:
cached_item = cache_store[cache_key]
if datetime.now() < cached_item["expire_time"]:
return cached_item["file_name"]
return None
def update_cache(cache_key: str, file_name: str) -> None:
"""更新缓存"""
cache_store[cache_key] = {
"file_name": file_name,
"expire_time": datetime.now() + timedelta(hours=CACHE_EXPIRE_HOURS)
}
async def make_tts_request(params: Dict[str, str]) -> bytes:
"""发送TTS请求并获取音频数据"""
async with httpx.AsyncClient() as client:
try:
print("\n=== TTS请求信息 ===")
print(f"请求URL: {HOST}/tts")
print(f"请求参数: {json.dumps(params, ensure_ascii=False, indent=2)}")
response = await client.post(f"{HOST}/tts", json=params)
print(f"响应状态码: {response.status_code}")
print("=== TTS请求完成 ===\n")
response.raise_for_status()
return response.content
except httpx.RequestError as e:
raise HTTPException(status_code=503, detail=f"请求TTS服务失败: {str(e)}")
except httpx.HTTPStatusError as e:
error_detail = "未知错误"
try:
error_json = e.response.json()
if isinstance(error_json, dict):
error_detail = error_json.get("detail") or error_json.get("message") or str(error_json)
except:
error_detail = e.response.text or str(e)
raise HTTPException(status_code=e.response.status_code, detail=f"TTS服务返回错误: {error_detail}")
async def save_audio_file(content: bytes) -> str:
"""保存音频文件并返回文件名"""
file_name = f"{uuid.uuid4()}.mp3"
file_path = Path(AUDIO_DIR) / file_name
try:
async with aiofiles.open(file_path, "wb") as f:
await f.write(content)
return file_name
except IOError as e:
raise HTTPException(status_code=500, detail=f"保存音频文件失败: {str(e)}")
async def download_audio(params: dict) -> str:
"""下载音频文件并保存到本地"""
params = validate_tts_params(params)
cache_key = get_cache_key(params)
# 检查缓存
cached_file = get_from_cache(cache_key)
if cached_file:
return cached_file
# 下载并保存音频
content = await make_tts_request(params)
file_name = await save_audio_file(content)
update_cache(cache_key, file_name)
return file_name
def generate_audio_url(request: Request, file_name: str) -> str:
"""生成音频URL"""
base_url = str(request.base_url)
return f"{base_url}audio/{file_name}"
def split_text(text: str) -> List[str]:
"""将长文本分割成较短的段落"""
segments = re.split(r'([。!?.!?])', text)
segments = [''.join(i) for i in zip(segments[0::2], segments[1::2] + [''] * (len(segments[0::2]) - len(segments[1::2])))]
result = []
current_segment = ""
for segment in segments:
segment = segment.strip()
if not segment:
continue
if len(current_segment) + len(segment) <= MAX_SEGMENT_LENGTH:
current_segment += segment
else:
if current_segment:
result.append(current_segment)
if len(segment) > MAX_SEGMENT_LENGTH:
for i in range(0, len(segment), MAX_SEGMENT_LENGTH):
result.append(segment[i:i + MAX_SEGMENT_LENGTH])
current_segment = ""
else:
current_segment = segment
if current_segment:
result.append(current_segment)
return result
async def merge_audio_files(file_names: List[str]) -> str:
"""合并多个音频文件,在音频间添加静音间隔"""
combined = AudioSegment.empty()
silence = AudioSegment.silent(duration=500) # 创建500毫秒的静音
for i, file_name in enumerate(file_names):
file_path = Path(AUDIO_DIR) / file_name
audio = AudioSegment.from_mp3(str(file_path))
if i > 0: # 从第二个音频开始,在前面添加静音
combined += silence
combined += audio
output_file_name = f"{uuid.uuid4()}.mp3"
output_path = Path(AUDIO_DIR) / output_file_name
combined.export(str(output_path), format="mp3")
return output_file_name
@app.get("/tts")
async def text_to_speech_get(
request: Request,
t: str,
v: Optional[str] = "zh-CN-XiaoxiaoMultilingualNeural",
r: Optional[str] = "0",
p: Optional[str] = "0",
o: Optional[str] = "audio-24khz-48kbitrate-mono-mp3"
):
"""GET方式转发TTS请求并返回在线播放链接"""
try:
file_name = await download_audio({"t": t, "v": v, "r": r, "p": p, "o": o})
audio_url = generate_audio_url(request, file_name)
return {"status": "success", "audio_url": audio_url, "cached": False}
except Exception as e:
return JSONResponse(status_code=500, content={"status": "error", "message": str(e)})
@app.post("/tts")
async def text_to_speech_post(request: Request, data: Dict[str, Any] = Body(...)):
"""POST方式转发TTS请求并返回在线播放链接"""
try:
file_name = await download_audio(data)
audio_url = generate_audio_url(request, file_name)
cache_key = get_cache_key(data)
cached = cache_key in cache_store and datetime.now() < cache_store[cache_key]["expire_time"]
return {"status": "success", "audio_url": audio_url, "cached": cached}
except Exception as e:
return JSONResponse(status_code=500, content={"status": "error", "message": str(e)})
@app.post("/ltts")
async def text_to_speech_long(request: Request, data: Dict[str, Any] = Body(...)):
"""处理长文本TTS请求"""
text = data.get("t")
if not text:
raise HTTPException(status_code=400, detail="Parameter 't' (text) is required")
try:
segments = split_text(text)
if not segments:
raise ValueError("文本分割后没有有效内容")
file_names = []
base_params = {k: v for k, v in data.items() if k != "t"}
for i, segment in enumerate(segments):
try:
segment_params = {**base_params, "t": segment}
file_name = await download_audio(segment_params)
file_names.append(file_name)
except Exception as e:
return JSONResponse(
status_code=500,
content={"status": "error", "message": f"第{i+1}段文本转语音失败: {str(e)}", "segment": segment}
)
final_file_name = await merge_audio_files(file_names)
audio_url = generate_audio_url(request, final_file_name)
return {"status": "success", "audio_url": audio_url, "segments_count": len(segments)}
except ValueError as e:
return JSONResponse(status_code=400, content={"status": "error", "message": str(e)})
except Exception as e:
return JSONResponse(status_code=500, content={"status": "error", "message": str(e)})
@app.post("/mtts")
async def multi_voice_tts(request: Request, data: List[Dict[str, Any]] = Body(...)):
"""处理多角色朗读请求
接收格式:[{"v":"voice1", "t":"text1", "r":"rate1"}, {"v":"voice2", "t":"text2", "r":"rate2"}]
"""
if not data or not isinstance(data, list):
raise HTTPException(status_code=400, detail="请求数据必须是包含多个角色配置的数组")
try:
file_names = []
for i, voice_config in enumerate(data):
try:
# 为每个角色生成音频
params = {
"t": voice_config.get("t", ""),
"v": voice_config.get("v", "zh-CN-XiaoxiaoMultilingualNeural"),
"r": voice_config.get("r", "0"),
"p": voice_config.get("p", "0"),
"o": voice_config.get("o", "audio-24khz-48kbitrate-mono-mp3")
}
file_name = await download_audio(params)
file_names.append(file_name)
except Exception as e:
return JSONResponse(
status_code=500,
content={"status": "error", "message": f"第{i+1}个角色的语音生成失败: {str(e)}", "voice_config": voice_config}
)
# 合并所有角色的音频
final_file_name = await merge_audio_files(file_names)
audio_url = generate_audio_url(request, final_file_name)
return {"status": "success", "audio_url": audio_url, "voices_count": len(data)}
except Exception as e:
return JSONResponse(status_code=500, content={"status": "error", "message": str(e)})
@app.get("/audio/{file_name}")
async def get_audio(file_name: str):
"""返回音频文件"""
file_path = Path(AUDIO_DIR) / file_name
if not file_path.exists():
raise HTTPException(status_code=404, detail="Audio file not found")
return FileResponse(file_path, media_type="audio/mpeg")
@app.get("/voices")
async def get_voices(l: Optional[str] = "zh", d: Optional[bool] = False):
"""获取语音列表"""
params = {}
if l:
params["l"] = l
if d:
params["d"] = ""
async with httpx.AsyncClient() as client:
try:
response = await client.get(f"{HOST}/voices", params=params)
response.raise_for_status()
return response.json()
except httpx.HTTPStatusError as e:
raise HTTPException(status_code=e.response.status_code, detail=str(e))
@app.get("/cache/status")
async def cache_status():
"""获取缓存状态"""
return {"cache_count": len(cache_store), "expire_hours": CACHE_EXPIRE_HOURS}
@app.delete("/cache/clear")
async def clear_cache():
"""清除所有缓存"""
global cache_store
cache_store = {}
return {"status": "success", "message": "Cache cleared"}
@app.get("/")
async def read_root(request: Request):
"""主页,展示语音列表"""
return templates.TemplateResponse("index.html", {"request": request})
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)