File size: 3,742 Bytes
9aaf513
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
from typing import Optional
import os

from modules.utils.paths import (FASTER_WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR,
                                 INSANELY_FAST_WHISPER_MODELS_DIR, WHISPER_MODELS_DIR, UVR_MODELS_DIR)
from modules.whisper.faster_whisper_inference import FasterWhisperInference
from modules.whisper.whisper_Inference import WhisperInference
from modules.whisper.insanely_fast_whisper_inference import InsanelyFastWhisperInference
from modules.whisper.base_transcription_pipeline import BaseTranscriptionPipeline
from modules.whisper.data_classes import *


class WhisperFactory:
    @staticmethod
    def create_whisper_inference(

        whisper_type: str,

        whisper_model_dir: str = WHISPER_MODELS_DIR,

        faster_whisper_model_dir: str = FASTER_WHISPER_MODELS_DIR,

        insanely_fast_whisper_model_dir: str = INSANELY_FAST_WHISPER_MODELS_DIR,

        diarization_model_dir: str = DIARIZATION_MODELS_DIR,

        uvr_model_dir: str = UVR_MODELS_DIR,

        output_dir: str = OUTPUT_DIR,

    ) -> "BaseTranscriptionPipeline":
        """

        Create a whisper inference class based on the provided whisper_type.



        Parameters

        ----------

        whisper_type : str

            The type of Whisper implementation to use. Supported values (case-insensitive):

            - "faster-whisper": https://github.com/openai/whisper

            - "whisper": https://github.com/openai/whisper

            - "insanely-fast-whisper": https://github.com/Vaibhavs10/insanely-fast-whisper

        whisper_model_dir : str

            Directory path for the Whisper model.

        faster_whisper_model_dir : str

            Directory path for the Faster Whisper model.

        insanely_fast_whisper_model_dir : str

            Directory path for the Insanely Fast Whisper model.

        diarization_model_dir : str

            Directory path for the diarization model.

        uvr_model_dir : str

            Directory path for the UVR model.

        output_dir : str

            Directory path where output files will be saved.



        Returns

        -------

        BaseTranscriptionPipeline

            An instance of the appropriate whisper inference class based on the whisper_type.

        """
        # Temporal fix of the bug : https://github.com/jhj0517/Whisper-WebUI/issues/144
        os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

        whisper_type = whisper_type.strip().lower()

        if whisper_type == WhisperImpl.FASTER_WHISPER.value:
            return FasterWhisperInference(
                model_dir=faster_whisper_model_dir,
                output_dir=output_dir,
                diarization_model_dir=diarization_model_dir,
                uvr_model_dir=uvr_model_dir
            )
        elif whisper_type == WhisperImpl.WHISPER.value:
            return WhisperInference(
                model_dir=whisper_model_dir,
                output_dir=output_dir,
                diarization_model_dir=diarization_model_dir,
                uvr_model_dir=uvr_model_dir
            )
        elif whisper_type == WhisperImpl.INSANELY_FAST_WHISPER.value:
            return InsanelyFastWhisperInference(
                model_dir=insanely_fast_whisper_model_dir,
                output_dir=output_dir,
                diarization_model_dir=diarization_model_dir,
                uvr_model_dir=uvr_model_dir
            )
        else:
            return FasterWhisperInference(
                model_dir=faster_whisper_model_dir,
                output_dir=output_dir,
                diarization_model_dir=diarization_model_dir,
                uvr_model_dir=uvr_model_dir
            )