bcci commited on
Commit
50babed
·
verified ·
1 Parent(s): 4bbccfd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -33
app.py CHANGED
@@ -6,7 +6,13 @@ from fastapi.responses import HTMLResponse
6
 
7
  # Import your model and VAD libraries.
8
  from silero_vad import VADIterator, load_silero_vad
9
- from moonshine_onnx import MoonshineOnnxModel, load_tokenizer
 
 
 
 
 
 
10
 
11
  # Constants
12
  SAMPLING_RATE = 16000
@@ -17,29 +23,29 @@ MIN_REFRESH_SECS = 1 # Minimum interval for sending partial updates.
17
 
18
  app = FastAPI()
19
 
20
- class Transcriber:
21
- def __init__(self, model_name: str, rate: int = 16000):
22
- if rate != 16000:
23
- raise ValueError("Moonshine supports sampling rate 16000 Hz.")
24
- self.model = MoonshineOnnxModel(model_name=model_name)
25
- self.rate = rate
26
- self.tokenizer = load_tokenizer()
27
- # Statistics (optional)
28
- self.inference_secs = 0
29
- self.number_inferences = 0
30
- self.speech_secs = 0
31
- # Warmup run.
32
- self.__call__(np.zeros(int(rate), dtype=np.float32))
33
 
34
- def __call__(self, speech: np.ndarray) -> str:
35
- """Returns a transcription of the given speech (a float32 numpy array)."""
36
- self.number_inferences += 1
37
- self.speech_secs += len(speech) / self.rate
38
- start_time = time.time()
39
- tokens = self.model.generate(speech[np.newaxis, :].astype(np.float32))
40
- text = self.tokenizer.decode_batch(tokens)[0]
41
- self.inference_secs += time.time() - start_time
42
- return text
43
 
44
  def pcm16_to_float32(pcm_data: bytes) -> np.ndarray:
45
  """
@@ -50,10 +56,10 @@ def pcm16_to_float32(pcm_data: bytes) -> np.ndarray:
50
  return float_data
51
 
52
  # Initialize models.
53
- model_name_tiny = "moonshine/tiny"
54
- model_name_base = "moonshine/base"
55
- transcriber_tiny = Transcriber(model_name=model_name_tiny, rate=SAMPLING_RATE)
56
- transcriber_base = Transcriber(model_name=model_name_base, rate=SAMPLING_RATE)
57
  vad_model = load_silero_vad(onnx=True)
58
  vad_iterator = VADIterator(
59
  model=vad_model,
@@ -79,10 +85,10 @@ async def websocket_endpoint(websocket: WebSocket):
79
  data = await websocket.receive()
80
  if data["type"] == "websocket.receive":
81
  if data.get("text") == "switch_to_tiny":
82
- current_model = transcriber_tiny
83
  continue
84
  elif data.get("text") == "switch_to_base":
85
- current_model = transcriber_base
86
  continue
87
 
88
  chunk = pcm16_to_float32(data["bytes"])
@@ -100,7 +106,7 @@ async def websocket_endpoint(websocket: WebSocket):
100
 
101
  if "end" in vad_result and recording:
102
  recording = False
103
- text = current_model(speech)
104
  await websocket.send_json({"type": "final", "transcript": text})
105
  caption_cache.append(text)
106
  speech = np.empty(0, dtype=np.float32)
@@ -111,7 +117,7 @@ async def websocket_endpoint(websocket: WebSocket):
111
  elif recording:
112
  if (len(speech) / SAMPLING_RATE) > MAX_SPEECH_SECS:
113
  recording = False
114
- text = current_model(speech)
115
  await websocket.send_json({"type": "final", "transcript": text})
116
  caption_cache.append(text)
117
  speech = np.empty(0, dtype=np.float32)
@@ -121,14 +127,14 @@ async def websocket_endpoint(websocket: WebSocket):
121
  await websocket.send_json({"type": "status", "message": "speaking_stopped"})
122
 
123
  if (current_time - last_partial_time) > MIN_REFRESH_SECS:
124
- text = current_model(speech)
125
  if last_output != text:
126
  last_output = text
127
  await websocket.send_json({"type": "partial", "transcript": text})
128
  last_partial_time = current_time
129
  except WebSocketDisconnect:
130
  if recording and speech.size:
131
- text = current_model(speech)
132
  await websocket.send_json({"type": "final", "transcript": text})
133
  print("WebSocket disconnected")
134
 
 
6
 
7
  # Import your model and VAD libraries.
8
  from silero_vad import VADIterator, load_silero_vad
9
+
10
+ from transformers import AutoProcessor, pipeline
11
+ from optimum.onnxruntime import ORTModelForSpeechSeq2Seq
12
+
13
+ processor = AutoProcessor.from_pretrained("optimum/whisper-tiny.en")
14
+ model = ORTModelForSpeechSeq2Seq.from_pretrained("optimum/whisper-tiny.en")
15
+ speech_recognition = pipeline("automatic-speech-recognition", model=model, tokenizer=processor.tokenizer, feature_extractor=processor.feature_extractor)
16
 
17
  # Constants
18
  SAMPLING_RATE = 16000
 
23
 
24
  app = FastAPI()
25
 
26
+ # class Transcriber:
27
+ # def __init__(self, model_name: str, rate: int = 16000):
28
+ # if rate != 16000:
29
+ # raise ValueError("Moonshine supports sampling rate 16000 Hz.")
30
+ # self.model = MoonshineOnnxModel(model_name=model_name)
31
+ # self.rate = rate
32
+ # self.tokenizer = load_tokenizer()
33
+ # # Statistics (optional)
34
+ # self.inference_secs = 0
35
+ # self.number_inferences = 0
36
+ # self.speech_secs = 0
37
+ # # Warmup run.
38
+ # self.__call__(np.zeros(int(rate), dtype=np.float32))
39
 
40
+ # def __call__(self, speech: np.ndarray) -> str:
41
+ # """Returns a transcription of the given speech (a float32 numpy array)."""
42
+ # self.number_inferences += 1
43
+ # self.speech_secs += len(speech) / self.rate
44
+ # start_time = time.time()
45
+ # tokens = self.model.generate(speech[np.newaxis, :].astype(np.float32))
46
+ # text = self.tokenizer.decode_batch(tokens)[0]
47
+ # self.inference_secs += time.time() - start_time
48
+ # return text
49
 
50
  def pcm16_to_float32(pcm_data: bytes) -> np.ndarray:
51
  """
 
56
  return float_data
57
 
58
  # Initialize models.
59
+ # model_name_tiny = "moonshine/tiny"
60
+ # model_name_base = "moonshine/base"
61
+ # transcriber_tiny = Transcriber(model_name=model_name_tiny, rate=SAMPLING_RATE)
62
+ # transcriber_base = Transcriber(model_name=model_name_base, rate=SAMPLING_RATE)
63
  vad_model = load_silero_vad(onnx=True)
64
  vad_iterator = VADIterator(
65
  model=vad_model,
 
85
  data = await websocket.receive()
86
  if data["type"] == "websocket.receive":
87
  if data.get("text") == "switch_to_tiny":
88
+ # current_model = transcriber_tiny
89
  continue
90
  elif data.get("text") == "switch_to_base":
91
+ # current_model = transcriber_base
92
  continue
93
 
94
  chunk = pcm16_to_float32(data["bytes"])
 
106
 
107
  if "end" in vad_result and recording:
108
  recording = False
109
+ text = pipe({"sampling_rate": 16000, "raw": speech})["text"]
110
  await websocket.send_json({"type": "final", "transcript": text})
111
  caption_cache.append(text)
112
  speech = np.empty(0, dtype=np.float32)
 
117
  elif recording:
118
  if (len(speech) / SAMPLING_RATE) > MAX_SPEECH_SECS:
119
  recording = False
120
+ text = pipe({"sampling_rate": 16000, "raw": speech})["text"]
121
  await websocket.send_json({"type": "final", "transcript": text})
122
  caption_cache.append(text)
123
  speech = np.empty(0, dtype=np.float32)
 
127
  await websocket.send_json({"type": "status", "message": "speaking_stopped"})
128
 
129
  if (current_time - last_partial_time) > MIN_REFRESH_SECS:
130
+ text = pipe({"sampling_rate": 16000, "raw": speech})["text"]
131
  if last_output != text:
132
  last_output = text
133
  await websocket.send_json({"type": "partial", "transcript": text})
134
  last_partial_time = current_time
135
  except WebSocketDisconnect:
136
  if recording and speech.size:
137
+ text = pipe({"sampling_rate": 16000, "raw": speech})["text"]
138
  await websocket.send_json({"type": "final", "transcript": text})
139
  print("WebSocket disconnected")
140