Update app.py
Browse files
app.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
from importlib.metadata import version
|
2 |
from timeit import default_timer as timer
|
3 |
|
@@ -6,7 +7,11 @@ import numpy as np
|
|
6 |
|
7 |
import onnx_asr
|
8 |
|
9 |
-
|
|
|
|
|
|
|
|
|
10 |
models = {
|
11 |
name: onnx_asr.load_model(name)
|
12 |
for name in [
|
@@ -22,14 +27,22 @@ models = {
|
|
22 |
|
23 |
|
24 |
def recognize(audio: tuple[int, np.ndarray]):
|
|
|
|
|
|
|
25 |
sample_rate, waveform = audio
|
|
|
26 |
try:
|
27 |
waveform = waveform.astype(np.float32) / 2 ** (8 * waveform.itemsize - 1)
|
|
|
|
|
|
|
28 |
results = []
|
29 |
for name, model in models.items():
|
30 |
start = timer()
|
31 |
result = model.recognize(waveform, sample_rate=sample_rate, language="ru")
|
32 |
time = timer() - start
|
|
|
33 |
results.append([name, result, f"{time:.3f} s."])
|
34 |
except Exception as e:
|
35 |
raise gr.Error(f"{e} Audio: sample_rate: {sample_rate}, waveform.shape: {waveform.shape}.") from e
|
|
|
1 |
+
import logging
|
2 |
from importlib.metadata import version
|
3 |
from timeit import default_timer as timer
|
4 |
|
|
|
7 |
|
8 |
import onnx_asr
|
9 |
|
10 |
+
logging.basicConfig(format="%(asctime)s %(levelname)s %(message)s", level=logging.WARNING)
|
11 |
+
logger = logging.getLogger(__name__)
|
12 |
+
logger.setLevel(logging.DEBUG)
|
13 |
+
logger.info("onnx_asr version: %s", version("onnx_asr"))
|
14 |
+
|
15 |
models = {
|
16 |
name: onnx_asr.load_model(name)
|
17 |
for name in [
|
|
|
27 |
|
28 |
|
29 |
def recognize(audio: tuple[int, np.ndarray]):
|
30 |
+
if audio is None:
|
31 |
+
return None
|
32 |
+
|
33 |
sample_rate, waveform = audio
|
34 |
+
logger.debug("recognize: sample_rate %s, waveform.shape %s.", sample_rate, waveform.shape)
|
35 |
try:
|
36 |
waveform = waveform.astype(np.float32) / 2 ** (8 * waveform.itemsize - 1)
|
37 |
+
if waveform.ndim == 2:
|
38 |
+
waveform = waveform.mean(axis=1)
|
39 |
+
|
40 |
results = []
|
41 |
for name, model in models.items():
|
42 |
start = timer()
|
43 |
result = model.recognize(waveform, sample_rate=sample_rate, language="ru")
|
44 |
time = timer() - start
|
45 |
+
logger.debug("recognized by %s: result '%s', time %.3f s.", name, result, time)
|
46 |
results.append([name, result, f"{time:.3f} s."])
|
47 |
except Exception as e:
|
48 |
raise gr.Error(f"{e} Audio: sample_rate: {sample_rate}, waveform.shape: {waveform.shape}.") from e
|