File size: 3,653 Bytes
9aaf513
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import functools
import uuid
import numpy as np
from fastapi import (
    File,
    UploadFile,
)
import gradio as gr
from fastapi import APIRouter, BackgroundTasks, Depends, Response, status
from typing import List, Dict
from sqlalchemy.orm import Session
from datetime import datetime
from modules.whisper.data_classes import *
from modules.utils.paths import BACKEND_CACHE_DIR
from modules.whisper.faster_whisper_inference import FasterWhisperInference
from backend.common.audio import read_audio
from backend.common.models import QueueResponse
from backend.common.config_loader import load_server_config
from backend.db.task.dao import (
    add_task_to_db,
    get_db_session,
    update_task_status_in_db
)
from backend.db.task.models import TaskStatus, TaskType

transcription_router = APIRouter(prefix="/transcription", tags=["Transcription"])


@functools.lru_cache
def get_pipeline() -> 'FasterWhisperInference':
    config = load_server_config()["whisper"]
    inferencer = FasterWhisperInference(
        output_dir=BACKEND_CACHE_DIR
    )
    inferencer.update_model(
        model_size=config["model_size"],
        compute_type=config["compute_type"]
    )
    return inferencer


def run_transcription(

    audio: np.ndarray,

    params: TranscriptionPipelineParams,

    identifier: str,

) -> List[Segment]:
    update_task_status_in_db(
        identifier=identifier,
        update_data={
            "uuid": identifier,
            "status": TaskStatus.IN_PROGRESS,
            "updated_at": datetime.utcnow()
        },
    )

    segments, elapsed_time = get_pipeline().run(
        audio,
        gr.Progress(),
        "SRT",
        False,
        *params.to_list()
    )
    segments = [seg.model_dump() for seg in segments]

    update_task_status_in_db(
        identifier=identifier,
        update_data={
            "uuid": identifier,
            "status": TaskStatus.COMPLETED,
            "result": segments,
            "updated_at": datetime.utcnow(),
            "duration": elapsed_time
        },
    )
    return segments


@transcription_router.post(

    "/",

    response_model=QueueResponse,

    status_code=status.HTTP_201_CREATED,

    summary="Transcribe Audio",

    description="Process the provided audio or video file to generate a transcription.",

)
async def transcription(

    background_tasks: BackgroundTasks,

    file: UploadFile = File(..., description="Audio or video file to transcribe."),

    whisper_params: WhisperParams = Depends(),

    vad_params: VadParams = Depends(),

    bgm_separation_params: BGMSeparationParams = Depends(),

    diarization_params: DiarizationParams = Depends(),

) -> QueueResponse:
    if not isinstance(file, np.ndarray):
        audio, info = await read_audio(file=file)
    else:
        audio, info = file, None

    params = TranscriptionPipelineParams(
        whisper=whisper_params,
        vad=vad_params,
        bgm_separation=bgm_separation_params,
        diarization=diarization_params
    )

    identifier = add_task_to_db(
        status=TaskStatus.QUEUED,
        file_name=file.filename,
        audio_duration=info.duration if info else None,
        language=params.whisper.lang,
        task_type=TaskType.TRANSCRIPTION,
        task_params=params.to_dict(),
    )

    background_tasks.add_task(
        run_transcription,
        audio=audio,
        params=params,
        identifier=identifier,
    )

    return QueueResponse(identifier=identifier, status=TaskStatus.QUEUED, message="Transcription task has queued")