Update main.py
Browse files
main.py
CHANGED
@@ -1,188 +1,310 @@
|
|
1 |
-
from fastapi import FastAPI, HTTPException, Request, Body
|
2 |
-
from fastapi.responses import JSONResponse, FileResponse
|
3 |
-
from fastapi.templating import Jinja2Templates
|
4 |
-
import httpx
|
5 |
-
import os
|
6 |
-
import uuid
|
7 |
-
from pathlib import Path
|
8 |
-
import aiofiles
|
9 |
-
from typing import Optional, Dict, Any
|
10 |
-
from datetime import datetime, timedelta
|
11 |
-
import hashlib
|
12 |
-
import json
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
return
|
97 |
-
except
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
""
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI, HTTPException, Request, Body
|
2 |
+
from fastapi.responses import JSONResponse, FileResponse
|
3 |
+
from fastapi.templating import Jinja2Templates
|
4 |
+
import httpx
|
5 |
+
import os
|
6 |
+
import uuid
|
7 |
+
from pathlib import Path
|
8 |
+
import aiofiles
|
9 |
+
from typing import Optional, Dict, Any, List
|
10 |
+
from datetime import datetime, timedelta
|
11 |
+
import hashlib
|
12 |
+
import json
|
13 |
+
from pydub import AudioSegment
|
14 |
+
import re
|
15 |
+
|
16 |
+
app = FastAPI()
|
17 |
+
templates = Jinja2Templates(directory="templates")
|
18 |
+
|
19 |
+
# 配置
|
20 |
+
HOST = os.getenv("TTS_HOST", "https://99i-tts.hf.space")
|
21 |
+
AUDIO_DIR = "audio_files"
|
22 |
+
CACHE_EXPIRE_HOURS = 24
|
23 |
+
MAX_SEGMENT_LENGTH = 300
|
24 |
+
Path(AUDIO_DIR).mkdir(exist_ok=True)
|
25 |
+
|
26 |
+
# 内存缓存
|
27 |
+
cache_store = {}
|
28 |
+
|
29 |
+
def get_cache_key(params: Dict[str, Any]) -> str:
|
30 |
+
"""生成缓存键"""
|
31 |
+
param_str = json.dumps(params, sort_keys=True)
|
32 |
+
return hashlib.md5(param_str.encode()).hexdigest()
|
33 |
+
|
34 |
+
def validate_tts_params(params: Dict[str, Any]) -> Dict[str, str]:
|
35 |
+
"""验证并转换TTS参数"""
|
36 |
+
if not params.get("t"):
|
37 |
+
raise HTTPException(status_code=400, detail="文本内容不能为空")
|
38 |
+
|
39 |
+
try:
|
40 |
+
return {
|
41 |
+
"t": str(params.get("t", "")),
|
42 |
+
"v": str(params.get("v", "zh-CN-XiaoxiaoMultilingualNeural")),
|
43 |
+
"r": str(float(params.get("r", 0))),
|
44 |
+
"p": str(float(params.get("p", 0))),
|
45 |
+
"o": str(params.get("o", "audio-24khz-48kbitrate-mono-mp3"))
|
46 |
+
}
|
47 |
+
except ValueError:
|
48 |
+
raise HTTPException(status_code=400, detail="语速和音调参数必须是有效的数字")
|
49 |
+
|
50 |
+
def get_from_cache(cache_key: str) -> Optional[str]:
|
51 |
+
"""从缓存获取音频文件名"""
|
52 |
+
if cache_key in cache_store:
|
53 |
+
cached_item = cache_store[cache_key]
|
54 |
+
if datetime.now() < cached_item["expire_time"]:
|
55 |
+
return cached_item["file_name"]
|
56 |
+
return None
|
57 |
+
|
58 |
+
def update_cache(cache_key: str, file_name: str) -> None:
|
59 |
+
"""更新缓存"""
|
60 |
+
cache_store[cache_key] = {
|
61 |
+
"file_name": file_name,
|
62 |
+
"expire_time": datetime.now() + timedelta(hours=CACHE_EXPIRE_HOURS)
|
63 |
+
}
|
64 |
+
|
65 |
+
async def make_tts_request(params: Dict[str, str]) -> bytes:
|
66 |
+
"""发送TTS请求并获取音频数据"""
|
67 |
+
async with httpx.AsyncClient() as client:
|
68 |
+
try:
|
69 |
+
print("\n=== TTS请求信息 ===")
|
70 |
+
print(f"请求URL: {HOST}/tts")
|
71 |
+
print(f"请求参数: {json.dumps(params, ensure_ascii=False, indent=2)}")
|
72 |
+
response = await client.post(f"{HOST}/tts", json=params)
|
73 |
+
print(f"响应状态码: {response.status_code}")
|
74 |
+
print("=== TTS请求完成 ===\n")
|
75 |
+
response.raise_for_status()
|
76 |
+
return response.content
|
77 |
+
except httpx.RequestError as e:
|
78 |
+
raise HTTPException(status_code=503, detail=f"请求TTS服务失败: {str(e)}")
|
79 |
+
except httpx.HTTPStatusError as e:
|
80 |
+
error_detail = "未知错误"
|
81 |
+
try:
|
82 |
+
error_json = e.response.json()
|
83 |
+
if isinstance(error_json, dict):
|
84 |
+
error_detail = error_json.get("detail") or error_json.get("message") or str(error_json)
|
85 |
+
except:
|
86 |
+
error_detail = e.response.text or str(e)
|
87 |
+
raise HTTPException(status_code=e.response.status_code, detail=f"TTS服务返回错误: {error_detail}")
|
88 |
+
|
89 |
+
async def save_audio_file(content: bytes) -> str:
|
90 |
+
"""保存音频文件并返回文件名"""
|
91 |
+
file_name = f"{uuid.uuid4()}.mp3"
|
92 |
+
file_path = Path(AUDIO_DIR) / file_name
|
93 |
+
try:
|
94 |
+
async with aiofiles.open(file_path, "wb") as f:
|
95 |
+
await f.write(content)
|
96 |
+
return file_name
|
97 |
+
except IOError as e:
|
98 |
+
raise HTTPException(status_code=500, detail=f"保存音频文件失败: {str(e)}")
|
99 |
+
|
100 |
+
async def download_audio(params: dict) -> str:
|
101 |
+
"""下载音频文件并保存到本地"""
|
102 |
+
params = validate_tts_params(params)
|
103 |
+
cache_key = get_cache_key(params)
|
104 |
+
|
105 |
+
# 检查缓存
|
106 |
+
cached_file = get_from_cache(cache_key)
|
107 |
+
if cached_file:
|
108 |
+
return cached_file
|
109 |
+
|
110 |
+
# 下载并保存音频
|
111 |
+
content = await make_tts_request(params)
|
112 |
+
file_name = await save_audio_file(content)
|
113 |
+
update_cache(cache_key, file_name)
|
114 |
+
return file_name
|
115 |
+
|
116 |
+
def generate_audio_url(request: Request, file_name: str) -> str:
|
117 |
+
"""生成音频URL"""
|
118 |
+
base_url = str(request.base_url)
|
119 |
+
return f"{base_url}audio/{file_name}"
|
120 |
+
|
121 |
+
def split_text(text: str) -> List[str]:
|
122 |
+
"""将长文本分割成较短的段落"""
|
123 |
+
segments = re.split(r'([。!?.!?])', text)
|
124 |
+
segments = [''.join(i) for i in zip(segments[0::2], segments[1::2] + [''] * (len(segments[0::2]) - len(segments[1::2])))]
|
125 |
+
|
126 |
+
result = []
|
127 |
+
current_segment = ""
|
128 |
+
|
129 |
+
for segment in segments:
|
130 |
+
segment = segment.strip()
|
131 |
+
if not segment:
|
132 |
+
continue
|
133 |
+
|
134 |
+
if len(current_segment) + len(segment) <= MAX_SEGMENT_LENGTH:
|
135 |
+
current_segment += segment
|
136 |
+
else:
|
137 |
+
if current_segment:
|
138 |
+
result.append(current_segment)
|
139 |
+
if len(segment) > MAX_SEGMENT_LENGTH:
|
140 |
+
for i in range(0, len(segment), MAX_SEGMENT_LENGTH):
|
141 |
+
result.append(segment[i:i + MAX_SEGMENT_LENGTH])
|
142 |
+
current_segment = ""
|
143 |
+
else:
|
144 |
+
current_segment = segment
|
145 |
+
|
146 |
+
if current_segment:
|
147 |
+
result.append(current_segment)
|
148 |
+
|
149 |
+
return result
|
150 |
+
|
151 |
+
async def merge_audio_files(file_names: List[str]) -> str:
|
152 |
+
"""合并多个音频文件,在音频间添加静音间隔"""
|
153 |
+
combined = AudioSegment.empty()
|
154 |
+
silence = AudioSegment.silent(duration=500) # 创建500毫秒的静音
|
155 |
+
|
156 |
+
for i, file_name in enumerate(file_names):
|
157 |
+
file_path = Path(AUDIO_DIR) / file_name
|
158 |
+
audio = AudioSegment.from_mp3(str(file_path))
|
159 |
+
if i > 0: # 从第二个音频开始,在前面添加静音
|
160 |
+
combined += silence
|
161 |
+
combined += audio
|
162 |
+
|
163 |
+
output_file_name = f"{uuid.uuid4()}.mp3"
|
164 |
+
output_path = Path(AUDIO_DIR) / output_file_name
|
165 |
+
combined.export(str(output_path), format="mp3")
|
166 |
+
return output_file_name
|
167 |
+
|
168 |
+
@app.get("/tts")
|
169 |
+
async def text_to_speech_get(
|
170 |
+
request: Request,
|
171 |
+
t: str,
|
172 |
+
v: Optional[str] = "zh-CN-XiaoxiaoMultilingualNeural",
|
173 |
+
r: Optional[str] = "0",
|
174 |
+
p: Optional[str] = "0",
|
175 |
+
o: Optional[str] = "audio-24khz-48kbitrate-mono-mp3"
|
176 |
+
):
|
177 |
+
"""GET方式转发TTS请求并返回在线播放链接"""
|
178 |
+
try:
|
179 |
+
file_name = await download_audio({"t": t, "v": v, "r": r, "p": p, "o": o})
|
180 |
+
audio_url = generate_audio_url(request, file_name)
|
181 |
+
return {"status": "success", "audio_url": audio_url, "cached": False}
|
182 |
+
except Exception as e:
|
183 |
+
return JSONResponse(status_code=500, content={"status": "error", "message": str(e)})
|
184 |
+
|
185 |
+
@app.post("/tts")
|
186 |
+
async def text_to_speech_post(request: Request, data: Dict[str, Any] = Body(...)):
|
187 |
+
"""POST方式转发TTS请求并返回在线播放链接"""
|
188 |
+
try:
|
189 |
+
file_name = await download_audio(data)
|
190 |
+
audio_url = generate_audio_url(request, file_name)
|
191 |
+
cache_key = get_cache_key(data)
|
192 |
+
cached = cache_key in cache_store and datetime.now() < cache_store[cache_key]["expire_time"]
|
193 |
+
return {"status": "success", "audio_url": audio_url, "cached": cached}
|
194 |
+
except Exception as e:
|
195 |
+
return JSONResponse(status_code=500, content={"status": "error", "message": str(e)})
|
196 |
+
|
197 |
+
@app.post("/ltts")
|
198 |
+
async def text_to_speech_long(request: Request, data: Dict[str, Any] = Body(...)):
|
199 |
+
"""处理长文本TTS请求"""
|
200 |
+
text = data.get("t")
|
201 |
+
if not text:
|
202 |
+
raise HTTPException(status_code=400, detail="Parameter 't' (text) is required")
|
203 |
+
|
204 |
+
try:
|
205 |
+
segments = split_text(text)
|
206 |
+
if not segments:
|
207 |
+
raise ValueError("文本分割后没有有效内容")
|
208 |
+
|
209 |
+
file_names = []
|
210 |
+
base_params = {k: v for k, v in data.items() if k != "t"}
|
211 |
+
|
212 |
+
for i, segment in enumerate(segments):
|
213 |
+
try:
|
214 |
+
segment_params = {**base_params, "t": segment}
|
215 |
+
file_name = await download_audio(segment_params)
|
216 |
+
file_names.append(file_name)
|
217 |
+
except Exception as e:
|
218 |
+
return JSONResponse(
|
219 |
+
status_code=500,
|
220 |
+
content={"status": "error", "message": f"第{i+1}段文本转语音失败: {str(e)}", "segment": segment}
|
221 |
+
)
|
222 |
+
|
223 |
+
final_file_name = await merge_audio_files(file_names)
|
224 |
+
audio_url = generate_audio_url(request, final_file_name)
|
225 |
+
return {"status": "success", "audio_url": audio_url, "segments_count": len(segments)}
|
226 |
+
except ValueError as e:
|
227 |
+
return JSONResponse(status_code=400, content={"status": "error", "message": str(e)})
|
228 |
+
except Exception as e:
|
229 |
+
return JSONResponse(status_code=500, content={"status": "error", "message": str(e)})
|
230 |
+
|
231 |
+
@app.post("/mtts")
|
232 |
+
async def multi_voice_tts(request: Request, data: List[Dict[str, Any]] = Body(...)):
|
233 |
+
"""处理多角色朗读请求
|
234 |
+
接收格式:[{"v":"voice1", "t":"text1", "r":"rate1"}, {"v":"voice2", "t":"text2", "r":"rate2"}]
|
235 |
+
"""
|
236 |
+
if not data or not isinstance(data, list):
|
237 |
+
raise HTTPException(status_code=400, detail="请求数据必须是包含多个角色配置的数组")
|
238 |
+
|
239 |
+
try:
|
240 |
+
file_names = []
|
241 |
+
for i, voice_config in enumerate(data):
|
242 |
+
try:
|
243 |
+
# 为每个角色生成音频
|
244 |
+
params = {
|
245 |
+
"t": voice_config.get("t", ""),
|
246 |
+
"v": voice_config.get("v", "zh-CN-XiaoxiaoMultilingualNeural"),
|
247 |
+
"r": voice_config.get("r", "0"),
|
248 |
+
"p": voice_config.get("p", "0"),
|
249 |
+
"o": voice_config.get("o", "audio-24khz-48kbitrate-mono-mp3")
|
250 |
+
}
|
251 |
+
file_name = await download_audio(params)
|
252 |
+
file_names.append(file_name)
|
253 |
+
except Exception as e:
|
254 |
+
return JSONResponse(
|
255 |
+
status_code=500,
|
256 |
+
content={"status": "error", "message": f"第{i+1}个角色的语音生成失败: {str(e)}", "voice_config": voice_config}
|
257 |
+
)
|
258 |
+
|
259 |
+
# 合并所有角色的音频
|
260 |
+
final_file_name = await merge_audio_files(file_names)
|
261 |
+
audio_url = generate_audio_url(request, final_file_name)
|
262 |
+
return {"status": "success", "audio_url": audio_url, "voices_count": len(data)}
|
263 |
+
except Exception as e:
|
264 |
+
return JSONResponse(status_code=500, content={"status": "error", "message": str(e)})
|
265 |
+
|
266 |
+
@app.get("/audio/{file_name}")
|
267 |
+
async def get_audio(file_name: str):
|
268 |
+
"""返回音频文件"""
|
269 |
+
file_path = Path(AUDIO_DIR) / file_name
|
270 |
+
if not file_path.exists():
|
271 |
+
raise HTTPException(status_code=404, detail="Audio file not found")
|
272 |
+
return FileResponse(file_path, media_type="audio/mpeg")
|
273 |
+
|
274 |
+
@app.get("/voices")
|
275 |
+
async def get_voices(l: Optional[str] = "zh", d: Optional[bool] = False):
|
276 |
+
"""获取语音列表"""
|
277 |
+
params = {}
|
278 |
+
if l:
|
279 |
+
params["l"] = l
|
280 |
+
if d:
|
281 |
+
params["d"] = ""
|
282 |
+
|
283 |
+
async with httpx.AsyncClient() as client:
|
284 |
+
try:
|
285 |
+
response = await client.get(f"{HOST}/voices", params=params)
|
286 |
+
response.raise_for_status()
|
287 |
+
return response.json()
|
288 |
+
except httpx.HTTPStatusError as e:
|
289 |
+
raise HTTPException(status_code=e.response.status_code, detail=str(e))
|
290 |
+
|
291 |
+
@app.get("/cache/status")
|
292 |
+
async def cache_status():
|
293 |
+
"""获取缓存状态"""
|
294 |
+
return {"cache_count": len(cache_store), "expire_hours": CACHE_EXPIRE_HOURS}
|
295 |
+
|
296 |
+
@app.delete("/cache/clear")
|
297 |
+
async def clear_cache():
|
298 |
+
"""清除所有缓存"""
|
299 |
+
global cache_store
|
300 |
+
cache_store = {}
|
301 |
+
return {"status": "success", "message": "Cache cleared"}
|
302 |
+
|
303 |
+
@app.get("/")
|
304 |
+
async def read_root(request: Request):
|
305 |
+
"""主页,展示语音列表"""
|
306 |
+
return templates.TemplateResponse("index.html", {"request": request})
|
307 |
+
|
308 |
+
if __name__ == "__main__":
|
309 |
+
import uvicorn
|
310 |
+
uvicorn.run(app, host="0.0.0.0", port=7860)
|