99i commited on
Commit
5931df7
·
verified ·
1 Parent(s): 0d8e608

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +310 -188
main.py CHANGED
@@ -1,188 +1,310 @@
1
- from fastapi import FastAPI, HTTPException, Request, Body
2
- from fastapi.responses import JSONResponse, FileResponse
3
- from fastapi.templating import Jinja2Templates
4
- import httpx
5
- import os
6
- import uuid
7
- from pathlib import Path
8
- import aiofiles
9
- from typing import Optional, Dict, Any
10
- from datetime import datetime, timedelta
11
- import hashlib
12
- import json
13
-
14
- app = FastAPI()
15
-
16
- templates = Jinja2Templates(directory="templates")
17
-
18
- # 配置
19
- HOST = os.getenv("TTS_HOST", "https://99i-tts.hf.space") # 替换为实际环境变量名
20
- AUDIO_DIR = "audio_files"
21
- CACHE_EXPIRE_HOURS = 24 # 缓存过期时间(小时)
22
- Path(AUDIO_DIR).mkdir(exist_ok=True)
23
-
24
- # 内存缓存 (生产环境建议使用Redis等)
25
- cache_store = {}
26
-
27
- def get_cache_key(params: Dict[str, Any]) -> str:
28
- """生成缓存键"""
29
- param_str = json.dumps(params, sort_keys=True)
30
- return hashlib.md5(param_str.encode()).hexdigest()
31
-
32
- async def download_audio(params: dict, use_post: bool = False) -> str:
33
- """下载音频文件并保存到本地"""
34
- cache_key = get_cache_key(params)
35
-
36
- # 检查缓存
37
- if cache_key in cache_store:
38
- cached_item = cache_store[cache_key]
39
- if datetime.now() < cached_item["expire_time"]:
40
- return cached_item["file_name"]
41
-
42
- async with httpx.AsyncClient() as client:
43
- try:
44
- if use_post:
45
- response = await client.post(f"{HOST}/tts", json=params)
46
- else:
47
- response = await client.get(f"{HOST}/tts", params=params)
48
-
49
- response.raise_for_status()
50
-
51
- # 生成唯一文件名
52
- file_name = f"{uuid.uuid4()}.mp3"
53
- file_path = Path(AUDIO_DIR) / file_name
54
-
55
- # 保存音频文件
56
- async with aiofiles.open(file_path, "wb") as f:
57
- await f.write(response.content)
58
-
59
- # 更新缓存
60
- cache_store[cache_key] = {
61
- "file_name": file_name,
62
- "expire_time": datetime.now() + timedelta(hours=CACHE_EXPIRE_HOURS)
63
- }
64
-
65
- return file_name
66
- except httpx.HTTPStatusError as e:
67
- raise HTTPException(status_code=e.response.status_code, detail=str(e))
68
-
69
- @app.get("/tts")
70
- async def text_to_speech_get(
71
- request: Request,
72
- t: str,
73
- v: Optional[str] = "zh-CN-XiaoxiaoMultilingualNeural",
74
- r: Optional[int] = 0,
75
- p: Optional[int] = 0,
76
- o: Optional[str] = "audio-24khz-48kbitrate-mono-mp3"
77
- ):
78
- """GET方式转发TTS请求并返回在线播放链接"""
79
- params = {
80
- "t": t,
81
- "v": v,
82
- "r": r,
83
- "p": p,
84
- "o": o
85
- }
86
-
87
- try:
88
- # 先尝试GET,失败后尝试POST
89
- try:
90
- file_name = await download_audio(params, use_post=False)
91
- except:
92
- file_name = await download_audio(params, use_post=True)
93
-
94
- base_url = str(request.base_url)
95
- audio_url = f"{base_url}audio/{file_name}"
96
- return {"status": "success", "audio_url": audio_url, "cached": False}
97
- except Exception as e:
98
- return JSONResponse(
99
- status_code=500,
100
- content={"status": "error", "message": str(e)}
101
- )
102
-
103
- @app.post("/tts")
104
- async def text_to_speech_post(
105
- request: Request,
106
- data: Dict[str, Any] = Body(...)
107
- ):
108
- """POST方式转发TTS请求并返回在线播放链接"""
109
- params = {
110
- "t": data.get("t"),
111
- "v": data.get("v", "zh-CN-XiaoxiaoMultilingualNeural"),
112
- "r": data.get("r", 0),
113
- "p": data.get("p", 0),
114
- "o": data.get("o", "audio-24khz-48kbitrate-mono-mp3")
115
- }
116
-
117
- # 检查必填参数
118
- if not params["t"]:
119
- raise HTTPException(status_code=400, detail="Parameter 't' (text) is required")
120
-
121
- try:
122
- file_name = await download_audio(params, use_post=True)
123
- base_url = str(request.base_url)
124
- audio_url = f"{base_url}audio/{file_name}"
125
-
126
- # 检查是否来自缓存
127
- cache_key = get_cache_key(params)
128
- cached = cache_key in cache_store and datetime.now() < cache_store[cache_key]["expire_time"]
129
-
130
- return {"status": "success", "audio_url": audio_url, "cached": cached}
131
- except Exception as e:
132
- return JSONResponse(
133
- status_code=500,
134
- content={"status": "error", "message": str(e)}
135
- )
136
-
137
- @app.get("/audio/{file_name}")
138
- async def get_audio(file_name: str):
139
- """返回音频文件"""
140
- file_path = Path(AUDIO_DIR) / file_name
141
- if not file_path.exists():
142
- raise HTTPException(status_code=404, detail="Audio file not found")
143
- return FileResponse(file_path, media_type="audio/mpeg")
144
-
145
- @app.get("/voices")
146
- async def get_voices(
147
- l: Optional[str] = "zh",
148
- d: Optional[bool] = False
149
- ):
150
- """获取语音列表"""
151
- params = {}
152
- if l:
153
- params["l"] = l
154
- if d:
155
- params["d"] = ""
156
-
157
- async with httpx.AsyncClient() as client:
158
- try:
159
- response = await client.get(f"{HOST}/voices", params=params)
160
- response.raise_for_status()
161
- return response.json()
162
- except httpx.HTTPStatusError as e:
163
- raise HTTPException(status_code=e.response.status_code, detail=str(e))
164
-
165
- @app.get("/cache/status")
166
- async def cache_status():
167
- """获取缓存状态"""
168
- return {
169
- "cache_count": len(cache_store),
170
- "expire_hours": CACHE_EXPIRE_HOURS
171
- }
172
-
173
- @app.delete("/cache/clear")
174
- async def clear_cache():
175
- """清除所有缓存"""
176
- global cache_store
177
- cache_store = {}
178
- return {"status": "success", "message": "Cache cleared"}
179
-
180
- @app.get("/")
181
- async def read_root(request: Request):
182
- """主页,展示语音列表"""
183
- return templates.TemplateResponse("index.html", {"request": request})
184
-
185
-
186
- if __name__ == "__main__":
187
- import uvicorn
188
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException, Request, Body
2
+ from fastapi.responses import JSONResponse, FileResponse
3
+ from fastapi.templating import Jinja2Templates
4
+ import httpx
5
+ import os
6
+ import uuid
7
+ from pathlib import Path
8
+ import aiofiles
9
+ from typing import Optional, Dict, Any, List
10
+ from datetime import datetime, timedelta
11
+ import hashlib
12
+ import json
13
+ from pydub import AudioSegment
14
+ import re
15
+
16
+ app = FastAPI()
17
+ templates = Jinja2Templates(directory="templates")
18
+
19
+ # 配置
20
+ HOST = os.getenv("TTS_HOST", "https://99i-tts.hf.space")
21
+ AUDIO_DIR = "audio_files"
22
+ CACHE_EXPIRE_HOURS = 24
23
+ MAX_SEGMENT_LENGTH = 300
24
+ Path(AUDIO_DIR).mkdir(exist_ok=True)
25
+
26
+ # 内存缓存
27
+ cache_store = {}
28
+
29
+ def get_cache_key(params: Dict[str, Any]) -> str:
30
+ """生成缓存键"""
31
+ param_str = json.dumps(params, sort_keys=True)
32
+ return hashlib.md5(param_str.encode()).hexdigest()
33
+
34
+ def validate_tts_params(params: Dict[str, Any]) -> Dict[str, str]:
35
+ """验证并转换TTS参数"""
36
+ if not params.get("t"):
37
+ raise HTTPException(status_code=400, detail="文本内容不能为空")
38
+
39
+ try:
40
+ return {
41
+ "t": str(params.get("t", "")),
42
+ "v": str(params.get("v", "zh-CN-XiaoxiaoMultilingualNeural")),
43
+ "r": str(float(params.get("r", 0))),
44
+ "p": str(float(params.get("p", 0))),
45
+ "o": str(params.get("o", "audio-24khz-48kbitrate-mono-mp3"))
46
+ }
47
+ except ValueError:
48
+ raise HTTPException(status_code=400, detail="语速和音调参数必须是有效的数字")
49
+
50
+ def get_from_cache(cache_key: str) -> Optional[str]:
51
+ """从缓存获取音频文件名"""
52
+ if cache_key in cache_store:
53
+ cached_item = cache_store[cache_key]
54
+ if datetime.now() < cached_item["expire_time"]:
55
+ return cached_item["file_name"]
56
+ return None
57
+
58
+ def update_cache(cache_key: str, file_name: str) -> None:
59
+ """更新缓存"""
60
+ cache_store[cache_key] = {
61
+ "file_name": file_name,
62
+ "expire_time": datetime.now() + timedelta(hours=CACHE_EXPIRE_HOURS)
63
+ }
64
+
65
+ async def make_tts_request(params: Dict[str, str]) -> bytes:
66
+ """发送TTS请求并获取音频数据"""
67
+ async with httpx.AsyncClient() as client:
68
+ try:
69
+ print("\n=== TTS请求信息 ===")
70
+ print(f"请求URL: {HOST}/tts")
71
+ print(f"请求参数: {json.dumps(params, ensure_ascii=False, indent=2)}")
72
+ response = await client.post(f"{HOST}/tts", json=params)
73
+ print(f"响应状态码: {response.status_code}")
74
+ print("=== TTS请求完成 ===\n")
75
+ response.raise_for_status()
76
+ return response.content
77
+ except httpx.RequestError as e:
78
+ raise HTTPException(status_code=503, detail=f"请求TTS服务失败: {str(e)}")
79
+ except httpx.HTTPStatusError as e:
80
+ error_detail = "未知错误"
81
+ try:
82
+ error_json = e.response.json()
83
+ if isinstance(error_json, dict):
84
+ error_detail = error_json.get("detail") or error_json.get("message") or str(error_json)
85
+ except:
86
+ error_detail = e.response.text or str(e)
87
+ raise HTTPException(status_code=e.response.status_code, detail=f"TTS服务返回错误: {error_detail}")
88
+
89
+ async def save_audio_file(content: bytes) -> str:
90
+ """保存音频文件并返回文件名"""
91
+ file_name = f"{uuid.uuid4()}.mp3"
92
+ file_path = Path(AUDIO_DIR) / file_name
93
+ try:
94
+ async with aiofiles.open(file_path, "wb") as f:
95
+ await f.write(content)
96
+ return file_name
97
+ except IOError as e:
98
+ raise HTTPException(status_code=500, detail=f"保存音频文件失败: {str(e)}")
99
+
100
+ async def download_audio(params: dict) -> str:
101
+ """下载音频文件并保存到本地"""
102
+ params = validate_tts_params(params)
103
+ cache_key = get_cache_key(params)
104
+
105
+ # 检查缓存
106
+ cached_file = get_from_cache(cache_key)
107
+ if cached_file:
108
+ return cached_file
109
+
110
+ # 下载并保存音频
111
+ content = await make_tts_request(params)
112
+ file_name = await save_audio_file(content)
113
+ update_cache(cache_key, file_name)
114
+ return file_name
115
+
116
+ def generate_audio_url(request: Request, file_name: str) -> str:
117
+ """生成音频URL"""
118
+ base_url = str(request.base_url)
119
+ return f"{base_url}audio/{file_name}"
120
+
121
+ def split_text(text: str) -> List[str]:
122
+ """将长文本分割成较短的段落"""
123
+ segments = re.split(r'([。!?.!?])', text)
124
+ segments = [''.join(i) for i in zip(segments[0::2], segments[1::2] + [''] * (len(segments[0::2]) - len(segments[1::2])))]
125
+
126
+ result = []
127
+ current_segment = ""
128
+
129
+ for segment in segments:
130
+ segment = segment.strip()
131
+ if not segment:
132
+ continue
133
+
134
+ if len(current_segment) + len(segment) <= MAX_SEGMENT_LENGTH:
135
+ current_segment += segment
136
+ else:
137
+ if current_segment:
138
+ result.append(current_segment)
139
+ if len(segment) > MAX_SEGMENT_LENGTH:
140
+ for i in range(0, len(segment), MAX_SEGMENT_LENGTH):
141
+ result.append(segment[i:i + MAX_SEGMENT_LENGTH])
142
+ current_segment = ""
143
+ else:
144
+ current_segment = segment
145
+
146
+ if current_segment:
147
+ result.append(current_segment)
148
+
149
+ return result
150
+
151
+ async def merge_audio_files(file_names: List[str]) -> str:
152
+ """合并多个音频文件,在音频间添加静音间隔"""
153
+ combined = AudioSegment.empty()
154
+ silence = AudioSegment.silent(duration=500) # 创建500毫秒的静音
155
+
156
+ for i, file_name in enumerate(file_names):
157
+ file_path = Path(AUDIO_DIR) / file_name
158
+ audio = AudioSegment.from_mp3(str(file_path))
159
+ if i > 0: # 从第二个音频开始,在前面添加静音
160
+ combined += silence
161
+ combined += audio
162
+
163
+ output_file_name = f"{uuid.uuid4()}.mp3"
164
+ output_path = Path(AUDIO_DIR) / output_file_name
165
+ combined.export(str(output_path), format="mp3")
166
+ return output_file_name
167
+
168
+ @app.get("/tts")
169
+ async def text_to_speech_get(
170
+ request: Request,
171
+ t: str,
172
+ v: Optional[str] = "zh-CN-XiaoxiaoMultilingualNeural",
173
+ r: Optional[str] = "0",
174
+ p: Optional[str] = "0",
175
+ o: Optional[str] = "audio-24khz-48kbitrate-mono-mp3"
176
+ ):
177
+ """GET方式转发TTS请求并返回在线播放链接"""
178
+ try:
179
+ file_name = await download_audio({"t": t, "v": v, "r": r, "p": p, "o": o})
180
+ audio_url = generate_audio_url(request, file_name)
181
+ return {"status": "success", "audio_url": audio_url, "cached": False}
182
+ except Exception as e:
183
+ return JSONResponse(status_code=500, content={"status": "error", "message": str(e)})
184
+
185
+ @app.post("/tts")
186
+ async def text_to_speech_post(request: Request, data: Dict[str, Any] = Body(...)):
187
+ """POST方式转发TTS请求并返回在线播放链接"""
188
+ try:
189
+ file_name = await download_audio(data)
190
+ audio_url = generate_audio_url(request, file_name)
191
+ cache_key = get_cache_key(data)
192
+ cached = cache_key in cache_store and datetime.now() < cache_store[cache_key]["expire_time"]
193
+ return {"status": "success", "audio_url": audio_url, "cached": cached}
194
+ except Exception as e:
195
+ return JSONResponse(status_code=500, content={"status": "error", "message": str(e)})
196
+
197
+ @app.post("/ltts")
198
+ async def text_to_speech_long(request: Request, data: Dict[str, Any] = Body(...)):
199
+ """处理长文本TTS请求"""
200
+ text = data.get("t")
201
+ if not text:
202
+ raise HTTPException(status_code=400, detail="Parameter 't' (text) is required")
203
+
204
+ try:
205
+ segments = split_text(text)
206
+ if not segments:
207
+ raise ValueError("文本分割后没有有效内容")
208
+
209
+ file_names = []
210
+ base_params = {k: v for k, v in data.items() if k != "t"}
211
+
212
+ for i, segment in enumerate(segments):
213
+ try:
214
+ segment_params = {**base_params, "t": segment}
215
+ file_name = await download_audio(segment_params)
216
+ file_names.append(file_name)
217
+ except Exception as e:
218
+ return JSONResponse(
219
+ status_code=500,
220
+ content={"status": "error", "message": f"第{i+1}段文本转语音失败: {str(e)}", "segment": segment}
221
+ )
222
+
223
+ final_file_name = await merge_audio_files(file_names)
224
+ audio_url = generate_audio_url(request, final_file_name)
225
+ return {"status": "success", "audio_url": audio_url, "segments_count": len(segments)}
226
+ except ValueError as e:
227
+ return JSONResponse(status_code=400, content={"status": "error", "message": str(e)})
228
+ except Exception as e:
229
+ return JSONResponse(status_code=500, content={"status": "error", "message": str(e)})
230
+
231
+ @app.post("/mtts")
232
+ async def multi_voice_tts(request: Request, data: List[Dict[str, Any]] = Body(...)):
233
+ """处理多角色朗读请求
234
+ 接收格式:[{"v":"voice1", "t":"text1", "r":"rate1"}, {"v":"voice2", "t":"text2", "r":"rate2"}]
235
+ """
236
+ if not data or not isinstance(data, list):
237
+ raise HTTPException(status_code=400, detail="请求数据必须是包含多个角色配置的数组")
238
+
239
+ try:
240
+ file_names = []
241
+ for i, voice_config in enumerate(data):
242
+ try:
243
+ # 为每个角色生成音频
244
+ params = {
245
+ "t": voice_config.get("t", ""),
246
+ "v": voice_config.get("v", "zh-CN-XiaoxiaoMultilingualNeural"),
247
+ "r": voice_config.get("r", "0"),
248
+ "p": voice_config.get("p", "0"),
249
+ "o": voice_config.get("o", "audio-24khz-48kbitrate-mono-mp3")
250
+ }
251
+ file_name = await download_audio(params)
252
+ file_names.append(file_name)
253
+ except Exception as e:
254
+ return JSONResponse(
255
+ status_code=500,
256
+ content={"status": "error", "message": f"第{i+1}个角色的语音生成失败: {str(e)}", "voice_config": voice_config}
257
+ )
258
+
259
+ # 合并所有角色的音频
260
+ final_file_name = await merge_audio_files(file_names)
261
+ audio_url = generate_audio_url(request, final_file_name)
262
+ return {"status": "success", "audio_url": audio_url, "voices_count": len(data)}
263
+ except Exception as e:
264
+ return JSONResponse(status_code=500, content={"status": "error", "message": str(e)})
265
+
266
+ @app.get("/audio/{file_name}")
267
+ async def get_audio(file_name: str):
268
+ """返回音频文件"""
269
+ file_path = Path(AUDIO_DIR) / file_name
270
+ if not file_path.exists():
271
+ raise HTTPException(status_code=404, detail="Audio file not found")
272
+ return FileResponse(file_path, media_type="audio/mpeg")
273
+
274
+ @app.get("/voices")
275
+ async def get_voices(l: Optional[str] = "zh", d: Optional[bool] = False):
276
+ """获取语音列表"""
277
+ params = {}
278
+ if l:
279
+ params["l"] = l
280
+ if d:
281
+ params["d"] = ""
282
+
283
+ async with httpx.AsyncClient() as client:
284
+ try:
285
+ response = await client.get(f"{HOST}/voices", params=params)
286
+ response.raise_for_status()
287
+ return response.json()
288
+ except httpx.HTTPStatusError as e:
289
+ raise HTTPException(status_code=e.response.status_code, detail=str(e))
290
+
291
+ @app.get("/cache/status")
292
+ async def cache_status():
293
+ """获取缓存状态"""
294
+ return {"cache_count": len(cache_store), "expire_hours": CACHE_EXPIRE_HOURS}
295
+
296
+ @app.delete("/cache/clear")
297
+ async def clear_cache():
298
+ """清除所有缓存"""
299
+ global cache_store
300
+ cache_store = {}
301
+ return {"status": "success", "message": "Cache cleared"}
302
+
303
+ @app.get("/")
304
+ async def read_root(request: Request):
305
+ """主页,展示语音列表"""
306
+ return templates.TemplateResponse("index.html", {"request": request})
307
+
308
+ if __name__ == "__main__":
309
+ import uvicorn
310
+ uvicorn.run(app, host="0.0.0.0", port=7860)