Spaces:
Running
Running
import functools | |
from fastapi import FastAPI, UploadFile | |
from fastapi.testclient import TestClient | |
from starlette.datastructures import UploadFile as StarletteUploadFile | |
from io import BytesIO | |
import os | |
import requests | |
import pytest | |
import yaml | |
import jiwer | |
from backend.main import app | |
from modules.whisper.data_classes import * | |
from modules.utils.paths import * | |
from modules.utils.files_manager import load_yaml, save_yaml | |
TEST_PIPELINE_PARAMS = {**WhisperParams(model_size="tiny", compute_type="float32").model_dump(exclude_none=True), | |
**VadParams().model_dump(exclude_none=True), | |
**BGMSeparationParams().model_dump(exclude_none=True), | |
**DiarizationParams().model_dump(exclude_none=True)} | |
TEST_VAD_PARAMS = VadParams().model_dump() | |
TEST_BGM_SEPARATION_PARAMS = BGMSeparationParams().model_dump() | |
TEST_FILE_DOWNLOAD_URL = "https://github.com/jhj0517/whisper_flutter_new/raw/main/example/assets/jfk.wav" | |
TEST_FILE_PATH = os.path.join(WEBUI_DIR, "backend", "tests", "jfk.wav") | |
TEST_BGM_SEPARATION_OUTPUT_PATH = os.path.join(WEBUI_DIR, "backend", "tests", "separated_audio.zip") | |
TEST_ANSWER = "And so my fellow Americans ask not what your country can do for you ask what you can do for your country" | |
TEST_WHISPER_MODEL = "tiny" | |
TEST_COMPUTE_TYPE = "float32" | |
def setup_test_file(): | |
def download_file(url=TEST_FILE_DOWNLOAD_URL, file_path=TEST_FILE_PATH): | |
if os.path.exists(file_path): | |
return | |
if not os.path.exists(os.path.dirname(file_path)): | |
os.makedirs(os.path.dirname(file_path)) | |
response = requests.get(url) | |
with open(file_path, "wb") as file: | |
file.write(response.content) | |
print(f"File downloaded to: {file_path}") | |
download_file(TEST_FILE_DOWNLOAD_URL, TEST_FILE_PATH) | |
def get_upload_file_instance(filepath: str = TEST_FILE_PATH) -> UploadFile: | |
with open(filepath, "rb") as f: | |
file_contents = BytesIO(f.read()) | |
filename = os.path.basename(filepath) | |
upload_file = StarletteUploadFile(file=file_contents, filename=filename) | |
return upload_file | |
def get_client(app: FastAPI = app): | |
return TestClient(app) | |
def calculate_wer(answer, prediction): | |
return jiwer.wer(answer, prediction) | |