Update app.py
Browse files
app.py
CHANGED
@@ -6,7 +6,13 @@ from fastapi.responses import HTMLResponse
|
|
6 |
|
7 |
# Import your model and VAD libraries.
|
8 |
from silero_vad import VADIterator, load_silero_vad
|
9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
# Constants
|
12 |
SAMPLING_RATE = 16000
|
@@ -17,29 +23,29 @@ MIN_REFRESH_SECS = 1 # Minimum interval for sending partial updates.
|
|
17 |
|
18 |
app = FastAPI()
|
19 |
|
20 |
-
class Transcriber:
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
|
44 |
def pcm16_to_float32(pcm_data: bytes) -> np.ndarray:
|
45 |
"""
|
@@ -50,10 +56,10 @@ def pcm16_to_float32(pcm_data: bytes) -> np.ndarray:
|
|
50 |
return float_data
|
51 |
|
52 |
# Initialize models.
|
53 |
-
model_name_tiny = "moonshine/tiny"
|
54 |
-
model_name_base = "moonshine/base"
|
55 |
-
transcriber_tiny = Transcriber(model_name=model_name_tiny, rate=SAMPLING_RATE)
|
56 |
-
transcriber_base = Transcriber(model_name=model_name_base, rate=SAMPLING_RATE)
|
57 |
vad_model = load_silero_vad(onnx=True)
|
58 |
vad_iterator = VADIterator(
|
59 |
model=vad_model,
|
@@ -79,10 +85,10 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
79 |
data = await websocket.receive()
|
80 |
if data["type"] == "websocket.receive":
|
81 |
if data.get("text") == "switch_to_tiny":
|
82 |
-
current_model = transcriber_tiny
|
83 |
continue
|
84 |
elif data.get("text") == "switch_to_base":
|
85 |
-
current_model = transcriber_base
|
86 |
continue
|
87 |
|
88 |
chunk = pcm16_to_float32(data["bytes"])
|
@@ -100,7 +106,7 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
100 |
|
101 |
if "end" in vad_result and recording:
|
102 |
recording = False
|
103 |
-
text =
|
104 |
await websocket.send_json({"type": "final", "transcript": text})
|
105 |
caption_cache.append(text)
|
106 |
speech = np.empty(0, dtype=np.float32)
|
@@ -111,7 +117,7 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
111 |
elif recording:
|
112 |
if (len(speech) / SAMPLING_RATE) > MAX_SPEECH_SECS:
|
113 |
recording = False
|
114 |
-
text =
|
115 |
await websocket.send_json({"type": "final", "transcript": text})
|
116 |
caption_cache.append(text)
|
117 |
speech = np.empty(0, dtype=np.float32)
|
@@ -121,14 +127,14 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
121 |
await websocket.send_json({"type": "status", "message": "speaking_stopped"})
|
122 |
|
123 |
if (current_time - last_partial_time) > MIN_REFRESH_SECS:
|
124 |
-
text =
|
125 |
if last_output != text:
|
126 |
last_output = text
|
127 |
await websocket.send_json({"type": "partial", "transcript": text})
|
128 |
last_partial_time = current_time
|
129 |
except WebSocketDisconnect:
|
130 |
if recording and speech.size:
|
131 |
-
text =
|
132 |
await websocket.send_json({"type": "final", "transcript": text})
|
133 |
print("WebSocket disconnected")
|
134 |
|
|
|
6 |
|
7 |
# Import your model and VAD libraries.
|
8 |
from silero_vad import VADIterator, load_silero_vad
|
9 |
+
|
10 |
+
from transformers import AutoProcessor, pipeline
|
11 |
+
from optimum.onnxruntime import ORTModelForSpeechSeq2Seq
|
12 |
+
|
13 |
+
processor = AutoProcessor.from_pretrained("optimum/whisper-tiny.en")
|
14 |
+
model = ORTModelForSpeechSeq2Seq.from_pretrained("optimum/whisper-tiny.en")
|
15 |
+
speech_recognition = pipeline("automatic-speech-recognition", model=model, tokenizer=processor.tokenizer, feature_extractor=processor.feature_extractor)
|
16 |
|
17 |
# Constants
|
18 |
SAMPLING_RATE = 16000
|
|
|
23 |
|
24 |
app = FastAPI()
|
25 |
|
26 |
+
# class Transcriber:
|
27 |
+
# def __init__(self, model_name: str, rate: int = 16000):
|
28 |
+
# if rate != 16000:
|
29 |
+
# raise ValueError("Moonshine supports sampling rate 16000 Hz.")
|
30 |
+
# self.model = MoonshineOnnxModel(model_name=model_name)
|
31 |
+
# self.rate = rate
|
32 |
+
# self.tokenizer = load_tokenizer()
|
33 |
+
# # Statistics (optional)
|
34 |
+
# self.inference_secs = 0
|
35 |
+
# self.number_inferences = 0
|
36 |
+
# self.speech_secs = 0
|
37 |
+
# # Warmup run.
|
38 |
+
# self.__call__(np.zeros(int(rate), dtype=np.float32))
|
39 |
|
40 |
+
# def __call__(self, speech: np.ndarray) -> str:
|
41 |
+
# """Returns a transcription of the given speech (a float32 numpy array)."""
|
42 |
+
# self.number_inferences += 1
|
43 |
+
# self.speech_secs += len(speech) / self.rate
|
44 |
+
# start_time = time.time()
|
45 |
+
# tokens = self.model.generate(speech[np.newaxis, :].astype(np.float32))
|
46 |
+
# text = self.tokenizer.decode_batch(tokens)[0]
|
47 |
+
# self.inference_secs += time.time() - start_time
|
48 |
+
# return text
|
49 |
|
50 |
def pcm16_to_float32(pcm_data: bytes) -> np.ndarray:
|
51 |
"""
|
|
|
56 |
return float_data
|
57 |
|
58 |
# Initialize models.
|
59 |
+
# model_name_tiny = "moonshine/tiny"
|
60 |
+
# model_name_base = "moonshine/base"
|
61 |
+
# transcriber_tiny = Transcriber(model_name=model_name_tiny, rate=SAMPLING_RATE)
|
62 |
+
# transcriber_base = Transcriber(model_name=model_name_base, rate=SAMPLING_RATE)
|
63 |
vad_model = load_silero_vad(onnx=True)
|
64 |
vad_iterator = VADIterator(
|
65 |
model=vad_model,
|
|
|
85 |
data = await websocket.receive()
|
86 |
if data["type"] == "websocket.receive":
|
87 |
if data.get("text") == "switch_to_tiny":
|
88 |
+
# current_model = transcriber_tiny
|
89 |
continue
|
90 |
elif data.get("text") == "switch_to_base":
|
91 |
+
# current_model = transcriber_base
|
92 |
continue
|
93 |
|
94 |
chunk = pcm16_to_float32(data["bytes"])
|
|
|
106 |
|
107 |
if "end" in vad_result and recording:
|
108 |
recording = False
|
109 |
+
text = pipe({"sampling_rate": 16000, "raw": speech})["text"]
|
110 |
await websocket.send_json({"type": "final", "transcript": text})
|
111 |
caption_cache.append(text)
|
112 |
speech = np.empty(0, dtype=np.float32)
|
|
|
117 |
elif recording:
|
118 |
if (len(speech) / SAMPLING_RATE) > MAX_SPEECH_SECS:
|
119 |
recording = False
|
120 |
+
text = pipe({"sampling_rate": 16000, "raw": speech})["text"]
|
121 |
await websocket.send_json({"type": "final", "transcript": text})
|
122 |
caption_cache.append(text)
|
123 |
speech = np.empty(0, dtype=np.float32)
|
|
|
127 |
await websocket.send_json({"type": "status", "message": "speaking_stopped"})
|
128 |
|
129 |
if (current_time - last_partial_time) > MIN_REFRESH_SECS:
|
130 |
+
text = pipe({"sampling_rate": 16000, "raw": speech})["text"]
|
131 |
if last_output != text:
|
132 |
last_output = text
|
133 |
await websocket.send_json({"type": "partial", "transcript": text})
|
134 |
last_partial_time = current_time
|
135 |
except WebSocketDisconnect:
|
136 |
if recording and speech.size:
|
137 |
+
text = pipe({"sampling_rate": 16000, "raw": speech})["text"]
|
138 |
await websocket.send_json({"type": "final", "transcript": text})
|
139 |
print("WebSocket disconnected")
|
140 |
|