File size: 2,469 Bytes
9aaf513
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import functools
from fastapi import FastAPI, UploadFile
from fastapi.testclient import TestClient
from starlette.datastructures import UploadFile as StarletteUploadFile
from io import BytesIO
import os
import requests
import pytest
import yaml
import jiwer

from backend.main import app
from modules.whisper.data_classes import *
from modules.utils.paths import *
from modules.utils.files_manager import load_yaml, save_yaml

TEST_PIPELINE_PARAMS = {**WhisperParams(model_size="tiny", compute_type="float32").model_dump(exclude_none=True),
                        **VadParams().model_dump(exclude_none=True),
                        **BGMSeparationParams().model_dump(exclude_none=True),
                        **DiarizationParams().model_dump(exclude_none=True)}
TEST_VAD_PARAMS = VadParams().model_dump()
TEST_BGM_SEPARATION_PARAMS = BGMSeparationParams().model_dump()
TEST_FILE_DOWNLOAD_URL = "https://github.com/jhj0517/whisper_flutter_new/raw/main/example/assets/jfk.wav"
TEST_FILE_PATH = os.path.join(WEBUI_DIR, "backend", "tests", "jfk.wav")
TEST_BGM_SEPARATION_OUTPUT_PATH = os.path.join(WEBUI_DIR, "backend", "tests", "separated_audio.zip")
TEST_ANSWER = "And so my fellow Americans ask not what your country can do for you ask what you can do for your country"
TEST_WHISPER_MODEL = "tiny"
TEST_COMPUTE_TYPE = "float32"


@pytest.fixture(autouse=True)
def setup_test_file():
    @functools.lru_cache
    def download_file(url=TEST_FILE_DOWNLOAD_URL, file_path=TEST_FILE_PATH):
        if os.path.exists(file_path):
            return

        if not os.path.exists(os.path.dirname(file_path)):
            os.makedirs(os.path.dirname(file_path))

        response = requests.get(url)

        with open(file_path, "wb") as file:
            file.write(response.content)

        print(f"File downloaded to: {file_path}")

    download_file(TEST_FILE_DOWNLOAD_URL, TEST_FILE_PATH)


@pytest.fixture
@functools.lru_cache
def get_upload_file_instance(filepath: str = TEST_FILE_PATH) -> UploadFile:
    with open(filepath, "rb") as f:
        file_contents = BytesIO(f.read())
        filename = os.path.basename(filepath)
        upload_file = StarletteUploadFile(file=file_contents, filename=filename)
    return upload_file


@functools.lru_cache
def get_client(app: FastAPI = app):
    return TestClient(app)


def calculate_wer(answer, prediction):
    return jiwer.wer(answer, prediction)