TaiYouWeb commited on
Commit
14b80ce
·
verified ·
1 Parent(s): 56bdf87

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -96
app.py CHANGED
@@ -9,19 +9,14 @@ import json
9
  from typing import Optional
10
 
11
  import torch
12
- import gradio as gr # 添加Gradio库
13
 
14
  from config import model_config
15
-
16
- from fastapi import FastAPI, File, Form, UploadFile, HTTPException
17
- from fastapi.responses import StreamingResponse, Response
18
-
19
- import uvicorn
20
 
21
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
22
  model_dir = snapshot_download(model_config['model_dir'])
23
 
24
- # 初始化模型
25
  model = AutoModel(
26
  model=model_dir,
27
  trust_remote_code=False,
@@ -39,15 +34,12 @@ def transcribe_audio(file_path, vad_model="fsmn-vad", vad_kwargs='{"max_single_s
39
  merge_vad=True, merge_length_s=15, batch_size_threshold_s=50,
40
  hotword=" ", spk_model="cam++", ban_emo_unk=False):
41
  try:
42
- # 将字符串转换为字典
43
  vad_kwargs = json.loads(vad_kwargs)
44
 
45
- # 使用文件路径作为输入
46
  temp_file_path = file_path
47
 
48
- # 生成结果
49
  res = model.generate(
50
- input=temp_file_path, # 使用文件路径作为输入
51
  cache={},
52
  language=language,
53
  use_itn=use_itn,
@@ -60,18 +52,15 @@ def transcribe_audio(file_path, vad_model="fsmn-vad", vad_kwargs='{"max_single_s
60
  ban_emo_unk=ban_emo_unk
61
  )
62
 
63
- # 处理结果
64
  text = rich_transcription_postprocess(res[0]["text"])
65
 
66
  return text
67
 
68
  except Exception as e:
69
- # 捕获异常并返回错误信息
70
  return str(e)
71
 
72
- # 创建Gradio界面
73
  inputs = [
74
- gr.Audio(type="filepath"), # 设置为'filepath'来支持文件路径
75
  gr.Textbox(value="fsmn-vad", label="VAD Model"),
76
  gr.Textbox(value='{"max_single_segment_time": 30000}', label="VAD Kwargs"),
77
  gr.Slider(1, 8, value=4, step=1, label="NCPU"),
@@ -97,84 +86,21 @@ gr.Interface(
97
  ).launch()
98
 
99
 
100
- class SynthesizeResponse(Response):
101
- media_type = 'text/plain'
102
-
103
- app = FastAPI()
104
-
105
- @app.post('/asr', response_class=SynthesizeResponse)
106
- async def generate(
107
- file: UploadFile = File(...),
108
- vad_model: str = Form("fsmn-vad"),
109
- vad_kwargs: str = Form('{"max_single_segment_time": 30000}'),
110
- ncpu: int = Form(4),
111
- batch_size: int = Form(1),
112
- language: str = Form("auto"),
113
- use_itn: bool = Form(True),
114
- batch_size_s: int = Form(60),
115
- merge_vad: bool = Form(True),
116
- merge_length_s: int = Form(15),
117
- batch_size_threshold_s: int = Form(50),
118
- hotword: Optional[str] = Form(" "),
119
- spk_model: str = Form("cam++"),
120
- ban_emo_unk: bool = Form(False),
121
- ) -> StreamingResponse:
122
- try:
123
- # 将字符串转换为字典
124
- vad_kwargs = json.loads(vad_kwargs)
125
-
126
- # 创建临时文件并保存上传的音频文件
127
- with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file:
128
- temp_file_path = temp_file.name
129
- input_wav_bytes = await file.read()
130
- temp_file.write(input_wav_bytes)
131
-
132
- try:
133
- # 初始化模型
134
- model = AutoModel(
135
- model=model_dir,
136
- trust_remote_code=False,
137
- remote_code="./model.py",
138
- vad_model=vad_model,
139
- vad_kwargs=vad_kwargs,
140
- ncpu=ncpu,
141
- batch_size=batch_size,
142
- hub="ms",
143
- device=device,
144
- )
145
-
146
- # 生成结果
147
- res = model.generate(
148
- input=temp_file_path, # 使用临时文件路径作为输入
149
- cache={},
150
- language=language,
151
- use_itn=use_itn,
152
- batch_size_s=batch_size_s,
153
- merge_vad=merge_vad,
154
- merge_length_s=merge_length_s,
155
- batch_size_threshold_s=batch_size_threshold_s,
156
- hotword=hotword,
157
- spk_model=spk_model,
158
- ban_emo_unk=ban_emo_unk
159
- )
160
-
161
- # 处理结果
162
- text = rich_transcription_postprocess(res[0]["text"])
163
-
164
- # 返回结果
165
- return StreamingResponse(io.BytesIO(text.encode('utf-8')), media_type="text/plain")
166
-
167
- finally:
168
- # 确保在处理完毕后删除临时文件
169
- if os.path.exists(temp_file_path):
170
- os.remove(temp_file_path)
171
-
172
- except Exception as e:
173
- raise HTTPException(status_code=500, detail=str(e))
174
-
175
- @app.get("/root")
176
- async def read_root():
177
- return {"message": "Hello World"}
178
-
179
- if __name__ == "__main__":
180
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
9
  from typing import Optional
10
 
11
  import torch
12
+ import gradio as gr
13
 
14
  from config import model_config
15
+ from gradio_client import Client, handle_file
 
 
 
 
16
 
17
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
18
  model_dir = snapshot_download(model_config['model_dir'])
19
 
 
20
  model = AutoModel(
21
  model=model_dir,
22
  trust_remote_code=False,
 
34
  merge_vad=True, merge_length_s=15, batch_size_threshold_s=50,
35
  hotword=" ", spk_model="cam++", ban_emo_unk=False):
36
  try:
 
37
  vad_kwargs = json.loads(vad_kwargs)
38
 
 
39
  temp_file_path = file_path
40
 
 
41
  res = model.generate(
42
+ input=temp_file_path,
43
  cache={},
44
  language=language,
45
  use_itn=use_itn,
 
52
  ban_emo_unk=ban_emo_unk
53
  )
54
 
 
55
  text = rich_transcription_postprocess(res[0]["text"])
56
 
57
  return text
58
 
59
  except Exception as e:
 
60
  return str(e)
61
 
 
62
  inputs = [
63
+ gr.Audio(type="filepath"),
64
  gr.Textbox(value="fsmn-vad", label="VAD Model"),
65
  gr.Textbox(value='{"max_single_segment_time": 30000}', label="VAD Kwargs"),
66
  gr.Slider(1, 8, value=4, step=1, label="NCPU"),
 
86
  ).launch()
87
 
88
 
89
+ client = Client("TaiYouWeb/funasr-svsmall-cpu")
90
+ result = client.predict(
91
+ file_path=handle_file('https://github.com/gradio-app/gradio/raw/main/test/test_files/audio_sample.wav'),
92
+ vad_model="fsmn-vad",
93
+ vad_kwargs="{"max_single_segment_time": 30000}",
94
+ ncpu=4,
95
+ batch_size=1,
96
+ language="auto",
97
+ use_itn=True,
98
+ batch_size_s=60,
99
+ merge_vad=True,
100
+ merge_length_s=15,
101
+ batch_size_threshold_s=50,
102
+ hotword=" ",
103
+ spk_model="cam++",
104
+ ban_emo_unk=False,
105
+ api_name="/asr"
106
+ )