Whisper-WebUI / backend /tests /test_backend_config.py
soiz1's picture
Upload 109 files
9aaf513 verified
import functools
from fastapi import FastAPI, UploadFile
from fastapi.testclient import TestClient
from starlette.datastructures import UploadFile as StarletteUploadFile
from io import BytesIO
import os
import requests
import pytest
import yaml
import jiwer
from backend.main import app
from modules.whisper.data_classes import *
from modules.utils.paths import *
from modules.utils.files_manager import load_yaml, save_yaml
TEST_PIPELINE_PARAMS = {**WhisperParams(model_size="tiny", compute_type="float32").model_dump(exclude_none=True),
**VadParams().model_dump(exclude_none=True),
**BGMSeparationParams().model_dump(exclude_none=True),
**DiarizationParams().model_dump(exclude_none=True)}
TEST_VAD_PARAMS = VadParams().model_dump()
TEST_BGM_SEPARATION_PARAMS = BGMSeparationParams().model_dump()
TEST_FILE_DOWNLOAD_URL = "https://github.com/jhj0517/whisper_flutter_new/raw/main/example/assets/jfk.wav"
TEST_FILE_PATH = os.path.join(WEBUI_DIR, "backend", "tests", "jfk.wav")
TEST_BGM_SEPARATION_OUTPUT_PATH = os.path.join(WEBUI_DIR, "backend", "tests", "separated_audio.zip")
TEST_ANSWER = "And so my fellow Americans ask not what your country can do for you ask what you can do for your country"
TEST_WHISPER_MODEL = "tiny"
TEST_COMPUTE_TYPE = "float32"
@pytest.fixture(autouse=True)
def setup_test_file():
@functools.lru_cache
def download_file(url=TEST_FILE_DOWNLOAD_URL, file_path=TEST_FILE_PATH):
if os.path.exists(file_path):
return
if not os.path.exists(os.path.dirname(file_path)):
os.makedirs(os.path.dirname(file_path))
response = requests.get(url)
with open(file_path, "wb") as file:
file.write(response.content)
print(f"File downloaded to: {file_path}")
download_file(TEST_FILE_DOWNLOAD_URL, TEST_FILE_PATH)
@pytest.fixture
@functools.lru_cache
def get_upload_file_instance(filepath: str = TEST_FILE_PATH) -> UploadFile:
with open(filepath, "rb") as f:
file_contents = BytesIO(f.read())
filename = os.path.basename(filepath)
upload_file = StarletteUploadFile(file=file_contents, filename=filename)
return upload_file
@functools.lru_cache
def get_client(app: FastAPI = app):
return TestClient(app)
def calculate_wer(answer, prediction):
return jiwer.wer(answer, prediction)